From 40510daeca17b3db2cad0c9101d8f513df7127d1 Mon Sep 17 00:00:00 2001 From: dec05eba Date: Sun, 21 Oct 2018 08:52:53 +0200 Subject: Fix concurrent connection to the same address --- src/DirectConnection.cpp | 107 ++++++++++++++++++++++++++--------------------- 1 file changed, 59 insertions(+), 48 deletions(-) (limited to 'src') diff --git a/src/DirectConnection.cpp b/src/DirectConnection.cpp index 010b8af..3ffff22 100644 --- a/src/DirectConnection.cpp +++ b/src/DirectConnection.cpp @@ -109,58 +109,77 @@ namespace sibs { connect(address, true, true, connectCallbackFunc, receiveDataCallbackFunc, true); } + + static bool isReady(const std::shared_future &connectionResultFuture) + { + return connectionResultFuture.wait_for(std::chrono::seconds(0)) == std::future_status::ready; + } void DirectConnections::connect(const Ipv4 &address, bool rendezvous, bool reuseAddr, PubSubConnectCallback connectCallbackFunc, PubSubReceiveDataCallback receiveDataCallbackFunc, bool bind) { - std::thread([this, address, rendezvous, reuseAddr, connectCallbackFunc, receiveDataCallbackFunc, bind]() + connectionResultsMutex.lock(); + auto it = connectionResults.find(address); + if(it != connectionResults.end()) { - std::shared_ptr peer = getPeerByAddress(address); - if(peer) + std::shared_future connectionResultFuture = it->second; + connectionResultsMutex.unlock(); + if(isReady(connectionResultFuture)) { - // this doesn't really matter, we always call connect with same callback function - peer->receiveDataCallbackFunc = receiveDataCallbackFunc; + PubSubConnectResult connectResult = connectionResultFuture.get(); if(connectCallbackFunc) - connectCallbackFunc(peer, PubSubResult::OK, ""); - return; + connectCallbackFunc(connectResult.peer, connectResult.result, connectResult.resultStr); } else { - peer = std::make_shared(); - peerByAddressMap[address] = peer; - } - - std::unique_ptr socket; - try - { - socket = createSocket(Ipv4(nullptr, port), rendezvous, reuseAddr, bind); - } - catch(SocketCreateException &e) - { - if(connectCallbackFunc) - connectCallbackFunc(peer, PubSubResult::ERROR, e.what()); - return; + std::thread([connectCallbackFunc, connectionResultFuture]() + { + PubSubConnectResult connectResult = connectionResultFuture.get(); + if(connectCallbackFunc) + connectCallbackFunc(connectResult.peer, connectResult.result, connectResult.resultStr); + }).detach(); } - - Log::debug("DirectConnections: Connecting to peer (ip: %s, port: %d, rendezvous: %s)", address.getAddress().c_str(), address.getPort(), rendezvous ? "yes" : "no"); - if(UDT::connect(socket->udtSocket, (sockaddr*)&address.address, sizeof(address.address)) == UDT::ERROR) + } + else + { + connectionResults[address] = std::async(std::launch::async, [this, address, rendezvous, reuseAddr, connectCallbackFunc, receiveDataCallbackFunc, bind]() { + std::shared_ptr peer = std::make_shared(); + std::unique_ptr socket; + + try + { + socket = createSocket(Ipv4(nullptr, port), rendezvous, reuseAddr, bind); + } + catch(SocketCreateException &e) + { + if(connectCallbackFunc) + connectCallbackFunc(peer, PubSubResult::ERROR, e.what()); + return PubSubConnectResult { peer, PubSubResult::ERROR, e.what() }; + } + + Log::debug("DirectConnections: Connecting to peer (ip: %s, port: %d, rendezvous: %s)", address.getAddress().c_str(), address.getPort(), rendezvous ? "yes" : "no"); + if(UDT::connect(socket->udtSocket, (sockaddr*)&address.address, sizeof(address.address)) == UDT::ERROR) + { + if(connectCallbackFunc) + connectCallbackFunc(peer, PubSubResult::ERROR, UDT::getlasterror_desc()); + return PubSubConnectResult{ peer, PubSubResult::ERROR, UDT::getlasterror_desc() }; + } + + UDT::epoll_add_usock(eid, socket->udtSocket); + socket->eid = eid; + peersMutex.lock(); + peers[socket->udtSocket] = peer; + peer->socket = std::move(socket); + peer->address = address; + peer->receiveDataCallbackFunc = receiveDataCallbackFunc; + peersMutex.unlock(); + if(connectCallbackFunc) - connectCallbackFunc(peer, PubSubResult::ERROR, UDT::getlasterror_desc()); - return; - } - - UDT::epoll_add_usock(eid, socket->udtSocket); - socket->eid = eid; - peersMutex.lock(); - peers[socket->udtSocket] = peer; - peer->socket = std::move(socket); - peer->address = address; - peer->receiveDataCallbackFunc = receiveDataCallbackFunc; - peersMutex.unlock(); - - if(connectCallbackFunc) - connectCallbackFunc(peer, PubSubResult::OK, ""); - }).detach(); + connectCallbackFunc(peer, PubSubResult::OK, ""); + return PubSubConnectResult { peer, PubSubResult::OK, "" }; + }); + connectionResultsMutex.unlock(); + } } bool DirectConnections::send(const std::shared_ptr peer, std::shared_ptr data, PubSubSendDataCallback sendDataCallbackFunc) @@ -241,14 +260,6 @@ namespace sibs } peersMutex.unlock(); } - - std::shared_ptr DirectConnections::getPeerByAddress(const Ipv4 &address) const - { - auto it = peerByAddressMap.find(address); - if(it != peerByAddressMap.end()) - return it->second; - return nullptr; - } void DirectConnections::receiveData() { -- cgit v1.2.3