From 40510daeca17b3db2cad0c9101d8f513df7127d1 Mon Sep 17 00:00:00 2001 From: dec05eba Date: Sun, 21 Oct 2018 08:52:53 +0200 Subject: Fix concurrent connection to the same address --- include/sibs/DirectConnection.hpp | 13 +++-- src/DirectConnection.cpp | 107 +++++++++++++++++++++----------------- tests/main.cpp | 21 ++++++++ 3 files changed, 90 insertions(+), 51 deletions(-) diff --git a/include/sibs/DirectConnection.hpp b/include/sibs/DirectConnection.hpp index 9be55f1..2137fd2 100644 --- a/include/sibs/DirectConnection.hpp +++ b/include/sibs/DirectConnection.hpp @@ -13,6 +13,7 @@ #include "../utils.hpp" #include "Socket.hpp" #include "Message.hpp" +#include namespace sibs { @@ -29,6 +30,13 @@ namespace sibs }; struct DirectConnectionPeer; + + struct PubSubConnectResult + { + std::shared_ptr peer; + PubSubResult result; + std::string resultStr; + }; using PubSubConnectCallback = std::function peer, PubSubResult result, const std::string &resultStr)>; using PubSubReceiveDataCallback = std::function peer, MessageType messageType, const void *data, const usize size)>; @@ -65,8 +73,6 @@ namespace sibs bool removePeer(int peerSocket); std::vector> getPeers(); - - std::shared_ptr getPeerByAddress(const Ipv4 &address) const; protected: std::unique_ptr createSocket(const Ipv4 &addressToBind, bool rendezvous, bool reuseAddr, bool bind = true); private: @@ -82,7 +88,8 @@ namespace sibs std::mutex peersMutex; bool alive; PubSubOnRemoveDisconnectedPeerCallback removeDisconnectedPeerCallback; - Ipv4Map> peerByAddressMap; + Ipv4Map> connectionResults; + std::mutex connectionResultsMutex; }; struct DirectConnectionsUtils diff --git a/src/DirectConnection.cpp b/src/DirectConnection.cpp index 010b8af..3ffff22 100644 --- a/src/DirectConnection.cpp +++ b/src/DirectConnection.cpp @@ -109,58 +109,77 @@ namespace sibs { connect(address, true, true, connectCallbackFunc, receiveDataCallbackFunc, true); } + + static bool isReady(const std::shared_future &connectionResultFuture) + { + return connectionResultFuture.wait_for(std::chrono::seconds(0)) == std::future_status::ready; + } void DirectConnections::connect(const Ipv4 &address, bool rendezvous, bool reuseAddr, PubSubConnectCallback connectCallbackFunc, PubSubReceiveDataCallback receiveDataCallbackFunc, bool bind) { - std::thread([this, address, rendezvous, reuseAddr, connectCallbackFunc, receiveDataCallbackFunc, bind]() + connectionResultsMutex.lock(); + auto it = connectionResults.find(address); + if(it != connectionResults.end()) { - std::shared_ptr peer = getPeerByAddress(address); - if(peer) + std::shared_future connectionResultFuture = it->second; + connectionResultsMutex.unlock(); + if(isReady(connectionResultFuture)) { - // this doesn't really matter, we always call connect with same callback function - peer->receiveDataCallbackFunc = receiveDataCallbackFunc; + PubSubConnectResult connectResult = connectionResultFuture.get(); if(connectCallbackFunc) - connectCallbackFunc(peer, PubSubResult::OK, ""); - return; + connectCallbackFunc(connectResult.peer, connectResult.result, connectResult.resultStr); } else { - peer = std::make_shared(); - peerByAddressMap[address] = peer; - } - - std::unique_ptr socket; - try - { - socket = createSocket(Ipv4(nullptr, port), rendezvous, reuseAddr, bind); - } - catch(SocketCreateException &e) - { - if(connectCallbackFunc) - connectCallbackFunc(peer, PubSubResult::ERROR, e.what()); - return; + std::thread([connectCallbackFunc, connectionResultFuture]() + { + PubSubConnectResult connectResult = connectionResultFuture.get(); + if(connectCallbackFunc) + connectCallbackFunc(connectResult.peer, connectResult.result, connectResult.resultStr); + }).detach(); } - - Log::debug("DirectConnections: Connecting to peer (ip: %s, port: %d, rendezvous: %s)", address.getAddress().c_str(), address.getPort(), rendezvous ? "yes" : "no"); - if(UDT::connect(socket->udtSocket, (sockaddr*)&address.address, sizeof(address.address)) == UDT::ERROR) + } + else + { + connectionResults[address] = std::async(std::launch::async, [this, address, rendezvous, reuseAddr, connectCallbackFunc, receiveDataCallbackFunc, bind]() { + std::shared_ptr peer = std::make_shared(); + std::unique_ptr socket; + + try + { + socket = createSocket(Ipv4(nullptr, port), rendezvous, reuseAddr, bind); + } + catch(SocketCreateException &e) + { + if(connectCallbackFunc) + connectCallbackFunc(peer, PubSubResult::ERROR, e.what()); + return PubSubConnectResult { peer, PubSubResult::ERROR, e.what() }; + } + + Log::debug("DirectConnections: Connecting to peer (ip: %s, port: %d, rendezvous: %s)", address.getAddress().c_str(), address.getPort(), rendezvous ? "yes" : "no"); + if(UDT::connect(socket->udtSocket, (sockaddr*)&address.address, sizeof(address.address)) == UDT::ERROR) + { + if(connectCallbackFunc) + connectCallbackFunc(peer, PubSubResult::ERROR, UDT::getlasterror_desc()); + return PubSubConnectResult{ peer, PubSubResult::ERROR, UDT::getlasterror_desc() }; + } + + UDT::epoll_add_usock(eid, socket->udtSocket); + socket->eid = eid; + peersMutex.lock(); + peers[socket->udtSocket] = peer; + peer->socket = std::move(socket); + peer->address = address; + peer->receiveDataCallbackFunc = receiveDataCallbackFunc; + peersMutex.unlock(); + if(connectCallbackFunc) - connectCallbackFunc(peer, PubSubResult::ERROR, UDT::getlasterror_desc()); - return; - } - - UDT::epoll_add_usock(eid, socket->udtSocket); - socket->eid = eid; - peersMutex.lock(); - peers[socket->udtSocket] = peer; - peer->socket = std::move(socket); - peer->address = address; - peer->receiveDataCallbackFunc = receiveDataCallbackFunc; - peersMutex.unlock(); - - if(connectCallbackFunc) - connectCallbackFunc(peer, PubSubResult::OK, ""); - }).detach(); + connectCallbackFunc(peer, PubSubResult::OK, ""); + return PubSubConnectResult { peer, PubSubResult::OK, "" }; + }); + connectionResultsMutex.unlock(); + } } bool DirectConnections::send(const std::shared_ptr peer, std::shared_ptr data, PubSubSendDataCallback sendDataCallbackFunc) @@ -241,14 +260,6 @@ namespace sibs } peersMutex.unlock(); } - - std::shared_ptr DirectConnections::getPeerByAddress(const Ipv4 &address) const - { - auto it = peerByAddressMap.find(address); - if(it != peerByAddressMap.end()) - return it->second; - return nullptr; - } void DirectConnections::receiveData() { diff --git a/tests/main.cpp b/tests/main.cpp index d6c37e2..e262e02 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -11,6 +11,7 @@ const int PORT = 35231; int main() { const sibs::PubsubKey key("abcdefghijklmnopqrstuvxyz0123456789", 35); + const sibs::PubsubKey key2("zbcdefghcjklmn3pqrs5uvx2z0123F56789", 35); sibs::BootstrapNode boostrapNode(sibs::Ipv4(nullptr, PORT)); sibs::BootstrapConnection connection1(sibs::Ipv4("127.0.0.1", PORT)); @@ -36,9 +37,26 @@ int main() gotAsdf2 = true; return true; }); + + bool gotListen = false; + connection1.listen(key2, [&gotListen](const sibs::DirectConnectionPeer *peer, const void *data, const sibs::usize size) + { + if(size == 14 && strncmp((const char*)data, "secondListener", 14) == 0) + gotListen = true; + return true; + }); + + bool gotListen2 = false; + connection2.listen(key2, [&gotListen2](const sibs::DirectConnectionPeer *peer, const void *data, const sibs::usize size) + { + if(size == 14 && strncmp((const char*)data, "secondListener", 14) == 0) + gotListen2 = true; + return true; + }); connection1.put(key, "hello", 5); connection2.put(key, "asdf", 4); + connection1.put(key2, "secondListener", 14); std::this_thread::sleep_for(std::chrono::seconds(6)); REQUIRE(gotData1); @@ -46,5 +64,8 @@ int main() REQUIRE(gotData2); REQUIRE(gotAsdf2); + REQUIRE(gotListen); + REQUIRE(gotListen2); + return 0; } -- cgit v1.2.3