[llvm][Support] ListeningSocket::accept returns operation_canceled if FD is set to -1 (#89479)

If `::poll` returns and `FD` equals -1, then `ListeningSocket::shutdown`
has been called. So, regardless of any other information that could be
gleaned from `FDs.revents` or `PollStatus`, it is appropriate to return
`std::errc::operation_canceled`. `ListeningSocket::shutdown` copies
`FD`'s value to `ObservedFD` then sets `FD` to -1 before canceling
`::poll` by calling `::close(ObservedFD)` and writing to the pipe.
This commit is contained in:
Connor Sughrue 2024-05-21 20:32:11 -04:00 committed by GitHub
parent 0170bd5d11
commit 203232ffbd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 25 deletions

View File

@ -204,17 +204,26 @@ ListeningSocket::accept(std::chrono::milliseconds Timeout) {
auto Start = std::chrono::steady_clock::now();
#ifdef _WIN32
PollStatus = WSAPoll(FDs, 2, RemainingTime);
if (PollStatus == SOCKET_ERROR) {
#else
PollStatus = ::poll(FDs, 2, RemainingTime);
#endif
// If FD equals -1 then ListeningSocket::shutdown has been called and it is
// appropriate to return operation_canceled
if (FD.load() == -1)
return llvm::make_error<StringError>(
std::make_error_code(std::errc::operation_canceled),
"Accept canceled");
#if _WIN32
if (PollStatus == SOCKET_ERROR) {
#else
if (PollStatus == -1) {
#endif
// Ignore error if caused by interupting signal
std::error_code PollErrCode = getLastSocketErrorCode();
// Ignore EINTR (signal occured before any request event) and retry
if (PollErrCode != std::errc::interrupted)
return llvm::make_error<StringError>(PollErrCode, "FD poll failed");
}
if (PollStatus == 0)
return llvm::make_error<StringError>(
std::make_error_code(std::errc::timed_out),
@ -222,13 +231,7 @@ ListeningSocket::accept(std::chrono::milliseconds Timeout) {
if (FDs[0].revents & POLLNVAL)
return llvm::make_error<StringError>(
std::make_error_code(std::errc::bad_file_descriptor),
"File descriptor closed by another thread");
if (FDs[1].revents & POLLIN)
return llvm::make_error<StringError>(
std::make_error_code(std::errc::operation_canceled),
"Accept canceled");
std::make_error_code(std::errc::bad_file_descriptor));
auto Stop = std::chrono::steady_clock::now();
ElapsedTime +=

View File

@ -7,7 +7,6 @@
#include "llvm/Testing/Support/Error.h"
#include "gtest/gtest.h"
#include <future>
#include <iostream>
#include <stdlib.h>
#include <thread>
@ -86,13 +85,8 @@ TEST(raw_socket_streamTest, TIMEOUT_PROVIDED) {
std::chrono::milliseconds Timeout = std::chrono::milliseconds(100);
Expected<std::unique_ptr<raw_socket_stream>> MaybeServer =
ServerListener.accept(Timeout);
ASSERT_THAT_EXPECTED(MaybeServer, Failed());
llvm::Error Err = MaybeServer.takeError();
llvm::handleAllErrors(std::move(Err), [&](const llvm::StringError &SE) {
std::error_code EC = SE.convertToErrorCode();
ASSERT_EQ(EC, std::errc::timed_out);
});
ASSERT_EQ(llvm::errorToErrorCode(MaybeServer.takeError()),
std::errc::timed_out);
}
TEST(raw_socket_streamTest, FILE_DESCRIPTOR_CLOSED) {
@ -122,12 +116,7 @@ TEST(raw_socket_streamTest, FILE_DESCRIPTOR_CLOSED) {
// Wait for the CloseThread to finish
CloseThread.join();
ASSERT_THAT_EXPECTED(MaybeServer, Failed());
llvm::Error Err = MaybeServer.takeError();
llvm::handleAllErrors(std::move(Err), [&](const llvm::StringError &SE) {
std::error_code EC = SE.convertToErrorCode();
ASSERT_EQ(EC, std::errc::operation_canceled);
});
ASSERT_EQ(llvm::errorToErrorCode(MaybeServer.takeError()),
std::errc::operation_canceled);
}
} // namespace