diff --git a/app/src/main/java/com/futo/platformplayer/sync/internal/Channel.kt b/app/src/main/java/com/futo/platformplayer/sync/internal/Channel.kt index 38408c25..89e4e3a7 100644 --- a/app/src/main/java/com/futo/platformplayer/sync/internal/Channel.kt +++ b/app/src/main/java/com/futo/platformplayer/sync/internal/Channel.kt @@ -96,11 +96,39 @@ class ChannelRelayed( private var onData: ((SyncSocketSession, IChannel, UByte, UByte, ByteBuffer) -> Unit)? = null private var onClose: ((IChannel) -> Unit)? = null private var disposed = false + private var _lastPongTime: Long = 0 + private val _pingInterval: Long = 5000 // 5 seconds in milliseconds + private val _disconnectTimeout: Long = 30000 // 30 seconds in milliseconds init { handshakeState?.start() } + private fun startPingLoop() { + if (remoteVersion!! < 5) { + return + } + + _lastPongTime = System.currentTimeMillis() + + Thread { + try { + while (!disposed) { + Thread.sleep(_pingInterval) + if (System.currentTimeMillis() - _lastPongTime > _disconnectTimeout) { + Logger.e("ChannelRelayed", "Channel timed out waiting for PONG; closing.") + close() + break + } + send(Opcode.PING.value, 0u) + } + } catch (e: Exception) { + Logger.e("ChannelRelayed", "Ping loop failed", e) + close() + } + }.start() + } + override fun setDataHandler(onData: ((SyncSocketSession, IChannel, UByte, UByte, ByteBuffer) -> Unit)?) { this.onData = onData } @@ -136,6 +164,10 @@ class ChannelRelayed( } fun invokeDataHandler(opcode: UByte, subOpcode: UByte, data: ByteBuffer) { + if (opcode == Opcode.PONG.value) { + _lastPongTime = System.currentTimeMillis() + return + } onData?.invoke(session, this, opcode, subOpcode, data) } @@ -150,6 +182,7 @@ class ChannelRelayed( handshakeState = null this.transport = transport Logger.i("ChannelRelayed", "Completed handshake for connectionId $connectionId") + startPingLoop() } private fun sendPacket(packet: ByteArray) { diff --git a/app/src/main/java/com/futo/platformplayer/sync/internal/SyncSocketSession.kt b/app/src/main/java/com/futo/platformplayer/sync/internal/SyncSocketSession.kt index 5c3bdb9d..3dc81334 100644 --- a/app/src/main/java/com/futo/platformplayer/sync/internal/SyncSocketSession.kt +++ b/app/src/main/java/com/futo/platformplayer/sync/internal/SyncSocketSession.kt @@ -9,11 +9,15 @@ import com.futo.platformplayer.logging.Logger import com.futo.platformplayer.noise.protocol.CipherStatePair import com.futo.platformplayer.noise.protocol.DHState import com.futo.platformplayer.noise.protocol.HandshakeState +import com.futo.platformplayer.states.StateApp import com.futo.platformplayer.states.StateSync import com.futo.platformplayer.sync.internal.ChannelRelayed.Companion import com.futo.polycentric.core.base64ToByteArray import com.futo.polycentric.core.toBase64 import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch import java.io.ByteArrayInputStream import java.io.ByteArrayOutputStream import java.io.InputStream @@ -80,6 +84,11 @@ class SyncSocketSession { private val _pendingBulkGetRecordRequests = ConcurrentHashMap>>>() private val _pendingBulkConnectionInfoRequests = ConcurrentHashMap>>() + @Volatile + private var _lastPongTime: Long = System.currentTimeMillis() + private val _pingInterval: Long = 5000 // 5 seconds in milliseconds + private val _disconnectTimeout: Long = 30000 // 30 seconds in milliseconds + data class ConnectionInfo( val port: UShort, val name: String, @@ -129,6 +138,7 @@ class SyncSocketSession { try { handshakeAsInitiator(remotePublicKey, appId, pairingCode) _onHandshakeComplete?.invoke(this) + startPingLoop() receiveLoop() } catch (e: Throwable) { Logger.e(TAG, "Failed to run as initiator", e) @@ -143,6 +153,7 @@ class SyncSocketSession { try { handshakeAsInitiator(remotePublicKey, appId, pairingCode) _onHandshakeComplete?.invoke(this) + startPingLoop() receiveLoop() } catch (e: Throwable) { Logger.e(TAG, "Failed to run as initiator", e) @@ -157,6 +168,7 @@ class SyncSocketSession { try { if (handshakeAsResponder()) { _onHandshakeComplete?.invoke(this) + startPingLoop() receiveLoop() } } catch (e: Throwable) { @@ -352,7 +364,7 @@ class SyncSocketSession { } private fun performVersionCheck() { - val CURRENT_VERSION = 4 + val CURRENT_VERSION = 5 val MINIMUM_VERSION = 4 val versionBytes = ByteArray(4) @@ -833,6 +845,30 @@ class SyncSocketSession { } } + private fun startPingLoop() { + if (remoteVersion < 5) return + + _lastPongTime = System.currentTimeMillis() + + StateApp.instance.scopeOrNull?.launch(Dispatchers.IO) { + try { + while (_started) { + delay(_pingInterval) + + if (System.currentTimeMillis() - _lastPongTime > _disconnectTimeout) { + Logger.e(TAG, "Session timed out waiting for PONG; closing.") + stop() + break + } + send(Opcode.PING.value) + } + } catch (e: Exception) { + Logger.e(TAG, "Ping loop failed", e) + stop() + } + } + } + private fun handlePacket(opcode: UByte, subOpcode: UByte, d: ByteBuffer, contentEncoding: UByte, sourceChannel: ChannelRelayed?) { Logger.i(TAG, "Handle packet (opcode = ${opcode}, subOpcode = ${subOpcode})") @@ -864,6 +900,11 @@ class SyncSocketSession { return } Opcode.PONG.value -> { + if (sourceChannel != null) { + sourceChannel.invokeDataHandler(opcode, subOpcode, data) + } else { + _lastPongTime = System.currentTimeMillis() + } Logger.v(TAG, "Received pong") return }