aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authordec05eba <dec05eba@protonmail.com>2018-10-21 08:52:53 +0200
committerdec05eba <dec05eba@protonmail.com>2020-08-18 22:56:48 +0200
commit40510daeca17b3db2cad0c9101d8f513df7127d1 (patch)
tree608232662f9f0c8abbff1af1aa4bfb0ef1a84282
parent980312b2a6e96c6d301d30d38922f8a2cc315c92 (diff)
Fix concurrent connection to the same address
-rw-r--r--include/sibs/DirectConnection.hpp13
-rw-r--r--src/DirectConnection.cpp107
-rw-r--r--tests/main.cpp21
3 files changed, 90 insertions, 51 deletions
diff --git a/include/sibs/DirectConnection.hpp b/include/sibs/DirectConnection.hpp
index 9be55f1..2137fd2 100644
--- a/include/sibs/DirectConnection.hpp
+++ b/include/sibs/DirectConnection.hpp
@@ -13,6 +13,7 @@
#include "../utils.hpp"
#include "Socket.hpp"
#include "Message.hpp"
+#include <future>
namespace sibs
{
@@ -29,6 +30,13 @@ namespace sibs
};
struct DirectConnectionPeer;
+
+ struct PubSubConnectResult
+ {
+ std::shared_ptr<DirectConnectionPeer> peer;
+ PubSubResult result;
+ std::string resultStr;
+ };
using PubSubConnectCallback = std::function<void(std::shared_ptr<DirectConnectionPeer> peer, PubSubResult result, const std::string &resultStr)>;
using PubSubReceiveDataCallback = std::function<void(std::shared_ptr<DirectConnectionPeer> peer, MessageType messageType, const void *data, const usize size)>;
@@ -65,8 +73,6 @@ namespace sibs
bool removePeer(int peerSocket);
std::vector<std::shared_ptr<DirectConnectionPeer>> getPeers();
-
- std::shared_ptr<DirectConnectionPeer> getPeerByAddress(const Ipv4 &address) const;
protected:
std::unique_ptr<Socket> createSocket(const Ipv4 &addressToBind, bool rendezvous, bool reuseAddr, bool bind = true);
private:
@@ -82,7 +88,8 @@ namespace sibs
std::mutex peersMutex;
bool alive;
PubSubOnRemoveDisconnectedPeerCallback removeDisconnectedPeerCallback;
- Ipv4Map<std::shared_ptr<DirectConnectionPeer>> peerByAddressMap;
+ Ipv4Map<std::shared_future<PubSubConnectResult>> connectionResults;
+ std::mutex connectionResultsMutex;
};
struct DirectConnectionsUtils
diff --git a/src/DirectConnection.cpp b/src/DirectConnection.cpp
index 010b8af..3ffff22 100644
--- a/src/DirectConnection.cpp
+++ b/src/DirectConnection.cpp
@@ -109,58 +109,77 @@ namespace sibs
{
connect(address, true, true, connectCallbackFunc, receiveDataCallbackFunc, true);
}
+
+ static bool isReady(const std::shared_future<PubSubConnectResult> &connectionResultFuture)
+ {
+ return connectionResultFuture.wait_for(std::chrono::seconds(0)) == std::future_status::ready;
+ }
void DirectConnections::connect(const Ipv4 &address, bool rendezvous, bool reuseAddr, PubSubConnectCallback connectCallbackFunc, PubSubReceiveDataCallback receiveDataCallbackFunc, bool bind)
{
- std::thread([this, address, rendezvous, reuseAddr, connectCallbackFunc, receiveDataCallbackFunc, bind]()
+ connectionResultsMutex.lock();
+ auto it = connectionResults.find(address);
+ if(it != connectionResults.end())
{
- std::shared_ptr<DirectConnectionPeer> peer = getPeerByAddress(address);
- if(peer)
+ std::shared_future<PubSubConnectResult> connectionResultFuture = it->second;
+ connectionResultsMutex.unlock();
+ if(isReady(connectionResultFuture))
{
- // this doesn't really matter, we always call connect with same callback function
- peer->receiveDataCallbackFunc = receiveDataCallbackFunc;
+ PubSubConnectResult connectResult = connectionResultFuture.get();
if(connectCallbackFunc)
- connectCallbackFunc(peer, PubSubResult::OK, "");
- return;
+ connectCallbackFunc(connectResult.peer, connectResult.result, connectResult.resultStr);
}
else
{
- peer = std::make_shared<DirectConnectionPeer>();
- peerByAddressMap[address] = peer;
- }
-
- std::unique_ptr<Socket> socket;
- try
- {
- socket = createSocket(Ipv4(nullptr, port), rendezvous, reuseAddr, bind);
- }
- catch(SocketCreateException &e)
- {
- if(connectCallbackFunc)
- connectCallbackFunc(peer, PubSubResult::ERROR, e.what());
- return;
+ std::thread([connectCallbackFunc, connectionResultFuture]()
+ {
+ PubSubConnectResult connectResult = connectionResultFuture.get();
+ if(connectCallbackFunc)
+ connectCallbackFunc(connectResult.peer, connectResult.result, connectResult.resultStr);
+ }).detach();
}
-
- Log::debug("DirectConnections: Connecting to peer (ip: %s, port: %d, rendezvous: %s)", address.getAddress().c_str(), address.getPort(), rendezvous ? "yes" : "no");
- if(UDT::connect(socket->udtSocket, (sockaddr*)&address.address, sizeof(address.address)) == UDT::ERROR)
+ }
+ else
+ {
+ connectionResults[address] = std::async(std::launch::async, [this, address, rendezvous, reuseAddr, connectCallbackFunc, receiveDataCallbackFunc, bind]()
{
+ std::shared_ptr<DirectConnectionPeer> peer = std::make_shared<DirectConnectionPeer>();
+ std::unique_ptr<Socket> socket;
+
+ try
+ {
+ socket = createSocket(Ipv4(nullptr, port), rendezvous, reuseAddr, bind);
+ }
+ catch(SocketCreateException &e)
+ {
+ if(connectCallbackFunc)
+ connectCallbackFunc(peer, PubSubResult::ERROR, e.what());
+ return PubSubConnectResult { peer, PubSubResult::ERROR, e.what() };
+ }
+
+ Log::debug("DirectConnections: Connecting to peer (ip: %s, port: %d, rendezvous: %s)", address.getAddress().c_str(), address.getPort(), rendezvous ? "yes" : "no");
+ if(UDT::connect(socket->udtSocket, (sockaddr*)&address.address, sizeof(address.address)) == UDT::ERROR)
+ {
+ if(connectCallbackFunc)
+ connectCallbackFunc(peer, PubSubResult::ERROR, UDT::getlasterror_desc());
+ return PubSubConnectResult{ peer, PubSubResult::ERROR, UDT::getlasterror_desc() };
+ }
+
+ UDT::epoll_add_usock(eid, socket->udtSocket);
+ socket->eid = eid;
+ peersMutex.lock();
+ peers[socket->udtSocket] = peer;
+ peer->socket = std::move(socket);
+ peer->address = address;
+ peer->receiveDataCallbackFunc = receiveDataCallbackFunc;
+ peersMutex.unlock();
+
if(connectCallbackFunc)
- connectCallbackFunc(peer, PubSubResult::ERROR, UDT::getlasterror_desc());
- return;
- }
-
- UDT::epoll_add_usock(eid, socket->udtSocket);
- socket->eid = eid;
- peersMutex.lock();
- peers[socket->udtSocket] = peer;
- peer->socket = std::move(socket);
- peer->address = address;
- peer->receiveDataCallbackFunc = receiveDataCallbackFunc;
- peersMutex.unlock();
-
- if(connectCallbackFunc)
- connectCallbackFunc(peer, PubSubResult::OK, "");
- }).detach();
+ connectCallbackFunc(peer, PubSubResult::OK, "");
+ return PubSubConnectResult { peer, PubSubResult::OK, "" };
+ });
+ connectionResultsMutex.unlock();
+ }
}
bool DirectConnections::send(const std::shared_ptr<DirectConnectionPeer> peer, std::shared_ptr<Message> data, PubSubSendDataCallback sendDataCallbackFunc)
@@ -241,14 +260,6 @@ namespace sibs
}
peersMutex.unlock();
}
-
- std::shared_ptr<DirectConnectionPeer> DirectConnections::getPeerByAddress(const Ipv4 &address) const
- {
- auto it = peerByAddressMap.find(address);
- if(it != peerByAddressMap.end())
- return it->second;
- return nullptr;
- }
void DirectConnections::receiveData()
{
diff --git a/tests/main.cpp b/tests/main.cpp
index d6c37e2..e262e02 100644
--- a/tests/main.cpp
+++ b/tests/main.cpp
@@ -11,6 +11,7 @@ const int PORT = 35231;
int main()
{
const sibs::PubsubKey key("abcdefghijklmnopqrstuvxyz0123456789", 35);
+ const sibs::PubsubKey key2("zbcdefghcjklmn3pqrs5uvx2z0123F56789", 35);
sibs::BootstrapNode boostrapNode(sibs::Ipv4(nullptr, PORT));
sibs::BootstrapConnection connection1(sibs::Ipv4("127.0.0.1", PORT));
@@ -36,9 +37,26 @@ int main()
gotAsdf2 = true;
return true;
});
+
+ bool gotListen = false;
+ connection1.listen(key2, [&gotListen](const sibs::DirectConnectionPeer *peer, const void *data, const sibs::usize size)
+ {
+ if(size == 14 && strncmp((const char*)data, "secondListener", 14) == 0)
+ gotListen = true;
+ return true;
+ });
+
+ bool gotListen2 = false;
+ connection2.listen(key2, [&gotListen2](const sibs::DirectConnectionPeer *peer, const void *data, const sibs::usize size)
+ {
+ if(size == 14 && strncmp((const char*)data, "secondListener", 14) == 0)
+ gotListen2 = true;
+ return true;
+ });
connection1.put(key, "hello", 5);
connection2.put(key, "asdf", 4);
+ connection1.put(key2, "secondListener", 14);
std::this_thread::sleep_for(std::chrono::seconds(6));
REQUIRE(gotData1);
@@ -46,5 +64,8 @@ int main()
REQUIRE(gotData2);
REQUIRE(gotAsdf2);
+ REQUIRE(gotListen);
+ REQUIRE(gotListen2);
+
return 0;
}