From 94b70bfda7fa32ed6b5c247df06de936d96443dc Mon Sep 17 00:00:00 2001 From: Yunze Xu Date: Wed, 11 Mar 2026 21:43:33 +0800 Subject: [PATCH 1/5] Fix crash due to asio object lifetime and thread safety issue --- .github/workflows/ci-pr-validation.yaml | 9 + lib/ClientConnection.cc | 534 +++++++++++------------- lib/ClientConnection.h | 48 ++- lib/ConnectionPool.cc | 37 +- lib/ExecutorService.cc | 2 - lib/ExecutorService.h | 18 +- lib/PeriodicTask.h | 2 +- tests/ClientTest.cc | 2 +- tests/MultiTopicsConsumerTest.cc | 2 +- 9 files changed, 346 insertions(+), 308 deletions(-) diff --git a/.github/workflows/ci-pr-validation.yaml b/.github/workflows/ci-pr-validation.yaml index b5a0973c..d209b533 100644 --- a/.github/workflows/ci-pr-validation.yaml +++ b/.github/workflows/ci-pr-validation.yaml @@ -260,6 +260,15 @@ jobs: Pop-Location } + - name: Ensure vcpkg has full history(windows) + if: runner.os == 'Windows' + shell: pwsh + run: | + $isShallow = (git -C "${{ env.VCPKG_ROOT }}" rev-parse --is-shallow-repository).Trim() + if ($isShallow -eq "true") { + git -C "${{ env.VCPKG_ROOT }}" fetch --unshallow + } + - name: remove system vcpkg(windows) if: runner.os == 'Windows' run: rm -rf "$VCPKG_INSTALLATION_ROOT" diff --git a/lib/ClientConnection.cc b/lib/ClientConnection.cc index 0a850ed0..b1d45b3f 100644 --- a/lib/ClientConnection.cc +++ b/lib/ClientConnection.cc @@ -189,10 +189,9 @@ ClientConnection::ClientConnection(const std::string& logicalAddress, const std: executor_(executor), resolver_(executor_->createTcpResolver()), socket_(executor_->createSocket()), - strand_(ASIO::make_strand(executor_->getIOService().get_executor())), logicalAddress_(logicalAddress), physicalAddress_(physicalAddress), - cnxString_("[ -> " + physicalAddress + "] "), + cnxStringPtr_(std::make_shared("[ -> " + physicalAddress + "] ")), incomingBuffer_(SharedBuffer::allocate(DefaultBufferSize)), connectTimeoutTask_( std::make_shared(*executor_, clientConfiguration.getConnectionTimeout())), @@ -203,7 +202,8 @@ ClientConnection::ClientConnection(const std::string& logicalAddress, const std: clientVersion_(clientVersion), pool_(pool), poolIndex_(poolIndex) { - LOG_INFO(cnxString_ << "Create ClientConnection, timeout=" << clientConfiguration.getConnectionTimeout()); + LOG_INFO(cnxString() << "Create ClientConnection, timeout=" + << clientConfiguration.getConnectionTimeout()); if (!authentication_) { LOG_ERROR("Invalid authentication plugin"); throw ResultAuthenticationError; @@ -295,12 +295,12 @@ ClientConnection::ClientConnection(const std::string& logicalAddress, const std: } ClientConnection::~ClientConnection() { - LOG_INFO(cnxString_ << "Destroyed connection to " << logicalAddress_ << "-" << poolIndex_); + LOG_INFO(cnxString() << "Destroyed connection to " << logicalAddress_ << "-" << poolIndex_); } void ClientConnection::handlePulsarConnected(const proto::CommandConnected& cmdConnected) { if (!cmdConnected.has_server_version()) { - LOG_ERROR(cnxString_ << "Server version is not set"); + LOG_ERROR(cnxString() << "Server version is not set"); close(); return; } @@ -314,11 +314,10 @@ void ClientConnection::handlePulsarConnected(const proto::CommandConnected& cmdC Lock lock(mutex_); if (isClosed()) { - LOG_INFO(cnxString_ << "Connection already closed"); + LOG_INFO(cnxString() << "Connection already closed"); return; } state_ = Ready; - connectTimeoutTask_->stop(); serverProtocolVersion_ = cmdConnected.protocol_version(); if (serverProtocolVersion_ >= proto::v1) { @@ -326,13 +325,8 @@ void ClientConnection::handlePulsarConnected(const proto::CommandConnected& cmdC keepAliveTimer_ = executor_->createDeadlineTimer(); if (keepAliveTimer_) { keepAliveTimer_->expires_after(std::chrono::seconds(keepAliveIntervalInSeconds_)); - auto weakSelf = weak_from_this(); - keepAliveTimer_->async_wait([weakSelf](const ASIO_ERROR&) { - auto self = weakSelf.lock(); - if (self) { - self->handleKeepAliveTimeout(); - } - }); + keepAliveTimer_->async_wait( + [this, self{shared_from_this()}](const ASIO_ERROR& err) { handleKeepAliveTimeout(err); }); } } @@ -352,12 +346,12 @@ void ClientConnection::startConsumerStatsTimer(std::vector consumerSta for (int i = 0; i < consumerStatsRequests.size(); i++) { PendingConsumerStatsMap::iterator it = pendingConsumerStatsMap_.find(consumerStatsRequests[i]); if (it != pendingConsumerStatsMap_.end()) { - LOG_DEBUG(cnxString_ << " removing request_id " << it->first - << " from the pendingConsumerStatsMap_"); + LOG_DEBUG(cnxString() << " removing request_id " << it->first + << " from the pendingConsumerStatsMap_"); consumerStatsPromises.push_back(it->second); pendingConsumerStatsMap_.erase(it); } else { - LOG_DEBUG(cnxString_ << "request_id " << it->first << " already fulfilled - not removing it"); + LOG_DEBUG(cnxString() << "request_id " << it->first << " already fulfilled - not removing it"); } } @@ -371,19 +365,16 @@ void ClientConnection::startConsumerStatsTimer(std::vector consumerSta // Check if we have a timer still before we set the request timer to pop again. if (consumerStatsRequestTimer_) { consumerStatsRequestTimer_->expires_after(operationsTimeout_); - auto weakSelf = weak_from_this(); - consumerStatsRequestTimer_->async_wait([weakSelf, consumerStatsRequests](const ASIO_ERROR& err) { - auto self = weakSelf.lock(); - if (self) { - self->handleConsumerStatsTimeout(err, consumerStatsRequests); - } - }); + consumerStatsRequestTimer_->async_wait( + [this, self{shared_from_this()}, consumerStatsRequests](const ASIO_ERROR& err) { + handleConsumerStatsTimeout(err, consumerStatsRequests); + }); } lock.unlock(); // Complex logic since promises need to be fulfilled outside the lock for (int i = 0; i < consumerStatsPromises.size(); i++) { consumerStatsPromises[i].setFailed(ResultTimeout); - LOG_WARN(cnxString_ << " Operation timedout, didn't get response from broker"); + LOG_WARN(cnxString() << " Operation timedout, didn't get response from broker"); } } @@ -416,37 +407,38 @@ void ClientConnection::handleTcpConnected(const ASIO_ERROR& err, const tcp::endp try { cnxStringStream << "[" << socket_->local_endpoint() << " -> " << socket_->remote_endpoint() << "] "; - cnxString_ = cnxStringStream.str(); + std::atomic_store(&cnxStringPtr_, std::make_shared(cnxStringStream.str())); } catch (const ASIO_SYSTEM_ERROR& e) { LOG_ERROR("Failed to get endpoint: " << e.what()); close(ResultRetryable); return; } if (logicalAddress_ == physicalAddress_) { - LOG_INFO(cnxString_ << "Connected to broker"); + LOG_INFO(cnxString() << "Connected to broker"); } else { - LOG_INFO(cnxString_ << "Connected to broker through proxy. Logical broker: " << logicalAddress_ - << ", proxy: " << proxyServiceUrl_ - << ", physical address:" << physicalAddress_); + LOG_INFO(cnxString() << "Connected to broker through proxy. Logical broker: " << logicalAddress_ + << ", proxy: " << proxyServiceUrl_ + << ", physical address:" << physicalAddress_); } Lock lock(mutex_); - if (isClosed()) { - LOG_INFO(cnxString_ << "Connection already closed"); + if (isClosed() || !connectTimeoutTask_) { + LOG_INFO(cnxString() << "Connection already closed"); return; } + connectTimeoutTask_->stop(); state_ = TcpConnected; lock.unlock(); ASIO_ERROR error; socket_->set_option(tcp::no_delay(true), error); if (error) { - LOG_WARN(cnxString_ << "Socket failed to set tcp::no_delay: " << error.message()); + LOG_WARN(cnxString() << "Socket failed to set tcp::no_delay: " << error.message()); } socket_->set_option(tcp::socket::keep_alive(true), error); if (error) { - LOG_WARN(cnxString_ << "Socket failed to set tcp::socket::keep_alive: " << error.message()); + LOG_WARN(cnxString() << "Socket failed to set tcp::socket::keep_alive: " << error.message()); } // Start TCP keep-alive probes after connection has been idle after 1 minute. Ideally this @@ -454,19 +446,19 @@ void ClientConnection::handleTcpConnected(const ASIO_ERROR& err, const tcp::endp // connection) every 30 seconds socket_->set_option(tcp_keep_alive_idle(1 * 60), error); if (error) { - LOG_DEBUG(cnxString_ << "Socket failed to set tcp_keep_alive_idle: " << error.message()); + LOG_DEBUG(cnxString() << "Socket failed to set tcp_keep_alive_idle: " << error.message()); } // Send up to 10 probes before declaring the connection broken socket_->set_option(tcp_keep_alive_count(10), error); if (error) { - LOG_DEBUG(cnxString_ << "Socket failed to set tcp_keep_alive_count: " << error.message()); + LOG_DEBUG(cnxString() << "Socket failed to set tcp_keep_alive_count: " << error.message()); } // Interval between probes: 6 seconds socket_->set_option(tcp_keep_alive_interval(6), error); if (error) { - LOG_DEBUG(cnxString_ << "Socket failed to set tcp_keep_alive_interval: " << error.message()); + LOG_DEBUG(cnxString() << "Socket failed to set tcp_keep_alive_interval: " << error.message()); } if (tlsSocket_) { @@ -474,29 +466,29 @@ void ClientConnection::handleTcpConnected(const ASIO_ERROR& err, const tcp::endp ASIO_ERROR err; Url service_url; if (!Url::parse(physicalAddress_, service_url)) { - LOG_ERROR(cnxString_ << "Invalid Url, unable to parse: " << err << " " << err.message()); + LOG_ERROR(cnxString() << "Invalid Url, unable to parse: " << err << " " << err.message()); close(); return; } } - auto weakSelf = weak_from_this(); - auto socket = socket_; - auto tlsSocket = tlsSocket_; // socket and ssl::stream objects must exist until async_handshake is done, otherwise segmentation // fault might happen - auto callback = [weakSelf, socket, tlsSocket](const ASIO_ERROR& err) { - auto self = weakSelf.lock(); - if (self) { - self->handleHandshake(err); - } - }; - tlsSocket_->async_handshake(ASIO::ssl::stream::client, - ASIO::bind_executor(strand_, callback)); + tlsSocket_->async_handshake( + ASIO::ssl::stream::client, + [this, self{shared_from_this()}](const auto& err) { handleHandshake(err); }); } else { handleHandshake(ASIO_SUCCESS); } } else { - LOG_ERROR(cnxString_ << "Failed to establish connection to " << endpoint << ": " << err.message()); + LOG_ERROR(cnxString() << "Failed to establish connection to " << endpoint << ": " << err.message()); + { + std::lock_guard lock{mutex_}; + if (isClosed() || !connectTimeoutTask_) { + return; + } + connectTimeoutTask_->stop(); + connectTimeoutTask_.reset(); // clear the callback, which holds a `shared_from_this()` + } if (err == ASIO::error::operation_aborted) { close(); } else { @@ -508,10 +500,10 @@ void ClientConnection::handleTcpConnected(const ASIO_ERROR& err, const tcp::endp void ClientConnection::handleHandshake(const ASIO_ERROR& err) { if (err) { if (err.value() == ASIO::ssl::error::stream_truncated) { - LOG_WARN(cnxString_ << "Handshake failed: " << err.message()); + LOG_WARN(cnxString() << "Handshake failed: " << err.message()); close(ResultRetryable); } else { - LOG_ERROR(cnxString_ << "Handshake failed: " << err.message()); + LOG_ERROR(cnxString() << "Handshake failed: " << err.message()); close(); } return; @@ -524,12 +516,12 @@ void ClientConnection::handleHandshake(const ASIO_ERROR& err) { buffer = Commands::newConnect(authentication_, logicalAddress_, connectingThroughProxy, clientVersion_, result); } catch (const std::exception& e) { - LOG_ERROR(cnxString_ << "Failed to create Connect command: " << e.what()); + LOG_ERROR(cnxString() << "Failed to create Connect command: " << e.what()); close(ResultAuthenticationError); return; } if (result != ResultOk) { - LOG_ERROR(cnxString_ << "Failed to establish connection: " << result); + LOG_ERROR(cnxString() << "Failed to establish connection: " << result); close(result); return; } @@ -546,7 +538,7 @@ void ClientConnection::handleSentPulsarConnect(const ASIO_ERROR& err, const Shar return; } if (err) { - LOG_ERROR(cnxString_ << "Failed to establish connection: " << err.message()); + LOG_ERROR(cnxString() << "Failed to establish connection: " << err.message()); close(); return; } @@ -560,7 +552,7 @@ void ClientConnection::handleSentAuthResponse(const ASIO_ERROR& err, const Share return; } if (err) { - LOG_WARN(cnxString_ << "Failed to send auth response: " << err.message()); + LOG_WARN(cnxString() << "Failed to send auth response: " << err.message()); close(); return; } @@ -581,73 +573,68 @@ void ClientConnection::tcpConnectAsync() { Url service_url; std::string hostUrl = isSniProxy_ ? proxyServiceUrl_ : physicalAddress_; if (!Url::parse(hostUrl, service_url)) { - LOG_ERROR(cnxString_ << "Invalid Url, unable to parse: " << err << " " << err.message()); + LOG_ERROR(cnxString() << "Invalid Url, unable to parse: " << err << " " << err.message()); close(); return; } if (service_url.protocol() != "pulsar" && service_url.protocol() != "pulsar+ssl") { - LOG_ERROR(cnxString_ << "Invalid Url protocol '" << service_url.protocol() - << "'. Valid values are 'pulsar' and 'pulsar+ssl'"); + LOG_ERROR(cnxString() << "Invalid Url protocol '" << service_url.protocol() + << "'. Valid values are 'pulsar' and 'pulsar+ssl'"); close(); return; } - LOG_DEBUG(cnxString_ << "Resolving " << service_url.host() << ":" << service_url.port()); + LOG_DEBUG(cnxString() << "Resolving " << service_url.host() << ":" << service_url.port()); - auto weakSelf = weak_from_this(); - resolver_->async_resolve(service_url.host(), std::to_string(service_url.port()), - [weakSelf](auto err, const auto& results) { - auto self = weakSelf.lock(); - if (self) { - self->handleResolve(err, results); - } - }); + resolver_->async_resolve( + service_url.host(), std::to_string(service_url.port()), + [this, self{shared_from_this()}](auto err, const auto& results) { handleResolve(err, results); }); } void ClientConnection::handleResolve(ASIO_ERROR err, const tcp::resolver::results_type& results) { if (err) { - std::string hostUrl = isSniProxy_ ? cnxString_ : proxyServiceUrl_; + std::string hostUrl = isSniProxy_ ? cnxString() : proxyServiceUrl_; LOG_ERROR(hostUrl << "Resolve error: " << err << " : " << err.message()); close(); return; } if (!results.empty()) { - LOG_DEBUG(cnxString_ << "Resolved " << results.size() << " endpoints"); + LOG_DEBUG(cnxString() << "Resolved " << results.size() << " endpoints"); for (const auto& entry : results) { const auto& ep = entry.endpoint(); - LOG_DEBUG(cnxString_ << " " << ep.address().to_string() << ":" << ep.port()); + LOG_DEBUG(cnxString() << " " << ep.address().to_string() << ":" << ep.port()); } } - auto weakSelf = weak_from_this(); - connectTimeoutTask_->setCallback([weakSelf, results = tcp::resolver::results_type(results)]( - const PeriodicTask::ErrorCode& ec) { - ClientConnectionPtr ptr = weakSelf.lock(); - if (!ptr) { - LOG_DEBUG("Connect timeout callback skipped: connection was already destroyed"); - return; - } - - if (ptr->state_ != Ready) { - LOG_ERROR(ptr->cnxString_ << "Connection to " << results << " was not established in " - << ptr->connectTimeoutTask_->getPeriodMs() << " ms, close the socket"); - PeriodicTask::ErrorCode err; - ptr->socket_->close(err); - if (err) { - LOG_WARN(ptr->cnxString_ << "Failed to close socket: " << err.message()); + // Acquire the lock to prevent the race: + // 1. thread 1: isClosed() returns false + // 2. thread 2: call `connectTimeoutTask_->stop()` and `connectTimeoutTask_.reset()` in `close()` + // 3. thread 1: call `connectTimeoutTask_->setCallback()` and `connectTimeoutTask_->start()` + // Then the self captured in the callback of `connectTimeoutTask_` would be kept alive unexpectedly and + // cannot be cancelled until the executor is destroyed. + std::lock_guard lock{mutex_}; + if (isClosed() || !connectTimeoutTask_) { + return; + } + connectTimeoutTask_->setCallback( + [this, self{shared_from_this()}, + results = tcp::resolver::results_type(results)](const PeriodicTask::ErrorCode& ec) { + if (state_ != Ready) { + LOG_ERROR(cnxString() << "Connection to " << results << " was not established in " + << connectTimeoutTask_->getPeriodMs() << " ms"); + close(); + } else { + connectTimeoutTask_->stop(); } - } - ptr->connectTimeoutTask_->stop(); - }); + }); connectTimeoutTask_->start(); - ASIO::async_connect(*socket_, results, [weakSelf](const ASIO_ERROR& err, const tcp::endpoint& endpoint) { - auto self = weakSelf.lock(); - if (self) { - self->handleTcpConnected(err, endpoint); - } - }); + ASIO::async_connect( + *socket_, results, + [this, self{shared_from_this()}](const ASIO_ERROR& err, const tcp::endpoint& endpoint) { + handleTcpConnected(err, endpoint); + }); } void ClientConnection::readNextCommand() { @@ -668,11 +655,11 @@ void ClientConnection::handleRead(const ASIO_ERROR& err, size_t bytesTransferred if (err || bytesTransferred == 0) { if (err == ASIO::error::operation_aborted) { - LOG_DEBUG(cnxString_ << "Read operation was canceled: " << err.message()); + LOG_DEBUG(cnxString() << "Read operation was canceled: " << err.message()); } else if (bytesTransferred == 0 || err == ASIO::error::eof) { - LOG_DEBUG(cnxString_ << "Server closed the connection: " << err.message()); + LOG_DEBUG(cnxString() << "Server closed the connection: " << err.message()); } else { - LOG_ERROR(cnxString_ << "Read operation failed: " << err.message()); + LOG_ERROR(cnxString() << "Read operation failed: " << err.message()); } close(ResultDisconnected); } else if (bytesTransferred < minReadSize) { @@ -724,7 +711,7 @@ void ClientConnection::processIncomingBuffer() { uint32_t cmdSize = incomingBuffer_.readUnsignedInt(); proto::BaseCommand incomingCmd; if (!incomingCmd.ParseFromArray(incomingBuffer_.data(), cmdSize)) { - LOG_ERROR(cnxString_ << "Error parsing protocol buffer command"); + LOG_ERROR(cnxString() << "Error parsing protocol buffer command"); close(ResultDisconnected); return; } @@ -744,11 +731,11 @@ void ClientConnection::processIncomingBuffer() { // broker entry metadata is present uint32_t brokerEntryMetadataSize = incomingBuffer_.readUnsignedInt(); if (!brokerEntryMetadata.ParseFromArray(incomingBuffer_.data(), brokerEntryMetadataSize)) { - LOG_ERROR(cnxString_ << "[consumer id " << incomingCmd.message().consumer_id() - << ", message ledger id " - << incomingCmd.message().message_id().ledgerid() << ", entry id " - << incomingCmd.message().message_id().entryid() - << "] Error parsing broker entry metadata"); + LOG_ERROR(cnxString() + << "[consumer id " << incomingCmd.message().consumer_id() + << ", message ledger id " << incomingCmd.message().message_id().ledgerid() + << ", entry id " << incomingCmd.message().message_id().entryid() + << "] Error parsing broker entry metadata"); close(ResultDisconnected); return; } @@ -762,11 +749,11 @@ void ClientConnection::processIncomingBuffer() { uint32_t metadataSize = incomingBuffer_.readUnsignedInt(); if (!msgMetadata.ParseFromArray(incomingBuffer_.data(), metadataSize)) { - LOG_ERROR(cnxString_ << "[consumer id " << incomingCmd.message().consumer_id() // - << ", message ledger id " - << incomingCmd.message().message_id().ledgerid() // - << ", entry id " << incomingCmd.message().message_id().entryid() - << "] Error parsing message metadata"); + LOG_ERROR(cnxString() + << "[consumer id " << incomingCmd.message().consumer_id() // + << ", message ledger id " << incomingCmd.message().message_id().ledgerid() // + << ", entry id " << incomingCmd.message().message_id().entryid() + << "] Error parsing message metadata"); close(ResultDisconnected); return; } @@ -839,8 +826,8 @@ bool ClientConnection::verifyChecksum(SharedBuffer& incomingBuffer_, uint32_t& r } void ClientConnection::handleActiveConsumerChange(const proto::CommandActiveConsumerChange& change) { - LOG_DEBUG(cnxString_ << "Received notification about active consumer change, consumer_id: " - << change.consumer_id() << " isActive: " << change.is_active()); + LOG_DEBUG(cnxString() << "Received notification about active consumer change, consumer_id: " + << change.consumer_id() << " isActive: " << change.is_active()); Lock lock(mutex_); ConsumersMap::iterator it = consumers_.find(change.consumer_id()); if (it != consumers_.end()) { @@ -851,19 +838,19 @@ void ClientConnection::handleActiveConsumerChange(const proto::CommandActiveCons consumer->activeConsumerChanged(change.is_active()); } else { consumers_.erase(change.consumer_id()); - LOG_DEBUG(cnxString_ << "Ignoring incoming message for already destroyed consumer " - << change.consumer_id()); + LOG_DEBUG(cnxString() << "Ignoring incoming message for already destroyed consumer " + << change.consumer_id()); } } else { - LOG_DEBUG(cnxString_ << "Got invalid consumer Id in " << change.consumer_id() - << " -- isActive: " << change.is_active()); + LOG_DEBUG(cnxString() << "Got invalid consumer Id in " << change.consumer_id() + << " -- isActive: " << change.is_active()); } } void ClientConnection::handleIncomingMessage(const proto::CommandMessage& msg, bool isChecksumValid, proto::BrokerEntryMetadata& brokerEntryMetadata, proto::MessageMetadata& msgMetadata, SharedBuffer& payload) { - LOG_DEBUG(cnxString_ << "Received a message from the server for consumer: " << msg.consumer_id()); + LOG_DEBUG(cnxString() << "Received a message from the server for consumer: " << msg.consumer_id()); Lock lock(mutex_); ConsumersMap::iterator it = consumers_.find(msg.consumer_id()); @@ -878,21 +865,21 @@ void ClientConnection::handleIncomingMessage(const proto::CommandMessage& msg, b msgMetadata, payload); } else { consumers_.erase(msg.consumer_id()); - LOG_DEBUG(cnxString_ << "Ignoring incoming message for already destroyed consumer " - << msg.consumer_id()); + LOG_DEBUG(cnxString() << "Ignoring incoming message for already destroyed consumer " + << msg.consumer_id()); } } else { - LOG_DEBUG(cnxString_ << "Got invalid consumer Id in " // - << msg.consumer_id() << " -- msg: " << msgMetadata.sequence_id()); + LOG_DEBUG(cnxString() << "Got invalid consumer Id in " // + << msg.consumer_id() << " -- msg: " << msgMetadata.sequence_id()); } } void ClientConnection::handleIncomingCommand(BaseCommand& incomingCmd) { - LOG_DEBUG(cnxString_ << "Handling incoming command: " << Commands::messageType(incomingCmd.type())); + LOG_DEBUG(cnxString() << "Handling incoming command: " << Commands::messageType(incomingCmd.type())); switch (state_.load()) { case Pending: { - LOG_ERROR(cnxString_ << "Connection is not ready yet"); + LOG_ERROR(cnxString() << "Connection is not ready yet"); break; } @@ -908,7 +895,7 @@ void ClientConnection::handleIncomingCommand(BaseCommand& incomingCmd) { } case Disconnected: { - LOG_ERROR(cnxString_ << "Connection already disconnected"); + LOG_ERROR(cnxString() << "Connection already disconnected"); break; } @@ -967,12 +954,12 @@ void ClientConnection::handleIncomingCommand(BaseCommand& incomingCmd) { case BaseCommand::PING: // Respond to ping request - LOG_DEBUG(cnxString_ << "Replying to ping command"); + LOG_DEBUG(cnxString() << "Replying to ping command"); sendCommand(Commands::newPong()); break; case BaseCommand::PONG: - LOG_DEBUG(cnxString_ << "Received response to ping message"); + LOG_DEBUG(cnxString() << "Received response to ping message"); break; case BaseCommand::AUTH_CHALLENGE: @@ -1000,7 +987,7 @@ void ClientConnection::handleIncomingCommand(BaseCommand& incomingCmd) { break; default: - LOG_WARN(cnxString_ << "Received invalid message from server"); + LOG_WARN(cnxString() << "Received invalid message from server"); close(ResultDisconnected); break; } @@ -1014,7 +1001,7 @@ Future ClientConnection::newConsumerStats(uint6 Promise promise; if (isClosed()) { lock.unlock(); - LOG_ERROR(cnxString_ << " Client is not connected to the broker"); + LOG_ERROR(cnxString() << " Client is not connected to the broker"); promise.setFailed(ResultNotConnected); return promise.getFuture(); } @@ -1059,18 +1046,14 @@ void ClientConnection::newLookup(const SharedBuffer& cmd, uint64_t requestId, co requestData.promise = promise; requestData.timer = executor_->createDeadlineTimer(); requestData.timer->expires_after(operationsTimeout_); - auto weakSelf = weak_from_this(); - requestData.timer->async_wait([weakSelf, requestData](const ASIO_ERROR& ec) { - auto self = weakSelf.lock(); - if (self) { - self->handleLookupTimeout(ec, requestData); - } + requestData.timer->async_wait([this, self{shared_from_this()}, requestData](const ASIO_ERROR& ec) { + handleLookupTimeout(ec, requestData); }); pendingLookupRequests_.insert(std::make_pair(requestId, requestData)); numOfPendingLookupRequest_++; lock.unlock(); - LOG_DEBUG(cnxString_ << "Inserted lookup request " << requestType << " (req_id: " << requestId << ")"); + LOG_DEBUG(cnxString() << "Inserted lookup request " << requestType << " (req_id: " << requestId << ")"); sendCommand(cmd); } @@ -1079,18 +1062,7 @@ void ClientConnection::sendCommand(const SharedBuffer& cmd) { if (pendingWriteOperations_++ == 0) { // Write immediately to socket - if (tlsSocket_) { - auto weakSelf = weak_from_this(); - auto callback = [weakSelf, cmd]() { - auto self = weakSelf.lock(); - if (self) { - self->sendCommandInternal(cmd); - } - }; - ASIO::post(strand_, callback); - } else { - sendCommandInternal(cmd); - } + executor_->dispatch([this, cmd, self{shared_from_this()}] { sendCommandInternal(cmd); }); } else { // Queue to send later pendingWriteBuffers_.push_back(cmd); @@ -1122,11 +1094,7 @@ void ClientConnection::sendMessage(const std::shared_ptr& args) { handleSendPair(err); })); }; - if (tlsSocket_) { - ASIO::post(strand_, sendMessageInternal); - } else { - sendMessageInternal(); - } + executor_->dispatch(sendMessageInternal); } void ClientConnection::handleSend(const ASIO_ERROR& err, const SharedBuffer&) { @@ -1134,7 +1102,7 @@ void ClientConnection::handleSend(const ASIO_ERROR& err, const SharedBuffer&) { return; } if (err) { - LOG_WARN(cnxString_ << "Could not send message on connection: " << err << " " << err.message()); + LOG_WARN(cnxString() << "Could not send message on connection: " << err << " " << err.message()); close(ResultDisconnected); } else { sendPendingCommands(); @@ -1146,7 +1114,7 @@ void ClientConnection::handleSendPair(const ASIO_ERROR& err) { return; } if (err) { - LOG_WARN(cnxString_ << "Could not send pair message on connection: " << err << " " << err.message()); + LOG_WARN(cnxString() << "Could not send pair message on connection: " << err << " " << err.message()); close(ResultDisconnected); } else { sendPendingCommands(); @@ -1194,8 +1162,8 @@ Future ClientConnection::sendRequestWithId(const SharedBuf if (isClosed()) { lock.unlock(); Promise promise; - LOG_DEBUG(cnxString_ << "Fail " << requestType << "(req_id: " << requestId - << ") to a closed connection"); + LOG_DEBUG(cnxString() << "Fail " << requestType << "(req_id: " << requestId + << ") to a closed connection"); promise.setFailed(ResultNotConnected); return promise.getFuture(); } @@ -1203,21 +1171,17 @@ Future ClientConnection::sendRequestWithId(const SharedBuf PendingRequestData requestData; requestData.timer = executor_->createDeadlineTimer(); requestData.timer->expires_after(operationsTimeout_); - auto weakSelf = weak_from_this(); - requestData.timer->async_wait([weakSelf, requestData](const ASIO_ERROR& ec) { - auto self = weakSelf.lock(); - if (self) { - self->handleRequestTimeout(ec, requestData); - } + requestData.timer->async_wait([this, self{shared_from_this()}, requestData](const ASIO_ERROR& ec) { + handleRequestTimeout(ec, requestData); }); pendingRequests_.insert(std::make_pair(requestId, requestData)); lock.unlock(); - LOG_DEBUG(cnxString_ << "Inserted request " << requestType << " (req_id: " << requestId << ")"); + LOG_DEBUG(cnxString() << "Inserted request " << requestType << " (req_id: " << requestId << ")"); if (mockingRequests_.load(std::memory_order_acquire)) { if (mockServer_ == nullptr) { - LOG_WARN(cnxString_ << "Mock server is unexpectedly null when processing " << requestType); + LOG_WARN(cnxString() << "Mock server is unexpectedly null when processing " << requestType); sendCommand(cmd); } else if (!mockServer_->sendRequest(requestType, requestId)) { sendCommand(cmd); @@ -1231,7 +1195,7 @@ Future ClientConnection::sendRequestWithId(const SharedBuf void ClientConnection::handleRequestTimeout(const ASIO_ERROR& ec, const PendingRequestData& pendingRequestData) { if (!ec && !pendingRequestData.hasGotResponse->load()) { - LOG_WARN(cnxString_ << "Network request timeout to broker, remote: " << physicalAddress_); + LOG_WARN(cnxString() << "Network request timeout to broker, remote: " << physicalAddress_); pendingRequestData.promise.setFailed(ResultTimeout); } } @@ -1239,7 +1203,7 @@ void ClientConnection::handleRequestTimeout(const ASIO_ERROR& ec, void ClientConnection::handleLookupTimeout(const ASIO_ERROR& ec, const LookupRequestData& pendingRequestData) { if (!ec) { - LOG_WARN(cnxString_ << "Lookup request timeout to broker, remote: " << physicalAddress_); + LOG_WARN(cnxString() << "Lookup request timeout to broker, remote: " << physicalAddress_); pendingRequestData.promise->setFailed(ResultTimeout); } } @@ -1247,22 +1211,22 @@ void ClientConnection::handleLookupTimeout(const ASIO_ERROR& ec, void ClientConnection::handleGetLastMessageIdTimeout(const ASIO_ERROR& ec, const ClientConnection::LastMessageIdRequestData& data) { if (!ec) { - LOG_WARN(cnxString_ << "GetLastMessageId request timeout to broker, remote: " << physicalAddress_); + LOG_WARN(cnxString() << "GetLastMessageId request timeout to broker, remote: " << physicalAddress_); data.promise->setFailed(ResultTimeout); } } -void ClientConnection::handleKeepAliveTimeout() { - if (isClosed()) { +void ClientConnection::handleKeepAliveTimeout(const ASIO_ERROR& ec) { + if (isClosed() || ec) { return; } if (havePendingPingRequest_) { - LOG_WARN(cnxString_ << "Forcing connection to close after keep-alive timeout"); + LOG_WARN(cnxString() << "Forcing connection to close after keep-alive timeout"); close(ResultDisconnected); } else { // Send keep alive probe to peer - LOG_DEBUG(cnxString_ << "Sending ping message"); + LOG_DEBUG(cnxString() << "Sending ping message"); havePendingPingRequest_ = true; sendCommand(Commands::newPing()); @@ -1271,13 +1235,8 @@ void ClientConnection::handleKeepAliveTimeout() { Lock lock(mutex_); if (keepAliveTimer_) { keepAliveTimer_->expires_after(std::chrono::seconds(keepAliveIntervalInSeconds_)); - auto weakSelf = weak_from_this(); - keepAliveTimer_->async_wait([weakSelf](const ASIO_ERROR&) { - auto self = weakSelf.lock(); - if (self) { - self->handleKeepAliveTimeout(); - } - }); + keepAliveTimer_->async_wait( + [this, self{shared_from_this()}](const auto& err) { handleKeepAliveTimeout(err); }); } lock.unlock(); } @@ -1286,39 +1245,32 @@ void ClientConnection::handleKeepAliveTimeout() { void ClientConnection::handleConsumerStatsTimeout(const ASIO_ERROR& ec, const std::vector& consumerStatsRequests) { if (ec) { - LOG_DEBUG(cnxString_ << " Ignoring timer cancelled event, code[" << ec << "]"); + LOG_DEBUG(cnxString() << " Ignoring timer cancelled event, code[" << ec << "]"); return; } startConsumerStatsTimer(consumerStatsRequests); } -void ClientConnection::close(Result result, bool detach) { +const std::future& ClientConnection::close(Result result) { Lock lock(mutex_); - if (isClosed()) { - return; - } + if (closeFuture_) { + connectPromise_.setFailed(result); + return *closeFuture_; + } + auto promise = std::make_shared>(); + closeFuture_ = promise->get_future(); + // The atomic update on state_ guarantees the previous modification on closeFuture_ is visible once the + // atomic read on state_ returns Disconnected `isClosed()`. + // However, it cannot prevent the race like: + // 1. thread 1: Check `isClosed()`, which returns false. + // 2. thread 2: call `close()`, now, `state_` becomes Disconnected, and `closeFuture_` is set. + // 3. thread 1: post the `async_write` to the `io_context`, + // 4. io thread: call `socket_->close()` + // 5. io thread: execute `async_write` on `socket_`, which has been closed + // However, even the race happens, it's still safe because all the socket operations happen in the same + // io thread, the `async_write` operation will simply fail with an error, no crash will happen. state_ = Disconnected; - if (socket_) { - ASIO_ERROR err; - socket_->shutdown(ASIO::socket_base::shutdown_both, err); - socket_->close(err); - if (err) { - LOG_WARN(cnxString_ << "Failed to close socket: " << err.message()); - } - } - if (tlsSocket_) { - ASIO_ERROR err; - tlsSocket_->lowest_layer().close(err); - if (err) { - LOG_WARN(cnxString_ << "Failed to close TLS socket: " << err.message()); - } - } - - if (executor_) { - executor_.reset(); - } - // Move the internal fields to process them after `mutex_` was unlocked auto consumers = std::move(consumers_); auto producers = std::move(producers_); @@ -1343,19 +1295,38 @@ void ClientConnection::close(Result result, bool detach) { if (connectTimeoutTask_) { connectTimeoutTask_->stop(); + connectTimeoutTask_.reset(); // clear the callback, which holds a `shared_from_this()` } lock.unlock(); int refCount = weak_from_this().use_count(); if (!isResultRetryable(result)) { - LOG_ERROR(cnxString_ << "Connection closed with " << result << " (refCnt: " << refCount << ")"); + LOG_ERROR(cnxString() << "Connection closed with " << result << " (refCnt: " << refCount << ")"); } else { - LOG_INFO(cnxString_ << "Connection disconnected (refCnt: " << refCount << ")"); + LOG_INFO(cnxString() << "Connection disconnected (refCnt: " << refCount << ")"); } // Remove the connection from the pool before completing any promise - if (detach) { - pool_.remove(logicalAddress_, physicalAddress_, poolIndex_, this); - } + pool_.remove(logicalAddress_, physicalAddress_, poolIndex_, this); + + // Close the socket after removing itself from the pool so that other requests won't be able to acquire + // this connection after the socket is closed. + executor_->dispatch([this, promise, self{shared_from_this()}] { + // According to asio document, ip::tcp::socket and ssl::stream are unsafe as shared objects, so the + // methods must be called within the same implicit or explicit strand. + // The implementation of `ExecutorService` guarantees the internal `io_context::run()` is only called + // in one thread, so we can safely call the socket methods without posting to a strand instance. + ASIO_ERROR err; + socket_->shutdown(ASIO::socket_base::shutdown_both, err); + socket_->close(err); + if (err) { + LOG_WARN(cnxString() << "Failed to close socket: " << err.message()); + } + if (tlsSocket_) { + tlsSocket_->async_shutdown([promise](const auto&) { promise->set_value(); }); + } else { + promise->set_value(); + } + }); auto self = shared_from_this(); for (ProducersMap::iterator it = producers.begin(); it != producers.end(); ++it) { @@ -1377,24 +1348,25 @@ void ClientConnection::close(Result result, bool detach) { // Fail all pending requests, all these type are map whose value type contains the Promise object for (auto& kv : pendingRequests) { - kv.second.promise.setFailed(result); + kv.second.fail(result); } for (auto& kv : pendingLookupRequests) { - kv.second.promise->setFailed(result); + kv.second.fail(result); } for (auto& kv : pendingConsumerStatsMap) { - LOG_ERROR(cnxString_ << " Closing Client Connection, please try again later"); + LOG_ERROR(cnxString() << " Closing Client Connection, please try again later"); kv.second.setFailed(result); } for (auto& kv : pendingGetLastMessageIdRequests) { - kv.second.promise->setFailed(result); + kv.second.fail(result); } for (auto& kv : pendingGetNamespaceTopicsRequests) { kv.second.setFailed(result); } for (auto& kv : pendingGetSchemaRequests) { - kv.second.promise.setFailed(result); + kv.second.fail(result); } + return *closeFuture_; } bool ClientConnection::isClosed() const { return state_ == Disconnected; } @@ -1425,8 +1397,6 @@ void ClientConnection::removeConsumer(int consumerId) { const std::string& ClientConnection::brokerAddress() const { return physicalAddress_; } -const std::string& ClientConnection::cnxString() const { return cnxString_; } - int ClientConnection::getServerProtocolVersion() const { return serverProtocolVersion_; } int32_t ClientConnection::getMaxMessageSize() { return maxMessageSize_.load(std::memory_order_acquire); } @@ -1441,7 +1411,7 @@ Future ClientConnection::newGetLastMessageId(u auto promise = std::make_shared(); if (isClosed()) { lock.unlock(); - LOG_ERROR(cnxString_ << " Client is not connected to the broker"); + LOG_ERROR(cnxString() << " Client is not connected to the broker"); promise->setFailed(ResultNotConnected); return promise->getFuture(); } @@ -1450,12 +1420,8 @@ Future ClientConnection::newGetLastMessageId(u requestData.promise = promise; requestData.timer = executor_->createDeadlineTimer(); requestData.timer->expires_after(operationsTimeout_); - auto weakSelf = weak_from_this(); - requestData.timer->async_wait([weakSelf, requestData](const ASIO_ERROR& ec) { - auto self = weakSelf.lock(); - if (self) { - self->handleGetLastMessageIdTimeout(ec, requestData); - } + requestData.timer->async_wait([this, self{shared_from_this()}, requestData](const ASIO_ERROR& ec) { + handleGetLastMessageIdTimeout(ec, requestData); }); pendingGetLastMessageIdRequests_.insert(std::make_pair(requestId, requestData)); lock.unlock(); @@ -1469,7 +1435,7 @@ Future ClientConnection::newGetTopicsOfNamespace( Promise promise; if (isClosed()) { lock.unlock(); - LOG_ERROR(cnxString_ << "Client is not connected to the broker"); + LOG_ERROR(cnxString() << "Client is not connected to the broker"); promise.setFailed(ResultNotConnected); return promise.getFuture(); } @@ -1487,7 +1453,7 @@ Future ClientConnection::newGetSchema(const std::string& top Promise promise; if (isClosed()) { lock.unlock(); - LOG_ERROR(cnxString_ << "Client is not connected to the broker"); + LOG_ERROR(cnxString() << "Client is not connected to the broker"); promise.setFailed(ResultNotConnected); return promise.getFuture(); } @@ -1496,11 +1462,9 @@ Future ClientConnection::newGetSchema(const std::string& top pendingGetSchemaRequests_.emplace(requestId, GetSchemaRequest{promise, timer}); lock.unlock(); - auto weakSelf = weak_from_this(); timer->expires_after(operationsTimeout_); - timer->async_wait([this, weakSelf, requestId](const ASIO_ERROR& ec) { - auto self = weakSelf.lock(); - if (!self) { + timer->async_wait([this, self{shared_from_this()}, requestId](const ASIO_ERROR& ec) { + if (ec) { return; } Lock lock(mutex_); @@ -1527,8 +1491,8 @@ void ClientConnection::handleSendReceipt(const proto::CommandSendReceipt& sendRe const proto::MessageIdData& messageIdData = sendReceipt.message_id(); auto messageId = toMessageId(messageIdData); - LOG_DEBUG(cnxString_ << "Got receipt for producer: " << producerId << " -- msg: " << sequenceId - << "-- message id: " << messageId); + LOG_DEBUG(cnxString() << "Got receipt for producer: " << producerId << " -- msg: " << sequenceId + << "-- message id: " << messageId); Lock lock(mutex_); auto it = producers_.find(producerId); @@ -1544,13 +1508,13 @@ void ClientConnection::handleSendReceipt(const proto::CommandSendReceipt& sendRe } } } else { - LOG_ERROR(cnxString_ << "Got invalid producer Id in SendReceipt: " // - << producerId << " -- msg: " << sequenceId); + LOG_ERROR(cnxString() << "Got invalid producer Id in SendReceipt: " // + << producerId << " -- msg: " << sequenceId); } } void ClientConnection::handleSendError(const proto::CommandSendError& error) { - LOG_WARN(cnxString_ << "Received send error from server: " << error.message()); + LOG_WARN(cnxString() << "Received send error from server: " << error.message()); if (ChecksumError == error.error()) { long producerId = error.producer_id(); long sequenceId = error.sequence_id(); @@ -1574,7 +1538,7 @@ void ClientConnection::handleSendError(const proto::CommandSendError& error) { } void ClientConnection::handleSuccess(const proto::CommandSuccess& success) { - LOG_DEBUG(cnxString_ << "Received success response from server. req_id: " << success.request_id()); + LOG_DEBUG(cnxString() << "Received success response from server. req_id: " << success.request_id()); Lock lock(mutex_); auto it = pendingRequests_.find(success.request_id()); @@ -1590,8 +1554,8 @@ void ClientConnection::handleSuccess(const proto::CommandSuccess& success) { void ClientConnection::handlePartitionedMetadataResponse( const proto::CommandPartitionedTopicMetadataResponse& partitionMetadataResponse) { - LOG_DEBUG(cnxString_ << "Received partition-metadata response from server. req_id: " - << partitionMetadataResponse.request_id()); + LOG_DEBUG(cnxString() << "Received partition-metadata response from server. req_id: " + << partitionMetadataResponse.request_id()); Lock lock(mutex_); auto it = pendingLookupRequests_.find(partitionMetadataResponse.request_id()); @@ -1607,16 +1571,16 @@ void ClientConnection::handlePartitionedMetadataResponse( (partitionMetadataResponse.response() == proto::CommandPartitionedTopicMetadataResponse::Failed)) { if (partitionMetadataResponse.has_error()) { - LOG_ERROR(cnxString_ << "Failed partition-metadata lookup req_id: " - << partitionMetadataResponse.request_id() - << " error: " << partitionMetadataResponse.error() - << " msg: " << partitionMetadataResponse.message()); + LOG_ERROR(cnxString() << "Failed partition-metadata lookup req_id: " + << partitionMetadataResponse.request_id() + << " error: " << partitionMetadataResponse.error() + << " msg: " << partitionMetadataResponse.message()); checkServerError(partitionMetadataResponse.error(), partitionMetadataResponse.message()); lookupDataPromise->setFailed( getResult(partitionMetadataResponse.error(), partitionMetadataResponse.message())); } else { - LOG_ERROR(cnxString_ << "Failed partition-metadata lookup req_id: " - << partitionMetadataResponse.request_id() << " with empty response: "); + LOG_ERROR(cnxString() << "Failed partition-metadata lookup req_id: " + << partitionMetadataResponse.request_id() << " with empty response: "); lookupDataPromise->setFailed(ResultConnectError); } } else { @@ -1632,9 +1596,9 @@ void ClientConnection::handlePartitionedMetadataResponse( void ClientConnection::handleConsumerStatsResponse( const proto::CommandConsumerStatsResponse& consumerStatsResponse) { - LOG_DEBUG(cnxString_ << "ConsumerStatsResponse command - Received consumer stats " - "response from server. req_id: " - << consumerStatsResponse.request_id()); + LOG_DEBUG(cnxString() << "ConsumerStatsResponse command - Received consumer stats " + "response from server. req_id: " + << consumerStatsResponse.request_id()); Lock lock(mutex_); auto it = pendingConsumerStatsMap_.find(consumerStatsResponse.request_id()); if (it != pendingConsumerStatsMap_.end()) { @@ -1644,15 +1608,15 @@ void ClientConnection::handleConsumerStatsResponse( if (consumerStatsResponse.has_error_code()) { if (consumerStatsResponse.has_error_message()) { - LOG_ERROR(cnxString_ << " Failed to get consumer stats - " - << consumerStatsResponse.error_message()); + LOG_ERROR(cnxString() + << " Failed to get consumer stats - " << consumerStatsResponse.error_message()); } consumerStatsPromise.setFailed( getResult(consumerStatsResponse.error_code(), consumerStatsResponse.error_message())); } else { - LOG_DEBUG(cnxString_ << "ConsumerStatsResponse command - Received consumer stats " - "response from server. req_id: " - << consumerStatsResponse.request_id() << " Stats: "); + LOG_DEBUG(cnxString() << "ConsumerStatsResponse command - Received consumer stats " + "response from server. req_id: " + << consumerStatsResponse.request_id() << " Stats: "); BrokerConsumerStatsImpl brokerStats( consumerStatsResponse.msgrateout(), consumerStatsResponse.msgthroughputout(), consumerStatsResponse.msgrateredeliver(), consumerStatsResponse.consumername(), @@ -1682,25 +1646,25 @@ void ClientConnection::handleLookupTopicRespose( if (!lookupTopicResponse.has_response() || (lookupTopicResponse.response() == proto::CommandLookupTopicResponse::Failed)) { if (lookupTopicResponse.has_error()) { - LOG_ERROR(cnxString_ << "Failed lookup req_id: " << lookupTopicResponse.request_id() - << " error: " << lookupTopicResponse.error() - << " msg: " << lookupTopicResponse.message()); + LOG_ERROR(cnxString() << "Failed lookup req_id: " << lookupTopicResponse.request_id() + << " error: " << lookupTopicResponse.error() + << " msg: " << lookupTopicResponse.message()); checkServerError(lookupTopicResponse.error(), lookupTopicResponse.message()); lookupDataPromise->setFailed( getResult(lookupTopicResponse.error(), lookupTopicResponse.message())); } else { - LOG_ERROR(cnxString_ << "Failed lookup req_id: " << lookupTopicResponse.request_id() - << " with empty response: "); + LOG_ERROR(cnxString() << "Failed lookup req_id: " << lookupTopicResponse.request_id() + << " with empty response: "); lookupDataPromise->setFailed(ResultConnectError); } } else { - LOG_DEBUG(cnxString_ << "Received lookup response from server. req_id: " - << lookupTopicResponse.request_id() // - << " -- broker-url: " << lookupTopicResponse.brokerserviceurl() - << " -- broker-tls-url: " // - << lookupTopicResponse.brokerserviceurltls() - << " authoritative: " << lookupTopicResponse.authoritative() // - << " redirect: " << lookupTopicResponse.response()); + LOG_DEBUG(cnxString() << "Received lookup response from server. req_id: " + << lookupTopicResponse.request_id() // + << " -- broker-url: " << lookupTopicResponse.brokerserviceurl() + << " -- broker-tls-url: " // + << lookupTopicResponse.brokerserviceurltls() + << " authoritative: " << lookupTopicResponse.authoritative() // + << " redirect: " << lookupTopicResponse.response()); LookupDataResultPtr lookupResultPtr = std::make_shared(); if (tlsSocket_) { @@ -1723,17 +1687,18 @@ void ClientConnection::handleLookupTopicRespose( } void ClientConnection::handleProducerSuccess(const proto::CommandProducerSuccess& producerSuccess) { - LOG_DEBUG(cnxString_ << "Received success producer response from server. req_id: " - << producerSuccess.request_id() // - << " -- producer name: " << producerSuccess.producer_name()); + LOG_DEBUG(cnxString() << "Received success producer response from server. req_id: " + << producerSuccess.request_id() // + << " -- producer name: " << producerSuccess.producer_name()); Lock lock(mutex_); auto it = pendingRequests_.find(producerSuccess.request_id()); if (it != pendingRequests_.end()) { PendingRequestData requestData = it->second; if (!producerSuccess.producer_ready()) { - LOG_INFO(cnxString_ << " Producer " << producerSuccess.producer_name() - << " has been queued up at broker. req_id: " << producerSuccess.request_id()); + LOG_INFO(cnxString() << " Producer " << producerSuccess.producer_name() + << " has been queued up at broker. req_id: " + << producerSuccess.request_id()); requestData.hasGotResponse->store(true); lock.unlock(); } else { @@ -1758,9 +1723,9 @@ void ClientConnection::handleProducerSuccess(const proto::CommandProducerSuccess void ClientConnection::handleError(const proto::CommandError& error) { Result result = getResult(error.error(), error.message()); - LOG_WARN(cnxString_ << "Received error response from server: " << result - << (error.has_message() ? (" (" + error.message() + ")") : "") - << " -- req_id: " << error.request_id()); + LOG_WARN(cnxString() << "Received error response from server: " << result + << (error.has_message() ? (" (" + error.message() + ")") : "") + << " -- req_id: " << error.request_id()); Lock lock(mutex_); @@ -1890,7 +1855,7 @@ void ClientConnection::handleCloseProducer(const proto::CommandCloseProducer& cl producer->disconnectProducer(assignedBrokerServiceUrl); } } else { - LOG_ERROR(cnxString_ << "Got invalid producer Id in closeProducer command: " << producerId); + LOG_ERROR(cnxString() << "Got invalid producer Id in closeProducer command: " << producerId); } } @@ -1911,17 +1876,17 @@ void ClientConnection::handleCloseConsumer(const proto::CommandCloseConsumer& cl consumer->disconnectConsumer(assignedBrokerServiceUrl); } } else { - LOG_ERROR(cnxString_ << "Got invalid consumer Id in closeConsumer command: " << consumerId); + LOG_ERROR(cnxString() << "Got invalid consumer Id in closeConsumer command: " << consumerId); } } void ClientConnection::handleAuthChallenge() { - LOG_DEBUG(cnxString_ << "Received auth challenge from broker"); + LOG_DEBUG(cnxString() << "Received auth challenge from broker"); Result result; SharedBuffer buffer = Commands::newAuthResponse(authentication_, result); if (result != ResultOk) { - LOG_ERROR(cnxString_ << "Failed to send auth response: " << result); + LOG_ERROR(cnxString() << "Failed to send auth response: " << result); close(result); return; } @@ -1934,8 +1899,8 @@ void ClientConnection::handleAuthChallenge() { void ClientConnection::handleGetLastMessageIdResponse( const proto::CommandGetLastMessageIdResponse& getLastMessageIdResponse) { - LOG_DEBUG(cnxString_ << "Received getLastMessageIdResponse from server. req_id: " - << getLastMessageIdResponse.request_id()); + LOG_DEBUG(cnxString() << "Received getLastMessageIdResponse from server. req_id: " + << getLastMessageIdResponse.request_id()); Lock lock(mutex_); auto it = pendingGetLastMessageIdRequests_.find(getLastMessageIdResponse.request_id()); @@ -1961,8 +1926,8 @@ void ClientConnection::handleGetLastMessageIdResponse( void ClientConnection::handleGetTopicOfNamespaceResponse( const proto::CommandGetTopicsOfNamespaceResponse& response) { - LOG_DEBUG(cnxString_ << "Received GetTopicsOfNamespaceResponse from server. req_id: " - << response.request_id() << " topicsSize" << response.topics_size()); + LOG_DEBUG(cnxString() << "Received GetTopicsOfNamespaceResponse from server. req_id: " + << response.request_id() << " topicsSize" << response.topics_size()); Lock lock(mutex_); auto it = pendingGetNamespaceTopicsRequests_.find(response.request_id()); @@ -2001,7 +1966,7 @@ void ClientConnection::handleGetTopicOfNamespaceResponse( } void ClientConnection::handleGetSchemaResponse(const proto::CommandGetSchemaResponse& response) { - LOG_DEBUG(cnxString_ << "Received GetSchemaResponse from server. req_id: " << response.request_id()); + LOG_DEBUG(cnxString() << "Received GetSchemaResponse from server. req_id: " << response.request_id()); Lock lock(mutex_); auto it = pendingGetSchemaRequests_.find(response.request_id()); if (it != pendingGetSchemaRequests_.end()) { @@ -2012,10 +1977,11 @@ void ClientConnection::handleGetSchemaResponse(const proto::CommandGetSchemaResp if (response.has_error_code()) { Result result = getResult(response.error_code(), response.error_message()); if (response.error_code() != proto::TopicNotFound) { - LOG_WARN(cnxString_ << "Received error GetSchemaResponse from server " << result - << (response.has_error_message() ? (" (" + response.error_message() + ")") - : "") - << " -- req_id: " << response.request_id()); + LOG_WARN(cnxString() << "Received error GetSchemaResponse from server " << result + << (response.has_error_message() + ? (" (" + response.error_message() + ")") + : "") + << " -- req_id: " << response.request_id()); } getSchemaPromise.setFailed(result); return; @@ -2039,7 +2005,7 @@ void ClientConnection::handleGetSchemaResponse(const proto::CommandGetSchemaResp } void ClientConnection::handleAckResponse(const proto::CommandAckResponse& response) { - LOG_DEBUG(cnxString_ << "Received AckResponse from server. req_id: " << response.request_id()); + LOG_DEBUG(cnxString() << "Received AckResponse from server. req_id: " << response.request_id()); Lock lock(mutex_); auto it = pendingRequests_.find(response.request_id()); diff --git a/lib/ClientConnection.h b/lib/ClientConnection.h index b2770006..7d52ef1a 100644 --- a/lib/ClientConnection.h +++ b/lib/ClientConnection.h @@ -25,6 +25,8 @@ #include #include #include +#include +#include #ifdef USE_ASIO #include #include @@ -156,11 +158,8 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this& close(Result result = ResultConnectError); bool isClosed() const; @@ -193,7 +192,7 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this promise; DeadlineTimerPtr timer; std::shared_ptr hasGotResponse{std::make_shared(false)}; + + void fail(Result result) { + cancelTimer(*timer); + ; + promise.setFailed(result); + } }; struct LookupRequestData { LookupDataResultPromisePtr promise; DeadlineTimerPtr timer; + + void fail(Result result) { + cancelTimer(*timer); + ; + promise->setFailed(result); + } }; struct LastMessageIdRequestData { GetLastMessageIdResponsePromisePtr promise; DeadlineTimerPtr timer; + + void fail(Result result) { + cancelTimer(*timer); + ; + promise->setFailed(result); + } }; struct GetSchemaRequest { Promise promise; DeadlineTimerPtr timer; + + void fail(Result result) { + cancelTimer(*timer); + promise.setFailed(result); + } }; /* @@ -297,26 +319,26 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this - inline void asyncWrite(const ConstBufferSequence& buffers, WriteHandler handler) { + inline void asyncWrite(const ConstBufferSequence& buffers, WriteHandler&& handler) { if (isClosed()) { return; } if (tlsSocket_) { - ASIO::async_write(*tlsSocket_, buffers, ASIO::bind_executor(strand_, handler)); + ASIO::async_write(*tlsSocket_, buffers, std::forward(handler)); } else { ASIO::async_write(*socket_, buffers, handler); } } template - inline void asyncReceive(const MutableBufferSequence& buffers, ReadHandler handler) { + inline void asyncReceive(const MutableBufferSequence& buffers, ReadHandler&& handler) { if (isClosed()) { return; } if (tlsSocket_) { - tlsSocket_->async_read_some(buffers, ASIO::bind_executor(strand_, handler)); + tlsSocket_->async_read_some(buffers, std::forward(handler)); } else { - socket_->async_receive(buffers, handler); + socket_->async_receive(buffers, std::forward(handler)); } } @@ -337,7 +359,6 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this strand_; const std::string logicalAddress_; /* @@ -350,7 +371,7 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this server:6650] - std::string cnxString_; + std::shared_ptr cnxStringPtr_; /* * indicates if async connection establishment failed @@ -419,6 +440,7 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this> closeFuture_; friend class PulsarFriend; friend class ConsumerTest; diff --git a/lib/ConnectionPool.cc b/lib/ConnectionPool.cc index df050b0e..c814cf85 100644 --- a/lib/ConnectionPool.cc +++ b/lib/ConnectionPool.cc @@ -54,16 +54,43 @@ bool ConnectionPool::close() { return false; } + std::vector connectionsToClose; + // ClientConnection::close() will remove the connection from the pool, which is not allowed when iterating + // over a map, so we store the connections to close in a vector first and don't iterate the pool when + // closing the connections. std::unique_lock lock(mutex_); + connectionsToClose.reserve(pool_.size()); + for (auto&& kv : pool_) { + connectionsToClose.emplace_back(kv.second); + } + pool_.clear(); + lock.unlock(); - for (auto cnxIt = pool_.begin(); cnxIt != pool_.end(); cnxIt++) { - auto& cnx = cnxIt->second; + for (auto&& cnx : connectionsToClose) { if (cnx) { - // The 2nd argument is false because removing a value during the iteration will cause segfault - cnx->close(ResultDisconnected, false); + // Close with a fatal error to not let client retry + auto& future = cnx->close(ResultAlreadyClosed); + using namespace std::chrono_literals; + if (auto status = future.wait_for(5s); status != std::future_status::ready) { + LOG_WARN("Connection close timed out for " << cnx.get()->cnxString()); + } + if (cnx.use_count() > 1) { + // There are some asynchronous operations that hold the reference on the connection, we should + // wait until them to finish. Otherwise, `io_context::stop()` will be called in + // `ClientImpl::shutdown()` when closing the `ExecutorServiceProvider`. Then + // `io_context::run()` will return and the `io_context` object will be destroyed. In this + // case, if there is any pending handler, it will crash. + for (int i = 0; i < 500 && cnx.use_count() > 1; i++) { + std::this_thread::sleep_for(10ms); + } + if (cnx.use_count() > 1) { + LOG_WARN("Connection still has " << (cnx.use_count() - 1) + << " references after waiting for 5 seconds for " + << cnx.get()->cnxString()); + } + } } } - pool_.clear(); return true; } diff --git a/lib/ExecutorService.cc b/lib/ExecutorService.cc index 99e2393f..eba74861 100644 --- a/lib/ExecutorService.cc +++ b/lib/ExecutorService.cc @@ -125,8 +125,6 @@ void ExecutorService::close(long timeoutMs) { } } -void ExecutorService::postWork(std::function task) { ASIO::post(io_context_, std::move(task)); } - ///////////////////// ExecutorServiceProvider::ExecutorServiceProvider(int nthreads) diff --git a/lib/ExecutorService.h b/lib/ExecutorService.h index 626cb203..ba8a877f 100644 --- a/lib/ExecutorService.h +++ b/lib/ExecutorService.h @@ -23,12 +23,16 @@ #include #ifdef USE_ASIO +#include #include #include +#include #include #else +#include #include #include +#include #include #endif #include @@ -62,7 +66,19 @@ class PULSAR_PUBLIC ExecutorService : public std::enable_shared_from_this task); + + // Execute the task in the event loop thread asynchronously, i.e. the task will be put in the event loop + // queue and executed later. + template + void postWork(T &&task) { + ASIO::post(io_context_, std::forward(task)); + } + + // Different from `postWork`, if it's already in the event loop, execute the task immediately + template + void dispatch(T &&task) { + ASIO::dispatch(io_context_, std::forward(task)); + } // See TimeoutProcessor for the semantics of the parameter. void close(long timeoutMs = 3000); diff --git a/lib/PeriodicTask.h b/lib/PeriodicTask.h index bc186348..ee19182a 100644 --- a/lib/PeriodicTask.h +++ b/lib/PeriodicTask.h @@ -53,7 +53,7 @@ class PeriodicTask : public std::enable_shared_from_this { void stop() noexcept; - void setCallback(CallbackType callback) noexcept { callback_ = callback; } + void setCallback(CallbackType&& callback) noexcept { callback_ = std::move(callback); } State getState() const noexcept { return state_; } int getPeriodMs() const noexcept { return periodMs_; } diff --git a/tests/ClientTest.cc b/tests/ClientTest.cc index dd892686..78c42c80 100644 --- a/tests/ClientTest.cc +++ b/tests/ClientTest.cc @@ -413,7 +413,7 @@ TEST(ClientTest, testConnectionClose) { LOG_INFO("Connection refcnt: " << cnx.use_count() << " before close"); auto executor = PulsarFriend::getExecutor(*cnx); // Simulate the close() happens in the event loop - executor->postWork([cnx, &client, numConnections] { + executor->dispatch([cnx, &client, numConnections] { cnx->close(); ASSERT_EQ(PulsarFriend::getConnections(client).size(), numConnections - 1); LOG_INFO("Connection refcnt: " << cnx.use_count() << " after close"); diff --git a/tests/MultiTopicsConsumerTest.cc b/tests/MultiTopicsConsumerTest.cc index db3bc963..8aae321b 100644 --- a/tests/MultiTopicsConsumerTest.cc +++ b/tests/MultiTopicsConsumerTest.cc @@ -166,7 +166,7 @@ TEST(MultiTopicsConsumerTest, testGetConsumerStatsFail) { future.wait_for(std::chrono::milliseconds(100)); std::this_thread::sleep_for(std::chrono::milliseconds(100)); - connection->handleKeepAliveTimeout(); + connection->handleKeepAliveTimeout(ASIO_SUCCESS); ASSERT_EQ(ResultDisconnected, future.get()); mockServer->close(); From 7fd7fed1fa09de7292cff293aa7015ad47653e3b Mon Sep 17 00:00:00 2001 From: Yunze Xu Date: Wed, 11 Mar 2026 22:14:05 +0800 Subject: [PATCH 2/5] fix regression for connection timeout --- lib/ClientConnection.cc | 8 ++- tests/ClientTest.cc | 105 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+), 2 deletions(-) diff --git a/lib/ClientConnection.cc b/lib/ClientConnection.cc index b1d45b3f..f8348d9a 100644 --- a/lib/ClientConnection.cc +++ b/lib/ClientConnection.cc @@ -317,6 +317,10 @@ void ClientConnection::handlePulsarConnected(const proto::CommandConnected& cmdC LOG_INFO(cnxString() << "Connection already closed"); return; } + if (connectTimeoutTask_) { + connectTimeoutTask_->stop(); + connectTimeoutTask_.reset(); // clear the callback once the Pulsar handshake is fully complete + } state_ = Ready; serverProtocolVersion_ = cmdConnected.protocol_version(); @@ -426,7 +430,6 @@ void ClientConnection::handleTcpConnected(const ASIO_ERROR& err, const tcp::endp LOG_INFO(cnxString() << "Connection already closed"); return; } - connectTimeoutTask_->stop(); state_ = TcpConnected; lock.unlock(); @@ -1322,7 +1325,8 @@ const std::future& ClientConnection::close(Result result) { LOG_WARN(cnxString() << "Failed to close socket: " << err.message()); } if (tlsSocket_) { - tlsSocket_->async_shutdown([promise](const auto&) { promise->set_value(); }); + auto tlsSocket = tlsSocket_; + tlsSocket->async_shutdown([promise, self, tlsSocket](const auto&) { promise->set_value(); }); } else { promise->set_value(); } diff --git a/tests/ClientTest.cc b/tests/ClientTest.cc index 78c42c80..5bb0ecbf 100644 --- a/tests/ClientTest.cc +++ b/tests/ClientTest.cc @@ -22,13 +22,17 @@ #include #include +#include #include +#include #include +#include #include "MockClientImpl.h" #include "PulsarAdminHelper.h" #include "PulsarFriend.h" #include "WaitUtils.h" +#include "lib/AsioDefines.h" #include "lib/ClientConnection.h" #include "lib/LogUtils.h" #include "lib/checksum/ChecksumProvider.h" @@ -42,6 +46,81 @@ using testing::AtLeast; static std::string lookupUrl = "pulsar://localhost:6650"; static std::string adminUrl = "http://localhost:8080/"; +namespace { + +class SilentTcpServer { + public: + SilentTcpServer() + : acceptor_(ioContext_, ASIO::ip::tcp::endpoint(ASIO::ip::tcp::v4(), 0)), + acceptedFuture_(acceptedPromise_.get_future()) {} + + ~SilentTcpServer() { stop(); } + + int getPort() const { return acceptor_.local_endpoint().port(); } + + void start() { + serverThread_ = std::thread([this] { + socket_.reset(new ASIO::ip::tcp::socket(ioContext_)); + + ASIO_ERROR acceptError; + acceptor_.accept(*socket_, acceptError); + acceptedPromise_.set_value(acceptError); + + std::unique_lock lock(mutex_); + cond_.wait(lock, [this] { return stopped_; }); + lock.unlock(); + + if (socket_) { + ASIO_ERROR closeError; + socket_->close(closeError); + } + + ASIO_ERROR closeError; + acceptor_.close(closeError); + }); + } + + bool waitUntilAccepted(std::chrono::milliseconds timeout) const { + return acceptedFuture_.wait_for(timeout) == std::future_status::ready; + } + + ASIO_ERROR acceptedError() const { return acceptedFuture_.get(); } + + void stop() { + { + std::lock_guard lock(mutex_); + if (stopped_) { + return; + } + stopped_ = true; + } + + ASIO_ERROR closeError; + acceptor_.close(closeError); + if (socket_) { + socket_->close(closeError); + } + + cond_.notify_all(); + if (serverThread_.joinable()) { + serverThread_.join(); + } + } + + private: + ASIO::io_context ioContext_; + ASIO::ip::tcp::acceptor acceptor_; + std::shared_ptr socket_; + std::promise acceptedPromise_; + std::shared_future acceptedFuture_; + std::mutex mutex_; + std::condition_variable cond_; + bool stopped_{false}; + std::thread serverThread_; +}; + +} // namespace + TEST(ClientTest, testChecksumComputation) { std::string data = "test"; std::string doubleData = "testtest"; @@ -137,6 +216,32 @@ TEST(ClientTest, testConnectTimeout) { ASSERT_EQ(futureDefault.get(), ResultDisconnected); } +TEST(ClientTest, testConnectTimeoutAfterTcpConnected) { + std::unique_ptr server; + try { + server.reset(new SilentTcpServer); + } catch (const ASIO_SYSTEM_ERROR &e) { + GTEST_SKIP() << "Cannot bind local test server in this environment: " << e.what(); + } + server->start(); + + const std::string serviceUrl = "pulsar://127.0.0.1:" + std::to_string(server->getPort()); + Client client(serviceUrl, ClientConfiguration().setConnectionTimeout(200)); + + std::promise promise; + auto future = promise.get_future(); + client.createProducerAsync("test-connect-timeout-after-tcp-connected", + [&promise](Result result, const Producer &) { promise.set_value(result); }); + + ASSERT_TRUE(server->waitUntilAccepted(std::chrono::seconds(1))); + ASSERT_FALSE(server->acceptedError()); + ASSERT_EQ(future.wait_for(std::chrono::seconds(2)), std::future_status::ready); + ASSERT_EQ(future.get(), ResultConnectError); + + client.close(); + server->stop(); +} + TEST(ClientTest, testGetNumberOfReferences) { Client client("pulsar://localhost:6650"); From 3850259d5bfd76084c17f64822adbf5d65971d6a Mon Sep 17 00:00:00 2001 From: Yunze Xu Date: Wed, 11 Mar 2026 22:37:22 +0800 Subject: [PATCH 3/5] fix unstable tests --- tests/BasicEndToEndTest.cc | 44 +++++++++++++++++++++++++++----- tests/MultiTopicsConsumerTest.cc | 9 ++++--- tests/PulsarFriend.h | 5 ++++ 3 files changed, 47 insertions(+), 11 deletions(-) diff --git a/tests/BasicEndToEndTest.cc b/tests/BasicEndToEndTest.cc index 9a02df0c..d3c6e612 100644 --- a/tests/BasicEndToEndTest.cc +++ b/tests/BasicEndToEndTest.cc @@ -3188,7 +3188,17 @@ static void expectTimeoutOnRecv(Consumer &consumer) { ASSERT_EQ(ResultTimeout, res); } -void testNegativeAcks(const std::string &topic, bool batchingEnabled) { +static std::vector expectedNegativeAckMessages(size_t numMessages) { + std::vector expected; + expected.reserve(numMessages); + for (size_t i = 0; i < numMessages; i++) { + expected.emplace_back("test-" + std::to_string(i)); + } + return expected; +} + +void testNegativeAcks(const std::string &topic, bool batchingEnabled, bool expectOrdered = true) { + constexpr size_t numMessages = 10; Client client(lookupUrl); Consumer consumer; ConsumerConfiguration conf; @@ -3202,22 +3212,32 @@ void testNegativeAcks(const std::string &topic, bool batchingEnabled) { result = client.createProducer(topic, producerConf, producer); ASSERT_EQ(ResultOk, result); - for (int i = 0; i < 10; i++) { + for (size_t i = 0; i < numMessages; i++) { Message msg = MessageBuilder().setContent("test-" + std::to_string(i)).build(); producer.sendAsync(msg, nullptr); } producer.flush(); + std::vector receivedMessages; + receivedMessages.reserve(numMessages); std::vector toNeg; - for (int i = 0; i < 10; i++) { + for (size_t i = 0; i < numMessages; i++) { Message msg; consumer.receive(msg); LOG_INFO("Received message " << msg.getDataAsString()); - ASSERT_EQ(msg.getDataAsString(), "test-" + std::to_string(i)); + if (expectOrdered) { + ASSERT_EQ(msg.getDataAsString(), "test-" + std::to_string(i)); + } + receivedMessages.emplace_back(msg.getDataAsString()); toNeg.push_back(msg.getMessageId()); } + if (!expectOrdered) { + auto expectedMessages = expectedNegativeAckMessages(numMessages); + std::sort(receivedMessages.begin(), receivedMessages.end()); + ASSERT_EQ(expectedMessages, receivedMessages); + } // No more messages expected expectTimeoutOnRecv(consumer); @@ -3228,15 +3248,25 @@ void testNegativeAcks(const std::string &topic, bool batchingEnabled) { } PulsarFriend::setNegativeAckEnabled(consumer, true); - for (int i = 0; i < 10; i++) { + std::vector redeliveredMessages; + redeliveredMessages.reserve(numMessages); + for (size_t i = 0; i < numMessages; i++) { Message msg; consumer.receive(msg); LOG_INFO("-- Redelivery -- Received message " << msg.getDataAsString()); - ASSERT_EQ(msg.getDataAsString(), "test-" + std::to_string(i)); + if (expectOrdered) { + ASSERT_EQ(msg.getDataAsString(), "test-" + std::to_string(i)); + } + redeliveredMessages.emplace_back(msg.getDataAsString()); consumer.acknowledge(msg); } + if (!expectOrdered) { + auto expectedMessages = expectedNegativeAckMessages(numMessages); + std::sort(redeliveredMessages.begin(), redeliveredMessages.end()); + ASSERT_EQ(expectedMessages, redeliveredMessages); + } // No more messages expected expectTimeoutOnRecv(consumer); @@ -3262,7 +3292,7 @@ TEST(BasicEndToEndTest, testNegativeAcksWithPartitions) { LOG_INFO("res = " << res); ASSERT_FALSE(res != 204 && res != 409); - testNegativeAcks(topicName, true); + testNegativeAcks(topicName, true, false); } void testNegativeAckPrecisionBitCnt(const std::string &topic, int precisionBitCnt) { diff --git a/tests/MultiTopicsConsumerTest.cc b/tests/MultiTopicsConsumerTest.cc index 8aae321b..57407fbd 100644 --- a/tests/MultiTopicsConsumerTest.cc +++ b/tests/MultiTopicsConsumerTest.cc @@ -162,11 +162,12 @@ TEST(MultiTopicsConsumerTest, testGetConsumerStatsFail) { BrokerConsumerStats stats; return consumer.getBrokerConsumerStats(stats); }); - // Trigger the `getBrokerConsumerStats` in a new thread - future.wait_for(std::chrono::milliseconds(100)); - std::this_thread::sleep_for(std::chrono::milliseconds(100)); + const auto expectedRequests = topics.size(); + ASSERT_TRUE(waitUntil(std::chrono::seconds(1), [connection, expectedRequests] { + return PulsarFriend::getPendingConsumerStatsRequests(*connection) == expectedRequests; + })); - connection->handleKeepAliveTimeout(ASIO_SUCCESS); + connection->close(ResultDisconnected); ASSERT_EQ(ResultDisconnected, future.get()); mockServer->close(); diff --git a/tests/PulsarFriend.h b/tests/PulsarFriend.h index e7084050..1f351d16 100644 --- a/tests/PulsarFriend.h +++ b/tests/PulsarFriend.h @@ -162,6 +162,11 @@ class PulsarFriend { return consumers; } + static size_t getPendingConsumerStatsRequests(const ClientConnection& cnx) { + std::lock_guard lock(cnx.mutex_); + return cnx.pendingConsumerStatsMap_.size(); + } + static void setNegativeAckEnabled(Consumer consumer, bool enabled) { consumer.impl_->setNegativeAcknowledgeEnabledForTesting(enabled); } From 505417d97c63ca4ac402583d4a194d2309736068 Mon Sep 17 00:00:00 2001 From: Yunze Xu Date: Thu, 12 Mar 2026 16:45:53 +0800 Subject: [PATCH 4/5] address comments and improve tests --- lib/ClientConnection.cc | 57 +++++++++++++++++------------------------ lib/ClientConnection.h | 10 ++++---- lib/ExecutorService.h | 1 + tests/ClientTest.cc | 11 +++++--- 4 files changed, 37 insertions(+), 42 deletions(-) diff --git a/lib/ClientConnection.cc b/lib/ClientConnection.cc index f8348d9a..c373c25c 100644 --- a/lib/ClientConnection.cc +++ b/lib/ClientConnection.cc @@ -193,8 +193,8 @@ ClientConnection::ClientConnection(const std::string& logicalAddress, const std: physicalAddress_(physicalAddress), cnxStringPtr_(std::make_shared("[ -> " + physicalAddress + "] ")), incomingBuffer_(SharedBuffer::allocate(DefaultBufferSize)), - connectTimeoutTask_( - std::make_shared(*executor_, clientConfiguration.getConnectionTimeout())), + connectTimeout_(std::chrono::milliseconds(clientConfiguration.getConnectionTimeout())), + connectTimer_(executor_->createDeadlineTimer()), outgoingBuffer_(SharedBuffer::allocate(DefaultBufferSize)), keepAliveIntervalInSeconds_(clientConfiguration.getKeepAliveIntervalInSeconds()), consumerStatsRequestTimer_(executor_->createDeadlineTimer()), @@ -317,10 +317,7 @@ void ClientConnection::handlePulsarConnected(const proto::CommandConnected& cmdC LOG_INFO(cnxString() << "Connection already closed"); return; } - if (connectTimeoutTask_) { - connectTimeoutTask_->stop(); - connectTimeoutTask_.reset(); // clear the callback once the Pulsar handshake is fully complete - } + cancelTimer(*connectTimer_); state_ = Ready; serverProtocolVersion_ = cmdConnected.protocol_version(); @@ -426,7 +423,7 @@ void ClientConnection::handleTcpConnected(const ASIO_ERROR& err, const tcp::endp } Lock lock(mutex_); - if (isClosed() || !connectTimeoutTask_) { + if (isClosed()) { LOG_INFO(cnxString() << "Connection already closed"); return; } @@ -486,11 +483,10 @@ void ClientConnection::handleTcpConnected(const ASIO_ERROR& err, const tcp::endp LOG_ERROR(cnxString() << "Failed to establish connection to " << endpoint << ": " << err.message()); { std::lock_guard lock{mutex_}; - if (isClosed() || !connectTimeoutTask_) { + if (isClosed()) { return; } - connectTimeoutTask_->stop(); - connectTimeoutTask_.reset(); // clear the callback, which holds a `shared_from_this()` + cancelTimer(*connectTimer_); } if (err == ASIO::error::operation_aborted) { close(); @@ -611,28 +607,25 @@ void ClientConnection::handleResolve(ASIO_ERROR err, const tcp::resolver::result } } - // Acquire the lock to prevent the race: - // 1. thread 1: isClosed() returns false - // 2. thread 2: call `connectTimeoutTask_->stop()` and `connectTimeoutTask_.reset()` in `close()` - // 3. thread 1: call `connectTimeoutTask_->setCallback()` and `connectTimeoutTask_->start()` - // Then the self captured in the callback of `connectTimeoutTask_` would be kept alive unexpectedly and - // cannot be cancelled until the executor is destroyed. std::lock_guard lock{mutex_}; - if (isClosed() || !connectTimeoutTask_) { + if (isClosed()) { return; } - connectTimeoutTask_->setCallback( - [this, self{shared_from_this()}, - results = tcp::resolver::results_type(results)](const PeriodicTask::ErrorCode& ec) { - if (state_ != Ready) { - LOG_ERROR(cnxString() << "Connection to " << results << " was not established in " - << connectTimeoutTask_->getPeriodMs() << " ms"); - close(); - } else { - connectTimeoutTask_->stop(); - } - }); - connectTimeoutTask_->start(); + + connectTimer_->expires_after(connectTimeout_); + connectTimer_->async_wait([this, results, self{shared_from_this()}](const auto& err) { + if (err) { + return; + } + Lock lock{mutex_}; + if (!isClosed() && state_ != Ready) { + LOG_ERROR(cnxString() << "Connection to " << results << " was not established in " + << connectTimeout_.count() << " ms"); + lock.unlock(); + close(); + } // else: the connection is closed or already established + }); + ASIO::async_connect( *socket_, results, [this, self{shared_from_this()}](const ASIO_ERROR& err, const tcp::endpoint& endpoint) { @@ -1296,11 +1289,7 @@ const std::future& ClientConnection::close(Result result) { consumerStatsRequestTimer_.reset(); } - if (connectTimeoutTask_) { - connectTimeoutTask_->stop(); - connectTimeoutTask_.reset(); // clear the callback, which holds a `shared_from_this()` - } - + cancelTimer(*connectTimer_); lock.unlock(); int refCount = weak_from_this().use_count(); if (!isResultRetryable(result)) { diff --git a/lib/ClientConnection.h b/lib/ClientConnection.h index 7d52ef1a..b9880ee2 100644 --- a/lib/ClientConnection.h +++ b/lib/ClientConnection.h @@ -43,8 +43,10 @@ #include #include #include +#include #include #include +#include #include #include "AsioTimer.h" @@ -228,7 +230,6 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_thissetFailed(result); } }; @@ -250,7 +250,6 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_thissetFailed(result); } }; @@ -326,7 +325,7 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this(handler)); } else { - ASIO::async_write(*socket_, buffers, handler); + ASIO::async_write(*socket_, buffers, std::forward(handler)); } } @@ -381,7 +380,8 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this connectPromise_; - std::shared_ptr connectTimeoutTask_; + const std::chrono::milliseconds connectTimeout_; + const DeadlineTimerPtr connectTimer_; typedef std::map PendingRequestsMap; PendingRequestsMap pendingRequests_; diff --git a/lib/ExecutorService.h b/lib/ExecutorService.h index ba8a877f..80659d4b 100644 --- a/lib/ExecutorService.h +++ b/lib/ExecutorService.h @@ -41,6 +41,7 @@ #include #include #include +#include #include "AsioTimer.h" diff --git a/tests/ClientTest.cc b/tests/ClientTest.cc index 5bb0ecbf..08d26ce9 100644 --- a/tests/ClientTest.cc +++ b/tests/ClientTest.cc @@ -414,14 +414,19 @@ TEST(ClientTest, testMultiBrokerUrl) { TEST(ClientTest, testCloseClient) { const std::string topic = "client-test-close-client-" + std::to_string(time(nullptr)); + using namespace std::chrono; for (int i = 0; i < 1000; ++i) { Client client(lookupUrl); client.createProducerAsync(topic, [](Result result, Producer producer) { producer.close(); }); // simulate different time interval before close - auto t0 = std::chrono::steady_clock::now(); - while ((std::chrono::steady_clock::now() - t0) < std::chrono::microseconds(i)) { + auto t0 = steady_clock::now(); + while ((steady_clock::now() - t0) < microseconds(i)) { } - client.close(); + + auto t1 = std::chrono::steady_clock::now(); + ASSERT_EQ(ResultOk, client.close()); + auto closeTimeMs = duration_cast(steady_clock::now() - t1).count(); + ASSERT_TRUE(closeTimeMs < 1000) << "close time: " << closeTimeMs << " ms"; } } From c97312d48c06968bf909aebc8358479ee0de305e Mon Sep 17 00:00:00 2001 From: Yunze Xu Date: Thu, 12 Mar 2026 16:57:33 +0800 Subject: [PATCH 5/5] simplify tests and fix thread safety when closing socket --- tests/ClientTest.cc | 71 +++++++++++++++++++-------------------------- 1 file changed, 30 insertions(+), 41 deletions(-) diff --git a/tests/ClientTest.cc b/tests/ClientTest.cc index 08d26ce9..6bd6cc8a 100644 --- a/tests/ClientTest.cc +++ b/tests/ClientTest.cc @@ -21,12 +21,12 @@ #include #include +#include #include -#include #include -#include #include #include +#include #include "MockClientImpl.h" #include "PulsarAdminHelper.h" @@ -52,31 +52,21 @@ class SilentTcpServer { public: SilentTcpServer() : acceptor_(ioContext_, ASIO::ip::tcp::endpoint(ASIO::ip::tcp::v4(), 0)), - acceptedFuture_(acceptedPromise_.get_future()) {} + acceptedFuture_(acceptedPromise_.get_future()), + port_(acceptor_.local_endpoint().port()), + workGuard_(ASIO::make_work_guard(ioContext_)) {} ~SilentTcpServer() { stop(); } - int getPort() const { return acceptor_.local_endpoint().port(); } + int getPort() const noexcept { return port_; } void start() { serverThread_ = std::thread([this] { socket_.reset(new ASIO::ip::tcp::socket(ioContext_)); + acceptor_.async_accept( + *socket_, [this](const ASIO_ERROR &acceptError) { acceptedPromise_.set_value(acceptError); }); - ASIO_ERROR acceptError; - acceptor_.accept(*socket_, acceptError); - acceptedPromise_.set_value(acceptError); - - std::unique_lock lock(mutex_); - cond_.wait(lock, [this] { return stopped_; }); - lock.unlock(); - - if (socket_) { - ASIO_ERROR closeError; - socket_->close(closeError); - } - - ASIO_ERROR closeError; - acceptor_.close(closeError); + ioContext_.run(); }); } @@ -84,38 +74,37 @@ class SilentTcpServer { return acceptedFuture_.wait_for(timeout) == std::future_status::ready; } - ASIO_ERROR acceptedError() const { return acceptedFuture_.get(); } + auto acceptedError() const { return acceptedFuture_.get(); } void stop() { - { - std::lock_guard lock(mutex_); - if (stopped_) { - return; - } - stopped_ = true; - } - - ASIO_ERROR closeError; - acceptor_.close(closeError); - if (socket_) { - socket_->close(closeError); - } - - cond_.notify_all(); - if (serverThread_.joinable()) { - serverThread_.join(); + bool expected = false; + if (!stopped_.compare_exchange_strong(expected, true) || !serverThread_.joinable()) { + return; } + ASIO::post(ioContext_, [this] { + ASIO_ERROR closeError; + if (socket_ && socket_->is_open()) { + socket_->close(closeError); + } + if (acceptor_.is_open()) { + acceptor_.close(closeError); + } + workGuard_.reset(); + }); + serverThread_.join(); } private: + using WorkGuard = decltype(ASIO::make_work_guard(std::declval())); + ASIO::io_context ioContext_; ASIO::ip::tcp::acceptor acceptor_; - std::shared_ptr socket_; + std::unique_ptr socket_; std::promise acceptedPromise_; std::shared_future acceptedFuture_; - std::mutex mutex_; - std::condition_variable cond_; - bool stopped_{false}; + const int port_; + WorkGuard workGuard_; + std::atomic_bool stopped_{false}; std::thread serverThread_; };