diff --git a/Kernel/IPv4Socket.cpp b/Kernel/IPv4Socket.cpp index d300c772dbf..163e7597ff7 100644 --- a/Kernel/IPv4Socket.cpp +++ b/Kernel/IPv4Socket.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -19,17 +20,17 @@ Lockable>& IPv4Socket::sockets_by_udp_port() return *s_map; } -Lockable>& IPv4Socket::sockets_by_tcp_port() +Lockable>& IPv4Socket::sockets_by_tcp_port() { - static Lockable>* s_map; + static Lockable>* s_map; if (!s_map) - s_map = new Lockable>; + s_map = new Lockable>; return *s_map; } -IPv4SocketHandle IPv4Socket::from_tcp_port(word port) +TCPSocketHandle IPv4Socket::from_tcp_port(word port) { - RetainPtr socket; + RetainPtr socket; { LOCKER(sockets_by_tcp_port().lock()); auto it = sockets_by_tcp_port().resource().find(port); @@ -65,6 +66,8 @@ Lockable>& IPv4Socket::all_sockets() Retained IPv4Socket::create(int type, int protocol) { + if (type == SOCK_STREAM) + return TCPSocket::create(protocol); return adopt(*new IPv4Socket(type, protocol)); } @@ -125,28 +128,7 @@ KResult IPv4Socket::connect(const sockaddr* address, socklen_t address_size) m_destination_address = IPv4Address((const byte*)&ia.sin_addr.s_addr); m_destination_port = ntohs(ia.sin_port); - if (type() != SOCK_STREAM) - return KSuccess; - - // FIXME: Figure out the adapter somehow differently. - auto* adapter = NetworkAdapter::from_ipv4_address(IPv4Address(192, 168, 5, 2)); - if (!adapter) - ASSERT_NOT_REACHED(); - - allocate_source_port_if_needed(); - - m_tcp_sequence_number = 0; - m_tcp_ack_number = 0; - - send_tcp_packet(*adapter, TCPFlags::SYN); - m_tcp_state = TCPState::Connecting1; - - current->set_blocked_socket(this); - block(Process::BlockedConnect); - Scheduler::yield(); - - ASSERT(is_connected()); - return KSuccess; + return protocol_connect(); } void IPv4Socket::attach_fd(SocketRole) @@ -205,7 +187,7 @@ void IPv4Socket::allocate_source_port_if_needed() auto it = sockets_by_tcp_port().resource().find(port); if (it == sockets_by_tcp_port().resource().end()) { m_source_port = port; - sockets_by_tcp_port().resource().set(port, this); + sockets_by_tcp_port().resource().set(port, static_cast(this)); return; } } @@ -213,83 +195,6 @@ void IPv4Socket::allocate_source_port_if_needed() } } -struct [[gnu::packed]] TCPPseudoHeader { - IPv4Address source; - IPv4Address destination; - byte zero; - byte protocol; - NetworkOrdered payload_size; -}; - -NetworkOrdered IPv4Socket::compute_tcp_checksum(const IPv4Address& source, const IPv4Address& destination, const TCPPacket& packet, word payload_size) -{ - TCPPseudoHeader pseudo_header { source, destination, 0, (byte)IPv4Protocol::TCP, sizeof(TCPPacket) + payload_size }; - - dword checksum = 0; - auto* w = (const NetworkOrdered*)&pseudo_header; - for (size_t i = 0; i < sizeof(pseudo_header) / sizeof(word); ++i) { - checksum += w[i]; - if (checksum > 0xffff) - checksum = (checksum >> 16) + (checksum & 0xffff); - } - w = (const NetworkOrdered*)&packet; - for (size_t i = 0; i < sizeof(packet) / sizeof(word); ++i) { - checksum += w[i]; - if (checksum > 0xffff) - checksum = (checksum >> 16) + (checksum & 0xffff); - } - ASSERT(packet.data_offset() * 4 == sizeof(TCPPacket)); - w = (const NetworkOrdered*)packet.payload(); - for (size_t i = 0; i < payload_size / sizeof(word); ++i) { - checksum += w[i]; - if (checksum > 0xffff) - checksum = (checksum >> 16) + (checksum & 0xffff); - } - if (payload_size & 1) { - word expanded_byte = ((const byte*)packet.payload())[payload_size - 1]; - checksum += expanded_byte; - if (checksum > 0xffff) - checksum = (checksum >> 16) + (checksum & 0xffff); - } - return ~(checksum & 0xffff); -} - -void IPv4Socket::send_tcp_packet(NetworkAdapter& adapter, word flags, const void* payload, size_t payload_size) -{ - auto buffer = ByteBuffer::create_zeroed(sizeof(TCPPacket) + payload_size); - auto& tcp_packet = *(TCPPacket*)(buffer.pointer()); - ASSERT(m_source_port); - tcp_packet.set_source_port(m_source_port); - tcp_packet.set_destination_port(m_destination_port); - tcp_packet.set_window_size(1024); - tcp_packet.set_sequence_number(m_tcp_sequence_number); - tcp_packet.set_data_offset(sizeof(TCPPacket) / sizeof(dword)); - tcp_packet.set_flags(flags); - - if (flags & TCPFlags::ACK) - tcp_packet.set_ack_number(m_tcp_ack_number); - - if (flags == TCPFlags::SYN) { - ++m_tcp_sequence_number; - } else { - m_tcp_sequence_number += payload_size; - } - - memcpy(tcp_packet.payload(), payload, payload_size); - tcp_packet.set_checksum(compute_tcp_checksum(adapter.ipv4_address(), m_destination_address, tcp_packet, payload_size)); - kprintf("sending tcp packet from %s:%u to %s:%u with (%s %s) seq_no=%u, ack_no=%u\n", - adapter.ipv4_address().to_string().characters(), - source_port(), - m_destination_address.to_string().characters(), - m_destination_port, - tcp_packet.has_syn() ? "SYN" : "", - tcp_packet.has_ack() ? "ACK" : "", - tcp_packet.sequence_number(), - tcp_packet.ack_number() - ); - adapter.send_ipv4(MACAddress(), m_destination_address, IPv4Protocol::TCP, move(buffer)); -} - ssize_t IPv4Socket::sendto(const void* data, size_t data_length, int flags, const sockaddr* addr, socklen_t addr_length) { (void)flags; @@ -338,10 +243,8 @@ ssize_t IPv4Socket::sendto(const void* data, size_t data_length, int flags, cons return data_length; } - if (type() == SOCK_STREAM) { - send_tcp_packet(*adapter, TCPFlags::PUSH | TCPFlags::ACK, data, data_length); - return data_length; - } + if (type() == SOCK_STREAM) + return protocol_send(data, data_length); ASSERT_NOT_REACHED(); } @@ -409,17 +312,8 @@ ssize_t IPv4Socket::recvfrom(void* buffer, size_t buffer_length, int flags, sock return udp_packet.length() - sizeof(UDPPacket); } - if (type() == SOCK_STREAM) { - auto& tcp_packet = *static_cast(ipv4_packet.payload()); - size_t payload_size = packet_buffer.size() - sizeof(IPv4Packet) - tcp_packet.header_size(); - ASSERT(buffer_length >= payload_size); - if (addr) { - auto& ia = *(sockaddr_in*)addr; - ia.sin_port = htons(tcp_packet.destination_port()); - } - memcpy(buffer, tcp_packet.payload(), payload_size); - return payload_size; - } + if (type() == SOCK_STREAM) + return protocol_receive(packet_buffer, buffer, buffer_length, flags, addr, addr_length); ASSERT_NOT_REACHED(); } diff --git a/Kernel/IPv4Socket.h b/Kernel/IPv4Socket.h index 2c84eabf9c7..5951099854f 100644 --- a/Kernel/IPv4Socket.h +++ b/Kernel/IPv4Socket.h @@ -8,28 +8,21 @@ #include class IPv4SocketHandle; +class TCPSocketHandle; class NetworkAdapter; class TCPPacket; +class TCPSocket; -enum TCPState { - Disconnected, - Connecting1, - Connecting2, - Connected, - Disconnecting1, - Disconnecting2, -}; - -class IPv4Socket final : public Socket { +class IPv4Socket : public Socket { public: static Retained create(int type, int protocol); virtual ~IPv4Socket() override; static Lockable>& all_sockets(); static Lockable>& sockets_by_udp_port(); - static Lockable>& sockets_by_tcp_port(); + static Lockable>& sockets_by_tcp_port(); - static IPv4SocketHandle from_tcp_port(word); + static TCPSocketHandle from_tcp_port(word); static IPv4SocketHandle from_udp_port(word); virtual KResult bind(const sockaddr*, socklen_t) override; @@ -46,24 +39,23 @@ public: void did_receive(ByteBuffer&&); + const IPv4Address& source_address() const; word source_port() const { return m_source_port; } + + const IPv4Address& destination_address() const { return m_destination_address; } word destination_port() const { return m_destination_port; } - void send_tcp_packet(NetworkAdapter&, word flags, const void* payload = nullptr, size_t = 0); - void set_tcp_state(TCPState state) { m_tcp_state = state; } - TCPState tcp_state() const { return m_tcp_state; } - void set_tcp_ack_number(dword n) { m_tcp_ack_number = n; } - void set_tcp_sequence_number(dword n) { m_tcp_sequence_number = n; } - dword tcp_ack_number() const { return m_tcp_ack_number; } - dword tcp_sequence_number() const { return m_tcp_sequence_number; } +protected: + IPv4Socket(int type, int protocol); + void allocate_source_port_if_needed(); + + virtual int protocol_receive(const ByteBuffer&, void*, size_t, int, sockaddr*, socklen_t*) { return -ENOTIMPL; } + virtual int protocol_send(const void*, int) { return -ENOTIMPL; } + virtual KResult protocol_connect() { return KSuccess; } private: - IPv4Socket(int type, int protocol); virtual bool is_ipv4() const override { return true; } - void allocate_source_port_if_needed(); - NetworkOrdered compute_tcp_checksum(const IPv4Address& source, const IPv4Address& destination, const TCPPacket&, word payload_size); - bool m_bound { false }; int m_attached_fds { 0 }; IPv4Address m_destination_address; @@ -76,10 +68,6 @@ private: word m_source_port { 0 }; word m_destination_port { 0 }; - dword m_tcp_sequence_number { 0 }; - dword m_tcp_ack_number { 0 }; - TCPState m_tcp_state { Disconnected }; - bool m_can_read { false }; }; diff --git a/Kernel/Makefile b/Kernel/Makefile index db5ee4f0518..2f69877f128 100644 --- a/Kernel/Makefile +++ b/Kernel/Makefile @@ -35,6 +35,7 @@ KERNEL_OBJS = \ Socket.o \ LocalSocket.o \ IPv4Socket.o \ + TCPSocket.o \ NetworkAdapter.o \ E1000NetworkAdapter.o \ NetworkTask.o diff --git a/Kernel/NetworkTask.cpp b/Kernel/NetworkTask.cpp index 761a27ed448..e79f68d9e3f 100644 --- a/Kernel/NetworkTask.cpp +++ b/Kernel/NetworkTask.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -284,29 +285,29 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size) ASSERT(socket->type() == SOCK_STREAM); ASSERT(socket->source_port() == tcp_packet.destination_port()); - if (tcp_packet.ack_number() != socket->tcp_sequence_number()) { - kprintf("handle_tcp: ack/seq mismatch: got %u, wanted %u\n",tcp_packet.ack_number(), socket->tcp_sequence_number()); + if (tcp_packet.ack_number() != socket->sequence_number()) { + kprintf("handle_tcp: ack/seq mismatch: got %u, wanted %u\n",tcp_packet.ack_number(), socket->sequence_number()); return; } if (tcp_packet.has_syn() && tcp_packet.has_ack()) { - socket->set_tcp_ack_number(tcp_packet.sequence_number() + payload_size + 1); - socket->send_tcp_packet(*adapter, TCPFlags::ACK); + socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); + socket->send_tcp_packet(TCPFlags::ACK); socket->set_connected(true); kprintf("handle_tcp: Connection established!\n"); - socket->set_tcp_state(Connected); + socket->set_state(TCPSocket::State::Connected); return; } - socket->set_tcp_ack_number(tcp_packet.sequence_number() + payload_size); + socket->set_ack_number(tcp_packet.sequence_number() + payload_size); kprintf("Got packet with ack_no=%u, seq_no=%u, payload_size=%u, acking it with new ack_no=%u, seq_no=%u\n", tcp_packet.ack_number(), tcp_packet.sequence_number(), payload_size, - socket->tcp_ack_number(), - socket->tcp_sequence_number() + socket->ack_number(), + socket->sequence_number() ); - socket->send_tcp_packet(*adapter, TCPFlags::ACK); + socket->send_tcp_packet(TCPFlags::ACK); if (payload_size != 0) socket->did_receive(ByteBuffer::copy((const byte*)&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size())); diff --git a/Kernel/Socket.cpp b/Kernel/Socket.cpp index 8866c282d25..e6aeecc7aab 100644 --- a/Kernel/Socket.cpp +++ b/Kernel/Socket.cpp @@ -19,10 +19,10 @@ KResultOr> Socket::create(int domain, int type, int protocol) } Socket::Socket(int domain, int type, int protocol) - : m_domain(domain) + : m_lock("Socket") + , m_domain(domain) , m_type(type) , m_protocol(protocol) - , m_lock("Socket") { m_origin_pid = current->pid(); } @@ -69,12 +69,12 @@ KResult Socket::setsockopt(int level, int option, const void* value, socklen_t v case SO_SNDTIMEO: if (value_size != sizeof(timeval)) return KResult(-EINVAL); - m_send_timeout = *(timeval*)value; + m_send_timeout = *(const timeval*)value; return KSuccess; case SO_RCVTIMEO: if (value_size != sizeof(timeval)) return KResult(-EINVAL); - m_receive_timeout = *(timeval*)value; + m_receive_timeout = *(const timeval*)value; return KSuccess; default: kprintf("%s(%u): setsockopt() at SOL_SOCKET with unimplemented option %d\n", option); diff --git a/Kernel/TCPSocket.cpp b/Kernel/TCPSocket.cpp new file mode 100644 index 00000000000..0433d6d8097 --- /dev/null +++ b/Kernel/TCPSocket.cpp @@ -0,0 +1,148 @@ +#include +#include +#include +#include + +TCPSocket::TCPSocket(int protocol) + : IPv4Socket(SOCK_STREAM, protocol) +{ +} + +TCPSocket::~TCPSocket() +{ +} + +Retained TCPSocket::create(int protocol) +{ + return adopt(*new TCPSocket(protocol)); +} + +int TCPSocket::protocol_receive(const ByteBuffer& packet_buffer, void* buffer, size_t buffer_size, int flags, sockaddr* addr, socklen_t* addr_length) +{ + (void)flags; + (void)addr_length; + ASSERT(!packet_buffer.is_null()); + auto& ipv4_packet = *(const IPv4Packet*)(packet_buffer.pointer()); + auto& tcp_packet = *static_cast(ipv4_packet.payload()); + size_t payload_size = packet_buffer.size() - sizeof(IPv4Packet) - tcp_packet.header_size(); + ASSERT(buffer_size >= payload_size); + if (addr) { + auto& ia = *(sockaddr_in*)addr; + ia.sin_port = htons(tcp_packet.destination_port()); + } + memcpy(buffer, tcp_packet.payload(), payload_size); + return payload_size; +} + +int TCPSocket::protocol_send(const void* data, int data_length) +{ + // FIXME: Figure out the adapter somehow differently. + auto* adapter = NetworkAdapter::from_ipv4_address(IPv4Address(192, 168, 5, 2)); + if (!adapter) + ASSERT_NOT_REACHED(); + send_tcp_packet(TCPFlags::PUSH | TCPFlags::ACK, data, data_length); + return data_length; +} + +void TCPSocket::send_tcp_packet(word flags, const void* payload, int payload_size) +{ + // FIXME: Figure out the adapter somehow differently. + auto& adapter = *NetworkAdapter::from_ipv4_address(IPv4Address(192, 168, 5, 2)); + + auto buffer = ByteBuffer::create_zeroed(sizeof(TCPPacket) + payload_size); + auto& tcp_packet = *(TCPPacket*)(buffer.pointer()); + ASSERT(source_port()); + tcp_packet.set_source_port(source_port()); + tcp_packet.set_destination_port(destination_port()); + tcp_packet.set_window_size(1024); + tcp_packet.set_sequence_number(m_sequence_number); + tcp_packet.set_data_offset(sizeof(TCPPacket) / sizeof(dword)); + tcp_packet.set_flags(flags); + + if (flags & TCPFlags::ACK) + tcp_packet.set_ack_number(m_ack_number); + + if (flags == TCPFlags::SYN) { + ++m_sequence_number; + } else { + m_sequence_number += payload_size; + } + + memcpy(tcp_packet.payload(), payload, payload_size); + tcp_packet.set_checksum(compute_tcp_checksum(adapter.ipv4_address(), destination_address(), tcp_packet, payload_size)); + kprintf("sending tcp packet from %s:%u to %s:%u with (%s %s) seq_no=%u, ack_no=%u\n", + adapter.ipv4_address().to_string().characters(), + source_port(), + destination_address().to_string().characters(), + destination_port(), + tcp_packet.has_syn() ? "SYN" : "", + tcp_packet.has_ack() ? "ACK" : "", + tcp_packet.sequence_number(), + tcp_packet.ack_number() + ); + adapter.send_ipv4(MACAddress(), destination_address(), IPv4Protocol::TCP, move(buffer)); +} + +NetworkOrdered TCPSocket::compute_tcp_checksum(const IPv4Address& source, const IPv4Address& destination, const TCPPacket& packet, word payload_size) +{ + struct [[gnu::packed]] PseudoHeader { + IPv4Address source; + IPv4Address destination; + byte zero; + byte protocol; + NetworkOrdered payload_size; + }; + + PseudoHeader pseudo_header { source, destination, 0, (byte)IPv4Protocol::TCP, sizeof(TCPPacket) + payload_size }; + + dword checksum = 0; + auto* w = (const NetworkOrdered*)&pseudo_header; + for (size_t i = 0; i < sizeof(pseudo_header) / sizeof(word); ++i) { + checksum += w[i]; + if (checksum > 0xffff) + checksum = (checksum >> 16) + (checksum & 0xffff); + } + w = (const NetworkOrdered*)&packet; + for (size_t i = 0; i < sizeof(packet) / sizeof(word); ++i) { + checksum += w[i]; + if (checksum > 0xffff) + checksum = (checksum >> 16) + (checksum & 0xffff); + } + ASSERT(packet.data_offset() * 4 == sizeof(TCPPacket)); + w = (const NetworkOrdered*)packet.payload(); + for (size_t i = 0; i < payload_size / sizeof(word); ++i) { + checksum += w[i]; + if (checksum > 0xffff) + checksum = (checksum >> 16) + (checksum & 0xffff); + } + if (payload_size & 1) { + word expanded_byte = ((const byte*)packet.payload())[payload_size - 1]; + checksum += expanded_byte; + if (checksum > 0xffff) + checksum = (checksum >> 16) + (checksum & 0xffff); + } + return ~(checksum & 0xffff); +} + +KResult TCPSocket::protocol_connect() +{ + // FIXME: Figure out the adapter somehow differently. + auto* adapter = NetworkAdapter::from_ipv4_address(IPv4Address(192, 168, 5, 2)); + if (!adapter) + ASSERT_NOT_REACHED(); + + allocate_source_port_if_needed(); + + m_sequence_number = 0; + m_ack_number = 0; + + send_tcp_packet(TCPFlags::SYN); + m_state = State::Connecting; + + current->set_blocked_socket(this); + block(Process::BlockedConnect); + Scheduler::yield(); + + ASSERT(is_connected()); + return KSuccess; +} diff --git a/Kernel/TCPSocket.h b/Kernel/TCPSocket.h new file mode 100644 index 00000000000..92b431463ad --- /dev/null +++ b/Kernel/TCPSocket.h @@ -0,0 +1,63 @@ +#pragma once + +#include + +class TCPSocket final : public IPv4Socket { +public: + static Retained create(int protocol); + virtual ~TCPSocket() override; + + enum class State { + Disconnected, + Connecting, + Connected, + Disconnecting, + }; + + State state() const { return m_state; } + void set_state(State state) { m_state = state; } + + void set_ack_number(dword n) { m_ack_number = n; } + void set_sequence_number(dword n) { m_sequence_number = n; } + dword ack_number() const { return m_ack_number; } + dword sequence_number() const { return m_sequence_number; } + + void send_tcp_packet(word flags, const void* = nullptr, int = 0); + +private: + explicit TCPSocket(int protocol); + + NetworkOrdered compute_tcp_checksum(const IPv4Address& source, const IPv4Address& destination, const TCPPacket&, word payload_size); + + virtual int protocol_receive(const ByteBuffer&, void* buffer, size_t buffer_size, int flags, sockaddr* addr, socklen_t* addr_length) override; + virtual int protocol_send(const void*, int) override; + virtual KResult protocol_connect() override; + + dword m_sequence_number { 0 }; + dword m_ack_number { 0 }; + State m_state { State::Disconnected }; +}; + +class TCPSocketHandle : public SocketHandle { +public: + TCPSocketHandle() { } + + TCPSocketHandle(RetainPtr&& socket) + : SocketHandle(move(socket)) + { + } + + TCPSocketHandle(TCPSocketHandle&& other) + : SocketHandle(move(other)) + { + } + + TCPSocketHandle(const TCPSocketHandle&) = delete; + TCPSocketHandle& operator=(const TCPSocketHandle&) = delete; + + TCPSocket* operator->() { return &socket(); } + const TCPSocket* operator->() const { return &socket(); } + + TCPSocket& socket() { return static_cast(SocketHandle::socket()); } + const TCPSocket& socket() const { return static_cast(SocketHandle::socket()); } +};