diff --git a/src/server_config.c b/src/server_config.c index 758edb06..503e95f6 100644 --- a/src/server_config.c +++ b/src/server_config.c @@ -407,7 +407,6 @@ nc_server_config_free(struct nc_server_config *config) } } free(endpt->binds[j].address); - pthread_mutex_destroy(&endpt->bind_lock); } LY_ARRAY_FREE(endpt->binds); @@ -3325,7 +3324,6 @@ config_endpoint(const struct lyd_node *node, enum nc_operation parent_op, struct nc_endpt *endpt = NULL; const char *name; LY_ARRAY_COUNT_TYPE i = 0; - int r; NC_NODE_GET_OP(node, parent_op, &op); @@ -3346,12 +3344,6 @@ config_endpoint(const struct lyd_node *node, enum nc_operation parent_op, } else if (op == NC_OP_CREATE) { /* create a new endpoint */ LY_ARRAY_NEW_RET(LYD_CTX(node), config->endpts, endpt, 1); - - /* init the new endpoint */ - if ((r = pthread_mutex_init(&endpt->bind_lock, NULL))) { - ERR(NULL, "Mutex init failed (%s).", strerror(r)); - return 1; - } } /* config name */ @@ -3376,7 +3368,6 @@ config_endpoint(const struct lyd_node *node, enum nc_operation parent_op, /* all children processed, we can now delete the endpoint */ if (op == NC_OP_DELETE) { - pthread_mutex_destroy(&endpt->bind_lock); if (i < LY_ARRAY_COUNT(config->endpts) - 1) { config->endpts[i] = config->endpts[LY_ARRAY_COUNT(config->endpts) - 1]; } @@ -6007,7 +5998,6 @@ nc_server_config_dup(const struct nc_server_config *src, struct nc_server_config dst_endpt->binds[j].sock = -1; LY_ARRAY_INCREMENT(dst_endpt->binds); } - pthread_mutex_init(&dst_endpt->bind_lock, NULL); dst_endpt->ka = src_endpt->ka; diff --git a/src/session_client.c b/src/session_client.c index 38a28f02..b051800f 100644 --- a/src/session_client.c +++ b/src/session_client.c @@ -1823,7 +1823,6 @@ nc_client_ch_add_bind_listen(const char *address, uint16_t port, const char *hos client_opts.ch_binds[client_opts.ch_bind_count - 1].address = strdup(address); client_opts.ch_binds[client_opts.ch_bind_count - 1].port = port; client_opts.ch_binds[client_opts.ch_bind_count - 1].sock = sock; - client_opts.ch_binds[client_opts.ch_bind_count - 1].pollin = 0; return 0; } @@ -1886,7 +1885,7 @@ nc_accept_callhome(int timeout, struct ly_ctx *ctx, struct nc_session **session) { int ret, sock; char *host = NULL; - uint16_t port, idx; + uint16_t port, bind_idx = 0; NC_CHECK_ARG_RET(NULL, session, -1); @@ -1895,8 +1894,8 @@ nc_accept_callhome(int timeout, struct ly_ctx *ctx, struct nc_session **session) return -1; } - ret = nc_sock_accept_binds(NULL, client_opts.ch_binds, client_opts.ch_bind_count, &client_opts.ch_bind_lock, timeout, - &host, &port, &idx, &sock); + ret = nc_server_ch_accept_binds(client_opts.ch_binds, client_opts.ch_bind_count, timeout, + &host, &port, &bind_idx, &sock); if (ret < 1) { free(host); return ret; @@ -1909,11 +1908,11 @@ nc_accept_callhome(int timeout, struct ly_ctx *ctx, struct nc_session **session) return -1; } - if (client_opts.ch_binds_aux[idx].ti == NC_TI_SSH) { + if (client_opts.ch_binds_aux[bind_idx].ti == NC_TI_SSH) { *session = nc_accept_callhome_ssh_sock(sock, host, port, ctx, NC_TRANSPORT_TIMEOUT); - } else if (client_opts.ch_binds_aux[idx].ti == NC_TI_TLS) { + } else if (client_opts.ch_binds_aux[bind_idx].ti == NC_TI_TLS) { *session = nc_accept_callhome_tls_sock(sock, host, port, ctx, NC_TRANSPORT_TIMEOUT, - client_opts.ch_binds_aux[idx].hostname); + client_opts.ch_binds_aux[bind_idx].hostname); } else { close(sock); *session = NULL; @@ -1971,13 +1970,6 @@ nc_session_ntf_thread_running(const struct nc_session *session) API int nc_client_init(void) { - int r; - - if ((r = pthread_mutex_init(&client_opts.ch_bind_lock, NULL))) { - ERR(NULL, "%s: failed to init bind lock(%s).", __func__, strerror(r)); - return -1; - } - #ifdef NC_ENABLED_SSH_TLS if (nc_tls_backend_init_wrap()) { ERR(NULL, "%s: failed to init the SSL library backend.", __func__); @@ -1995,7 +1987,6 @@ nc_client_init(void) API void nc_client_destroy(void) { - pthread_mutex_destroy(&client_opts.ch_bind_lock); nc_client_set_schema_searchpath(NULL); nc_client_unix_set_username(NULL); #ifdef NC_ENABLED_SSH_TLS diff --git a/src/session_client_tls.c b/src/session_client_tls.c index 4404d95f..e30f2e8d 100644 --- a/src/session_client_tls.c +++ b/src/session_client_tls.c @@ -231,7 +231,7 @@ nc_client_tls_connect_check(int connect_ret, void *tls_session, const char *peer static void * nc_client_tls_session_new(int sock, const char *host, int timeout, struct nc_client_tls_opts *opts, void **out_tls_cfg, struct nc_tls_ctx *tls_ctx) { - int ret = 0, sock_tmp = sock; + int ret = 0; struct timespec ts_timeout; void *tls_session, *tls_cfg, *cli_cert, *cli_pkey, *cert_store, *crl_store; @@ -300,7 +300,7 @@ nc_client_tls_session_new(int sock, const char *host, int timeout, struct nc_cli if (timeout > -1) { nc_timeouttime_get(&ts_timeout, timeout); } - while ((ret = nc_client_tls_handshake_step_wrap(tls_session, sock_tmp)) == 0) { + while ((ret = nc_client_tls_handshake_step_wrap(tls_session, sock)) == 0) { usleep(NC_TIMEOUT_STEP); if ((timeout > -1) && (nc_timeouttime_cur_diff(&ts_timeout) < 1)) { ERR(NULL, "SSL connect timeout."); diff --git a/src/session_mbedtls.c b/src/session_mbedtls.c index 90ce003e..d398e103 100644 --- a/src/session_mbedtls.c +++ b/src/session_mbedtls.c @@ -978,7 +978,7 @@ nc_server_tls_send(void *ctx, const unsigned char *buf, size_t len) ret = send(sock, buf, len, MSG_NOSIGNAL); if (ret < 0) { - if ((errno == EWOULDBLOCK) || (errno = EAGAIN) || (errno == EINTR)) { + if ((errno == EWOULDBLOCK) || (errno == EAGAIN) || (errno == EINTR)) { return MBEDTLS_ERR_SSL_WANT_WRITE; } else if ((errno == EPIPE) || (errno == ECONNRESET)) { return MBEDTLS_ERR_NET_CONN_RESET; @@ -1009,7 +1009,7 @@ nc_server_tls_recv(void *ctx, unsigned char *buf, size_t len) ret = recv(sock, buf, len, 0); if (ret < 0) { - if ((errno == EWOULDBLOCK) || (errno = EAGAIN) || (errno == EINTR)) { + if ((errno == EWOULDBLOCK) || (errno == EAGAIN) || (errno == EINTR)) { return MBEDTLS_ERR_SSL_WANT_READ; } else if ((errno == EPIPE) || (errno == ECONNRESET)) { return MBEDTLS_ERR_NET_CONN_RESET; diff --git a/src/session_p.h b/src/session_p.h index 29b74167..6d7af74b 100644 --- a/src/session_p.h +++ b/src/session_p.h @@ -492,7 +492,6 @@ struct nc_bind { char *address; /**< Either IPv4/IPv6 address or path to UNIX socket. */ uint16_t port; /**< Either port number or 0 for UNIX socket. */ int sock; /**< Socket file descriptor, -1 if not created yet. */ - int pollin; /**< Specifies, which sockets to poll on. */ }; struct nc_client_unix_opts { @@ -573,7 +572,6 @@ struct nc_client_opts { struct nc_keepalives ka; struct nc_bind *ch_binds; - pthread_mutex_t ch_bind_lock; /**< To avoid concurrent calls of poll and accept on the bound sockets **/ struct { NC_TRANSPORT_IMPL ti; @@ -692,9 +690,7 @@ struct nc_server_config { struct nc_endpt { char *name; /**< Identifier of the endpoint. */ - /* ACCESS locked - bind_lock */ struct nc_bind *binds; /**< Listening binds of the endpoint (sized-array, see libyang docs). */ - pthread_mutex_t bind_lock; /**< To avoid concurrent calls of poll and accept on the bound sockets. **/ struct nc_keepalives ka; /**< TCP keepalives configuration. */ @@ -1285,23 +1281,19 @@ int nc_sock_accept(int sock, int timeout, char **peer_host, uint16_t *peer_port) int nc_sock_listen_inet(const char *address, uint16_t port); /** - * @brief Accept a new connection on a listening socket. + * @brief Accept a new connection on any of the given Call Home binds. * - * @param[in] endpt Optional endpoint the binds belong to (only for logging purposes). - * @param[in] binds Structure with the listening sockets. + * @param[in] binds Call Home binds to accept on. * @param[in] bind_count Number of @p binds. - * @param[in] bind_lock Lock for avoiding concurrent poll/accept on a single bind. * @param[in] timeout Timeout for accepting. * @param[out] host Host of the remote peer. Can be NULL. * @param[out] port Port of the new connection. Can be NULL. - * @param[out] idx Index of the bind that was accepted. Can be NULL. + * @param[out] bind_idx Index of the bind that was accepted. Can be NULL. * @param[out] sock Accepted socket, if any. - * @return -1 on error. - * @return 0 on timeout. - * @return 1 if a socket was accepted. + * @return -1 on error, 0 on timeout, 1 if a socket was accepted. */ -int nc_sock_accept_binds(struct nc_endpt *endpt, struct nc_bind *binds, uint16_t bind_count, - pthread_mutex_t *bind_lock, int timeout, char **host, uint16_t *port, uint16_t *idx, int *sock); +int nc_server_ch_accept_binds(struct nc_bind *binds, uint16_t bind_count, int timeout, char **host, + uint16_t *port, uint16_t *bind_idx, int *sock); /** * @brief Establish a UNIX transport session. diff --git a/src/session_server.c b/src/session_server.c index 90c5fc69..2d5e8937 100644 --- a/src/session_server.c +++ b/src/session_server.c @@ -332,8 +332,7 @@ nc_sock_bind_inet(int sock, const char *address, uint16_t port, int is_ipv4) int nc_sock_listen_inet(const char *address, uint16_t port) { - int opt; - int is_ipv4, sock; + int opt, flags, is_ipv4, sock; if (!strchr(address, ':')) { is_ipv4 = 1; @@ -347,6 +346,12 @@ nc_sock_listen_inet(const char *address, uint16_t port) goto fail; } + /* make the socket non-blocking */ + if (((flags = fcntl(sock, F_GETFL)) == -1) || (fcntl(sock, F_SETFL, flags | O_NONBLOCK) == -1)) { + ERR(NULL, "Fcntl failed (%s).", strerror(errno)); + goto fail; + } + /* these options will be inherited by accepted sockets */ opt = 1; if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof opt) == -1) { @@ -414,6 +419,7 @@ nc_session_unix_construct_socket_path(const char *filename, char **path) if (strlen(full_path) > sizeof(sun.sun_path) - 1) { ERR(NULL, "Socket path \"%s\" is too long.", full_path); + rc = 1; goto cleanup; } @@ -525,7 +531,7 @@ static int nc_sock_listen_unix(const char *address, const struct nc_server_unix_opts *opts) { struct sockaddr_un sun; - int sock = -1; + int sock = -1, flags; if (!address) { ERR(NULL, "No socket path set."); @@ -565,6 +571,12 @@ nc_sock_listen_unix(const char *address, const struct nc_server_unix_opts *opts) } } + /* make the socket non-blocking */ + if (((flags = fcntl(sock, F_GETFL)) == -1) || (fcntl(sock, F_SETFL, flags | O_NONBLOCK) == -1)) { + ERR(NULL, "Fcntl failed (%s).", strerror(errno)); + goto fail; + } + if (listen(sock, NC_REVERSE_QUEUE) == -1) { ERR(NULL, "Unable to start listening on \"%s\" (%s).", address, strerror(errno)); goto fail; @@ -588,7 +600,7 @@ nc_sock_listen_unix(const char *address, const struct nc_server_unix_opts *opts) * @return -1 in case of error. Parameter host is set to NULL. */ static int -sock_host_unix(int acc_sock_fd, char **host) +nc_sock_host_unix(int acc_sock_fd, char **host) { char *sun_path; struct sockaddr_storage saddr; @@ -623,7 +635,7 @@ sock_host_unix(int acc_sock_fd, char **host) * @return -1 in case of error. Parameter host is set to NULL and port is unchanged. */ static int -sock_host_inet(const struct sockaddr_in *addr, char **host, uint16_t *port) +nc_sock_host_inet(const struct sockaddr_in *addr, char **host, uint16_t *port) { *host = malloc(INET_ADDRSTRLEN); NC_CHECK_ERRMEM_RET(!(*host), -1); @@ -649,7 +661,7 @@ sock_host_inet(const struct sockaddr_in *addr, char **host, uint16_t *port) * @return -1 in case of error. Parameter host is set to the NULL and port is unchanged. */ static int -sock_host_inet6(const struct sockaddr_in6 *addr, char **host, uint16_t *port) +nc_sock_host_inet6(const struct sockaddr_in6 *addr, char **host, uint16_t *port) { *host = malloc(INET6_ADDRSTRLEN); NC_CHECK_ERRMEM_RET(!(*host), -1); @@ -666,150 +678,344 @@ sock_host_inet6(const struct sockaddr_in6 *addr, char **host, uint16_t *port) return 0; } -int -nc_sock_accept_binds(struct nc_endpt *endpt, struct nc_bind *binds, uint16_t bind_count, - pthread_mutex_t *bind_lock, int timeout, char **host, uint16_t *port, uint16_t *idx, int *sock) +/** + * @brief Get the client's host information from the accepted socket address. + * + * @param[in] saddr sockaddr_storage. + * @param[in] client_sock Socket FD of the accepted connection. + * @param[out] client_address Hostname or IP address of the connecting client (must be freed by the caller). + * @param[out] client_port Port number of the connecting client, if any (0 for AF_UNIX). + * @return 0 on success, -1 on error. + */ +static int +nc_sock_host_get(const struct sockaddr_storage *saddr, int client_sock, char **client_address, uint16_t *client_port) { - uint16_t i, j, pfd_count, client_port; - char *client_address, *sockpath = NULL; - struct pollfd *pfd; - struct sockaddr_storage saddr; - socklen_t saddr_len = sizeof(saddr); - int ret, client_sock, server_sock = -1, flags; - - pfd = malloc(bind_count * sizeof *pfd); - NC_CHECK_ERRMEM_RET(!pfd, -1); - - /* LOCK */ - if (nc_mutex_lock(bind_lock, timeout, __func__) != 1) { - free(pfd); - return -1; - } + int rc = 0; - for (i = 0, pfd_count = 0; i < bind_count; ++i) { - if (binds[i].sock < 0) { - /* invalid socket */ - continue; + /* learn information about the client end */ + if (saddr->ss_family == AF_UNIX) { + if ((rc = nc_sock_host_unix(client_sock, client_address))) { + goto cleanup; } - if (binds[i].pollin) { - binds[i].pollin = 0; - /* leftover pollin */ - server_sock = binds[i].sock; - break; + *client_port = 0; + } else if (saddr->ss_family == AF_INET) { + if ((rc = nc_sock_host_inet((struct sockaddr_in *)saddr, client_address, client_port))) { + goto cleanup; } - pfd[pfd_count].fd = binds[i].sock; - pfd[pfd_count].events = POLLIN; - pfd[pfd_count].revents = 0; - - ++pfd_count; + } else if (saddr->ss_family == AF_INET6) { + if ((rc = nc_sock_host_inet6((struct sockaddr_in6 *)saddr, client_address, client_port))) { + goto cleanup; + } + } else { + ERR(NULL, "Source host of an unknown protocol family."); + rc = -1; + goto cleanup; } - if (server_sock == -1) { - /* poll for a new connection */ - ret = nc_poll(pfd, pfd_count, timeout); - if (ret < 1) { - free(pfd); +cleanup: + return rc; +} - /* UNLOCK */ - nc_mutex_unlock(bind_lock, __func__); +/** + * @brief Log the accepted connection. + * + * @param[in] saddr sockaddr_storage. + * @param[in] endpt Endpoint on which the connection was accepted (optional, used for logging). + * @param[in] bind Bind on which the connection was accepted. + * @param[in] client_address Hostname or IP address of the connecting client. + * @param[in] client_port Port number of the connecting client, if any. + * @return 0 on success, -1 on error. + */ +static int +nc_sock_log_accepted(const struct sockaddr_storage *saddr, const struct nc_endpt *endpt, const struct nc_bind *bind, + const char *client_address, uint16_t client_port) +{ + char *unix_sockpath = NULL; + + if (saddr->ss_family == AF_UNIX) { + /* UNIX socket, get the socket path for logging, + * UNIX socket connection can NOT be over call home (caller = client connect), so endpt is always available */ + assert(endpt); + unix_sockpath = nc_server_unix_get_socket_path(endpt); + VRB(NULL, "Accepted a new connection on %s.", unix_sockpath ? unix_sockpath : "UNIX socket"); + free(unix_sockpath); + } else if (saddr->ss_family == AF_INET) { + /* IPv4 socket */ + VRB(NULL, "Accepted a new connection on %s:%" PRIu16 " from %s:%" PRIu16 ".", bind->address, bind->port, + client_address, client_port); + } else if (saddr->ss_family == AF_INET6) { + /* IPv6 socket */ + VRB(NULL, "Accepted a new connection on [%s]:%" PRIu16 " from [%s]:%" PRIu16 ".", bind->address, bind->port, + client_address, client_port); + } else { + ERR(NULL, "Source host of an unknown protocol family."); + return -1; + } - return ret; - } + return 0; +} - for (i = 0, j = 0; j < pfd_count; ++i, ++j) { - /* adjust i so that indices in binds and pfd always match */ - while (binds[i].sock != pfd[j].fd) { - ++i; - } +/** + * @brief Accept the first available connection on the given pollfds. + * + * @param[in] pfd Array of pollfds to check for incoming connections. + * @param[in] pfd_count Number of pollfds in the array. + * @param[out] client_sock Socket file descriptor of the accepted connection, -1 if no connection was accepted. + * @param[out] saddr sockaddr_storage to store the address of the connecting client. + * @param[out] saddr_len Length of the sockaddr_storage structure. + * @param[out] fd_idx Index of the pollfd on which the connection was accepted, valid only if client_sock is not -1. + * @return 0 on success, -1 on error (client_sock will be -1 on error or if no connection was accepted). + */ +static int +nc_sock_accept_first(struct pollfd *pfd, uint16_t pfd_count, int *client_sock, + struct sockaddr_storage *saddr, socklen_t *saddr_len, uint16_t *fd_idx) +{ + int sock = -1; + uint16_t i; - if (pfd[j].revents & POLLIN) { - --ret; + *client_sock = -1; - if (!ret) { - /* the last socket with an event, use it */ - server_sock = pfd[j].fd; - break; - } else { - /* just remember the event for next time */ - binds[i].pollin = 1; + for (i = 0; i < pfd_count; i++) { + if (pfd[i].revents & POLLIN) { + sock = accept(pfd[i].fd, (struct sockaddr *)saddr, saddr_len); + if (sock < 0) { + if ((errno == EAGAIN) || (errno == EWOULDBLOCK)) { + /* another thread already accepted the connection, try another one */ + continue; } + ERR(NULL, "Accept failed (%s).", strerror(errno)); + return -1; } + + /* successfully accepted a connection! */ + break; } } - free(pfd); - if (server_sock == -1) { - ERRINT; - /* UNLOCK */ - nc_mutex_unlock(bind_lock, __func__); - return -1; + + if (sock != -1) { + *client_sock = sock; + *fd_idx = i; } - /* accept connection */ - client_sock = accept(server_sock, (struct sockaddr *)&saddr, &saddr_len); - if (client_sock < 0) { - ERR(NULL, "Accept failed (%s).", strerror(errno)); - /* UNLOCK */ - nc_mutex_unlock(bind_lock, __func__); - return -1; + return 0; +} + +/** + * @brief Accept a new connection on any of the given pollfds. + * + * Can be called by multiple threads. If there is only a single event, one thread will accept the connection, + * others will timeout. + * + * @param[in] pollfds FDs to poll for new connections. + * @param[in] pollfd_count Number of FDs in the pollfds array. + * @param[in] endpt_map Map of pollfd indices to endpoints (optional, used for logging). + * @param[in] bind_map Map of pollfd indices to binds (optional, used for logging). + * @param[in] timeout Timeout for accepting a connection. + * @param[out] host Hostname or IP address of the connecting client. + * @param[out] port Port number of the connecting client, if any. + * @param[out] fd_idx Index of the pollfd on which the connection was accepted. + * @param[out] sock Socket file descriptor of the accepted connection. + * @return 1 on success, 0 on timeout, -1 on error. + */ +static int +nc_sock_accept_pollfds(struct pollfd *pollfds, uint16_t pollfd_count, struct nc_endpt **endpt_map, + struct nc_bind **bind_map, int timeout, char **host, uint16_t *port, + uint16_t *fd_idx, int *sock) +{ + uint16_t client_port = 0, matched_pollfd_idx = 0; + char *client_address = NULL; + struct sockaddr_storage client_saddr; + socklen_t saddr_len = sizeof(client_saddr); + struct nc_endpt *endpt; + struct nc_bind *bind; + int client_sock = -1, ret = 1, r, flags; + + if (!pollfd_count) { + /* no FDs to poll, treat as a timeout */ + ret = 0; + goto cleanup; + } + + /* poll for a new connection */ + r = nc_poll(pollfds, pollfd_count, timeout); + if (r < 1) { + /* either 0 (timeout) or -1 (error) */ + ret = r; + goto cleanup; + } + + /* try to accept the first available connection */ + if ((r = nc_sock_accept_first(pollfds, pollfd_count, &client_sock, &client_saddr, &saddr_len, &matched_pollfd_idx))) { + ret = r; + goto cleanup; + } + if (client_sock == -1) { + /* all events were stolen by other threads, treat as a timeout */ + ret = 0; + goto cleanup; } + bind = bind_map[matched_pollfd_idx]; + endpt = endpt_map ? endpt_map[matched_pollfd_idx] : NULL; + /* make the socket non-blocking */ if (((flags = fcntl(client_sock, F_GETFL)) == -1) || (fcntl(client_sock, F_SETFL, flags | O_NONBLOCK) == -1)) { ERR(NULL, "Fcntl failed (%s).", strerror(errno)); - goto fail; + goto cleanup; } - /* learn information about the client end */ - if (saddr.ss_family == AF_UNIX) { - if (sock_host_unix(client_sock, &client_address)) { - goto fail; - } - client_port = 0; - } else if (saddr.ss_family == AF_INET) { - if (sock_host_inet((struct sockaddr_in *)&saddr, &client_address, &client_port)) { - goto fail; - } - } else if (saddr.ss_family == AF_INET6) { - if (sock_host_inet6((struct sockaddr_in6 *)&saddr, &client_address, &client_port)) { - goto fail; - } - } else { - ERR(NULL, "Source host of an unknown protocol family."); - goto fail; + /* learn information about the peer */ + if ((r = nc_sock_host_get(&client_saddr, client_sock, &client_address, &client_port))) { + ret = r; + goto cleanup; } - if (saddr.ss_family == AF_UNIX) { - if (endpt) { - sockpath = nc_server_unix_get_socket_path(endpt); - } - VRB(NULL, "Accepted a connection on %s.", sockpath ? sockpath : "UNIX socket"); - free(sockpath); - } else { - VRB(NULL, "Accepted a connection on %s:%u from %s:%u.", binds[i].address, binds[i].port, client_address, client_port); + /* log the new accepted connection */ + if ((r = nc_sock_log_accepted(&client_saddr, endpt, bind, client_address, client_port))) { + ret = r; + goto cleanup; } if (host) { *host = client_address; - } else { - free(client_address); + client_address = NULL; } if (port) { *port = client_port; } - if (idx) { - *idx = i; + if (fd_idx) { + *fd_idx = matched_pollfd_idx; } - /* UNLOCK */ - nc_mutex_unlock(bind_lock, __func__); - *sock = client_sock; - return 1; + client_sock = -1; -fail: - close(client_sock); - /* UNLOCK */ - nc_mutex_unlock(bind_lock, __func__); - return -1; +cleanup: + free(client_address); + if (client_sock > -1) { + close(client_sock); + } + return ret; +} + +/** + * @brief Accept a new connection on any of the server's listening binds. + * + * @param[in] config Server configuration. + * @param[in] timeout Timeout for accepting a connection. + * @param[out] host Hostname or IP address of the connecting client. + * @param[out] port Port number of the connecting client, if any. + * @param[out] idx Index of the endpoint on which the connection was accepted (optional). + * @param[out] sock Socket file descriptor of the accepted connection. + * @return 1 on success, 0 on timeout, -1 on error. + */ +static int +nc_server_accept_binds(struct nc_server_config *config, int timeout, char **host, + uint16_t *port, LY_ARRAY_COUNT_TYPE *idx, int *sock) +{ + struct pollfd *pollfds = NULL; + uint16_t pollfd_count = 0, fd_idx = 0, bind_count = 0; + LY_ARRAY_COUNT_TYPE i; + struct nc_endpt *endpt; + struct nc_bind *bind; + int ret = 1; + struct nc_endpt **endpt_map = NULL; + struct nc_bind **bind_map = NULL; + + /* count the number of valid binds and prepare the pollfd and map parallel arrays */ + LY_ARRAY_FOR(config->endpts, i) { + bind_count += LY_ARRAY_COUNT(config->endpts[i].binds); + } + if (!bind_count) { + /* no binds to accept on, treat as a timeout */ + ret = 0; + goto cleanup; + } + + pollfds = malloc(bind_count * sizeof *pollfds); + NC_CHECK_ERRMEM_RET(!pollfds, -1); + endpt_map = malloc(bind_count * sizeof *endpt_map); + NC_CHECK_ERRMEM_GOTO(!endpt_map, ret = -1, cleanup); + bind_map = malloc(bind_count * sizeof *bind_map); + NC_CHECK_ERRMEM_GOTO(!bind_map, ret = -1, cleanup); + + /* fill the arrays */ + LY_ARRAY_FOR(config->endpts, struct nc_endpt, endpt) { + LY_ARRAY_FOR(endpt->binds, struct nc_bind, bind) { + if (bind->sock < 0) { + /* invalid socket */ + continue; + } + + pollfds[pollfd_count].fd = bind->sock; + pollfds[pollfd_count].events = POLLIN; + pollfds[pollfd_count].revents = 0; + + endpt_map[pollfd_count] = endpt; + bind_map[pollfd_count] = bind; + + ++pollfd_count; + } + } + + /* accept a new connection on any of the sockets */ + ret = nc_sock_accept_pollfds(pollfds, pollfd_count, endpt_map, bind_map, timeout, host, port, &fd_idx, sock); + if (idx && (ret > 0)) { + *idx = endpt_map[fd_idx] - config->endpts; + } + +cleanup: + free(pollfds); + free(endpt_map); + free(bind_map); + return ret; +} + +int +nc_server_ch_accept_binds(struct nc_bind *binds, uint16_t bind_count, int timeout, char **host, + uint16_t *port, uint16_t *bind_idx, int *sock) +{ + struct pollfd *pollfds = NULL; + uint16_t pollfd_count = 0, fd_idx = 0, i; + int ret = 1; + struct nc_bind **bind_map = NULL; + + if (!bind_count) { + /* no binds to accept on, treat as a timeout */ + ret = 0; + goto cleanup; + } + + /* prepare the pollfd and map parallel arrays */ + pollfds = malloc(bind_count * sizeof *pollfds); + NC_CHECK_ERRMEM_RET(!pollfds, -1); + bind_map = malloc(bind_count * sizeof *bind_map); + NC_CHECK_ERRMEM_GOTO(!bind_map, ret = -1, cleanup); + + /* fill the arrays */ + for (i = 0; i < bind_count; ++i) { + if (binds[i].sock < 0) { + /* invalid socket */ + continue; + } + + pollfds[pollfd_count].fd = binds[i].sock; + pollfds[pollfd_count].events = POLLIN; + pollfds[pollfd_count].revents = 0; + + bind_map[pollfd_count] = &binds[i]; + + ++pollfd_count; + } + + ret = nc_sock_accept_pollfds(pollfds, pollfd_count, NULL, bind_map, timeout, host, port, &fd_idx, sock); + if (bind_idx && (ret > 0)) { + *bind_idx = bind_map[fd_idx] - binds; + } + +cleanup: + free(pollfds); + free(bind_map); + return ret; } API struct nc_server_reply * @@ -2758,19 +2964,12 @@ nc_accept(int timeout, const struct ly_ctx *ctx, struct nc_session **session) goto cleanup; } - /* try to accept a new connection on any of the endpoints */ - LY_ARRAY_FOR(config->endpts, endpt_idx) { - ret = nc_sock_accept_binds(&config->endpts[endpt_idx], config->endpts[endpt_idx].binds, LY_ARRAY_COUNT(config->endpts[endpt_idx].binds), - &config->endpts[endpt_idx].bind_lock, timeout, &host, &port, NULL, &sock); - if (ret < 0) { - msgtype = NC_MSG_ERROR; - goto cleanup; - } else if (ret > 0) { - /* accepted */ - break; - } - } - if (sock == -1) { + /* try to accept a new connection on any of the listening endpoints */ + ret = nc_server_accept_binds(config, timeout, &host, &port, &endpt_idx, &sock); + if (ret < 0) { + msgtype = NC_MSG_ERROR; + goto cleanup; + } else if (!ret) { /* timeout, no connection established */ msgtype = NC_MSG_WOULDBLOCK; goto cleanup; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index cbea5b60..f2a520ac 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -81,6 +81,7 @@ if(ENABLE_SSH_TLS) endif() libnetconf2_test(NAME test_replace) libnetconf2_test(NAME test_runtime_changes PORT_COUNT 2) + libnetconf2_test(NAME test_server_thread) libnetconf2_test(NAME test_ssh) libnetconf2_test(NAME test_tls) libnetconf2_test(NAME test_two_channels) diff --git a/tests/test_server_thread.c b/tests/test_server_thread.c new file mode 100644 index 00000000..2f2b25bf --- /dev/null +++ b/tests/test_server_thread.c @@ -0,0 +1,300 @@ +/** + * @file test_server_thread.c + * @author Roman Janota + * @brief libnetconf2 parallel server accept thread test. + * + * @copyright + * Copyright (c) 2026 CESNET, z.s.p.o. + * + * This source code is licensed under BSD 3-Clause License (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + */ + +#define _GNU_SOURCE + +#include +#include +#include +#include +#include +#include + +#include + +#include "ln2_test.h" + +#define PARALLEL_SERVER_THREADS 4 +#define PARALLEL_CLIENT_THREADS 4 +#define NOCLIENT_ATTEMPTS 8 +#define NONBLOCK_BACKOFF_USECS 1000 +#define SHORT_ACCEPT_TIMEOUT 400 + +int TEST_PORT = 10050; +const char *TEST_PORT_STR = "10050"; + +struct accept_state { + pthread_barrier_t start_barrier; + pthread_mutex_t lock; + int accept_timeout; + int client_count; + int accepted_count; + int timeout_count; + struct ln2_test_ctx *test_ctx; +}; + +struct no_client_state { + pthread_barrier_t start_barrier; + int accept_timeout; + struct ln2_test_ctx *test_ctx; +}; + +static void * +server_thread_accept_all(void *arg) +{ + int done; + NC_MSG_TYPE msgtype; + struct nc_session *session = NULL; + struct accept_state *state = arg; + + /* wait until all server and client threads are ready to start the test */ + pthread_barrier_wait(&state->start_barrier); + + while (1) { + pthread_mutex_lock(&state->lock); + done = state->accepted_count >= state->client_count; + pthread_mutex_unlock(&state->lock); + if (done) { + break; + } + + msgtype = nc_accept(state->accept_timeout, state->test_ctx->ctx, &session); + if (msgtype == NC_MSG_HELLO) { + assert_non_null(session); + nc_session_free(session, NULL); + + pthread_mutex_lock(&state->lock); + ++state->accepted_count; + pthread_mutex_unlock(&state->lock); + } else if (msgtype == NC_MSG_WOULDBLOCK) { + assert_null(session); + + pthread_mutex_lock(&state->lock); + ++state->timeout_count; + pthread_mutex_unlock(&state->lock); + + usleep(NONBLOCK_BACKOFF_USECS); + } else { + fail_msg("Unexpected nc_accept return code %d", msgtype); + } + } + + return NULL; +} + +static void * +client_thread_connect(void *arg) +{ + int ret = 0; + struct nc_session *session = NULL; + struct accept_state *state = arg; + + ret = nc_client_set_schema_searchpath(MODULES_DIR); + assert_int_equal(ret, 0); + + ret = nc_client_ssh_set_username("parallel_client"); + assert_int_equal(ret, 0); + + ret = nc_client_ssh_add_keypair(TESTS_DIR "/data/id_ed25519.pub", TESTS_DIR "/data/id_ed25519"); + assert_int_equal(ret, 0); + + nc_client_ssh_set_knownhosts_mode(NC_SSH_KNOWNHOSTS_SKIP); + + /* wait until the server threads are ready to accept connections */ + pthread_barrier_wait(&state->start_barrier); + + session = nc_connect_ssh("127.0.0.1", TEST_PORT, NULL); + assert_non_null(session); + nc_session_free(session, NULL); + + return NULL; +} + +static void * +server_thread_timeout_only(void *arg) +{ + int i; + NC_MSG_TYPE msgtype; + struct nc_session *session = NULL; + struct no_client_state *state = arg; + + /* wait until all server threads are ready to start the test */ + pthread_barrier_wait(&state->start_barrier); + + for (i = 0; i < NOCLIENT_ATTEMPTS; ++i) { + msgtype = nc_accept(state->accept_timeout, state->test_ctx->ctx, &session); + assert_int_equal(msgtype, NC_MSG_WOULDBLOCK); + assert_null(session); + + usleep(NONBLOCK_BACKOFF_USECS); + } + + return NULL; +} + +static void +run_parallel_accept(void **state, int accept_timeout) +{ + int i, ret; + pthread_t server_tids[PARALLEL_SERVER_THREADS]; + pthread_t client_tids[PARALLEL_CLIENT_THREADS]; + struct ln2_test_ctx *test_ctx = *state; + struct accept_state accept_state; + + accept_state.accept_timeout = accept_timeout; + accept_state.client_count = PARALLEL_CLIENT_THREADS; + accept_state.accepted_count = 0; + accept_state.timeout_count = 0; + accept_state.test_ctx = test_ctx; + + /* sync all threads to start at the same time */ + ret = pthread_barrier_init(&accept_state.start_barrier, NULL, + PARALLEL_SERVER_THREADS + PARALLEL_CLIENT_THREADS + 1); + assert_int_equal(ret, 0); + ret = pthread_mutex_init(&accept_state.lock, NULL); + assert_int_equal(ret, 0); + + /* start server threads */ + for (i = 0; i < PARALLEL_SERVER_THREADS; ++i) { + ret = pthread_create(&server_tids[i], NULL, server_thread_accept_all, &accept_state); + assert_int_equal(ret, 0); + } + + /* start client threads */ + for (i = 0; i < PARALLEL_CLIENT_THREADS; ++i) { + ret = pthread_create(&client_tids[i], NULL, client_thread_connect, &accept_state); + assert_int_equal(ret, 0); + } + + /* wait until all threads are ready to start the test */ + pthread_barrier_wait(&accept_state.start_barrier); + + for (i = 0; i < PARALLEL_CLIENT_THREADS; ++i) { + pthread_join(client_tids[i], NULL); + } + for (i = 0; i < PARALLEL_SERVER_THREADS; ++i) { + pthread_join(server_tids[i], NULL); + } + + /* all clients should have been accepted */ + assert_int_equal(accept_state.accepted_count, PARALLEL_CLIENT_THREADS); + + pthread_mutex_destroy(&accept_state.lock); + pthread_barrier_destroy(&accept_state.start_barrier); +} + +static void +run_timeout_only(void **state, int accept_timeout) +{ + int i, ret; + pthread_t server_tids[PARALLEL_SERVER_THREADS]; + struct ln2_test_ctx *test_ctx = *state; + struct no_client_state no_client_state; + + no_client_state.accept_timeout = accept_timeout; + no_client_state.test_ctx = test_ctx; + + /* sync all threads to start at the same time */ + ret = pthread_barrier_init(&no_client_state.start_barrier, NULL, PARALLEL_SERVER_THREADS + 1); + assert_int_equal(ret, 0); + + /* start server threads, no client threads will be started, so all threads should only experience timeouts */ + for (i = 0; i < PARALLEL_SERVER_THREADS; ++i) { + ret = pthread_create(&server_tids[i], NULL, server_thread_timeout_only, &no_client_state); + assert_int_equal(ret, 0); + } + + /* wait until all threads are ready to start the test */ + pthread_barrier_wait(&no_client_state.start_barrier); + + for (i = 0; i < PARALLEL_SERVER_THREADS; ++i) { + pthread_join(server_tids[i], NULL); + } + + pthread_barrier_destroy(&no_client_state.start_barrier); +} + +static void +test_parallel_accept_nonblocking(void **state) +{ + run_parallel_accept(state, 0); +} + +static void +test_parallel_accept_timed(void **state) +{ + run_parallel_accept(state, NC_ACCEPT_TIMEOUT); +} + +static void +test_parallel_accept_timeout_only_nonblocking(void **state) +{ + run_timeout_only(state, 0); +} + +static void +test_parallel_accept_timeout_only_timed(void **state) +{ + run_timeout_only(state, SHORT_ACCEPT_TIMEOUT); +} + +static int +setup_ssh(void **state) +{ + int ret; + struct lyd_node *tree = NULL; + struct ln2_test_ctx *test_ctx; + + ret = ln2_glob_test_setup(&test_ctx); + assert_int_equal(ret, 0); + + *state = test_ctx; + + /* setup server with single SSH endpoint and one user with public key authentication */ + ret = nc_server_config_add_address_port(test_ctx->ctx, "endpt", NC_TI_SSH, "127.0.0.1", TEST_PORT, &tree); + assert_int_equal(ret, 0); + + ret = nc_server_config_add_ssh_hostkey(test_ctx->ctx, "endpt", "hostkey", TESTS_DIR "/data/key_ecdsa", NULL, &tree); + assert_int_equal(ret, 0); + + ret = nc_server_config_add_ssh_user_pubkey(test_ctx->ctx, "endpt", "parallel_client", "pubkey", + TESTS_DIR "/data/id_ed25519.pub", &tree); + assert_int_equal(ret, 0); + + ret = nc_server_config_setup_data(tree); + assert_int_equal(ret, 0); + + lyd_free_all(tree); + + return 0; +} + +int +main(void) +{ + const struct CMUnitTest tests[] = { + cmocka_unit_test_setup_teardown(test_parallel_accept_nonblocking, setup_ssh, ln2_glob_test_teardown), + cmocka_unit_test_setup_teardown(test_parallel_accept_timed, setup_ssh, ln2_glob_test_teardown), + cmocka_unit_test_setup_teardown(test_parallel_accept_timeout_only_nonblocking, setup_ssh, ln2_glob_test_teardown), + cmocka_unit_test_setup_teardown(test_parallel_accept_timeout_only_timed, setup_ssh, ln2_glob_test_teardown), + }; + + if (ln2_glob_test_get_ports(1, &TEST_PORT, &TEST_PORT_STR)) { + return 1; + } + + setenv("CMOCKA_TEST_ABORT", "1", 1); + return cmocka_run_group_tests(tests, NULL, NULL); +}