diff options
-rw-r--r-- | include/sibs/BootstrapConnection.hpp | 8 | ||||
-rw-r--r-- | include/sibs/BootstrapNode.hpp | 2 | ||||
-rw-r--r-- | include/sibs/DirectConnection.hpp | 14 | ||||
-rw-r--r-- | src/BootstrapConnection.cpp | 42 | ||||
-rw-r--r-- | src/DirectConnection.cpp | 64 |
5 files changed, 93 insertions, 37 deletions
diff --git a/include/sibs/BootstrapConnection.hpp b/include/sibs/BootstrapConnection.hpp index 094ee32..e77222d 100644 --- a/include/sibs/BootstrapConnection.hpp +++ b/include/sibs/BootstrapConnection.hpp @@ -7,6 +7,12 @@ namespace sibs { + class BootstrapConnectionException : public std::runtime_error + { + public: + BootstrapConnectionException(const std::string &errMsg) : std::runtime_error(errMsg) {} + }; + class PubsubKeyAlreadyListeningException : public std::runtime_error { public: @@ -19,6 +25,7 @@ namespace sibs { DISABLE_COPY(BootstrapConnection) public: + // Throws BootstrapConnectionException on error BootstrapConnection(const Ipv4 &bootstrapAddress); // Throws PubsubKeyAlreadyListeningException if we are already listening on the key @pubsubKey @@ -33,5 +40,6 @@ namespace sibs PubsubKeyMap<BoostrapConnectionListenCallbackFunc> listenCallbackFuncs; PubsubKeyMap<std::vector<std::shared_ptr<DirectConnectionPeer>>> subscribedPeers; std::mutex listenerCallbackFuncMutex; + std::mutex subscribedPeersMutex; }; } diff --git a/include/sibs/BootstrapNode.hpp b/include/sibs/BootstrapNode.hpp index 13a62e1..48e527c 100644 --- a/include/sibs/BootstrapNode.hpp +++ b/include/sibs/BootstrapNode.hpp @@ -17,7 +17,7 @@ namespace sibs { DISABLE_COPY(BootstrapNode) public: - // Throws BootstrapException on error + // Throws SocketCreateException or BootstrapException on error BootstrapNode(const Ipv4 &address); ~BootstrapNode(); private: diff --git a/include/sibs/DirectConnection.hpp b/include/sibs/DirectConnection.hpp index 45ba5d9..b7c467e 100644 --- a/include/sibs/DirectConnection.hpp +++ b/include/sibs/DirectConnection.hpp @@ -1,6 +1,5 @@ #pragma once -#include <stdexcept> #include <unordered_map> #include <functional> #include <string> @@ -8,16 +7,17 @@ #include <thread> #include <mutex> #include <vector> +#include <stdexcept> #include "IpAddress.hpp" #include "../types.hpp" #include "../utils.hpp" namespace sibs { - class ConnectionException : public std::runtime_error + class SocketCreateException : public std::runtime_error { public: - ConnectionException(const std::string &errMsg) : std::runtime_error(errMsg) {} + SocketCreateException(const std::string &errMsg) : std::runtime_error(errMsg) {} }; enum class PubSubResult @@ -28,7 +28,7 @@ namespace sibs struct DirectConnectionPeer; - using PubSubConnectCallback = std::function<void(PubSubResult result, const std::string &resultStr)>; + using PubSubConnectCallback = std::function<void(std::shared_ptr<DirectConnectionPeer> peer, PubSubResult result, const std::string &resultStr)>; using PubSubReceiveDataCallback = std::function<void(std::shared_ptr<DirectConnectionPeer> peer, const void *data, const usize size)>; using PubSubSendDataCallback = std::function<void(PubSubResult result, const std::string &resultStr)>; @@ -48,14 +48,14 @@ namespace sibs ~DirectConnections(); // Throws ConnectionException on error - std::shared_ptr<DirectConnectionPeer> connectServer(const Ipv4 &address, PubSubReceiveDataCallback receiveDataCallbackFunc); + void connectServer(const Ipv4 &address, PubSubConnectCallback connectCallbackFunc, PubSubReceiveDataCallback receiveDataCallbackFunc); // Throws ConnectionException on error - std::shared_ptr<DirectConnectionPeer> connect(const Ipv4 &address, PubSubReceiveDataCallback receiveDataCallbackFunc); + void connect(const Ipv4 &address, PubSubConnectCallback connectCallbackFunc, PubSubReceiveDataCallback receiveDataCallbackFunc); void send(const std::shared_ptr<DirectConnectionPeer> &peer, std::shared_ptr<std::vector<u8>> data, PubSubSendDataCallback sendDataCallbackFunc = nullptr); protected: int createSocket(const Ipv4 &addressToBind, bool rendezvous, bool reuseAddr); private: - std::shared_ptr<DirectConnectionPeer> connect(const Ipv4 &address, bool rendezvous, bool reuseAddr, PubSubReceiveDataCallback receiveDataCallbackFunc); + void connect(const Ipv4 &address, bool rendezvous, bool reuseAddr, PubSubConnectCallback connectCallbackFunc, PubSubReceiveDataCallback receiveDataCallbackFunc); void receiveData(); int receiveDataFromPeer(const int socket, char *output); bool removePeer(int peerSocket); diff --git a/src/BootstrapConnection.cpp b/src/BootstrapConnection.cpp index 395e9e0..2920440 100644 --- a/src/BootstrapConnection.cpp +++ b/src/BootstrapConnection.cpp @@ -6,7 +6,28 @@ namespace sibs { BootstrapConnection::BootstrapConnection(const Ipv4 &bootstrapAddress) { - serverPeer = connections.connectServer(bootstrapAddress, std::bind(&BootstrapConnection::receiveDataFromServer, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + PubSubResult connectResult = PubSubResult::OK; + std::string connectResultStr; + bool connected = false; + connections.connectServer(bootstrapAddress, [this, &connectResult, &connectResultStr, &connected](std::shared_ptr<DirectConnectionPeer> peer, PubSubResult result, const std::string &resultStr) + { + serverPeer = peer; + connectResult = result; + connectResultStr = resultStr; + connected = true; + }, std::bind(&BootstrapConnection::receiveDataFromServer, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + + while(!connected) + { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + + if(connectResult != PubSubResult::OK) + { + std::string errMsg = "Failed to connect to bootstrap node, error: "; + errMsg += connectResultStr; + throw BootstrapConnectionException(errMsg); + } } // TODO: This is vulnerable against MitM attack, replace with asymmetric cryptography, get data signed with server private key and verify against known server public key @@ -27,7 +48,6 @@ namespace sibs auto listenerCallbackFunc = listenerFuncIt->second; listenerCallbackFuncMutex.unlock(); - auto &peers = subscribedPeers[pubsubKey]; while(!deserializer.empty()) { sa_family_t addressFamily = deserializer.extract<u32>(); @@ -41,8 +61,17 @@ namespace sibs newPeerAddress.address.sin_port = port; memset(newPeerAddress.address.sin_zero, 0, sizeof(newPeerAddress.address.sin_zero)); // TODO: Move connection to thread and add callback function, just like @receiveData and @send - std::shared_ptr<DirectConnectionPeer> newPeer = connections.connect(newPeerAddress, std::bind(&BootstrapConnection::receiveDataFromPeer, this, listenerCallbackFunc, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); - peers.push_back(newPeer); + connections.connect(newPeerAddress, [this, pubsubKey](std::shared_ptr<DirectConnectionPeer> newPeer, PubSubResult result, const std::string &resultStr) + { + if(result == PubSubResult::OK) + { + subscribedPeersMutex.lock(); + subscribedPeers[pubsubKey].push_back(newPeer); + subscribedPeersMutex.unlock(); + } + else + Log::error("UDT: Failed to connect to peer given by bootstrap node, error: %s", resultStr.c_str()); + }, std::bind(&BootstrapConnection::receiveDataFromPeer, this, listenerCallbackFunc, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); } else Log::error("BootstrapConnection: Unknown address family: %d", addressFamily); @@ -75,9 +104,14 @@ namespace sibs listenCallbackFuncIt->second(data->data(), data->size()); } + subscribedPeersMutex.lock(); auto peersIt = subscribedPeers.find(pubsubKey); if(peersIt == subscribedPeers.end()) + { + subscribedPeersMutex.unlock(); return; + } + subscribedPeersMutex.unlock(); for(auto &peer : peersIt->second) { diff --git a/src/DirectConnection.cpp b/src/DirectConnection.cpp index 748a936..d10c990 100644 --- a/src/DirectConnection.cpp +++ b/src/DirectConnection.cpp @@ -47,7 +47,7 @@ namespace sibs { std::string errMsg = "UDT: Failed to create socket, error: "; errMsg += UDT::getlasterror_desc(); - throw ConnectionException(errMsg); + throw SocketCreateException(errMsg); } UDT::setsockopt(socket, 0, UDT_RENDEZVOUS, &rendezvous, sizeof(bool)); @@ -64,43 +64,57 @@ namespace sibs { std::string errMsg = "UDT: Failed to bind, error: "; errMsg += UDT::getlasterror_desc(); - throw ConnectionException(errMsg); + throw SocketCreateException(errMsg); } return socket; } - std::shared_ptr<DirectConnectionPeer> DirectConnections::connectServer(const Ipv4 &address, PubSubReceiveDataCallback receiveDataCallbackFunc) + void DirectConnections::connectServer(const Ipv4 &address, PubSubConnectCallback connectCallbackFunc, PubSubReceiveDataCallback receiveDataCallbackFunc) { - return connect(address, false, true, receiveDataCallbackFunc); + connect(address, false, true, connectCallbackFunc, receiveDataCallbackFunc); } - std::shared_ptr<DirectConnectionPeer> DirectConnections::connect(const Ipv4 &address, PubSubReceiveDataCallback receiveDataCallbackFunc) + void DirectConnections::connect(const Ipv4 &address, PubSubConnectCallback connectCallbackFunc, PubSubReceiveDataCallback receiveDataCallbackFunc) { - return connect(address, true, true, receiveDataCallbackFunc); + connect(address, true, true, connectCallbackFunc, receiveDataCallbackFunc); } - std::shared_ptr<DirectConnectionPeer> DirectConnections::connect(const Ipv4 &address, bool rendezvous, bool reuseAddr, PubSubReceiveDataCallback receiveDataCallbackFunc) + void DirectConnections::connect(const Ipv4 &address, bool rendezvous, bool reuseAddr, PubSubConnectCallback connectCallbackFunc, PubSubReceiveDataCallback receiveDataCallbackFunc) { - UDTSOCKET socket = createSocket(Ipv4(nullptr, port), rendezvous, reuseAddr); - - if(UDT::connect(socket, (sockaddr*)&address.address, sizeof(address.address)) == UDT::ERROR) + std::thread([this, address, rendezvous, reuseAddr, connectCallbackFunc, receiveDataCallbackFunc]() { - UDT::close(socket); - std::string errMsg = "UDT: Failed to connect, error: "; - errMsg += UDT::getlasterror_desc(); - throw ConnectionException(errMsg); - } - - UDT::epoll_add_usock(eid, socket); - std::shared_ptr<DirectConnectionPeer> peer = std::make_shared<DirectConnectionPeer>(); - peer->socket = socket; - peer->address = address; - peer->receiveDataCallbackFunc = receiveDataCallbackFunc; - peersMutex.lock(); - peers[socket] = peer; - peersMutex.unlock(); - return peer; + std::shared_ptr<DirectConnectionPeer> peer = std::make_shared<DirectConnectionPeer>(); + UDTSOCKET socket; + try + { + socket = createSocket(Ipv4(nullptr, port), rendezvous, reuseAddr); + } + catch(SocketCreateException &e) + { + if(connectCallbackFunc) + connectCallbackFunc(peer, PubSubResult::ERROR, e.what()); + return; + } + + if(UDT::connect(socket, (sockaddr*)&address.address, sizeof(address.address)) == UDT::ERROR) + { + if(connectCallbackFunc) + connectCallbackFunc(peer, PubSubResult::ERROR, UDT::getlasterror_desc()); + return; + } + + UDT::epoll_add_usock(eid, socket); + peer->socket = socket; + peer->address = address; + peer->receiveDataCallbackFunc = receiveDataCallbackFunc; + peersMutex.lock(); + peers[socket] = peer; + peersMutex.unlock(); + + if(connectCallbackFunc) + connectCallbackFunc(peer, PubSubResult::OK, ""); + }).detach(); } void DirectConnections::send(const std::shared_ptr<DirectConnectionPeer> &peer, std::shared_ptr<std::vector<u8>> data, PubSubSendDataCallback sendDataCallbackFunc) |