diff --git a/app/src/main/java/com/futo/platformplayer/api/http/server/HttpContext.kt b/app/src/main/java/com/futo/platformplayer/api/http/server/HttpContext.kt index 5b7740c8..9302629d 100644 --- a/app/src/main/java/com/futo/platformplayer/api/http/server/HttpContext.kt +++ b/app/src/main/java/com/futo/platformplayer/api/http/server/HttpContext.kt @@ -7,19 +7,21 @@ import com.futo.platformplayer.api.media.Serializer import kotlinx.serialization.decodeFromString import kotlinx.serialization.encodeToString import kotlinx.serialization.json.Json -import java.io.BufferedInputStream import java.io.BufferedReader +import java.io.ByteArrayInputStream +import java.io.ByteArrayOutputStream +import java.io.IOException +import java.io.InputStream import java.io.OutputStream import java.io.StringWriter import java.net.SocketTimeoutException -import java.nio.ByteBuffer class HttpContext : AutoCloseable { - private val _stream: BufferedInputStream; + private val _inputStream: InputStream; private var _responseStream: OutputStream? = null; - + var id: String? = null; - + var head: String = ""; var headers: HttpHeaders = HttpHeaders(); @@ -40,103 +42,131 @@ class HttpContext : AutoCloseable { private val _responseHeaders: HttpHeaders = HttpHeaders(); - private val newLineByte = "\n"[0]; - private fun readStreamLine(): String { - //TODO: This is not ideal.. - var twoByteArray = ByteBuffer.allocate(2); - var lastChar: Char = Char.MIN_VALUE; - val builder = StringBuilder(); - do { - val firstByte = _stream.read(); - if(firstByte == -1) - break; - if(isCharacter2Bytes(firstByte)) { - twoByteArray.put(0, firstByte.toByte()); - val secondByte = _stream.read(); - if(secondByte == -1) - break; - twoByteArray.put(1, secondByte.toByte()); - } - else - lastChar = firstByte.toChar(); - builder.append(lastChar); - if(lastChar == newLineByte) - break; - } - while(lastChar != Char.MIN_VALUE); - return builder.toString(); - } - constructor(stream: BufferedInputStream, responseStream: OutputStream? = null, requestId: String? = null, timeout: Int? = null) { - _stream = stream; + constructor(inputStream: InputStream, responseStream: OutputStream? = null, requestId: String? = null, timeout: Int? = null) { + _inputStream = inputStream; _responseStream = responseStream; this.id = requestId; - try { - head = readStreamLine() ?: throw EmptyRequestException("No head found"); - } - catch(ex: SocketTimeoutException) { - if((timeout ?: 0) > 0) - throw KeepAliveTimeoutException("Keep-Alive timedout", ex); - throw ex; - } - - val methodEndIndex = head.indexOf(' '); - val urlEndIndex = head.indexOf(' ', methodEndIndex + 1); - if (methodEndIndex == -1 || urlEndIndex == -1) { - Logger.w(TAG, "Skipped request, wrong format."); - throw IllegalStateException("Invalid request"); - } - - method = head.substring(0, methodEndIndex); - path = head.substring(methodEndIndex + 1, urlEndIndex); - - if (path.contains("?")) { - val queryPartIndex = path.indexOf("?"); - val queryParts = path.substring(queryPartIndex + 1).split("&"); - path = path.substring(0, queryPartIndex); - - for(queryPart in queryParts) { - val eqIndex = queryPart.indexOf("="); - if(eqIndex > 0) - query.put(queryPart.substring(0, eqIndex), queryPart.substring(eqIndex + 1)); - else - query.put(queryPart, ""); + val headerBytes = readHeaderBytes() + ByteArrayInputStream(headerBytes).use { + val reader = it.bufferedReader(Charsets.UTF_8) + try { + head = reader.readLine() ?: throw EmptyRequestException("No head found"); + } + catch(ex: SocketTimeoutException) { + if((timeout ?: 0) > 0) + throw KeepAliveTimeoutException("Keep-Alive timedout", ex); + throw ex; } - } - while (true) { - val line = readStreamLine(); - val headerEndIndex = line.indexOf(":"); - if (headerEndIndex == -1) - break; + val methodEndIndex = head.indexOf(' '); + val urlEndIndex = head.indexOf(' ', methodEndIndex + 1); + if (methodEndIndex == -1 || urlEndIndex == -1) { + Logger.w(TAG, "Skipped request, wrong format."); + throw IllegalStateException("Invalid request"); + } - val headerKey = line.substring(0, headerEndIndex).lowercase() - val headerValue = line.substring(headerEndIndex + 1).trim(); - headers[headerKey] = headerValue; + method = head.substring(0, methodEndIndex); + path = head.substring(methodEndIndex + 1, urlEndIndex); - when(headerKey) { - "content-length" -> contentLength = headerValue.toLong(); - "content-type" -> contentType = headerValue; - "connection" -> keepAlive = headerValue.lowercase() == "keep-alive"; - "keep-alive" -> { - val keepAliveParams = headerValue.split(","); - for(keepAliveParam in keepAliveParams) { - val eqIndex = keepAliveParam.indexOf("="); - if(eqIndex > 0){ - when(keepAliveParam.substring(0, eqIndex)) { - "timeout" -> keepAliveTimeout = keepAliveParam.substring(eqIndex+1).toInt(); - "max" -> keepAliveTimeout = keepAliveParam.substring(eqIndex+1).toInt(); + if (path.contains("?")) { + val queryPartIndex = path.indexOf("?"); + val queryParts = path.substring(queryPartIndex + 1).split("&"); + path = path.substring(0, queryPartIndex); + + for(queryPart in queryParts) { + val eqIndex = queryPart.indexOf("="); + if(eqIndex > 0) + query.put(queryPart.substring(0, eqIndex), queryPart.substring(eqIndex + 1)); + else + query.put(queryPart, ""); + } + } + + while (true) { + val line = reader.readLine(); + val headerEndIndex = line.indexOf(":"); + if (headerEndIndex == -1) + break; + + val headerKey = line.substring(0, headerEndIndex).lowercase() + val headerValue = line.substring(headerEndIndex + 1).trim(); + headers[headerKey] = headerValue; + + when(headerKey) { + "content-length" -> contentLength = headerValue.toLong(); + "content-type" -> contentType = headerValue; + "connection" -> keepAlive = headerValue.lowercase() == "keep-alive"; + "keep-alive" -> { + val keepAliveParams = headerValue.split(","); + for(keepAliveParam in keepAliveParams) { + val eqIndex = keepAliveParam.indexOf("="); + if(eqIndex > 0){ + when(keepAliveParam.substring(0, eqIndex)) { + "timeout" -> keepAliveTimeout = keepAliveParam.substring(eqIndex+1).toInt(); + "max" -> keepAliveTimeout = keepAliveParam.substring(eqIndex+1).toInt(); + } } } } } + if(line.isNullOrEmpty()) + break; } - if(line.isNullOrEmpty()) - break; } } + private fun readHeaderBytes(): ByteArray { + val headerBytes = ByteArrayOutputStream() + var crlfCount = 0 + + while (crlfCount < 4) { + val b = _inputStream.read() + if (b == -1) { + throw IOException("Unexpected end of stream while reading headers") + } + + if (b == 0x0D || b == 0x0A) { // CR or LF + crlfCount++ + } else { + crlfCount = 0 + } + + headerBytes.write(b) + } + + return headerBytes.toByteArray() + } + + fun readContentBytes(buffer: ByteArray, length: Int): Int { + val remainingBytes = (contentLength - _totalRead).coerceAtMost(length.toLong()).toInt() + val read = _inputStream.read(buffer, 0, remainingBytes); + if (read > 0) { + _totalRead += read + } + + return read; + } + fun readContentString(): String { + val byteArrayOutputStream = ByteArrayOutputStream() + val buffer = ByteArray(4096) + var read: Int + while (true) { + read = readContentBytes(buffer, buffer.size) + if (read <= 0) break + byteArrayOutputStream.write(buffer, 0, read) + } + return byteArrayOutputStream.toString(Charsets.UTF_8.name()) + } + inline fun readContentJson() : T { + return Serializer.json.decodeFromString(readContentString()); + } + fun skipBody() { + if (contentLength > 0) + _inputStream.skip(contentLength - _totalRead) + } + fun getHttpHeaderString(): String { val writer = StringWriter(); writer.write(head + "\r\n"); @@ -200,58 +230,13 @@ class HttpContext : AutoCloseable { statusCode = status; } - fun readContentBytes(buffer: ByteArray, length: Int) : Int { - val reading = if(contentLength - _totalRead < length) - (contentLength - _totalRead).toInt(); - else - length; - val read = _stream.read(buffer, 0, reading); - _totalRead += read; - return read; - } - fun readContentString() : String{ - val writer = StringWriter(); - var read = 0; - val buffer = ByteArray(8192); - val twoByteArray = ByteArray(2); - do { - read = readContentBytes(buffer, buffer.size); - - if(read > 0) { - if (isCharacter2Bytes(buffer[read - 1].toInt())) { - //Fixes overlapping buffers - writer.write(String(buffer, 0, read - 1)); - twoByteArray[0] = buffer[read - 1]; - val secondByte = _stream.read(); - if (secondByte < 0) - break; - twoByteArray[1] = secondByte.toByte(); - writer.write(String(twoByteArray)); - } else - writer.write(String(buffer, 0, read)); - } - } while(read > 0); - return writer.toString(); - } - inline fun readContentJson() : T { - return Serializer.json.decodeFromString(readContentString()); - } - fun skipBody() { - if(contentLength > 0) - _stream.skip(contentLength - _totalRead); - } - override fun close() { if(!keepAlive) { - _stream?.close(); + _inputStream.close(); _responseStream?.close(); } } - private fun isCharacter2Bytes(firstByte: Int): Boolean { - return firstByte and 0xE0 == 0xC0 - } - companion object { private val TAG = "HttpRequest"; private val statusCodeMap = mapOf(