diff options
Diffstat (limited to 'kms/server/kms_server.c')
-rw-r--r-- | kms/server/kms_server.c | 242 |
1 files changed, 180 insertions, 62 deletions
diff --git a/kms/server/kms_server.c b/kms/server/kms_server.c index 2eaa1ed..b4f3378 100644 --- a/kms/server/kms_server.c +++ b/kms/server/kms_server.c @@ -1,3 +1,7 @@ +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + #include "../kms_shared.h" #include <stdio.h> @@ -6,6 +10,7 @@ #include <stdlib.h> #include <unistd.h> +#include <limits.h> #include <fcntl.h> #include <sys/socket.h> #include <sys/un.h> @@ -37,6 +42,14 @@ static int max_int(int a, int b) { return a > b ? a : b; } +static int count_num_fds(const gsr_kms_response *response) { + int num_fds = 0; + for(int i = 0; i < response->num_items; ++i) { + num_fds += response->items[i].num_dma_bufs; + } + return num_fds; +} + static int send_msg_to_client(int client_fd, gsr_kms_response *response) { struct iovec iov; iov.iov_base = response; @@ -46,21 +59,25 @@ static int send_msg_to_client(int client_fd, gsr_kms_response *response) { response_message.msg_iov = &iov; response_message.msg_iovlen = 1; - char cmsgbuf[CMSG_SPACE(sizeof(int) * max_int(1, response->num_fds))]; + const int num_fds = count_num_fds(response); + char cmsgbuf[CMSG_SPACE(sizeof(int) * max_int(1, num_fds))]; memset(cmsgbuf, 0, sizeof(cmsgbuf)); - if(response->num_fds > 0) { + if(num_fds > 0) { response_message.msg_control = cmsgbuf; response_message.msg_controllen = sizeof(cmsgbuf); struct cmsghdr *cmsg = CMSG_FIRSTHDR(&response_message); cmsg->cmsg_level = SOL_SOCKET; cmsg->cmsg_type = SCM_RIGHTS; - cmsg->cmsg_len = CMSG_LEN(sizeof(int) * response->num_fds); + cmsg->cmsg_len = CMSG_LEN(sizeof(int) * num_fds); int *fds = (int*)CMSG_DATA(cmsg); - for(int i = 0; i < response->num_fds; ++i) { - fds[i] = response->fds[i].fd; + int fd_index = 0; + for(int i = 0; i < response->num_items; ++i) { + for(int j = 0; j < response->items[i].num_dma_bufs; ++j) { + fds[fd_index++] = response->items[i].dma_buf[j].fd; + } } response_message.msg_controllen = cmsg->cmsg_len; @@ -258,14 +275,27 @@ static bool get_hdr_metadata(int drm_fd, uint64_t hdr_metadata_blob_id, struct h return true; } +/* Returns the number of drm handles that we managed to get */ +static int drm_prime_handles_to_fds(gsr_drm *drm, drmModeFB2Ptr drmfb, int *fb_fds) { + for(int i = 0; i < GSR_KMS_MAX_DMA_BUFS; ++i) { + if(!drmfb->handles[i]) + return i; + + const int ret = drmPrimeHandleToFD(drm->drmfd, drmfb->handles[i], O_RDONLY, &fb_fds[i]); + if(ret != 0 || fb_fds[i] == -1) + return i; + } + return GSR_KMS_MAX_DMA_BUFS; +} + static int kms_get_fb(gsr_drm *drm, gsr_kms_response *response, connector_to_crtc_map *c2crtc_map) { int result = -1; response->result = KMS_RESULT_OK; response->err_msg[0] = '\0'; - response->num_fds = 0; + response->num_items = 0; - for(uint32_t i = 0; i < drm->planes->count_planes && response->num_fds < GSR_KMS_MAX_PLANES; ++i) { + for(uint32_t i = 0; i < drm->planes->count_planes && response->num_items < GSR_KMS_MAX_ITEMS; ++i) { drmModePlanePtr plane = NULL; drmModeFB2Ptr drmfb = NULL; @@ -299,52 +329,54 @@ static int kms_get_fb(gsr_drm *drm, gsr_kms_response *response, connector_to_crt // TODO: Check if dimensions have changed by comparing width and height to previous time this was called. // TODO: Support other plane formats than rgb (with multiple planes, such as direct YUV420 on wayland). - int fb_fd = -1; - const int ret = drmPrimeHandleToFD(drm->drmfd, drmfb->handles[0], O_RDONLY, &fb_fd); - if(ret != 0 || fb_fd == -1) { + int x = 0, y = 0, src_x = 0, src_y = 0, src_w = 0, src_h = 0; + plane_property_mask property_mask = plane_get_properties(drm->drmfd, plane->plane_id, &x, &y, &src_x, &src_y, &src_w, &src_h); + if(!(property_mask & PLANE_PROPERTY_IS_PRIMARY) && !(property_mask & PLANE_PROPERTY_IS_CURSOR)) + continue; + + int fb_fds[GSR_KMS_MAX_DMA_BUFS]; + const int num_fb_fds = drm_prime_handles_to_fds(drm, drmfb, fb_fds); + if(num_fb_fds == 0) { response->result = KMS_RESULT_FAILED_TO_GET_PLANE; snprintf(response->err_msg, sizeof(response->err_msg), "failed to get fd from drm handle, error: %s", strerror(errno)); fprintf(stderr, "kms server error: %s\n", response->err_msg); goto cleanup_handles; } - const int fd_index = response->num_fds; + const int item_index = response->num_items; - int x = 0, y = 0, src_x = 0, src_y = 0, src_w = 0, src_h = 0; - plane_property_mask property_mask = plane_get_properties(drm->drmfd, plane->plane_id, &x, &y, &src_x, &src_y, &src_w, &src_h); - if((property_mask & PLANE_PROPERTY_IS_PRIMARY) || (property_mask & PLANE_PROPERTY_IS_CURSOR)) { - const connector_crtc_pair *crtc_pair = get_connector_pair_by_crtc_id(c2crtc_map, plane->crtc_id); - if(crtc_pair && crtc_pair->hdr_metadata_blob_id) { - response->fds[fd_index].has_hdr_metadata = get_hdr_metadata(drm->drmfd, crtc_pair->hdr_metadata_blob_id, &response->fds[fd_index].hdr_metadata); - } else { - response->fds[fd_index].has_hdr_metadata = false; - } + const connector_crtc_pair *crtc_pair = get_connector_pair_by_crtc_id(c2crtc_map, plane->crtc_id); + if(crtc_pair && crtc_pair->hdr_metadata_blob_id) { + response->items[item_index].has_hdr_metadata = get_hdr_metadata(drm->drmfd, crtc_pair->hdr_metadata_blob_id, &response->items[item_index].hdr_metadata); + } else { + response->items[item_index].has_hdr_metadata = false; + } - response->fds[fd_index].fd = fb_fd; - response->fds[fd_index].width = drmfb->width; - response->fds[fd_index].height = drmfb->height; - response->fds[fd_index].pitch = drmfb->pitches[0]; - response->fds[fd_index].offset = drmfb->offsets[0]; - response->fds[fd_index].pixel_format = drmfb->pixel_format; - response->fds[fd_index].modifier = drmfb->modifier; - response->fds[fd_index].connector_id = crtc_pair ? crtc_pair->connector_id : 0; - response->fds[fd_index].is_cursor = property_mask & PLANE_PROPERTY_IS_CURSOR; - response->fds[fd_index].is_combined_plane = false; - if(property_mask & PLANE_PROPERTY_IS_CURSOR) { - response->fds[fd_index].x = x; - response->fds[fd_index].y = y; - response->fds[fd_index].src_w = 0; - response->fds[fd_index].src_h = 0; - } else { - response->fds[fd_index].x = src_x; - response->fds[fd_index].y = src_y; - response->fds[fd_index].src_w = src_w; - response->fds[fd_index].src_h = src_h; - } - ++response->num_fds; + for(int j = 0; j < num_fb_fds; ++j) { + response->items[item_index].dma_buf[j].fd = fb_fds[j]; + response->items[item_index].dma_buf[j].pitch = drmfb->pitches[j]; + response->items[item_index].dma_buf[j].offset = drmfb->offsets[j]; + } + response->items[item_index].num_dma_bufs = num_fb_fds; + + response->items[item_index].width = drmfb->width; + response->items[item_index].height = drmfb->height; + response->items[item_index].pixel_format = drmfb->pixel_format; + response->items[item_index].modifier = drmfb->modifier; + response->items[item_index].connector_id = crtc_pair ? crtc_pair->connector_id : 0; + response->items[item_index].is_cursor = property_mask & PLANE_PROPERTY_IS_CURSOR; + if(property_mask & PLANE_PROPERTY_IS_CURSOR) { + response->items[item_index].x = x; + response->items[item_index].y = y; + response->items[item_index].src_w = 0; + response->items[item_index].src_h = 0; } else { - close(fb_fd); + response->items[item_index].x = src_x; + response->items[item_index].y = src_y; + response->items[item_index].src_w = src_w; + response->items[item_index].src_h = src_h; } + ++response->num_items; cleanup_handles: drm_mode_cleanup_handles(drm->drmfd, drmfb); @@ -356,16 +388,23 @@ static int kms_get_fb(gsr_drm *drm, gsr_kms_response *response, connector_to_crt drmModeFreePlane(plane); } - if(response->num_fds > 0) + if(response->num_items > 0) response->result = KMS_RESULT_OK; if(response->result == KMS_RESULT_OK) { result = 0; } else { - for(int i = 0; i < response->num_fds; ++i) { - close(response->fds[i].fd); + for(int i = 0; i < response->num_items; ++i) { + for(int j = 0; j < response->items[i].num_dma_bufs; ++j) { + gsr_kms_response_dma_buf *dma_buf = &response->items[i].dma_buf[j]; + if(dma_buf->fd > 0) { + close(dma_buf->fd); + dma_buf->fd = -1; + } + } + response->items[i].num_dma_bufs = 0; } - response->num_fds = 0; + response->num_items = 0; } return result; @@ -379,14 +418,80 @@ static double clock_get_monotonic_seconds(void) { return (double)ts.tv_sec + (double)ts.tv_nsec * 0.000000001; } -static void string_copy(char *dst, const char *src, int len) { - int src_len = strlen(src); - int min_len = src_len; - if(len - 1 < min_len) - min_len = len - 1; - memcpy(dst, src, min_len); - dst[min_len] = '\0'; -} +// static bool readlink_realpath(const char *filepath, char *buffer) { +// char symlinked_path[PATH_MAX]; +// ssize_t bytes_written = readlink(filepath, symlinked_path, sizeof(symlinked_path) - 1); +// if(bytes_written == -1 && errno == EINVAL) { +// /* Not a symlink */ +// snprintf(symlinked_path, sizeof(symlinked_path), "%s", filepath); +// } else if(bytes_written == -1) { +// return false; +// } else { +// symlinked_path[bytes_written] = '\0'; +// } + +// if(!realpath(symlinked_path, buffer)) +// return false; + +// return true; +// } + +// static void file_get_directory(char *filepath) { +// char *end = strrchr(filepath, '/'); +// if(end == NULL) +// filepath[0] = '\0'; +// else +// *end = '\0'; +// } + +// static bool string_ends_with(const char *str, const char *ends_with) { +// const int len = strlen(str); +// const int ends_with_len = strlen(ends_with); +// return len >= ends_with_len && memcmp(str + len - ends_with_len, ends_with, ends_with_len) == 0; +// } + +// This is not foolproof, but the assumption is that gsr-kms-server and gpu-screen-recorder are installed in the same directory +// in a location that only the root user can write to (usually /usr/bin or /usr/local/bin) and if the client runs from that location +// and is called gpu-screen-recorder then gsr-kms-server can only be used by a malicious program if the malicious program +// had root access, to modify that program install directory. +// static bool is_remote_peer_program_gpu_screen_recorder(int socket_fd) { +// // TODO: Use SO_PEERPIDFD on kernel >= 6.5 to avoid a race condition in the /proc/<pid> check +// struct ucred cred; +// socklen_t ucred_len = sizeof(cred); +// if(getsockopt(socket_fd, SOL_SOCKET, SO_PEERCRED, &cred, &ucred_len) == -1) { +// fprintf(stderr, "kms server error: failed to get peer credentials, error: %s\n", strerror(errno)); +// return false; +// } + +// char self_directory[PATH_MAX]; +// if(!readlink_realpath("/proc/self/exe", self_directory)) { +// fprintf(stderr, "kms server error: failed to resolve /proc/self/exe\n"); +// return false; +// } +// file_get_directory(self_directory); + +// char peer_directory[PATH_MAX]; +// char peer_exe_path[PATH_MAX]; +// snprintf(peer_exe_path, sizeof(peer_exe_path), "/proc/%d/exe", (int)cred.pid); +// if(!readlink_realpath(peer_exe_path, peer_directory)) { +// fprintf(stderr, "kms server error: failed to resolve /proc/self/exe\n"); +// return false; +// } + +// if(!string_ends_with(peer_directory, "/gpu-screen-recorder")) { +// fprintf(stderr, "kms server error: only gpu-screen-recorder can use gsr-kms-server. client program location is %s\n", peer_directory); +// return false; +// } + +// file_get_directory(peer_directory); + +// if(strcmp(self_directory, peer_directory) != 0) { +// fprintf(stderr, "kms server error: the client program is in directory %s but only programs in %s can run gsr-kms-server\n", peer_directory, self_directory); +// return false; +// } + +// return true; +// } int main(int argc, char **argv) { int res = 0; @@ -444,7 +549,7 @@ int main(int argc, char **argv) { while(clock_get_monotonic_seconds() - start_time < connect_timeout_sec) { struct sockaddr_un remote_addr = {0}; remote_addr.sun_family = AF_UNIX; - string_copy(remote_addr.sun_path, domain_socket_path, sizeof(remote_addr.sun_path)); + snprintf(remote_addr.sun_path, sizeof(remote_addr.sun_path), "%s", domain_socket_path); // TODO: Check if parent disconnected if(connect(socket_fd, (struct sockaddr*)&remote_addr, sizeof(remote_addr.sun_family) + strlen(remote_addr.sun_path)) == -1) { if(errno == ECONNREFUSED || errno == ENOENT) { @@ -471,6 +576,11 @@ int main(int argc, char **argv) { goto done; } + // if(!is_remote_peer_program_gpu_screen_recorder(socket_fd)) { + // res = 3; + // goto done; + // } + for(;;) { gsr_kms_request request; request.version = 0; @@ -494,7 +604,7 @@ int main(int argc, char **argv) { } if(request.version != GSR_KMS_PROTOCOL_VERSION) { - fprintf(stderr, "kms server error: expected gpu screen recorder protocol version to be %u, but it's %u\n", GSR_KMS_PROTOCOL_VERSION, request.version); + fprintf(stderr, "kms server error: expected gpu screen recorder protocol version to be %u, but it's %u. please reinstall gpu screen recorder\n", GSR_KMS_PROTOCOL_VERSION, request.version); /* if(request.new_connection_fd > 0) close(request.new_connection_fd); @@ -506,7 +616,7 @@ int main(int argc, char **argv) { case KMS_REQUEST_TYPE_REPLACE_CONNECTION: { gsr_kms_response response; response.version = GSR_KMS_PROTOCOL_VERSION; - response.num_fds = 0; + response.num_items = 0; if(request.new_connection_fd > 0) { if(socket_fd > 0) @@ -529,7 +639,7 @@ int main(int argc, char **argv) { case KMS_REQUEST_TYPE_GET_KMS: { gsr_kms_response response; response.version = GSR_KMS_PROTOCOL_VERSION; - response.num_fds = 0; + response.num_items = 0; if(kms_get_fb(&drm, &response, &c2crtc_map) == 0) { if(send_msg_to_client(socket_fd, &response) == -1) @@ -539,9 +649,17 @@ int main(int argc, char **argv) { fprintf(stderr, "kms server error: failed to respond to client KMS_REQUEST_TYPE_GET_KMS request\n"); } - for(int i = 0; i < response.num_fds; ++i) { - close(response.fds[i].fd); + for(int i = 0; i < response.num_items; ++i) { + for(int j = 0; j < response.items[i].num_dma_bufs; ++j) { + gsr_kms_response_dma_buf *dma_buf = &response.items[i].dma_buf[j]; + if(dma_buf->fd > 0) { + close(dma_buf->fd); + dma_buf->fd = -1; + } + } + response.items[i].num_dma_bufs = 0; } + response.num_items = 0; break; } @@ -549,7 +667,7 @@ int main(int argc, char **argv) { gsr_kms_response response; response.version = GSR_KMS_PROTOCOL_VERSION; response.result = KMS_RESULT_INVALID_REQUEST; - response.num_fds = 0; + response.num_items = 0; snprintf(response.err_msg, sizeof(response.err_msg), "invalid request type %d, expected %d (%s)", request.type, KMS_REQUEST_TYPE_GET_KMS, "KMS_REQUEST_TYPE_GET_KMS"); fprintf(stderr, "kms server error: %s\n", response.err_msg); |