diff options
author | dec05eba <dec05eba@protonmail.com> | 2018-10-16 00:37:21 +0200 |
---|---|---|
committer | dec05eba <dec05eba@protonmail.com> | 2020-08-18 22:56:48 +0200 |
commit | c47870421f189eb98fc66e912693d73fbd8477ee (patch) | |
tree | 036ead590fa17bef279de483489a880c54ef4ba1 /src | |
parent | 0c1b3db7c4d9a4bcde4160c437613b32cd4081d6 (diff) |
Reuse peer connection if subscribed to same key
Diffstat (limited to 'src')
-rw-r--r-- | src/BootstrapConnection.cpp | 165 | ||||
-rw-r--r-- | src/BootstrapNode.cpp | 117 | ||||
-rw-r--r-- | src/DirectConnection.cpp | 161 | ||||
-rw-r--r-- | src/IpAddress.cpp | 7 | ||||
-rw-r--r-- | src/Message.cpp | 15 | ||||
-rw-r--r-- | src/PubsubKey.cpp | 20 | ||||
-rw-r--r-- | src/Socket.cpp | 40 |
7 files changed, 411 insertions, 114 deletions
diff --git a/src/BootstrapConnection.cpp b/src/BootstrapConnection.cpp index e5f5178..0237a90 100644 --- a/src/BootstrapConnection.cpp +++ b/src/BootstrapConnection.cpp @@ -8,7 +8,7 @@ namespace sibs { connections.onRemoveDisconnectedPeer([this](std::shared_ptr<DirectConnectionPeer> peer) { - std::lock_guard<std::mutex> lock(subscribedPeersMutex); + std::lock_guard<std::recursive_mutex> lock(subscribedPeersMutex); for(auto &topicUsers : subscribedPeers) { for(auto it = topicUsers.second.begin(); it != topicUsers.second.end(); ) @@ -30,7 +30,7 @@ namespace sibs connectResult = result; connectResultStr = resultStr; connected = true; - }, std::bind(&BootstrapConnection::receiveDataFromServer, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + }, std::bind(&BootstrapConnection::receiveDataFromServer, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4)); while(!connected) { @@ -46,23 +46,27 @@ namespace sibs } // TODO: This is vulnerable against MitM attack, replace with asymmetric cryptography, get data signed with server private key and verify against known server public key - void BootstrapConnection::receiveDataFromServer(std::shared_ptr<DirectConnectionPeer> peer, const void *data, const usize size) + void BootstrapConnection::receiveDataFromServer(std::shared_ptr<DirectConnectionPeer> peer, MessageType messageType, const void *data, const usize size) { + if(messageType != MessageType::SUBSCRIBE) + { + Log::warn("BootstrapConnection: received message from server that was not subscribe"); + return; + } + Log::debug("BootstrapConnection: Received subscriber(s) from bootstrap node"); sibs::SafeDeserializer deserializer((const u8*)data, size); PubsubKey pubsubKey; deserializer.extract(pubsubKey.data.data(), PUBSUB_KEY_LENGTH); - listenerCallbackFuncMutex.lock(); + // we want lock to live this whole scope so we dont connect to peer when cancelListen is called + std::lock_guard<std::recursive_mutex> lock(listenerCallbackFuncMutex); auto listenerFuncIt = listenCallbackFuncs.find(pubsubKey); if(listenerFuncIt == listenCallbackFuncs.end()) { - Log::debug("BoostrapConnection: No listener found for key XXX, ignoring..."); - listenerCallbackFuncMutex.unlock(); + Log::debug("BoostrapConnection: No listener found for key '%s', ignoring...", pubsubKey.toString().c_str()); return; } - auto listenerCallbackFunc = listenerFuncIt->second; - listenerCallbackFuncMutex.unlock(); while(!deserializer.empty()) { @@ -76,18 +80,19 @@ namespace sibs newPeerAddress.address.sin_addr.s_addr = ipv4Address; newPeerAddress.address.sin_port = port; memset(newPeerAddress.address.sin_zero, 0, sizeof(newPeerAddress.address.sin_zero)); - connections.connect(newPeerAddress, [this, pubsubKey](std::shared_ptr<DirectConnectionPeer> newPeer, PubSubResult result, const std::string &resultStr) + Log::debug("BootstrapConnection: received subscriber (ip: %s, port: %d) from bootstrap node", newPeerAddress.getAddress().c_str(), newPeerAddress.getPort()); + connections.connect(newPeerAddress, [this, pubsubKey](std::shared_ptr<DirectConnectionPeer> peer, PubSubResult result, const std::string &resultStr) { if(result == PubSubResult::OK) { - subscribedPeersMutex.lock(); - subscribedPeers[pubsubKey].push_back(newPeer); - subscribedPeersMutex.unlock(); - Log::debug("BootstrapConnection: Connected to peer (ip: %s, port: %d) given by bootstrap node", newPeer->address.getAddress().c_str(), newPeer->address.getPort()); + std::lock_guard<std::recursive_mutex> lock(subscribedPeersMutex); + subscribedPeers[pubsubKey].push_back(peer); + ++peer->sharedKeys; + Log::debug("BootstrapConnection: Connected to peer (ip: %s, port: %d) given by bootstrap node", peer->address.getAddress().c_str(), peer->address.getPort()); } else Log::error("BootstrapConnection: Failed to connect to peer given by bootstrap node, error: %s", resultStr.c_str()); - }, std::bind(&BootstrapConnection::receiveDataFromPeer, this, listenerCallbackFunc, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + }, std::bind(&BootstrapConnection::receiveDataFromPeer, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4)); } else { @@ -97,46 +102,152 @@ namespace sibs } } - void BootstrapConnection::receiveDataFromPeer(BoostrapConnectionListenCallbackFunc listenCallbackFunc, std::shared_ptr<DirectConnectionPeer> peer, const void *data, const usize size) + void BootstrapConnection::receiveDataFromPeer(std::shared_ptr<DirectConnectionPeer> peer, MessageType messageType, const void *data, const usize size) { - if(listenCallbackFunc) - listenCallbackFunc(peer.get(), data, size); + if(size < PUBSUB_KEY_LENGTH) + return; + + PubsubKey pubsubKey; + memcpy(pubsubKey.data.data(), data, PUBSUB_KEY_LENGTH); + if(messageType == MessageType::DATA) + { + listenerCallbackFuncMutex.lock(); + auto listenerFuncIt = listenCallbackFuncs.find(pubsubKey); + if(listenerFuncIt == listenCallbackFuncs.end()) + { + listenerCallbackFuncMutex.unlock(); + Log::debug("BoostrapConnection: No listener found for key '%s', ignoring...", pubsubKey.toString().c_str()); + return; + } + auto listenCallbackFunc = listenerFuncIt->second; + listenerCallbackFuncMutex.unlock(); + + if(listenCallbackFunc) + { + bool continueListening = listenCallbackFunc(peer.get(), (const u8*)data + PUBSUB_KEY_LENGTH, size - PUBSUB_KEY_LENGTH); + if(!continueListening) + cancelListen({ pubsubKey }); + } + } + else if(messageType == MessageType::UNSUBSCRIBE) + { + Log::debug("BootstrapConnection: peer (ip: %s, port: %d) unsubscribed from key '%s'", peer->address.getAddress().c_str(), peer->address.getPort(), pubsubKey.toString().c_str()); + std::lock_guard<std::recursive_mutex> subscribersMutex(subscribedPeersMutex); + auto peersListIt = subscribedPeers.find(pubsubKey); + if(peersListIt == subscribedPeers.end()) + return; + + for(auto it = peersListIt->second.begin(); it != peersListIt->second.end(); ++it) + { + auto existingPeer = *it; + if(*existingPeer == *peer) + { + peersListIt->second.erase(it); + --peer->sharedKeys; + if(peer->sharedKeys <= 0) + connections.removePeer(peer->socket->udtSocket); + break; + } + } + } + else + { + Log::warn("BootstrapConnection: received message from peer that was not data or unsubscribe"); + } } - void BootstrapConnection::listen(const PubsubKey &pubsubKey, BoostrapConnectionListenCallbackFunc callbackFunc) + ListenHandle BootstrapConnection::listen(const PubsubKey &pubsubKey, BoostrapConnectionListenCallbackFunc callbackFunc) { { - std::lock_guard<std::mutex> lock(listenerCallbackFuncMutex); + std::lock_guard<std::recursive_mutex> lock(listenerCallbackFuncMutex); if(listenCallbackFuncs.find(pubsubKey) != listenCallbackFuncs.end()) throw PubsubKeyAlreadyListeningException(""); listenCallbackFuncs[pubsubKey] = callbackFunc; } - connections.send(serverPeer, std::make_shared<std::vector<u8>>(pubsubKey.data.begin(), pubsubKey.data.end()), - [](PubSubResult result, const std::string &resultStr) + + auto message = std::make_shared<Message>(MessageType::SUBSCRIBE); + message->append(pubsubKey.data.data(), pubsubKey.data.size()); + connections.send(serverPeer, message, [](PubSubResult result, const std::string &resultStr) { Log::debug("BootstrapConnection::listen: PubSubResult: %d, result string: %s", result, resultStr.c_str()); }); + + return { pubsubKey }; } - void BootstrapConnection::put(const PubsubKey &pubsubKey, std::shared_ptr<std::vector<u8>> data) + bool BootstrapConnection::put(const PubsubKey &pubsubKey, const void *data, const usize size) { + if(size > 819200) // 800kb + return false; + { - std::lock_guard<std::mutex> lock(listenerCallbackFuncMutex); + std::lock_guard<std::recursive_mutex> lock(listenerCallbackFuncMutex); auto listenCallbackFuncIt = listenCallbackFuncs.find(pubsubKey); if(listenCallbackFuncIt != listenCallbackFuncs.end() && listenCallbackFuncIt->second) - listenCallbackFuncIt->second(nullptr, data->data(), data->size()); + listenCallbackFuncIt->second(nullptr, data, size); + + if(listenCallbackFuncIt == listenCallbackFuncs.end()) + Log::warn("BootstrapConnection::put on key '%s' which we are not listening to", pubsubKey.toString().c_str()); } - std::lock_guard<std::mutex> lock(subscribedPeersMutex); + std::lock_guard<std::recursive_mutex> lock(subscribedPeersMutex); auto peersIt = subscribedPeers.find(pubsubKey); if(peersIt == subscribedPeers.end()) { - return; + Log::warn("BootstrapConnection::put with no subscribers on same key '%s'", pubsubKey.toString().c_str()); + return true; } + auto message = std::make_shared<Message>(MessageType::DATA); + message->append(pubsubKey.data.data(), pubsubKey.data.size()); + message->append(data, size); for(auto &peer : peersIt->second) { - connections.send(peer, data); + connections.send(peer, message); + } + return true; + } + + bool BootstrapConnection::cancelListen(const ListenHandle &listener) + { + { + std::lock_guard<std::recursive_mutex> lock(listenerCallbackFuncMutex); + auto it = listenCallbackFuncs.find(listener.key); + if(it == listenCallbackFuncs.end()) + return false; + listenCallbackFuncs.erase(it); + + auto message = std::make_shared<Message>(MessageType::UNSUBSCRIBE); + message->append(listener.key.data.data(), listener.key.data.size()); + connections.send(serverPeer, message); + + std::lock_guard<std::recursive_mutex> subscribersMutex(subscribedPeersMutex); + auto peersListIt = subscribedPeers.find(listener.key); + // this will happen if there are no other peers subscribed to the key + if(peersListIt == subscribedPeers.end()) + return true; + + for(auto &peer : peersListIt->second) + { + --peer->sharedKeys; + if(peer->sharedKeys <= 0) + { + // disconnect from peer + connections.removePeer(peer->socket->udtSocket); + } + else + { + // unsubscribe from peers request, even if they dont accept it we wont listen to messages from the key anymore + connections.send(peer, message); + } + } + subscribedPeers.erase(peersListIt); } + return true; + } + + std::vector<std::shared_ptr<DirectConnectionPeer>> BootstrapConnection::getPeers() + { + return connections.getPeers(); } } diff --git a/src/BootstrapNode.cpp b/src/BootstrapNode.cpp index 273abf0..5c62e64 100644 --- a/src/BootstrapNode.cpp +++ b/src/BootstrapNode.cpp @@ -14,9 +14,14 @@ namespace sibs { BootstrapNode::BootstrapNode(const Ipv4 &address) : - connections(27130), + connections(address.getPort()), socket(connections.createSocket(address, false, true)) { + if(connections.port != address.getPort()) + { + throw SocketCreateException("BootstrapNode: Failed to bind port " + std::to_string(address.getPort())); + } + connections.onRemoveDisconnectedPeer([this](std::shared_ptr<DirectConnectionPeer> peer) { std::lock_guard<std::mutex> lock(subscribedPeersMutex); @@ -32,7 +37,7 @@ namespace sibs } }); - if(UDT::listen(socket, 10) == UDT::ERROR) + if(UDT::listen(socket->udtSocket, 10) == UDT::ERROR) { std::string errMsg = "UDT: Failed to listen, error: "; errMsg += UDT::getlasterror_desc(); @@ -43,8 +48,9 @@ namespace sibs BootstrapNode::~BootstrapNode() { + std::lock_guard<std::mutex> lock(connections.peersMutex); connections.alive = false; - UDT::close(socket); + socket.reset(); acceptConnectionsThread.join(); } @@ -55,8 +61,8 @@ namespace sibs while(connections.alive) { - UDTSOCKET clientSocket = UDT::accept(socket, (sockaddr*)&clientAddr, &addrLen); - if(clientSocket == UDT::INVALID_SOCK) + UDTSOCKET clientUdtSocket = UDT::accept(socket->udtSocket, (sockaddr*)&clientAddr, &addrLen); + if(clientUdtSocket == UDT::INVALID_SOCK) { // Connection was killed because bootstrap node was taken down if(!connections.alive) @@ -65,61 +71,90 @@ namespace sibs std::this_thread::sleep_for(std::chrono::milliseconds(10)); continue; } + auto clientSocket = std::make_unique<Socket>(clientUdtSocket); char clientHost[NI_MAXHOST]; char clientService[NI_MAXSERV]; getnameinfo((sockaddr *)&clientAddr, addrLen, clientHost, sizeof(clientHost), clientService, sizeof(clientService), NI_NUMERICHOST | NI_NUMERICSERV); - Log::debug("UDT: New connection: %s:%s (socket: %d)", clientHost, clientService, clientSocket); + Log::debug("UDT: New connection: %s:%s (socket: %d)", clientHost, clientService, clientUdtSocket); - UDT::epoll_add_usock(connections.eid, clientSocket); + std::lock_guard<std::mutex> lock(connections.peersMutex); + UDT::epoll_add_usock(connections.eid, clientUdtSocket); + clientSocket->eid = connections.eid; std::shared_ptr<DirectConnectionPeer> peer = std::make_shared<DirectConnectionPeer>(); - peer->socket = clientSocket; + peer->socket = std::move(clientSocket); sockaddr_in *clientAddrSock = (sockaddr_in*)&clientAddr; memcpy(&peer->address.address, clientAddrSock, sizeof(peer->address.address)); - peer->receiveDataCallbackFunc = std::bind(&BootstrapNode::peerSubscribe, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3); - connections.peersMutex.lock(); - connections.peers[clientSocket] = peer; - connections.peersMutex.unlock(); + peer->receiveDataCallbackFunc = std::bind(&BootstrapNode::messageFromClient, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4); + connections.peers[clientUdtSocket] = peer; } } - void BootstrapNode::peerSubscribe(std::shared_ptr<DirectConnectionPeer> newPeer, const void *data, const usize size) + void BootstrapNode::messageFromClient(std::shared_ptr<DirectConnectionPeer> peer, MessageType messageType, const void *data, const usize size) { - Log::debug("BootstrapNode: Received peer subscribe from (ip: %s, port: %d)", newPeer->address.getAddress().c_str(), newPeer->address.getPort()); sibs::SafeDeserializer deserializer((const u8*)data, size); PubsubKey pubsubKey; deserializer.extract(pubsubKey.data.data(), PUBSUB_KEY_LENGTH); - - std::lock_guard<std::mutex> lock(subscribedPeersMutex); - auto &peers = subscribedPeers[pubsubKey]; - for(auto &peer : peers) + + if(messageType == MessageType::SUBSCRIBE) { - if(peer->address == newPeer->address) - return; + Log::debug("BootstrapNode: Received peer subscribe from (ip: %s, port: %d)", peer->address.getAddress().c_str(), peer->address.getPort()); + std::lock_guard<std::mutex> lock(subscribedPeersMutex); + auto &peers = subscribedPeers[pubsubKey]; + for(auto &existingPeer : peers) + { + if(existingPeer->address == peer->address) + return; + } + + sibs::SafeSerializer serializer; + serializer.add((u32)peer->address.address.sin_family); + serializer.add((u32)peer->address.address.sin_addr.s_addr); + serializer.add((u16)peer->address.address.sin_port); + auto newPeerMessage = std::make_shared<Message>(MessageType::SUBSCRIBE); + newPeerMessage->append(pubsubKey.data.data(), pubsubKey.data.size()); + newPeerMessage->append(serializer.getBuffer().data(), serializer.getBuffer().size()); + auto sendCallbackFunc = [](PubSubResult result, const std::string &resultStr) + { + Log::debug("BootstrapNode::peerSubscribe send result: %d, result string: %s", result, resultStr.c_str()); + }; + for(auto &existingPeer : peers) + { + connections.send(existingPeer, newPeerMessage, sendCallbackFunc); + } + + sibs::SafeSerializer newPeerSerializer; + for(auto &existingPeer : peers) + { + newPeerSerializer.add((u32)existingPeer->address.address.sin_family); + newPeerSerializer.add((u32)existingPeer->address.address.sin_addr.s_addr); + newPeerSerializer.add((u16)existingPeer->address.address.sin_port); + } + peers.push_back(peer); + + auto existingPeerMessage = std::make_shared<Message>(MessageType::SUBSCRIBE); + existingPeerMessage->append(pubsubKey.data.data(), pubsubKey.data.size()); + existingPeerMessage->append(newPeerSerializer.getBuffer().data(), newPeerSerializer.getBuffer().size()); + connections.send(peer, existingPeerMessage, sendCallbackFunc); } - - sibs::SafeSerializer serializer; - serializer.add(pubsubKey.data.data(), pubsubKey.data.size()); - serializer.add((u32)newPeer->address.address.sin_family); - serializer.add((u32)newPeer->address.address.sin_addr.s_addr); - serializer.add((u16)newPeer->address.address.sin_port); - std::shared_ptr<std::vector<u8>> serializerData = std::make_shared<std::vector<u8>>(std::move(serializer.getBuffer())); - - auto sendCallbackFunc = [](PubSubResult result, const std::string &resultStr) + else if(messageType == MessageType::UNSUBSCRIBE) { - Log::debug("BootstrapNode::peerSubscribe send result: %d, result string: %s", result, resultStr.c_str()); - }; - - sibs::SafeSerializer newPeerSerializer; - newPeerSerializer.add(pubsubKey.data.data(), pubsubKey.data.size()); - for(auto &peer : peers) + Log::debug("BootstrapNode: Received peer unsubscribe from (ip: %s, port: %d)", peer->address.getAddress().c_str(), peer->address.getPort()); + std::lock_guard<std::mutex> lock(subscribedPeersMutex); + auto &peers = subscribedPeers[pubsubKey]; + for(auto it = peers.begin(); it != peers.end(); ++it) + { + auto existingPeer = *it; + if(existingPeer->address == peer->address) + { + peers.erase(it); + break; + } + } + } + else { - connections.send(peer, serializerData, sendCallbackFunc); - newPeerSerializer.add((u32)peer->address.address.sin_family); - newPeerSerializer.add((u32)peer->address.address.sin_addr.s_addr); - newPeerSerializer.add((u16)peer->address.address.sin_port); + Log::warn("BootstrapNode: received message from client that was not subscribe or unsubscribe"); } - peers.push_back(newPeer); - connections.send(newPeer, std::make_shared<std::vector<u8>>(std::move(newPeerSerializer.getBuffer())), sendCallbackFunc); } } diff --git a/src/DirectConnection.cpp b/src/DirectConnection.cpp index 083f557..e41a4a5 100644 --- a/src/DirectConnection.cpp +++ b/src/DirectConnection.cpp @@ -3,6 +3,8 @@ #include <cstdio> #include <cstring> #include <random> +#include <sibs/SafeSerializer.hpp> +#include <sibs/SafeDeserializer.hpp> #ifndef WIN32 #include <arpa/inet.h> @@ -26,6 +28,16 @@ namespace sibs // Max received data size allowed when receiving regular data, receive data as file to receive more data const int MAX_RECEIVED_DATA_SIZE = 1024 * 1024 * 1; // 1Mb + + bool DirectConnectionPeer::operator == (const DirectConnectionPeer &other) const + { + return socket->udtSocket == other.socket->udtSocket; + } + + bool DirectConnectionPeer::operator != (const DirectConnectionPeer &other) const + { + return !(*this == other); + } DirectConnections::DirectConnections(u16 _port) : port(_port == 0 ? (u16)generateRandomNumber(2000, 32000) : _port), @@ -40,29 +52,26 @@ namespace sibs DirectConnections::~DirectConnections() { alive = false; - receiveDataThread.join(); - - for(auto &peer : peers) - { - UDT::close(peer.first); - } + peers.clear(); UDT::epoll_release(eid); UDT::cleanup(); + receiveDataThread.join(); } - int DirectConnections::createSocket(const Ipv4 &addressToBind, bool rendezvous, bool reuseAddr, bool bind) + std::unique_ptr<Socket> DirectConnections::createSocket(const Ipv4 &addressToBind, bool rendezvous, bool reuseAddr, bool bind) { Log::debug("UDT: Creating socket for ipv4 address %s, port: %d", addressToBind.getAddress().c_str(), addressToBind.getPort()); - UDTSOCKET socket = UDT::socket(AF_INET, SOCK_STREAM, 0); - if(socket == UDT::INVALID_SOCK) + UDTSOCKET udtSocket = UDT::socket(AF_INET, SOCK_STREAM, 0); + if(udtSocket == UDT::INVALID_SOCK) { std::string errMsg = "UDT: Failed to create socket, error: "; errMsg += UDT::getlasterror_desc(); throw SocketCreateException(errMsg); } - UDT::setsockopt(socket, 0, UDT_RENDEZVOUS, &rendezvous, sizeof(bool)); - UDT::setsockopt(socket, 0, UDT_REUSEADDR, &reuseAddr, sizeof(bool)); + auto socket = std::make_unique<Socket>(udtSocket); + UDT::setsockopt(udtSocket, 0, UDT_RENDEZVOUS, &rendezvous, sizeof(bool)); + UDT::setsockopt(udtSocket, 0, UDT_REUSEADDR, &reuseAddr, sizeof(bool)); // Windows UDP issue // For better performance, modify HKLM\System\CurrentControlSet\Services\Afd\Parameters\FastSendDatagramThreshold @@ -86,9 +95,11 @@ namespace sibs Ipv4 myAddr = addressToBind; for(int i = 0; i < 2000; ++i) { - if(UDT::bind(socket, (sockaddr*)&myAddr.address, sizeof(myAddr.address)) == UDT::ERROR) + if(UDT::bind(udtSocket, (sockaddr*)&myAddr.address, sizeof(myAddr.address)) == UDT::ERROR) { - port = (u16)generateRandomNumber(2000, 32000); + u16 newPort = (u16)generateRandomNumber(2000, 32000); + Log::warn("DirectConnections: failed to bind socket to port %d, trying port %d. Fail reason: %s", port, newPort, UDT::getlasterror_desc()); + port = newPort; myAddr.address.sin_port = htons(port); } else @@ -101,7 +112,7 @@ namespace sibs Ipv4 myAddr = addressToBind; for(int i = 0; i < 2000; ++i) { - if(UDT::bind(socket, (sockaddr*)&myAddr.address, sizeof(myAddr.address)) == UDT::ERROR) + if(UDT::bind(udtSocket, (sockaddr*)&myAddr.address, sizeof(myAddr.address)) == UDT::ERROR) { port = (u16)generateRandomNumber(2000, 32000); myAddr.address.sin_port = htons(port); @@ -130,8 +141,22 @@ namespace sibs { std::thread([this, address, rendezvous, reuseAddr, connectCallbackFunc, receiveDataCallbackFunc, bind]() { - std::shared_ptr<DirectConnectionPeer> peer = std::make_shared<DirectConnectionPeer>(); - UDTSOCKET socket; + std::shared_ptr<DirectConnectionPeer> peer = getPeerByAddress(address); + if(peer) + { + // this doesn't really matter, we always call connect with same callback function + peer->receiveDataCallbackFunc = receiveDataCallbackFunc; + if(connectCallbackFunc) + connectCallbackFunc(peer, PubSubResult::OK, ""); + return; + } + else + { + peer = std::make_shared<DirectConnectionPeer>(); + peerByAddressMap[address] = peer; + } + + std::unique_ptr<Socket> socket; try { socket = createSocket(Ipv4(nullptr, port), rendezvous, reuseAddr, bind); @@ -144,19 +169,20 @@ namespace sibs } Log::debug("DirectConnections: Connecting to peer (ip: %s, port: %d, rendezvous: %s)", address.getAddress().c_str(), address.getPort(), rendezvous ? "yes" : "no"); - if(UDT::connect(socket, (sockaddr*)&address.address, sizeof(address.address)) == UDT::ERROR) + if(UDT::connect(socket->udtSocket, (sockaddr*)&address.address, sizeof(address.address)) == UDT::ERROR) { if(connectCallbackFunc) connectCallbackFunc(peer, PubSubResult::ERROR, UDT::getlasterror_desc()); return; } - UDT::epoll_add_usock(eid, socket); - peer->socket = socket; + UDT::epoll_add_usock(eid, socket->udtSocket); + socket->eid = eid; + peersMutex.lock(); + peers[socket->udtSocket] = peer; + peer->socket = std::move(socket); peer->address = address; peer->receiveDataCallbackFunc = receiveDataCallbackFunc; - peersMutex.lock(); - peers[socket] = peer; peersMutex.unlock(); if(connectCallbackFunc) @@ -164,27 +190,27 @@ namespace sibs }).detach(); } - void DirectConnections::send(const std::shared_ptr<DirectConnectionPeer> peer, std::shared_ptr<std::vector<u8>> data, PubSubSendDataCallback sendDataCallbackFunc) + bool DirectConnections::send(const std::shared_ptr<DirectConnectionPeer> peer, std::shared_ptr<Message> data, PubSubSendDataCallback sendDataCallbackFunc) { + if(data->getDataSize() == 0) + return true; + + if(data->getDataSize() > 819200) // 800kb + return false; + // TODO: Replace this with light-weight threads (fibers)? std::thread([peer, data, sendDataCallbackFunc]() { - usize sentSizeTotal = 0; - while(sentSizeTotal < data->size() || data->size() == 0) + int sentSize = UDT::send(peer->socket->udtSocket, (char*)data->data(), data->getRawSize(), 0); + if(sentSize == UDT::ERROR) { - int sentSize = UDT::send(peer->socket, (char*)data->data() + sentSizeTotal, data->size() - sentSizeTotal, 0); - if(sentSize == UDT::ERROR) - { - if(sendDataCallbackFunc) - sendDataCallbackFunc(PubSubResult::ERROR, UDT::getlasterror_desc()); - } - sentSizeTotal += sentSize; - if(data->size() == 0) - break; + if(sendDataCallbackFunc) + sendDataCallbackFunc(PubSubResult::ERROR, UDT::getlasterror_desc()); } - if(sendDataCallbackFunc) + else if(sendDataCallbackFunc) sendDataCallbackFunc(PubSubResult::OK, ""); }).detach(); + return true; } void DirectConnections::onRemoveDisconnectedPeer(PubSubOnRemoveDisconnectedPeerCallback callbackFunc) @@ -195,19 +221,29 @@ namespace sibs bool DirectConnections::removePeer(int peerSocket) { bool wasRemoved = false; - peersMutex.lock(); + std::lock_guard<std::mutex> lock(peersMutex); auto peerIt = peers.find(peerSocket); if(peerIt != peers.end()) { if(removeDisconnectedPeerCallback) removeDisconnectedPeerCallback(peerIt->second); - UDT::epoll_remove_usock(eid, peerSocket); peers.erase(peerIt); wasRemoved = true; } - peersMutex.unlock(); return wasRemoved; } + + std::vector<std::shared_ptr<DirectConnectionPeer>> DirectConnections::getPeers() + { + std::vector<std::shared_ptr<DirectConnectionPeer>> result; + result.reserve(peers.size()); + std::lock_guard<std::mutex> lock(peersMutex); + for(auto &it : peers) + { + result.push_back(it.second); + } + return result; + } void DirectConnections::removeDisconnectedPeers() { @@ -221,9 +257,6 @@ namespace sibs if(removeDisconnectedPeerCallback) removeDisconnectedPeerCallback(it->second); - if(peerSocketStatus == UDTSTATUS::BROKEN) - UDT::epoll_remove_usock(eid, socket); - Log::debug("UDT: Connection was broken to socket %d (peer most likely disconnected), removing peer", socket); it = peers.erase(it); Log::debug("UDT: Removed peer socket %d", socket); @@ -233,11 +266,19 @@ namespace sibs } peersMutex.unlock(); } + + std::shared_ptr<DirectConnectionPeer> DirectConnections::getPeerByAddress(const Ipv4 &address) const + { + auto it = peerByAddressMap.find(address); + if(it != peerByAddressMap.end()) + return it->second; + return nullptr; + } void DirectConnections::receiveData() { std::vector<char> data; - data.reserve(MAX_RECEIVED_DATA_SIZE); + data.resize(MAX_RECEIVED_DATA_SIZE); Log::debug("DirectConnections::receiveData(): waiting for events"); std::set<UDTSOCKET> readfds; @@ -251,6 +292,9 @@ namespace sibs } else if(numfsReady == -1) { + if(!alive) + continue; + if(UDT::getlasterror_code() == UDT::ERRORINFO::ETIMEOUT) { continue; @@ -275,8 +319,12 @@ namespace sibs try { Log::debug("DirectConnection: Received data from peer: (ip: %s, port: %d)", peer->address.getAddress().c_str(), peer->address.getPort()); - if(peer->receiveDataCallbackFunc) - peer->receiveDataCallbackFunc(peer, data.data(), receivedTotalSize); + if(peer->receiveDataCallbackFunc && receivedTotalSize > 0) + { + static_assert(sizeof(MessageType) == sizeof(u8), ""); + MessageType messageType = (MessageType)data[0]; + peer->receiveDataCallbackFunc(peer, messageType, data.data() + 1, receivedTotalSize - 1); + } } catch(std::exception &e) { @@ -340,4 +388,31 @@ namespace sibs Log::error("UDT: Received too much data, ignoring..."); return 0; } + + std::vector<u8> DirectConnectionsUtils::serializePeers(const std::vector<std::shared_ptr<DirectConnectionPeer>> &peers) + { + sibs::SafeSerializer serializer; + for(const auto &it : peers) + { + serializer.add((u32)it->address.address.sin_family); + serializer.add((u32)it->address.address.sin_addr.s_addr); + serializer.add((u16)it->address.address.sin_port); + } + return serializer.getBuffer(); + } + + std::vector<std::shared_ptr<DirectConnectionPeer>> DirectConnectionsUtils::deserializePeers(const u8 *data, const usize size) + { + std::vector<std::shared_ptr<DirectConnectionPeer>> result; + sibs::SafeDeserializer deserializer(data, size); + while(!deserializer.empty()) + { + std::shared_ptr<DirectConnectionPeer> peer = std::make_shared<DirectConnectionPeer>(); + peer->address.address.sin_family = deserializer.extract<u32>(); + peer->address.address.sin_addr.s_addr = deserializer.extract<u32>(); + peer->address.address.sin_port = deserializer.extract<u32>(); + result.push_back(peer); + } + return result; + } } diff --git a/src/IpAddress.cpp b/src/IpAddress.cpp index c90c924..58ed661 100644 --- a/src/IpAddress.cpp +++ b/src/IpAddress.cpp @@ -56,6 +56,11 @@ namespace sibs bool Ipv4::operator == (const Ipv4 &other) const { - return address.sin_addr.s_addr == other.address.sin_addr.s_addr && getPort() == other.getPort(); + return address.sin_addr.s_addr == other.address.sin_addr.s_addr && address.sin_port == other.address.sin_port; + } + + bool Ipv4::operator != (const Ipv4 &other) const + { + return !(*this == other); } } diff --git a/src/Message.cpp b/src/Message.cpp new file mode 100644 index 0000000..58b39da --- /dev/null +++ b/src/Message.cpp @@ -0,0 +1,15 @@ +#include "../include/sibs/Message.hpp" + +namespace sibs +{ + Message::Message(MessageType messageType) + { + static_assert(sizeof(MessageType) == sizeof(u8), "Whoops, message type size has changed, the below code doesn't work"); + rawData.push_back((u8)messageType); + } + + void Message::append(const void *data, const usize size) + { + rawData.insert(rawData.end(), (const u8*)data, (const u8*)data + size); + } +}
\ No newline at end of file diff --git a/src/PubsubKey.cpp b/src/PubsubKey.cpp index dd807ce..64d2621 100644 --- a/src/PubsubKey.cpp +++ b/src/PubsubKey.cpp @@ -2,6 +2,8 @@ namespace sibs { + static const char *HEX_TABLE = "0123456789abcdef"; + PubsubKey::PubsubKey() : data({}) { @@ -10,9 +12,10 @@ namespace sibs PubsubKey::PubsubKey(const void *data, const usize size) { - std::copy((char*)data, (char*)data + std::min(size, PUBSUB_KEY_LENGTH), this->data.begin()); + usize _size = std::min(size, PUBSUB_KEY_LENGTH); + std::copy((char*)data, (char*)data + _size, this->data.begin()); if(size < PUBSUB_KEY_LENGTH) - std::fill_n((char*)data + size, PUBSUB_KEY_LENGTH - size, 0); + std::fill_n(this->data.begin() + size, PUBSUB_KEY_LENGTH - size, 0); } bool PubsubKey::operator == (const PubsubKey &other) const @@ -24,4 +27,17 @@ namespace sibs { return data != other.data; } + + std::string PubsubKey::toString() const + { + std::string result; + result.reserve(data.size()); + for(usize i = 0; i < data.size(); ++i) + { + u8 c = data[i]; + result += HEX_TABLE[(c & 0xF0) >> 4]; + result += HEX_TABLE[c & 0x0F]; + } + return result; + } } diff --git a/src/Socket.cpp b/src/Socket.cpp new file mode 100644 index 0000000..9c8da69 --- /dev/null +++ b/src/Socket.cpp @@ -0,0 +1,40 @@ +#include "../include/sibs/Socket.hpp" +#include <udt/udt.h> + +namespace sibs +{ + Socket::Socket() : + eid(-1), + udtSocket(-1) + { + + } + + Socket::Socket(int _udtSocket) : + eid(-1), + udtSocket(_udtSocket) + { + + } + + Socket::Socket(int _eid, int _udtSocket) : + eid(_eid), + udtSocket(_udtSocket) + { + + } + + Socket::Socket(Socket &&other) + { + eid = other.eid; + udtSocket = other.udtSocket; + other.eid = 0; + other.udtSocket = 0; + } + + Socket::~Socket() + { + UDT::close(udtSocket); + UDT::epoll_remove_usock(eid, udtSocket); + } +}
\ No newline at end of file |