diff --git a/.github/workflows/ci-pr-validation.yaml b/.github/workflows/ci-pr-validation.yaml index b5a0973c..d209b533 100644 --- a/.github/workflows/ci-pr-validation.yaml +++ b/.github/workflows/ci-pr-validation.yaml @@ -260,6 +260,15 @@ jobs: Pop-Location } + - name: Ensure vcpkg has full history(windows) + if: runner.os == 'Windows' + shell: pwsh + run: | + $isShallow = (git -C "${{ env.VCPKG_ROOT }}" rev-parse --is-shallow-repository).Trim() + if ($isShallow -eq "true") { + git -C "${{ env.VCPKG_ROOT }}" fetch --unshallow + } + - name: remove system vcpkg(windows) if: runner.os == 'Windows' run: rm -rf "$VCPKG_INSTALLATION_ROOT" diff --git a/lib/ClientConnection.cc b/lib/ClientConnection.cc index 0a850ed0..c373c25c 100644 --- a/lib/ClientConnection.cc +++ b/lib/ClientConnection.cc @@ -189,13 +189,12 @@ ClientConnection::ClientConnection(const std::string& logicalAddress, const std: executor_(executor), resolver_(executor_->createTcpResolver()), socket_(executor_->createSocket()), - strand_(ASIO::make_strand(executor_->getIOService().get_executor())), logicalAddress_(logicalAddress), physicalAddress_(physicalAddress), - cnxString_("[ -> " + physicalAddress + "] "), + cnxStringPtr_(std::make_shared("[ -> " + physicalAddress + "] ")), incomingBuffer_(SharedBuffer::allocate(DefaultBufferSize)), - connectTimeoutTask_( - std::make_shared(*executor_, clientConfiguration.getConnectionTimeout())), + connectTimeout_(std::chrono::milliseconds(clientConfiguration.getConnectionTimeout())), + connectTimer_(executor_->createDeadlineTimer()), outgoingBuffer_(SharedBuffer::allocate(DefaultBufferSize)), keepAliveIntervalInSeconds_(clientConfiguration.getKeepAliveIntervalInSeconds()), consumerStatsRequestTimer_(executor_->createDeadlineTimer()), @@ -203,7 +202,8 @@ ClientConnection::ClientConnection(const std::string& logicalAddress, const std: clientVersion_(clientVersion), pool_(pool), poolIndex_(poolIndex) { - LOG_INFO(cnxString_ << "Create ClientConnection, timeout=" << clientConfiguration.getConnectionTimeout()); + LOG_INFO(cnxString() << "Create ClientConnection, timeout=" + << clientConfiguration.getConnectionTimeout()); if (!authentication_) { LOG_ERROR("Invalid authentication plugin"); throw ResultAuthenticationError; @@ -295,12 +295,12 @@ ClientConnection::ClientConnection(const std::string& logicalAddress, const std: } ClientConnection::~ClientConnection() { - LOG_INFO(cnxString_ << "Destroyed connection to " << logicalAddress_ << "-" << poolIndex_); + LOG_INFO(cnxString() << "Destroyed connection to " << logicalAddress_ << "-" << poolIndex_); } void ClientConnection::handlePulsarConnected(const proto::CommandConnected& cmdConnected) { if (!cmdConnected.has_server_version()) { - LOG_ERROR(cnxString_ << "Server version is not set"); + LOG_ERROR(cnxString() << "Server version is not set"); close(); return; } @@ -314,11 +314,11 @@ void ClientConnection::handlePulsarConnected(const proto::CommandConnected& cmdC Lock lock(mutex_); if (isClosed()) { - LOG_INFO(cnxString_ << "Connection already closed"); + LOG_INFO(cnxString() << "Connection already closed"); return; } + cancelTimer(*connectTimer_); state_ = Ready; - connectTimeoutTask_->stop(); serverProtocolVersion_ = cmdConnected.protocol_version(); if (serverProtocolVersion_ >= proto::v1) { @@ -326,13 +326,8 @@ void ClientConnection::handlePulsarConnected(const proto::CommandConnected& cmdC keepAliveTimer_ = executor_->createDeadlineTimer(); if (keepAliveTimer_) { keepAliveTimer_->expires_after(std::chrono::seconds(keepAliveIntervalInSeconds_)); - auto weakSelf = weak_from_this(); - keepAliveTimer_->async_wait([weakSelf](const ASIO_ERROR&) { - auto self = weakSelf.lock(); - if (self) { - self->handleKeepAliveTimeout(); - } - }); + keepAliveTimer_->async_wait( + [this, self{shared_from_this()}](const ASIO_ERROR& err) { handleKeepAliveTimeout(err); }); } } @@ -352,12 +347,12 @@ void ClientConnection::startConsumerStatsTimer(std::vector consumerSta for (int i = 0; i < consumerStatsRequests.size(); i++) { PendingConsumerStatsMap::iterator it = pendingConsumerStatsMap_.find(consumerStatsRequests[i]); if (it != pendingConsumerStatsMap_.end()) { - LOG_DEBUG(cnxString_ << " removing request_id " << it->first - << " from the pendingConsumerStatsMap_"); + LOG_DEBUG(cnxString() << " removing request_id " << it->first + << " from the pendingConsumerStatsMap_"); consumerStatsPromises.push_back(it->second); pendingConsumerStatsMap_.erase(it); } else { - LOG_DEBUG(cnxString_ << "request_id " << it->first << " already fulfilled - not removing it"); + LOG_DEBUG(cnxString() << "request_id " << it->first << " already fulfilled - not removing it"); } } @@ -371,19 +366,16 @@ void ClientConnection::startConsumerStatsTimer(std::vector consumerSta // Check if we have a timer still before we set the request timer to pop again. if (consumerStatsRequestTimer_) { consumerStatsRequestTimer_->expires_after(operationsTimeout_); - auto weakSelf = weak_from_this(); - consumerStatsRequestTimer_->async_wait([weakSelf, consumerStatsRequests](const ASIO_ERROR& err) { - auto self = weakSelf.lock(); - if (self) { - self->handleConsumerStatsTimeout(err, consumerStatsRequests); - } - }); + consumerStatsRequestTimer_->async_wait( + [this, self{shared_from_this()}, consumerStatsRequests](const ASIO_ERROR& err) { + handleConsumerStatsTimeout(err, consumerStatsRequests); + }); } lock.unlock(); // Complex logic since promises need to be fulfilled outside the lock for (int i = 0; i < consumerStatsPromises.size(); i++) { consumerStatsPromises[i].setFailed(ResultTimeout); - LOG_WARN(cnxString_ << " Operation timedout, didn't get response from broker"); + LOG_WARN(cnxString() << " Operation timedout, didn't get response from broker"); } } @@ -416,23 +408,23 @@ void ClientConnection::handleTcpConnected(const ASIO_ERROR& err, const tcp::endp try { cnxStringStream << "[" << socket_->local_endpoint() << " -> " << socket_->remote_endpoint() << "] "; - cnxString_ = cnxStringStream.str(); + std::atomic_store(&cnxStringPtr_, std::make_shared(cnxStringStream.str())); } catch (const ASIO_SYSTEM_ERROR& e) { LOG_ERROR("Failed to get endpoint: " << e.what()); close(ResultRetryable); return; } if (logicalAddress_ == physicalAddress_) { - LOG_INFO(cnxString_ << "Connected to broker"); + LOG_INFO(cnxString() << "Connected to broker"); } else { - LOG_INFO(cnxString_ << "Connected to broker through proxy. Logical broker: " << logicalAddress_ - << ", proxy: " << proxyServiceUrl_ - << ", physical address:" << physicalAddress_); + LOG_INFO(cnxString() << "Connected to broker through proxy. Logical broker: " << logicalAddress_ + << ", proxy: " << proxyServiceUrl_ + << ", physical address:" << physicalAddress_); } Lock lock(mutex_); if (isClosed()) { - LOG_INFO(cnxString_ << "Connection already closed"); + LOG_INFO(cnxString() << "Connection already closed"); return; } state_ = TcpConnected; @@ -441,12 +433,12 @@ void ClientConnection::handleTcpConnected(const ASIO_ERROR& err, const tcp::endp ASIO_ERROR error; socket_->set_option(tcp::no_delay(true), error); if (error) { - LOG_WARN(cnxString_ << "Socket failed to set tcp::no_delay: " << error.message()); + LOG_WARN(cnxString() << "Socket failed to set tcp::no_delay: " << error.message()); } socket_->set_option(tcp::socket::keep_alive(true), error); if (error) { - LOG_WARN(cnxString_ << "Socket failed to set tcp::socket::keep_alive: " << error.message()); + LOG_WARN(cnxString() << "Socket failed to set tcp::socket::keep_alive: " << error.message()); } // Start TCP keep-alive probes after connection has been idle after 1 minute. Ideally this @@ -454,19 +446,19 @@ void ClientConnection::handleTcpConnected(const ASIO_ERROR& err, const tcp::endp // connection) every 30 seconds socket_->set_option(tcp_keep_alive_idle(1 * 60), error); if (error) { - LOG_DEBUG(cnxString_ << "Socket failed to set tcp_keep_alive_idle: " << error.message()); + LOG_DEBUG(cnxString() << "Socket failed to set tcp_keep_alive_idle: " << error.message()); } // Send up to 10 probes before declaring the connection broken socket_->set_option(tcp_keep_alive_count(10), error); if (error) { - LOG_DEBUG(cnxString_ << "Socket failed to set tcp_keep_alive_count: " << error.message()); + LOG_DEBUG(cnxString() << "Socket failed to set tcp_keep_alive_count: " << error.message()); } // Interval between probes: 6 seconds socket_->set_option(tcp_keep_alive_interval(6), error); if (error) { - LOG_DEBUG(cnxString_ << "Socket failed to set tcp_keep_alive_interval: " << error.message()); + LOG_DEBUG(cnxString() << "Socket failed to set tcp_keep_alive_interval: " << error.message()); } if (tlsSocket_) { @@ -474,29 +466,28 @@ void ClientConnection::handleTcpConnected(const ASIO_ERROR& err, const tcp::endp ASIO_ERROR err; Url service_url; if (!Url::parse(physicalAddress_, service_url)) { - LOG_ERROR(cnxString_ << "Invalid Url, unable to parse: " << err << " " << err.message()); + LOG_ERROR(cnxString() << "Invalid Url, unable to parse: " << err << " " << err.message()); close(); return; } } - auto weakSelf = weak_from_this(); - auto socket = socket_; - auto tlsSocket = tlsSocket_; // socket and ssl::stream objects must exist until async_handshake is done, otherwise segmentation // fault might happen - auto callback = [weakSelf, socket, tlsSocket](const ASIO_ERROR& err) { - auto self = weakSelf.lock(); - if (self) { - self->handleHandshake(err); - } - }; - tlsSocket_->async_handshake(ASIO::ssl::stream::client, - ASIO::bind_executor(strand_, callback)); + tlsSocket_->async_handshake( + ASIO::ssl::stream::client, + [this, self{shared_from_this()}](const auto& err) { handleHandshake(err); }); } else { handleHandshake(ASIO_SUCCESS); } } else { - LOG_ERROR(cnxString_ << "Failed to establish connection to " << endpoint << ": " << err.message()); + LOG_ERROR(cnxString() << "Failed to establish connection to " << endpoint << ": " << err.message()); + { + std::lock_guard lock{mutex_}; + if (isClosed()) { + return; + } + cancelTimer(*connectTimer_); + } if (err == ASIO::error::operation_aborted) { close(); } else { @@ -508,10 +499,10 @@ void ClientConnection::handleTcpConnected(const ASIO_ERROR& err, const tcp::endp void ClientConnection::handleHandshake(const ASIO_ERROR& err) { if (err) { if (err.value() == ASIO::ssl::error::stream_truncated) { - LOG_WARN(cnxString_ << "Handshake failed: " << err.message()); + LOG_WARN(cnxString() << "Handshake failed: " << err.message()); close(ResultRetryable); } else { - LOG_ERROR(cnxString_ << "Handshake failed: " << err.message()); + LOG_ERROR(cnxString() << "Handshake failed: " << err.message()); close(); } return; @@ -524,12 +515,12 @@ void ClientConnection::handleHandshake(const ASIO_ERROR& err) { buffer = Commands::newConnect(authentication_, logicalAddress_, connectingThroughProxy, clientVersion_, result); } catch (const std::exception& e) { - LOG_ERROR(cnxString_ << "Failed to create Connect command: " << e.what()); + LOG_ERROR(cnxString() << "Failed to create Connect command: " << e.what()); close(ResultAuthenticationError); return; } if (result != ResultOk) { - LOG_ERROR(cnxString_ << "Failed to establish connection: " << result); + LOG_ERROR(cnxString() << "Failed to establish connection: " << result); close(result); return; } @@ -546,7 +537,7 @@ void ClientConnection::handleSentPulsarConnect(const ASIO_ERROR& err, const Shar return; } if (err) { - LOG_ERROR(cnxString_ << "Failed to establish connection: " << err.message()); + LOG_ERROR(cnxString() << "Failed to establish connection: " << err.message()); close(); return; } @@ -560,7 +551,7 @@ void ClientConnection::handleSentAuthResponse(const ASIO_ERROR& err, const Share return; } if (err) { - LOG_WARN(cnxString_ << "Failed to send auth response: " << err.message()); + LOG_WARN(cnxString() << "Failed to send auth response: " << err.message()); close(); return; } @@ -581,73 +572,65 @@ void ClientConnection::tcpConnectAsync() { Url service_url; std::string hostUrl = isSniProxy_ ? proxyServiceUrl_ : physicalAddress_; if (!Url::parse(hostUrl, service_url)) { - LOG_ERROR(cnxString_ << "Invalid Url, unable to parse: " << err << " " << err.message()); + LOG_ERROR(cnxString() << "Invalid Url, unable to parse: " << err << " " << err.message()); close(); return; } if (service_url.protocol() != "pulsar" && service_url.protocol() != "pulsar+ssl") { - LOG_ERROR(cnxString_ << "Invalid Url protocol '" << service_url.protocol() - << "'. Valid values are 'pulsar' and 'pulsar+ssl'"); + LOG_ERROR(cnxString() << "Invalid Url protocol '" << service_url.protocol() + << "'. Valid values are 'pulsar' and 'pulsar+ssl'"); close(); return; } - LOG_DEBUG(cnxString_ << "Resolving " << service_url.host() << ":" << service_url.port()); + LOG_DEBUG(cnxString() << "Resolving " << service_url.host() << ":" << service_url.port()); - auto weakSelf = weak_from_this(); - resolver_->async_resolve(service_url.host(), std::to_string(service_url.port()), - [weakSelf](auto err, const auto& results) { - auto self = weakSelf.lock(); - if (self) { - self->handleResolve(err, results); - } - }); + resolver_->async_resolve( + service_url.host(), std::to_string(service_url.port()), + [this, self{shared_from_this()}](auto err, const auto& results) { handleResolve(err, results); }); } void ClientConnection::handleResolve(ASIO_ERROR err, const tcp::resolver::results_type& results) { if (err) { - std::string hostUrl = isSniProxy_ ? cnxString_ : proxyServiceUrl_; + std::string hostUrl = isSniProxy_ ? cnxString() : proxyServiceUrl_; LOG_ERROR(hostUrl << "Resolve error: " << err << " : " << err.message()); close(); return; } if (!results.empty()) { - LOG_DEBUG(cnxString_ << "Resolved " << results.size() << " endpoints"); + LOG_DEBUG(cnxString() << "Resolved " << results.size() << " endpoints"); for (const auto& entry : results) { const auto& ep = entry.endpoint(); - LOG_DEBUG(cnxString_ << " " << ep.address().to_string() << ":" << ep.port()); + LOG_DEBUG(cnxString() << " " << ep.address().to_string() << ":" << ep.port()); } } - auto weakSelf = weak_from_this(); - connectTimeoutTask_->setCallback([weakSelf, results = tcp::resolver::results_type(results)]( - const PeriodicTask::ErrorCode& ec) { - ClientConnectionPtr ptr = weakSelf.lock(); - if (!ptr) { - LOG_DEBUG("Connect timeout callback skipped: connection was already destroyed"); - return; - } + std::lock_guard lock{mutex_}; + if (isClosed()) { + return; + } - if (ptr->state_ != Ready) { - LOG_ERROR(ptr->cnxString_ << "Connection to " << results << " was not established in " - << ptr->connectTimeoutTask_->getPeriodMs() << " ms, close the socket"); - PeriodicTask::ErrorCode err; - ptr->socket_->close(err); - if (err) { - LOG_WARN(ptr->cnxString_ << "Failed to close socket: " << err.message()); - } - } - ptr->connectTimeoutTask_->stop(); - }); - connectTimeoutTask_->start(); - ASIO::async_connect(*socket_, results, [weakSelf](const ASIO_ERROR& err, const tcp::endpoint& endpoint) { - auto self = weakSelf.lock(); - if (self) { - self->handleTcpConnected(err, endpoint); + connectTimer_->expires_after(connectTimeout_); + connectTimer_->async_wait([this, results, self{shared_from_this()}](const auto& err) { + if (err) { + return; } + Lock lock{mutex_}; + if (!isClosed() && state_ != Ready) { + LOG_ERROR(cnxString() << "Connection to " << results << " was not established in " + << connectTimeout_.count() << " ms"); + lock.unlock(); + close(); + } // else: the connection is closed or already established }); + + ASIO::async_connect( + *socket_, results, + [this, self{shared_from_this()}](const ASIO_ERROR& err, const tcp::endpoint& endpoint) { + handleTcpConnected(err, endpoint); + }); } void ClientConnection::readNextCommand() { @@ -668,11 +651,11 @@ void ClientConnection::handleRead(const ASIO_ERROR& err, size_t bytesTransferred if (err || bytesTransferred == 0) { if (err == ASIO::error::operation_aborted) { - LOG_DEBUG(cnxString_ << "Read operation was canceled: " << err.message()); + LOG_DEBUG(cnxString() << "Read operation was canceled: " << err.message()); } else if (bytesTransferred == 0 || err == ASIO::error::eof) { - LOG_DEBUG(cnxString_ << "Server closed the connection: " << err.message()); + LOG_DEBUG(cnxString() << "Server closed the connection: " << err.message()); } else { - LOG_ERROR(cnxString_ << "Read operation failed: " << err.message()); + LOG_ERROR(cnxString() << "Read operation failed: " << err.message()); } close(ResultDisconnected); } else if (bytesTransferred < minReadSize) { @@ -724,7 +707,7 @@ void ClientConnection::processIncomingBuffer() { uint32_t cmdSize = incomingBuffer_.readUnsignedInt(); proto::BaseCommand incomingCmd; if (!incomingCmd.ParseFromArray(incomingBuffer_.data(), cmdSize)) { - LOG_ERROR(cnxString_ << "Error parsing protocol buffer command"); + LOG_ERROR(cnxString() << "Error parsing protocol buffer command"); close(ResultDisconnected); return; } @@ -744,11 +727,11 @@ void ClientConnection::processIncomingBuffer() { // broker entry metadata is present uint32_t brokerEntryMetadataSize = incomingBuffer_.readUnsignedInt(); if (!brokerEntryMetadata.ParseFromArray(incomingBuffer_.data(), brokerEntryMetadataSize)) { - LOG_ERROR(cnxString_ << "[consumer id " << incomingCmd.message().consumer_id() - << ", message ledger id " - << incomingCmd.message().message_id().ledgerid() << ", entry id " - << incomingCmd.message().message_id().entryid() - << "] Error parsing broker entry metadata"); + LOG_ERROR(cnxString() + << "[consumer id " << incomingCmd.message().consumer_id() + << ", message ledger id " << incomingCmd.message().message_id().ledgerid() + << ", entry id " << incomingCmd.message().message_id().entryid() + << "] Error parsing broker entry metadata"); close(ResultDisconnected); return; } @@ -762,11 +745,11 @@ void ClientConnection::processIncomingBuffer() { uint32_t metadataSize = incomingBuffer_.readUnsignedInt(); if (!msgMetadata.ParseFromArray(incomingBuffer_.data(), metadataSize)) { - LOG_ERROR(cnxString_ << "[consumer id " << incomingCmd.message().consumer_id() // - << ", message ledger id " - << incomingCmd.message().message_id().ledgerid() // - << ", entry id " << incomingCmd.message().message_id().entryid() - << "] Error parsing message metadata"); + LOG_ERROR(cnxString() + << "[consumer id " << incomingCmd.message().consumer_id() // + << ", message ledger id " << incomingCmd.message().message_id().ledgerid() // + << ", entry id " << incomingCmd.message().message_id().entryid() + << "] Error parsing message metadata"); close(ResultDisconnected); return; } @@ -839,8 +822,8 @@ bool ClientConnection::verifyChecksum(SharedBuffer& incomingBuffer_, uint32_t& r } void ClientConnection::handleActiveConsumerChange(const proto::CommandActiveConsumerChange& change) { - LOG_DEBUG(cnxString_ << "Received notification about active consumer change, consumer_id: " - << change.consumer_id() << " isActive: " << change.is_active()); + LOG_DEBUG(cnxString() << "Received notification about active consumer change, consumer_id: " + << change.consumer_id() << " isActive: " << change.is_active()); Lock lock(mutex_); ConsumersMap::iterator it = consumers_.find(change.consumer_id()); if (it != consumers_.end()) { @@ -851,19 +834,19 @@ void ClientConnection::handleActiveConsumerChange(const proto::CommandActiveCons consumer->activeConsumerChanged(change.is_active()); } else { consumers_.erase(change.consumer_id()); - LOG_DEBUG(cnxString_ << "Ignoring incoming message for already destroyed consumer " - << change.consumer_id()); + LOG_DEBUG(cnxString() << "Ignoring incoming message for already destroyed consumer " + << change.consumer_id()); } } else { - LOG_DEBUG(cnxString_ << "Got invalid consumer Id in " << change.consumer_id() - << " -- isActive: " << change.is_active()); + LOG_DEBUG(cnxString() << "Got invalid consumer Id in " << change.consumer_id() + << " -- isActive: " << change.is_active()); } } void ClientConnection::handleIncomingMessage(const proto::CommandMessage& msg, bool isChecksumValid, proto::BrokerEntryMetadata& brokerEntryMetadata, proto::MessageMetadata& msgMetadata, SharedBuffer& payload) { - LOG_DEBUG(cnxString_ << "Received a message from the server for consumer: " << msg.consumer_id()); + LOG_DEBUG(cnxString() << "Received a message from the server for consumer: " << msg.consumer_id()); Lock lock(mutex_); ConsumersMap::iterator it = consumers_.find(msg.consumer_id()); @@ -878,21 +861,21 @@ void ClientConnection::handleIncomingMessage(const proto::CommandMessage& msg, b msgMetadata, payload); } else { consumers_.erase(msg.consumer_id()); - LOG_DEBUG(cnxString_ << "Ignoring incoming message for already destroyed consumer " - << msg.consumer_id()); + LOG_DEBUG(cnxString() << "Ignoring incoming message for already destroyed consumer " + << msg.consumer_id()); } } else { - LOG_DEBUG(cnxString_ << "Got invalid consumer Id in " // - << msg.consumer_id() << " -- msg: " << msgMetadata.sequence_id()); + LOG_DEBUG(cnxString() << "Got invalid consumer Id in " // + << msg.consumer_id() << " -- msg: " << msgMetadata.sequence_id()); } } void ClientConnection::handleIncomingCommand(BaseCommand& incomingCmd) { - LOG_DEBUG(cnxString_ << "Handling incoming command: " << Commands::messageType(incomingCmd.type())); + LOG_DEBUG(cnxString() << "Handling incoming command: " << Commands::messageType(incomingCmd.type())); switch (state_.load()) { case Pending: { - LOG_ERROR(cnxString_ << "Connection is not ready yet"); + LOG_ERROR(cnxString() << "Connection is not ready yet"); break; } @@ -908,7 +891,7 @@ void ClientConnection::handleIncomingCommand(BaseCommand& incomingCmd) { } case Disconnected: { - LOG_ERROR(cnxString_ << "Connection already disconnected"); + LOG_ERROR(cnxString() << "Connection already disconnected"); break; } @@ -967,12 +950,12 @@ void ClientConnection::handleIncomingCommand(BaseCommand& incomingCmd) { case BaseCommand::PING: // Respond to ping request - LOG_DEBUG(cnxString_ << "Replying to ping command"); + LOG_DEBUG(cnxString() << "Replying to ping command"); sendCommand(Commands::newPong()); break; case BaseCommand::PONG: - LOG_DEBUG(cnxString_ << "Received response to ping message"); + LOG_DEBUG(cnxString() << "Received response to ping message"); break; case BaseCommand::AUTH_CHALLENGE: @@ -1000,7 +983,7 @@ void ClientConnection::handleIncomingCommand(BaseCommand& incomingCmd) { break; default: - LOG_WARN(cnxString_ << "Received invalid message from server"); + LOG_WARN(cnxString() << "Received invalid message from server"); close(ResultDisconnected); break; } @@ -1014,7 +997,7 @@ Future ClientConnection::newConsumerStats(uint6 Promise promise; if (isClosed()) { lock.unlock(); - LOG_ERROR(cnxString_ << " Client is not connected to the broker"); + LOG_ERROR(cnxString() << " Client is not connected to the broker"); promise.setFailed(ResultNotConnected); return promise.getFuture(); } @@ -1059,18 +1042,14 @@ void ClientConnection::newLookup(const SharedBuffer& cmd, uint64_t requestId, co requestData.promise = promise; requestData.timer = executor_->createDeadlineTimer(); requestData.timer->expires_after(operationsTimeout_); - auto weakSelf = weak_from_this(); - requestData.timer->async_wait([weakSelf, requestData](const ASIO_ERROR& ec) { - auto self = weakSelf.lock(); - if (self) { - self->handleLookupTimeout(ec, requestData); - } + requestData.timer->async_wait([this, self{shared_from_this()}, requestData](const ASIO_ERROR& ec) { + handleLookupTimeout(ec, requestData); }); pendingLookupRequests_.insert(std::make_pair(requestId, requestData)); numOfPendingLookupRequest_++; lock.unlock(); - LOG_DEBUG(cnxString_ << "Inserted lookup request " << requestType << " (req_id: " << requestId << ")"); + LOG_DEBUG(cnxString() << "Inserted lookup request " << requestType << " (req_id: " << requestId << ")"); sendCommand(cmd); } @@ -1079,18 +1058,7 @@ void ClientConnection::sendCommand(const SharedBuffer& cmd) { if (pendingWriteOperations_++ == 0) { // Write immediately to socket - if (tlsSocket_) { - auto weakSelf = weak_from_this(); - auto callback = [weakSelf, cmd]() { - auto self = weakSelf.lock(); - if (self) { - self->sendCommandInternal(cmd); - } - }; - ASIO::post(strand_, callback); - } else { - sendCommandInternal(cmd); - } + executor_->dispatch([this, cmd, self{shared_from_this()}] { sendCommandInternal(cmd); }); } else { // Queue to send later pendingWriteBuffers_.push_back(cmd); @@ -1122,11 +1090,7 @@ void ClientConnection::sendMessage(const std::shared_ptr& args) { handleSendPair(err); })); }; - if (tlsSocket_) { - ASIO::post(strand_, sendMessageInternal); - } else { - sendMessageInternal(); - } + executor_->dispatch(sendMessageInternal); } void ClientConnection::handleSend(const ASIO_ERROR& err, const SharedBuffer&) { @@ -1134,7 +1098,7 @@ void ClientConnection::handleSend(const ASIO_ERROR& err, const SharedBuffer&) { return; } if (err) { - LOG_WARN(cnxString_ << "Could not send message on connection: " << err << " " << err.message()); + LOG_WARN(cnxString() << "Could not send message on connection: " << err << " " << err.message()); close(ResultDisconnected); } else { sendPendingCommands(); @@ -1146,7 +1110,7 @@ void ClientConnection::handleSendPair(const ASIO_ERROR& err) { return; } if (err) { - LOG_WARN(cnxString_ << "Could not send pair message on connection: " << err << " " << err.message()); + LOG_WARN(cnxString() << "Could not send pair message on connection: " << err << " " << err.message()); close(ResultDisconnected); } else { sendPendingCommands(); @@ -1194,8 +1158,8 @@ Future ClientConnection::sendRequestWithId(const SharedBuf if (isClosed()) { lock.unlock(); Promise promise; - LOG_DEBUG(cnxString_ << "Fail " << requestType << "(req_id: " << requestId - << ") to a closed connection"); + LOG_DEBUG(cnxString() << "Fail " << requestType << "(req_id: " << requestId + << ") to a closed connection"); promise.setFailed(ResultNotConnected); return promise.getFuture(); } @@ -1203,21 +1167,17 @@ Future ClientConnection::sendRequestWithId(const SharedBuf PendingRequestData requestData; requestData.timer = executor_->createDeadlineTimer(); requestData.timer->expires_after(operationsTimeout_); - auto weakSelf = weak_from_this(); - requestData.timer->async_wait([weakSelf, requestData](const ASIO_ERROR& ec) { - auto self = weakSelf.lock(); - if (self) { - self->handleRequestTimeout(ec, requestData); - } + requestData.timer->async_wait([this, self{shared_from_this()}, requestData](const ASIO_ERROR& ec) { + handleRequestTimeout(ec, requestData); }); pendingRequests_.insert(std::make_pair(requestId, requestData)); lock.unlock(); - LOG_DEBUG(cnxString_ << "Inserted request " << requestType << " (req_id: " << requestId << ")"); + LOG_DEBUG(cnxString() << "Inserted request " << requestType << " (req_id: " << requestId << ")"); if (mockingRequests_.load(std::memory_order_acquire)) { if (mockServer_ == nullptr) { - LOG_WARN(cnxString_ << "Mock server is unexpectedly null when processing " << requestType); + LOG_WARN(cnxString() << "Mock server is unexpectedly null when processing " << requestType); sendCommand(cmd); } else if (!mockServer_->sendRequest(requestType, requestId)) { sendCommand(cmd); @@ -1231,7 +1191,7 @@ Future ClientConnection::sendRequestWithId(const SharedBuf void ClientConnection::handleRequestTimeout(const ASIO_ERROR& ec, const PendingRequestData& pendingRequestData) { if (!ec && !pendingRequestData.hasGotResponse->load()) { - LOG_WARN(cnxString_ << "Network request timeout to broker, remote: " << physicalAddress_); + LOG_WARN(cnxString() << "Network request timeout to broker, remote: " << physicalAddress_); pendingRequestData.promise.setFailed(ResultTimeout); } } @@ -1239,7 +1199,7 @@ void ClientConnection::handleRequestTimeout(const ASIO_ERROR& ec, void ClientConnection::handleLookupTimeout(const ASIO_ERROR& ec, const LookupRequestData& pendingRequestData) { if (!ec) { - LOG_WARN(cnxString_ << "Lookup request timeout to broker, remote: " << physicalAddress_); + LOG_WARN(cnxString() << "Lookup request timeout to broker, remote: " << physicalAddress_); pendingRequestData.promise->setFailed(ResultTimeout); } } @@ -1247,22 +1207,22 @@ void ClientConnection::handleLookupTimeout(const ASIO_ERROR& ec, void ClientConnection::handleGetLastMessageIdTimeout(const ASIO_ERROR& ec, const ClientConnection::LastMessageIdRequestData& data) { if (!ec) { - LOG_WARN(cnxString_ << "GetLastMessageId request timeout to broker, remote: " << physicalAddress_); + LOG_WARN(cnxString() << "GetLastMessageId request timeout to broker, remote: " << physicalAddress_); data.promise->setFailed(ResultTimeout); } } -void ClientConnection::handleKeepAliveTimeout() { - if (isClosed()) { +void ClientConnection::handleKeepAliveTimeout(const ASIO_ERROR& ec) { + if (isClosed() || ec) { return; } if (havePendingPingRequest_) { - LOG_WARN(cnxString_ << "Forcing connection to close after keep-alive timeout"); + LOG_WARN(cnxString() << "Forcing connection to close after keep-alive timeout"); close(ResultDisconnected); } else { // Send keep alive probe to peer - LOG_DEBUG(cnxString_ << "Sending ping message"); + LOG_DEBUG(cnxString() << "Sending ping message"); havePendingPingRequest_ = true; sendCommand(Commands::newPing()); @@ -1271,13 +1231,8 @@ void ClientConnection::handleKeepAliveTimeout() { Lock lock(mutex_); if (keepAliveTimer_) { keepAliveTimer_->expires_after(std::chrono::seconds(keepAliveIntervalInSeconds_)); - auto weakSelf = weak_from_this(); - keepAliveTimer_->async_wait([weakSelf](const ASIO_ERROR&) { - auto self = weakSelf.lock(); - if (self) { - self->handleKeepAliveTimeout(); - } - }); + keepAliveTimer_->async_wait( + [this, self{shared_from_this()}](const auto& err) { handleKeepAliveTimeout(err); }); } lock.unlock(); } @@ -1286,39 +1241,32 @@ void ClientConnection::handleKeepAliveTimeout() { void ClientConnection::handleConsumerStatsTimeout(const ASIO_ERROR& ec, const std::vector& consumerStatsRequests) { if (ec) { - LOG_DEBUG(cnxString_ << " Ignoring timer cancelled event, code[" << ec << "]"); + LOG_DEBUG(cnxString() << " Ignoring timer cancelled event, code[" << ec << "]"); return; } startConsumerStatsTimer(consumerStatsRequests); } -void ClientConnection::close(Result result, bool detach) { +const std::future& ClientConnection::close(Result result) { Lock lock(mutex_); - if (isClosed()) { - return; - } + if (closeFuture_) { + connectPromise_.setFailed(result); + return *closeFuture_; + } + auto promise = std::make_shared>(); + closeFuture_ = promise->get_future(); + // The atomic update on state_ guarantees the previous modification on closeFuture_ is visible once the + // atomic read on state_ returns Disconnected `isClosed()`. + // However, it cannot prevent the race like: + // 1. thread 1: Check `isClosed()`, which returns false. + // 2. thread 2: call `close()`, now, `state_` becomes Disconnected, and `closeFuture_` is set. + // 3. thread 1: post the `async_write` to the `io_context`, + // 4. io thread: call `socket_->close()` + // 5. io thread: execute `async_write` on `socket_`, which has been closed + // However, even the race happens, it's still safe because all the socket operations happen in the same + // io thread, the `async_write` operation will simply fail with an error, no crash will happen. state_ = Disconnected; - if (socket_) { - ASIO_ERROR err; - socket_->shutdown(ASIO::socket_base::shutdown_both, err); - socket_->close(err); - if (err) { - LOG_WARN(cnxString_ << "Failed to close socket: " << err.message()); - } - } - if (tlsSocket_) { - ASIO_ERROR err; - tlsSocket_->lowest_layer().close(err); - if (err) { - LOG_WARN(cnxString_ << "Failed to close TLS socket: " << err.message()); - } - } - - if (executor_) { - executor_.reset(); - } - // Move the internal fields to process them after `mutex_` was unlocked auto consumers = std::move(consumers_); auto producers = std::move(producers_); @@ -1341,21 +1289,37 @@ void ClientConnection::close(Result result, bool detach) { consumerStatsRequestTimer_.reset(); } - if (connectTimeoutTask_) { - connectTimeoutTask_->stop(); - } - + cancelTimer(*connectTimer_); lock.unlock(); int refCount = weak_from_this().use_count(); if (!isResultRetryable(result)) { - LOG_ERROR(cnxString_ << "Connection closed with " << result << " (refCnt: " << refCount << ")"); + LOG_ERROR(cnxString() << "Connection closed with " << result << " (refCnt: " << refCount << ")"); } else { - LOG_INFO(cnxString_ << "Connection disconnected (refCnt: " << refCount << ")"); + LOG_INFO(cnxString() << "Connection disconnected (refCnt: " << refCount << ")"); } // Remove the connection from the pool before completing any promise - if (detach) { - pool_.remove(logicalAddress_, physicalAddress_, poolIndex_, this); - } + pool_.remove(logicalAddress_, physicalAddress_, poolIndex_, this); + + // Close the socket after removing itself from the pool so that other requests won't be able to acquire + // this connection after the socket is closed. + executor_->dispatch([this, promise, self{shared_from_this()}] { + // According to asio document, ip::tcp::socket and ssl::stream are unsafe as shared objects, so the + // methods must be called within the same implicit or explicit strand. + // The implementation of `ExecutorService` guarantees the internal `io_context::run()` is only called + // in one thread, so we can safely call the socket methods without posting to a strand instance. + ASIO_ERROR err; + socket_->shutdown(ASIO::socket_base::shutdown_both, err); + socket_->close(err); + if (err) { + LOG_WARN(cnxString() << "Failed to close socket: " << err.message()); + } + if (tlsSocket_) { + auto tlsSocket = tlsSocket_; + tlsSocket->async_shutdown([promise, self, tlsSocket](const auto&) { promise->set_value(); }); + } else { + promise->set_value(); + } + }); auto self = shared_from_this(); for (ProducersMap::iterator it = producers.begin(); it != producers.end(); ++it) { @@ -1377,24 +1341,25 @@ void ClientConnection::close(Result result, bool detach) { // Fail all pending requests, all these type are map whose value type contains the Promise object for (auto& kv : pendingRequests) { - kv.second.promise.setFailed(result); + kv.second.fail(result); } for (auto& kv : pendingLookupRequests) { - kv.second.promise->setFailed(result); + kv.second.fail(result); } for (auto& kv : pendingConsumerStatsMap) { - LOG_ERROR(cnxString_ << " Closing Client Connection, please try again later"); + LOG_ERROR(cnxString() << " Closing Client Connection, please try again later"); kv.second.setFailed(result); } for (auto& kv : pendingGetLastMessageIdRequests) { - kv.second.promise->setFailed(result); + kv.second.fail(result); } for (auto& kv : pendingGetNamespaceTopicsRequests) { kv.second.setFailed(result); } for (auto& kv : pendingGetSchemaRequests) { - kv.second.promise.setFailed(result); + kv.second.fail(result); } + return *closeFuture_; } bool ClientConnection::isClosed() const { return state_ == Disconnected; } @@ -1425,8 +1390,6 @@ void ClientConnection::removeConsumer(int consumerId) { const std::string& ClientConnection::brokerAddress() const { return physicalAddress_; } -const std::string& ClientConnection::cnxString() const { return cnxString_; } - int ClientConnection::getServerProtocolVersion() const { return serverProtocolVersion_; } int32_t ClientConnection::getMaxMessageSize() { return maxMessageSize_.load(std::memory_order_acquire); } @@ -1441,7 +1404,7 @@ Future ClientConnection::newGetLastMessageId(u auto promise = std::make_shared(); if (isClosed()) { lock.unlock(); - LOG_ERROR(cnxString_ << " Client is not connected to the broker"); + LOG_ERROR(cnxString() << " Client is not connected to the broker"); promise->setFailed(ResultNotConnected); return promise->getFuture(); } @@ -1450,12 +1413,8 @@ Future ClientConnection::newGetLastMessageId(u requestData.promise = promise; requestData.timer = executor_->createDeadlineTimer(); requestData.timer->expires_after(operationsTimeout_); - auto weakSelf = weak_from_this(); - requestData.timer->async_wait([weakSelf, requestData](const ASIO_ERROR& ec) { - auto self = weakSelf.lock(); - if (self) { - self->handleGetLastMessageIdTimeout(ec, requestData); - } + requestData.timer->async_wait([this, self{shared_from_this()}, requestData](const ASIO_ERROR& ec) { + handleGetLastMessageIdTimeout(ec, requestData); }); pendingGetLastMessageIdRequests_.insert(std::make_pair(requestId, requestData)); lock.unlock(); @@ -1469,7 +1428,7 @@ Future ClientConnection::newGetTopicsOfNamespace( Promise promise; if (isClosed()) { lock.unlock(); - LOG_ERROR(cnxString_ << "Client is not connected to the broker"); + LOG_ERROR(cnxString() << "Client is not connected to the broker"); promise.setFailed(ResultNotConnected); return promise.getFuture(); } @@ -1487,7 +1446,7 @@ Future ClientConnection::newGetSchema(const std::string& top Promise promise; if (isClosed()) { lock.unlock(); - LOG_ERROR(cnxString_ << "Client is not connected to the broker"); + LOG_ERROR(cnxString() << "Client is not connected to the broker"); promise.setFailed(ResultNotConnected); return promise.getFuture(); } @@ -1496,11 +1455,9 @@ Future ClientConnection::newGetSchema(const std::string& top pendingGetSchemaRequests_.emplace(requestId, GetSchemaRequest{promise, timer}); lock.unlock(); - auto weakSelf = weak_from_this(); timer->expires_after(operationsTimeout_); - timer->async_wait([this, weakSelf, requestId](const ASIO_ERROR& ec) { - auto self = weakSelf.lock(); - if (!self) { + timer->async_wait([this, self{shared_from_this()}, requestId](const ASIO_ERROR& ec) { + if (ec) { return; } Lock lock(mutex_); @@ -1527,8 +1484,8 @@ void ClientConnection::handleSendReceipt(const proto::CommandSendReceipt& sendRe const proto::MessageIdData& messageIdData = sendReceipt.message_id(); auto messageId = toMessageId(messageIdData); - LOG_DEBUG(cnxString_ << "Got receipt for producer: " << producerId << " -- msg: " << sequenceId - << "-- message id: " << messageId); + LOG_DEBUG(cnxString() << "Got receipt for producer: " << producerId << " -- msg: " << sequenceId + << "-- message id: " << messageId); Lock lock(mutex_); auto it = producers_.find(producerId); @@ -1544,13 +1501,13 @@ void ClientConnection::handleSendReceipt(const proto::CommandSendReceipt& sendRe } } } else { - LOG_ERROR(cnxString_ << "Got invalid producer Id in SendReceipt: " // - << producerId << " -- msg: " << sequenceId); + LOG_ERROR(cnxString() << "Got invalid producer Id in SendReceipt: " // + << producerId << " -- msg: " << sequenceId); } } void ClientConnection::handleSendError(const proto::CommandSendError& error) { - LOG_WARN(cnxString_ << "Received send error from server: " << error.message()); + LOG_WARN(cnxString() << "Received send error from server: " << error.message()); if (ChecksumError == error.error()) { long producerId = error.producer_id(); long sequenceId = error.sequence_id(); @@ -1574,7 +1531,7 @@ void ClientConnection::handleSendError(const proto::CommandSendError& error) { } void ClientConnection::handleSuccess(const proto::CommandSuccess& success) { - LOG_DEBUG(cnxString_ << "Received success response from server. req_id: " << success.request_id()); + LOG_DEBUG(cnxString() << "Received success response from server. req_id: " << success.request_id()); Lock lock(mutex_); auto it = pendingRequests_.find(success.request_id()); @@ -1590,8 +1547,8 @@ void ClientConnection::handleSuccess(const proto::CommandSuccess& success) { void ClientConnection::handlePartitionedMetadataResponse( const proto::CommandPartitionedTopicMetadataResponse& partitionMetadataResponse) { - LOG_DEBUG(cnxString_ << "Received partition-metadata response from server. req_id: " - << partitionMetadataResponse.request_id()); + LOG_DEBUG(cnxString() << "Received partition-metadata response from server. req_id: " + << partitionMetadataResponse.request_id()); Lock lock(mutex_); auto it = pendingLookupRequests_.find(partitionMetadataResponse.request_id()); @@ -1607,16 +1564,16 @@ void ClientConnection::handlePartitionedMetadataResponse( (partitionMetadataResponse.response() == proto::CommandPartitionedTopicMetadataResponse::Failed)) { if (partitionMetadataResponse.has_error()) { - LOG_ERROR(cnxString_ << "Failed partition-metadata lookup req_id: " - << partitionMetadataResponse.request_id() - << " error: " << partitionMetadataResponse.error() - << " msg: " << partitionMetadataResponse.message()); + LOG_ERROR(cnxString() << "Failed partition-metadata lookup req_id: " + << partitionMetadataResponse.request_id() + << " error: " << partitionMetadataResponse.error() + << " msg: " << partitionMetadataResponse.message()); checkServerError(partitionMetadataResponse.error(), partitionMetadataResponse.message()); lookupDataPromise->setFailed( getResult(partitionMetadataResponse.error(), partitionMetadataResponse.message())); } else { - LOG_ERROR(cnxString_ << "Failed partition-metadata lookup req_id: " - << partitionMetadataResponse.request_id() << " with empty response: "); + LOG_ERROR(cnxString() << "Failed partition-metadata lookup req_id: " + << partitionMetadataResponse.request_id() << " with empty response: "); lookupDataPromise->setFailed(ResultConnectError); } } else { @@ -1632,9 +1589,9 @@ void ClientConnection::handlePartitionedMetadataResponse( void ClientConnection::handleConsumerStatsResponse( const proto::CommandConsumerStatsResponse& consumerStatsResponse) { - LOG_DEBUG(cnxString_ << "ConsumerStatsResponse command - Received consumer stats " - "response from server. req_id: " - << consumerStatsResponse.request_id()); + LOG_DEBUG(cnxString() << "ConsumerStatsResponse command - Received consumer stats " + "response from server. req_id: " + << consumerStatsResponse.request_id()); Lock lock(mutex_); auto it = pendingConsumerStatsMap_.find(consumerStatsResponse.request_id()); if (it != pendingConsumerStatsMap_.end()) { @@ -1644,15 +1601,15 @@ void ClientConnection::handleConsumerStatsResponse( if (consumerStatsResponse.has_error_code()) { if (consumerStatsResponse.has_error_message()) { - LOG_ERROR(cnxString_ << " Failed to get consumer stats - " - << consumerStatsResponse.error_message()); + LOG_ERROR(cnxString() + << " Failed to get consumer stats - " << consumerStatsResponse.error_message()); } consumerStatsPromise.setFailed( getResult(consumerStatsResponse.error_code(), consumerStatsResponse.error_message())); } else { - LOG_DEBUG(cnxString_ << "ConsumerStatsResponse command - Received consumer stats " - "response from server. req_id: " - << consumerStatsResponse.request_id() << " Stats: "); + LOG_DEBUG(cnxString() << "ConsumerStatsResponse command - Received consumer stats " + "response from server. req_id: " + << consumerStatsResponse.request_id() << " Stats: "); BrokerConsumerStatsImpl brokerStats( consumerStatsResponse.msgrateout(), consumerStatsResponse.msgthroughputout(), consumerStatsResponse.msgrateredeliver(), consumerStatsResponse.consumername(), @@ -1682,25 +1639,25 @@ void ClientConnection::handleLookupTopicRespose( if (!lookupTopicResponse.has_response() || (lookupTopicResponse.response() == proto::CommandLookupTopicResponse::Failed)) { if (lookupTopicResponse.has_error()) { - LOG_ERROR(cnxString_ << "Failed lookup req_id: " << lookupTopicResponse.request_id() - << " error: " << lookupTopicResponse.error() - << " msg: " << lookupTopicResponse.message()); + LOG_ERROR(cnxString() << "Failed lookup req_id: " << lookupTopicResponse.request_id() + << " error: " << lookupTopicResponse.error() + << " msg: " << lookupTopicResponse.message()); checkServerError(lookupTopicResponse.error(), lookupTopicResponse.message()); lookupDataPromise->setFailed( getResult(lookupTopicResponse.error(), lookupTopicResponse.message())); } else { - LOG_ERROR(cnxString_ << "Failed lookup req_id: " << lookupTopicResponse.request_id() - << " with empty response: "); + LOG_ERROR(cnxString() << "Failed lookup req_id: " << lookupTopicResponse.request_id() + << " with empty response: "); lookupDataPromise->setFailed(ResultConnectError); } } else { - LOG_DEBUG(cnxString_ << "Received lookup response from server. req_id: " - << lookupTopicResponse.request_id() // - << " -- broker-url: " << lookupTopicResponse.brokerserviceurl() - << " -- broker-tls-url: " // - << lookupTopicResponse.brokerserviceurltls() - << " authoritative: " << lookupTopicResponse.authoritative() // - << " redirect: " << lookupTopicResponse.response()); + LOG_DEBUG(cnxString() << "Received lookup response from server. req_id: " + << lookupTopicResponse.request_id() // + << " -- broker-url: " << lookupTopicResponse.brokerserviceurl() + << " -- broker-tls-url: " // + << lookupTopicResponse.brokerserviceurltls() + << " authoritative: " << lookupTopicResponse.authoritative() // + << " redirect: " << lookupTopicResponse.response()); LookupDataResultPtr lookupResultPtr = std::make_shared(); if (tlsSocket_) { @@ -1723,17 +1680,18 @@ void ClientConnection::handleLookupTopicRespose( } void ClientConnection::handleProducerSuccess(const proto::CommandProducerSuccess& producerSuccess) { - LOG_DEBUG(cnxString_ << "Received success producer response from server. req_id: " - << producerSuccess.request_id() // - << " -- producer name: " << producerSuccess.producer_name()); + LOG_DEBUG(cnxString() << "Received success producer response from server. req_id: " + << producerSuccess.request_id() // + << " -- producer name: " << producerSuccess.producer_name()); Lock lock(mutex_); auto it = pendingRequests_.find(producerSuccess.request_id()); if (it != pendingRequests_.end()) { PendingRequestData requestData = it->second; if (!producerSuccess.producer_ready()) { - LOG_INFO(cnxString_ << " Producer " << producerSuccess.producer_name() - << " has been queued up at broker. req_id: " << producerSuccess.request_id()); + LOG_INFO(cnxString() << " Producer " << producerSuccess.producer_name() + << " has been queued up at broker. req_id: " + << producerSuccess.request_id()); requestData.hasGotResponse->store(true); lock.unlock(); } else { @@ -1758,9 +1716,9 @@ void ClientConnection::handleProducerSuccess(const proto::CommandProducerSuccess void ClientConnection::handleError(const proto::CommandError& error) { Result result = getResult(error.error(), error.message()); - LOG_WARN(cnxString_ << "Received error response from server: " << result - << (error.has_message() ? (" (" + error.message() + ")") : "") - << " -- req_id: " << error.request_id()); + LOG_WARN(cnxString() << "Received error response from server: " << result + << (error.has_message() ? (" (" + error.message() + ")") : "") + << " -- req_id: " << error.request_id()); Lock lock(mutex_); @@ -1890,7 +1848,7 @@ void ClientConnection::handleCloseProducer(const proto::CommandCloseProducer& cl producer->disconnectProducer(assignedBrokerServiceUrl); } } else { - LOG_ERROR(cnxString_ << "Got invalid producer Id in closeProducer command: " << producerId); + LOG_ERROR(cnxString() << "Got invalid producer Id in closeProducer command: " << producerId); } } @@ -1911,17 +1869,17 @@ void ClientConnection::handleCloseConsumer(const proto::CommandCloseConsumer& cl consumer->disconnectConsumer(assignedBrokerServiceUrl); } } else { - LOG_ERROR(cnxString_ << "Got invalid consumer Id in closeConsumer command: " << consumerId); + LOG_ERROR(cnxString() << "Got invalid consumer Id in closeConsumer command: " << consumerId); } } void ClientConnection::handleAuthChallenge() { - LOG_DEBUG(cnxString_ << "Received auth challenge from broker"); + LOG_DEBUG(cnxString() << "Received auth challenge from broker"); Result result; SharedBuffer buffer = Commands::newAuthResponse(authentication_, result); if (result != ResultOk) { - LOG_ERROR(cnxString_ << "Failed to send auth response: " << result); + LOG_ERROR(cnxString() << "Failed to send auth response: " << result); close(result); return; } @@ -1934,8 +1892,8 @@ void ClientConnection::handleAuthChallenge() { void ClientConnection::handleGetLastMessageIdResponse( const proto::CommandGetLastMessageIdResponse& getLastMessageIdResponse) { - LOG_DEBUG(cnxString_ << "Received getLastMessageIdResponse from server. req_id: " - << getLastMessageIdResponse.request_id()); + LOG_DEBUG(cnxString() << "Received getLastMessageIdResponse from server. req_id: " + << getLastMessageIdResponse.request_id()); Lock lock(mutex_); auto it = pendingGetLastMessageIdRequests_.find(getLastMessageIdResponse.request_id()); @@ -1961,8 +1919,8 @@ void ClientConnection::handleGetLastMessageIdResponse( void ClientConnection::handleGetTopicOfNamespaceResponse( const proto::CommandGetTopicsOfNamespaceResponse& response) { - LOG_DEBUG(cnxString_ << "Received GetTopicsOfNamespaceResponse from server. req_id: " - << response.request_id() << " topicsSize" << response.topics_size()); + LOG_DEBUG(cnxString() << "Received GetTopicsOfNamespaceResponse from server. req_id: " + << response.request_id() << " topicsSize" << response.topics_size()); Lock lock(mutex_); auto it = pendingGetNamespaceTopicsRequests_.find(response.request_id()); @@ -2001,7 +1959,7 @@ void ClientConnection::handleGetTopicOfNamespaceResponse( } void ClientConnection::handleGetSchemaResponse(const proto::CommandGetSchemaResponse& response) { - LOG_DEBUG(cnxString_ << "Received GetSchemaResponse from server. req_id: " << response.request_id()); + LOG_DEBUG(cnxString() << "Received GetSchemaResponse from server. req_id: " << response.request_id()); Lock lock(mutex_); auto it = pendingGetSchemaRequests_.find(response.request_id()); if (it != pendingGetSchemaRequests_.end()) { @@ -2012,10 +1970,11 @@ void ClientConnection::handleGetSchemaResponse(const proto::CommandGetSchemaResp if (response.has_error_code()) { Result result = getResult(response.error_code(), response.error_message()); if (response.error_code() != proto::TopicNotFound) { - LOG_WARN(cnxString_ << "Received error GetSchemaResponse from server " << result - << (response.has_error_message() ? (" (" + response.error_message() + ")") - : "") - << " -- req_id: " << response.request_id()); + LOG_WARN(cnxString() << "Received error GetSchemaResponse from server " << result + << (response.has_error_message() + ? (" (" + response.error_message() + ")") + : "") + << " -- req_id: " << response.request_id()); } getSchemaPromise.setFailed(result); return; @@ -2039,7 +1998,7 @@ void ClientConnection::handleGetSchemaResponse(const proto::CommandGetSchemaResp } void ClientConnection::handleAckResponse(const proto::CommandAckResponse& response) { - LOG_DEBUG(cnxString_ << "Received AckResponse from server. req_id: " << response.request_id()); + LOG_DEBUG(cnxString() << "Received AckResponse from server. req_id: " << response.request_id()); Lock lock(mutex_); auto it = pendingRequests_.find(response.request_id()); diff --git a/lib/ClientConnection.h b/lib/ClientConnection.h index b2770006..b9880ee2 100644 --- a/lib/ClientConnection.h +++ b/lib/ClientConnection.h @@ -25,6 +25,8 @@ #include #include #include +#include +#include #ifdef USE_ASIO #include #include @@ -41,8 +43,10 @@ #include #include #include +#include #include #include +#include #include #include "AsioTimer.h" @@ -156,11 +160,8 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this& close(Result result = ResultConnectError); bool isClosed() const; @@ -193,7 +194,7 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this promise; DeadlineTimerPtr timer; std::shared_ptr hasGotResponse{std::make_shared(false)}; + + void fail(Result result) { + cancelTimer(*timer); + promise.setFailed(result); + } }; struct LookupRequestData { LookupDataResultPromisePtr promise; DeadlineTimerPtr timer; + + void fail(Result result) { + cancelTimer(*timer); + promise->setFailed(result); + } }; struct LastMessageIdRequestData { GetLastMessageIdResponsePromisePtr promise; DeadlineTimerPtr timer; + + void fail(Result result) { + cancelTimer(*timer); + promise->setFailed(result); + } }; struct GetSchemaRequest { Promise promise; DeadlineTimerPtr timer; + + void fail(Result result) { + cancelTimer(*timer); + promise.setFailed(result); + } }; /* @@ -297,26 +318,26 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this - inline void asyncWrite(const ConstBufferSequence& buffers, WriteHandler handler) { + inline void asyncWrite(const ConstBufferSequence& buffers, WriteHandler&& handler) { if (isClosed()) { return; } if (tlsSocket_) { - ASIO::async_write(*tlsSocket_, buffers, ASIO::bind_executor(strand_, handler)); + ASIO::async_write(*tlsSocket_, buffers, std::forward(handler)); } else { - ASIO::async_write(*socket_, buffers, handler); + ASIO::async_write(*socket_, buffers, std::forward(handler)); } } template - inline void asyncReceive(const MutableBufferSequence& buffers, ReadHandler handler) { + inline void asyncReceive(const MutableBufferSequence& buffers, ReadHandler&& handler) { if (isClosed()) { return; } if (tlsSocket_) { - tlsSocket_->async_read_some(buffers, ASIO::bind_executor(strand_, handler)); + tlsSocket_->async_read_some(buffers, std::forward(handler)); } else { - socket_->async_receive(buffers, handler); + socket_->async_receive(buffers, std::forward(handler)); } } @@ -337,7 +358,6 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this strand_; const std::string logicalAddress_; /* @@ -350,7 +370,7 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this server:6650] - std::string cnxString_; + std::shared_ptr cnxStringPtr_; /* * indicates if async connection establishment failed @@ -360,7 +380,8 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this connectPromise_; - std::shared_ptr connectTimeoutTask_; + const std::chrono::milliseconds connectTimeout_; + const DeadlineTimerPtr connectTimer_; typedef std::map PendingRequestsMap; PendingRequestsMap pendingRequests_; @@ -419,6 +440,7 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this> closeFuture_; friend class PulsarFriend; friend class ConsumerTest; diff --git a/lib/ConnectionPool.cc b/lib/ConnectionPool.cc index df050b0e..c814cf85 100644 --- a/lib/ConnectionPool.cc +++ b/lib/ConnectionPool.cc @@ -54,16 +54,43 @@ bool ConnectionPool::close() { return false; } + std::vector connectionsToClose; + // ClientConnection::close() will remove the connection from the pool, which is not allowed when iterating + // over a map, so we store the connections to close in a vector first and don't iterate the pool when + // closing the connections. std::unique_lock lock(mutex_); + connectionsToClose.reserve(pool_.size()); + for (auto&& kv : pool_) { + connectionsToClose.emplace_back(kv.second); + } + pool_.clear(); + lock.unlock(); - for (auto cnxIt = pool_.begin(); cnxIt != pool_.end(); cnxIt++) { - auto& cnx = cnxIt->second; + for (auto&& cnx : connectionsToClose) { if (cnx) { - // The 2nd argument is false because removing a value during the iteration will cause segfault - cnx->close(ResultDisconnected, false); + // Close with a fatal error to not let client retry + auto& future = cnx->close(ResultAlreadyClosed); + using namespace std::chrono_literals; + if (auto status = future.wait_for(5s); status != std::future_status::ready) { + LOG_WARN("Connection close timed out for " << cnx.get()->cnxString()); + } + if (cnx.use_count() > 1) { + // There are some asynchronous operations that hold the reference on the connection, we should + // wait until them to finish. Otherwise, `io_context::stop()` will be called in + // `ClientImpl::shutdown()` when closing the `ExecutorServiceProvider`. Then + // `io_context::run()` will return and the `io_context` object will be destroyed. In this + // case, if there is any pending handler, it will crash. + for (int i = 0; i < 500 && cnx.use_count() > 1; i++) { + std::this_thread::sleep_for(10ms); + } + if (cnx.use_count() > 1) { + LOG_WARN("Connection still has " << (cnx.use_count() - 1) + << " references after waiting for 5 seconds for " + << cnx.get()->cnxString()); + } + } } } - pool_.clear(); return true; } diff --git a/lib/ExecutorService.cc b/lib/ExecutorService.cc index 99e2393f..eba74861 100644 --- a/lib/ExecutorService.cc +++ b/lib/ExecutorService.cc @@ -125,8 +125,6 @@ void ExecutorService::close(long timeoutMs) { } } -void ExecutorService::postWork(std::function task) { ASIO::post(io_context_, std::move(task)); } - ///////////////////// ExecutorServiceProvider::ExecutorServiceProvider(int nthreads) diff --git a/lib/ExecutorService.h b/lib/ExecutorService.h index 626cb203..80659d4b 100644 --- a/lib/ExecutorService.h +++ b/lib/ExecutorService.h @@ -23,12 +23,16 @@ #include #ifdef USE_ASIO +#include #include #include +#include #include #else +#include #include #include +#include #include #endif #include @@ -37,6 +41,7 @@ #include #include #include +#include #include "AsioTimer.h" @@ -62,7 +67,19 @@ class PULSAR_PUBLIC ExecutorService : public std::enable_shared_from_this task); + + // Execute the task in the event loop thread asynchronously, i.e. the task will be put in the event loop + // queue and executed later. + template + void postWork(T &&task) { + ASIO::post(io_context_, std::forward(task)); + } + + // Different from `postWork`, if it's already in the event loop, execute the task immediately + template + void dispatch(T &&task) { + ASIO::dispatch(io_context_, std::forward(task)); + } // See TimeoutProcessor for the semantics of the parameter. void close(long timeoutMs = 3000); diff --git a/lib/PeriodicTask.h b/lib/PeriodicTask.h index bc186348..ee19182a 100644 --- a/lib/PeriodicTask.h +++ b/lib/PeriodicTask.h @@ -53,7 +53,7 @@ class PeriodicTask : public std::enable_shared_from_this { void stop() noexcept; - void setCallback(CallbackType callback) noexcept { callback_ = callback; } + void setCallback(CallbackType&& callback) noexcept { callback_ = std::move(callback); } State getState() const noexcept { return state_; } int getPeriodMs() const noexcept { return periodMs_; } diff --git a/tests/BasicEndToEndTest.cc b/tests/BasicEndToEndTest.cc index 9a02df0c..d3c6e612 100644 --- a/tests/BasicEndToEndTest.cc +++ b/tests/BasicEndToEndTest.cc @@ -3188,7 +3188,17 @@ static void expectTimeoutOnRecv(Consumer &consumer) { ASSERT_EQ(ResultTimeout, res); } -void testNegativeAcks(const std::string &topic, bool batchingEnabled) { +static std::vector expectedNegativeAckMessages(size_t numMessages) { + std::vector expected; + expected.reserve(numMessages); + for (size_t i = 0; i < numMessages; i++) { + expected.emplace_back("test-" + std::to_string(i)); + } + return expected; +} + +void testNegativeAcks(const std::string &topic, bool batchingEnabled, bool expectOrdered = true) { + constexpr size_t numMessages = 10; Client client(lookupUrl); Consumer consumer; ConsumerConfiguration conf; @@ -3202,22 +3212,32 @@ void testNegativeAcks(const std::string &topic, bool batchingEnabled) { result = client.createProducer(topic, producerConf, producer); ASSERT_EQ(ResultOk, result); - for (int i = 0; i < 10; i++) { + for (size_t i = 0; i < numMessages; i++) { Message msg = MessageBuilder().setContent("test-" + std::to_string(i)).build(); producer.sendAsync(msg, nullptr); } producer.flush(); + std::vector receivedMessages; + receivedMessages.reserve(numMessages); std::vector toNeg; - for (int i = 0; i < 10; i++) { + for (size_t i = 0; i < numMessages; i++) { Message msg; consumer.receive(msg); LOG_INFO("Received message " << msg.getDataAsString()); - ASSERT_EQ(msg.getDataAsString(), "test-" + std::to_string(i)); + if (expectOrdered) { + ASSERT_EQ(msg.getDataAsString(), "test-" + std::to_string(i)); + } + receivedMessages.emplace_back(msg.getDataAsString()); toNeg.push_back(msg.getMessageId()); } + if (!expectOrdered) { + auto expectedMessages = expectedNegativeAckMessages(numMessages); + std::sort(receivedMessages.begin(), receivedMessages.end()); + ASSERT_EQ(expectedMessages, receivedMessages); + } // No more messages expected expectTimeoutOnRecv(consumer); @@ -3228,15 +3248,25 @@ void testNegativeAcks(const std::string &topic, bool batchingEnabled) { } PulsarFriend::setNegativeAckEnabled(consumer, true); - for (int i = 0; i < 10; i++) { + std::vector redeliveredMessages; + redeliveredMessages.reserve(numMessages); + for (size_t i = 0; i < numMessages; i++) { Message msg; consumer.receive(msg); LOG_INFO("-- Redelivery -- Received message " << msg.getDataAsString()); - ASSERT_EQ(msg.getDataAsString(), "test-" + std::to_string(i)); + if (expectOrdered) { + ASSERT_EQ(msg.getDataAsString(), "test-" + std::to_string(i)); + } + redeliveredMessages.emplace_back(msg.getDataAsString()); consumer.acknowledge(msg); } + if (!expectOrdered) { + auto expectedMessages = expectedNegativeAckMessages(numMessages); + std::sort(redeliveredMessages.begin(), redeliveredMessages.end()); + ASSERT_EQ(expectedMessages, redeliveredMessages); + } // No more messages expected expectTimeoutOnRecv(consumer); @@ -3262,7 +3292,7 @@ TEST(BasicEndToEndTest, testNegativeAcksWithPartitions) { LOG_INFO("res = " << res); ASSERT_FALSE(res != 204 && res != 409); - testNegativeAcks(topicName, true); + testNegativeAcks(topicName, true, false); } void testNegativeAckPrecisionBitCnt(const std::string &topic, int precisionBitCnt) { diff --git a/tests/ClientTest.cc b/tests/ClientTest.cc index dd892686..6bd6cc8a 100644 --- a/tests/ClientTest.cc +++ b/tests/ClientTest.cc @@ -21,14 +21,18 @@ #include #include +#include #include #include #include +#include +#include #include "MockClientImpl.h" #include "PulsarAdminHelper.h" #include "PulsarFriend.h" #include "WaitUtils.h" +#include "lib/AsioDefines.h" #include "lib/ClientConnection.h" #include "lib/LogUtils.h" #include "lib/checksum/ChecksumProvider.h" @@ -42,6 +46,70 @@ using testing::AtLeast; static std::string lookupUrl = "pulsar://localhost:6650"; static std::string adminUrl = "http://localhost:8080/"; +namespace { + +class SilentTcpServer { + public: + SilentTcpServer() + : acceptor_(ioContext_, ASIO::ip::tcp::endpoint(ASIO::ip::tcp::v4(), 0)), + acceptedFuture_(acceptedPromise_.get_future()), + port_(acceptor_.local_endpoint().port()), + workGuard_(ASIO::make_work_guard(ioContext_)) {} + + ~SilentTcpServer() { stop(); } + + int getPort() const noexcept { return port_; } + + void start() { + serverThread_ = std::thread([this] { + socket_.reset(new ASIO::ip::tcp::socket(ioContext_)); + acceptor_.async_accept( + *socket_, [this](const ASIO_ERROR &acceptError) { acceptedPromise_.set_value(acceptError); }); + + ioContext_.run(); + }); + } + + bool waitUntilAccepted(std::chrono::milliseconds timeout) const { + return acceptedFuture_.wait_for(timeout) == std::future_status::ready; + } + + auto acceptedError() const { return acceptedFuture_.get(); } + + void stop() { + bool expected = false; + if (!stopped_.compare_exchange_strong(expected, true) || !serverThread_.joinable()) { + return; + } + ASIO::post(ioContext_, [this] { + ASIO_ERROR closeError; + if (socket_ && socket_->is_open()) { + socket_->close(closeError); + } + if (acceptor_.is_open()) { + acceptor_.close(closeError); + } + workGuard_.reset(); + }); + serverThread_.join(); + } + + private: + using WorkGuard = decltype(ASIO::make_work_guard(std::declval())); + + ASIO::io_context ioContext_; + ASIO::ip::tcp::acceptor acceptor_; + std::unique_ptr socket_; + std::promise acceptedPromise_; + std::shared_future acceptedFuture_; + const int port_; + WorkGuard workGuard_; + std::atomic_bool stopped_{false}; + std::thread serverThread_; +}; + +} // namespace + TEST(ClientTest, testChecksumComputation) { std::string data = "test"; std::string doubleData = "testtest"; @@ -137,6 +205,32 @@ TEST(ClientTest, testConnectTimeout) { ASSERT_EQ(futureDefault.get(), ResultDisconnected); } +TEST(ClientTest, testConnectTimeoutAfterTcpConnected) { + std::unique_ptr server; + try { + server.reset(new SilentTcpServer); + } catch (const ASIO_SYSTEM_ERROR &e) { + GTEST_SKIP() << "Cannot bind local test server in this environment: " << e.what(); + } + server->start(); + + const std::string serviceUrl = "pulsar://127.0.0.1:" + std::to_string(server->getPort()); + Client client(serviceUrl, ClientConfiguration().setConnectionTimeout(200)); + + std::promise promise; + auto future = promise.get_future(); + client.createProducerAsync("test-connect-timeout-after-tcp-connected", + [&promise](Result result, const Producer &) { promise.set_value(result); }); + + ASSERT_TRUE(server->waitUntilAccepted(std::chrono::seconds(1))); + ASSERT_FALSE(server->acceptedError()); + ASSERT_EQ(future.wait_for(std::chrono::seconds(2)), std::future_status::ready); + ASSERT_EQ(future.get(), ResultConnectError); + + client.close(); + server->stop(); +} + TEST(ClientTest, testGetNumberOfReferences) { Client client("pulsar://localhost:6650"); @@ -309,14 +403,19 @@ TEST(ClientTest, testMultiBrokerUrl) { TEST(ClientTest, testCloseClient) { const std::string topic = "client-test-close-client-" + std::to_string(time(nullptr)); + using namespace std::chrono; for (int i = 0; i < 1000; ++i) { Client client(lookupUrl); client.createProducerAsync(topic, [](Result result, Producer producer) { producer.close(); }); // simulate different time interval before close - auto t0 = std::chrono::steady_clock::now(); - while ((std::chrono::steady_clock::now() - t0) < std::chrono::microseconds(i)) { + auto t0 = steady_clock::now(); + while ((steady_clock::now() - t0) < microseconds(i)) { } - client.close(); + + auto t1 = std::chrono::steady_clock::now(); + ASSERT_EQ(ResultOk, client.close()); + auto closeTimeMs = duration_cast(steady_clock::now() - t1).count(); + ASSERT_TRUE(closeTimeMs < 1000) << "close time: " << closeTimeMs << " ms"; } } @@ -413,7 +512,7 @@ TEST(ClientTest, testConnectionClose) { LOG_INFO("Connection refcnt: " << cnx.use_count() << " before close"); auto executor = PulsarFriend::getExecutor(*cnx); // Simulate the close() happens in the event loop - executor->postWork([cnx, &client, numConnections] { + executor->dispatch([cnx, &client, numConnections] { cnx->close(); ASSERT_EQ(PulsarFriend::getConnections(client).size(), numConnections - 1); LOG_INFO("Connection refcnt: " << cnx.use_count() << " after close"); diff --git a/tests/MultiTopicsConsumerTest.cc b/tests/MultiTopicsConsumerTest.cc index db3bc963..57407fbd 100644 --- a/tests/MultiTopicsConsumerTest.cc +++ b/tests/MultiTopicsConsumerTest.cc @@ -162,11 +162,12 @@ TEST(MultiTopicsConsumerTest, testGetConsumerStatsFail) { BrokerConsumerStats stats; return consumer.getBrokerConsumerStats(stats); }); - // Trigger the `getBrokerConsumerStats` in a new thread - future.wait_for(std::chrono::milliseconds(100)); - std::this_thread::sleep_for(std::chrono::milliseconds(100)); + const auto expectedRequests = topics.size(); + ASSERT_TRUE(waitUntil(std::chrono::seconds(1), [connection, expectedRequests] { + return PulsarFriend::getPendingConsumerStatsRequests(*connection) == expectedRequests; + })); - connection->handleKeepAliveTimeout(); + connection->close(ResultDisconnected); ASSERT_EQ(ResultDisconnected, future.get()); mockServer->close(); diff --git a/tests/PulsarFriend.h b/tests/PulsarFriend.h index e7084050..1f351d16 100644 --- a/tests/PulsarFriend.h +++ b/tests/PulsarFriend.h @@ -162,6 +162,11 @@ class PulsarFriend { return consumers; } + static size_t getPendingConsumerStatsRequests(const ClientConnection& cnx) { + std::lock_guard lock(cnx.mutex_); + return cnx.pendingConsumerStatsMap_.size(); + } + static void setNegativeAckEnabled(Consumer consumer, bool enabled) { consumer.impl_->setNegativeAcknowledgeEnabledForTesting(enabled); }