aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--kms/client/kms_client.c26
-rw-r--r--kms/server/kms_server.c45
2 files changed, 59 insertions, 12 deletions
diff --git a/kms/client/kms_client.c b/kms/client/kms_client.c
index 3e60c63..8e1dafb 100644
--- a/kms/client/kms_client.c
+++ b/kms/client/kms_client.c
@@ -53,11 +53,29 @@ static int send_msg_to_server(int server_fd, gsr_kms_request *request) {
iov.iov_base = request;
iov.iov_len = sizeof(*request);
- struct msghdr request_message = {0};
- request_message.msg_iov = &iov;
- request_message.msg_iovlen = 1;
+ struct msghdr response_message = {0};
+ response_message.msg_iov = &iov;
+ response_message.msg_iovlen = 1;
+
+ char cmsgbuf[CMSG_SPACE(sizeof(int) * 1)];
+ memset(cmsgbuf, 0, sizeof(cmsgbuf));
+
+ if(request->new_connection_fd > 0) {
+ response_message.msg_control = cmsgbuf;
+ response_message.msg_controllen = sizeof(cmsgbuf);
+
+ struct cmsghdr *cmsg = CMSG_FIRSTHDR(&response_message);
+ cmsg->cmsg_level = SOL_SOCKET;
+ cmsg->cmsg_type = SCM_RIGHTS;
+ cmsg->cmsg_len = CMSG_LEN(sizeof(int) * 1);
+
+ int *fds = (int*)CMSG_DATA(cmsg);
+ fds[0] = request->new_connection_fd;
+
+ response_message.msg_controllen = cmsg->cmsg_len;
+ }
- return sendmsg(server_fd, &request_message, 0);
+ return sendmsg(server_fd, &response_message, 0);
}
static int recv_msg_from_server(int server_fd, gsr_kms_response *response) {
diff --git a/kms/server/kms_server.c b/kms/server/kms_server.c
index 39ce446..fbd101e 100644
--- a/kms/server/kms_server.c
+++ b/kms/server/kms_server.c
@@ -68,6 +68,40 @@ static int send_msg_to_client(int client_fd, gsr_kms_response *response) {
return sendmsg(client_fd, &response_message, 0);
}
+static int recv_msg_from_client(int client_fd, gsr_kms_request *request) {
+ struct iovec iov;
+ iov.iov_base = request;
+ iov.iov_len = sizeof(*request);
+
+ struct msghdr response_message = {0};
+ response_message.msg_iov = &iov;
+ response_message.msg_iovlen = 1;
+
+ char cmsgbuf[CMSG_SPACE(sizeof(int) * 1)];
+ memset(cmsgbuf, 0, sizeof(cmsgbuf));
+ response_message.msg_control = cmsgbuf;
+ response_message.msg_controllen = sizeof(cmsgbuf);
+
+ int res = recvmsg(client_fd, &response_message, MSG_WAITALL);
+ if(res <= 0)
+ return res;
+
+ if(request->new_connection_fd > 0) {
+ struct cmsghdr *cmsg = CMSG_FIRSTHDR(&response_message);
+ if(cmsg) {
+ int *fds = (int*)CMSG_DATA(cmsg);
+ request->new_connection_fd = fds[0];
+ } else {
+ if(request->new_connection_fd > 0) {
+ close(request->new_connection_fd);
+ request->new_connection_fd = 0;
+ }
+ }
+ }
+
+ return res;
+}
+
static bool connector_get_property_by_name(int drmfd, drmModeConnectorPtr props, const char *name, uint64_t *result) {
for(int i = 0; i < props->count_props; ++i) {
drmModePropertyPtr prop = drmModeGetProperty(drmfd, props->props[i]);
@@ -410,14 +444,9 @@ int main(int argc, char **argv) {
gsr_kms_request request;
request.version = 0;
request.type = -1;
- struct iovec iov;
- iov.iov_base = &request;
- iov.iov_len = sizeof(request);
-
- struct msghdr request_message = {0};
- request_message.msg_iov = &iov;
- request_message.msg_iovlen = 1;
- const int recv_res = recvmsg(socket_fd, &request_message, MSG_WAITALL);
+ request.new_connection_fd = 0;
+
+ const int recv_res = recv_msg_from_client(socket_fd, &request);
if(recv_res == 0) {
fprintf(stderr, "kms server info: kms client shutdown, shutting down the server\n");
res = 3;