aboutsummaryrefslogtreecommitdiff
path: root/kms
diff options
context:
space:
mode:
Diffstat (limited to 'kms')
-rw-r--r--kms/client/kms_client.c420
-rw-r--r--kms/client/kms_client.h13
-rw-r--r--kms/kms_shared.h48
-rw-r--r--kms/server/kms_server.c496
-rw-r--r--kms/server/project.conf3
5 files changed, 694 insertions, 286 deletions
diff --git a/kms/client/kms_client.c b/kms/client/kms_client.c
index 587dda3..57afd04 100644
--- a/kms/client/kms_client.c
+++ b/kms/client/kms_client.c
@@ -1,4 +1,5 @@
#include "kms_client.h"
+#include "../../include/utils.h"
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
@@ -10,28 +11,28 @@
#include <sys/socket.h>
#include <sys/un.h>
#include <sys/wait.h>
+#include <poll.h>
+#include <sys/stat.h>
#include <sys/capability.h>
-static bool generate_random_characters(char *buffer, int buffer_size, const char *alphabet, size_t alphabet_size) {
- int fd = open("/dev/urandom", O_RDONLY);
- if(fd == -1) {
- perror("/dev/urandom");
- return false;
- }
+#define GSR_SOCKET_PAIR_LOCAL 0
+#define GSR_SOCKET_PAIR_REMOTE 1
- if(read(fd, buffer, buffer_size) < buffer_size) {
- fprintf(stderr, "Failed to read %d bytes from /dev/urandom\n", buffer_size);
- close(fd);
- return false;
- }
+static void cleanup_socket(gsr_kms_client *self, bool kill_server);
+static int gsr_kms_client_replace_connection(gsr_kms_client *self);
- for(int i = 0; i < buffer_size; ++i) {
- unsigned char c = *(unsigned char*)&buffer[i];
- buffer[i] = alphabet[c % alphabet_size];
+static void close_fds(gsr_kms_response *response) {
+ for(int i = 0; i < response->num_items; ++i) {
+ for(int j = 0; j < response->items[i].num_dma_bufs; ++j) {
+ gsr_kms_response_dma_buf *dma_buf = &response->items[i].dma_buf[j];
+ if(dma_buf->fd > 0) {
+ close(dma_buf->fd);
+ dma_buf->fd = -1;
+ }
+ }
+ response->items[i].num_dma_bufs = 0;
}
-
- close(fd);
- return true;
+ response->num_items = 0;
}
static int send_msg_to_server(int server_fd, gsr_kms_request *request) {
@@ -39,14 +40,32 @@ static int send_msg_to_server(int server_fd, gsr_kms_request *request) {
iov.iov_base = request;
iov.iov_len = sizeof(*request);
- struct msghdr request_message = {0};
- request_message.msg_iov = &iov;
- request_message.msg_iovlen = 1;
+ struct msghdr response_message = {0};
+ response_message.msg_iov = &iov;
+ response_message.msg_iovlen = 1;
+
+ char cmsgbuf[CMSG_SPACE(sizeof(int) * 1)];
+ memset(cmsgbuf, 0, sizeof(cmsgbuf));
+
+ if(request->new_connection_fd > 0) {
+ response_message.msg_control = cmsgbuf;
+ response_message.msg_controllen = sizeof(cmsgbuf);
- return sendmsg(server_fd, &request_message, 0);
+ struct cmsghdr *cmsg = CMSG_FIRSTHDR(&response_message);
+ cmsg->cmsg_level = SOL_SOCKET;
+ cmsg->cmsg_type = SCM_RIGHTS;
+ cmsg->cmsg_len = CMSG_LEN(sizeof(int) * 1);
+
+ int *fds = (int*)CMSG_DATA(cmsg);
+ fds[0] = request->new_connection_fd;
+
+ response_message.msg_controllen = cmsg->cmsg_len;
+ }
+
+ return sendmsg(server_fd, &response_message, 0);
}
-static int recv_msg_from_server(int server_fd, gsr_kms_response *response) {
+static int recv_msg_from_server(int server_pid, int server_fd, gsr_kms_response *response) {
struct iovec iov;
iov.iov_base = response;
iov.iov_len = sizeof(*response);
@@ -55,26 +74,43 @@ static int recv_msg_from_server(int server_fd, gsr_kms_response *response) {
response_message.msg_iov = &iov;
response_message.msg_iovlen = 1;
- char cmsgbuf[CMSG_SPACE(sizeof(int) * GSR_KMS_MAX_PLANES)];
+ char cmsgbuf[CMSG_SPACE(sizeof(int) * GSR_KMS_MAX_ITEMS * GSR_KMS_MAX_DMA_BUFS)];
memset(cmsgbuf, 0, sizeof(cmsgbuf));
response_message.msg_control = cmsgbuf;
response_message.msg_controllen = sizeof(cmsgbuf);
- int res = recvmsg(server_fd, &response_message, MSG_WAITALL);
- if(res <= 0)
- return res;
+ int res = 0;
+ for(;;) {
+ res = recvmsg(server_fd, &response_message, MSG_DONTWAIT);
+ if(res <= 0 && (errno == EAGAIN || errno == EWOULDBLOCK)) {
+ // If we are replacing the connection and closing the application at the same time
+ // then recvmsg can get stuck (because the server died), so we prevent that by doing
+ // non-blocking recvmsg and checking if the server died
+ int status = 0;
+ int wait_result = waitpid(server_pid, &status, WNOHANG);
+ if(wait_result != 0) {
+ res = -1;
+ break;
+ }
+ usleep(1000);
+ } else {
+ break;
+ }
+ }
- if(response->num_fds > 0) {
+ if(res > 0 && response->num_items > 0) {
struct cmsghdr *cmsg = CMSG_FIRSTHDR(&response_message);
if(cmsg) {
int *fds = (int*)CMSG_DATA(cmsg);
- for(int i = 0; i < response->num_fds; ++i) {
- response->fds[i].fd = fds[i];
+ int fd_index = 0;
+ for(int i = 0; i < response->num_items; ++i) {
+ for(int j = 0; j < response->items[i].num_dma_bufs; ++j) {
+ gsr_kms_response_dma_buf *dma_buf = &response->items[i].dma_buf[j];
+ dma_buf->fd = fds[fd_index++];
+ }
}
} else {
- for(int i = 0; i < response->num_fds; ++i) {
- response->fds[i].fd = 0;
- }
+ close_fds(response);
}
}
@@ -89,85 +125,174 @@ static bool create_socket_path(char *output_path, size_t output_path_size) {
char random_characters[11];
random_characters[10] = '\0';
- if(!generate_random_characters(random_characters, 10, "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789", 62))
+ if(!generate_random_characters_standard_alphabet(random_characters, 10))
return false;
snprintf(output_path, output_path_size, "%s/.gsr-kms-socket-%s", home, random_characters);
return true;
}
-static void strncpy_safe(char *dst, const char *src, int len) {
- int src_len = strlen(src);
- int min_len = src_len;
- if(len - 1 < min_len)
- min_len = len - 1;
- memcpy(dst, src, min_len);
- dst[min_len] = '\0';
+static bool readlink_realpath(const char *filepath, char *buffer) {
+ char symlinked_path[PATH_MAX];
+ ssize_t bytes_written = readlink(filepath, symlinked_path, sizeof(symlinked_path) - 1);
+ if(bytes_written == -1 && errno == EINVAL) {
+ /* Not a symlink */
+ snprintf(symlinked_path, sizeof(symlinked_path), "%s", filepath);
+ } else if(bytes_written == -1) {
+ return false;
+ } else {
+ symlinked_path[bytes_written] = '\0';
+ }
+
+ if(!realpath(symlinked_path, buffer))
+ return false;
+
+ return true;
+}
+
+static bool strcat_safe(char *str, int size, const char *str_to_add) {
+ const int str_len = strlen(str);
+ const int str_to_add_len = strlen(str_to_add);
+ if(str_len + str_to_add_len + 1 >= size)
+ return false;
+
+ memcpy(str + str_len, str_to_add, str_to_add_len);
+ str[str_len + str_to_add_len] = '\0';
+ return true;
+}
+
+static void file_get_directory(char *filepath) {
+ char *end = strrchr(filepath, '/');
+ if(end == NULL)
+ filepath[0] = '\0';
+ else
+ *end = '\0';
+}
+
+static bool find_program_in_path(const char *program_name, char *filepath, int filepath_len) {
+ const char *path = getenv("PATH");
+ if(!path)
+ return false;
+
+ int program_name_len = strlen(program_name);
+ const char *end = path + strlen(path);
+ while(path != end) {
+ const char *part_end = strchr(path, ':');
+ const char *next = part_end;
+ if(part_end) {
+ next = part_end + 1;
+ } else {
+ part_end = end;
+ next = end;
+ }
+
+ int len = part_end - path;
+ if(len + 1 + program_name_len < filepath_len) {
+ memcpy(filepath, path, len);
+ filepath[len] = '/';
+ memcpy(filepath + len + 1, program_name, program_name_len);
+ filepath[len + 1 + program_name_len] = '\0';
+
+ if(access(filepath, F_OK) == 0)
+ return true;
+ }
+
+ path = next;
+ }
+
+ return false;
}
int gsr_kms_client_init(gsr_kms_client *self, const char *card_path) {
+ int result = -1;
self->kms_server_pid = -1;
- self->socket_fd = -1;
- self->client_fd = -1;
- self->socket_path[0] = '\0';
+ self->initial_socket_fd = -1;
+ self->initial_client_fd = -1;
+ self->initial_socket_path[0] = '\0';
+ self->socket_pair[0] = -1;
+ self->socket_pair[1] = -1;
struct sockaddr_un local_addr = {0};
struct sockaddr_un remote_addr = {0};
- if(!create_socket_path(self->socket_path, sizeof(self->socket_path))) {
+ if(!create_socket_path(self->initial_socket_path, sizeof(self->initial_socket_path))) {
fprintf(stderr, "gsr error: gsr_kms_client_init: failed to create path to kms socket\n");
return -1;
}
- // This doesn't work on nixos, but we dont want to use $PATH because we want to make this as safe as possible by running pkexec
- // on a path that only root can modify. If we use "gsr-kms-server" instead then $PATH can be modified in ~/.bashrc for example
- // which will overwrite the path to gsr-kms-server and the user can end up running a malicious program that pretends to be gsr-kms-server.
- // If there is a safe way to do this on nixos, then please tell me; or use gpu-screen-recorder flatpak instead.
- const char *server_filepath = "/usr/bin/gsr-kms-server";
- bool has_perm = 0;
- const bool inside_flatpak = getenv("FLATPAK_ID") != NULL;
- if(!inside_flatpak) {
- if(access("/usr/bin/gsr-kms-server", F_OK) != 0) {
- fprintf(stderr, "gsr error: gsr_kms_client_init: /usr/bin/gsr-kms-server not found, please install gpu-screen-recorder first\n");
+ char server_filepath[PATH_MAX];
+ if(!readlink_realpath("/proc/self/exe", server_filepath)) {
+ fprintf(stderr, "gsr error: gsr_kms_client_init: failed to resolve /proc/self/exe\n");
+ return -1;
+ }
+ file_get_directory(server_filepath);
+
+ if(!strcat_safe(server_filepath, sizeof(server_filepath), "/gsr-kms-server")) {
+ fprintf(stderr, "gsr error: gsr_kms_client_init: gsr-kms-server path too long\n");
+ return -1;
+ }
+
+ if(access(server_filepath, F_OK) != 0) {
+ fprintf(stderr, "gsr info: gsr_kms_client_init: gsr-kms-server is not installed in the same directory as gpu-screen-recorder (%s not found), looking for gsr-kms-server in PATH instead\n", server_filepath);
+ if(!find_program_in_path("gsr-kms-server", server_filepath, sizeof(server_filepath)) || access(server_filepath, F_OK) != 0) {
+ fprintf(stderr, "gsr error: gsr_kms_client_init: gsr-kms-server was not found in PATH. Please install gpu-screen-recorder properly\n");
return -1;
}
+ }
- if(geteuid() == 0) {
- has_perm = true;
- } else {
- cap_t kms_server_cap = cap_get_file(server_filepath);
- if(kms_server_cap) {
- cap_flag_value_t res = 0;
- cap_get_flag(kms_server_cap, CAP_SYS_ADMIN, CAP_PERMITTED, &res);
- if(res == CAP_SET) {
- //fprintf(stderr, "has permission!\n");
- has_perm = true;
- } else {
- //fprintf(stderr, "No permission:(\n");
- }
- cap_free(kms_server_cap);
+ fprintf(stderr, "gsr info: gsr_kms_client_init: setting up connection to %s\n", server_filepath);
+
+ const bool inside_flatpak = getenv("FLATPAK_ID") != NULL;
+ const char *home = getenv("HOME");
+ if(!home)
+ home = "/tmp";
+
+ bool has_perm = 0;
+ if(geteuid() == 0) {
+ has_perm = true;
+ } else {
+ cap_t kms_server_cap = cap_get_file(server_filepath);
+ if(kms_server_cap) {
+ cap_flag_value_t res = CAP_CLEAR;
+ cap_get_flag(kms_server_cap, CAP_SYS_ADMIN, CAP_PERMITTED, &res);
+ if(res == CAP_SET) {
+ //fprintf(stderr, "has permission!\n");
+ has_perm = true;
} else {
- if(errno == ENODATA)
- fprintf(stderr, "gsr info: gsr_kms_client_init: gsr-kms-server is missing sys_admin cap and will require root authentication. To bypass this automatically, run: sudo setcap cap_sys_admin+ep '%s'\n", server_filepath);
- else
- fprintf(stderr, "gsr info: gsr_kms_client_init: failed to get cap\n");
+ //fprintf(stderr, "No permission:(\n");
}
+ cap_free(kms_server_cap);
+ } else if(!inside_flatpak) {
+ if(errno == ENODATA)
+ fprintf(stderr, "gsr info: gsr_kms_client_init: gsr-kms-server is missing sys_admin cap and will require root authentication. To bypass this automatically, run: sudo setcap cap_sys_admin+ep '%s'\n", server_filepath);
+ else
+ fprintf(stderr, "gsr info: gsr_kms_client_init: failed to get cap\n");
}
}
- self->socket_fd = socket(AF_UNIX, SOCK_STREAM, 0);
- if(self->socket_fd == -1) {
+ if(socketpair(AF_UNIX, SOCK_STREAM, 0, self->socket_pair) == -1) {
+ fprintf(stderr, "gsr error: gsr_kms_client_init: socketpair failed, error: %s\n", strerror(errno));
+ goto err;
+ }
+
+ self->initial_socket_fd = socket(AF_UNIX, SOCK_STREAM, 0);
+ if(self->initial_socket_fd == -1) {
fprintf(stderr, "gsr error: gsr_kms_client_init: socket failed, error: %s\n", strerror(errno));
goto err;
}
local_addr.sun_family = AF_UNIX;
- strncpy_safe(local_addr.sun_path, self->socket_path, sizeof(local_addr.sun_path));
- if(bind(self->socket_fd, (struct sockaddr*)&local_addr, sizeof(local_addr.sun_family) + strlen(local_addr.sun_path)) == -1) {
+ snprintf(local_addr.sun_path, sizeof(local_addr.sun_path), "%s", (const char*)self->initial_socket_path);
+
+ const mode_t prev_mask = umask(0000);
+ const int bind_res = bind(self->initial_socket_fd, (struct sockaddr*)&local_addr, sizeof(local_addr.sun_family) + strlen(local_addr.sun_path));
+ umask(prev_mask);
+
+ if(bind_res == -1) {
fprintf(stderr, "gsr error: gsr_kms_client_init: failed to bind socket, error: %s\n", strerror(errno));
goto err;
}
- if(listen(self->socket_fd, 1) == -1) {
+ if(listen(self->initial_socket_fd, 1) == -1) {
fprintf(stderr, "gsr error: gsr_kms_client_init: failed to listen on socket, error: %s\n", strerror(errno));
goto err;
}
@@ -178,100 +303,169 @@ int gsr_kms_client_init(gsr_kms_client *self, const char *card_path) {
goto err;
} else if(pid == 0) { /* child */
if(inside_flatpak) {
- const char *args[] = { "flatpak-spawn", "--host", "pkexec", "flatpak", "run", "--command=gsr-kms-server", "com.dec05eba.gpu_screen_recorder", self->socket_path, card_path, NULL };
+ const char *args[] = { "flatpak-spawn", "--host", "/var/lib/flatpak/app/com.dec05eba.gpu_screen_recorder/current/active/files/bin/kms-server-proxy", self->initial_socket_path, card_path, home, NULL };
execvp(args[0], (char *const*)args);
} else if(has_perm) {
- const char *args[] = { server_filepath, self->socket_path, card_path, NULL };
+ const char *args[] = { server_filepath, self->initial_socket_path, card_path, NULL };
execvp(args[0], (char *const*)args);
} else {
- const char *args[] = { "pkexec", server_filepath, self->socket_path, card_path, NULL };
+ const char *args[] = { "pkexec", server_filepath, self->initial_socket_path, card_path, NULL };
execvp(args[0], (char *const*)args);
}
- fprintf(stderr, "gsr error: gsr_kms_client_init: execvp failed, error: %s\n", strerror(errno));
+ fprintf(stderr, "gsr error: gsr_kms_client_init: failed to launch \"gsr-kms-server\", error: %s\n", strerror(errno));
_exit(127);
} else { /* parent */
self->kms_server_pid = pid;
}
fprintf(stderr, "gsr info: gsr_kms_client_init: waiting for server to connect\n");
+ struct pollfd poll_fd = {
+ .fd = self->initial_socket_fd,
+ .events = POLLIN,
+ .revents = 0
+ };
for(;;) {
- struct timeval tv;
- fd_set rfds;
- FD_ZERO(&rfds);
- FD_SET(self->socket_fd, &rfds);
-
- tv.tv_sec = 0;
- tv.tv_usec = 100 * 1000; // 100 ms
-
- int select_res = select(1 + self->socket_fd, &rfds, NULL, NULL, &tv);
- if(select_res > 0) {
+ int poll_res = poll(&poll_fd, 1, 100);
+ if(poll_res > 0 && (poll_fd.revents & POLLIN)) {
socklen_t sock_len = 0;
- self->client_fd = accept(self->socket_fd, (struct sockaddr*)&remote_addr, &sock_len);
- if(self->client_fd == -1) {
+ self->initial_client_fd = accept(self->initial_socket_fd, (struct sockaddr*)&remote_addr, &sock_len);
+ if(self->initial_client_fd == -1) {
fprintf(stderr, "gsr error: gsr_kms_client_init: accept failed on socket, error: %s\n", strerror(errno));
goto err;
}
break;
} else {
- int status;
+ int status = 0;
int wait_result = waitpid(self->kms_server_pid, &status, WNOHANG);
if(wait_result != 0) {
- fprintf(stderr, "gsr error: gsr_kms_client_init: kms server died or never started, error: %s\n", strerror(errno));
+ int exit_code = -1;
+ if(WIFEXITED(status))
+ exit_code = WEXITSTATUS(status);
+ fprintf(stderr, "gsr error: gsr_kms_client_init: kms server died or never started, exit code: %d\n", exit_code);
self->kms_server_pid = -1;
+ if(exit_code != 0)
+ result = exit_code;
goto err;
}
}
}
fprintf(stderr, "gsr info: gsr_kms_client_init: server connected\n");
+ fprintf(stderr, "gsr info: replacing file-backed unix domain socket with socketpair\n");
+ if(gsr_kms_client_replace_connection(self) != 0)
+ goto err;
+
+ cleanup_socket(self, false);
+ fprintf(stderr, "gsr info: using socketpair\n");
+
return 0;
err:
gsr_kms_client_deinit(self);
- return -1;
+ return result;
}
-void gsr_kms_client_deinit(gsr_kms_client *self) {
- if(self->client_fd != -1) {
- close(self->client_fd);
- self->client_fd = -1;
+void cleanup_socket(gsr_kms_client *self, bool kill_server) {
+ if(self->initial_client_fd > 0) {
+ close(self->initial_client_fd);
+ self->initial_client_fd = -1;
}
- if(self->socket_fd != -1) {
- close(self->socket_fd);
- self->socket_fd = -1;
+ if(self->initial_socket_fd > 0) {
+ close(self->initial_socket_fd);
+ self->initial_socket_fd = -1;
}
- if(self->kms_server_pid != -1) {
- kill(self->kms_server_pid, SIGINT);
- int status;
- waitpid(self->kms_server_pid, &status, 0);
+ if(kill_server) {
+ for(int i = 0; i < 2; ++i) {
+ if(self->socket_pair[i] > 0) {
+ close(self->socket_pair[i]);
+ self->socket_pair[i] = -1;
+ }
+ }
+ }
+
+ if(kill_server && self->kms_server_pid > 0) {
+ kill(self->kms_server_pid, SIGKILL);
+ // TODO:
+ //int status;
+ //waitpid(self->kms_server_pid, &status, 0);
self->kms_server_pid = -1;
}
- if(self->socket_path[0] != '\0') {
- remove(self->socket_path);
- self->socket_path[0] = '\0';
+ if(self->initial_socket_path[0] != '\0') {
+ remove(self->initial_socket_path);
+ self->initial_socket_path[0] = '\0';
}
}
+void gsr_kms_client_deinit(gsr_kms_client *self) {
+ cleanup_socket(self, true);
+}
+
+int gsr_kms_client_replace_connection(gsr_kms_client *self) {
+ gsr_kms_response response;
+ response.version = 0;
+ response.result = KMS_RESULT_FAILED_TO_SEND;
+ response.err_msg[0] = '\0';
+
+ gsr_kms_request request;
+ request.version = GSR_KMS_PROTOCOL_VERSION;
+ request.type = KMS_REQUEST_TYPE_REPLACE_CONNECTION;
+ request.new_connection_fd = self->socket_pair[GSR_SOCKET_PAIR_REMOTE];
+ if(send_msg_to_server(self->initial_client_fd, &request) == -1) {
+ fprintf(stderr, "gsr error: gsr_kms_client_replace_connection: failed to send request message to server\n");
+ return -1;
+ }
+
+ const int recv_res = recv_msg_from_server(self->kms_server_pid, self->socket_pair[GSR_SOCKET_PAIR_LOCAL], &response);
+ if(recv_res == 0) {
+ fprintf(stderr, "gsr warning: gsr_kms_client_replace_connection: kms server shut down\n");
+ return -1;
+ } else if(recv_res == -1) {
+ fprintf(stderr, "gsr error: gsr_kms_client_replace_connection: failed to receive response\n");
+ return -1;
+ }
+
+ if(response.version != GSR_KMS_PROTOCOL_VERSION) {
+ fprintf(stderr, "gsr error: gsr_kms_client_replace_connection: expected gsr-kms-server protocol version to be %u, but it's %u. please reinstall gpu screen recorder\n", GSR_KMS_PROTOCOL_VERSION, response.version);
+ /*close_fds(response);*/
+ return -1;
+ }
+
+ return 0;
+}
+
int gsr_kms_client_get_kms(gsr_kms_client *self, gsr_kms_response *response) {
+ response->version = 0;
response->result = KMS_RESULT_FAILED_TO_SEND;
- strcpy(response->err_msg, "failed to send");
+ response->err_msg[0] = '\0';
gsr_kms_request request;
+ request.version = GSR_KMS_PROTOCOL_VERSION;
request.type = KMS_REQUEST_TYPE_GET_KMS;
- if(send_msg_to_server(self->client_fd, &request) == -1) {
+ request.new_connection_fd = 0;
+ if(send_msg_to_server(self->socket_pair[GSR_SOCKET_PAIR_LOCAL], &request) == -1) {
fprintf(stderr, "gsr error: gsr_kms_client_get_kms: failed to send request message to server\n");
+ strcpy(response->err_msg, "failed to send");
return -1;
}
- const int recv_res = recv_msg_from_server(self->client_fd, response);
+ const int recv_res = recv_msg_from_server(self->kms_server_pid, self->socket_pair[GSR_SOCKET_PAIR_LOCAL], response);
if(recv_res == 0) {
fprintf(stderr, "gsr warning: gsr_kms_client_get_kms: kms server shut down\n");
+ strcpy(response->err_msg, "failed to receive");
return -1;
} else if(recv_res == -1) {
fprintf(stderr, "gsr error: gsr_kms_client_get_kms: failed to receive response\n");
+ strcpy(response->err_msg, "failed to receive");
+ return -1;
+ }
+
+ if(response->version != GSR_KMS_PROTOCOL_VERSION) {
+ fprintf(stderr, "gsr error: gsr_kms_client_get_kms: expected gsr-kms-server protocol version to be %u, but it's %u. please reinstall gpu screen recorder\n", GSR_KMS_PROTOCOL_VERSION, response->version);
+ /*close_fds(response);*/
+ strcpy(response->err_msg, "mismatching protocol version");
return -1;
}
diff --git a/kms/client/kms_client.h b/kms/client/kms_client.h
index 254637b..2d18848 100644
--- a/kms/client/kms_client.h
+++ b/kms/client/kms_client.h
@@ -5,12 +5,15 @@
#include <sys/types.h>
#include <limits.h>
-typedef struct {
+typedef struct gsr_kms_client gsr_kms_client;
+
+struct gsr_kms_client {
pid_t kms_server_pid;
- int socket_fd;
- int client_fd;
- char socket_path[PATH_MAX];
-} gsr_kms_client;
+ int initial_socket_fd;
+ int initial_client_fd;
+ char initial_socket_path[PATH_MAX];
+ int socket_pair[2];
+};
/* |card_path| should be a path to card, for example /dev/dri/card0 */
int gsr_kms_client_init(gsr_kms_client *self, const char *card_path);
diff --git a/kms/kms_shared.h b/kms/kms_shared.h
index e0687b2..2dbb655 100644
--- a/kms/kms_shared.h
+++ b/kms/kms_shared.h
@@ -3,10 +3,19 @@
#include <stdint.h>
#include <stdbool.h>
+#include <drm_mode.h>
-#define GSR_KMS_MAX_PLANES 32
+#define GSR_KMS_PROTOCOL_VERSION 4
+
+#define GSR_KMS_MAX_ITEMS 8
+#define GSR_KMS_MAX_DMA_BUFS 4
+
+typedef struct gsr_kms_response_dma_buf gsr_kms_response_dma_buf;
+typedef struct gsr_kms_response_item gsr_kms_response_item;
+typedef struct gsr_kms_response gsr_kms_response;
typedef enum {
+ KMS_REQUEST_TYPE_REPLACE_CONNECTION,
KMS_REQUEST_TYPE_GET_KMS
} gsr_kms_request_type;
@@ -14,30 +23,45 @@ typedef enum {
KMS_RESULT_OK,
KMS_RESULT_INVALID_REQUEST,
KMS_RESULT_FAILED_TO_GET_PLANE,
+ KMS_RESULT_FAILED_TO_GET_PLANES,
KMS_RESULT_FAILED_TO_SEND
} gsr_kms_result;
typedef struct {
- int type; /* gsr_kms_request_type */
+ uint32_t version; /* GSR_KMS_PROTOCOL_VERSION */
+ int type; /* gsr_kms_request_type */
+ int new_connection_fd;
} gsr_kms_request;
-typedef struct {
+struct gsr_kms_response_dma_buf {
int fd;
- uint32_t width;
- uint32_t height;
uint32_t pitch;
uint32_t offset;
+};
+
+struct gsr_kms_response_item {
+ gsr_kms_response_dma_buf dma_buf[GSR_KMS_MAX_DMA_BUFS];
+ int num_dma_bufs;
+ uint32_t width;
+ uint32_t height;
uint32_t pixel_format;
uint64_t modifier;
uint32_t connector_id; /* 0 if unknown */
- bool is_combined_plane;
-} gsr_kms_response_fd;
+ bool is_cursor;
+ bool has_hdr_metadata;
+ int x;
+ int y;
+ int src_w;
+ int src_h;
+ struct hdr_output_metadata hdr_metadata;
+};
-typedef struct {
- int result; /* gsr_kms_result */
+struct gsr_kms_response {
+ uint32_t version; /* GSR_KMS_PROTOCOL_VERSION */
+ int result; /* gsr_kms_result */
char err_msg[128];
- gsr_kms_response_fd fds[GSR_KMS_MAX_PLANES];
- int num_fds;
-} gsr_kms_response;
+ gsr_kms_response_item items[GSR_KMS_MAX_ITEMS];
+ int num_items;
+};
#endif /* #define GSR_KMS_SHARED_H */
diff --git a/kms/server/kms_server.c b/kms/server/kms_server.c
index 5aa6590..070875b 100644
--- a/kms/server/kms_server.c
+++ b/kms/server/kms_server.c
@@ -1,11 +1,17 @@
+#ifndef _GNU_SOURCE
+#define _GNU_SOURCE
+#endif
+
#include "../kms_shared.h"
#include <stdio.h>
#include <string.h>
#include <errno.h>
#include <stdlib.h>
+#include <locale.h>
#include <unistd.h>
+#include <limits.h>
#include <fcntl.h>
#include <sys/socket.h>
#include <sys/un.h>
@@ -13,20 +19,19 @@
#include <xf86drm.h>
#include <xf86drmMode.h>
-#include <libdrm/drm_mode.h>
+#include <drm_mode.h>
+#include <drm_fourcc.h>
#define MAX_CONNECTORS 32
typedef struct {
int drmfd;
- uint32_t plane_ids[GSR_KMS_MAX_PLANES];
- uint32_t connector_ids[GSR_KMS_MAX_PLANES];
- size_t num_plane_ids;
} gsr_drm;
typedef struct {
uint32_t connector_id;
uint64_t crtc_id;
+ uint64_t hdr_metadata_blob_id;
} connector_crtc_pair;
typedef struct {
@@ -38,6 +43,14 @@ static int max_int(int a, int b) {
return a > b ? a : b;
}
+static int count_num_fds(const gsr_kms_response *response) {
+ int num_fds = 0;
+ for(int i = 0; i < response->num_items; ++i) {
+ num_fds += response->items[i].num_dma_bufs;
+ }
+ return num_fds;
+}
+
static int send_msg_to_client(int client_fd, gsr_kms_response *response) {
struct iovec iov;
iov.iov_base = response;
@@ -47,21 +60,25 @@ static int send_msg_to_client(int client_fd, gsr_kms_response *response) {
response_message.msg_iov = &iov;
response_message.msg_iovlen = 1;
- char cmsgbuf[CMSG_SPACE(sizeof(int) * max_int(1, response->num_fds))];
+ const int num_fds = count_num_fds(response);
+ char cmsgbuf[CMSG_SPACE(sizeof(int) * max_int(1, num_fds))];
memset(cmsgbuf, 0, sizeof(cmsgbuf));
- if(response->num_fds > 0) {
+ if(num_fds > 0) {
response_message.msg_control = cmsgbuf;
response_message.msg_controllen = sizeof(cmsgbuf);
struct cmsghdr *cmsg = CMSG_FIRSTHDR(&response_message);
cmsg->cmsg_level = SOL_SOCKET;
cmsg->cmsg_type = SCM_RIGHTS;
- cmsg->cmsg_len = CMSG_LEN(sizeof(int) * response->num_fds);
+ cmsg->cmsg_len = CMSG_LEN(sizeof(int) * num_fds);
int *fds = (int*)CMSG_DATA(cmsg);
- for(int i = 0; i < response->num_fds; ++i) {
- fds[i] = response->fds[i].fd;
+ int fd_index = 0;
+ for(int i = 0; i < response->num_items; ++i) {
+ for(int j = 0; j < response->items[i].num_dma_bufs; ++j) {
+ fds[fd_index++] = response->items[i].dma_buf[j].fd;
+ }
}
response_message.msg_controllen = cmsg->cmsg_len;
@@ -70,6 +87,40 @@ static int send_msg_to_client(int client_fd, gsr_kms_response *response) {
return sendmsg(client_fd, &response_message, 0);
}
+static int recv_msg_from_client(int client_fd, gsr_kms_request *request) {
+ struct iovec iov;
+ iov.iov_base = request;
+ iov.iov_len = sizeof(*request);
+
+ struct msghdr response_message = {0};
+ response_message.msg_iov = &iov;
+ response_message.msg_iovlen = 1;
+
+ char cmsgbuf[CMSG_SPACE(sizeof(int) * 1)];
+ memset(cmsgbuf, 0, sizeof(cmsgbuf));
+ response_message.msg_control = cmsgbuf;
+ response_message.msg_controllen = sizeof(cmsgbuf);
+
+ int res = recvmsg(client_fd, &response_message, MSG_WAITALL);
+ if(res <= 0)
+ return res;
+
+ if(request->new_connection_fd > 0) {
+ struct cmsghdr *cmsg = CMSG_FIRSTHDR(&response_message);
+ if(cmsg) {
+ int *fds = (int*)CMSG_DATA(cmsg);
+ request->new_connection_fd = fds[0];
+ } else {
+ if(request->new_connection_fd > 0) {
+ close(request->new_connection_fd);
+ request->new_connection_fd = 0;
+ }
+ }
+ }
+
+ return res;
+}
+
static bool connector_get_property_by_name(int drmfd, drmModeConnectorPtr props, const char *name, uint64_t *result) {
for(int i = 0; i < props->count_props; ++i) {
drmModePropertyPtr prop = drmModeGetProperty(drmfd, props->props[i]);
@@ -85,128 +136,157 @@ static bool connector_get_property_by_name(int drmfd, drmModeConnectorPtr props,
return false;
}
-static bool plane_is_cursor_plane(int drmfd, uint32_t plane_id) {
+typedef enum {
+ PLANE_PROPERTY_X = 1 << 0,
+ PLANE_PROPERTY_Y = 1 << 1,
+ PLANE_PROPERTY_SRC_X = 1 << 2,
+ PLANE_PROPERTY_SRC_Y = 1 << 3,
+ PLANE_PROPERTY_SRC_W = 1 << 4,
+ PLANE_PROPERTY_SRC_H = 1 << 5,
+ PLANE_PROPERTY_IS_CURSOR = 1 << 6,
+ PLANE_PROPERTY_IS_PRIMARY = 1 << 7,
+} plane_property_mask;
+
+/* Returns plane_property_mask */
+static uint32_t plane_get_properties(int drmfd, uint32_t plane_id, int *x, int *y, int *src_x, int *src_y, int *src_w, int *src_h) {
+ *x = 0;
+ *y = 0;
+ *src_x = 0;
+ *src_y = 0;
+ *src_w = 0;
+ *src_h = 0;
+
+ plane_property_mask property_mask = 0;
+
drmModeObjectPropertiesPtr props = drmModeObjectGetProperties(drmfd, plane_id, DRM_MODE_OBJECT_PLANE);
if(!props)
- return false;
+ return property_mask;
+ // TODO: Dont do this every frame
for(uint32_t i = 0; i < props->count_props; ++i) {
drmModePropertyPtr prop = drmModeGetProperty(drmfd, props->props[i]);
- if(prop) {
- if(strcmp(prop->name, "type") == 0) {
- const uint64_t current_enum_value = props->prop_values[i];
- bool is_cursor = false;
-
- for(int j = 0; j < prop->count_enums; ++j) {
- if(prop->enums[j].value == current_enum_value && strcmp(prop->enums[j].name, "Cursor") == 0) {
- is_cursor = true;
- break;
- }
+ if(!prop)
+ continue;
+ // SRC_* values are fixed 16.16 points
+ const uint32_t type = prop->flags & (DRM_MODE_PROP_LEGACY_TYPE | DRM_MODE_PROP_EXTENDED_TYPE);
+ if((type & DRM_MODE_PROP_SIGNED_RANGE) && strcmp(prop->name, "CRTC_X") == 0) {
+ *x = (int)props->prop_values[i];
+ property_mask |= PLANE_PROPERTY_X;
+ } else if((type & DRM_MODE_PROP_SIGNED_RANGE) && strcmp(prop->name, "CRTC_Y") == 0) {
+ *y = (int)props->prop_values[i];
+ property_mask |= PLANE_PROPERTY_Y;
+ } else if((type & DRM_MODE_PROP_RANGE) && strcmp(prop->name, "SRC_X") == 0) {
+ *src_x = (int)(props->prop_values[i] >> 16);
+ property_mask |= PLANE_PROPERTY_SRC_X;
+ } else if((type & DRM_MODE_PROP_RANGE) && strcmp(prop->name, "SRC_Y") == 0) {
+ *src_y = (int)(props->prop_values[i] >> 16);
+ property_mask |= PLANE_PROPERTY_SRC_Y;
+ } else if((type & DRM_MODE_PROP_RANGE) && strcmp(prop->name, "SRC_W") == 0) {
+ *src_w = (int)(props->prop_values[i] >> 16);
+ property_mask |= PLANE_PROPERTY_SRC_W;
+ } else if((type & DRM_MODE_PROP_RANGE) && strcmp(prop->name, "SRC_H") == 0) {
+ *src_h = (int)(props->prop_values[i] >> 16);
+ property_mask |= PLANE_PROPERTY_SRC_H;
+ } else if((type & DRM_MODE_PROP_ENUM) && strcmp(prop->name, "type") == 0) {
+ const uint64_t current_enum_value = props->prop_values[i];
+ for(int j = 0; j < prop->count_enums; ++j) {
+ if(prop->enums[j].value == current_enum_value && strcmp(prop->enums[j].name, "Primary") == 0) {
+ property_mask |= PLANE_PROPERTY_IS_PRIMARY;
+ break;
+ } else if(prop->enums[j].value == current_enum_value && strcmp(prop->enums[j].name, "Cursor") == 0) {
+ property_mask |= PLANE_PROPERTY_IS_CURSOR;
+ break;
}
-
- drmModeFreeProperty(prop);
- return is_cursor;
}
- drmModeFreeProperty(prop);
}
+
+ drmModeFreeProperty(prop);
}
drmModeFreeObjectProperties(props);
- return false;
+ return property_mask;
}
-/* Returns 0 if not found */
-static uint32_t get_connector_by_crtc_id(const connector_to_crtc_map *c2crtc_map, uint32_t crtc_id) {
+/* Returns NULL if not found */
+static const connector_crtc_pair* get_connector_pair_by_crtc_id(const connector_to_crtc_map *c2crtc_map, uint32_t crtc_id) {
for(int i = 0; i < c2crtc_map->num_maps; ++i) {
if(c2crtc_map->maps[i].crtc_id == crtc_id)
- return c2crtc_map->maps[i].connector_id;
+ return &c2crtc_map->maps[i];
}
- return 0;
+ return NULL;
}
-static int kms_get_plane_ids(gsr_drm *drm) {
- drmModePlaneResPtr planes = NULL;
- drmModeResPtr resources = NULL;
- int result = -1;
-
- connector_to_crtc_map c2crtc_map;
- c2crtc_map.num_maps = 0;
+static void map_crtc_to_connector_ids(gsr_drm *drm, connector_to_crtc_map *c2crtc_map) {
+ c2crtc_map->num_maps = 0;
+ drmModeResPtr resources = drmModeGetResources(drm->drmfd);
+ if(!resources)
+ return;
- if(drmSetClientCap(drm->drmfd, DRM_CLIENT_CAP_UNIVERSAL_PLANES, 1) != 0) {
- fprintf(stderr, "kms server error: drmSetClientCap DRM_CLIENT_CAP_UNIVERSAL_PLANES failed, error: %s\n", strerror(errno));
- goto error;
- }
-
- if(drmSetClientCap(drm->drmfd, DRM_CLIENT_CAP_ATOMIC, 1) != 0) {
- fprintf(stderr, "kms server warning: drmSetClientCap DRM_CLIENT_CAP_ATOMIC failed, error: %s. The wrong monitor may be captured as a result\n", strerror(errno));
- }
+ for(int i = 0; i < resources->count_connectors && c2crtc_map->num_maps < MAX_CONNECTORS; ++i) {
+ drmModeConnectorPtr connector = drmModeGetConnectorCurrent(drm->drmfd, resources->connectors[i]);
+ if(!connector)
+ continue;
- planes = drmModeGetPlaneResources(drm->drmfd);
- if(!planes) {
- fprintf(stderr, "kms server error: failed to access planes, error: %s\n", strerror(errno));
- goto error;
- }
+ uint64_t crtc_id = 0;
+ connector_get_property_by_name(drm->drmfd, connector, "CRTC_ID", &crtc_id);
- resources = drmModeGetResources(drm->drmfd);
- if(resources) {
- for(int i = 0; i < resources->count_connectors && c2crtc_map.num_maps < MAX_CONNECTORS; ++i) {
- drmModeConnectorPtr connector = drmModeGetConnectorCurrent(drm->drmfd, resources->connectors[i]);
- if(connector) {
- uint64_t crtc_id = 0;
- connector_get_property_by_name(drm->drmfd, connector, "CRTC_ID", &crtc_id);
+ uint64_t hdr_output_metadata_blob_id = 0;
+ connector_get_property_by_name(drm->drmfd, connector, "HDR_OUTPUT_METADATA", &hdr_output_metadata_blob_id);
- c2crtc_map.maps[c2crtc_map.num_maps].connector_id = connector->connector_id;
- c2crtc_map.maps[c2crtc_map.num_maps].crtc_id = crtc_id;
- ++c2crtc_map.num_maps;
+ c2crtc_map->maps[c2crtc_map->num_maps].connector_id = connector->connector_id;
+ c2crtc_map->maps[c2crtc_map->num_maps].crtc_id = crtc_id;
+ c2crtc_map->maps[c2crtc_map->num_maps].hdr_metadata_blob_id = hdr_output_metadata_blob_id;
+ ++c2crtc_map->num_maps;
- drmModeFreeConnector(connector);
- }
- }
- drmModeFreeResources(resources);
+ drmModeFreeConnector(connector);
}
+ drmModeFreeResources(resources);
+}
- for(uint32_t i = 0; i < planes->count_planes && drm->num_plane_ids < GSR_KMS_MAX_PLANES; ++i) {
- drmModePlanePtr plane = drmModeGetPlane(drm->drmfd, planes->planes[i]);
- if(!plane) {
- fprintf(stderr, "kms server warning: failed to get drmModePlanePtr for plane %#x: %s (%d)\n", planes->planes[i], strerror(errno), errno);
+static void drm_mode_cleanup_handles(int drmfd, drmModeFB2Ptr drmfb) {
+ for(int i = 0; i < 4; ++i) {
+ if(!drmfb->handles[i])
continue;
- }
- if(!plane->fb_id) {
- drmModeFreePlane(plane);
- continue;
+ bool already_closed = false;
+ for(int j = 0; j < i; ++j) {
+ if(drmfb->handles[i] == drmfb->handles[j]) {
+ already_closed = true;
+ break;
+ }
}
- if(plane_is_cursor_plane(drm->drmfd, plane->plane_id))
+ if(already_closed)
continue;
- // TODO: Fallback to getfb(1)?
- drmModeFB2Ptr drmfb = drmModeGetFB2(drm->drmfd, plane->fb_id);
- if(drmfb) {
- drm->plane_ids[drm->num_plane_ids] = plane->plane_id;
- drm->connector_ids[drm->num_plane_ids] = get_connector_by_crtc_id(&c2crtc_map, plane->crtc_id);
- ++drm->num_plane_ids;
- drmModeFreeFB2(drmfb);
- }
- drmModeFreePlane(plane);
+ drmCloseBufferHandle(drmfd, drmfb->handles[i]);
}
+}
- result = 0;
+static bool get_hdr_metadata(int drm_fd, uint64_t hdr_metadata_blob_id, struct hdr_output_metadata *hdr_metadata) {
+ drmModePropertyBlobPtr hdr_metadata_blob = drmModeGetPropertyBlob(drm_fd, hdr_metadata_blob_id);
+ if(!hdr_metadata_blob)
+ return false;
- error:
- if(planes)
- drmModeFreePlaneResources(planes);
+ if(hdr_metadata_blob->length >= sizeof(struct hdr_output_metadata))
+ *hdr_metadata = *(struct hdr_output_metadata*)hdr_metadata_blob->data;
- return result;
+ drmModeFreePropertyBlob(hdr_metadata_blob);
+ return true;
}
-static bool drmfb_has_multiple_handles(drmModeFB2 *drmfb) {
- int num_handles = 0;
- for(uint32_t handle_index = 0; handle_index < 4 && drmfb->handles[handle_index]; ++handle_index) {
- ++num_handles;
+/* Returns the number of drm handles that we managed to get */
+static int drm_prime_handles_to_fds(gsr_drm *drm, drmModeFB2Ptr drmfb, int *fb_fds) {
+ for(int i = 0; i < GSR_KMS_MAX_DMA_BUFS; ++i) {
+ if(!drmfb->handles[i])
+ return i;
+
+ const int ret = drmPrimeHandleToFD(drm->drmfd, drmfb->handles[i], O_RDONLY, &fb_fds[i]);
+ if(ret != 0 || fb_fds[i] == -1)
+ return i;
}
- return num_handles > 1;
+ return GSR_KMS_MAX_DMA_BUFS;
}
static int kms_get_fb(gsr_drm *drm, gsr_kms_response *response) {
@@ -214,20 +294,33 @@ static int kms_get_fb(gsr_drm *drm, gsr_kms_response *response) {
response->result = KMS_RESULT_OK;
response->err_msg[0] = '\0';
- response->num_fds = 0;
+ response->num_items = 0;
+
+ connector_to_crtc_map c2crtc_map;
+ c2crtc_map.num_maps = 0;
+ map_crtc_to_connector_ids(drm, &c2crtc_map);
- for(size_t i = 0; i < drm->num_plane_ids && response->num_fds < GSR_KMS_MAX_PLANES; ++i) {
+ drmModePlaneResPtr planes = drmModeGetPlaneResources(drm->drmfd);
+ if(!planes) {
+ fprintf(stderr, "kms server error: failed to get plane resources, error: %s\n", strerror(errno));
+ goto done;
+ }
+
+ for(uint32_t i = 0; i < planes->count_planes && response->num_items < GSR_KMS_MAX_ITEMS; ++i) {
drmModePlanePtr plane = NULL;
- drmModeFB2 *drmfb = NULL;
+ drmModeFB2Ptr drmfb = NULL;
- plane = drmModeGetPlane(drm->drmfd, drm->plane_ids[i]);
+ plane = drmModeGetPlane(drm->drmfd, planes->planes[i]);
if(!plane) {
response->result = KMS_RESULT_FAILED_TO_GET_PLANE;
- snprintf(response->err_msg, sizeof(response->err_msg), "failed to get drm plane with id %u, error: %s\n", drm->plane_ids[i], strerror(errno));
+ snprintf(response->err_msg, sizeof(response->err_msg), "failed to get drm plane with id %u, error: %s\n", planes->planes[i], strerror(errno));
fprintf(stderr, "kms server error: %s\n", response->err_msg);
goto next;
}
+ if(!plane->fb_id)
+ goto next;
+
drmfb = drmModeGetFB2(drm->drmfd, plane->fb_id);
if(!drmfb) {
// Commented out for now because we get here if the cursor is moved to another monitor and we dont care about the cursor
@@ -241,31 +334,63 @@ static int kms_get_fb(gsr_drm *drm, gsr_kms_response *response) {
response->result = KMS_RESULT_FAILED_TO_GET_PLANE;
snprintf(response->err_msg, sizeof(response->err_msg), "drmfb handle is NULL");
fprintf(stderr, "kms server error: %s\n", response->err_msg);
- goto next;
+ goto cleanup_handles;
}
// TODO: Check if dimensions have changed by comparing width and height to previous time this was called.
// TODO: Support other plane formats than rgb (with multiple planes, such as direct YUV420 on wayland).
- int fb_fd = -1;
- const int ret = drmPrimeHandleToFD(drm->drmfd, drmfb->handles[0], O_RDONLY, &fb_fd);
- if(ret != 0 || fb_fd == -1) {
+ int x = 0, y = 0, src_x = 0, src_y = 0, src_w = 0, src_h = 0;
+ plane_property_mask property_mask = plane_get_properties(drm->drmfd, plane->plane_id, &x, &y, &src_x, &src_y, &src_w, &src_h);
+ if(!(property_mask & PLANE_PROPERTY_IS_PRIMARY) && !(property_mask & PLANE_PROPERTY_IS_CURSOR))
+ continue;
+
+ int fb_fds[GSR_KMS_MAX_DMA_BUFS];
+ const int num_fb_fds = drm_prime_handles_to_fds(drm, drmfb, fb_fds);
+ if(num_fb_fds == 0) {
response->result = KMS_RESULT_FAILED_TO_GET_PLANE;
snprintf(response->err_msg, sizeof(response->err_msg), "failed to get fd from drm handle, error: %s", strerror(errno));
fprintf(stderr, "kms server error: %s\n", response->err_msg);
- continue;
+ goto cleanup_handles;
}
- response->fds[response->num_fds].fd = fb_fd;
- response->fds[response->num_fds].width = drmfb->width;
- response->fds[response->num_fds].height = drmfb->height;
- response->fds[response->num_fds].pitch = drmfb->pitches[0];
- response->fds[response->num_fds].offset = drmfb->offsets[0];
- response->fds[response->num_fds].pixel_format = drmfb->pixel_format;
- response->fds[response->num_fds].modifier = drmfb->modifier;
- response->fds[response->num_fds].connector_id = drm->connector_ids[i];
- response->fds[response->num_fds].is_combined_plane = drmfb_has_multiple_handles(drmfb);
- ++response->num_fds;
+ const int item_index = response->num_items;
+
+ const connector_crtc_pair *crtc_pair = get_connector_pair_by_crtc_id(&c2crtc_map, plane->crtc_id);
+ if(crtc_pair && crtc_pair->hdr_metadata_blob_id) {
+ response->items[item_index].has_hdr_metadata = get_hdr_metadata(drm->drmfd, crtc_pair->hdr_metadata_blob_id, &response->items[item_index].hdr_metadata);
+ } else {
+ response->items[item_index].has_hdr_metadata = false;
+ }
+
+ for(int j = 0; j < num_fb_fds; ++j) {
+ response->items[item_index].dma_buf[j].fd = fb_fds[j];
+ response->items[item_index].dma_buf[j].pitch = drmfb->pitches[j];
+ response->items[item_index].dma_buf[j].offset = drmfb->offsets[j];
+ }
+ response->items[item_index].num_dma_bufs = num_fb_fds;
+
+ response->items[item_index].width = drmfb->width;
+ response->items[item_index].height = drmfb->height;
+ response->items[item_index].pixel_format = drmfb->pixel_format;
+ response->items[item_index].modifier = drmfb->flags & DRM_MODE_FB_MODIFIERS ? drmfb->modifier : DRM_FORMAT_MOD_INVALID;
+ response->items[item_index].connector_id = crtc_pair ? crtc_pair->connector_id : 0;
+ response->items[item_index].is_cursor = property_mask & PLANE_PROPERTY_IS_CURSOR;
+ if(property_mask & PLANE_PROPERTY_IS_CURSOR) {
+ response->items[item_index].x = x;
+ response->items[item_index].y = y;
+ response->items[item_index].src_w = 0;
+ response->items[item_index].src_h = 0;
+ } else {
+ response->items[item_index].x = src_x;
+ response->items[item_index].y = src_y;
+ response->items[item_index].src_w = src_w;
+ response->items[item_index].src_h = src_h;
+ }
+ ++response->num_items;
+
+ cleanup_handles:
+ drm_mode_cleanup_handles(drm->drmfd, drmfb);
next:
if(drmfb)
@@ -274,13 +399,28 @@ static int kms_get_fb(gsr_drm *drm, gsr_kms_response *response) {
drmModeFreePlane(plane);
}
- if(response->num_fds > 0 || response->result == KMS_RESULT_OK) {
+ done:
+
+ if(planes)
+ drmModeFreePlaneResources(planes);
+
+ if(response->num_items > 0)
+ response->result = KMS_RESULT_OK;
+
+ if(response->result == KMS_RESULT_OK) {
result = 0;
} else {
- for(int i = 0; i < response->num_fds; ++i) {
- close(response->fds[i].fd);
+ for(int i = 0; i < response->num_items; ++i) {
+ for(int j = 0; j < response->items[i].num_dma_bufs; ++j) {
+ gsr_kms_response_dma_buf *dma_buf = &response->items[i].dma_buf[j];
+ if(dma_buf->fd > 0) {
+ close(dma_buf->fd);
+ dma_buf->fd = -1;
+ }
+ }
+ response->items[i].num_dma_bufs = 0;
}
- response->num_fds = 0;
+ response->num_items = 0;
}
return result;
@@ -294,23 +434,21 @@ static double clock_get_monotonic_seconds(void) {
return (double)ts.tv_sec + (double)ts.tv_nsec * 0.000000001;
}
-static void strncpy_safe(char *dst, const char *src, int len) {
- int src_len = strlen(src);
- int min_len = src_len;
- if(len - 1 < min_len)
- min_len = len - 1;
- memcpy(dst, src, min_len);
- dst[min_len] = '\0';
-}
-
int main(int argc, char **argv) {
+ setlocale(LC_ALL, "C"); // Sigh... stupid C
+
+ int res = 0;
+ int socket_fd = 0;
+ gsr_drm drm;
+ drm.drmfd = 0;
+
if(argc != 3) {
- fprintf(stderr, "usage: kms_server <domain_socket_path> <card_path>\n");
+ fprintf(stderr, "usage: gsr-kms-server <domain_socket_path> <card_path>\n");
return 1;
}
const char *domain_socket_path = argv[1];
- int socket_fd = socket(AF_UNIX, SOCK_STREAM, 0);
+ socket_fd = socket(AF_UNIX, SOCK_STREAM, 0);
if(socket_fd == -1) {
fprintf(stderr, "kms server error: failed to create socket, error: %s\n", strerror(errno));
return 2;
@@ -318,17 +456,21 @@ int main(int argc, char **argv) {
const char *card_path = argv[2];
- gsr_drm drm;
- drm.num_plane_ids = 0;
drm.drmfd = open(card_path, O_RDONLY);
if(drm.drmfd < 0) {
fprintf(stderr, "kms server error: failed to open %s, error: %s", card_path, strerror(errno));
- return 2;
+ res = 2;
+ goto done;
}
- if(kms_get_plane_ids(&drm) != 0) {
- close(drm.drmfd);
- return 2;
+ if(drmSetClientCap(drm.drmfd, DRM_CLIENT_CAP_UNIVERSAL_PLANES, 1) != 0) {
+ fprintf(stderr, "kms server error: drmSetClientCap DRM_CLIENT_CAP_UNIVERSAL_PLANES failed, error: %s\n", strerror(errno));
+ res = 2;
+ goto done;
+ }
+
+ if(drmSetClientCap(drm.drmfd, DRM_CLIENT_CAP_ATOMIC, 1) != 0) {
+ fprintf(stderr, "kms server warning: drmSetClientCap DRM_CLIENT_CAP_ATOMIC failed, error: %s. The wrong monitor may be captured as a result\n", strerror(errno));
}
fprintf(stderr, "kms server info: connecting to the client\n");
@@ -338,7 +480,7 @@ int main(int argc, char **argv) {
while(clock_get_monotonic_seconds() - start_time < connect_timeout_sec) {
struct sockaddr_un remote_addr = {0};
remote_addr.sun_family = AF_UNIX;
- strncpy_safe(remote_addr.sun_path, domain_socket_path, sizeof(remote_addr.sun_path));
+ snprintf(remote_addr.sun_path, sizeof(remote_addr.sun_path), "%s", domain_socket_path);
// TODO: Check if parent disconnected
if(connect(socket_fd, (struct sockaddr*)&remote_addr, sizeof(remote_addr.sun_family) + strlen(remote_addr.sun_path)) == -1) {
if(errno == ECONNREFUSED || errno == ENOENT) {
@@ -349,8 +491,8 @@ int main(int argc, char **argv) {
}
fprintf(stderr, "kms server error: connect failed, error: %s (%d)\n", strerror(errno), errno);
- close(drm.drmfd);
- return 2;
+ res = 2;
+ goto done;
}
next:
@@ -361,21 +503,17 @@ int main(int argc, char **argv) {
fprintf(stderr, "kms server info: connected to the client\n");
} else {
fprintf(stderr, "kms server error: failed to connect to the client in %f seconds\n", connect_timeout_sec);
- close(drm.drmfd);
- return 2;
+ res = 2;
+ goto done;
}
- int res = 0;
for(;;) {
gsr_kms_request request;
- struct iovec iov;
- iov.iov_base = &request;
- iov.iov_len = sizeof(request);
-
- struct msghdr request_message = {0};
- request_message.msg_iov = &iov;
- request_message.msg_iovlen = 1;
- const int recv_res = recvmsg(socket_fd, &request_message, MSG_WAITALL);
+ request.version = 0;
+ request.type = -1;
+ request.new_connection_fd = 0;
+
+ const int recv_res = recv_msg_from_client(socket_fd, &request);
if(recv_res == 0) {
fprintf(stderr, "kms server info: kms client shutdown, shutting down the server\n");
res = 3;
@@ -391,40 +529,86 @@ int main(int argc, char **argv) {
continue;
}
+ if(request.version != GSR_KMS_PROTOCOL_VERSION) {
+ fprintf(stderr, "kms server error: expected gpu screen recorder protocol version to be %u, but it's %u. please reinstall gpu screen recorder\n", GSR_KMS_PROTOCOL_VERSION, request.version);
+ /*
+ if(request.new_connection_fd > 0)
+ close(request.new_connection_fd);
+ */
+ continue;
+ }
+
switch(request.type) {
+ case KMS_REQUEST_TYPE_REPLACE_CONNECTION: {
+ gsr_kms_response response;
+ response.version = GSR_KMS_PROTOCOL_VERSION;
+ response.num_items = 0;
+
+ if(request.new_connection_fd > 0) {
+ if(socket_fd > 0)
+ close(socket_fd);
+ socket_fd = request.new_connection_fd;
+
+ response.result = KMS_RESULT_OK;
+ if(send_msg_to_client(socket_fd, &response) == -1)
+ fprintf(stderr, "kms server error: failed to respond to client KMS_REQUEST_TYPE_REPLACE_CONNECTION request\n");
+ } else {
+ response.result = KMS_RESULT_INVALID_REQUEST;
+ snprintf(response.err_msg, sizeof(response.err_msg), "received invalid connection fd");
+ fprintf(stderr, "kms server error: %s\n", response.err_msg);
+ if(send_msg_to_client(socket_fd, &response) == -1)
+ fprintf(stderr, "kms server error: failed to respond to client request\n");
+ }
+
+ break;
+ }
case KMS_REQUEST_TYPE_GET_KMS: {
gsr_kms_response response;
+ response.version = GSR_KMS_PROTOCOL_VERSION;
+ response.num_items = 0;
if(kms_get_fb(&drm, &response) == 0) {
if(send_msg_to_client(socket_fd, &response) == -1)
fprintf(stderr, "kms server error: failed to respond to client KMS_REQUEST_TYPE_GET_KMS request\n");
-
- for(int i = 0; i < response.num_fds; ++i) {
- close(response.fds[i].fd);
- }
} else {
if(send_msg_to_client(socket_fd, &response) == -1)
fprintf(stderr, "kms server error: failed to respond to client KMS_REQUEST_TYPE_GET_KMS request\n");
}
+ for(int i = 0; i < response.num_items; ++i) {
+ for(int j = 0; j < response.items[i].num_dma_bufs; ++j) {
+ gsr_kms_response_dma_buf *dma_buf = &response.items[i].dma_buf[j];
+ if(dma_buf->fd > 0) {
+ close(dma_buf->fd);
+ dma_buf->fd = -1;
+ }
+ }
+ response.items[i].num_dma_bufs = 0;
+ }
+ response.num_items = 0;
+
break;
}
default: {
gsr_kms_response response;
+ response.version = GSR_KMS_PROTOCOL_VERSION;
response.result = KMS_RESULT_INVALID_REQUEST;
+ response.num_items = 0;
+
snprintf(response.err_msg, sizeof(response.err_msg), "invalid request type %d, expected %d (%s)", request.type, KMS_REQUEST_TYPE_GET_KMS, "KMS_REQUEST_TYPE_GET_KMS");
fprintf(stderr, "kms server error: %s\n", response.err_msg);
- if(send_msg_to_client(socket_fd, &response) == -1) {
+ if(send_msg_to_client(socket_fd, &response) == -1)
fprintf(stderr, "kms server error: failed to respond to client request\n");
- break;
- }
+
break;
}
}
}
done:
- close(drm.drmfd);
- close(socket_fd);
+ if(drm.drmfd > 0)
+ close(drm.drmfd);
+ if(socket_fd > 0)
+ close(socket_fd);
return res;
}
diff --git a/kms/server/project.conf b/kms/server/project.conf
index cf863c1..26a1947 100644
--- a/kms/server/project.conf
+++ b/kms/server/project.conf
@@ -4,5 +4,8 @@ type = "executable"
version = "1.0.0"
platforms = ["posix"]
+[config]
+error_on_warning = "true"
+
[dependencies]
libdrm = ">=2"