diff options
author | dec05eba <0xdec05eba@gmail.com> | 2018-10-21 08:52:53 +0200 |
---|---|---|
committer | dec05eba <0xdec05eba@gmail.com> | 2018-10-21 08:52:55 +0200 |
commit | e81aea0fc96a2eb44e09d3aa0aad1e7d11878cae (patch) | |
tree | 62006080a5747622e83df76eefb7eca230f40a6d /src | |
parent | 54254462e432dcc6ef2bb306a9ee773d21314d19 (diff) |
Fix concurrent connection to the same address
Diffstat (limited to 'src')
-rw-r--r-- | src/DirectConnection.cpp | 107 |
1 files changed, 59 insertions, 48 deletions
diff --git a/src/DirectConnection.cpp b/src/DirectConnection.cpp index 010b8af..3ffff22 100644 --- a/src/DirectConnection.cpp +++ b/src/DirectConnection.cpp @@ -109,58 +109,77 @@ namespace sibs { connect(address, true, true, connectCallbackFunc, receiveDataCallbackFunc, true); } + + static bool isReady(const std::shared_future<PubSubConnectResult> &connectionResultFuture) + { + return connectionResultFuture.wait_for(std::chrono::seconds(0)) == std::future_status::ready; + } void DirectConnections::connect(const Ipv4 &address, bool rendezvous, bool reuseAddr, PubSubConnectCallback connectCallbackFunc, PubSubReceiveDataCallback receiveDataCallbackFunc, bool bind) { - std::thread([this, address, rendezvous, reuseAddr, connectCallbackFunc, receiveDataCallbackFunc, bind]() + connectionResultsMutex.lock(); + auto it = connectionResults.find(address); + if(it != connectionResults.end()) { - std::shared_ptr<DirectConnectionPeer> peer = getPeerByAddress(address); - if(peer) + std::shared_future<PubSubConnectResult> connectionResultFuture = it->second; + connectionResultsMutex.unlock(); + if(isReady(connectionResultFuture)) { - // this doesn't really matter, we always call connect with same callback function - peer->receiveDataCallbackFunc = receiveDataCallbackFunc; + PubSubConnectResult connectResult = connectionResultFuture.get(); if(connectCallbackFunc) - connectCallbackFunc(peer, PubSubResult::OK, ""); - return; + connectCallbackFunc(connectResult.peer, connectResult.result, connectResult.resultStr); } else { - peer = std::make_shared<DirectConnectionPeer>(); - peerByAddressMap[address] = peer; - } - - std::unique_ptr<Socket> socket; - try - { - socket = createSocket(Ipv4(nullptr, port), rendezvous, reuseAddr, bind); - } - catch(SocketCreateException &e) - { - if(connectCallbackFunc) - connectCallbackFunc(peer, PubSubResult::ERROR, e.what()); - return; + std::thread([connectCallbackFunc, connectionResultFuture]() + { + PubSubConnectResult connectResult = connectionResultFuture.get(); + if(connectCallbackFunc) + connectCallbackFunc(connectResult.peer, connectResult.result, connectResult.resultStr); + }).detach(); } - - Log::debug("DirectConnections: Connecting to peer (ip: %s, port: %d, rendezvous: %s)", address.getAddress().c_str(), address.getPort(), rendezvous ? "yes" : "no"); - if(UDT::connect(socket->udtSocket, (sockaddr*)&address.address, sizeof(address.address)) == UDT::ERROR) + } + else + { + connectionResults[address] = std::async(std::launch::async, [this, address, rendezvous, reuseAddr, connectCallbackFunc, receiveDataCallbackFunc, bind]() { + std::shared_ptr<DirectConnectionPeer> peer = std::make_shared<DirectConnectionPeer>(); + std::unique_ptr<Socket> socket; + + try + { + socket = createSocket(Ipv4(nullptr, port), rendezvous, reuseAddr, bind); + } + catch(SocketCreateException &e) + { + if(connectCallbackFunc) + connectCallbackFunc(peer, PubSubResult::ERROR, e.what()); + return PubSubConnectResult { peer, PubSubResult::ERROR, e.what() }; + } + + Log::debug("DirectConnections: Connecting to peer (ip: %s, port: %d, rendezvous: %s)", address.getAddress().c_str(), address.getPort(), rendezvous ? "yes" : "no"); + if(UDT::connect(socket->udtSocket, (sockaddr*)&address.address, sizeof(address.address)) == UDT::ERROR) + { + if(connectCallbackFunc) + connectCallbackFunc(peer, PubSubResult::ERROR, UDT::getlasterror_desc()); + return PubSubConnectResult{ peer, PubSubResult::ERROR, UDT::getlasterror_desc() }; + } + + UDT::epoll_add_usock(eid, socket->udtSocket); + socket->eid = eid; + peersMutex.lock(); + peers[socket->udtSocket] = peer; + peer->socket = std::move(socket); + peer->address = address; + peer->receiveDataCallbackFunc = receiveDataCallbackFunc; + peersMutex.unlock(); + if(connectCallbackFunc) - connectCallbackFunc(peer, PubSubResult::ERROR, UDT::getlasterror_desc()); - return; - } - - UDT::epoll_add_usock(eid, socket->udtSocket); - socket->eid = eid; - peersMutex.lock(); - peers[socket->udtSocket] = peer; - peer->socket = std::move(socket); - peer->address = address; - peer->receiveDataCallbackFunc = receiveDataCallbackFunc; - peersMutex.unlock(); - - if(connectCallbackFunc) - connectCallbackFunc(peer, PubSubResult::OK, ""); - }).detach(); + connectCallbackFunc(peer, PubSubResult::OK, ""); + return PubSubConnectResult { peer, PubSubResult::OK, "" }; + }); + connectionResultsMutex.unlock(); + } } bool DirectConnections::send(const std::shared_ptr<DirectConnectionPeer> peer, std::shared_ptr<Message> data, PubSubSendDataCallback sendDataCallbackFunc) @@ -241,14 +260,6 @@ namespace sibs } peersMutex.unlock(); } - - std::shared_ptr<DirectConnectionPeer> DirectConnections::getPeerByAddress(const Ipv4 &address) const - { - auto it = peerByAddressMap.find(address); - if(it != peerByAddressMap.end()) - return it->second; - return nullptr; - } void DirectConnections::receiveData() { |