#include "../include/sibs/DirectConnection.hpp" #include "../include/Log.hpp" #include #include #ifndef WIN32 #include #include #else #include #include #endif #include namespace sibs { // Max received data size allowed when receiving regular data, receive data as file to receive more data const int MAX_RECEIVED_DATA_SIZE = 1024 * 1024 * 1; // 1Mb DirectConnections::DirectConnections(u16 _port) : port(_port), alive(true) { UDT::startup(); eid = UDT::epoll_create(); receiveDataThread = std::thread(&DirectConnections::receiveData, this); } DirectConnections::~DirectConnections() { alive = false; receiveDataThread.join(); for(auto &peer : peers) { UDT::close(peer.first); } UDT::epoll_release(eid); UDT::cleanup(); } int DirectConnections::createSocket(const Ipv4 &addressToBind, bool rendezvous, bool reuseAddr) { Log::debug("UDT: Creating socket for ipv4 address %s, port: %d", addressToBind.getAddress().c_str(), addressToBind.getPort()); UDTSOCKET socket = UDT::socket(AF_INET, SOCK_STREAM, IPPROTO_UDP); if(socket == UDT::INVALID_SOCK) { std::string errMsg = "UDT: Failed to create socket, error: "; errMsg += UDT::getlasterror_desc(); throw ConnectionException(errMsg); } UDT::setsockopt(socket, 0, UDT_RENDEZVOUS, &rendezvous, sizeof(bool)); UDT::setsockopt(socket, 0, UDT_REUSEADDR, &reuseAddr, sizeof(bool)); // Windows UDP issue // For better performance, modify HKLM\System\CurrentControlSet\Services\Afd\Parameters\FastSendDatagramThreshold #ifdef WIN32 int mss = 1052; UDT::setsockopt(socket, 0, UDT_MSS, &mss, sizeof(mss)); #endif if(UDT::bind(socket, (sockaddr*)&addressToBind.address, sizeof(addressToBind.address)) == UDT::ERROR) { std::string errMsg = "UDT: Failed to bind, error: "; errMsg += UDT::getlasterror_desc(); throw ConnectionException(errMsg); } return socket; } std::shared_ptr DirectConnections::connectServer(const Ipv4 &address, PubSubReceiveDataCallback receiveDataCallbackFunc) { return connect(address, false, true, receiveDataCallbackFunc); } std::shared_ptr DirectConnections::connect(const Ipv4 &address, PubSubReceiveDataCallback receiveDataCallbackFunc) { return connect(address, true, true, receiveDataCallbackFunc); } std::shared_ptr DirectConnections::connect(const Ipv4 &address, bool rendezvous, bool reuseAddr, PubSubReceiveDataCallback receiveDataCallbackFunc) { UDTSOCKET socket = createSocket(Ipv4(nullptr, port), rendezvous, reuseAddr); if(UDT::connect(socket, (sockaddr*)&address.address, sizeof(address.address)) == UDT::ERROR) { UDT::close(socket); std::string errMsg = "UDT: Failed to connect, error: "; errMsg += UDT::getlasterror_desc(); throw ConnectionException(errMsg); } UDT::epoll_add_usock(eid, socket); std::shared_ptr peer = std::make_shared(); peer->socket = socket; peer->address = address; peer->receiveDataCallbackFunc = receiveDataCallbackFunc; peersMutex.lock(); peers[socket] = peer; peersMutex.unlock(); return peer; } void DirectConnections::send(const std::shared_ptr &peer, std::shared_ptr> data, PubSubSendDataCallback sendDataCallbackFunc) { // TODO: Replace this with light-weight threads (fibers)? std::thread([peer, data, sendDataCallbackFunc]() { usize sentSizeTotal = 0; while(sentSizeTotal < data->size()) { int sentSize = UDT::send(peer->socket, (char*)data->data() + sentSizeTotal, data->size() - sentSizeTotal, 0); if(sentSize == UDT::ERROR) { if(sendDataCallbackFunc) sendDataCallbackFunc(PubSubResult::ERROR, UDT::getlasterror_desc()); } sentSizeTotal += sentSize; } sendDataCallbackFunc(PubSubResult::OK, ""); }).detach(); } void DirectConnections::receiveData() { std::vector data; data.reserve(MAX_RECEIVED_DATA_SIZE); std::set readfds; while(alive) { int numfsReady = UDT::epoll_wait(eid, &readfds, nullptr, 250); if(numfsReady == 0) continue; else if(numfsReady == -1) { if(UDT::getlasterror_code() == UDT::ERRORINFO::ETIMEOUT) continue; else { Log::error("UDT: Stop receiving data, got error: %s", UDT::getlasterror_desc()); return; } } for(UDTSOCKET receivedDataFromPeer : readfds) { bool receivedData = receiveDataFromPeer(receivedDataFromPeer, data.data()); if(receivedData) { peersMutex.lock(); auto peer = peers[receivedDataFromPeer]; peersMutex.unlock(); try { if(peer->receiveDataCallbackFunc) peer->receiveDataCallbackFunc(peer, data.data(), data.size()); } catch(std::exception &e) { Log::error("UDT: Receive callback function threw exception: %s, ignoring...", e.what()); } } } readfds.clear(); } } bool DirectConnections::receiveDataFromPeer(const int socket, char *output) { usize receivedTotalSize = 0; while(receivedTotalSize < MAX_RECEIVED_DATA_SIZE) { usize dataAvailableSize; int receiveSizeDataTypeSize = sizeof(dataAvailableSize); if(UDT::getsockopt(socket, 0, UDT_RCVDATA, &dataAvailableSize, &receiveSizeDataTypeSize) == UDT::ERROR) { Log::error("UDT: Failed to receive data available size, error: %s", UDT::getlasterror_desc()); return false; } int receivedSize = UDT::recv(socket, &output[receivedTotalSize], MAX_RECEIVED_DATA_SIZE - receivedTotalSize, 0); if(receivedSize == UDT::ERROR) { Log::error("UDT: Failed to receive data, error: %s", UDT::getlasterror_desc()); return false; } if(receivedSize == 0) { return true; } receivedTotalSize += dataAvailableSize; } Log::error("UDT: Received too much data, ignoring..."); return false; } }