Another iteration on the HttpContext fix.

This commit is contained in:
Koen 2023-11-20 12:14:03 +01:00
commit 8661ff88c0

View file

@ -7,19 +7,21 @@ import com.futo.platformplayer.api.media.Serializer
import kotlinx.serialization.decodeFromString import kotlinx.serialization.decodeFromString
import kotlinx.serialization.encodeToString import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json import kotlinx.serialization.json.Json
import java.io.BufferedInputStream
import java.io.BufferedReader import java.io.BufferedReader
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.IOException
import java.io.InputStream
import java.io.OutputStream import java.io.OutputStream
import java.io.StringWriter import java.io.StringWriter
import java.net.SocketTimeoutException import java.net.SocketTimeoutException
import java.nio.ByteBuffer
class HttpContext : AutoCloseable { class HttpContext : AutoCloseable {
private val _stream: BufferedInputStream; private val _inputStream: InputStream;
private var _responseStream: OutputStream? = null; private var _responseStream: OutputStream? = null;
var id: String? = null; var id: String? = null;
var head: String = ""; var head: String = "";
var headers: HttpHeaders = HttpHeaders(); var headers: HttpHeaders = HttpHeaders();
@ -40,103 +42,131 @@ class HttpContext : AutoCloseable {
private val _responseHeaders: HttpHeaders = HttpHeaders(); private val _responseHeaders: HttpHeaders = HttpHeaders();
private val newLineByte = "\n"[0];
private fun readStreamLine(): String {
//TODO: This is not ideal..
var twoByteArray = ByteBuffer.allocate(2);
var lastChar: Char = Char.MIN_VALUE;
val builder = StringBuilder();
do {
val firstByte = _stream.read();
if(firstByte == -1)
break;
if(isCharacter2Bytes(firstByte)) {
twoByteArray.put(0, firstByte.toByte());
val secondByte = _stream.read();
if(secondByte == -1)
break;
twoByteArray.put(1, secondByte.toByte());
}
else
lastChar = firstByte.toChar();
builder.append(lastChar);
if(lastChar == newLineByte)
break;
}
while(lastChar != Char.MIN_VALUE);
return builder.toString();
}
constructor(stream: BufferedInputStream, responseStream: OutputStream? = null, requestId: String? = null, timeout: Int? = null) { constructor(inputStream: InputStream, responseStream: OutputStream? = null, requestId: String? = null, timeout: Int? = null) {
_stream = stream; _inputStream = inputStream;
_responseStream = responseStream; _responseStream = responseStream;
this.id = requestId; this.id = requestId;
try { val headerBytes = readHeaderBytes()
head = readStreamLine() ?: throw EmptyRequestException("No head found"); ByteArrayInputStream(headerBytes).use {
} val reader = it.bufferedReader(Charsets.UTF_8)
catch(ex: SocketTimeoutException) { try {
if((timeout ?: 0) > 0) head = reader.readLine() ?: throw EmptyRequestException("No head found");
throw KeepAliveTimeoutException("Keep-Alive timedout", ex); }
throw ex; catch(ex: SocketTimeoutException) {
} if((timeout ?: 0) > 0)
throw KeepAliveTimeoutException("Keep-Alive timedout", ex);
val methodEndIndex = head.indexOf(' '); throw ex;
val urlEndIndex = head.indexOf(' ', methodEndIndex + 1);
if (methodEndIndex == -1 || urlEndIndex == -1) {
Logger.w(TAG, "Skipped request, wrong format.");
throw IllegalStateException("Invalid request");
}
method = head.substring(0, methodEndIndex);
path = head.substring(methodEndIndex + 1, urlEndIndex);
if (path.contains("?")) {
val queryPartIndex = path.indexOf("?");
val queryParts = path.substring(queryPartIndex + 1).split("&");
path = path.substring(0, queryPartIndex);
for(queryPart in queryParts) {
val eqIndex = queryPart.indexOf("=");
if(eqIndex > 0)
query.put(queryPart.substring(0, eqIndex), queryPart.substring(eqIndex + 1));
else
query.put(queryPart, "");
} }
}
while (true) { val methodEndIndex = head.indexOf(' ');
val line = readStreamLine(); val urlEndIndex = head.indexOf(' ', methodEndIndex + 1);
val headerEndIndex = line.indexOf(":"); if (methodEndIndex == -1 || urlEndIndex == -1) {
if (headerEndIndex == -1) Logger.w(TAG, "Skipped request, wrong format.");
break; throw IllegalStateException("Invalid request");
}
val headerKey = line.substring(0, headerEndIndex).lowercase() method = head.substring(0, methodEndIndex);
val headerValue = line.substring(headerEndIndex + 1).trim(); path = head.substring(methodEndIndex + 1, urlEndIndex);
headers[headerKey] = headerValue;
when(headerKey) { if (path.contains("?")) {
"content-length" -> contentLength = headerValue.toLong(); val queryPartIndex = path.indexOf("?");
"content-type" -> contentType = headerValue; val queryParts = path.substring(queryPartIndex + 1).split("&");
"connection" -> keepAlive = headerValue.lowercase() == "keep-alive"; path = path.substring(0, queryPartIndex);
"keep-alive" -> {
val keepAliveParams = headerValue.split(","); for(queryPart in queryParts) {
for(keepAliveParam in keepAliveParams) { val eqIndex = queryPart.indexOf("=");
val eqIndex = keepAliveParam.indexOf("="); if(eqIndex > 0)
if(eqIndex > 0){ query.put(queryPart.substring(0, eqIndex), queryPart.substring(eqIndex + 1));
when(keepAliveParam.substring(0, eqIndex)) { else
"timeout" -> keepAliveTimeout = keepAliveParam.substring(eqIndex+1).toInt(); query.put(queryPart, "");
"max" -> keepAliveTimeout = keepAliveParam.substring(eqIndex+1).toInt(); }
}
while (true) {
val line = reader.readLine();
val headerEndIndex = line.indexOf(":");
if (headerEndIndex == -1)
break;
val headerKey = line.substring(0, headerEndIndex).lowercase()
val headerValue = line.substring(headerEndIndex + 1).trim();
headers[headerKey] = headerValue;
when(headerKey) {
"content-length" -> contentLength = headerValue.toLong();
"content-type" -> contentType = headerValue;
"connection" -> keepAlive = headerValue.lowercase() == "keep-alive";
"keep-alive" -> {
val keepAliveParams = headerValue.split(",");
for(keepAliveParam in keepAliveParams) {
val eqIndex = keepAliveParam.indexOf("=");
if(eqIndex > 0){
when(keepAliveParam.substring(0, eqIndex)) {
"timeout" -> keepAliveTimeout = keepAliveParam.substring(eqIndex+1).toInt();
"max" -> keepAliveTimeout = keepAliveParam.substring(eqIndex+1).toInt();
}
} }
} }
} }
} }
if(line.isNullOrEmpty())
break;
} }
if(line.isNullOrEmpty())
break;
} }
} }
private fun readHeaderBytes(): ByteArray {
val headerBytes = ByteArrayOutputStream()
var crlfCount = 0
while (crlfCount < 4) {
val b = _inputStream.read()
if (b == -1) {
throw IOException("Unexpected end of stream while reading headers")
}
if (b == 0x0D || b == 0x0A) { // CR or LF
crlfCount++
} else {
crlfCount = 0
}
headerBytes.write(b)
}
return headerBytes.toByteArray()
}
fun readContentBytes(buffer: ByteArray, length: Int): Int {
val remainingBytes = (contentLength - _totalRead).coerceAtMost(length.toLong()).toInt()
val read = _inputStream.read(buffer, 0, remainingBytes);
if (read > 0) {
_totalRead += read
}
return read;
}
fun readContentString(): String {
val byteArrayOutputStream = ByteArrayOutputStream()
val buffer = ByteArray(4096)
var read: Int
while (true) {
read = readContentBytes(buffer, buffer.size)
if (read <= 0) break
byteArrayOutputStream.write(buffer, 0, read)
}
return byteArrayOutputStream.toString(Charsets.UTF_8.name())
}
inline fun <reified T> readContentJson() : T {
return Serializer.json.decodeFromString(readContentString());
}
fun skipBody() {
if (contentLength > 0)
_inputStream.skip(contentLength - _totalRead)
}
fun getHttpHeaderString(): String { fun getHttpHeaderString(): String {
val writer = StringWriter(); val writer = StringWriter();
writer.write(head + "\r\n"); writer.write(head + "\r\n");
@ -200,58 +230,13 @@ class HttpContext : AutoCloseable {
statusCode = status; statusCode = status;
} }
fun readContentBytes(buffer: ByteArray, length: Int) : Int {
val reading = if(contentLength - _totalRead < length)
(contentLength - _totalRead).toInt();
else
length;
val read = _stream.read(buffer, 0, reading);
_totalRead += read;
return read;
}
fun readContentString() : String{
val writer = StringWriter();
var read = 0;
val buffer = ByteArray(8192);
val twoByteArray = ByteArray(2);
do {
read = readContentBytes(buffer, buffer.size);
if(read > 0) {
if (isCharacter2Bytes(buffer[read - 1].toInt())) {
//Fixes overlapping buffers
writer.write(String(buffer, 0, read - 1));
twoByteArray[0] = buffer[read - 1];
val secondByte = _stream.read();
if (secondByte < 0)
break;
twoByteArray[1] = secondByte.toByte();
writer.write(String(twoByteArray));
} else
writer.write(String(buffer, 0, read));
}
} while(read > 0);
return writer.toString();
}
inline fun <reified T> readContentJson() : T {
return Serializer.json.decodeFromString(readContentString());
}
fun skipBody() {
if(contentLength > 0)
_stream.skip(contentLength - _totalRead);
}
override fun close() { override fun close() {
if(!keepAlive) { if(!keepAlive) {
_stream?.close(); _inputStream.close();
_responseStream?.close(); _responseStream?.close();
} }
} }
private fun isCharacter2Bytes(firstByte: Int): Boolean {
return firstByte and 0xE0 == 0xC0
}
companion object { companion object {
private val TAG = "HttpRequest"; private val TAG = "HttpRequest";
private val statusCodeMap = mapOf( private val statusCodeMap = mapOf(