Added tests and fixes.

This commit is contained in:
Koen J 2025-04-09 14:01:36 +02:00
parent 79a932b4ca
commit 97381739dd
3 changed files with 285 additions and 59 deletions

View file

@ -0,0 +1,256 @@
package com.futo.platformplayer
import com.futo.platformplayer.noise.protocol.Noise
import com.futo.platformplayer.sync.internal.*
import kotlinx.coroutines.*
import org.junit.Assert.*
import org.junit.Test
import java.net.Socket
import java.nio.ByteBuffer
import kotlin.random.Random
import kotlin.time.Duration.Companion.milliseconds
class SyncServerTests {
private val relayHost = "relay.grayjay.app"
private val relayKey = "xGbHRzDOvE6plRbQaFgSen82eijF+gxS0yeUaeEErkw="
private val relayPort = 9000
/** Creates a client connected to the live relay server. */
private suspend fun createClient(
onHandshakeComplete: ((SyncSocketSession) -> Unit)? = null,
onData: ((SyncSocketSession, UByte, UByte, ByteBuffer) -> Unit)? = null,
onNewChannel: ((SyncSocketSession, ChannelRelayed) -> Unit)? = null,
isHandshakeAllowed: ((SyncSocketSession, String, String?) -> Boolean)? = null
): SyncSocketSession = withContext(Dispatchers.IO) {
val p = Noise.createDH("25519")
p.generateKeyPair()
val socket = Socket(relayHost, relayPort)
val inputStream = LittleEndianDataInputStream(socket.getInputStream())
val outputStream = LittleEndianDataOutputStream(socket.getOutputStream())
val tcs = CompletableDeferred<Boolean>()
val socketSession = SyncSocketSession(
relayHost,
p,
inputStream,
outputStream,
onClose = { socket.close() },
onHandshakeComplete = { s ->
onHandshakeComplete?.invoke(s)
tcs.complete(true)
},
onData = onData ?: { _, _, _, _ -> },
onNewChannel = onNewChannel ?: { _, _ -> },
isHandshakeAllowed = isHandshakeAllowed ?: { _, _, _ -> true }
)
socketSession.authorizable = AlwaysAuthorized()
socketSession.startAsInitiator(relayKey)
withTimeout(5000.milliseconds) { tcs.await() }
return@withContext socketSession
}
@Test
fun multipleClientsHandshake_Success() = runBlocking {
val client1 = createClient()
val client2 = createClient()
assertNotNull(client1.remotePublicKey, "Client 1 handshake failed")
assertNotNull(client2.remotePublicKey, "Client 2 handshake failed")
client1.stop()
client2.stop()
}
@Test
fun publishAndRequestConnectionInfo_Authorized_Success() = runBlocking {
val clientA = createClient()
val clientB = createClient()
val clientC = createClient()
clientA.publishConnectionInformation(arrayOf(clientB.localPublicKey), 12345, true, true, true, true)
delay(100.milliseconds)
val infoB = clientB.requestConnectionInfo(clientA.localPublicKey)
val infoC = clientC.requestConnectionInfo(clientA.localPublicKey)
assertNotNull("Client B should receive connection info", infoB)
assertEquals(12345.toUShort(), infoB!!.port)
assertNull("Client C should not receive connection info (unauthorized)", infoC)
clientA.stop()
clientB.stop()
clientC.stop()
}
@Test
fun relayedTransport_Bidirectional_Success() = runBlocking {
val tcsA = CompletableDeferred<ChannelRelayed>()
val tcsB = CompletableDeferred<ChannelRelayed>()
val clientA = createClient(onNewChannel = { _, c -> tcsA.complete(c) })
val clientB = createClient(onNewChannel = { _, c -> tcsB.complete(c) })
val channelTask = async { clientA.startRelayedChannel(clientB.localPublicKey) }
val channelA = withTimeout(5000.milliseconds) { tcsA.await() }
channelA.authorizable = AlwaysAuthorized()
val channelB = withTimeout(5000.milliseconds) { tcsB.await() }
channelB.authorizable = AlwaysAuthorized()
channelTask.await()
val tcsDataB = CompletableDeferred<ByteArray>()
channelB.setDataHandler { _, _, o, so, d ->
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(d.array())
}
channelA.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(byteArrayOf(1, 2, 3)))
val tcsDataA = CompletableDeferred<ByteArray>()
channelA.setDataHandler { _, _, o, so, d ->
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataA.complete(d.array())
}
channelB.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(byteArrayOf(4, 5, 6)))
val receivedB = withTimeout(5000.milliseconds) { tcsDataB.await() }
val receivedA = withTimeout(5000.milliseconds) { tcsDataA.await() }
assertArrayEquals(byteArrayOf(1, 2, 3), receivedB)
assertArrayEquals(byteArrayOf(4, 5, 6), receivedA)
clientA.stop()
clientB.stop()
}
@Test
fun relayedTransport_MaximumMessageSize_Success() = runBlocking {
val MAX_DATA_PER_PACKET = SyncSocketSession.MAXIMUM_PACKET_SIZE - SyncSocketSession.HEADER_SIZE - 8 - 16 - 16
val maxSizeData = ByteArray(MAX_DATA_PER_PACKET).apply { Random.nextBytes(this) }
val tcsA = CompletableDeferred<ChannelRelayed>()
val tcsB = CompletableDeferred<ChannelRelayed>()
val clientA = createClient(onNewChannel = { _, c -> tcsA.complete(c) })
val clientB = createClient(onNewChannel = { _, c -> tcsB.complete(c) })
val channelTask = async { clientA.startRelayedChannel(clientB.localPublicKey) }
val channelA = withTimeout(5000.milliseconds) { tcsA.await() }
channelA.authorizable = AlwaysAuthorized()
val channelB = withTimeout(5000.milliseconds) { tcsB.await() }
channelB.authorizable = AlwaysAuthorized()
channelTask.await()
val tcsDataB = CompletableDeferred<ByteArray>()
channelB.setDataHandler { _, _, o, so, d ->
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(d.array())
}
channelA.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(maxSizeData))
val receivedData = withTimeout(5000.milliseconds) { tcsDataB.await() }
assertArrayEquals(maxSizeData, receivedData)
clientA.stop()
clientB.stop()
}
@Test
fun publishAndGetRecord_Success() = runBlocking {
val clientA = createClient()
val clientB = createClient()
val clientC = createClient()
val data = byteArrayOf(1, 2, 3)
val success = clientA.publishRecords(listOf(clientB.localPublicKey), "testKey", data)
val recordB = clientB.getRecord(clientA.localPublicKey, "testKey")
val recordC = clientC.getRecord(clientA.localPublicKey, "testKey")
assertTrue(success)
assertNotNull(recordB)
assertArrayEquals(data, recordB!!.first)
assertNull("Unauthorized client should not access record", recordC)
clientA.stop()
clientB.stop()
clientC.stop()
}
@Test
fun getNonExistentRecord_ReturnsNull() = runBlocking {
val clientA = createClient()
val clientB = createClient()
val record = clientB.getRecord(clientA.localPublicKey, "nonExistentKey")
assertNull("Getting non-existent record should return null", record)
clientA.stop()
clientB.stop()
}
@Test
fun updateRecord_TimestampUpdated() = runBlocking {
val clientA = createClient()
val clientB = createClient()
val key = "updateKey"
val data1 = byteArrayOf(1)
val data2 = byteArrayOf(2)
clientA.publishRecords(listOf(clientB.localPublicKey), key, data1)
val record1 = clientB.getRecord(clientA.localPublicKey, key)
delay(1000.milliseconds)
clientA.publishRecords(listOf(clientB.localPublicKey), key, data2)
val record2 = clientB.getRecord(clientA.localPublicKey, key)
assertNotNull(record1)
assertNotNull(record2)
assertTrue(record2!!.second > record1!!.second)
assertArrayEquals(data2, record2.first)
clientA.stop()
clientB.stop()
}
@Test
fun deleteRecord_Success() = runBlocking {
val clientA = createClient()
val clientB = createClient()
val data = byteArrayOf(1, 2, 3)
clientA.publishRecords(listOf(clientB.localPublicKey), "toDelete", data)
val success = clientB.deleteRecords(clientA.localPublicKey, clientB.localPublicKey, listOf("toDelete"))
val record = clientB.getRecord(clientA.localPublicKey, "toDelete")
assertTrue(success)
assertNull(record)
clientA.stop()
clientB.stop()
}
@Test
fun listRecordKeys_Success() = runBlocking {
val clientA = createClient()
val clientB = createClient()
val keys = arrayOf("key1", "key2", "key3")
keys.forEach { key ->
clientA.publishRecords(listOf(clientB.localPublicKey), key, byteArrayOf(1))
}
val listedKeys = clientB.listRecordKeys(clientA.localPublicKey, clientB.localPublicKey)
assertArrayEquals(keys, listedKeys.map { it.first }.toTypedArray())
clientA.stop()
clientB.stop()
}
@Test
fun singleLargeMessageViaRelayedChannel_Success() = runBlocking {
val largeData = ByteArray(100000).apply { Random.nextBytes(this) }
val tcsA = CompletableDeferred<ChannelRelayed>()
val tcsB = CompletableDeferred<ChannelRelayed>()
val clientA = createClient(onNewChannel = { _, c -> tcsA.complete(c) })
val clientB = createClient(onNewChannel = { _, c -> tcsB.complete(c) })
val channelTask = async { clientA.startRelayedChannel(clientB.localPublicKey) }
val channelA = withTimeout(5000.milliseconds) { tcsA.await() }
channelA.authorizable = AlwaysAuthorized()
val channelB = withTimeout(5000.milliseconds) { tcsB.await() }
channelB.authorizable = AlwaysAuthorized()
channelTask.await()
val tcsDataB = CompletableDeferred<ByteArray>()
channelB.setDataHandler { _, _, o, so, d ->
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(d.array())
}
channelA.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(largeData))
val receivedData = withTimeout(10000.milliseconds) { tcsDataB.await() }
assertArrayEquals(largeData, receivedData)
clientA.stop()
clientB.stop()
}
@Test
fun publishAndGetLargeRecord_Success() = runBlocking {
val largeData = ByteArray(1000000).apply { Random.nextBytes(this) }
val clientA = createClient()
val clientB = createClient()
val success = clientA.publishRecords(listOf(clientB.localPublicKey), "largeRecord", largeData)
val record = clientB.getRecord(clientA.localPublicKey, "largeRecord")
assertTrue(success)
assertNotNull(record)
assertArrayEquals(largeData, record!!.first)
clientA.stop()
clientB.stop()
}
}
class AlwaysAuthorized : IAuthorizable {
override val isAuthorized: Boolean get() = true
}

View file

@ -251,8 +251,7 @@ class ChannelRelayed(
if (publicKeyBytes.size != 32) throw IllegalArgumentException("Public key must be 32 bytes")
val (pairingMessageLength, pairingMessage) = if (pairingCode != null) {
val pairingProtocolName = "Noise_N_25519_ChaChaPoly_Blake2b"
val pairingHandshake = HandshakeState(pairingProtocolName, HandshakeState.INITIATOR).apply {
val pairingHandshake = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.INITIATOR).apply {
remotePublicKey.setPublicKey(publicKeyBytes, 0)
start()
}

View file

@ -206,8 +206,7 @@ class SyncSocketSession {
val pairingMessage: ByteArray
val pairingMessageLength: Int
if (pairingCode != null) {
val pairingProtocolName = "Noise_N_25519_ChaChaPoly_Blake2b"
val pairingHandshake = HandshakeState(pairingProtocolName, HandshakeState.INITIATOR)
val pairingHandshake = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.INITIATOR)
pairingHandshake.remotePublicKey.setPublicKey(Base64.getDecoder().decode(remotePublicKey), 0)
pairingHandshake.start()
val pairingCodeBytes = pairingCode.toByteArray(Charsets.UTF_8)
@ -261,8 +260,7 @@ class SyncSocketSession {
var pairingCode: String? = null
if (pairingMessageLength > 0) {
val pairingProtocolName = "Noise_N_25519_ChaChaPoly_Blake2b"
val pairingHandshake = HandshakeState(pairingProtocolName, HandshakeState.RESPONDER)
val pairingHandshake = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.RESPONDER)
pairingHandshake.localKeyPair.copyFrom(_localKeyPair)
pairingHandshake.start()
val pairingPlaintext = ByteArray(512)
@ -298,47 +296,8 @@ class SyncSocketSession {
throw Exception("Invalid version")
}
private fun handshake(handshakeState: HandshakeState): CipherStatePair {
handshakeState.start()
val message = ByteArray(8192)
val plaintext = ByteArray(8192)
while (_started) {
when (handshakeState.action) {
HandshakeState.READ_MESSAGE -> {
val messageSize = _inputStream.readInt()
Logger.i(TAG, "Handshake read message (size = ${messageSize})")
var bytesRead = 0
while (bytesRead < messageSize) {
val read = _inputStream.read(message, bytesRead, messageSize - bytesRead)
if (read == -1)
throw Exception("Stream closed")
bytesRead += read
}
handshakeState.readMessage(message, 0, messageSize, plaintext, 0)
}
HandshakeState.WRITE_MESSAGE -> {
val messageSize = handshakeState.writeMessage(message, 0, null, 0, 0)
Logger.i(TAG, "Handshake wrote message (size = ${messageSize})")
_outputStream.writeInt(messageSize)
_outputStream.write(message, 0, messageSize)
}
HandshakeState.SPLIT -> {
//Logger.i(TAG, "Handshake split")
return handshakeState.split()
}
else -> throw Exception("Unexpected state (handshakeState.action = ${handshakeState.action})")
}
}
throw Exception("Handshake finished without completing")
}
fun generateStreamId(): Int = synchronized(_streamIdGeneratorLock) { _streamIdGenerator++ }
fun generateRequestId(): Int = synchronized(_requestIdGeneratorLock) { _requestIdGenerator++ }
private fun generateRequestId(): Int = synchronized(_requestIdGeneratorLock) { _requestIdGenerator++ }
fun send(opcode: UByte, subOpcode: UByte, data: ByteBuffer) {
ensureNotMainThread()
@ -453,7 +412,7 @@ class SyncSocketSession {
val channelHandshakeMessage = ByteArray(channelMessageLength).also { data.get(it) }
val publicKey = Base64.getEncoder().encodeToString(publicKeyBytes)
val pairingCode = if (pairingMessageLength > 0) {
val pairingProtocol = HandshakeState("Noise_N_25519_ChaChaPoly_Blake2b", HandshakeState.RESPONDER).apply {
val pairingProtocol = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.RESPONDER).apply {
localKeyPair.copyFrom(_localKeyPair)
start()
}
@ -467,6 +426,7 @@ class SyncSocketSession {
rp.putInt(2) // Status code for not allowed
rp.putLong(connectionId)
rp.putInt(requestId)
rp.rewind()
send(Opcode.RESPONSE.value, ResponseOpcode.TRANSPORT.value, rp)
return
}
@ -502,7 +462,7 @@ class SyncSocketSession {
}
} ?: Logger.e(TAG, "No pending request for requestId $requestId")
}
ResponseOpcode.TRANSPORT.value -> {
ResponseOpcode.TRANSPORT_RELAYED.value -> {
if (statusCode == 0) {
if (data.remaining() < 16) {
Logger.e(TAG, "RESPONSE_TRANSPORT packet too short")
@ -564,7 +524,7 @@ class SyncSocketSession {
val blobLength = data.int
val encryptedBlob = ByteArray(blobLength).also { data.get(it) }
val timestamp = data.long
val protocol = HandshakeState("Noise_N_25519_ChaChaPoly_Blake2b", HandshakeState.RESPONDER).apply {
val protocol = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.RESPONDER).apply {
localKeyPair.copyFrom(_localKeyPair)
start()
}
@ -607,7 +567,7 @@ class SyncSocketSession {
val blobLength = data.int
val encryptedBlob = ByteArray(blobLength).also { data.get(it) }
val timestamp = data.long
val protocol = HandshakeState("Noise_N_25519_ChaChaPoly_Blake2b", HandshakeState.RESPONDER).apply {
val protocol = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.RESPONDER).apply {
localKeyPair.copyFrom(_localKeyPair)
start()
}
@ -667,7 +627,7 @@ class SyncSocketSession {
val remoteIp = remoteIpBytes.joinToString(".") { it.toUByte().toString() }
val handshakeMessage = ByteArray(48).also { data.get(it) }
val ciphertext = ByteArray(data.remaining()).also { data.get(it) }
val protocol = HandshakeState("Noise_N_25519_ChaChaPoly_Blake2b", HandshakeState.RESPONDER).apply {
val protocol = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.RESPONDER).apply {
localKeyPair.copyFrom(_localKeyPair)
start()
}
@ -702,6 +662,7 @@ class SyncSocketSession {
val packet = ByteBuffer.allocate(12).order(ByteOrder.LITTLE_ENDIAN)
packet.putLong(connectionId)
packet.putInt(errorCode.value)
packet.rewind()
send(Opcode.RELAY.value, RelayOpcode.RELAY_ERROR.value, packet)
}
@ -872,6 +833,7 @@ class SyncSocketSession {
val packet = ByteBuffer.allocate(4 + 32).order(ByteOrder.LITTLE_ENDIAN)
packet.putInt(requestId)
packet.put(publicKeyBytes)
packet.rewind()
send(Opcode.REQUEST.value, RequestOpcode.CONNECTION_INFO.value, packet)
} catch (e: Exception) {
_pendingConnectionInfoRequests.remove(requestId)?.completeExceptionally(e)
@ -893,6 +855,7 @@ class SyncSocketSession {
if (pkBytes.size != 32) throw IllegalArgumentException("Invalid public key length for $pk")
packet.put(pkBytes)
}
packet.rewind()
send(Opcode.REQUEST.value, RequestOpcode.BULK_CONNECTION_INFO.value, packet)
} catch (e: Exception) {
_pendingBulkConnectionInfoRequests.remove(requestId)?.completeExceptionally(e)
@ -949,7 +912,6 @@ class SyncSocketSession {
) {
if (authorizedKeys.size > 255) throw IllegalArgumentException("Number of authorized keys exceeds 255")
// **Step 1: Collect Network Information**
val ipv4Addresses = mutableListOf<String>()
val ipv6Addresses = mutableListOf<String>()
for (nic in NetworkInterface.getNetworkInterfaces()) {
@ -965,11 +927,9 @@ class SyncSocketSession {
}
}
// **Step 2: Get Device Name**
val deviceName = getDeviceName()
val nameBytes = getLimitedUtf8Bytes(deviceName, 255)
// **Step 3: Serialize Connection Information**
val blobSize = 2 + 1 + nameBytes.size + 1 + ipv4Addresses.size * 4 + 1 + ipv6Addresses.size * 16 + 1 + 1 + 1 + 1
val data = ByteBuffer.allocate(blobSize).order(ByteOrder.LITTLE_ENDIAN)
data.putShort(port.toShort())
@ -990,19 +950,19 @@ class SyncSocketSession {
data.put(if (allowRemoteHolePunched) 1 else 0)
data.put(if (allowRemoteProxied) 1 else 0)
// **Step 4: Precalculate Total Size**
val handshakeSize = 48 // Noise handshake size for N pattern
data.rewind()
val ciphertextSize = data.remaining() + 16 // Encrypted data size
val totalSize = 1 + authorizedKeys.size * (32 + handshakeSize + 4 + ciphertextSize)
val publishBytes = ByteBuffer.allocate(totalSize).order(ByteOrder.LITTLE_ENDIAN)
publishBytes.put(authorizedKeys.size.toByte())
// **Step 5: Encrypt Data for Each Authorized Key**
for (key in authorizedKeys) {
val publicKeyBytes = Base64.getDecoder().decode(key)
if (publicKeyBytes.size != 32) throw IllegalArgumentException("Public key must be 32 bytes")
val protocol = HandshakeState("Noise_N_25519_ChaChaPoly_Blake2b", HandshakeState.INITIATOR)
val protocol = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.INITIATOR)
protocol.remotePublicKey.setPublicKey(publicKeyBytes, 0)
protocol.start()
@ -1023,7 +983,7 @@ class SyncSocketSession {
publishBytes.put(ciphertext, 0, ciphertextBytesWritten)
}
// **Step 6: Send the Encrypted Data**
publishBytes.rewind()
send(Opcode.NOTIFY.value, NotifyOpcode.CONNECTION_INFO.value, publishBytes)
}
@ -1051,7 +1011,7 @@ class SyncSocketSession {
val consumerBytes = Base64.getDecoder().decode(consumer)
if (consumerBytes.size != 32) throw IllegalArgumentException("Consumer public key must be 32 bytes")
packet.put(consumerBytes)
val protocol = HandshakeState("Noise_N_25519_ChaChaPoly_Blake2b", HandshakeState.INITIATOR).apply {
val protocol = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.INITIATOR).apply {
remotePublicKey.setPublicKey(consumerBytes, 0)
start()
}
@ -1071,6 +1031,7 @@ class SyncSocketSession {
dataOffset += chunkSize
}
}
packet.rewind()
send(Opcode.REQUEST.value, RequestOpcode.BULK_PUBLISH_RECORD.value, packet)
} catch (e: Exception) {
_pendingPublishRequests.remove(requestId)?.completeExceptionally(e)
@ -1093,6 +1054,7 @@ class SyncSocketSession {
packet.put(publisherBytes)
packet.put(keyBytes.size.toByte())
packet.put(keyBytes)
packet.rewind()
send(Opcode.REQUEST.value, RequestOpcode.GET_RECORD.value, packet)
} catch (e: Exception) {
_pendingGetRecordRequests.remove(requestId)?.completeExceptionally(e)
@ -1119,6 +1081,7 @@ class SyncSocketSession {
if (bytes.size != 32) throw IllegalArgumentException("Publisher public key must be 32 bytes")
packet.put(bytes)
}
packet.rewind()
send(Opcode.REQUEST.value, RequestOpcode.BULK_GET_RECORD.value, packet)
} catch (e: Exception) {
_pendingBulkGetRecordRequests.remove(requestId)?.completeExceptionally(e)
@ -1148,6 +1111,7 @@ class SyncSocketSession {
packet.put(keyBytes.size.toByte())
packet.put(keyBytes)
}
packet.rewind()
send(Opcode.REQUEST.value, RequestOpcode.BULK_DELETE_RECORD.value, packet)
} catch (e: Exception) {
_pendingDeleteRequests.remove(requestId)?.completeExceptionally(e)
@ -1169,6 +1133,7 @@ class SyncSocketSession {
packet.putInt(requestId)
packet.put(publisherBytes)
packet.put(consumerBytes)
packet.rewind()
send(Opcode.REQUEST.value, RequestOpcode.LIST_RECORD_KEYS.value, packet)
} catch (e: Exception) {
_pendingListKeysRequests.remove(requestId)?.completeExceptionally(e)
@ -1178,6 +1143,12 @@ class SyncSocketSession {
}
companion object {
val dh = "25519"
val pattern = "N"
val cipher = "ChaChaPoly"
val hash = "BLAKE2b"
var nProtocolName = "Noise_${pattern}_${dh}_${cipher}_${hash}"
private const val TAG = "SyncSocketSession"
const val MAXIMUM_PACKET_SIZE = 65535 - 16
const val MAXIMUM_PACKET_SIZE_ENCRYPTED = MAXIMUM_PACKET_SIZE + 16