Another iteration on the HttpContext fix.

This commit is contained in:
Koen 2023-11-20 12:14:03 +01:00
parent 0bba7fa373
commit 8661ff88c0

View file

@ -7,19 +7,21 @@ import com.futo.platformplayer.api.media.Serializer
import kotlinx.serialization.decodeFromString
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import java.io.BufferedInputStream
import java.io.BufferedReader
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.IOException
import java.io.InputStream
import java.io.OutputStream
import java.io.StringWriter
import java.net.SocketTimeoutException
import java.nio.ByteBuffer
class HttpContext : AutoCloseable {
private val _stream: BufferedInputStream;
private val _inputStream: InputStream;
private var _responseStream: OutputStream? = null;
var id: String? = null;
var head: String = "";
var headers: HttpHeaders = HttpHeaders();
@ -40,103 +42,131 @@ class HttpContext : AutoCloseable {
private val _responseHeaders: HttpHeaders = HttpHeaders();
private val newLineByte = "\n"[0];
private fun readStreamLine(): String {
//TODO: This is not ideal..
var twoByteArray = ByteBuffer.allocate(2);
var lastChar: Char = Char.MIN_VALUE;
val builder = StringBuilder();
do {
val firstByte = _stream.read();
if(firstByte == -1)
break;
if(isCharacter2Bytes(firstByte)) {
twoByteArray.put(0, firstByte.toByte());
val secondByte = _stream.read();
if(secondByte == -1)
break;
twoByteArray.put(1, secondByte.toByte());
}
else
lastChar = firstByte.toChar();
builder.append(lastChar);
if(lastChar == newLineByte)
break;
}
while(lastChar != Char.MIN_VALUE);
return builder.toString();
}
constructor(stream: BufferedInputStream, responseStream: OutputStream? = null, requestId: String? = null, timeout: Int? = null) {
_stream = stream;
constructor(inputStream: InputStream, responseStream: OutputStream? = null, requestId: String? = null, timeout: Int? = null) {
_inputStream = inputStream;
_responseStream = responseStream;
this.id = requestId;
try {
head = readStreamLine() ?: throw EmptyRequestException("No head found");
}
catch(ex: SocketTimeoutException) {
if((timeout ?: 0) > 0)
throw KeepAliveTimeoutException("Keep-Alive timedout", ex);
throw ex;
}
val methodEndIndex = head.indexOf(' ');
val urlEndIndex = head.indexOf(' ', methodEndIndex + 1);
if (methodEndIndex == -1 || urlEndIndex == -1) {
Logger.w(TAG, "Skipped request, wrong format.");
throw IllegalStateException("Invalid request");
}
method = head.substring(0, methodEndIndex);
path = head.substring(methodEndIndex + 1, urlEndIndex);
if (path.contains("?")) {
val queryPartIndex = path.indexOf("?");
val queryParts = path.substring(queryPartIndex + 1).split("&");
path = path.substring(0, queryPartIndex);
for(queryPart in queryParts) {
val eqIndex = queryPart.indexOf("=");
if(eqIndex > 0)
query.put(queryPart.substring(0, eqIndex), queryPart.substring(eqIndex + 1));
else
query.put(queryPart, "");
val headerBytes = readHeaderBytes()
ByteArrayInputStream(headerBytes).use {
val reader = it.bufferedReader(Charsets.UTF_8)
try {
head = reader.readLine() ?: throw EmptyRequestException("No head found");
}
catch(ex: SocketTimeoutException) {
if((timeout ?: 0) > 0)
throw KeepAliveTimeoutException("Keep-Alive timedout", ex);
throw ex;
}
}
while (true) {
val line = readStreamLine();
val headerEndIndex = line.indexOf(":");
if (headerEndIndex == -1)
break;
val methodEndIndex = head.indexOf(' ');
val urlEndIndex = head.indexOf(' ', methodEndIndex + 1);
if (methodEndIndex == -1 || urlEndIndex == -1) {
Logger.w(TAG, "Skipped request, wrong format.");
throw IllegalStateException("Invalid request");
}
val headerKey = line.substring(0, headerEndIndex).lowercase()
val headerValue = line.substring(headerEndIndex + 1).trim();
headers[headerKey] = headerValue;
method = head.substring(0, methodEndIndex);
path = head.substring(methodEndIndex + 1, urlEndIndex);
when(headerKey) {
"content-length" -> contentLength = headerValue.toLong();
"content-type" -> contentType = headerValue;
"connection" -> keepAlive = headerValue.lowercase() == "keep-alive";
"keep-alive" -> {
val keepAliveParams = headerValue.split(",");
for(keepAliveParam in keepAliveParams) {
val eqIndex = keepAliveParam.indexOf("=");
if(eqIndex > 0){
when(keepAliveParam.substring(0, eqIndex)) {
"timeout" -> keepAliveTimeout = keepAliveParam.substring(eqIndex+1).toInt();
"max" -> keepAliveTimeout = keepAliveParam.substring(eqIndex+1).toInt();
if (path.contains("?")) {
val queryPartIndex = path.indexOf("?");
val queryParts = path.substring(queryPartIndex + 1).split("&");
path = path.substring(0, queryPartIndex);
for(queryPart in queryParts) {
val eqIndex = queryPart.indexOf("=");
if(eqIndex > 0)
query.put(queryPart.substring(0, eqIndex), queryPart.substring(eqIndex + 1));
else
query.put(queryPart, "");
}
}
while (true) {
val line = reader.readLine();
val headerEndIndex = line.indexOf(":");
if (headerEndIndex == -1)
break;
val headerKey = line.substring(0, headerEndIndex).lowercase()
val headerValue = line.substring(headerEndIndex + 1).trim();
headers[headerKey] = headerValue;
when(headerKey) {
"content-length" -> contentLength = headerValue.toLong();
"content-type" -> contentType = headerValue;
"connection" -> keepAlive = headerValue.lowercase() == "keep-alive";
"keep-alive" -> {
val keepAliveParams = headerValue.split(",");
for(keepAliveParam in keepAliveParams) {
val eqIndex = keepAliveParam.indexOf("=");
if(eqIndex > 0){
when(keepAliveParam.substring(0, eqIndex)) {
"timeout" -> keepAliveTimeout = keepAliveParam.substring(eqIndex+1).toInt();
"max" -> keepAliveTimeout = keepAliveParam.substring(eqIndex+1).toInt();
}
}
}
}
}
if(line.isNullOrEmpty())
break;
}
if(line.isNullOrEmpty())
break;
}
}
private fun readHeaderBytes(): ByteArray {
val headerBytes = ByteArrayOutputStream()
var crlfCount = 0
while (crlfCount < 4) {
val b = _inputStream.read()
if (b == -1) {
throw IOException("Unexpected end of stream while reading headers")
}
if (b == 0x0D || b == 0x0A) { // CR or LF
crlfCount++
} else {
crlfCount = 0
}
headerBytes.write(b)
}
return headerBytes.toByteArray()
}
fun readContentBytes(buffer: ByteArray, length: Int): Int {
val remainingBytes = (contentLength - _totalRead).coerceAtMost(length.toLong()).toInt()
val read = _inputStream.read(buffer, 0, remainingBytes);
if (read > 0) {
_totalRead += read
}
return read;
}
fun readContentString(): String {
val byteArrayOutputStream = ByteArrayOutputStream()
val buffer = ByteArray(4096)
var read: Int
while (true) {
read = readContentBytes(buffer, buffer.size)
if (read <= 0) break
byteArrayOutputStream.write(buffer, 0, read)
}
return byteArrayOutputStream.toString(Charsets.UTF_8.name())
}
inline fun <reified T> readContentJson() : T {
return Serializer.json.decodeFromString(readContentString());
}
fun skipBody() {
if (contentLength > 0)
_inputStream.skip(contentLength - _totalRead)
}
fun getHttpHeaderString(): String {
val writer = StringWriter();
writer.write(head + "\r\n");
@ -200,58 +230,13 @@ class HttpContext : AutoCloseable {
statusCode = status;
}
fun readContentBytes(buffer: ByteArray, length: Int) : Int {
val reading = if(contentLength - _totalRead < length)
(contentLength - _totalRead).toInt();
else
length;
val read = _stream.read(buffer, 0, reading);
_totalRead += read;
return read;
}
fun readContentString() : String{
val writer = StringWriter();
var read = 0;
val buffer = ByteArray(8192);
val twoByteArray = ByteArray(2);
do {
read = readContentBytes(buffer, buffer.size);
if(read > 0) {
if (isCharacter2Bytes(buffer[read - 1].toInt())) {
//Fixes overlapping buffers
writer.write(String(buffer, 0, read - 1));
twoByteArray[0] = buffer[read - 1];
val secondByte = _stream.read();
if (secondByte < 0)
break;
twoByteArray[1] = secondByte.toByte();
writer.write(String(twoByteArray));
} else
writer.write(String(buffer, 0, read));
}
} while(read > 0);
return writer.toString();
}
inline fun <reified T> readContentJson() : T {
return Serializer.json.decodeFromString(readContentString());
}
fun skipBody() {
if(contentLength > 0)
_stream.skip(contentLength - _totalRead);
}
override fun close() {
if(!keepAlive) {
_stream?.close();
_inputStream.close();
_responseStream?.close();
}
}
private fun isCharacter2Bytes(firstByte: Int): Boolean {
return firstByte and 0xE0 == 0xC0
}
companion object {
private val TAG = "HttpRequest";
private val statusCodeMap = mapOf(