diff --git a/app/src/server.c b/app/src/server.c index 77978d19..e737aa2b 100644 --- a/app/src/server.c +++ b/app/src/server.c @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -337,6 +338,10 @@ static int run_wait_server(void *data) { struct server *server = data; cmd_simple_wait(server->process, NULL); // ignore exit code + + // wake up any net_select_interruptible() + close(server->pipe_intr[1]); + LOGD("Server terminated"); return 0; } @@ -361,10 +366,16 @@ server_start(struct server *server, const char *serial, goto error1; } + bool ok = net_pipe(server->pipe_intr); + if (!ok) { + perror("pipe"); + goto error2; + } + // server will connect to our server socket server->process = execute_server(server, params); if (server->process == PROCESS_NONE) { - goto error2; + goto error3; } server->wait_server_thread = @@ -374,13 +385,16 @@ server_start(struct server *server, const char *serial, LOGW("Could not terminate server"); } cmd_simple_wait(server->process, NULL); // ignore exit code - goto error2; + goto error3; } server->tunnel_enabled = true; return true; +error3: + close(server->pipe_intr[0]); + close(server->pipe_intr[1]); error2: if (!server->tunnel_forward) { close_socket(&server->server_socket); @@ -394,17 +408,32 @@ error1: bool server_connect_to(struct server *server) { if (!server->tunnel_forward) { + bool acceptable = net_select_interruptible(server->server_socket, + server->pipe_intr[0]); + if (!acceptable) { + // the process died, accept() would never succeed + return false; + } + server->video_socket = net_accept(server->server_socket); if (server->video_socket == INVALID_SOCKET) { return false; } + acceptable = net_select_interruptible(server->server_socket, + server->pipe_intr[0]); + if (!acceptable) { + // the process died, accept() would never succeed + return false; + } server->control_socket = net_accept(server->server_socket); if (server->control_socket == INVALID_SOCKET) { // the video_socket will be cleaned up on destroy return false; } + close(server->pipe_intr[0]); + // we don't need the server socket anymore close_socket(&server->server_socket); } else { diff --git a/app/src/server.h b/app/src/server.h index f7c063a2..2c311879 100644 --- a/app/src/server.h +++ b/app/src/server.h @@ -14,6 +14,7 @@ struct server { char *serial; process_t process; SDL_Thread *wait_server_thread; + int pipe_intr[2]; // to wake up blocking accept() on process exit socket_t server_socket; // only used if !tunnel_forward socket_t video_socket; socket_t control_socket; diff --git a/app/src/util/net.c b/app/src/util/net.c index efce6fa9..f7e827f5 100644 --- a/app/src/util/net.c +++ b/app/src/util/net.c @@ -1,16 +1,22 @@ #include "net.h" +#include +#include #include #include #include "config.h" +#include "common.h" #include "log.h" #ifdef __WINDOWS__ +# include +# include typedef int socklen_t; #else -# include +# include # include +# include # include # include # include @@ -145,3 +151,38 @@ net_close(socket_t socket) { return !close(socket); #endif } + +bool +net_select_interruptible(int fd, int fd_intr) { + fd_set rfds; + + FD_ZERO(&rfds); + FD_SET(fd, &rfds); + FD_SET(fd_intr, &rfds); + + int nfds = MAX(fd, fd_intr) + 1; + + // use select() because it's available on supported platforms + int r = select(nfds, &rfds, NULL, NULL, NULL); + if (r == -1) { + // failure + return false; + } + assert(r > 0); + if (FD_ISSET(fd_intr, &rfds)) { + // interrupted is set + return false; + } + + assert(FD_ISSET(fd, &rfds)); + return true; +} + +bool +net_pipe(int fds[static 2]) { +#ifdef __WINDOWS__ + return !_pipe(fds, 4096, 0); +#else + return !pipe(fds); +#endif +} diff --git a/app/src/util/net.h b/app/src/util/net.h index ffd5dd89..ddc245c4 100644 --- a/app/src/util/net.h +++ b/app/src/util/net.h @@ -54,4 +54,12 @@ net_shutdown(socket_t socket, int how); bool net_close(socket_t socket); +// wait for fd or fd_intr to be readable +// return true if fd is readable and fd_intr is not +bool +net_select_interruptible(int fd, int fd_intr); + +bool +net_pipe(int fd[static 2]); + #endif