[llvm][Support] Implement raw_socket_stream::read with optional timeout (#92308)

This PR implements `raw_socket_stream::read`, which overloads the base
class `raw_fd_stream::read`. `raw_socket_stream::read` provides a way to
timeout the underlying `::read`. The timeout functionality was not added
to `raw_fd_stream::read` to avoid needlessly increasing compile times
and allow for convenient code reuse with `raw_socket_stream::accept`,
which also requires timeout functionality. This PR supports the module
build daemon and will help guarantee it never becomes a zombie process.
This commit is contained in:
Connor Sughrue 2024-07-21 23:50:28 -04:00 committed by GitHub
parent 324fea9baa
commit 76321b9f08
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 154 additions and 71 deletions

View File

@ -92,13 +92,14 @@ public:
/// Accepts an incoming connection on the listening socket. This method can
/// optionally either block until a connection is available or timeout after a
/// specified amount of time has passed. By default the method will block
/// until the socket has recieved a connection.
/// until the socket has recieved a connection. If the accept timesout this
/// method will return std::errc:timed_out
///
/// \param Timeout An optional timeout duration in milliseconds. Setting
/// Timeout to -1 causes accept to block indefinitely
/// Timeout to a negative number causes ::accept to block indefinitely
///
Expected<std::unique_ptr<raw_socket_stream>>
accept(std::chrono::milliseconds Timeout = std::chrono::milliseconds(-1));
Expected<std::unique_ptr<raw_socket_stream>> accept(
const std::chrono::milliseconds &Timeout = std::chrono::milliseconds(-1));
/// Creates a listening socket bound to the specified file system path.
/// Handles the socket creation, binding, and immediately starts listening for
@ -124,11 +125,28 @@ class raw_socket_stream : public raw_fd_stream {
public:
raw_socket_stream(int SocketFD);
~raw_socket_stream();
/// Create a \p raw_socket_stream connected to the UNIX domain socket at \p
/// SocketPath.
static Expected<std::unique_ptr<raw_socket_stream>>
createConnectedUnix(StringRef SocketPath);
~raw_socket_stream();
/// Attempt to read from the raw_socket_stream's file descriptor.
///
/// This method can optionally either block until data is read or an error has
/// occurred or timeout after a specified amount of time has passed. By
/// default the method will block until the socket has read data or
/// encountered an error. If the read times out this method will return
/// std::errc:timed_out
///
/// \param Ptr The start of the buffer that will hold any read data
/// \param Size The number of bytes to be read
/// \param Timeout An optional timeout duration in milliseconds
///
ssize_t read(
char *Ptr, size_t Size,
const std::chrono::milliseconds &Timeout = std::chrono::milliseconds(-1));
};
} // end namespace llvm

View File

@ -18,6 +18,7 @@
#include <atomic>
#include <fcntl.h>
#include <functional>
#include <thread>
#ifndef _WIN32
@ -177,70 +178,89 @@ Expected<ListeningSocket> ListeningSocket::createUnix(StringRef SocketPath,
#endif // _WIN32
}
Expected<std::unique_ptr<raw_socket_stream>>
ListeningSocket::accept(std::chrono::milliseconds Timeout) {
struct pollfd FDs[2];
FDs[0].events = POLLIN;
// If a file descriptor being monitored by ::poll is closed by another thread,
// the result is unspecified. In the case ::poll does not unblock and return,
// when ActiveFD is closed, you can provide another file descriptor via CancelFD
// that when written to will cause poll to return. Typically CancelFD is the
// read end of a unidirectional pipe.
//
// Timeout should be -1 to block indefinitly
//
// getActiveFD is a callback to handle ActiveFD's of std::atomic<int> and int
static std::error_code
manageTimeout(const std::chrono::milliseconds &Timeout,
const std::function<int()> &getActiveFD,
const std::optional<int> &CancelFD = std::nullopt) {
struct pollfd FD[2];
FD[0].events = POLLIN;
#ifdef _WIN32
SOCKET WinServerSock = _get_osfhandle(FD);
FDs[0].fd = WinServerSock;
SOCKET WinServerSock = _get_osfhandle(getActiveFD());
FD[0].fd = WinServerSock;
#else
FDs[0].fd = FD;
FD[0].fd = getActiveFD();
#endif
FDs[1].events = POLLIN;
FDs[1].fd = PipeFD[0];
// Keep track of how much time has passed in case poll is interupted by a
// signal and needs to be recalled
int RemainingTime = Timeout.count();
std::chrono::milliseconds ElapsedTime = std::chrono::milliseconds(0);
int PollStatus = -1;
while (PollStatus == -1 && (Timeout.count() == -1 || ElapsedTime < Timeout)) {
if (Timeout.count() != -1)
RemainingTime -= ElapsedTime.count();
auto Start = std::chrono::steady_clock::now();
#ifdef _WIN32
PollStatus = WSAPoll(FDs, 2, RemainingTime);
#else
PollStatus = ::poll(FDs, 2, RemainingTime);
#endif
// If FD equals -1 then ListeningSocket::shutdown has been called and it is
// appropriate to return operation_canceled
if (FD.load() == -1)
return llvm::make_error<StringError>(
std::make_error_code(std::errc::operation_canceled),
"Accept canceled");
#if _WIN32
if (PollStatus == SOCKET_ERROR) {
#else
if (PollStatus == -1) {
#endif
std::error_code PollErrCode = getLastSocketErrorCode();
// Ignore EINTR (signal occured before any request event) and retry
if (PollErrCode != std::errc::interrupted)
return llvm::make_error<StringError>(PollErrCode, "FD poll failed");
}
if (PollStatus == 0)
return llvm::make_error<StringError>(
std::make_error_code(std::errc::timed_out),
"No client requests within timeout window");
if (FDs[0].revents & POLLNVAL)
return llvm::make_error<StringError>(
std::make_error_code(std::errc::bad_file_descriptor));
auto Stop = std::chrono::steady_clock::now();
ElapsedTime +=
std::chrono::duration_cast<std::chrono::milliseconds>(Stop - Start);
uint8_t FDCount = 1;
if (CancelFD.has_value()) {
FD[1].events = POLLIN;
FD[1].fd = CancelFD.value();
FDCount++;
}
// Keep track of how much time has passed in case ::poll or WSAPoll are
// interupted by a signal and need to be recalled
auto Start = std::chrono::steady_clock::now();
auto RemainingTimeout = Timeout;
int PollStatus = 0;
do {
// If Timeout is -1 then poll should block and RemainingTimeout does not
// need to be recalculated
if (PollStatus != 0 && Timeout != std::chrono::milliseconds(-1)) {
auto TotalElapsedTime =
std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now() - Start);
if (TotalElapsedTime >= Timeout)
return std::make_error_code(std::errc::operation_would_block);
RemainingTimeout = Timeout - TotalElapsedTime;
}
#ifdef _WIN32
PollStatus = WSAPoll(FD, FDCount, RemainingTimeout.count());
} while (PollStatus == SOCKET_ERROR &&
getLastSocketErrorCode() == std::errc::interrupted);
#else
PollStatus = ::poll(FD, FDCount, RemainingTimeout.count());
} while (PollStatus == -1 &&
getLastSocketErrorCode() == std::errc::interrupted);
#endif
// If ActiveFD equals -1 or CancelFD has data to be read then the operation
// has been canceled by another thread
if (getActiveFD() == -1 || (CancelFD.has_value() && FD[1].revents & POLLIN))
return std::make_error_code(std::errc::operation_canceled);
#if _WIN32
if (PollStatus == SOCKET_ERROR)
#else
if (PollStatus == -1)
#endif
return getLastSocketErrorCode();
if (PollStatus == 0)
return std::make_error_code(std::errc::timed_out);
if (FD[0].revents & POLLNVAL)
return std::make_error_code(std::errc::bad_file_descriptor);
return std::error_code();
}
Expected<std::unique_ptr<raw_socket_stream>>
ListeningSocket::accept(const std::chrono::milliseconds &Timeout) {
auto getActiveFD = [this]() -> int { return FD; };
std::error_code TimeoutErr = manageTimeout(Timeout, getActiveFD, PipeFD[0]);
if (TimeoutErr)
return llvm::make_error<StringError>(TimeoutErr, "Timeout error");
int AcceptFD;
#ifdef _WIN32
SOCKET WinAcceptSock = ::accept(WinServerSock, NULL, NULL);
SOCKET WinAcceptSock = ::accept(_get_osfhandle(FD), NULL, NULL);
AcceptFD = _open_osfhandle(WinAcceptSock, 0);
#else
AcceptFD = ::accept(FD, NULL, NULL);
@ -295,6 +315,8 @@ ListeningSocket::~ListeningSocket() {
raw_socket_stream::raw_socket_stream(int SocketFD)
: raw_fd_stream(SocketFD, true) {}
raw_socket_stream::~raw_socket_stream() {}
Expected<std::unique_ptr<raw_socket_stream>>
raw_socket_stream::createConnectedUnix(StringRef SocketPath) {
#ifdef _WIN32
@ -306,4 +328,14 @@ raw_socket_stream::createConnectedUnix(StringRef SocketPath) {
return std::make_unique<raw_socket_stream>(*FD);
}
raw_socket_stream::~raw_socket_stream() {}
ssize_t raw_socket_stream::read(char *Ptr, size_t Size,
const std::chrono::milliseconds &Timeout) {
auto getActiveFD = [this]() -> int { return this->get_fd(); };
std::error_code Err = manageTimeout(Timeout, getActiveFD);
// Mimic raw_fd_stream::read error handling behavior
if (Err) {
raw_fd_stream::error_detected(Err);
return -1;
}
return raw_fd_stream::read(Ptr, Size);
}

View File

@ -62,17 +62,18 @@ TEST(raw_socket_streamTest, CLIENT_TO_SERVER_AND_SERVER_TO_CLIENT) {
ssize_t BytesRead = Server.read(Bytes, 8);
std::string string(Bytes, 8);
ASSERT_EQ(Server.has_error(), false);
ASSERT_EQ(8, BytesRead);
ASSERT_EQ("01234567", string);
}
TEST(raw_socket_streamTest, TIMEOUT_PROVIDED) {
TEST(raw_socket_streamTest, READ_WITH_TIMEOUT) {
if (!hasUnixSocketSupport())
GTEST_SKIP();
SmallString<100> SocketPath;
llvm::sys::fs::createUniquePath("timout_provided.sock", SocketPath, true);
llvm::sys::fs::createUniquePath("read_with_timeout.sock", SocketPath, true);
// Make sure socket file does not exist. May still be there from the last test
std::remove(SocketPath.c_str());
@ -82,19 +83,51 @@ TEST(raw_socket_streamTest, TIMEOUT_PROVIDED) {
ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded());
ListeningSocket ServerListener = std::move(*MaybeServerListener);
std::chrono::milliseconds Timeout = std::chrono::milliseconds(100);
Expected<std::unique_ptr<raw_socket_stream>> MaybeClient =
raw_socket_stream::createConnectedUnix(SocketPath);
ASSERT_THAT_EXPECTED(MaybeClient, llvm::Succeeded());
Expected<std::unique_ptr<raw_socket_stream>> MaybeServer =
ServerListener.accept(Timeout);
ASSERT_EQ(llvm::errorToErrorCode(MaybeServer.takeError()),
std::errc::timed_out);
ServerListener.accept();
ASSERT_THAT_EXPECTED(MaybeServer, llvm::Succeeded());
raw_socket_stream &Server = **MaybeServer;
char Bytes[8];
ssize_t BytesRead = Server.read(Bytes, 8, std::chrono::milliseconds(100));
ASSERT_EQ(BytesRead, -1);
ASSERT_EQ(Server.has_error(), true);
ASSERT_EQ(Server.error(), std::errc::timed_out);
Server.clear_error();
}
TEST(raw_socket_streamTest, FILE_DESCRIPTOR_CLOSED) {
TEST(raw_socket_streamTest, ACCEPT_WITH_TIMEOUT) {
if (!hasUnixSocketSupport())
GTEST_SKIP();
SmallString<100> SocketPath;
llvm::sys::fs::createUniquePath("fd_closed.sock", SocketPath, true);
llvm::sys::fs::createUniquePath("accept_with_timeout.sock", SocketPath, true);
// Make sure socket file does not exist. May still be there from the last test
std::remove(SocketPath.c_str());
Expected<ListeningSocket> MaybeServerListener =
ListeningSocket::createUnix(SocketPath);
ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded());
ListeningSocket ServerListener = std::move(*MaybeServerListener);
Expected<std::unique_ptr<raw_socket_stream>> MaybeServer =
ServerListener.accept(std::chrono::milliseconds(100));
ASSERT_EQ(llvm::errorToErrorCode(MaybeServer.takeError()),
std::errc::timed_out);
}
TEST(raw_socket_streamTest, ACCEPT_WITH_SHUTDOWN) {
if (!hasUnixSocketSupport())
GTEST_SKIP();
SmallString<100> SocketPath;
llvm::sys::fs::createUniquePath("accept_with_shutdown.sock", SocketPath,
true);
// Make sure socket file does not exist. May still be there from the last test
std::remove(SocketPath.c_str());