mirror of
				https://github.com/dolphin-emu/dolphin.git
				synced 2025-10-25 09:29:43 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			482 lines
		
	
	
	
		
			13 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			482 lines
		
	
	
	
		
			13 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| // SPDX-License-Identifier: CC0-1.0
 | |
| 
 | |
| #include "Common/TraversalClient.h"
 | |
| 
 | |
| #include <cstddef>
 | |
| #include <cstring>
 | |
| #include <string>
 | |
| 
 | |
| #include "Common/CommonTypes.h"
 | |
| #include "Common/Logging/Log.h"
 | |
| #include "Common/MsgHandler.h"
 | |
| #include "Common/Random.h"
 | |
| #include "Core/NetPlayProto.h"
 | |
| 
 | |
| namespace Common
 | |
| {
 | |
| TraversalClient::TraversalClient(ENetHost* netHost, const std::string& server, const u16 port,
 | |
|                                  const u16 port_alt)
 | |
|     : m_NetHost(netHost), m_Server(server), m_port(port), m_portAlt(port_alt)
 | |
| {
 | |
|   netHost->intercept = TraversalClient::InterceptCallback;
 | |
| 
 | |
|   Reset();
 | |
| 
 | |
|   ReconnectToServer();
 | |
| }
 | |
| 
 | |
| TraversalClient::~TraversalClient() = default;
 | |
| 
 | |
| TraversalHostId TraversalClient::GetHostID() const
 | |
| {
 | |
|   return m_HostId;
 | |
| }
 | |
| 
 | |
| TraversalInetAddress TraversalClient::GetExternalAddress() const
 | |
| {
 | |
|   return m_external_address;
 | |
| }
 | |
| 
 | |
| TraversalClient::State TraversalClient::GetState() const
 | |
| {
 | |
|   return m_State;
 | |
| }
 | |
| 
 | |
| TraversalClient::FailureReason TraversalClient::GetFailureReason() const
 | |
| {
 | |
|   return m_FailureReason;
 | |
| }
 | |
| 
 | |
| void TraversalClient::ReconnectToServer()
 | |
| {
 | |
|   if (enet_address_set_host(&m_ServerAddress, m_Server.c_str()))
 | |
|   {
 | |
|     OnFailure(FailureReason::BadHost);
 | |
|     return;
 | |
|   }
 | |
|   m_ServerAddress.port = m_port;
 | |
| 
 | |
|   m_State = State::Connecting;
 | |
| 
 | |
|   TraversalPacket hello = {};
 | |
|   hello.type = TraversalPacketType::HelloFromClient;
 | |
|   hello.helloFromClient.protoVersion = TraversalProtoVersion;
 | |
|   SendTraversalPacket(hello);
 | |
|   if (m_Client)
 | |
|     m_Client->OnTraversalStateChanged();
 | |
| }
 | |
| 
 | |
| static ENetAddress MakeENetAddress(const TraversalInetAddress& address)
 | |
| {
 | |
|   ENetAddress eaddr{};
 | |
|   if (address.isIPV6)
 | |
|   {
 | |
|     eaddr.port = 0;  // no support yet :(
 | |
|   }
 | |
|   else
 | |
|   {
 | |
|     eaddr.host = address.address[0];
 | |
|     eaddr.port = ntohs(address.port);
 | |
|   }
 | |
|   return eaddr;
 | |
| }
 | |
| 
 | |
| void TraversalClient::ConnectToClient(std::string_view host)
 | |
| {
 | |
|   if (host.size() > sizeof(TraversalHostId))
 | |
|   {
 | |
|     PanicAlertFmt("Host too long");
 | |
|     return;
 | |
|   }
 | |
|   TraversalPacket packet = {};
 | |
|   packet.type = TraversalPacketType::ConnectPlease;
 | |
|   memcpy(packet.connectPlease.hostId.data(), host.data(), host.size());
 | |
|   m_ConnectRequestId = SendTraversalPacket(packet);
 | |
|   m_PendingConnect = true;
 | |
| }
 | |
| 
 | |
| bool TraversalClient::TestPacket(u8* data, size_t size, ENetAddress* from)
 | |
| {
 | |
|   if (from->host == m_ServerAddress.host && from->port == m_ServerAddress.port)
 | |
|   {
 | |
|     if (size < sizeof(TraversalPacket))
 | |
|     {
 | |
|       ERROR_LOG_FMT(NETPLAY, "Received too-short traversal packet.");
 | |
|     }
 | |
|     else
 | |
|     {
 | |
|       HandleServerPacket((TraversalPacket*)data);
 | |
|       return true;
 | |
|     }
 | |
|   }
 | |
|   return false;
 | |
| }
 | |
| 
 | |
| //--Temporary until more of the old netplay branch is moved over
 | |
| void TraversalClient::Update()
 | |
| {
 | |
|   ENetEvent netEvent;
 | |
|   if (enet_host_service(m_NetHost, &netEvent, 4) > 0)
 | |
|   {
 | |
|     switch (netEvent.type)
 | |
|     {
 | |
|     case ENET_EVENT_TYPE_RECEIVE:
 | |
|       TestPacket(netEvent.packet->data, netEvent.packet->dataLength, &netEvent.peer->address);
 | |
| 
 | |
|       enet_packet_destroy(netEvent.packet);
 | |
|       break;
 | |
|     default:
 | |
|       break;
 | |
|     }
 | |
|   }
 | |
|   HandleResends();
 | |
| }
 | |
| 
 | |
| void TraversalClient::HandleServerPacket(TraversalPacket* packet)
 | |
| {
 | |
|   u8 ok = 1;
 | |
|   switch (packet->type)
 | |
|   {
 | |
|   case TraversalPacketType::Ack:
 | |
|     if (!packet->ack.ok)
 | |
|     {
 | |
|       OnFailure(FailureReason::ServerForgotAboutUs);
 | |
|       break;
 | |
|     }
 | |
|     for (auto it = m_OutgoingTraversalPackets.begin(); it != m_OutgoingTraversalPackets.end(); ++it)
 | |
|     {
 | |
|       if (it->packet.requestId == packet->requestId)
 | |
|       {
 | |
|         if (packet->requestId == m_TestRequestId)
 | |
|           HandleTraversalTest();
 | |
|         m_OutgoingTraversalPackets.erase(it);
 | |
|         break;
 | |
|       }
 | |
|     }
 | |
|     break;
 | |
|   case TraversalPacketType::HelloFromServer:
 | |
|     if (!IsConnecting())
 | |
|       break;
 | |
|     if (!packet->helloFromServer.ok)
 | |
|     {
 | |
|       OnFailure(FailureReason::VersionTooOld);
 | |
|       break;
 | |
|     }
 | |
|     m_HostId = packet->helloFromServer.yourHostId;
 | |
|     m_external_address = packet->helloFromServer.yourAddress;
 | |
|     NewTraversalTest();
 | |
|     m_State = State::Connected;
 | |
|     if (m_Client)
 | |
|       m_Client->OnTraversalStateChanged();
 | |
|     break;
 | |
|   case TraversalPacketType::PleaseSendPacket:
 | |
|   {
 | |
|     // security is overrated.
 | |
|     ENetAddress addr = MakeENetAddress(packet->pleaseSendPacket.address);
 | |
|     if (addr.port != 0)
 | |
|     {
 | |
|       char message[] = "Hello from Dolphin Netplay...";
 | |
|       ENetBuffer buf;
 | |
|       buf.data = message;
 | |
|       buf.dataLength = sizeof(message) - 1;
 | |
|       if (m_ttlReady)
 | |
|       {
 | |
|         int oldttl;
 | |
|         enet_socket_get_option(m_NetHost->socket, ENET_SOCKOPT_TTL, &oldttl);
 | |
|         enet_socket_set_option(m_NetHost->socket, ENET_SOCKOPT_TTL, m_ttl);
 | |
|         enet_socket_send(m_NetHost->socket, &addr, &buf, 1);
 | |
|         enet_socket_set_option(m_NetHost->socket, ENET_SOCKOPT_TTL, oldttl);
 | |
|       }
 | |
|       else
 | |
|       {
 | |
|         enet_socket_send(m_NetHost->socket, &addr, &buf, 1);
 | |
|       }
 | |
|     }
 | |
|     else
 | |
|     {
 | |
|       // invalid IPV6
 | |
|       ok = 0;
 | |
|     }
 | |
|     break;
 | |
|   }
 | |
|   case TraversalPacketType::ConnectReady:
 | |
|   case TraversalPacketType::ConnectFailed:
 | |
|   {
 | |
|     if (!m_PendingConnect || packet->connectReady.requestId != m_ConnectRequestId)
 | |
|       break;
 | |
| 
 | |
|     m_PendingConnect = false;
 | |
| 
 | |
|     if (!m_Client)
 | |
|       break;
 | |
| 
 | |
|     if (packet->type == TraversalPacketType::ConnectReady)
 | |
|       m_Client->OnConnectReady(MakeENetAddress(packet->connectReady.address));
 | |
|     else
 | |
|       m_Client->OnConnectFailed(packet->connectFailed.reason);
 | |
|     break;
 | |
|   }
 | |
|   default:
 | |
|     WARN_LOG_FMT(NETPLAY, "Received unknown packet with type {}", static_cast<int>(packet->type));
 | |
|     break;
 | |
|   }
 | |
|   if (packet->type != TraversalPacketType::Ack)
 | |
|   {
 | |
|     TraversalPacket ack = {};
 | |
|     ack.type = TraversalPacketType::Ack;
 | |
|     ack.requestId = packet->requestId;
 | |
|     ack.ack.ok = ok;
 | |
| 
 | |
|     ENetBuffer buf;
 | |
|     buf.data = &ack;
 | |
|     buf.dataLength = sizeof(ack);
 | |
|     if (enet_socket_send(m_NetHost->socket, &m_ServerAddress, &buf, 1) == -1)
 | |
|       OnFailure(FailureReason::SocketSendError);
 | |
|   }
 | |
| }
 | |
| 
 | |
| void TraversalClient::OnFailure(FailureReason reason)
 | |
| {
 | |
|   m_State = State::Failure;
 | |
|   m_FailureReason = reason;
 | |
| 
 | |
|   if (m_Client)
 | |
|     m_Client->OnTraversalStateChanged();
 | |
| }
 | |
| 
 | |
| void TraversalClient::ResendPacket(OutgoingTraversalPacketInfo* info)
 | |
| {
 | |
|   bool testPacket =
 | |
|       m_TestSocket != ENET_SOCKET_NULL && info->packet.type == TraversalPacketType::TestPlease;
 | |
|   info->sendTime = enet_time_get();
 | |
|   info->tries++;
 | |
|   ENetBuffer buf;
 | |
|   buf.data = &info->packet;
 | |
|   buf.dataLength = sizeof(info->packet);
 | |
|   if (enet_socket_send(testPacket ? m_TestSocket : m_NetHost->socket, &m_ServerAddress, &buf, 1) ==
 | |
|       -1)
 | |
|     OnFailure(FailureReason::SocketSendError);
 | |
| }
 | |
| 
 | |
| void TraversalClient::HandleResends()
 | |
| {
 | |
|   const u32 now = enet_time_get();
 | |
|   for (auto& tpi : m_OutgoingTraversalPackets)
 | |
|   {
 | |
|     if (now - tpi.sendTime >= (u32)(300 * tpi.tries))
 | |
|     {
 | |
|       if (tpi.tries >= 5)
 | |
|       {
 | |
|         OnFailure(FailureReason::ResendTimeout);
 | |
|         m_OutgoingTraversalPackets.clear();
 | |
|         break;
 | |
|       }
 | |
|       else
 | |
|       {
 | |
|         ResendPacket(&tpi);
 | |
|       }
 | |
|     }
 | |
|   }
 | |
|   HandlePing();
 | |
| }
 | |
| 
 | |
| void TraversalClient::HandlePing()
 | |
| {
 | |
|   const u32 now = enet_time_get();
 | |
|   if (IsConnected() && now - m_PingTime >= 500)
 | |
|   {
 | |
|     TraversalPacket ping = {};
 | |
|     ping.type = TraversalPacketType::Ping;
 | |
|     ping.ping.hostId = m_HostId;
 | |
|     SendTraversalPacket(ping);
 | |
|     m_PingTime = now;
 | |
|   }
 | |
| }
 | |
| 
 | |
| void TraversalClient::NewTraversalTest()
 | |
| {
 | |
|   // create test socket
 | |
|   if (m_TestSocket != ENET_SOCKET_NULL)
 | |
|     enet_socket_destroy(m_TestSocket);
 | |
|   m_TestSocket = enet_socket_create(ENET_SOCKET_TYPE_DATAGRAM);
 | |
|   ENetAddress addr = {ENET_HOST_ANY, 0};
 | |
|   if (m_TestSocket == ENET_SOCKET_NULL || enet_socket_bind(m_TestSocket, &addr) < 0)
 | |
|   {
 | |
|     // error, abort
 | |
|     if (m_TestSocket != ENET_SOCKET_NULL)
 | |
|     {
 | |
|       enet_socket_destroy(m_TestSocket);
 | |
|       m_TestSocket = ENET_SOCKET_NULL;
 | |
|     }
 | |
|     return;
 | |
|   }
 | |
|   enet_socket_set_option(m_TestSocket, ENET_SOCKOPT_NONBLOCK, 1);
 | |
|   // create holepunch packet
 | |
|   TraversalPacket packet = {};
 | |
|   packet.type = TraversalPacketType::Ping;
 | |
|   packet.ping.hostId = m_HostId;
 | |
|   packet.requestId = Common::Random::GenerateValue<TraversalRequestId>();
 | |
|   // create buffer
 | |
|   ENetBuffer buf;
 | |
|   buf.data = &packet;
 | |
|   buf.dataLength = sizeof(packet);
 | |
|   // send to alt port
 | |
|   ENetAddress altAddress = m_ServerAddress;
 | |
|   altAddress.port = m_portAlt;
 | |
|   // set up ttl and send
 | |
|   int oldttl;
 | |
|   enet_socket_get_option(m_TestSocket, ENET_SOCKOPT_TTL, &oldttl);
 | |
|   enet_socket_set_option(m_TestSocket, ENET_SOCKOPT_TTL, m_ttl);
 | |
|   if (enet_socket_send(m_TestSocket, &altAddress, &buf, 1) == -1)
 | |
|   {
 | |
|     // error, abort
 | |
|     enet_socket_destroy(m_TestSocket);
 | |
|     m_TestSocket = ENET_SOCKET_NULL;
 | |
|     return;
 | |
|   }
 | |
|   enet_socket_set_option(m_TestSocket, ENET_SOCKOPT_TTL, oldttl);
 | |
|   // send the test request
 | |
|   packet.type = TraversalPacketType::TestPlease;
 | |
|   m_TestRequestId = SendTraversalPacket(packet);
 | |
| }
 | |
| 
 | |
| void TraversalClient::HandleTraversalTest()
 | |
| {
 | |
|   if (m_TestSocket != ENET_SOCKET_NULL)
 | |
|   {
 | |
|     // check for packet on test socket (with timeout)
 | |
|     u32 deadline = enet_time_get() + 50;
 | |
|     u32 waitCondition;
 | |
|     do
 | |
|     {
 | |
|       waitCondition = ENET_SOCKET_WAIT_RECEIVE | ENET_SOCKET_WAIT_INTERRUPT;
 | |
|       u32 currentTime = enet_time_get();
 | |
|       if (currentTime > deadline ||
 | |
|           enet_socket_wait(m_TestSocket, &waitCondition, deadline - currentTime) != 0)
 | |
|       {
 | |
|         // error or timeout, exit the loop and assume test failure
 | |
|         waitCondition = 0;
 | |
|         break;
 | |
|       }
 | |
|       else if (waitCondition & ENET_SOCKET_WAIT_RECEIVE)
 | |
|       {
 | |
|         // try reading the packet and see if it's relevant
 | |
|         ENetAddress raddr;
 | |
|         TraversalPacket packet;
 | |
|         ENetBuffer buf;
 | |
|         buf.data = &packet;
 | |
|         buf.dataLength = sizeof(packet);
 | |
|         int rv = enet_socket_receive(m_TestSocket, &raddr, &buf, 1);
 | |
|         if (rv < 0)
 | |
|         {
 | |
|           // error, exit the loop and assume test failure
 | |
|           waitCondition = 0;
 | |
|           break;
 | |
|         }
 | |
|         else if (rv < int(sizeof(packet)) || raddr.host != m_ServerAddress.host ||
 | |
|                  raddr.host != m_portAlt || packet.requestId != m_TestRequestId)
 | |
|         {
 | |
|           // irrelevant packet, ignore
 | |
|           continue;
 | |
|         }
 | |
|       }
 | |
|     } while (waitCondition & ENET_SOCKET_WAIT_INTERRUPT);
 | |
|     // regardless of what happens next, we can throw out the socket
 | |
|     enet_socket_destroy(m_TestSocket);
 | |
|     m_TestSocket = ENET_SOCKET_NULL;
 | |
|     if (waitCondition & ENET_SOCKET_WAIT_RECEIVE)
 | |
|     {
 | |
|       // success, we can stop now
 | |
|       m_ttlReady = true;
 | |
|       m_Client->OnTtlDetermined(m_ttl);
 | |
|     }
 | |
|     else
 | |
|     {
 | |
|       // fail, increment and retry
 | |
|       if (++m_ttl < 32)
 | |
|         NewTraversalTest();
 | |
|     }
 | |
|   }
 | |
| }
 | |
| 
 | |
| TraversalRequestId TraversalClient::SendTraversalPacket(const TraversalPacket& packet)
 | |
| {
 | |
|   OutgoingTraversalPacketInfo info;
 | |
|   info.packet = packet;
 | |
|   info.packet.requestId = Common::Random::GenerateValue<TraversalRequestId>();
 | |
|   info.tries = 0;
 | |
|   m_OutgoingTraversalPackets.push_back(info);
 | |
|   ResendPacket(&m_OutgoingTraversalPackets.back());
 | |
|   return info.packet.requestId;
 | |
| }
 | |
| 
 | |
| void TraversalClient::Reset()
 | |
| {
 | |
|   m_PendingConnect = false;
 | |
|   m_Client = nullptr;
 | |
| }
 | |
| 
 | |
| int ENET_CALLBACK TraversalClient::InterceptCallback(ENetHost* host, ENetEvent* event)
 | |
| {
 | |
|   auto traversalClient = g_TraversalClient.get();
 | |
|   if (traversalClient->TestPacket(host->receivedData, host->receivedDataLength,
 | |
|                                   &host->receivedAddress) ||
 | |
|       (host->receivedDataLength == 1 && host->receivedData[0] == 0))
 | |
|   {
 | |
|     event->type = static_cast<ENetEventType>(Common::ENet::SKIPPABLE_EVENT);
 | |
|     return 1;
 | |
|   }
 | |
|   return 0;
 | |
| }
 | |
| 
 | |
| std::unique_ptr<TraversalClient> g_TraversalClient;
 | |
| ENet::ENetHostPtr g_MainNetHost;
 | |
| 
 | |
| // The settings at the previous TraversalClient reset - notably, we
 | |
| // need to know not just what port it's on, but whether it was
 | |
| // explicitly requested.
 | |
| static std::string g_OldServer;
 | |
| static u16 g_OldServerPort;
 | |
| static u16 g_OldServerPortAlt;
 | |
| static u16 g_OldListenPort;
 | |
| 
 | |
| bool EnsureTraversalClient(const std::string& server, u16 server_port, u16 server_port_alt,
 | |
|                            u16 listen_port)
 | |
| {
 | |
|   if (!g_MainNetHost || !g_TraversalClient || server != g_OldServer ||
 | |
|       server_port != g_OldServerPort || server_port_alt != g_OldServerPortAlt ||
 | |
|       listen_port != g_OldListenPort)
 | |
|   {
 | |
|     g_OldServer = server;
 | |
|     g_OldServerPort = server_port;
 | |
|     g_OldServerPortAlt = server_port_alt;
 | |
|     g_OldListenPort = listen_port;
 | |
| 
 | |
|     ENetAddress addr = {ENET_HOST_ANY, listen_port};
 | |
|     auto host = Common::ENet::ENetHostPtr{enet_host_create(&addr,                   // address
 | |
|                                                            50,                      // peerCount
 | |
|                                                            NetPlay::CHANNEL_COUNT,  // channelLimit
 | |
|                                                            0,    // incomingBandwidth
 | |
|                                                            0)};  // outgoingBandwidth
 | |
|     if (!host)
 | |
|     {
 | |
|       g_MainNetHost.reset();
 | |
|       return false;
 | |
|     }
 | |
|     host->mtu = std::min(host->mtu, NetPlay::MAX_ENET_MTU);
 | |
|     g_MainNetHost = std::move(host);
 | |
|     g_TraversalClient.reset(
 | |
|         new TraversalClient(g_MainNetHost.get(), server, server_port, server_port_alt));
 | |
|   }
 | |
|   return true;
 | |
| }
 | |
| 
 | |
| void ReleaseTraversalClient()
 | |
| {
 | |
|   if (!g_TraversalClient)
 | |
|     return;
 | |
| 
 | |
|   g_TraversalClient.reset();
 | |
|   g_MainNetHost.reset();
 | |
| }
 | |
| }  // namespace Common
 |