Fixed various implementation bugs.

This commit is contained in:
Koen J 2025-04-10 10:55:27 +02:00
parent 1ae9f0ea26
commit 955ba23b0d
3 changed files with 60 additions and 40 deletions

View file

@ -93,13 +93,17 @@ class SyncServerTests {
val tcsDataB = CompletableDeferred<ByteArray>()
channelB.setDataHandler { _, _, o, so, d ->
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(d.array())
val b = ByteArray(d.remaining())
d.get(b)
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(b)
}
channelA.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(byteArrayOf(1, 2, 3)))
val tcsDataA = CompletableDeferred<ByteArray>()
channelA.setDataHandler { _, _, o, so, d ->
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataA.complete(d.array())
val b = ByteArray(d.remaining())
d.get(b)
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataA.complete(b)
}
channelB.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(byteArrayOf(4, 5, 6)))
@ -231,7 +235,9 @@ class SyncServerTests {
val tcsDataB = CompletableDeferred<ByteArray>()
channelB.setDataHandler { _, _, o, so, d ->
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(d.array())
val b = ByteArray(d.remaining())
d.get(b)
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(b)
}
channelA.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(largeData))
val receivedData = withTimeout(10000.milliseconds) { tcsDataB.await() }

View file

@ -14,7 +14,7 @@ interface IChannel : AutoCloseable {
val remoteVersion: Int?
var authorizable: IAuthorizable?
fun setDataHandler(onData: ((SyncSocketSession, IChannel, UByte, UByte, ByteBuffer) -> Unit)?)
fun send(opcode: UByte, subOpcode: UByte, data: ByteBuffer? = null)
fun send(opcode: UByte, subOpcode: UByte = 0u, data: ByteBuffer? = null)
fun setCloseHandler(onClose: ((IChannel) -> Unit)?)
}
@ -326,12 +326,4 @@ class ChannelRelayed(
completeHandshake(remoteVersion, transport)
}
}
fun handleData(data: ByteBuffer) {
val size = data.int
if (size != data.remaining()) throw IllegalStateException("Incomplete packet received")
val opcode = data.get().toUByte()
val subOpcode = data.get().toUByte()
invokeDataHandler(opcode, subOpcode, data)
}
}

View file

@ -155,7 +155,7 @@ class SyncSocketSession {
val plen: Int = _cipherStatePair!!.receiver.decryptWithAd(null, _buffer, 0, _bufferDecrypted, 0, messageSize)
//Logger.i(TAG, "Decrypted message (size = ${plen})")
handleData(_bufferDecrypted, plen)
handleData(_bufferDecrypted, plen, null)
} catch (e: Throwable) {
Logger.e(TAG, "Exception while receiving data", e)
break
@ -374,22 +374,25 @@ class SyncSocketSession {
}
}
@OptIn(ExperimentalUnsignedTypes::class)
private fun handleData(data: ByteArray, length: Int) {
private fun handleData(data: ByteArray, length: Int, sourceChannel: ChannelRelayed?) {
return handleData(ByteBuffer.wrap(data, 0, length).order(ByteOrder.LITTLE_ENDIAN), sourceChannel)
}
private fun handleData(data: ByteBuffer, sourceChannel: ChannelRelayed?) {
val length = data.remaining()
if (length < HEADER_SIZE)
throw Exception("Packet must be at least 6 bytes (header size)")
val size = ByteBuffer.wrap(data, 0, 4).order(ByteOrder.LITTLE_ENDIAN).int
val size = data.int
if (size != length - 4)
throw Exception("Incomplete packet received")
val opcode = data.asUByteArray()[4]
val subOpcode = data.asUByteArray()[5]
val packetData = ByteBuffer.wrap(data, HEADER_SIZE, size - 2)
handlePacket(opcode, subOpcode, packetData.order(ByteOrder.LITTLE_ENDIAN))
val opcode = data.get().toUByte()
val subOpcode = data.get().toUByte()
handlePacket(opcode, subOpcode, data, sourceChannel)
}
private fun handleRequest(subOpcode: UByte, data: ByteBuffer) {
private fun handleRequest(subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) {
when (subOpcode) {
RequestOpcode.TRANSPORT_RELAYED.value -> {
Logger.i(TAG, "Received request for a relayed transport")
@ -440,7 +443,7 @@ class SyncSocketSession {
}
}
private fun handleResponse(subOpcode: UByte, data: ByteBuffer) {
private fun handleResponse(subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) {
if (data.remaining() < 8) {
Logger.e(TAG, "Response packet too short")
return
@ -651,7 +654,7 @@ class SyncSocketSession {
return ConnectionInfo(port, name, remoteIp, ipv4Addresses, ipv6Addresses, allowLocalDirect, allowRemoteDirect, allowRemoteHolePunched, allowRemoteProxied)
}
private fun handleNotify(subOpcode: UByte, data: ByteBuffer) {
private fun handleNotify(subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) {
when (subOpcode) {
NotifyOpcode.AUTHORIZED.value, NotifyOpcode.UNAUTHORIZED.value -> _onData?.invoke(this, Opcode.NOTIFY.value, subOpcode, data)
NotifyOpcode.CONNECTION_INFO.value -> { /* Handle connection info if needed */ }
@ -666,7 +669,7 @@ class SyncSocketSession {
send(Opcode.RELAY.value, RelayOpcode.RELAY_ERROR.value, packet)
}
private fun handleRelay(subOpcode: UByte, data: ByteBuffer) {
private fun handleRelay(subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) {
when (subOpcode) {
RelayOpcode.RELAYED_DATA.value -> {
if (data.remaining() < 8) {
@ -680,7 +683,7 @@ class SyncSocketSession {
}
val decryptedPayload = channel.decrypt(data)
try {
channel.handleData(decryptedPayload)
handleData(decryptedPayload, channel)
} catch (e: Exception) {
Logger.e(TAG, "Exception while handling relayed data", e)
channel.sendError(SyncErrorCode.ConnectionClosed)
@ -726,33 +729,36 @@ class SyncSocketSession {
}
}
private fun handlePacket(opcode: UByte, subOpcode: UByte, data: ByteBuffer) {
private fun handlePacket(opcode: UByte, subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) {
Logger.i(TAG, "Handle packet (opcode = ${opcode}, subOpcode = ${subOpcode})")
when (opcode) {
Opcode.PING.value -> {
send(Opcode.PONG.value)
if (sourceChannel != null)
sourceChannel.send(Opcode.PONG.value)
else
send(Opcode.PONG.value)
//Logger.i(TAG, "Received ping, sent pong")
return
}
Opcode.PONG.value -> {
//Logger.i(TAG, "Received pong")
Logger.v(TAG, "Received pong")
return
}
Opcode.REQUEST.value -> {
handleRequest(subOpcode, data)
handleRequest(subOpcode, data, sourceChannel)
return
}
Opcode.RESPONSE.value -> {
handleResponse(subOpcode, data)
handleResponse(subOpcode, data, sourceChannel)
return
}
Opcode.NOTIFY.value -> {
handleNotify(subOpcode, data)
handleNotify(subOpcode, data, sourceChannel)
return
}
Opcode.RELAY.value -> {
handleRelay(subOpcode, data)
handleRelay(subOpcode, data, sourceChannel)
return
}
else -> if (isAuthorized) when (opcode) {
@ -809,12 +815,18 @@ class SyncSocketSession {
throw Exception("After sync stream end, the stream must be complete")
}
handlePacket(syncStream.opcode, syncStream.subOpcode, syncStream.getBytes().let { ByteBuffer.wrap(it).order(ByteOrder.LITTLE_ENDIAN) })
}
else -> {
Logger.w(TAG, "Unknown opcode received (opcode = ${opcode}, subOpcode = ${subOpcode})")
handlePacket(syncStream.opcode, syncStream.subOpcode, syncStream.getBytes().let { ByteBuffer.wrap(it).order(ByteOrder.LITTLE_ENDIAN) }, sourceChannel)
}
}
Opcode.DATA.value -> {
if (sourceChannel != null)
sourceChannel.invokeDataHandler(opcode, subOpcode, data)
else
_onData?.invoke(this, opcode, subOpcode, data)
}
else -> {
Logger.w(TAG, "Unknown opcode received (opcode = ${opcode}, subOpcode = ${subOpcode})")
}
}
}
@ -995,18 +1007,27 @@ class SyncSocketSession {
val deferred = CompletableDeferred<Boolean>()
_pendingPublishRequests[requestId] = deferred
try {
val MAX_PLAINTEXT_SIZE = 65535 - 16 // Adjust for tag size
val MAX_PLAINTEXT_SIZE = 65535
val HANDSHAKE_SIZE = 48
val LENGTH_SIZE = 4
val TAG_SIZE = 16
val chunkCount = (data.size + MAX_PLAINTEXT_SIZE - 1) / MAX_PLAINTEXT_SIZE
val blobSize = HANDSHAKE_SIZE + chunkCount * (LENGTH_SIZE + MAX_PLAINTEXT_SIZE + TAG_SIZE)
var blobSize = HANDSHAKE_SIZE
var dataOffset = 0
for (i in 0 until chunkCount) {
val chunkSize = minOf(MAX_PLAINTEXT_SIZE, data.size - dataOffset)
blobSize += LENGTH_SIZE + (chunkSize + TAG_SIZE)
dataOffset += chunkSize
}
val totalPacketSize = 4 + 1 + keyBytes.size + 1 + consumerPublicKeys.size * (32 + 4 + blobSize)
val packet = ByteBuffer.allocate(totalPacketSize).order(ByteOrder.LITTLE_ENDIAN)
packet.putInt(requestId)
packet.put(keyBytes.size.toByte())
packet.put(keyBytes)
packet.put(consumerPublicKeys.size.toByte())
for (consumer in consumerPublicKeys) {
val consumerBytes = Base64.getDecoder().decode(consumer)
if (consumerBytes.size != 32) throw IllegalArgumentException("Consumer public key must be 32 bytes")
@ -1020,9 +1041,10 @@ class SyncSocketSession {
val transportPair = protocol.split()
packet.putInt(blobSize)
packet.put(handshakeMessage)
var dataOffset = 0
dataOffset = 0
for (i in 0 until chunkCount) {
val chunkSize = min(MAX_PLAINTEXT_SIZE, data.size - dataOffset)
val chunkSize = minOf(MAX_PLAINTEXT_SIZE, data.size - dataOffset)
val plaintext = data.copyOfRange(dataOffset, dataOffset + chunkSize)
val ciphertext = ByteArray(chunkSize + TAG_SIZE)
val written = transportPair.sender.encryptWithAd(null, plaintext, 0, ciphertext, 0, plaintext.size)