mirror of
https://gitlab.futo.org/videostreaming/grayjay.git
synced 2025-04-19 19:14:51 +00:00
Fixed various implementation bugs.
This commit is contained in:
parent
1ae9f0ea26
commit
955ba23b0d
3 changed files with 60 additions and 40 deletions
|
@ -93,13 +93,17 @@ class SyncServerTests {
|
|||
|
||||
val tcsDataB = CompletableDeferred<ByteArray>()
|
||||
channelB.setDataHandler { _, _, o, so, d ->
|
||||
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(d.array())
|
||||
val b = ByteArray(d.remaining())
|
||||
d.get(b)
|
||||
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(b)
|
||||
}
|
||||
channelA.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(byteArrayOf(1, 2, 3)))
|
||||
|
||||
val tcsDataA = CompletableDeferred<ByteArray>()
|
||||
channelA.setDataHandler { _, _, o, so, d ->
|
||||
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataA.complete(d.array())
|
||||
val b = ByteArray(d.remaining())
|
||||
d.get(b)
|
||||
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataA.complete(b)
|
||||
}
|
||||
channelB.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(byteArrayOf(4, 5, 6)))
|
||||
|
||||
|
@ -231,7 +235,9 @@ class SyncServerTests {
|
|||
|
||||
val tcsDataB = CompletableDeferred<ByteArray>()
|
||||
channelB.setDataHandler { _, _, o, so, d ->
|
||||
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(d.array())
|
||||
val b = ByteArray(d.remaining())
|
||||
d.get(b)
|
||||
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(b)
|
||||
}
|
||||
channelA.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(largeData))
|
||||
val receivedData = withTimeout(10000.milliseconds) { tcsDataB.await() }
|
||||
|
|
|
@ -14,7 +14,7 @@ interface IChannel : AutoCloseable {
|
|||
val remoteVersion: Int?
|
||||
var authorizable: IAuthorizable?
|
||||
fun setDataHandler(onData: ((SyncSocketSession, IChannel, UByte, UByte, ByteBuffer) -> Unit)?)
|
||||
fun send(opcode: UByte, subOpcode: UByte, data: ByteBuffer? = null)
|
||||
fun send(opcode: UByte, subOpcode: UByte = 0u, data: ByteBuffer? = null)
|
||||
fun setCloseHandler(onClose: ((IChannel) -> Unit)?)
|
||||
}
|
||||
|
||||
|
@ -326,12 +326,4 @@ class ChannelRelayed(
|
|||
completeHandshake(remoteVersion, transport)
|
||||
}
|
||||
}
|
||||
|
||||
fun handleData(data: ByteBuffer) {
|
||||
val size = data.int
|
||||
if (size != data.remaining()) throw IllegalStateException("Incomplete packet received")
|
||||
val opcode = data.get().toUByte()
|
||||
val subOpcode = data.get().toUByte()
|
||||
invokeDataHandler(opcode, subOpcode, data)
|
||||
}
|
||||
}
|
|
@ -155,7 +155,7 @@ class SyncSocketSession {
|
|||
val plen: Int = _cipherStatePair!!.receiver.decryptWithAd(null, _buffer, 0, _bufferDecrypted, 0, messageSize)
|
||||
//Logger.i(TAG, "Decrypted message (size = ${plen})")
|
||||
|
||||
handleData(_bufferDecrypted, plen)
|
||||
handleData(_bufferDecrypted, plen, null)
|
||||
} catch (e: Throwable) {
|
||||
Logger.e(TAG, "Exception while receiving data", e)
|
||||
break
|
||||
|
@ -374,22 +374,25 @@ class SyncSocketSession {
|
|||
}
|
||||
}
|
||||
|
||||
@OptIn(ExperimentalUnsignedTypes::class)
|
||||
private fun handleData(data: ByteArray, length: Int) {
|
||||
private fun handleData(data: ByteArray, length: Int, sourceChannel: ChannelRelayed?) {
|
||||
return handleData(ByteBuffer.wrap(data, 0, length).order(ByteOrder.LITTLE_ENDIAN), sourceChannel)
|
||||
}
|
||||
|
||||
private fun handleData(data: ByteBuffer, sourceChannel: ChannelRelayed?) {
|
||||
val length = data.remaining()
|
||||
if (length < HEADER_SIZE)
|
||||
throw Exception("Packet must be at least 6 bytes (header size)")
|
||||
|
||||
val size = ByteBuffer.wrap(data, 0, 4).order(ByteOrder.LITTLE_ENDIAN).int
|
||||
val size = data.int
|
||||
if (size != length - 4)
|
||||
throw Exception("Incomplete packet received")
|
||||
|
||||
val opcode = data.asUByteArray()[4]
|
||||
val subOpcode = data.asUByteArray()[5]
|
||||
val packetData = ByteBuffer.wrap(data, HEADER_SIZE, size - 2)
|
||||
handlePacket(opcode, subOpcode, packetData.order(ByteOrder.LITTLE_ENDIAN))
|
||||
val opcode = data.get().toUByte()
|
||||
val subOpcode = data.get().toUByte()
|
||||
handlePacket(opcode, subOpcode, data, sourceChannel)
|
||||
}
|
||||
|
||||
private fun handleRequest(subOpcode: UByte, data: ByteBuffer) {
|
||||
private fun handleRequest(subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) {
|
||||
when (subOpcode) {
|
||||
RequestOpcode.TRANSPORT_RELAYED.value -> {
|
||||
Logger.i(TAG, "Received request for a relayed transport")
|
||||
|
@ -440,7 +443,7 @@ class SyncSocketSession {
|
|||
}
|
||||
}
|
||||
|
||||
private fun handleResponse(subOpcode: UByte, data: ByteBuffer) {
|
||||
private fun handleResponse(subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) {
|
||||
if (data.remaining() < 8) {
|
||||
Logger.e(TAG, "Response packet too short")
|
||||
return
|
||||
|
@ -651,7 +654,7 @@ class SyncSocketSession {
|
|||
return ConnectionInfo(port, name, remoteIp, ipv4Addresses, ipv6Addresses, allowLocalDirect, allowRemoteDirect, allowRemoteHolePunched, allowRemoteProxied)
|
||||
}
|
||||
|
||||
private fun handleNotify(subOpcode: UByte, data: ByteBuffer) {
|
||||
private fun handleNotify(subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) {
|
||||
when (subOpcode) {
|
||||
NotifyOpcode.AUTHORIZED.value, NotifyOpcode.UNAUTHORIZED.value -> _onData?.invoke(this, Opcode.NOTIFY.value, subOpcode, data)
|
||||
NotifyOpcode.CONNECTION_INFO.value -> { /* Handle connection info if needed */ }
|
||||
|
@ -666,7 +669,7 @@ class SyncSocketSession {
|
|||
send(Opcode.RELAY.value, RelayOpcode.RELAY_ERROR.value, packet)
|
||||
}
|
||||
|
||||
private fun handleRelay(subOpcode: UByte, data: ByteBuffer) {
|
||||
private fun handleRelay(subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) {
|
||||
when (subOpcode) {
|
||||
RelayOpcode.RELAYED_DATA.value -> {
|
||||
if (data.remaining() < 8) {
|
||||
|
@ -680,7 +683,7 @@ class SyncSocketSession {
|
|||
}
|
||||
val decryptedPayload = channel.decrypt(data)
|
||||
try {
|
||||
channel.handleData(decryptedPayload)
|
||||
handleData(decryptedPayload, channel)
|
||||
} catch (e: Exception) {
|
||||
Logger.e(TAG, "Exception while handling relayed data", e)
|
||||
channel.sendError(SyncErrorCode.ConnectionClosed)
|
||||
|
@ -726,33 +729,36 @@ class SyncSocketSession {
|
|||
}
|
||||
}
|
||||
|
||||
private fun handlePacket(opcode: UByte, subOpcode: UByte, data: ByteBuffer) {
|
||||
private fun handlePacket(opcode: UByte, subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) {
|
||||
Logger.i(TAG, "Handle packet (opcode = ${opcode}, subOpcode = ${subOpcode})")
|
||||
|
||||
when (opcode) {
|
||||
Opcode.PING.value -> {
|
||||
send(Opcode.PONG.value)
|
||||
if (sourceChannel != null)
|
||||
sourceChannel.send(Opcode.PONG.value)
|
||||
else
|
||||
send(Opcode.PONG.value)
|
||||
//Logger.i(TAG, "Received ping, sent pong")
|
||||
return
|
||||
}
|
||||
Opcode.PONG.value -> {
|
||||
//Logger.i(TAG, "Received pong")
|
||||
Logger.v(TAG, "Received pong")
|
||||
return
|
||||
}
|
||||
Opcode.REQUEST.value -> {
|
||||
handleRequest(subOpcode, data)
|
||||
handleRequest(subOpcode, data, sourceChannel)
|
||||
return
|
||||
}
|
||||
Opcode.RESPONSE.value -> {
|
||||
handleResponse(subOpcode, data)
|
||||
handleResponse(subOpcode, data, sourceChannel)
|
||||
return
|
||||
}
|
||||
Opcode.NOTIFY.value -> {
|
||||
handleNotify(subOpcode, data)
|
||||
handleNotify(subOpcode, data, sourceChannel)
|
||||
return
|
||||
}
|
||||
Opcode.RELAY.value -> {
|
||||
handleRelay(subOpcode, data)
|
||||
handleRelay(subOpcode, data, sourceChannel)
|
||||
return
|
||||
}
|
||||
else -> if (isAuthorized) when (opcode) {
|
||||
|
@ -809,12 +815,18 @@ class SyncSocketSession {
|
|||
throw Exception("After sync stream end, the stream must be complete")
|
||||
}
|
||||
|
||||
handlePacket(syncStream.opcode, syncStream.subOpcode, syncStream.getBytes().let { ByteBuffer.wrap(it).order(ByteOrder.LITTLE_ENDIAN) })
|
||||
}
|
||||
else -> {
|
||||
Logger.w(TAG, "Unknown opcode received (opcode = ${opcode}, subOpcode = ${subOpcode})")
|
||||
handlePacket(syncStream.opcode, syncStream.subOpcode, syncStream.getBytes().let { ByteBuffer.wrap(it).order(ByteOrder.LITTLE_ENDIAN) }, sourceChannel)
|
||||
}
|
||||
}
|
||||
Opcode.DATA.value -> {
|
||||
if (sourceChannel != null)
|
||||
sourceChannel.invokeDataHandler(opcode, subOpcode, data)
|
||||
else
|
||||
_onData?.invoke(this, opcode, subOpcode, data)
|
||||
}
|
||||
else -> {
|
||||
Logger.w(TAG, "Unknown opcode received (opcode = ${opcode}, subOpcode = ${subOpcode})")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -995,18 +1007,27 @@ class SyncSocketSession {
|
|||
val deferred = CompletableDeferred<Boolean>()
|
||||
_pendingPublishRequests[requestId] = deferred
|
||||
try {
|
||||
val MAX_PLAINTEXT_SIZE = 65535 - 16 // Adjust for tag size
|
||||
val MAX_PLAINTEXT_SIZE = 65535
|
||||
val HANDSHAKE_SIZE = 48
|
||||
val LENGTH_SIZE = 4
|
||||
val TAG_SIZE = 16
|
||||
val chunkCount = (data.size + MAX_PLAINTEXT_SIZE - 1) / MAX_PLAINTEXT_SIZE
|
||||
val blobSize = HANDSHAKE_SIZE + chunkCount * (LENGTH_SIZE + MAX_PLAINTEXT_SIZE + TAG_SIZE)
|
||||
|
||||
var blobSize = HANDSHAKE_SIZE
|
||||
var dataOffset = 0
|
||||
for (i in 0 until chunkCount) {
|
||||
val chunkSize = minOf(MAX_PLAINTEXT_SIZE, data.size - dataOffset)
|
||||
blobSize += LENGTH_SIZE + (chunkSize + TAG_SIZE)
|
||||
dataOffset += chunkSize
|
||||
}
|
||||
|
||||
val totalPacketSize = 4 + 1 + keyBytes.size + 1 + consumerPublicKeys.size * (32 + 4 + blobSize)
|
||||
val packet = ByteBuffer.allocate(totalPacketSize).order(ByteOrder.LITTLE_ENDIAN)
|
||||
packet.putInt(requestId)
|
||||
packet.put(keyBytes.size.toByte())
|
||||
packet.put(keyBytes)
|
||||
packet.put(consumerPublicKeys.size.toByte())
|
||||
|
||||
for (consumer in consumerPublicKeys) {
|
||||
val consumerBytes = Base64.getDecoder().decode(consumer)
|
||||
if (consumerBytes.size != 32) throw IllegalArgumentException("Consumer public key must be 32 bytes")
|
||||
|
@ -1020,9 +1041,10 @@ class SyncSocketSession {
|
|||
val transportPair = protocol.split()
|
||||
packet.putInt(blobSize)
|
||||
packet.put(handshakeMessage)
|
||||
var dataOffset = 0
|
||||
|
||||
dataOffset = 0
|
||||
for (i in 0 until chunkCount) {
|
||||
val chunkSize = min(MAX_PLAINTEXT_SIZE, data.size - dataOffset)
|
||||
val chunkSize = minOf(MAX_PLAINTEXT_SIZE, data.size - dataOffset)
|
||||
val plaintext = data.copyOfRange(dataOffset, dataOffset + chunkSize)
|
||||
val ciphertext = ByteArray(chunkSize + TAG_SIZE)
|
||||
val written = transportPair.sender.encryptWithAd(null, plaintext, 0, ciphertext, 0, plaintext.size)
|
||||
|
|
Loading…
Add table
Reference in a new issue