diff --git a/native/src/core/socket.cpp b/native/src/core/socket.cpp index 57f70553b..a499b10a7 100644 --- a/native/src/core/socket.cpp +++ b/native/src/core/socket.cpp @@ -73,13 +73,27 @@ static void *recv_fds(int sockfd, char *cmsgbuf, size_t bufsz, int cnt) { }; xrecvmsg(sockfd, &msg, MSG_WAITALL); - cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); + if (msg.msg_controllen != bufsz) { + LOGE("recv_fd: msg_flags = %d, msg_controllen(%zu) != %zu\n", + msg.msg_flags, msg.msg_controllen, bufsz); + return nullptr; + } - if (msg.msg_controllen != bufsz || - cmsg == nullptr || - cmsg->cmsg_len != CMSG_LEN(sizeof(int) * cnt) || - cmsg->cmsg_level != SOL_SOCKET || - cmsg->cmsg_type != SCM_RIGHTS) { + cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); + if (cmsg == nullptr) { + LOGE("recv_fd: cmsg == nullptr\n"); + return nullptr; + } + if (cmsg->cmsg_len != CMSG_LEN(sizeof(int) * cnt)) { + LOGE("recv_fd: cmsg_len(%zu) != %zu\n", cmsg->cmsg_len, CMSG_LEN(sizeof(int) * cnt)); + return nullptr; + } + if (cmsg->cmsg_level != SOL_SOCKET) { + LOGE("recv_fd: cmsg_level != SOL_SOCKET\n"); + return nullptr; + } + if (cmsg->cmsg_type != SCM_RIGHTS) { + LOGE("recv_fd: cmsg_type != SCM_RIGHTS\n"); return nullptr; } @@ -92,8 +106,11 @@ vector recv_fds(int sockfd) { // Peek fd count to allocate proper buffer int cnt; recv(sockfd, &cnt, sizeof(cnt), MSG_PEEK); - if (cnt == 0) + if (cnt == 0) { + // Consume data + recv(sockfd, &cnt, sizeof(cnt), MSG_WAITALL); return results; + } vector cmsgbuf; cmsgbuf.resize(CMSG_SPACE(sizeof(int) * cnt)); @@ -109,6 +126,15 @@ vector recv_fds(int sockfd) { } int recv_fd(int sockfd) { + // Peek fd count + int cnt; + recv(sockfd, &cnt, sizeof(cnt), MSG_PEEK); + if (cnt == 0) { + // Consume data + recv(sockfd, &cnt, sizeof(cnt), MSG_WAITALL); + return -1; + } + char cmsgbuf[CMSG_SPACE(sizeof(int))]; void *data = recv_fds(sockfd, cmsgbuf, sizeof(cmsgbuf), 1);