mirror of
https://gitlab.futo.org/videostreaming/grayjay.git
synced 2025-04-20 03:24:50 +00:00
Added tests and fixes.
This commit is contained in:
parent
79a932b4ca
commit
97381739dd
3 changed files with 285 additions and 59 deletions
|
@ -0,0 +1,256 @@
|
|||
package com.futo.platformplayer
|
||||
|
||||
import com.futo.platformplayer.noise.protocol.Noise
|
||||
import com.futo.platformplayer.sync.internal.*
|
||||
import kotlinx.coroutines.*
|
||||
import org.junit.Assert.*
|
||||
import org.junit.Test
|
||||
import java.net.Socket
|
||||
import java.nio.ByteBuffer
|
||||
import kotlin.random.Random
|
||||
import kotlin.time.Duration.Companion.milliseconds
|
||||
|
||||
class SyncServerTests {
|
||||
|
||||
private val relayHost = "relay.grayjay.app"
|
||||
private val relayKey = "xGbHRzDOvE6plRbQaFgSen82eijF+gxS0yeUaeEErkw="
|
||||
private val relayPort = 9000
|
||||
|
||||
/** Creates a client connected to the live relay server. */
|
||||
private suspend fun createClient(
|
||||
onHandshakeComplete: ((SyncSocketSession) -> Unit)? = null,
|
||||
onData: ((SyncSocketSession, UByte, UByte, ByteBuffer) -> Unit)? = null,
|
||||
onNewChannel: ((SyncSocketSession, ChannelRelayed) -> Unit)? = null,
|
||||
isHandshakeAllowed: ((SyncSocketSession, String, String?) -> Boolean)? = null
|
||||
): SyncSocketSession = withContext(Dispatchers.IO) {
|
||||
val p = Noise.createDH("25519")
|
||||
p.generateKeyPair()
|
||||
val socket = Socket(relayHost, relayPort)
|
||||
val inputStream = LittleEndianDataInputStream(socket.getInputStream())
|
||||
val outputStream = LittleEndianDataOutputStream(socket.getOutputStream())
|
||||
val tcs = CompletableDeferred<Boolean>()
|
||||
val socketSession = SyncSocketSession(
|
||||
relayHost,
|
||||
p,
|
||||
inputStream,
|
||||
outputStream,
|
||||
onClose = { socket.close() },
|
||||
onHandshakeComplete = { s ->
|
||||
onHandshakeComplete?.invoke(s)
|
||||
tcs.complete(true)
|
||||
},
|
||||
onData = onData ?: { _, _, _, _ -> },
|
||||
onNewChannel = onNewChannel ?: { _, _ -> },
|
||||
isHandshakeAllowed = isHandshakeAllowed ?: { _, _, _ -> true }
|
||||
)
|
||||
socketSession.authorizable = AlwaysAuthorized()
|
||||
socketSession.startAsInitiator(relayKey)
|
||||
withTimeout(5000.milliseconds) { tcs.await() }
|
||||
return@withContext socketSession
|
||||
}
|
||||
|
||||
@Test
|
||||
fun multipleClientsHandshake_Success() = runBlocking {
|
||||
val client1 = createClient()
|
||||
val client2 = createClient()
|
||||
assertNotNull(client1.remotePublicKey, "Client 1 handshake failed")
|
||||
assertNotNull(client2.remotePublicKey, "Client 2 handshake failed")
|
||||
client1.stop()
|
||||
client2.stop()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun publishAndRequestConnectionInfo_Authorized_Success() = runBlocking {
|
||||
val clientA = createClient()
|
||||
val clientB = createClient()
|
||||
val clientC = createClient()
|
||||
clientA.publishConnectionInformation(arrayOf(clientB.localPublicKey), 12345, true, true, true, true)
|
||||
delay(100.milliseconds)
|
||||
val infoB = clientB.requestConnectionInfo(clientA.localPublicKey)
|
||||
val infoC = clientC.requestConnectionInfo(clientA.localPublicKey)
|
||||
assertNotNull("Client B should receive connection info", infoB)
|
||||
assertEquals(12345.toUShort(), infoB!!.port)
|
||||
assertNull("Client C should not receive connection info (unauthorized)", infoC)
|
||||
clientA.stop()
|
||||
clientB.stop()
|
||||
clientC.stop()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun relayedTransport_Bidirectional_Success() = runBlocking {
|
||||
val tcsA = CompletableDeferred<ChannelRelayed>()
|
||||
val tcsB = CompletableDeferred<ChannelRelayed>()
|
||||
val clientA = createClient(onNewChannel = { _, c -> tcsA.complete(c) })
|
||||
val clientB = createClient(onNewChannel = { _, c -> tcsB.complete(c) })
|
||||
val channelTask = async { clientA.startRelayedChannel(clientB.localPublicKey) }
|
||||
val channelA = withTimeout(5000.milliseconds) { tcsA.await() }
|
||||
channelA.authorizable = AlwaysAuthorized()
|
||||
val channelB = withTimeout(5000.milliseconds) { tcsB.await() }
|
||||
channelB.authorizable = AlwaysAuthorized()
|
||||
channelTask.await()
|
||||
|
||||
val tcsDataB = CompletableDeferred<ByteArray>()
|
||||
channelB.setDataHandler { _, _, o, so, d ->
|
||||
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(d.array())
|
||||
}
|
||||
channelA.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(byteArrayOf(1, 2, 3)))
|
||||
|
||||
val tcsDataA = CompletableDeferred<ByteArray>()
|
||||
channelA.setDataHandler { _, _, o, so, d ->
|
||||
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataA.complete(d.array())
|
||||
}
|
||||
channelB.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(byteArrayOf(4, 5, 6)))
|
||||
|
||||
val receivedB = withTimeout(5000.milliseconds) { tcsDataB.await() }
|
||||
val receivedA = withTimeout(5000.milliseconds) { tcsDataA.await() }
|
||||
assertArrayEquals(byteArrayOf(1, 2, 3), receivedB)
|
||||
assertArrayEquals(byteArrayOf(4, 5, 6), receivedA)
|
||||
clientA.stop()
|
||||
clientB.stop()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun relayedTransport_MaximumMessageSize_Success() = runBlocking {
|
||||
val MAX_DATA_PER_PACKET = SyncSocketSession.MAXIMUM_PACKET_SIZE - SyncSocketSession.HEADER_SIZE - 8 - 16 - 16
|
||||
val maxSizeData = ByteArray(MAX_DATA_PER_PACKET).apply { Random.nextBytes(this) }
|
||||
val tcsA = CompletableDeferred<ChannelRelayed>()
|
||||
val tcsB = CompletableDeferred<ChannelRelayed>()
|
||||
val clientA = createClient(onNewChannel = { _, c -> tcsA.complete(c) })
|
||||
val clientB = createClient(onNewChannel = { _, c -> tcsB.complete(c) })
|
||||
val channelTask = async { clientA.startRelayedChannel(clientB.localPublicKey) }
|
||||
val channelA = withTimeout(5000.milliseconds) { tcsA.await() }
|
||||
channelA.authorizable = AlwaysAuthorized()
|
||||
val channelB = withTimeout(5000.milliseconds) { tcsB.await() }
|
||||
channelB.authorizable = AlwaysAuthorized()
|
||||
channelTask.await()
|
||||
|
||||
val tcsDataB = CompletableDeferred<ByteArray>()
|
||||
channelB.setDataHandler { _, _, o, so, d ->
|
||||
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(d.array())
|
||||
}
|
||||
channelA.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(maxSizeData))
|
||||
val receivedData = withTimeout(5000.milliseconds) { tcsDataB.await() }
|
||||
assertArrayEquals(maxSizeData, receivedData)
|
||||
clientA.stop()
|
||||
clientB.stop()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun publishAndGetRecord_Success() = runBlocking {
|
||||
val clientA = createClient()
|
||||
val clientB = createClient()
|
||||
val clientC = createClient()
|
||||
val data = byteArrayOf(1, 2, 3)
|
||||
val success = clientA.publishRecords(listOf(clientB.localPublicKey), "testKey", data)
|
||||
val recordB = clientB.getRecord(clientA.localPublicKey, "testKey")
|
||||
val recordC = clientC.getRecord(clientA.localPublicKey, "testKey")
|
||||
assertTrue(success)
|
||||
assertNotNull(recordB)
|
||||
assertArrayEquals(data, recordB!!.first)
|
||||
assertNull("Unauthorized client should not access record", recordC)
|
||||
clientA.stop()
|
||||
clientB.stop()
|
||||
clientC.stop()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun getNonExistentRecord_ReturnsNull() = runBlocking {
|
||||
val clientA = createClient()
|
||||
val clientB = createClient()
|
||||
val record = clientB.getRecord(clientA.localPublicKey, "nonExistentKey")
|
||||
assertNull("Getting non-existent record should return null", record)
|
||||
clientA.stop()
|
||||
clientB.stop()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun updateRecord_TimestampUpdated() = runBlocking {
|
||||
val clientA = createClient()
|
||||
val clientB = createClient()
|
||||
val key = "updateKey"
|
||||
val data1 = byteArrayOf(1)
|
||||
val data2 = byteArrayOf(2)
|
||||
clientA.publishRecords(listOf(clientB.localPublicKey), key, data1)
|
||||
val record1 = clientB.getRecord(clientA.localPublicKey, key)
|
||||
delay(1000.milliseconds)
|
||||
clientA.publishRecords(listOf(clientB.localPublicKey), key, data2)
|
||||
val record2 = clientB.getRecord(clientA.localPublicKey, key)
|
||||
assertNotNull(record1)
|
||||
assertNotNull(record2)
|
||||
assertTrue(record2!!.second > record1!!.second)
|
||||
assertArrayEquals(data2, record2.first)
|
||||
clientA.stop()
|
||||
clientB.stop()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun deleteRecord_Success() = runBlocking {
|
||||
val clientA = createClient()
|
||||
val clientB = createClient()
|
||||
val data = byteArrayOf(1, 2, 3)
|
||||
clientA.publishRecords(listOf(clientB.localPublicKey), "toDelete", data)
|
||||
val success = clientB.deleteRecords(clientA.localPublicKey, clientB.localPublicKey, listOf("toDelete"))
|
||||
val record = clientB.getRecord(clientA.localPublicKey, "toDelete")
|
||||
assertTrue(success)
|
||||
assertNull(record)
|
||||
clientA.stop()
|
||||
clientB.stop()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun listRecordKeys_Success() = runBlocking {
|
||||
val clientA = createClient()
|
||||
val clientB = createClient()
|
||||
val keys = arrayOf("key1", "key2", "key3")
|
||||
keys.forEach { key ->
|
||||
clientA.publishRecords(listOf(clientB.localPublicKey), key, byteArrayOf(1))
|
||||
}
|
||||
val listedKeys = clientB.listRecordKeys(clientA.localPublicKey, clientB.localPublicKey)
|
||||
assertArrayEquals(keys, listedKeys.map { it.first }.toTypedArray())
|
||||
clientA.stop()
|
||||
clientB.stop()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun singleLargeMessageViaRelayedChannel_Success() = runBlocking {
|
||||
val largeData = ByteArray(100000).apply { Random.nextBytes(this) }
|
||||
val tcsA = CompletableDeferred<ChannelRelayed>()
|
||||
val tcsB = CompletableDeferred<ChannelRelayed>()
|
||||
val clientA = createClient(onNewChannel = { _, c -> tcsA.complete(c) })
|
||||
val clientB = createClient(onNewChannel = { _, c -> tcsB.complete(c) })
|
||||
val channelTask = async { clientA.startRelayedChannel(clientB.localPublicKey) }
|
||||
val channelA = withTimeout(5000.milliseconds) { tcsA.await() }
|
||||
channelA.authorizable = AlwaysAuthorized()
|
||||
val channelB = withTimeout(5000.milliseconds) { tcsB.await() }
|
||||
channelB.authorizable = AlwaysAuthorized()
|
||||
channelTask.await()
|
||||
|
||||
val tcsDataB = CompletableDeferred<ByteArray>()
|
||||
channelB.setDataHandler { _, _, o, so, d ->
|
||||
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(d.array())
|
||||
}
|
||||
channelA.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(largeData))
|
||||
val receivedData = withTimeout(10000.milliseconds) { tcsDataB.await() }
|
||||
assertArrayEquals(largeData, receivedData)
|
||||
clientA.stop()
|
||||
clientB.stop()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun publishAndGetLargeRecord_Success() = runBlocking {
|
||||
val largeData = ByteArray(1000000).apply { Random.nextBytes(this) }
|
||||
val clientA = createClient()
|
||||
val clientB = createClient()
|
||||
val success = clientA.publishRecords(listOf(clientB.localPublicKey), "largeRecord", largeData)
|
||||
val record = clientB.getRecord(clientA.localPublicKey, "largeRecord")
|
||||
assertTrue(success)
|
||||
assertNotNull(record)
|
||||
assertArrayEquals(largeData, record!!.first)
|
||||
clientA.stop()
|
||||
clientB.stop()
|
||||
}
|
||||
}
|
||||
|
||||
class AlwaysAuthorized : IAuthorizable {
|
||||
override val isAuthorized: Boolean get() = true
|
||||
}
|
|
@ -251,8 +251,7 @@ class ChannelRelayed(
|
|||
if (publicKeyBytes.size != 32) throw IllegalArgumentException("Public key must be 32 bytes")
|
||||
|
||||
val (pairingMessageLength, pairingMessage) = if (pairingCode != null) {
|
||||
val pairingProtocolName = "Noise_N_25519_ChaChaPoly_Blake2b"
|
||||
val pairingHandshake = HandshakeState(pairingProtocolName, HandshakeState.INITIATOR).apply {
|
||||
val pairingHandshake = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.INITIATOR).apply {
|
||||
remotePublicKey.setPublicKey(publicKeyBytes, 0)
|
||||
start()
|
||||
}
|
||||
|
|
|
@ -206,8 +206,7 @@ class SyncSocketSession {
|
|||
val pairingMessage: ByteArray
|
||||
val pairingMessageLength: Int
|
||||
if (pairingCode != null) {
|
||||
val pairingProtocolName = "Noise_N_25519_ChaChaPoly_Blake2b"
|
||||
val pairingHandshake = HandshakeState(pairingProtocolName, HandshakeState.INITIATOR)
|
||||
val pairingHandshake = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.INITIATOR)
|
||||
pairingHandshake.remotePublicKey.setPublicKey(Base64.getDecoder().decode(remotePublicKey), 0)
|
||||
pairingHandshake.start()
|
||||
val pairingCodeBytes = pairingCode.toByteArray(Charsets.UTF_8)
|
||||
|
@ -261,8 +260,7 @@ class SyncSocketSession {
|
|||
|
||||
var pairingCode: String? = null
|
||||
if (pairingMessageLength > 0) {
|
||||
val pairingProtocolName = "Noise_N_25519_ChaChaPoly_Blake2b"
|
||||
val pairingHandshake = HandshakeState(pairingProtocolName, HandshakeState.RESPONDER)
|
||||
val pairingHandshake = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.RESPONDER)
|
||||
pairingHandshake.localKeyPair.copyFrom(_localKeyPair)
|
||||
pairingHandshake.start()
|
||||
val pairingPlaintext = ByteArray(512)
|
||||
|
@ -298,47 +296,8 @@ class SyncSocketSession {
|
|||
throw Exception("Invalid version")
|
||||
}
|
||||
|
||||
private fun handshake(handshakeState: HandshakeState): CipherStatePair {
|
||||
handshakeState.start()
|
||||
|
||||
val message = ByteArray(8192)
|
||||
val plaintext = ByteArray(8192)
|
||||
|
||||
while (_started) {
|
||||
when (handshakeState.action) {
|
||||
HandshakeState.READ_MESSAGE -> {
|
||||
val messageSize = _inputStream.readInt()
|
||||
Logger.i(TAG, "Handshake read message (size = ${messageSize})")
|
||||
|
||||
var bytesRead = 0
|
||||
while (bytesRead < messageSize) {
|
||||
val read = _inputStream.read(message, bytesRead, messageSize - bytesRead)
|
||||
if (read == -1)
|
||||
throw Exception("Stream closed")
|
||||
bytesRead += read
|
||||
}
|
||||
|
||||
handshakeState.readMessage(message, 0, messageSize, plaintext, 0)
|
||||
}
|
||||
HandshakeState.WRITE_MESSAGE -> {
|
||||
val messageSize = handshakeState.writeMessage(message, 0, null, 0, 0)
|
||||
Logger.i(TAG, "Handshake wrote message (size = ${messageSize})")
|
||||
_outputStream.writeInt(messageSize)
|
||||
_outputStream.write(message, 0, messageSize)
|
||||
}
|
||||
HandshakeState.SPLIT -> {
|
||||
//Logger.i(TAG, "Handshake split")
|
||||
return handshakeState.split()
|
||||
}
|
||||
else -> throw Exception("Unexpected state (handshakeState.action = ${handshakeState.action})")
|
||||
}
|
||||
}
|
||||
|
||||
throw Exception("Handshake finished without completing")
|
||||
}
|
||||
|
||||
fun generateStreamId(): Int = synchronized(_streamIdGeneratorLock) { _streamIdGenerator++ }
|
||||
fun generateRequestId(): Int = synchronized(_requestIdGeneratorLock) { _requestIdGenerator++ }
|
||||
private fun generateRequestId(): Int = synchronized(_requestIdGeneratorLock) { _requestIdGenerator++ }
|
||||
|
||||
fun send(opcode: UByte, subOpcode: UByte, data: ByteBuffer) {
|
||||
ensureNotMainThread()
|
||||
|
@ -453,7 +412,7 @@ class SyncSocketSession {
|
|||
val channelHandshakeMessage = ByteArray(channelMessageLength).also { data.get(it) }
|
||||
val publicKey = Base64.getEncoder().encodeToString(publicKeyBytes)
|
||||
val pairingCode = if (pairingMessageLength > 0) {
|
||||
val pairingProtocol = HandshakeState("Noise_N_25519_ChaChaPoly_Blake2b", HandshakeState.RESPONDER).apply {
|
||||
val pairingProtocol = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.RESPONDER).apply {
|
||||
localKeyPair.copyFrom(_localKeyPair)
|
||||
start()
|
||||
}
|
||||
|
@ -467,6 +426,7 @@ class SyncSocketSession {
|
|||
rp.putInt(2) // Status code for not allowed
|
||||
rp.putLong(connectionId)
|
||||
rp.putInt(requestId)
|
||||
rp.rewind()
|
||||
send(Opcode.RESPONSE.value, ResponseOpcode.TRANSPORT.value, rp)
|
||||
return
|
||||
}
|
||||
|
@ -502,7 +462,7 @@ class SyncSocketSession {
|
|||
}
|
||||
} ?: Logger.e(TAG, "No pending request for requestId $requestId")
|
||||
}
|
||||
ResponseOpcode.TRANSPORT.value -> {
|
||||
ResponseOpcode.TRANSPORT_RELAYED.value -> {
|
||||
if (statusCode == 0) {
|
||||
if (data.remaining() < 16) {
|
||||
Logger.e(TAG, "RESPONSE_TRANSPORT packet too short")
|
||||
|
@ -564,7 +524,7 @@ class SyncSocketSession {
|
|||
val blobLength = data.int
|
||||
val encryptedBlob = ByteArray(blobLength).also { data.get(it) }
|
||||
val timestamp = data.long
|
||||
val protocol = HandshakeState("Noise_N_25519_ChaChaPoly_Blake2b", HandshakeState.RESPONDER).apply {
|
||||
val protocol = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.RESPONDER).apply {
|
||||
localKeyPair.copyFrom(_localKeyPair)
|
||||
start()
|
||||
}
|
||||
|
@ -607,7 +567,7 @@ class SyncSocketSession {
|
|||
val blobLength = data.int
|
||||
val encryptedBlob = ByteArray(blobLength).also { data.get(it) }
|
||||
val timestamp = data.long
|
||||
val protocol = HandshakeState("Noise_N_25519_ChaChaPoly_Blake2b", HandshakeState.RESPONDER).apply {
|
||||
val protocol = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.RESPONDER).apply {
|
||||
localKeyPair.copyFrom(_localKeyPair)
|
||||
start()
|
||||
}
|
||||
|
@ -667,7 +627,7 @@ class SyncSocketSession {
|
|||
val remoteIp = remoteIpBytes.joinToString(".") { it.toUByte().toString() }
|
||||
val handshakeMessage = ByteArray(48).also { data.get(it) }
|
||||
val ciphertext = ByteArray(data.remaining()).also { data.get(it) }
|
||||
val protocol = HandshakeState("Noise_N_25519_ChaChaPoly_Blake2b", HandshakeState.RESPONDER).apply {
|
||||
val protocol = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.RESPONDER).apply {
|
||||
localKeyPair.copyFrom(_localKeyPair)
|
||||
start()
|
||||
}
|
||||
|
@ -702,6 +662,7 @@ class SyncSocketSession {
|
|||
val packet = ByteBuffer.allocate(12).order(ByteOrder.LITTLE_ENDIAN)
|
||||
packet.putLong(connectionId)
|
||||
packet.putInt(errorCode.value)
|
||||
packet.rewind()
|
||||
send(Opcode.RELAY.value, RelayOpcode.RELAY_ERROR.value, packet)
|
||||
}
|
||||
|
||||
|
@ -872,6 +833,7 @@ class SyncSocketSession {
|
|||
val packet = ByteBuffer.allocate(4 + 32).order(ByteOrder.LITTLE_ENDIAN)
|
||||
packet.putInt(requestId)
|
||||
packet.put(publicKeyBytes)
|
||||
packet.rewind()
|
||||
send(Opcode.REQUEST.value, RequestOpcode.CONNECTION_INFO.value, packet)
|
||||
} catch (e: Exception) {
|
||||
_pendingConnectionInfoRequests.remove(requestId)?.completeExceptionally(e)
|
||||
|
@ -893,6 +855,7 @@ class SyncSocketSession {
|
|||
if (pkBytes.size != 32) throw IllegalArgumentException("Invalid public key length for $pk")
|
||||
packet.put(pkBytes)
|
||||
}
|
||||
packet.rewind()
|
||||
send(Opcode.REQUEST.value, RequestOpcode.BULK_CONNECTION_INFO.value, packet)
|
||||
} catch (e: Exception) {
|
||||
_pendingBulkConnectionInfoRequests.remove(requestId)?.completeExceptionally(e)
|
||||
|
@ -949,7 +912,6 @@ class SyncSocketSession {
|
|||
) {
|
||||
if (authorizedKeys.size > 255) throw IllegalArgumentException("Number of authorized keys exceeds 255")
|
||||
|
||||
// **Step 1: Collect Network Information**
|
||||
val ipv4Addresses = mutableListOf<String>()
|
||||
val ipv6Addresses = mutableListOf<String>()
|
||||
for (nic in NetworkInterface.getNetworkInterfaces()) {
|
||||
|
@ -965,11 +927,9 @@ class SyncSocketSession {
|
|||
}
|
||||
}
|
||||
|
||||
// **Step 2: Get Device Name**
|
||||
val deviceName = getDeviceName()
|
||||
val nameBytes = getLimitedUtf8Bytes(deviceName, 255)
|
||||
|
||||
// **Step 3: Serialize Connection Information**
|
||||
val blobSize = 2 + 1 + nameBytes.size + 1 + ipv4Addresses.size * 4 + 1 + ipv6Addresses.size * 16 + 1 + 1 + 1 + 1
|
||||
val data = ByteBuffer.allocate(blobSize).order(ByteOrder.LITTLE_ENDIAN)
|
||||
data.putShort(port.toShort())
|
||||
|
@ -990,19 +950,19 @@ class SyncSocketSession {
|
|||
data.put(if (allowRemoteHolePunched) 1 else 0)
|
||||
data.put(if (allowRemoteProxied) 1 else 0)
|
||||
|
||||
// **Step 4: Precalculate Total Size**
|
||||
val handshakeSize = 48 // Noise handshake size for N pattern
|
||||
|
||||
data.rewind()
|
||||
val ciphertextSize = data.remaining() + 16 // Encrypted data size
|
||||
val totalSize = 1 + authorizedKeys.size * (32 + handshakeSize + 4 + ciphertextSize)
|
||||
val publishBytes = ByteBuffer.allocate(totalSize).order(ByteOrder.LITTLE_ENDIAN)
|
||||
publishBytes.put(authorizedKeys.size.toByte())
|
||||
|
||||
// **Step 5: Encrypt Data for Each Authorized Key**
|
||||
for (key in authorizedKeys) {
|
||||
val publicKeyBytes = Base64.getDecoder().decode(key)
|
||||
if (publicKeyBytes.size != 32) throw IllegalArgumentException("Public key must be 32 bytes")
|
||||
|
||||
val protocol = HandshakeState("Noise_N_25519_ChaChaPoly_Blake2b", HandshakeState.INITIATOR)
|
||||
val protocol = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.INITIATOR)
|
||||
protocol.remotePublicKey.setPublicKey(publicKeyBytes, 0)
|
||||
protocol.start()
|
||||
|
||||
|
@ -1023,7 +983,7 @@ class SyncSocketSession {
|
|||
publishBytes.put(ciphertext, 0, ciphertextBytesWritten)
|
||||
}
|
||||
|
||||
// **Step 6: Send the Encrypted Data**
|
||||
publishBytes.rewind()
|
||||
send(Opcode.NOTIFY.value, NotifyOpcode.CONNECTION_INFO.value, publishBytes)
|
||||
}
|
||||
|
||||
|
@ -1051,7 +1011,7 @@ class SyncSocketSession {
|
|||
val consumerBytes = Base64.getDecoder().decode(consumer)
|
||||
if (consumerBytes.size != 32) throw IllegalArgumentException("Consumer public key must be 32 bytes")
|
||||
packet.put(consumerBytes)
|
||||
val protocol = HandshakeState("Noise_N_25519_ChaChaPoly_Blake2b", HandshakeState.INITIATOR).apply {
|
||||
val protocol = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.INITIATOR).apply {
|
||||
remotePublicKey.setPublicKey(consumerBytes, 0)
|
||||
start()
|
||||
}
|
||||
|
@ -1071,6 +1031,7 @@ class SyncSocketSession {
|
|||
dataOffset += chunkSize
|
||||
}
|
||||
}
|
||||
packet.rewind()
|
||||
send(Opcode.REQUEST.value, RequestOpcode.BULK_PUBLISH_RECORD.value, packet)
|
||||
} catch (e: Exception) {
|
||||
_pendingPublishRequests.remove(requestId)?.completeExceptionally(e)
|
||||
|
@ -1093,6 +1054,7 @@ class SyncSocketSession {
|
|||
packet.put(publisherBytes)
|
||||
packet.put(keyBytes.size.toByte())
|
||||
packet.put(keyBytes)
|
||||
packet.rewind()
|
||||
send(Opcode.REQUEST.value, RequestOpcode.GET_RECORD.value, packet)
|
||||
} catch (e: Exception) {
|
||||
_pendingGetRecordRequests.remove(requestId)?.completeExceptionally(e)
|
||||
|
@ -1119,6 +1081,7 @@ class SyncSocketSession {
|
|||
if (bytes.size != 32) throw IllegalArgumentException("Publisher public key must be 32 bytes")
|
||||
packet.put(bytes)
|
||||
}
|
||||
packet.rewind()
|
||||
send(Opcode.REQUEST.value, RequestOpcode.BULK_GET_RECORD.value, packet)
|
||||
} catch (e: Exception) {
|
||||
_pendingBulkGetRecordRequests.remove(requestId)?.completeExceptionally(e)
|
||||
|
@ -1148,6 +1111,7 @@ class SyncSocketSession {
|
|||
packet.put(keyBytes.size.toByte())
|
||||
packet.put(keyBytes)
|
||||
}
|
||||
packet.rewind()
|
||||
send(Opcode.REQUEST.value, RequestOpcode.BULK_DELETE_RECORD.value, packet)
|
||||
} catch (e: Exception) {
|
||||
_pendingDeleteRequests.remove(requestId)?.completeExceptionally(e)
|
||||
|
@ -1169,6 +1133,7 @@ class SyncSocketSession {
|
|||
packet.putInt(requestId)
|
||||
packet.put(publisherBytes)
|
||||
packet.put(consumerBytes)
|
||||
packet.rewind()
|
||||
send(Opcode.REQUEST.value, RequestOpcode.LIST_RECORD_KEYS.value, packet)
|
||||
} catch (e: Exception) {
|
||||
_pendingListKeysRequests.remove(requestId)?.completeExceptionally(e)
|
||||
|
@ -1178,6 +1143,12 @@ class SyncSocketSession {
|
|||
}
|
||||
|
||||
companion object {
|
||||
val dh = "25519"
|
||||
val pattern = "N"
|
||||
val cipher = "ChaChaPoly"
|
||||
val hash = "BLAKE2b"
|
||||
var nProtocolName = "Noise_${pattern}_${dh}_${cipher}_${hash}"
|
||||
|
||||
private const val TAG = "SyncSocketSession"
|
||||
const val MAXIMUM_PACKET_SIZE = 65535 - 16
|
||||
const val MAXIMUM_PACKET_SIZE_ENCRYPTED = MAXIMUM_PACKET_SIZE + 16
|
||||
|
|
Loading…
Add table
Reference in a new issue