diff --git a/Libraries/LibCrypto/PK/PK.h b/Libraries/LibCrypto/PK/PK.h index 4e24f4a2af7..baa4ae81c83 100644 --- a/Libraries/LibCrypto/PK/PK.h +++ b/Libraries/LibCrypto/PK/PK.h @@ -232,11 +232,11 @@ public: PKSystem() = default; - virtual void encrypt(ReadonlyBytes in, Bytes& out) = 0; - virtual void decrypt(ReadonlyBytes in, Bytes& out) = 0; + virtual ErrorOr encrypt(ReadonlyBytes in, Bytes& out) = 0; + virtual ErrorOr decrypt(ReadonlyBytes in, Bytes& out) = 0; - virtual void sign(ReadonlyBytes in, Bytes& out) = 0; - virtual void verify(ReadonlyBytes in, Bytes& out) = 0; + virtual ErrorOr verify(ReadonlyBytes in, Bytes& out) = 0; + virtual ErrorOr sign(ReadonlyBytes in, Bytes& out) = 0; virtual ByteString class_name() const = 0; diff --git a/Libraries/LibCrypto/PK/RSA.cpp b/Libraries/LibCrypto/PK/RSA.cpp index 4032636b7d4..af7c9d580a0 100644 --- a/Libraries/LibCrypto/PK/RSA.cpp +++ b/Libraries/LibCrypto/PK/RSA.cpp @@ -136,25 +136,21 @@ ErrorOr RSA::generate_key_pair(size_t bits, IntegerType e) return keys; } -void RSA::encrypt(ReadonlyBytes in, Bytes& out) +ErrorOr RSA::encrypt(ReadonlyBytes in, Bytes& out) { dbgln_if(CRYPTO_DEBUG, "in size: {}", in.size()); auto in_integer = UnsignedBigInteger::import_data(in.data(), in.size()); - if (!(in_integer < m_public_key.modulus())) { - dbgln("value too large for key"); - out = {}; - return; - } + if (in_integer >= m_public_key.modulus()) + return Error::from_string_literal("Data too large for key"); + auto exp = NumberTheory::ModularPower(in_integer, m_public_key.public_exponent(), m_public_key.modulus()); auto size = exp.export_data(out); auto outsize = out.size(); - if (size != outsize) { - dbgln("POSSIBLE RSA BUG!!! Size mismatch: {} requested but {} bytes generated", outsize, size); - out = out.slice(outsize - size, size); - } + VERIFY(size == outsize); + return {}; } -void RSA::decrypt(ReadonlyBytes in, Bytes& out) +ErrorOr RSA::decrypt(ReadonlyBytes in, Bytes& out) { auto in_integer = UnsignedBigInteger::import_data(in.data(), in.size()); @@ -178,22 +174,25 @@ void RSA::decrypt(ReadonlyBytes in, Bytes& out) for (auto i = size; i < aligned_size; ++i) out[out.size() - i - 1] = 0; // zero the non-aligned values out = out.slice(out.size() - aligned_size, aligned_size); + return {}; } -void RSA::sign(ReadonlyBytes in, Bytes& out) +ErrorOr RSA::sign(ReadonlyBytes in, Bytes& out) { auto in_integer = UnsignedBigInteger::import_data(in.data(), in.size()); auto exp = NumberTheory::ModularPower(in_integer, m_private_key.private_exponent(), m_private_key.modulus()); auto size = exp.export_data(out); out = out.slice(out.size() - size, size); + return {}; } -void RSA::verify(ReadonlyBytes in, Bytes& out) +ErrorOr RSA::verify(ReadonlyBytes in, Bytes& out) { auto in_integer = UnsignedBigInteger::import_data(in.data(), in.size()); auto exp = NumberTheory::ModularPower(in_integer, m_public_key.public_exponent(), m_public_key.modulus()); auto size = exp.export_data(out); out = out.slice(out.size() - size, size); + return {}; } void RSA::import_private_key(ReadonlyBytes bytes, bool pem) @@ -258,19 +257,15 @@ void RSA::import_public_key(ReadonlyBytes bytes, bool pem) m_public_key = maybe_key.release_value().public_key; } -void RSA_PKCS1_EME::encrypt(ReadonlyBytes in, Bytes& out) +ErrorOr RSA_PKCS1_EME::encrypt(ReadonlyBytes in, Bytes& out) { auto mod_len = (m_public_key.modulus().trimmed_length() * sizeof(u32) * 8 + 7) / 8; dbgln_if(CRYPTO_DEBUG, "key size: {}", mod_len); - if (in.size() > mod_len - 11) { - dbgln("message too long :("); - out = out.trim(0); - return; - } - if (out.size() < mod_len) { - dbgln("output buffer too small"); - return; - } + if (in.size() > mod_len - 11) + return Error::from_string_literal("Message too long"); + + if (out.size() < mod_len) + return Error::from_string_literal("Output buffer too small"); auto ps_length = mod_len - in.size() - 3; Vector ps; @@ -294,60 +289,47 @@ void RSA_PKCS1_EME::encrypt(ReadonlyBytes in, Bytes& out) dbgln_if(CRYPTO_DEBUG, "padded output size: {} buffer size: {}", 3 + ps_length + in.size(), out.size()); - RSA::encrypt(out, out); + TRY(RSA::encrypt(out, out)); + return {}; } -void RSA_PKCS1_EME::decrypt(ReadonlyBytes in, Bytes& out) + +ErrorOr RSA_PKCS1_EME::decrypt(ReadonlyBytes in, Bytes& out) { auto mod_len = (m_public_key.modulus().trimmed_length() * sizeof(u32) * 8 + 7) / 8; - if (in.size() != mod_len) { - dbgln("decryption error: wrong amount of data: {}", in.size()); - out = out.trim(0); - return; - } + if (in.size() != mod_len) + return Error::from_string_literal("Invalid input size"); - RSA::decrypt(in, out); + TRY(RSA::decrypt(in, out)); - if (out.size() < RSA::output_size()) { - dbgln("decryption error: not enough data after decryption: {}", out.size()); - out = out.trim(0); - return; - } + if (out.size() < RSA::output_size()) + return Error::from_string_literal("Not enough data after decryption"); - if (out[0] != 0x00) { - dbgln("invalid padding byte 0 : {}", out[0]); - return; - } - - if (out[1] != 0x02) { - dbgln("invalid padding byte 1 : {}", out[1]); - return; - } + if (out[0] != 0x00 || out[1] != 0x02) + return Error::from_string_literal("Invalid padding"); size_t offset = 2; while (offset < out.size() && out[offset]) ++offset; - if (offset == out.size()) { - dbgln("garbage data, no zero to split padding"); - return; - } + if (offset == out.size()) + return Error::from_string_literal("Garbage data, no zero to split padding"); ++offset; - if (offset - 3 < 8) { - dbgln("PS too small"); - return; - } + if (offset - 3 < 8) + return Error::from_string_literal("PS too small"); out = out.slice(offset, out.size() - offset); + return {}; } -void RSA_PKCS1_EME::sign(ReadonlyBytes, Bytes&) +ErrorOr RSA_PKCS1_EME::sign(ReadonlyBytes, Bytes&) { - dbgln("FIXME: RSA_PKCS_EME::sign"); + return Error::from_string_literal("FIXME: RSA_PKCS_EME::sign"); } -void RSA_PKCS1_EME::verify(ReadonlyBytes, Bytes&) + +ErrorOr RSA_PKCS1_EME::verify(ReadonlyBytes, Bytes&) { - dbgln("FIXME: RSA_PKCS_EME::verify"); + return Error::from_string_literal("FIXME: RSA_PKCS_EME::verify"); } } diff --git a/Libraries/LibCrypto/PK/RSA.h b/Libraries/LibCrypto/PK/RSA.h index 11b127aed6d..4706ce5772b 100644 --- a/Libraries/LibCrypto/PK/RSA.h +++ b/Libraries/LibCrypto/PK/RSA.h @@ -195,11 +195,11 @@ public: m_public_key.set(m_private_key.modulus(), m_private_key.public_exponent()); } - virtual void encrypt(ReadonlyBytes in, Bytes& out) override; - virtual void decrypt(ReadonlyBytes in, Bytes& out) override; + virtual ErrorOr encrypt(ReadonlyBytes in, Bytes& out) override; + virtual ErrorOr decrypt(ReadonlyBytes in, Bytes& out) override; - virtual void sign(ReadonlyBytes in, Bytes& out) override; - virtual void verify(ReadonlyBytes in, Bytes& out) override; + virtual ErrorOr verify(ReadonlyBytes in, Bytes& out) override; + virtual ErrorOr sign(ReadonlyBytes in, Bytes& out) override; virtual ByteString class_name() const override { @@ -232,11 +232,11 @@ public: ~RSA_PKCS1_EME() = default; - virtual void encrypt(ReadonlyBytes in, Bytes& out) override; - virtual void decrypt(ReadonlyBytes in, Bytes& out) override; + virtual ErrorOr encrypt(ReadonlyBytes in, Bytes& out) override; + virtual ErrorOr decrypt(ReadonlyBytes in, Bytes& out) override; - virtual void sign(ReadonlyBytes, Bytes&) override; - virtual void verify(ReadonlyBytes, Bytes&) override; + virtual ErrorOr verify(ReadonlyBytes in, Bytes& out) override; + virtual ErrorOr sign(ReadonlyBytes in, Bytes& out) override; virtual ByteString class_name() const override { diff --git a/Libraries/LibTLS/HandshakeClient.cpp b/Libraries/LibTLS/HandshakeClient.cpp index bd2b8502dde..a94c878bfba 100644 --- a/Libraries/LibTLS/HandshakeClient.cpp +++ b/Libraries/LibTLS/HandshakeClient.cpp @@ -195,7 +195,7 @@ void TLSv12::build_rsa_pre_master_secret(PacketBuilder& builder) Vector out; out.resize(rsa.output_size()); auto outbuf = out.span(); - rsa.encrypt(m_context.premaster_key, outbuf); + MUST(rsa.encrypt(m_context.premaster_key, outbuf)); if constexpr (TLS_DEBUG) { dbgln("Encrypted: "); diff --git a/Libraries/LibTLS/HandshakeServer.cpp b/Libraries/LibTLS/HandshakeServer.cpp index a84806c24ef..cef2fe0ea1f 100644 --- a/Libraries/LibTLS/HandshakeServer.cpp +++ b/Libraries/LibTLS/HandshakeServer.cpp @@ -388,7 +388,7 @@ ssize_t TLSv12::verify_rsa_server_key_exchange(ReadonlyBytes server_key_info_buf } auto signature_verify_buffer = signature_verify_buffer_result.release_value(); auto signature_verify_bytes = signature_verify_buffer.bytes(); - rsa.verify(signature, signature_verify_bytes); + MUST(rsa.verify(signature, signature_verify_bytes)); auto message_result = ByteBuffer::create_uninitialized(64 + server_key_info_buffer.size()); if (message_result.is_error()) { diff --git a/Libraries/LibTLS/TLSv12.cpp b/Libraries/LibTLS/TLSv12.cpp index d4517d7074c..86b196a049e 100644 --- a/Libraries/LibTLS/TLSv12.cpp +++ b/Libraries/LibTLS/TLSv12.cpp @@ -352,7 +352,7 @@ bool Context::verify_certificate_pair(Certificate const& subject, Certificate co } auto verification_buffer = verification_buffer_result.release_value(); auto verification_buffer_bytes = verification_buffer.bytes(); - rsa.verify(subject.signature_value, verification_buffer_bytes); + MUST(rsa.verify(subject.signature_value, verification_buffer_bytes)); ReadonlyBytes message = subject.tbs_asn1.bytes(); auto pkcs1 = Crypto::PK::EMSA_PKCS1_V1_5(kind); diff --git a/Libraries/LibWeb/Crypto/CryptoAlgorithms.cpp b/Libraries/LibWeb/Crypto/CryptoAlgorithms.cpp index df50b2ca4df..b3eb9c2a121 100644 --- a/Libraries/LibWeb/Crypto/CryptoAlgorithms.cpp +++ b/Libraries/LibWeb/Crypto/CryptoAlgorithms.cpp @@ -691,7 +691,9 @@ WebIDL::ExceptionOr> RSAOAEP::encrypt(AlgorithmParams c auto ciphertext_bytes = ciphertext.bytes(); auto rsa = ::Crypto::PK::RSA { public_key }; - rsa.encrypt(padding, ciphertext_bytes); + auto maybe_encrypt_error = rsa.encrypt(padding, ciphertext_bytes); + if (maybe_encrypt_error.is_error()) + return WebIDL::OperationError::create(realm, "Failed to encrypt"_string); // 6. Return the result of creating an ArrayBuffer containing ciphertext. return JS::ArrayBuffer::create(realm, move(ciphertext)); @@ -723,7 +725,9 @@ WebIDL::ExceptionOr> RSAOAEP::decrypt(AlgorithmParams c auto padding = TRY_OR_THROW_OOM(vm, ByteBuffer::create_uninitialized(private_key_length)); auto padding_bytes = padding.bytes(); - rsa.decrypt(ciphertext, padding_bytes); + auto maybe_encrypt_error = rsa.decrypt(ciphertext, padding_bytes); + if (maybe_encrypt_error.is_error()) + return WebIDL::OperationError::create(realm, "Failed to encrypt"_string); auto error_message = MUST(String::formatted("Invalid hash function '{}'", hash)); ErrorOr maybe_plaintext = Error::from_string_view(error_message.bytes_as_string_view()); diff --git a/Tests/LibCrypto/TestOAEP.cpp b/Tests/LibCrypto/TestOAEP.cpp index a19f8531ef9..96d696b49fc 100644 --- a/Tests/LibCrypto/TestOAEP.cpp +++ b/Tests/LibCrypto/TestOAEP.cpp @@ -137,7 +137,7 @@ TEST_CASE(test_oaep) auto output_buffer = maybe_output_buffer.release_value(); auto output_span = output_buffer.bytes(); - rsa.encrypt(result, output_span); + TRY_OR_FAIL(rsa.encrypt(result, output_span)); EXPECT_EQ(expected_rsa_value, output_span); } diff --git a/Tests/LibCrypto/TestRSA.cpp b/Tests/LibCrypto/TestRSA.cpp index 1e06ccb5cde..fba3c387961 100644 --- a/Tests/LibCrypto/TestRSA.cpp +++ b/Tests/LibCrypto/TestRSA.cpp @@ -28,7 +28,7 @@ TEST_CASE(test_RSA_raw_encrypt) ByteBuffer buffer = {}; buffer.resize(rsa.output_size()); auto buf = buffer.bytes(); - rsa.encrypt(data, buf); + TRY_OR_FAIL(rsa.encrypt(data, buf)); EXPECT(memcmp(result, buf.data(), buf.size()) == 0); } @@ -42,8 +42,8 @@ TEST_CASE(test_RSA_PKCS_1_encrypt) ByteBuffer buffer = {}; buffer.resize(rsa.output_size()); auto buf = buffer.bytes(); - rsa.encrypt(data, buf); - rsa.decrypt(buf, buf); + TRY_OR_FAIL(rsa.encrypt(data, buf)); + TRY_OR_FAIL(rsa.decrypt(buf, buf)); EXPECT(memcmp(buf.data(), "hellohellohellohellohellohellohellohellohello123-", 49) == 0); } @@ -149,8 +149,8 @@ c8yGzl89pYST dec.overwrite(0, "WellHelloFriends", 16); - rsa_from_pair.encrypt(dec, enc); - rsa_from_pem.decrypt(enc, dec); + TRY_OR_FAIL(rsa_from_pair.encrypt(dec, enc)); + TRY_OR_FAIL(rsa_from_pem.decrypt(enc, dec)); EXPECT_EQ(memcmp(dec.data(), "WellHelloFriends", 16), 0); } @@ -171,8 +171,8 @@ TEST_CASE(test_RSA_encrypt_decrypt) enc.overwrite(0, "WellHelloFriendsWellHelloFriendsWellHelloFriendsWellHelloFriends", 64); - rsa.encrypt(enc, dec); - rsa.decrypt(dec, enc); + TRY_OR_FAIL(rsa.encrypt(enc, dec)); + TRY_OR_FAIL(rsa.decrypt(dec, enc)); EXPECT(memcmp(enc.data(), "WellHelloFriendsWellHelloFriendsWellHelloFriendsWellHelloFriends", 64) == 0); }