aboutsummaryrefslogtreecommitdiff
path: root/src/DirectConnection.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/DirectConnection.cpp')
-rw-r--r--src/DirectConnection.cpp161
1 files changed, 118 insertions, 43 deletions
diff --git a/src/DirectConnection.cpp b/src/DirectConnection.cpp
index 083f557..e41a4a5 100644
--- a/src/DirectConnection.cpp
+++ b/src/DirectConnection.cpp
@@ -3,6 +3,8 @@
#include <cstdio>
#include <cstring>
#include <random>
+#include <sibs/SafeSerializer.hpp>
+#include <sibs/SafeDeserializer.hpp>
#ifndef WIN32
#include <arpa/inet.h>
@@ -26,6 +28,16 @@ namespace sibs
// Max received data size allowed when receiving regular data, receive data as file to receive more data
const int MAX_RECEIVED_DATA_SIZE = 1024 * 1024 * 1; // 1Mb
+
+ bool DirectConnectionPeer::operator == (const DirectConnectionPeer &other) const
+ {
+ return socket->udtSocket == other.socket->udtSocket;
+ }
+
+ bool DirectConnectionPeer::operator != (const DirectConnectionPeer &other) const
+ {
+ return !(*this == other);
+ }
DirectConnections::DirectConnections(u16 _port) :
port(_port == 0 ? (u16)generateRandomNumber(2000, 32000) : _port),
@@ -40,29 +52,26 @@ namespace sibs
DirectConnections::~DirectConnections()
{
alive = false;
- receiveDataThread.join();
-
- for(auto &peer : peers)
- {
- UDT::close(peer.first);
- }
+ peers.clear();
UDT::epoll_release(eid);
UDT::cleanup();
+ receiveDataThread.join();
}
- int DirectConnections::createSocket(const Ipv4 &addressToBind, bool rendezvous, bool reuseAddr, bool bind)
+ std::unique_ptr<Socket> DirectConnections::createSocket(const Ipv4 &addressToBind, bool rendezvous, bool reuseAddr, bool bind)
{
Log::debug("UDT: Creating socket for ipv4 address %s, port: %d", addressToBind.getAddress().c_str(), addressToBind.getPort());
- UDTSOCKET socket = UDT::socket(AF_INET, SOCK_STREAM, 0);
- if(socket == UDT::INVALID_SOCK)
+ UDTSOCKET udtSocket = UDT::socket(AF_INET, SOCK_STREAM, 0);
+ if(udtSocket == UDT::INVALID_SOCK)
{
std::string errMsg = "UDT: Failed to create socket, error: ";
errMsg += UDT::getlasterror_desc();
throw SocketCreateException(errMsg);
}
- UDT::setsockopt(socket, 0, UDT_RENDEZVOUS, &rendezvous, sizeof(bool));
- UDT::setsockopt(socket, 0, UDT_REUSEADDR, &reuseAddr, sizeof(bool));
+ auto socket = std::make_unique<Socket>(udtSocket);
+ UDT::setsockopt(udtSocket, 0, UDT_RENDEZVOUS, &rendezvous, sizeof(bool));
+ UDT::setsockopt(udtSocket, 0, UDT_REUSEADDR, &reuseAddr, sizeof(bool));
// Windows UDP issue
// For better performance, modify HKLM\System\CurrentControlSet\Services\Afd\Parameters\FastSendDatagramThreshold
@@ -86,9 +95,11 @@ namespace sibs
Ipv4 myAddr = addressToBind;
for(int i = 0; i < 2000; ++i)
{
- if(UDT::bind(socket, (sockaddr*)&myAddr.address, sizeof(myAddr.address)) == UDT::ERROR)
+ if(UDT::bind(udtSocket, (sockaddr*)&myAddr.address, sizeof(myAddr.address)) == UDT::ERROR)
{
- port = (u16)generateRandomNumber(2000, 32000);
+ u16 newPort = (u16)generateRandomNumber(2000, 32000);
+ Log::warn("DirectConnections: failed to bind socket to port %d, trying port %d. Fail reason: %s", port, newPort, UDT::getlasterror_desc());
+ port = newPort;
myAddr.address.sin_port = htons(port);
}
else
@@ -101,7 +112,7 @@ namespace sibs
Ipv4 myAddr = addressToBind;
for(int i = 0; i < 2000; ++i)
{
- if(UDT::bind(socket, (sockaddr*)&myAddr.address, sizeof(myAddr.address)) == UDT::ERROR)
+ if(UDT::bind(udtSocket, (sockaddr*)&myAddr.address, sizeof(myAddr.address)) == UDT::ERROR)
{
port = (u16)generateRandomNumber(2000, 32000);
myAddr.address.sin_port = htons(port);
@@ -130,8 +141,22 @@ namespace sibs
{
std::thread([this, address, rendezvous, reuseAddr, connectCallbackFunc, receiveDataCallbackFunc, bind]()
{
- std::shared_ptr<DirectConnectionPeer> peer = std::make_shared<DirectConnectionPeer>();
- UDTSOCKET socket;
+ std::shared_ptr<DirectConnectionPeer> peer = getPeerByAddress(address);
+ if(peer)
+ {
+ // this doesn't really matter, we always call connect with same callback function
+ peer->receiveDataCallbackFunc = receiveDataCallbackFunc;
+ if(connectCallbackFunc)
+ connectCallbackFunc(peer, PubSubResult::OK, "");
+ return;
+ }
+ else
+ {
+ peer = std::make_shared<DirectConnectionPeer>();
+ peerByAddressMap[address] = peer;
+ }
+
+ std::unique_ptr<Socket> socket;
try
{
socket = createSocket(Ipv4(nullptr, port), rendezvous, reuseAddr, bind);
@@ -144,19 +169,20 @@ namespace sibs
}
Log::debug("DirectConnections: Connecting to peer (ip: %s, port: %d, rendezvous: %s)", address.getAddress().c_str(), address.getPort(), rendezvous ? "yes" : "no");
- if(UDT::connect(socket, (sockaddr*)&address.address, sizeof(address.address)) == UDT::ERROR)
+ if(UDT::connect(socket->udtSocket, (sockaddr*)&address.address, sizeof(address.address)) == UDT::ERROR)
{
if(connectCallbackFunc)
connectCallbackFunc(peer, PubSubResult::ERROR, UDT::getlasterror_desc());
return;
}
- UDT::epoll_add_usock(eid, socket);
- peer->socket = socket;
+ UDT::epoll_add_usock(eid, socket->udtSocket);
+ socket->eid = eid;
+ peersMutex.lock();
+ peers[socket->udtSocket] = peer;
+ peer->socket = std::move(socket);
peer->address = address;
peer->receiveDataCallbackFunc = receiveDataCallbackFunc;
- peersMutex.lock();
- peers[socket] = peer;
peersMutex.unlock();
if(connectCallbackFunc)
@@ -164,27 +190,27 @@ namespace sibs
}).detach();
}
- void DirectConnections::send(const std::shared_ptr<DirectConnectionPeer> peer, std::shared_ptr<std::vector<u8>> data, PubSubSendDataCallback sendDataCallbackFunc)
+ bool DirectConnections::send(const std::shared_ptr<DirectConnectionPeer> peer, std::shared_ptr<Message> data, PubSubSendDataCallback sendDataCallbackFunc)
{
+ if(data->getDataSize() == 0)
+ return true;
+
+ if(data->getDataSize() > 819200) // 800kb
+ return false;
+
// TODO: Replace this with light-weight threads (fibers)?
std::thread([peer, data, sendDataCallbackFunc]()
{
- usize sentSizeTotal = 0;
- while(sentSizeTotal < data->size() || data->size() == 0)
+ int sentSize = UDT::send(peer->socket->udtSocket, (char*)data->data(), data->getRawSize(), 0);
+ if(sentSize == UDT::ERROR)
{
- int sentSize = UDT::send(peer->socket, (char*)data->data() + sentSizeTotal, data->size() - sentSizeTotal, 0);
- if(sentSize == UDT::ERROR)
- {
- if(sendDataCallbackFunc)
- sendDataCallbackFunc(PubSubResult::ERROR, UDT::getlasterror_desc());
- }
- sentSizeTotal += sentSize;
- if(data->size() == 0)
- break;
+ if(sendDataCallbackFunc)
+ sendDataCallbackFunc(PubSubResult::ERROR, UDT::getlasterror_desc());
}
- if(sendDataCallbackFunc)
+ else if(sendDataCallbackFunc)
sendDataCallbackFunc(PubSubResult::OK, "");
}).detach();
+ return true;
}
void DirectConnections::onRemoveDisconnectedPeer(PubSubOnRemoveDisconnectedPeerCallback callbackFunc)
@@ -195,19 +221,29 @@ namespace sibs
bool DirectConnections::removePeer(int peerSocket)
{
bool wasRemoved = false;
- peersMutex.lock();
+ std::lock_guard<std::mutex> lock(peersMutex);
auto peerIt = peers.find(peerSocket);
if(peerIt != peers.end())
{
if(removeDisconnectedPeerCallback)
removeDisconnectedPeerCallback(peerIt->second);
- UDT::epoll_remove_usock(eid, peerSocket);
peers.erase(peerIt);
wasRemoved = true;
}
- peersMutex.unlock();
return wasRemoved;
}
+
+ std::vector<std::shared_ptr<DirectConnectionPeer>> DirectConnections::getPeers()
+ {
+ std::vector<std::shared_ptr<DirectConnectionPeer>> result;
+ result.reserve(peers.size());
+ std::lock_guard<std::mutex> lock(peersMutex);
+ for(auto &it : peers)
+ {
+ result.push_back(it.second);
+ }
+ return result;
+ }
void DirectConnections::removeDisconnectedPeers()
{
@@ -221,9 +257,6 @@ namespace sibs
if(removeDisconnectedPeerCallback)
removeDisconnectedPeerCallback(it->second);
- if(peerSocketStatus == UDTSTATUS::BROKEN)
- UDT::epoll_remove_usock(eid, socket);
-
Log::debug("UDT: Connection was broken to socket %d (peer most likely disconnected), removing peer", socket);
it = peers.erase(it);
Log::debug("UDT: Removed peer socket %d", socket);
@@ -233,11 +266,19 @@ namespace sibs
}
peersMutex.unlock();
}
+
+ std::shared_ptr<DirectConnectionPeer> DirectConnections::getPeerByAddress(const Ipv4 &address) const
+ {
+ auto it = peerByAddressMap.find(address);
+ if(it != peerByAddressMap.end())
+ return it->second;
+ return nullptr;
+ }
void DirectConnections::receiveData()
{
std::vector<char> data;
- data.reserve(MAX_RECEIVED_DATA_SIZE);
+ data.resize(MAX_RECEIVED_DATA_SIZE);
Log::debug("DirectConnections::receiveData(): waiting for events");
std::set<UDTSOCKET> readfds;
@@ -251,6 +292,9 @@ namespace sibs
}
else if(numfsReady == -1)
{
+ if(!alive)
+ continue;
+
if(UDT::getlasterror_code() == UDT::ERRORINFO::ETIMEOUT)
{
continue;
@@ -275,8 +319,12 @@ namespace sibs
try
{
Log::debug("DirectConnection: Received data from peer: (ip: %s, port: %d)", peer->address.getAddress().c_str(), peer->address.getPort());
- if(peer->receiveDataCallbackFunc)
- peer->receiveDataCallbackFunc(peer, data.data(), receivedTotalSize);
+ if(peer->receiveDataCallbackFunc && receivedTotalSize > 0)
+ {
+ static_assert(sizeof(MessageType) == sizeof(u8), "");
+ MessageType messageType = (MessageType)data[0];
+ peer->receiveDataCallbackFunc(peer, messageType, data.data() + 1, receivedTotalSize - 1);
+ }
}
catch(std::exception &e)
{
@@ -340,4 +388,31 @@ namespace sibs
Log::error("UDT: Received too much data, ignoring...");
return 0;
}
+
+ std::vector<u8> DirectConnectionsUtils::serializePeers(const std::vector<std::shared_ptr<DirectConnectionPeer>> &peers)
+ {
+ sibs::SafeSerializer serializer;
+ for(const auto &it : peers)
+ {
+ serializer.add((u32)it->address.address.sin_family);
+ serializer.add((u32)it->address.address.sin_addr.s_addr);
+ serializer.add((u16)it->address.address.sin_port);
+ }
+ return serializer.getBuffer();
+ }
+
+ std::vector<std::shared_ptr<DirectConnectionPeer>> DirectConnectionsUtils::deserializePeers(const u8 *data, const usize size)
+ {
+ std::vector<std::shared_ptr<DirectConnectionPeer>> result;
+ sibs::SafeDeserializer deserializer(data, size);
+ while(!deserializer.empty())
+ {
+ std::shared_ptr<DirectConnectionPeer> peer = std::make_shared<DirectConnectionPeer>();
+ peer->address.address.sin_family = deserializer.extract<u32>();
+ peer->address.address.sin_addr.s_addr = deserializer.extract<u32>();
+ peer->address.address.sin_port = deserializer.extract<u32>();
+ result.push_back(peer);
+ }
+ return result;
+ }
}