#include "../include/sibs/DirectConnection.hpp" #include "../include/Log.hpp" #include #include #include #ifndef WIN32 #include #include #else #include #include #endif #include namespace sibs { static std::random_device rd; static std::mt19937 gen(rd()); static u32 generateRandomNumber(u32 start, u32 end) { std::uniform_int_distribution<> dis(start, end); return dis(gen); } // Max received data size allowed when receiving regular data, receive data as file to receive more data const int MAX_RECEIVED_DATA_SIZE = 1024 * 1024 * 1; // 1Mb DirectConnections::DirectConnections(u16 _port) : port(_port == 0 ? (u16)generateRandomNumber(2000, 32000) : _port), alive(true), removeDisconnectedPeerCallback(nullptr) { UDT::startup(); eid = UDT::epoll_create(); receiveDataThread = std::thread(&DirectConnections::receiveData, this); } DirectConnections::~DirectConnections() { alive = false; receiveDataThread.join(); for(auto &peer : peers) { UDT::close(peer.first); } UDT::epoll_release(eid); UDT::cleanup(); } int DirectConnections::createSocket(const Ipv4 &addressToBind, bool rendezvous, bool reuseAddr, bool bind) { Log::debug("UDT: Creating socket for ipv4 address %s, port: %d", addressToBind.getAddress().c_str(), addressToBind.getPort()); UDTSOCKET socket = UDT::socket(AF_INET, SOCK_STREAM, 0); if(socket == UDT::INVALID_SOCK) { std::string errMsg = "UDT: Failed to create socket, error: "; errMsg += UDT::getlasterror_desc(); throw SocketCreateException(errMsg); } UDT::setsockopt(socket, 0, UDT_RENDEZVOUS, &rendezvous, sizeof(bool)); UDT::setsockopt(socket, 0, UDT_REUSEADDR, &reuseAddr, sizeof(bool)); // Windows UDP issue // For better performance, modify HKLM\System\CurrentControlSet\Services\Afd\Parameters\FastSendDatagramThreshold #ifdef WIN32 int mss = 1052; UDT::setsockopt(socket, 0, UDT_MSS, &mss, sizeof(mss)); #endif if(rendezvous || bind) { if(reuseAddr) { /* if(UDT::bind(socket, (sockaddr*)&addressToBind.address, sizeof(addressToBind.address)) == UDT::ERROR) { std::string errMsg = "UDT: Failed to bind, error: "; errMsg += UDT::getlasterror_desc(); throw SocketCreateException(errMsg); } */ Ipv4 myAddr = addressToBind; for(int i = 0; i < 2000; ++i) { if(UDT::bind(socket, (sockaddr*)&myAddr.address, sizeof(myAddr.address)) == UDT::ERROR) { port = (u16)generateRandomNumber(2000, 32000); myAddr.address.sin_port = htons(port); } else return socket; } throw SocketCreateException("UDT: Failed to bind after 2000 tries"); } else { Ipv4 myAddr = addressToBind; for(int i = 0; i < 2000; ++i) { if(UDT::bind(socket, (sockaddr*)&myAddr.address, sizeof(myAddr.address)) == UDT::ERROR) { port = (u16)generateRandomNumber(2000, 32000); myAddr.address.sin_port = htons(port); } else return socket; } throw SocketCreateException("UDT: Failed to bind after 2000 tries"); } } return socket; } void DirectConnections::connectServer(const Ipv4 &address, PubSubConnectCallback connectCallbackFunc, PubSubReceiveDataCallback receiveDataCallbackFunc) { connect(address, false, true, connectCallbackFunc, receiveDataCallbackFunc, true); } void DirectConnections::connect(const Ipv4 &address, PubSubConnectCallback connectCallbackFunc, PubSubReceiveDataCallback receiveDataCallbackFunc) { connect(address, true, true, connectCallbackFunc, receiveDataCallbackFunc, true); } void DirectConnections::connect(const Ipv4 &address, bool rendezvous, bool reuseAddr, PubSubConnectCallback connectCallbackFunc, PubSubReceiveDataCallback receiveDataCallbackFunc, bool bind) { std::thread([this, address, rendezvous, reuseAddr, connectCallbackFunc, receiveDataCallbackFunc, bind]() { std::shared_ptr peer = std::make_shared(); UDTSOCKET socket; try { socket = createSocket(Ipv4(nullptr, port), rendezvous, reuseAddr, bind); } catch(SocketCreateException &e) { if(connectCallbackFunc) connectCallbackFunc(peer, PubSubResult::ERROR, e.what()); return; } Log::debug("DirectConnections: Connecting to peer (ip: %s, port: %d, rendezvous: %s)", address.getAddress().c_str(), address.getPort(), rendezvous ? "yes" : "no"); if(UDT::connect(socket, (sockaddr*)&address.address, sizeof(address.address)) == UDT::ERROR) { if(connectCallbackFunc) connectCallbackFunc(peer, PubSubResult::ERROR, UDT::getlasterror_desc()); return; } UDT::epoll_add_usock(eid, socket); peer->socket = socket; peer->address = address; peer->receiveDataCallbackFunc = receiveDataCallbackFunc; peersMutex.lock(); peers[socket] = peer; peersMutex.unlock(); if(connectCallbackFunc) connectCallbackFunc(peer, PubSubResult::OK, ""); }).detach(); } void DirectConnections::send(const std::shared_ptr peer, std::shared_ptr> data, PubSubSendDataCallback sendDataCallbackFunc) { // TODO: Replace this with light-weight threads (fibers)? std::thread([peer, data, sendDataCallbackFunc]() { usize sentSizeTotal = 0; while(sentSizeTotal < data->size() || data->size() == 0) { int sentSize = UDT::send(peer->socket, (char*)data->data() + sentSizeTotal, data->size() - sentSizeTotal, 0); if(sentSize == UDT::ERROR) { if(sendDataCallbackFunc) sendDataCallbackFunc(PubSubResult::ERROR, UDT::getlasterror_desc()); } sentSizeTotal += sentSize; if(data->size() == 0) break; } if(sendDataCallbackFunc) sendDataCallbackFunc(PubSubResult::OK, ""); }).detach(); } void DirectConnections::onRemoveDisconnectedPeer(PubSubOnRemoveDisconnectedPeerCallback callbackFunc) { removeDisconnectedPeerCallback = callbackFunc; } bool DirectConnections::removePeer(int peerSocket) { bool wasRemoved = false; peersMutex.lock(); auto peerIt = peers.find(peerSocket); if(peerIt != peers.end()) { if(removeDisconnectedPeerCallback) removeDisconnectedPeerCallback(peerIt->second); UDT::epoll_remove_usock(eid, peerSocket); peers.erase(peerIt); wasRemoved = true; } peersMutex.unlock(); return wasRemoved; } void DirectConnections::removeDisconnectedPeers() { peersMutex.lock(); for(std::unordered_map>::iterator it = peers.begin(); it != peers.end(); ) { int socket = it->first; UDTSTATUS peerSocketStatus = UDT::getsockstate(socket); if(peerSocketStatus == UDTSTATUS::BROKEN || peerSocketStatus == UDTSTATUS::CLOSING || peerSocketStatus == UDTSTATUS::CLOSED || peerSocketStatus == UDTSTATUS::NONEXIST) { if(removeDisconnectedPeerCallback) removeDisconnectedPeerCallback(it->second); if(peerSocketStatus == UDTSTATUS::BROKEN) UDT::epoll_remove_usock(eid, socket); Log::debug("UDT: Connection was broken to socket %d (peer most likely disconnected), removing peer", socket); it = peers.erase(it); Log::debug("UDT: Removed peer socket %d", socket); } else ++it; } peersMutex.unlock(); } void DirectConnections::receiveData() { std::vector data; data.reserve(MAX_RECEIVED_DATA_SIZE); Log::debug("DirectConnections::receiveData(): waiting for events"); std::set readfds; while(alive) { removeDisconnectedPeers(); int numfsReady = UDT::epoll_wait(eid, &readfds, nullptr, 1000); if(numfsReady == 0) { continue; } else if(numfsReady == -1) { if(UDT::getlasterror_code() == UDT::ERRORINFO::ETIMEOUT) { continue; } else { Log::error("UDT: Stop receiving data, got error: %s", UDT::getlasterror_desc()); return; } } for(UDTSOCKET receivedDataFromPeer : readfds) { bool peerDisconnected = false; usize receivedTotalSize = 0; int receivedDataStatus = receiveDataFromPeer(receivedDataFromPeer, data.data(), &receivedTotalSize); if(receivedDataStatus == 0) { peersMutex.lock(); auto peer = peers[receivedDataFromPeer]; peersMutex.unlock(); try { Log::debug("DirectConnection: Received data from peer: (ip: %s, port: %d)", peer->address.getAddress().c_str(), peer->address.getPort()); if(peer->receiveDataCallbackFunc) peer->receiveDataCallbackFunc(peer, data.data(), receivedTotalSize); } catch(std::exception &e) { Log::error("UDT: Receive callback function threw exception: %s, ignoring...", e.what()); } } else if(receivedDataStatus == CUDTException::EINVSOCK) { Log::debug("UDT: Invalid socket %d, did remote peer disconnect?", receivedDataFromPeer); peerDisconnected = true; } else if(receivedDataStatus == CUDTException::ECONNLOST) { Log::debug("UDT: Connection was broken to socket %d (peer most likely disconnected), removing peer", receivedDataFromPeer); peerDisconnected = true; } if(peerDisconnected) { if(removePeer(receivedDataFromPeer)) Log::debug("UDT: Removed peer socket %d", receivedDataFromPeer); else Log::error("UDT: Failed to remove peer socket %d, system said we got data from it but we are not connected to it", receivedDataFromPeer); } } readfds.clear(); } } int DirectConnections::receiveDataFromPeer(const int socket, char *output, usize *receivedTotalSize) { *receivedTotalSize = 0; while(*receivedTotalSize < MAX_RECEIVED_DATA_SIZE) { int dataAvailableSize; int receiveSizeDataTypeSize = sizeof(dataAvailableSize); if(UDT::getsockopt(socket, 0, UDT_RCVDATA, &dataAvailableSize, &receiveSizeDataTypeSize) == UDT::ERROR) { Log::error("UDT: Failed to receive data available size, error: %s (%d)", UDT::getlasterror_desc(), UDT::getlasterror_code()); return UDT::getlasterror_code(); } if(dataAvailableSize == 0) return 0; int receivedSize = UDT::recv(socket, &output[*receivedTotalSize], MAX_RECEIVED_DATA_SIZE - *receivedTotalSize, 0); if(receivedSize == UDT::ERROR) { Log::error("UDT: Failed to receive data, error: %s (%d)", UDT::getlasterror_desc(), UDT::getlasterror_code()); return UDT::getlasterror_code(); } if(receivedSize == 0) { return 0; } (*receivedTotalSize) += receivedSize; } Log::error("UDT: Received too much data, ignoring..."); return 0; } }