From eda94456add9a65d1821302e343bef4021d2a773 Mon Sep 17 00:00:00 2001 From: dec05eba <0xdec05eba@gmail.com> Date: Tue, 16 Oct 2018 00:37:21 +0200 Subject: Reuse peer connection if subscribed to same key --- include/FnvHash.hpp | 15 ++++ include/sibs/BootstrapConnection.hpp | 28 ++++-- include/sibs/BootstrapNode.hpp | 5 +- include/sibs/DirectConnection.hpp | 30 +++++-- include/sibs/IpAddress.hpp | 13 +++ include/sibs/Message.hpp | 30 +++++++ include/sibs/PubsubKey.hpp | 15 ++-- include/sibs/Socket.hpp | 19 ++++ project.conf | 2 +- src/BootstrapConnection.cpp | 165 +++++++++++++++++++++++++++++------ src/BootstrapNode.cpp | 117 ++++++++++++++++--------- src/DirectConnection.cpp | 161 +++++++++++++++++++++++++--------- src/IpAddress.cpp | 7 +- src/Message.cpp | 15 ++++ src/PubsubKey.cpp | 20 ++++- src/Socket.cpp | 40 +++++++++ tests/main.cpp | 40 ++++++++- 17 files changed, 579 insertions(+), 143 deletions(-) create mode 100644 include/FnvHash.hpp create mode 100644 include/sibs/Message.hpp create mode 100644 include/sibs/Socket.hpp create mode 100644 src/Message.cpp create mode 100644 src/Socket.cpp diff --git a/include/FnvHash.hpp b/include/FnvHash.hpp new file mode 100644 index 0000000..7766756 --- /dev/null +++ b/include/FnvHash.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include "types.hpp" + +namespace sibs +{ + // Source: https://stackoverflow.com/a/11414104 (public license) + static usize fnvHash(const unsigned char *key, int len) + { + usize h = 2166136261ULL; + for (int i = 0; i < len; i++) + h = (h * 16777619ULL) ^ key[i]; + return h; + } +} \ No newline at end of file diff --git a/include/sibs/BootstrapConnection.hpp b/include/sibs/BootstrapConnection.hpp index dd90b7d..08af775 100644 --- a/include/sibs/BootstrapConnection.hpp +++ b/include/sibs/BootstrapConnection.hpp @@ -19,8 +19,14 @@ namespace sibs PubsubKeyAlreadyListeningException(const std::string &errMsg) : std::runtime_error(errMsg) {} }; - // @peer is nullptr is data was sent by local user - using BoostrapConnectionListenCallbackFunc = std::function; + // @peer is nullptr is data was sent by local user. + // Return false if you want to stop listening on the key + using BoostrapConnectionListenCallbackFunc = std::function; + + struct ListenHandle + { + PubsubKey key; + }; class BootstrapConnection { @@ -30,17 +36,23 @@ namespace sibs BootstrapConnection(const Ipv4 &bootstrapAddress); // Throws PubsubKeyAlreadyListeningException if we are already listening on the key @pubsubKey - void listen(const PubsubKey &pubsubKey, BoostrapConnectionListenCallbackFunc callbackFunc); - void put(const PubsubKey &pubsubKey, std::shared_ptr> data); + ListenHandle listen(const PubsubKey &pubsubKey, BoostrapConnectionListenCallbackFunc callbackFunc); + // Returns false if data is larger than 800kb. + // Note: @data is copied in this function. + // Note: You can't put data on a pubsubkey that you are not listening on. Call @listen first. + bool put(const PubsubKey &pubsubKey, const void *data, const usize size); + bool cancelListen(const ListenHandle &listener); + + std::vector> getPeers(); private: - void receiveDataFromServer(std::shared_ptr peer, const void *data, const usize size); - void receiveDataFromPeer(BoostrapConnectionListenCallbackFunc listenCallbackFunc, std::shared_ptr peer, const void *data, const usize size); + void receiveDataFromServer(std::shared_ptr peer, MessageType messageType, const void *data, const usize size); + void receiveDataFromPeer(std::shared_ptr peer, MessageType messageType, const void *data, const usize size); private: DirectConnections connections; std::shared_ptr serverPeer; PubsubKeyMap listenCallbackFuncs; PubsubKeyMap>> subscribedPeers; - std::mutex listenerCallbackFuncMutex; - std::mutex subscribedPeersMutex; + std::recursive_mutex listenerCallbackFuncMutex; + std::recursive_mutex subscribedPeersMutex; }; } diff --git a/include/sibs/BootstrapNode.hpp b/include/sibs/BootstrapNode.hpp index ab3d6b3..c824a84 100644 --- a/include/sibs/BootstrapNode.hpp +++ b/include/sibs/BootstrapNode.hpp @@ -1,5 +1,6 @@ #pragma once +#include "Socket.hpp" #include "DirectConnection.hpp" #include "IpAddress.hpp" #include "PubsubKey.hpp" @@ -23,10 +24,10 @@ namespace sibs ~BootstrapNode(); private: void acceptConnections(); - void peerSubscribe(std::shared_ptr peer, const void *data, const usize size); + void messageFromClient(std::shared_ptr peer, MessageType messageType, const void *data, const usize size); private: DirectConnections connections; - int socket; + std::unique_ptr socket; std::thread acceptConnectionsThread; PubsubKeyMap>> subscribedPeers; std::mutex subscribedPeersMutex; diff --git a/include/sibs/DirectConnection.hpp b/include/sibs/DirectConnection.hpp index 8e3865f..9be55f1 100644 --- a/include/sibs/DirectConnection.hpp +++ b/include/sibs/DirectConnection.hpp @@ -11,6 +11,8 @@ #include "IpAddress.hpp" #include "../types.hpp" #include "../utils.hpp" +#include "Socket.hpp" +#include "Message.hpp" namespace sibs { @@ -29,15 +31,19 @@ namespace sibs struct DirectConnectionPeer; using PubSubConnectCallback = std::function peer, PubSubResult result, const std::string &resultStr)>; - using PubSubReceiveDataCallback = std::function peer, const void *data, const usize size)>; + using PubSubReceiveDataCallback = std::function peer, MessageType messageType, const void *data, const usize size)>; using PubSubSendDataCallback = std::function; using PubSubOnRemoveDisconnectedPeerCallback = std::function peer)>; struct DirectConnectionPeer { - int socket; + std::unique_ptr socket; Ipv4 address; PubSubReceiveDataCallback receiveDataCallbackFunc; + int sharedKeys = 0; + + bool operator == (const DirectConnectionPeer &other) const; + bool operator != (const DirectConnectionPeer &other) const; }; class DirectConnections @@ -52,18 +58,22 @@ namespace sibs void connectServer(const Ipv4 &address, PubSubConnectCallback connectCallbackFunc, PubSubReceiveDataCallback receiveDataCallbackFunc); // Throws ConnectionException on error void connect(const Ipv4 &address, PubSubConnectCallback connectCallbackFunc, PubSubReceiveDataCallback receiveDataCallbackFunc); - - void send(const std::shared_ptr peer, std::shared_ptr> data, PubSubSendDataCallback sendDataCallbackFunc = nullptr); + // Returns false if data is larger than 800kb + bool send(const std::shared_ptr peer, std::shared_ptr data, PubSubSendDataCallback sendDataCallbackFunc = nullptr); void onRemoveDisconnectedPeer(PubSubOnRemoveDisconnectedPeerCallback callbackFunc); + bool removePeer(int peerSocket); + + std::vector> getPeers(); + + std::shared_ptr getPeerByAddress(const Ipv4 &address) const; protected: - int createSocket(const Ipv4 &addressToBind, bool rendezvous, bool reuseAddr, bool bind = true); + std::unique_ptr createSocket(const Ipv4 &addressToBind, bool rendezvous, bool reuseAddr, bool bind = true); private: void connect(const Ipv4 &address, bool rendezvous, bool reuseAddr, PubSubConnectCallback connectCallbackFunc, PubSubReceiveDataCallback receiveDataCallbackFunc, bool bind); void removeDisconnectedPeers(); void receiveData(); int receiveDataFromPeer(const int socket, char *output, usize *receivedTotalSize); - bool removePeer(int peerSocket); private: u16 port; int eid; @@ -72,5 +82,13 @@ namespace sibs std::mutex peersMutex; bool alive; PubSubOnRemoveDisconnectedPeerCallback removeDisconnectedPeerCallback; + Ipv4Map> peerByAddressMap; + }; + + struct DirectConnectionsUtils + { + static std::vector serializePeers(const std::vector> &peers); + // Throws DeserializeException on error + static std::vector> deserializePeers(const u8 *data, const usize size); }; } diff --git a/include/sibs/IpAddress.hpp b/include/sibs/IpAddress.hpp index c3b43c4..4403e83 100644 --- a/include/sibs/IpAddress.hpp +++ b/include/sibs/IpAddress.hpp @@ -2,6 +2,7 @@ #include #include +#include #ifndef WIN32 #include #include @@ -32,7 +33,19 @@ namespace sibs unsigned short getPort() const; bool operator == (const Ipv4 &other) const; + bool operator != (const Ipv4 &other) const; struct sockaddr_in address; }; + + struct Ipv4Hasher + { + size_t operator()(const Ipv4 &address) const + { + return address.address.sin_addr.s_addr ^ address.address.sin_port; + } + }; + + template + using Ipv4Map = std::unordered_map; } diff --git a/include/sibs/Message.hpp b/include/sibs/Message.hpp new file mode 100644 index 0000000..b6eb858 --- /dev/null +++ b/include/sibs/Message.hpp @@ -0,0 +1,30 @@ +#pragma once + +#include "../types.hpp" +#include + +namespace sibs +{ + enum class MessageType : u8 + { + NONE, + DATA, + SUBSCRIBE, + UNSUBSCRIBE + }; + + class Message + { + public: + Message(MessageType messageType); + + void append(const void *data, const usize size); + + usize getDataSize() const { return rawData.size() - 1; } + usize getRawSize() const { return rawData.size(); } + + const u8* data() const { return rawData.data(); } + private: + std::vector rawData; + }; +} \ No newline at end of file diff --git a/include/sibs/PubsubKey.hpp b/include/sibs/PubsubKey.hpp index d123331..f1239ca 100644 --- a/include/sibs/PubsubKey.hpp +++ b/include/sibs/PubsubKey.hpp @@ -1,21 +1,14 @@ #pragma once +#include "../FnvHash.hpp" #include "../types.hpp" #include #include +#include namespace sibs { - const usize PUBSUB_KEY_LENGTH = 32; - - // Source: https://stackoverflow.com/a/11414104 (public license) - static size_t fnvHash(const unsigned char *key, int len) - { - size_t h = 2166136261; - for (int i = 0; i < len; i++) - h = (h * 16777619) ^ key[i]; - return h; - } + const usize PUBSUB_KEY_LENGTH = 20; class PubsubKey { @@ -26,6 +19,8 @@ namespace sibs bool operator == (const PubsubKey &other) const; bool operator != (const PubsubKey &other) const; + std::string toString() const; + std::array data; }; diff --git a/include/sibs/Socket.hpp b/include/sibs/Socket.hpp new file mode 100644 index 0000000..0bc9ec3 --- /dev/null +++ b/include/sibs/Socket.hpp @@ -0,0 +1,19 @@ +#pragma once + +namespace sibs +{ + class Socket + { + public: + Socket(); + Socket(int udtSocket); + Socket(int eid, int udtSocket); + Socket(Socket &&other); + Socket(const Socket&) = delete; + Socket& operator = (const Socket&) = delete; + ~Socket(); + + int eid; + int udtSocket; + }; +} \ No newline at end of file diff --git a/project.conf b/project.conf index 6659455..5863279 100644 --- a/project.conf +++ b/project.conf @@ -9,4 +9,4 @@ expose_include_dirs = ["include"] [dependencies] udt = "4.11" -sibs-serializer = "1.0.1" +sibs-serializer = "2.0.0" diff --git a/src/BootstrapConnection.cpp b/src/BootstrapConnection.cpp index e5f5178..0237a90 100644 --- a/src/BootstrapConnection.cpp +++ b/src/BootstrapConnection.cpp @@ -8,7 +8,7 @@ namespace sibs { connections.onRemoveDisconnectedPeer([this](std::shared_ptr peer) { - std::lock_guard lock(subscribedPeersMutex); + std::lock_guard lock(subscribedPeersMutex); for(auto &topicUsers : subscribedPeers) { for(auto it = topicUsers.second.begin(); it != topicUsers.second.end(); ) @@ -30,7 +30,7 @@ namespace sibs connectResult = result; connectResultStr = resultStr; connected = true; - }, std::bind(&BootstrapConnection::receiveDataFromServer, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + }, std::bind(&BootstrapConnection::receiveDataFromServer, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4)); while(!connected) { @@ -46,23 +46,27 @@ namespace sibs } // TODO: This is vulnerable against MitM attack, replace with asymmetric cryptography, get data signed with server private key and verify against known server public key - void BootstrapConnection::receiveDataFromServer(std::shared_ptr peer, const void *data, const usize size) + void BootstrapConnection::receiveDataFromServer(std::shared_ptr peer, MessageType messageType, const void *data, const usize size) { + if(messageType != MessageType::SUBSCRIBE) + { + Log::warn("BootstrapConnection: received message from server that was not subscribe"); + return; + } + Log::debug("BootstrapConnection: Received subscriber(s) from bootstrap node"); sibs::SafeDeserializer deserializer((const u8*)data, size); PubsubKey pubsubKey; deserializer.extract(pubsubKey.data.data(), PUBSUB_KEY_LENGTH); - listenerCallbackFuncMutex.lock(); + // we want lock to live this whole scope so we dont connect to peer when cancelListen is called + std::lock_guard lock(listenerCallbackFuncMutex); auto listenerFuncIt = listenCallbackFuncs.find(pubsubKey); if(listenerFuncIt == listenCallbackFuncs.end()) { - Log::debug("BoostrapConnection: No listener found for key XXX, ignoring..."); - listenerCallbackFuncMutex.unlock(); + Log::debug("BoostrapConnection: No listener found for key '%s', ignoring...", pubsubKey.toString().c_str()); return; } - auto listenerCallbackFunc = listenerFuncIt->second; - listenerCallbackFuncMutex.unlock(); while(!deserializer.empty()) { @@ -76,18 +80,19 @@ namespace sibs newPeerAddress.address.sin_addr.s_addr = ipv4Address; newPeerAddress.address.sin_port = port; memset(newPeerAddress.address.sin_zero, 0, sizeof(newPeerAddress.address.sin_zero)); - connections.connect(newPeerAddress, [this, pubsubKey](std::shared_ptr newPeer, PubSubResult result, const std::string &resultStr) + Log::debug("BootstrapConnection: received subscriber (ip: %s, port: %d) from bootstrap node", newPeerAddress.getAddress().c_str(), newPeerAddress.getPort()); + connections.connect(newPeerAddress, [this, pubsubKey](std::shared_ptr peer, PubSubResult result, const std::string &resultStr) { if(result == PubSubResult::OK) { - subscribedPeersMutex.lock(); - subscribedPeers[pubsubKey].push_back(newPeer); - subscribedPeersMutex.unlock(); - Log::debug("BootstrapConnection: Connected to peer (ip: %s, port: %d) given by bootstrap node", newPeer->address.getAddress().c_str(), newPeer->address.getPort()); + std::lock_guard lock(subscribedPeersMutex); + subscribedPeers[pubsubKey].push_back(peer); + ++peer->sharedKeys; + Log::debug("BootstrapConnection: Connected to peer (ip: %s, port: %d) given by bootstrap node", peer->address.getAddress().c_str(), peer->address.getPort()); } else Log::error("BootstrapConnection: Failed to connect to peer given by bootstrap node, error: %s", resultStr.c_str()); - }, std::bind(&BootstrapConnection::receiveDataFromPeer, this, listenerCallbackFunc, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + }, std::bind(&BootstrapConnection::receiveDataFromPeer, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4)); } else { @@ -97,46 +102,152 @@ namespace sibs } } - void BootstrapConnection::receiveDataFromPeer(BoostrapConnectionListenCallbackFunc listenCallbackFunc, std::shared_ptr peer, const void *data, const usize size) + void BootstrapConnection::receiveDataFromPeer(std::shared_ptr peer, MessageType messageType, const void *data, const usize size) { - if(listenCallbackFunc) - listenCallbackFunc(peer.get(), data, size); + if(size < PUBSUB_KEY_LENGTH) + return; + + PubsubKey pubsubKey; + memcpy(pubsubKey.data.data(), data, PUBSUB_KEY_LENGTH); + if(messageType == MessageType::DATA) + { + listenerCallbackFuncMutex.lock(); + auto listenerFuncIt = listenCallbackFuncs.find(pubsubKey); + if(listenerFuncIt == listenCallbackFuncs.end()) + { + listenerCallbackFuncMutex.unlock(); + Log::debug("BoostrapConnection: No listener found for key '%s', ignoring...", pubsubKey.toString().c_str()); + return; + } + auto listenCallbackFunc = listenerFuncIt->second; + listenerCallbackFuncMutex.unlock(); + + if(listenCallbackFunc) + { + bool continueListening = listenCallbackFunc(peer.get(), (const u8*)data + PUBSUB_KEY_LENGTH, size - PUBSUB_KEY_LENGTH); + if(!continueListening) + cancelListen({ pubsubKey }); + } + } + else if(messageType == MessageType::UNSUBSCRIBE) + { + Log::debug("BootstrapConnection: peer (ip: %s, port: %d) unsubscribed from key '%s'", peer->address.getAddress().c_str(), peer->address.getPort(), pubsubKey.toString().c_str()); + std::lock_guard subscribersMutex(subscribedPeersMutex); + auto peersListIt = subscribedPeers.find(pubsubKey); + if(peersListIt == subscribedPeers.end()) + return; + + for(auto it = peersListIt->second.begin(); it != peersListIt->second.end(); ++it) + { + auto existingPeer = *it; + if(*existingPeer == *peer) + { + peersListIt->second.erase(it); + --peer->sharedKeys; + if(peer->sharedKeys <= 0) + connections.removePeer(peer->socket->udtSocket); + break; + } + } + } + else + { + Log::warn("BootstrapConnection: received message from peer that was not data or unsubscribe"); + } } - void BootstrapConnection::listen(const PubsubKey &pubsubKey, BoostrapConnectionListenCallbackFunc callbackFunc) + ListenHandle BootstrapConnection::listen(const PubsubKey &pubsubKey, BoostrapConnectionListenCallbackFunc callbackFunc) { { - std::lock_guard lock(listenerCallbackFuncMutex); + std::lock_guard lock(listenerCallbackFuncMutex); if(listenCallbackFuncs.find(pubsubKey) != listenCallbackFuncs.end()) throw PubsubKeyAlreadyListeningException(""); listenCallbackFuncs[pubsubKey] = callbackFunc; } - connections.send(serverPeer, std::make_shared>(pubsubKey.data.begin(), pubsubKey.data.end()), - [](PubSubResult result, const std::string &resultStr) + + auto message = std::make_shared(MessageType::SUBSCRIBE); + message->append(pubsubKey.data.data(), pubsubKey.data.size()); + connections.send(serverPeer, message, [](PubSubResult result, const std::string &resultStr) { Log::debug("BootstrapConnection::listen: PubSubResult: %d, result string: %s", result, resultStr.c_str()); }); + + return { pubsubKey }; } - void BootstrapConnection::put(const PubsubKey &pubsubKey, std::shared_ptr> data) + bool BootstrapConnection::put(const PubsubKey &pubsubKey, const void *data, const usize size) { + if(size > 819200) // 800kb + return false; + { - std::lock_guard lock(listenerCallbackFuncMutex); + std::lock_guard lock(listenerCallbackFuncMutex); auto listenCallbackFuncIt = listenCallbackFuncs.find(pubsubKey); if(listenCallbackFuncIt != listenCallbackFuncs.end() && listenCallbackFuncIt->second) - listenCallbackFuncIt->second(nullptr, data->data(), data->size()); + listenCallbackFuncIt->second(nullptr, data, size); + + if(listenCallbackFuncIt == listenCallbackFuncs.end()) + Log::warn("BootstrapConnection::put on key '%s' which we are not listening to", pubsubKey.toString().c_str()); } - std::lock_guard lock(subscribedPeersMutex); + std::lock_guard lock(subscribedPeersMutex); auto peersIt = subscribedPeers.find(pubsubKey); if(peersIt == subscribedPeers.end()) { - return; + Log::warn("BootstrapConnection::put with no subscribers on same key '%s'", pubsubKey.toString().c_str()); + return true; } + auto message = std::make_shared(MessageType::DATA); + message->append(pubsubKey.data.data(), pubsubKey.data.size()); + message->append(data, size); for(auto &peer : peersIt->second) { - connections.send(peer, data); + connections.send(peer, message); + } + return true; + } + + bool BootstrapConnection::cancelListen(const ListenHandle &listener) + { + { + std::lock_guard lock(listenerCallbackFuncMutex); + auto it = listenCallbackFuncs.find(listener.key); + if(it == listenCallbackFuncs.end()) + return false; + listenCallbackFuncs.erase(it); + + auto message = std::make_shared(MessageType::UNSUBSCRIBE); + message->append(listener.key.data.data(), listener.key.data.size()); + connections.send(serverPeer, message); + + std::lock_guard subscribersMutex(subscribedPeersMutex); + auto peersListIt = subscribedPeers.find(listener.key); + // this will happen if there are no other peers subscribed to the key + if(peersListIt == subscribedPeers.end()) + return true; + + for(auto &peer : peersListIt->second) + { + --peer->sharedKeys; + if(peer->sharedKeys <= 0) + { + // disconnect from peer + connections.removePeer(peer->socket->udtSocket); + } + else + { + // unsubscribe from peers request, even if they dont accept it we wont listen to messages from the key anymore + connections.send(peer, message); + } + } + subscribedPeers.erase(peersListIt); } + return true; + } + + std::vector> BootstrapConnection::getPeers() + { + return connections.getPeers(); } } diff --git a/src/BootstrapNode.cpp b/src/BootstrapNode.cpp index 273abf0..5c62e64 100644 --- a/src/BootstrapNode.cpp +++ b/src/BootstrapNode.cpp @@ -14,9 +14,14 @@ namespace sibs { BootstrapNode::BootstrapNode(const Ipv4 &address) : - connections(27130), + connections(address.getPort()), socket(connections.createSocket(address, false, true)) { + if(connections.port != address.getPort()) + { + throw SocketCreateException("BootstrapNode: Failed to bind port " + std::to_string(address.getPort())); + } + connections.onRemoveDisconnectedPeer([this](std::shared_ptr peer) { std::lock_guard lock(subscribedPeersMutex); @@ -32,7 +37,7 @@ namespace sibs } }); - if(UDT::listen(socket, 10) == UDT::ERROR) + if(UDT::listen(socket->udtSocket, 10) == UDT::ERROR) { std::string errMsg = "UDT: Failed to listen, error: "; errMsg += UDT::getlasterror_desc(); @@ -43,8 +48,9 @@ namespace sibs BootstrapNode::~BootstrapNode() { + std::lock_guard lock(connections.peersMutex); connections.alive = false; - UDT::close(socket); + socket.reset(); acceptConnectionsThread.join(); } @@ -55,8 +61,8 @@ namespace sibs while(connections.alive) { - UDTSOCKET clientSocket = UDT::accept(socket, (sockaddr*)&clientAddr, &addrLen); - if(clientSocket == UDT::INVALID_SOCK) + UDTSOCKET clientUdtSocket = UDT::accept(socket->udtSocket, (sockaddr*)&clientAddr, &addrLen); + if(clientUdtSocket == UDT::INVALID_SOCK) { // Connection was killed because bootstrap node was taken down if(!connections.alive) @@ -65,61 +71,90 @@ namespace sibs std::this_thread::sleep_for(std::chrono::milliseconds(10)); continue; } + auto clientSocket = std::make_unique(clientUdtSocket); char clientHost[NI_MAXHOST]; char clientService[NI_MAXSERV]; getnameinfo((sockaddr *)&clientAddr, addrLen, clientHost, sizeof(clientHost), clientService, sizeof(clientService), NI_NUMERICHOST | NI_NUMERICSERV); - Log::debug("UDT: New connection: %s:%s (socket: %d)", clientHost, clientService, clientSocket); + Log::debug("UDT: New connection: %s:%s (socket: %d)", clientHost, clientService, clientUdtSocket); - UDT::epoll_add_usock(connections.eid, clientSocket); + std::lock_guard lock(connections.peersMutex); + UDT::epoll_add_usock(connections.eid, clientUdtSocket); + clientSocket->eid = connections.eid; std::shared_ptr peer = std::make_shared(); - peer->socket = clientSocket; + peer->socket = std::move(clientSocket); sockaddr_in *clientAddrSock = (sockaddr_in*)&clientAddr; memcpy(&peer->address.address, clientAddrSock, sizeof(peer->address.address)); - peer->receiveDataCallbackFunc = std::bind(&BootstrapNode::peerSubscribe, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3); - connections.peersMutex.lock(); - connections.peers[clientSocket] = peer; - connections.peersMutex.unlock(); + peer->receiveDataCallbackFunc = std::bind(&BootstrapNode::messageFromClient, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4); + connections.peers[clientUdtSocket] = peer; } } - void BootstrapNode::peerSubscribe(std::shared_ptr newPeer, const void *data, const usize size) + void BootstrapNode::messageFromClient(std::shared_ptr peer, MessageType messageType, const void *data, const usize size) { - Log::debug("BootstrapNode: Received peer subscribe from (ip: %s, port: %d)", newPeer->address.getAddress().c_str(), newPeer->address.getPort()); sibs::SafeDeserializer deserializer((const u8*)data, size); PubsubKey pubsubKey; deserializer.extract(pubsubKey.data.data(), PUBSUB_KEY_LENGTH); - - std::lock_guard lock(subscribedPeersMutex); - auto &peers = subscribedPeers[pubsubKey]; - for(auto &peer : peers) + + if(messageType == MessageType::SUBSCRIBE) { - if(peer->address == newPeer->address) - return; + Log::debug("BootstrapNode: Received peer subscribe from (ip: %s, port: %d)", peer->address.getAddress().c_str(), peer->address.getPort()); + std::lock_guard lock(subscribedPeersMutex); + auto &peers = subscribedPeers[pubsubKey]; + for(auto &existingPeer : peers) + { + if(existingPeer->address == peer->address) + return; + } + + sibs::SafeSerializer serializer; + serializer.add((u32)peer->address.address.sin_family); + serializer.add((u32)peer->address.address.sin_addr.s_addr); + serializer.add((u16)peer->address.address.sin_port); + auto newPeerMessage = std::make_shared(MessageType::SUBSCRIBE); + newPeerMessage->append(pubsubKey.data.data(), pubsubKey.data.size()); + newPeerMessage->append(serializer.getBuffer().data(), serializer.getBuffer().size()); + auto sendCallbackFunc = [](PubSubResult result, const std::string &resultStr) + { + Log::debug("BootstrapNode::peerSubscribe send result: %d, result string: %s", result, resultStr.c_str()); + }; + for(auto &existingPeer : peers) + { + connections.send(existingPeer, newPeerMessage, sendCallbackFunc); + } + + sibs::SafeSerializer newPeerSerializer; + for(auto &existingPeer : peers) + { + newPeerSerializer.add((u32)existingPeer->address.address.sin_family); + newPeerSerializer.add((u32)existingPeer->address.address.sin_addr.s_addr); + newPeerSerializer.add((u16)existingPeer->address.address.sin_port); + } + peers.push_back(peer); + + auto existingPeerMessage = std::make_shared(MessageType::SUBSCRIBE); + existingPeerMessage->append(pubsubKey.data.data(), pubsubKey.data.size()); + existingPeerMessage->append(newPeerSerializer.getBuffer().data(), newPeerSerializer.getBuffer().size()); + connections.send(peer, existingPeerMessage, sendCallbackFunc); } - - sibs::SafeSerializer serializer; - serializer.add(pubsubKey.data.data(), pubsubKey.data.size()); - serializer.add((u32)newPeer->address.address.sin_family); - serializer.add((u32)newPeer->address.address.sin_addr.s_addr); - serializer.add((u16)newPeer->address.address.sin_port); - std::shared_ptr> serializerData = std::make_shared>(std::move(serializer.getBuffer())); - - auto sendCallbackFunc = [](PubSubResult result, const std::string &resultStr) + else if(messageType == MessageType::UNSUBSCRIBE) { - Log::debug("BootstrapNode::peerSubscribe send result: %d, result string: %s", result, resultStr.c_str()); - }; - - sibs::SafeSerializer newPeerSerializer; - newPeerSerializer.add(pubsubKey.data.data(), pubsubKey.data.size()); - for(auto &peer : peers) + Log::debug("BootstrapNode: Received peer unsubscribe from (ip: %s, port: %d)", peer->address.getAddress().c_str(), peer->address.getPort()); + std::lock_guard lock(subscribedPeersMutex); + auto &peers = subscribedPeers[pubsubKey]; + for(auto it = peers.begin(); it != peers.end(); ++it) + { + auto existingPeer = *it; + if(existingPeer->address == peer->address) + { + peers.erase(it); + break; + } + } + } + else { - connections.send(peer, serializerData, sendCallbackFunc); - newPeerSerializer.add((u32)peer->address.address.sin_family); - newPeerSerializer.add((u32)peer->address.address.sin_addr.s_addr); - newPeerSerializer.add((u16)peer->address.address.sin_port); + Log::warn("BootstrapNode: received message from client that was not subscribe or unsubscribe"); } - peers.push_back(newPeer); - connections.send(newPeer, std::make_shared>(std::move(newPeerSerializer.getBuffer())), sendCallbackFunc); } } diff --git a/src/DirectConnection.cpp b/src/DirectConnection.cpp index 083f557..e41a4a5 100644 --- a/src/DirectConnection.cpp +++ b/src/DirectConnection.cpp @@ -3,6 +3,8 @@ #include #include #include +#include +#include #ifndef WIN32 #include @@ -26,6 +28,16 @@ namespace sibs // Max received data size allowed when receiving regular data, receive data as file to receive more data const int MAX_RECEIVED_DATA_SIZE = 1024 * 1024 * 1; // 1Mb + + bool DirectConnectionPeer::operator == (const DirectConnectionPeer &other) const + { + return socket->udtSocket == other.socket->udtSocket; + } + + bool DirectConnectionPeer::operator != (const DirectConnectionPeer &other) const + { + return !(*this == other); + } DirectConnections::DirectConnections(u16 _port) : port(_port == 0 ? (u16)generateRandomNumber(2000, 32000) : _port), @@ -40,29 +52,26 @@ namespace sibs DirectConnections::~DirectConnections() { alive = false; - receiveDataThread.join(); - - for(auto &peer : peers) - { - UDT::close(peer.first); - } + peers.clear(); UDT::epoll_release(eid); UDT::cleanup(); + receiveDataThread.join(); } - int DirectConnections::createSocket(const Ipv4 &addressToBind, bool rendezvous, bool reuseAddr, bool bind) + std::unique_ptr DirectConnections::createSocket(const Ipv4 &addressToBind, bool rendezvous, bool reuseAddr, bool bind) { Log::debug("UDT: Creating socket for ipv4 address %s, port: %d", addressToBind.getAddress().c_str(), addressToBind.getPort()); - UDTSOCKET socket = UDT::socket(AF_INET, SOCK_STREAM, 0); - if(socket == UDT::INVALID_SOCK) + UDTSOCKET udtSocket = UDT::socket(AF_INET, SOCK_STREAM, 0); + if(udtSocket == UDT::INVALID_SOCK) { std::string errMsg = "UDT: Failed to create socket, error: "; errMsg += UDT::getlasterror_desc(); throw SocketCreateException(errMsg); } - UDT::setsockopt(socket, 0, UDT_RENDEZVOUS, &rendezvous, sizeof(bool)); - UDT::setsockopt(socket, 0, UDT_REUSEADDR, &reuseAddr, sizeof(bool)); + auto socket = std::make_unique(udtSocket); + UDT::setsockopt(udtSocket, 0, UDT_RENDEZVOUS, &rendezvous, sizeof(bool)); + UDT::setsockopt(udtSocket, 0, UDT_REUSEADDR, &reuseAddr, sizeof(bool)); // Windows UDP issue // For better performance, modify HKLM\System\CurrentControlSet\Services\Afd\Parameters\FastSendDatagramThreshold @@ -86,9 +95,11 @@ namespace sibs Ipv4 myAddr = addressToBind; for(int i = 0; i < 2000; ++i) { - if(UDT::bind(socket, (sockaddr*)&myAddr.address, sizeof(myAddr.address)) == UDT::ERROR) + if(UDT::bind(udtSocket, (sockaddr*)&myAddr.address, sizeof(myAddr.address)) == UDT::ERROR) { - port = (u16)generateRandomNumber(2000, 32000); + u16 newPort = (u16)generateRandomNumber(2000, 32000); + Log::warn("DirectConnections: failed to bind socket to port %d, trying port %d. Fail reason: %s", port, newPort, UDT::getlasterror_desc()); + port = newPort; myAddr.address.sin_port = htons(port); } else @@ -101,7 +112,7 @@ namespace sibs Ipv4 myAddr = addressToBind; for(int i = 0; i < 2000; ++i) { - if(UDT::bind(socket, (sockaddr*)&myAddr.address, sizeof(myAddr.address)) == UDT::ERROR) + if(UDT::bind(udtSocket, (sockaddr*)&myAddr.address, sizeof(myAddr.address)) == UDT::ERROR) { port = (u16)generateRandomNumber(2000, 32000); myAddr.address.sin_port = htons(port); @@ -130,8 +141,22 @@ namespace sibs { std::thread([this, address, rendezvous, reuseAddr, connectCallbackFunc, receiveDataCallbackFunc, bind]() { - std::shared_ptr peer = std::make_shared(); - UDTSOCKET socket; + std::shared_ptr peer = getPeerByAddress(address); + if(peer) + { + // this doesn't really matter, we always call connect with same callback function + peer->receiveDataCallbackFunc = receiveDataCallbackFunc; + if(connectCallbackFunc) + connectCallbackFunc(peer, PubSubResult::OK, ""); + return; + } + else + { + peer = std::make_shared(); + peerByAddressMap[address] = peer; + } + + std::unique_ptr socket; try { socket = createSocket(Ipv4(nullptr, port), rendezvous, reuseAddr, bind); @@ -144,19 +169,20 @@ namespace sibs } Log::debug("DirectConnections: Connecting to peer (ip: %s, port: %d, rendezvous: %s)", address.getAddress().c_str(), address.getPort(), rendezvous ? "yes" : "no"); - if(UDT::connect(socket, (sockaddr*)&address.address, sizeof(address.address)) == UDT::ERROR) + if(UDT::connect(socket->udtSocket, (sockaddr*)&address.address, sizeof(address.address)) == UDT::ERROR) { if(connectCallbackFunc) connectCallbackFunc(peer, PubSubResult::ERROR, UDT::getlasterror_desc()); return; } - UDT::epoll_add_usock(eid, socket); - peer->socket = socket; + UDT::epoll_add_usock(eid, socket->udtSocket); + socket->eid = eid; + peersMutex.lock(); + peers[socket->udtSocket] = peer; + peer->socket = std::move(socket); peer->address = address; peer->receiveDataCallbackFunc = receiveDataCallbackFunc; - peersMutex.lock(); - peers[socket] = peer; peersMutex.unlock(); if(connectCallbackFunc) @@ -164,27 +190,27 @@ namespace sibs }).detach(); } - void DirectConnections::send(const std::shared_ptr peer, std::shared_ptr> data, PubSubSendDataCallback sendDataCallbackFunc) + bool DirectConnections::send(const std::shared_ptr peer, std::shared_ptr data, PubSubSendDataCallback sendDataCallbackFunc) { + if(data->getDataSize() == 0) + return true; + + if(data->getDataSize() > 819200) // 800kb + return false; + // TODO: Replace this with light-weight threads (fibers)? std::thread([peer, data, sendDataCallbackFunc]() { - usize sentSizeTotal = 0; - while(sentSizeTotal < data->size() || data->size() == 0) + int sentSize = UDT::send(peer->socket->udtSocket, (char*)data->data(), data->getRawSize(), 0); + if(sentSize == UDT::ERROR) { - int sentSize = UDT::send(peer->socket, (char*)data->data() + sentSizeTotal, data->size() - sentSizeTotal, 0); - if(sentSize == UDT::ERROR) - { - if(sendDataCallbackFunc) - sendDataCallbackFunc(PubSubResult::ERROR, UDT::getlasterror_desc()); - } - sentSizeTotal += sentSize; - if(data->size() == 0) - break; + if(sendDataCallbackFunc) + sendDataCallbackFunc(PubSubResult::ERROR, UDT::getlasterror_desc()); } - if(sendDataCallbackFunc) + else if(sendDataCallbackFunc) sendDataCallbackFunc(PubSubResult::OK, ""); }).detach(); + return true; } void DirectConnections::onRemoveDisconnectedPeer(PubSubOnRemoveDisconnectedPeerCallback callbackFunc) @@ -195,19 +221,29 @@ namespace sibs bool DirectConnections::removePeer(int peerSocket) { bool wasRemoved = false; - peersMutex.lock(); + std::lock_guard lock(peersMutex); auto peerIt = peers.find(peerSocket); if(peerIt != peers.end()) { if(removeDisconnectedPeerCallback) removeDisconnectedPeerCallback(peerIt->second); - UDT::epoll_remove_usock(eid, peerSocket); peers.erase(peerIt); wasRemoved = true; } - peersMutex.unlock(); return wasRemoved; } + + std::vector> DirectConnections::getPeers() + { + std::vector> result; + result.reserve(peers.size()); + std::lock_guard lock(peersMutex); + for(auto &it : peers) + { + result.push_back(it.second); + } + return result; + } void DirectConnections::removeDisconnectedPeers() { @@ -221,9 +257,6 @@ namespace sibs if(removeDisconnectedPeerCallback) removeDisconnectedPeerCallback(it->second); - if(peerSocketStatus == UDTSTATUS::BROKEN) - UDT::epoll_remove_usock(eid, socket); - Log::debug("UDT: Connection was broken to socket %d (peer most likely disconnected), removing peer", socket); it = peers.erase(it); Log::debug("UDT: Removed peer socket %d", socket); @@ -233,11 +266,19 @@ namespace sibs } peersMutex.unlock(); } + + std::shared_ptr DirectConnections::getPeerByAddress(const Ipv4 &address) const + { + auto it = peerByAddressMap.find(address); + if(it != peerByAddressMap.end()) + return it->second; + return nullptr; + } void DirectConnections::receiveData() { std::vector data; - data.reserve(MAX_RECEIVED_DATA_SIZE); + data.resize(MAX_RECEIVED_DATA_SIZE); Log::debug("DirectConnections::receiveData(): waiting for events"); std::set readfds; @@ -251,6 +292,9 @@ namespace sibs } else if(numfsReady == -1) { + if(!alive) + continue; + if(UDT::getlasterror_code() == UDT::ERRORINFO::ETIMEOUT) { continue; @@ -275,8 +319,12 @@ namespace sibs try { Log::debug("DirectConnection: Received data from peer: (ip: %s, port: %d)", peer->address.getAddress().c_str(), peer->address.getPort()); - if(peer->receiveDataCallbackFunc) - peer->receiveDataCallbackFunc(peer, data.data(), receivedTotalSize); + if(peer->receiveDataCallbackFunc && receivedTotalSize > 0) + { + static_assert(sizeof(MessageType) == sizeof(u8), ""); + MessageType messageType = (MessageType)data[0]; + peer->receiveDataCallbackFunc(peer, messageType, data.data() + 1, receivedTotalSize - 1); + } } catch(std::exception &e) { @@ -340,4 +388,31 @@ namespace sibs Log::error("UDT: Received too much data, ignoring..."); return 0; } + + std::vector DirectConnectionsUtils::serializePeers(const std::vector> &peers) + { + sibs::SafeSerializer serializer; + for(const auto &it : peers) + { + serializer.add((u32)it->address.address.sin_family); + serializer.add((u32)it->address.address.sin_addr.s_addr); + serializer.add((u16)it->address.address.sin_port); + } + return serializer.getBuffer(); + } + + std::vector> DirectConnectionsUtils::deserializePeers(const u8 *data, const usize size) + { + std::vector> result; + sibs::SafeDeserializer deserializer(data, size); + while(!deserializer.empty()) + { + std::shared_ptr peer = std::make_shared(); + peer->address.address.sin_family = deserializer.extract(); + peer->address.address.sin_addr.s_addr = deserializer.extract(); + peer->address.address.sin_port = deserializer.extract(); + result.push_back(peer); + } + return result; + } } diff --git a/src/IpAddress.cpp b/src/IpAddress.cpp index c90c924..58ed661 100644 --- a/src/IpAddress.cpp +++ b/src/IpAddress.cpp @@ -56,6 +56,11 @@ namespace sibs bool Ipv4::operator == (const Ipv4 &other) const { - return address.sin_addr.s_addr == other.address.sin_addr.s_addr && getPort() == other.getPort(); + return address.sin_addr.s_addr == other.address.sin_addr.s_addr && address.sin_port == other.address.sin_port; + } + + bool Ipv4::operator != (const Ipv4 &other) const + { + return !(*this == other); } } diff --git a/src/Message.cpp b/src/Message.cpp new file mode 100644 index 0000000..58b39da --- /dev/null +++ b/src/Message.cpp @@ -0,0 +1,15 @@ +#include "../include/sibs/Message.hpp" + +namespace sibs +{ + Message::Message(MessageType messageType) + { + static_assert(sizeof(MessageType) == sizeof(u8), "Whoops, message type size has changed, the below code doesn't work"); + rawData.push_back((u8)messageType); + } + + void Message::append(const void *data, const usize size) + { + rawData.insert(rawData.end(), (const u8*)data, (const u8*)data + size); + } +} \ No newline at end of file diff --git a/src/PubsubKey.cpp b/src/PubsubKey.cpp index dd807ce..64d2621 100644 --- a/src/PubsubKey.cpp +++ b/src/PubsubKey.cpp @@ -2,6 +2,8 @@ namespace sibs { + static const char *HEX_TABLE = "0123456789abcdef"; + PubsubKey::PubsubKey() : data({}) { @@ -10,9 +12,10 @@ namespace sibs PubsubKey::PubsubKey(const void *data, const usize size) { - std::copy((char*)data, (char*)data + std::min(size, PUBSUB_KEY_LENGTH), this->data.begin()); + usize _size = std::min(size, PUBSUB_KEY_LENGTH); + std::copy((char*)data, (char*)data + _size, this->data.begin()); if(size < PUBSUB_KEY_LENGTH) - std::fill_n((char*)data + size, PUBSUB_KEY_LENGTH - size, 0); + std::fill_n(this->data.begin() + size, PUBSUB_KEY_LENGTH - size, 0); } bool PubsubKey::operator == (const PubsubKey &other) const @@ -24,4 +27,17 @@ namespace sibs { return data != other.data; } + + std::string PubsubKey::toString() const + { + std::string result; + result.reserve(data.size()); + for(usize i = 0; i < data.size(); ++i) + { + u8 c = data[i]; + result += HEX_TABLE[(c & 0xF0) >> 4]; + result += HEX_TABLE[c & 0x0F]; + } + return result; + } } diff --git a/src/Socket.cpp b/src/Socket.cpp new file mode 100644 index 0000000..9c8da69 --- /dev/null +++ b/src/Socket.cpp @@ -0,0 +1,40 @@ +#include "../include/sibs/Socket.hpp" +#include + +namespace sibs +{ + Socket::Socket() : + eid(-1), + udtSocket(-1) + { + + } + + Socket::Socket(int _udtSocket) : + eid(-1), + udtSocket(_udtSocket) + { + + } + + Socket::Socket(int _eid, int _udtSocket) : + eid(_eid), + udtSocket(_udtSocket) + { + + } + + Socket::Socket(Socket &&other) + { + eid = other.eid; + udtSocket = other.udtSocket; + other.eid = 0; + other.udtSocket = 0; + } + + Socket::~Socket() + { + UDT::close(udtSocket); + UDT::epoll_remove_usock(eid, udtSocket); + } +} \ No newline at end of file diff --git a/tests/main.cpp b/tests/main.cpp index b2343af..d650128 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -1,7 +1,43 @@ -#include "../include/sibs/DirectConnection.hpp" +#include "../include/sibs/BootstrapConnection.hpp" +#include "../include/sibs/BootstrapNode.hpp" +#include +#include +#include + +const int PORT = 35231; + +#define REQUIRE(expr) do { if(!(expr)) { fprintf(stderr, "Assert failed: %s\n", #expr); exit(1); } }while(0) int main() { - sibs::DirectConnections user1(27137); + const sibs::PubsubKey key("abcdefghijklmnopqrstuvxyz0123456789", 35); + sibs::BootstrapNode boostrapNode(sibs::Ipv4(nullptr, PORT)); + + sibs::BootstrapConnection connection1(sibs::Ipv4("127.0.0.1", PORT)); + bool gotData1 = false; + connection1.listen(key, [&gotData1](const sibs::DirectConnectionPeer *peer, const void *data, const sibs::usize size) + { + if(size == 5 && strncmp((const char*)data, "hello", 5) == 0) + gotData1 = true; + return true; + }); + + sibs::BootstrapConnection connection2(sibs::Ipv4("127.0.0.1", PORT)); + bool gotData2 = false; + connection2.listen(key, [&gotData2](const sibs::DirectConnectionPeer *peer, const void *data, const sibs::usize size) + { + if(size == 5 && strncmp((const char*)data, "hello", 5) == 0) + gotData2 = true; + return true; + }); + // wait until connection1 and connection2 receive each other as peers from bootstrap node + std::this_thread::sleep_for(std::chrono::seconds(3)); + + connection1.put(key, "hello", 5); + std::this_thread::sleep_for(std::chrono::seconds(3)); + + REQUIRE(gotData1); + REQUIRE(gotData2); + return 0; } -- cgit v1.2.3