#include "../include/sibs/DirectConnection.hpp" #include "../include/Log.hpp" #include #include #include #include #include #ifndef WIN32 #include #include #else #include #include #endif #include namespace sibs { static std::random_device rd; static std::mt19937 gen(rd()); static u32 generateRandomNumber(u32 start, u32 end) { std::uniform_int_distribution<> dis(start, end); return dis(gen); } // Max received data size allowed when receiving regular data, receive data as file to receive more data const int MAX_RECEIVED_DATA_SIZE = 1024 * 1024 * 1; // 1Mb bool DirectConnectionPeer::operator == (const DirectConnectionPeer &other) const { return socket->udtSocket == other.socket->udtSocket; } bool DirectConnectionPeer::operator != (const DirectConnectionPeer &other) const { return !(*this == other); } DirectConnections::DirectConnections(u16 _port) : port(_port == 0 ? (u16)generateRandomNumber(2000, 32000) : _port), alive(true), removeDisconnectedPeerCallback(nullptr) { UDT::startup(); eid = UDT::epoll_create(); receiveDataThread = std::thread(&DirectConnections::receiveData, this); } DirectConnections::~DirectConnections() { alive = false; receiveDataThread.join(); peers.clear(); UDT::epoll_release(eid); UDT::cleanup(); } std::unique_ptr DirectConnections::createSocket(const Ipv4 &addressToBind, bool rendezvous, bool reuseAddr, bool bind) { Log::debug("UDT: Creating socket for ipv4 address %s, port: %d", addressToBind.getAddress().c_str(), addressToBind.getPort()); UDTSOCKET udtSocket = UDT::socket(AF_INET, SOCK_DGRAM, 0); if(udtSocket == UDT::INVALID_SOCK) { std::string errMsg = "UDT: Failed to create socket, error: "; errMsg += UDT::getlasterror_desc(); throw SocketCreateException(errMsg); } auto socket = std::make_unique(udtSocket); UDT::setsockopt(udtSocket, 0, UDT_RENDEZVOUS, &rendezvous, sizeof(bool)); UDT::setsockopt(udtSocket, 0, UDT_REUSEADDR, &reuseAddr, sizeof(bool)); // Windows UDP issue // For better performance, modify HKLM\System\CurrentControlSet\Services\Afd\Parameters\FastSendDatagramThreshold #ifdef WIN32 int mss = 1052; UDT::setsockopt(udtSocket, 0, UDT_MSS, &mss, sizeof(mss)); #endif if(rendezvous || bind) { Ipv4 myAddr = addressToBind; for(int i = 0; i < 2000; ++i) { if(UDT::bind(udtSocket, (sockaddr*)&myAddr.address, sizeof(myAddr.address)) == UDT::ERROR) { port = (u16)generateRandomNumber(2000, 32000); Log::warn("DirectConnections::createSocket: failed to bind socket to port %d, trying port %d. Fail reason: %s", myAddr.getPort(), port, UDT::getlasterror_desc()); myAddr.address.sin_port = htons(port); } else return socket; } throw SocketCreateException("UDT: Failed to bind after 2000 tries"); } return socket; } void DirectConnections::connectServer(const Ipv4 &address, PubSubConnectCallback connectCallbackFunc, PubSubReceiveDataCallback receiveDataCallbackFunc) { connect(address, false, true, connectCallbackFunc, receiveDataCallbackFunc, true, true); } void DirectConnections::connect(const Ipv4 &address, PubSubConnectCallback connectCallbackFunc, PubSubReceiveDataCallback receiveDataCallbackFunc) { connect(address, true, true, connectCallbackFunc, receiveDataCallbackFunc, true, false); } static bool isReady(const std::shared_future &connectionResultFuture) { return connectionResultFuture.wait_for(std::chrono::seconds(0)) == std::future_status::ready; } void DirectConnections::connect(const Ipv4 &address, bool rendezvous, bool reuseAddr, PubSubConnectCallback connectCallbackFunc, PubSubReceiveDataCallback receiveDataCallbackFunc, bool bind, bool server) { connectionResultsMutex.lock(); auto it = connectionResults.find(address); if(it != connectionResults.end()) { std::shared_future connectionResultFuture = it->second; connectionResultsMutex.unlock(); if(isReady(connectionResultFuture)) { PubSubConnectResult connectResult = connectionResultFuture.get(); if(connectCallbackFunc) connectCallbackFunc(connectResult.peer, connectResult.result, connectResult.resultStr); } else { std::thread([connectCallbackFunc, connectionResultFuture]() { PubSubConnectResult connectResult = connectionResultFuture.get(); if(connectCallbackFunc) connectCallbackFunc(connectResult.peer, connectResult.result, connectResult.resultStr); }).detach(); } } else { connectionResults[address] = std::async(std::launch::async, [this, address, rendezvous, reuseAddr, connectCallbackFunc, receiveDataCallbackFunc, bind, server]() { std::shared_ptr peer = std::make_shared(); std::unique_ptr socket; try { socket = createSocket(Ipv4(nullptr, port), rendezvous, reuseAddr, bind); } catch(SocketCreateException &e) { if(connectCallbackFunc) connectCallbackFunc(peer, PubSubResult::RESULT_ERROR, e.what()); return PubSubConnectResult { peer, PubSubResult::RESULT_ERROR, e.what() }; } Log::debug("DirectConnections: Connecting to peer (ip: %s, port: %d, rendezvous: %s)", address.getAddress().c_str(), address.getPort(), rendezvous ? "yes" : "no"); if(UDT::connect(socket->udtSocket, (sockaddr*)&address.address, sizeof(address.address)) == UDT::ERROR) { if(connectCallbackFunc) connectCallbackFunc(peer, PubSubResult::RESULT_ERROR, UDT::getlasterror_desc()); return PubSubConnectResult{ peer, PubSubResult::RESULT_ERROR, UDT::getlasterror_desc() }; } UDT::epoll_add_usock(eid, socket->udtSocket); socket->eid = eid; peersMutex.lock(); peers[socket->udtSocket] = peer; peer->socket = std::move(socket); peer->address = address; peer->receiveDataCallbackFunc = receiveDataCallbackFunc; peer->type = (server ? PeerType::SERVER : PeerType::CLIENT); peersMutex.unlock(); if(connectCallbackFunc) connectCallbackFunc(peer, PubSubResult::RESULT_OK, ""); return PubSubConnectResult { peer, PubSubResult::RESULT_OK, "" }; }); connectionResultsMutex.unlock(); } } bool DirectConnections::send(const std::shared_ptr peer, std::shared_ptr data, PubSubSendDataCallback sendDataCallbackFunc) { if(data->getDataSize() == 0) { Log::warn("No data sent because you are trying to send 0 bytes"); return true; } if(data->getDataSize() > 819200) // 800kb { Log::error("Data not sent, data was over 800kb"); return false; } Log::debug("DirectConnections::send: sending %d bytes to %s:%d", data->getRawSize(), peer->address.getAddress().c_str(), peer->address.getPort()); // TODO: Replace this with light-weight threads (fibers)? std::thread([peer, data, sendDataCallbackFunc]() { const int one_min_ms = 1000 * 60; int sentSize = UDT::sendmsg(peer->socket->udtSocket, (char*)data->data(), data->getRawSize(), one_min_ms, true); if(sentSize == UDT::ERROR) { if(sendDataCallbackFunc) sendDataCallbackFunc(PubSubResult::RESULT_ERROR, UDT::getlasterror_desc()); } else if(sendDataCallbackFunc) sendDataCallbackFunc(PubSubResult::RESULT_OK, ""); }).detach(); return true; } void DirectConnections::onRemoveDisconnectedPeer(PubSubOnRemoveDisconnectedPeerCallback callbackFunc) { removeDisconnectedPeerCallback = callbackFunc; } bool DirectConnections::removePeer(int peerSocket) { bool wasRemoved = false; std::lock_guard lock(peersMutex); auto peerIt = peers.find(peerSocket); if(peerIt != peers.end()) { if(removeDisconnectedPeerCallback) removeDisconnectedPeerCallback(peerIt->second); peers.erase(peerIt); wasRemoved = true; } return wasRemoved; } std::vector> DirectConnections::getPeers() { std::vector> result; result.reserve(peers.size()); std::lock_guard lock(peersMutex); for(auto &it : peers) { if(it.second->type == PeerType::CLIENT) result.push_back(it.second); } return result; } void DirectConnections::removeDisconnectedPeers() { peersMutex.lock(); for(std::unordered_map>::iterator it = peers.begin(); it != peers.end(); ) { int socket = it->first; UDTSTATUS peerSocketStatus = UDT::getsockstate(socket); if(peerSocketStatus == UDTSTATUS::BROKEN || peerSocketStatus == UDTSTATUS::CLOSING || peerSocketStatus == UDTSTATUS::CLOSED || peerSocketStatus == UDTSTATUS::NONEXIST) { if(removeDisconnectedPeerCallback) removeDisconnectedPeerCallback(it->second); Log::debug("UDT: Connection was broken to socket %d (peer most likely disconnected), removing peer", socket); it = peers.erase(it); Log::debug("UDT: Removed peer socket %d", socket); } else ++it; } peersMutex.unlock(); } void DirectConnections::receiveData() { std::vector data; data.resize(MAX_RECEIVED_DATA_SIZE); Log::debug("DirectConnections::receiveData(): waiting for events"); std::set readfds; while(alive) { removeDisconnectedPeers(); readfds.clear(); int numfsReady = UDT::epoll_wait(eid, &readfds, nullptr, 1000); if(numfsReady == 0) { continue; } else if(numfsReady == -1) { if(!alive) continue; if(UDT::getlasterror_code() == UDT::ERRORINFO::ETIMEOUT) { continue; } else { Log::error("UDT: Stop receiving data, got error: %s", UDT::getlasterror_desc()); return; } } for(UDTSOCKET receivedDataFromPeer : readfds) { bool peerDisconnected = false; usize receivedTotalSize = 0; int receivedDataStatus = receiveDataFromPeer(receivedDataFromPeer, data.data(), &receivedTotalSize); if(receivedDataStatus == 0) { peersMutex.lock(); auto peer = peers[receivedDataFromPeer]; peersMutex.unlock(); try { Log::debug("DirectConnections::receiveData: Received %d bytes from peer: (ip: %s, port: %d)", receivedTotalSize, peer->address.getAddress().c_str(), peer->address.getPort()); if(peer->receiveDataCallbackFunc && receivedTotalSize > 0) { static_assert(sizeof(MessageType) == sizeof(u8), ""); MessageType messageType = (MessageType)data[0]; peer->receiveDataCallbackFunc(peer, messageType, data.data() + 1, receivedTotalSize - 1); } } catch(std::exception &e) { Log::error("UDT: Receive callback function threw exception: %s, ignoring...", e.what()); } catch(...) { Log::error("UDT: Receive callback function threw exception, ignoring..."); } } else if(receivedDataStatus == CUDTException::EINVSOCK) { Log::debug("UDT: Invalid socket %d, did remote peer disconnect?", receivedDataFromPeer); peerDisconnected = true; } else if(receivedDataStatus == CUDTException::ECONNLOST) { Log::debug("UDT: Connection was broken to socket %d (peer most likely disconnected), removing peer", receivedDataFromPeer); peerDisconnected = true; } if(peerDisconnected) { if(removePeer(receivedDataFromPeer)) Log::debug("UDT: Removed peer socket %d", receivedDataFromPeer); else Log::error("UDT: Failed to remove peer socket %d, system said we got data from it but we are not connected to it", receivedDataFromPeer); } } } } int DirectConnections::receiveDataFromPeer(const int socket, char *output, usize *receivedTotalSize) { *receivedTotalSize = 0; int dataAvailableSize; int receiveSizeDataTypeSize = sizeof(dataAvailableSize); if(UDT::getsockopt(socket, 0, UDT_RCVDATA, &dataAvailableSize, &receiveSizeDataTypeSize) == UDT::ERROR) { Log::error("DirectConnections::receiveDataFromPeer: Failed to receive data available size, error: %s (%d)", UDT::getlasterror_desc(), UDT::getlasterror_code()); return UDT::getlasterror_code(); } if(dataAvailableSize == 0) return 0; if(dataAvailableSize > MAX_RECEIVED_DATA_SIZE) { Log::error("DirectConnections::receiveDataFromPeer: Received too much data, ignoring..."); return 0; } int receivedSize = UDT::recvmsg(socket, &output[0], MAX_RECEIVED_DATA_SIZE); if(receivedSize == UDT::ERROR) { Log::error("UDT: Failed to receive data, error: %s (%d)", UDT::getlasterror_desc(), UDT::getlasterror_code()); return UDT::getlasterror_code(); } (*receivedTotalSize) = receivedSize; return 0; } std::vector DirectConnectionsUtils::serializePeers(const std::vector> &peers) { sibs::SafeSerializer serializer; for(const auto &it : peers) { serializer.add((u16)it->address.address.sin_family); serializer.add((u32)it->address.address.sin_addr.s_addr); serializer.add((u16)it->address.address.sin_port); } return serializer.getBuffer(); } std::vector> DirectConnectionsUtils::deserializePeers(const u8 *data, const usize size) { std::vector> result; sibs::SafeDeserializer deserializer(data, size); while(!deserializer.empty()) { std::shared_ptr peer = std::make_shared(); peer->address.address.sin_family = deserializer.extract(); peer->address.address.sin_addr.s_addr = deserializer.extract(); peer->address.address.sin_port = deserializer.extract(); result.push_back(peer); } return result; } }