From ffe47300a1352e32a611b96f28fdbdc0bf84623f Mon Sep 17 00:00:00 2001 From: topjohnwu Date: Thu, 19 Aug 2021 04:36:47 -0700 Subject: [PATCH] Update recv/send fd function --- native/jni/core/socket.cpp | 173 ++++++++++++++++------------------ native/jni/include/socket.hpp | 8 +- 2 files changed, 84 insertions(+), 97 deletions(-) diff --git a/native/jni/core/socket.cpp b/native/jni/core/socket.cpp index 55e38cb07..9cf1fcc70 100644 --- a/native/jni/core/socket.cpp +++ b/native/jni/core/socket.cpp @@ -20,122 +20,107 @@ socklen_t setup_sockaddr(sockaddr_un *sun, const char *name) { return socket_len(sun); } -int socket_accept(int sockfd, int timeout) { - struct pollfd pfd = { - .fd = sockfd, - .events = POLL_IN - }; - return xpoll(&pfd, 1, timeout * 1000) <= 0 ? -1 : xaccept4(sockfd, nullptr, nullptr, SOCK_CLOEXEC); +void get_client_cred(int fd, ucred *cred) { + socklen_t len = sizeof(*cred); + getsockopt(fd, SOL_SOCKET, SO_PEERCRED, cred, &len); } -void get_client_cred(int fd, struct ucred *cred) { - socklen_t ucred_length = sizeof(*cred); - getsockopt(fd, SOL_SOCKET, SO_PEERCRED, cred, &ucred_length); -} - -/* - * Receive a file descriptor from a Unix socket. - * Contributed by @mkasick - * - * Returns the file descriptor on success, or -1 if a file - * descriptor was not actually included in the message - * - * On error the function terminates by calling exit(-1) - */ -int recv_fd(int sockfd) { - // Need to receive data from the message, otherwise don't care about it. - char iovbuf; - struct cmsghdr *cmsg; - - struct iovec iov = { - .iov_base = &iovbuf, - .iov_len = 1, +static int send_fds(int sockfd, void *cmsgbuf, size_t bufsz, const int *fds, int cnt) { + iovec iov = { + .iov_base = &cnt, + .iov_len = sizeof(cnt), + }; + msghdr msg = { + .msg_iov = &iov, + .msg_iovlen = 1, }; + if (cnt) { + msg.msg_control = cmsgbuf; + msg.msg_controllen = bufsz; + cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); + cmsg->cmsg_len = CMSG_LEN(sizeof(int) * cnt); + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = SCM_RIGHTS; + + memcpy(CMSG_DATA(cmsg), fds, sizeof(int) * cnt); + } + + return xsendmsg(sockfd, &msg, 0); +} + +int send_fds(int sockfd, const int *fds, int cnt) { + if (cnt == 0) { + return send_fds(sockfd, nullptr, 0, nullptr, 0); + } + vector cmsgbuf; + cmsgbuf.resize(CMSG_SPACE(sizeof(int) * cnt)); + return send_fds(sockfd, cmsgbuf.data(), cmsgbuf.size(), fds, cnt); +} + +int send_fd(int sockfd, int fd) { + if (fd < 0) { + return send_fds(sockfd, nullptr, 0, nullptr, 0); + } char cmsgbuf[CMSG_SPACE(sizeof(int))]; + return send_fds(sockfd, cmsgbuf, sizeof(cmsgbuf), &fd, 1); +} - struct msghdr msg = { +static void *recv_fds(int sockfd, char *cmsgbuf, size_t bufsz, int cnt) { + iovec iov = { + .iov_base = &cnt, + .iov_len = sizeof(cnt), + }; + msghdr msg = { .msg_iov = &iov, .msg_iovlen = 1, .msg_control = cmsgbuf, - .msg_controllen = sizeof(cmsgbuf), + .msg_controllen = bufsz }; xrecvmsg(sockfd, &msg, MSG_WAITALL); + cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); - // Was a control message actually sent? - switch (msg.msg_controllen) { - case 0: - // No, so the file descriptor was closed and won't be used. - return -1; - case sizeof(cmsgbuf): - // Yes, grab the file descriptor from it. - break; - default: - goto error; + if (msg.msg_controllen != bufsz || + cmsg == nullptr || + cmsg->cmsg_len != CMSG_LEN(sizeof(int) * cnt) || + cmsg->cmsg_level != SOL_SOCKET || + cmsg->cmsg_type != SCM_RIGHTS) { + return nullptr; } - cmsg = CMSG_FIRSTHDR(&msg); - - if (cmsg == nullptr || - cmsg->cmsg_len != CMSG_LEN(sizeof(int)) || - cmsg->cmsg_level != SOL_SOCKET || - cmsg->cmsg_type != SCM_RIGHTS) { -error: - LOGE("unable to read fd\n"); - exit(-1); - } - - return *(int *)CMSG_DATA(cmsg); + return CMSG_DATA(cmsg); } -/* - * Send a file descriptor through a Unix socket. - * Contributed by @mkasick - * - * On error the function terminates by calling exit(-1) - * - * fd may be -1, in which case the dummy data is sent, - * but no control message with the FD is sent. - */ -void send_fd(int sockfd, int fd) { - // Need to send some data in the message, this will do. - char junk[] = { '\0' }; - struct iovec iov = { - .iov_base = junk, - .iov_len = 1, - }; +vector recv_fds(int sockfd) { + // Peek fd count to allocate proper buffer + int cnt; + recv(sockfd, &cnt, sizeof(cnt), MSG_PEEK); - struct msghdr msg = { - .msg_iov = &iov, - .msg_iovlen = 1, - }; + vector cmsgbuf; + cmsgbuf.resize(CMSG_SPACE(sizeof(int) * cnt)); + vector results; + void *data = recv_fds(sockfd, cmsgbuf.data(), cmsgbuf.size(), cnt); + if (data == nullptr) + return results; + + results.resize(cnt); + memcpy(results.data(), data, sizeof(int) * cnt); + + return results; +} + +int recv_fd(int sockfd) { char cmsgbuf[CMSG_SPACE(sizeof(int))]; - if (fd != -1) { - // Is the file descriptor actually open? - if (fcntl(fd, F_GETFD) == -1) { - if (errno != EBADF) { - PLOGE("unable to send fd"); - } - // It's closed, don't send a control message or sendmsg will EBADF. - } else { - // It's open, send the file descriptor in a control message. - msg.msg_control = cmsgbuf; - msg.msg_controllen = sizeof(cmsgbuf); + void *data = recv_fds(sockfd, cmsgbuf, sizeof(cmsgbuf), 1); + if (data == nullptr) + return -1; - struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); - - cmsg->cmsg_len = CMSG_LEN(sizeof(int)); - cmsg->cmsg_level = SOL_SOCKET; - cmsg->cmsg_type = SCM_RIGHTS; - - *(int *)CMSG_DATA(cmsg) = fd; - } - } - - xsendmsg(sockfd, &msg, 0); + int result; + memcpy(&result, data, sizeof(int)); + return result; } int read_int(int fd) { diff --git a/native/jni/include/socket.hpp b/native/jni/include/socket.hpp index 72b1694c6..ee7afdcb1 100644 --- a/native/jni/include/socket.hpp +++ b/native/jni/include/socket.hpp @@ -3,12 +3,14 @@ #include #include #include +#include socklen_t setup_sockaddr(sockaddr_un *sun, const char *name); -int socket_accept(int sockfd, int timeout); -void get_client_cred(int fd, struct ucred *cred); +void get_client_cred(int fd, ucred *cred); +std::vector recv_fds(int sockfd); int recv_fd(int sockfd); -void send_fd(int sockfd, int fd); +int send_fds(int sockfd, const int *fds, int cnt); +int send_fd(int sockfd, int fd); int read_int(int fd); int read_int_be(int fd); void write_int(int fd, int val);