소스 검색

LibCrypto: Make a better ASN.1 parser

And use it to parse RSA keys.
As a bonus, this one shouldn't be reading out of bounds or messing with
the stack (as much) anymore.
AnotherTest 4 년 전
부모
커밋
3fe7ac0924

+ 4 - 0
AK/Debug.h.in

@@ -322,6 +322,10 @@
 #cmakedefine01 RESOURCE_DEBUG
 #endif
 
+#ifndef RSA_PARSE_DEBUG
+#cmakedefine01 RSA_PARSE_DEBUG
+#endif
+
 #ifndef SAFE_SYSCALL_DEBUG
 #cmakedefine01 SAFE_SYSCALL_DEBUG
 #endif

+ 1 - 0
Meta/CMake/all_the_debug_macros.cmake

@@ -167,6 +167,7 @@ set(DEBUG_CPP_LANGUAGE_SERVER ON)
 set(DEBUG_AUTOCOMPLETE ON)
 set(FILE_WATCHER_DEBUG ON)
 set(SYSCALL_1_DEBUG ON)
+set(RSA_PARSE_DEBUG ON)
 
 # False positive: DEBUG is a flag but it works differently.
 # set(DEBUG ON)

+ 93 - 0
Userland/Libraries/LibCrypto/ASN1/ASN1.cpp

@@ -0,0 +1,93 @@
+/*
+ * Copyright (c) 2021, the SerenityOS developers.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ *    list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ *    this list of conditions and the following disclaimer in the documentation
+ *    and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#include <LibCrypto/ASN1/ASN1.h>
+
+namespace Crypto::ASN1 {
+
+String kind_name(Kind kind)
+{
+    switch (kind) {
+    case Kind::Eol:
+        return "EndOfList";
+    case Kind::Boolean:
+        return "Boolean";
+    case Kind::Integer:
+        return "Integer";
+    case Kind::BitString:
+        return "BitString";
+    case Kind::OctetString:
+        return "OctetString";
+    case Kind::Null:
+        return "Null";
+    case Kind::ObjectIdentifier:
+        return "ObjectIdentifier";
+    case Kind::IA5String:
+        return "IA5String";
+    case Kind::PrintableString:
+        return "PrintableString";
+    case Kind::Utf8String:
+        return "UTF8String";
+    case Kind::UTCTime:
+        return "UTCTime";
+    case Kind::Sequence:
+        return "Sequence";
+    case Kind::Set:
+        return "Set";
+    }
+
+    return "InvalidKind";
+}
+
+String class_name(Class class_)
+{
+    switch (class_) {
+    case Class::Application:
+        return "Application";
+    case Class::Context:
+        return "Context";
+    case Class::Private:
+        return "Private";
+    case Class::Universal:
+        return "Universal";
+    }
+
+    return "InvalidClass";
+}
+
+String type_name(Type type)
+{
+    switch (type) {
+    case Type::Constructed:
+        return "Constructed";
+    case Type::Primitive:
+        return "Primitive";
+    }
+
+    return "InvalidType";
+}
+
+}

+ 31 - 70
Userland/Libraries/LibCrypto/ASN1/ASN1.h

@@ -29,84 +29,45 @@
 #include <AK/Types.h>
 #include <LibCrypto/BigInt/UnsignedBigInteger.h>
 
-namespace Crypto {
+namespace Crypto::ASN1 {
 
-namespace ASN1 {
-
-enum class Kind {
+enum class Kind : u8 {
     Eol,
-    Boolean,
-    Integer,
-    ShortInteger,
-    BitString,
-    OctetString,
-    Null,
-    ObjectIdentifier,
-    IA5String,
-    PrintableString,
-    Utf8String,
-    UTCTime,
-    Choice,
-    Sequence,
-    Set,
-    SetOf
+    Boolean = 0x01,
+    Integer = 0x02,
+    BitString = 0x03,
+    OctetString = 0x04,
+    Null = 0x05,
+    ObjectIdentifier = 0x06,
+    IA5String = 0x16,
+    PrintableString = 0x13,
+    Utf8String = 0x0c,
+    UTCTime = 0x017,
+    Sequence = 0x10,
+    Set = 0x11,
+    // Choice = ??,
 };
 
-static inline StringView kind_name(Kind kind)
-{
-    switch (kind) {
-    case Kind::Eol:
-        return "EndOfList";
-    case Kind::Boolean:
-        return "Boolean";
-    case Kind::Integer:
-        return "Integer";
-    case Kind::ShortInteger:
-        return "ShortInteger";
-    case Kind::BitString:
-        return "BitString";
-    case Kind::OctetString:
-        return "OctetString";
-    case Kind::Null:
-        return "Null";
-    case Kind::ObjectIdentifier:
-        return "ObjectIdentifier";
-    case Kind::IA5String:
-        return "IA5String";
-    case Kind::PrintableString:
-        return "PrintableString";
-    case Kind::Utf8String:
-        return "UTF8String";
-    case Kind::UTCTime:
-        return "UTCTime";
-    case Kind::Choice:
-        return "Choice";
-    case Kind::Sequence:
-        return "Sequence";
-    case Kind::Set:
-        return "Set";
-    case Kind::SetOf:
-        return "SetOf";
-    }
+enum class Class : u8 {
+    Universal = 0,
+    Application = 0x40,
+    Context = 0x80,
+    Private = 0xc0,
+};
 
-    return "InvalidKind";
-}
+enum class Type : u8 {
+    Primitive = 0,
+    Constructed = 0x20,
+};
 
-struct List {
+struct Tag {
     Kind kind;
-    void* data;
-    size_t size;
-    bool used;
-    List *prev, *next, *child, *parent;
+    Class class_;
+    Type type;
 };
 
-static constexpr void set(List& list, Kind type, void* data, size_t size)
-{
-    list.kind = type;
-    list.data = data;
-    list.size = size;
-    list.used = false;
-}
-}
+String kind_name(Kind);
+String class_name(Class);
+String type_name(Type);
 
 }

+ 300 - 0
Userland/Libraries/LibCrypto/ASN1/DER.cpp

@@ -0,0 +1,300 @@
+/*
+ * Copyright (c) 2021, the SerenityOS developers.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ *    list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ *    this list of conditions and the following disclaimer in the documentation
+ *    and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#include <AK/Bitmap.h>
+#include <AK/Utf8View.h>
+#include <LibCrypto/ASN1/DER.h>
+
+namespace Crypto::ASN1 {
+
+Result<Tag, DecodeError> Decoder::read_tag()
+{
+    auto byte_or_error = read_byte();
+    if (byte_or_error.is_error())
+        return byte_or_error.error();
+
+    auto byte = byte_or_error.value();
+    u8 class_ = byte & 0xc0;
+    u8 type = byte & 0x20;
+    u8 kind = byte & 0x1f;
+
+    if (kind == 0x1f) {
+        kind = 0;
+        while (byte & 0x80) {
+            auto byte_or_error = read_byte();
+            if (byte_or_error.is_error())
+                return byte_or_error.error();
+
+            byte = byte_or_error.value();
+            kind = (kind << 7) | (byte & 0x7f);
+        }
+    }
+
+    return Tag { (Kind)kind, (Class)class_, (Type)type };
+}
+
+Result<size_t, DecodeError> Decoder::read_length()
+{
+    auto byte_or_error = read_byte();
+    if (byte_or_error.is_error())
+        return byte_or_error.error();
+
+    auto byte = byte_or_error.value();
+    size_t length = byte;
+
+    if (byte & 0x80) {
+        auto count = byte & 0x7f;
+        if (count == 0x7f)
+            return DecodeError::InvalidInputFormat;
+        auto data_or_error = read_bytes(count);
+        if (data_or_error.is_error())
+            return data_or_error.error();
+
+        auto data = data_or_error.value();
+        length = 0;
+
+        if (data.size() > sizeof(size_t))
+            return DecodeError::Overflow;
+
+        for (auto&& byte : data)
+            length = (length << 8) | (size_t)byte;
+    }
+
+    return length;
+}
+
+Result<u8, DecodeError> Decoder::read_byte()
+{
+    if (m_stack.is_empty())
+        return DecodeError::NoInput;
+
+    auto& entry = m_stack.last();
+    if (entry.is_empty())
+        return DecodeError::NotEnoughData;
+
+    auto byte = entry[0];
+    entry = entry.slice(1);
+
+    return byte;
+}
+
+Result<ReadonlyBytes, DecodeError> Decoder::read_bytes(size_t length)
+{
+    if (m_stack.is_empty())
+        return DecodeError::NoInput;
+
+    auto& entry = m_stack.last();
+    if (entry.size() < length)
+        return DecodeError::NotEnoughData;
+
+    auto bytes = entry.slice(0, length);
+    entry = entry.slice(length);
+
+    return bytes;
+}
+
+Result<bool, DecodeError> Decoder::decode_boolean(ReadonlyBytes data)
+{
+    if (data.size() != 1)
+        return DecodeError::InvalidInputFormat;
+
+    return data[0] == 0;
+}
+
+Result<UnsignedBigInteger, DecodeError> Decoder::decode_arbitrary_sized_integer(ReadonlyBytes data)
+{
+    if (data.size() < 1)
+        return DecodeError::NotEnoughData;
+
+    if (data.size() > 1
+        && ((data[0] == 0xff && data[1] & 0x80)
+            || (data[0] == 0x00 && !(data[1] & 0x80)))) {
+        return DecodeError::InvalidInputFormat;
+    }
+
+    bool is_negative = data[0] & 0x80;
+    if (is_negative)
+        return DecodeError::UnsupportedFormat;
+
+    return UnsignedBigInteger::import_data(data.data(), data.size());
+}
+
+Result<StringView, DecodeError> Decoder::decode_octet_string(ReadonlyBytes bytes)
+{
+    return StringView { bytes.data(), bytes.size() };
+}
+
+Result<std::nullptr_t, DecodeError> Decoder::decode_null(ReadonlyBytes data)
+{
+    if (data.size() != 0)
+        return DecodeError::InvalidInputFormat;
+
+    return nullptr;
+}
+
+Result<Vector<int>, DecodeError> Decoder::decode_object_identifier(ReadonlyBytes data)
+{
+    Vector<int> result;
+    result.append(0); // Reserved space.
+
+    u32 value = 0;
+    for (auto&& byte : data) {
+        if (value == 0 && byte == 0x80)
+            return DecodeError::InvalidInputFormat;
+
+        value = (value << 7) | (byte & 0x7f);
+        if (!(byte & 0x80)) {
+            result.append(value);
+            value = 0;
+        }
+    }
+
+    if (result.size() == 1 || result[1] >= 1600)
+        return DecodeError::InvalidInputFormat;
+
+    result[0] = result[1] / 40;
+    result[1] = result[1] % 40;
+
+    return result;
+}
+
+Result<StringView, DecodeError> Decoder::decode_printable_string(ReadonlyBytes data)
+{
+    Utf8View view { data };
+    if (!view.validate())
+        return DecodeError::InvalidInputFormat;
+
+    return StringView { data };
+}
+
+Result<Bitmap, DecodeError> Decoder::decode_bit_string(ReadonlyBytes data)
+{
+    if (data.size() < 1)
+        return DecodeError::InvalidInputFormat;
+
+    auto unused_bits = data[0];
+    // FIXME: It's rather annoying that `Bitmap` is always mutable.
+    return Bitmap::wrap(const_cast<u8*>(data.offset_pointer(1)), data.size() * 8 - unused_bits);
+}
+
+Result<Tag, DecodeError> Decoder::peek()
+{
+    if (m_stack.is_empty())
+        return DecodeError::NoInput;
+
+    if (eof())
+        return DecodeError::EndOfStream;
+
+    if (m_current_tag.has_value())
+        return m_current_tag.value();
+
+    auto tag_or_error = read_tag();
+    if (tag_or_error.is_error())
+        return tag_or_error.error();
+
+    m_current_tag = tag_or_error.value();
+
+    return m_current_tag.value();
+}
+
+bool Decoder::eof() const
+{
+    return m_stack.is_empty() || m_stack.last().is_empty();
+}
+
+Optional<DecodeError> Decoder::enter()
+{
+    if (m_stack.is_empty())
+        return DecodeError::NoInput;
+
+    auto tag_or_error = peek();
+    if (tag_or_error.is_error())
+        return tag_or_error.error();
+
+    auto tag = tag_or_error.value();
+    if (tag.type != Type::Constructed)
+        return DecodeError::EnteringNonConstructedTag;
+
+    auto length_or_error = read_length();
+    if (length_or_error.is_error())
+        return length_or_error.error();
+
+    auto length = length_or_error.value();
+
+    auto data_or_error = read_bytes(length);
+    if (data_or_error.is_error())
+        return data_or_error.error();
+
+    m_current_tag.clear();
+
+    auto data = data_or_error.value();
+    m_stack.append(data);
+    return {};
+}
+
+Optional<DecodeError> Decoder::leave()
+{
+    if (m_stack.is_empty())
+        return DecodeError::NoInput;
+
+    if (m_stack.size() == 1)
+        return DecodeError::LeavingMainContext;
+
+    m_stack.take_last();
+    m_current_tag.clear();
+
+    return {};
+}
+
+}
+
+void AK::Formatter<Crypto::ASN1::DecodeError>::format(FormatBuilder& fmtbuilder, Crypto::ASN1::DecodeError error)
+{
+    using Crypto::ASN1::DecodeError;
+
+    switch (error) {
+    case DecodeError::NoInput:
+        return Formatter<StringView>::format(fmtbuilder, "DecodeError(No input provided)");
+    case DecodeError::NonConformingType:
+        return Formatter<StringView>::format(fmtbuilder, "DecodeError(Tried to read with a non-conforming type)");
+    case DecodeError::EndOfStream:
+        return Formatter<StringView>::format(fmtbuilder, "DecodeError(End of stream)");
+    case DecodeError::NotEnoughData:
+        return Formatter<StringView>::format(fmtbuilder, "DecodeError(Not enough data)");
+    case DecodeError::EnteringNonConstructedTag:
+        return Formatter<StringView>::format(fmtbuilder, "DecodeError(Tried to enter a primitive tag)");
+    case DecodeError::LeavingMainContext:
+        return Formatter<StringView>::format(fmtbuilder, "DecodeError(Tried to leave the main context)");
+    case DecodeError::InvalidInputFormat:
+        return Formatter<StringView>::format(fmtbuilder, "DecodeError(Input data contained invalid syntax/data)");
+    case DecodeError::Overflow:
+        return Formatter<StringView>::format(fmtbuilder, "DecodeError(Construction would overflow)");
+    case DecodeError::UnsupportedFormat:
+        return Formatter<StringView>::format(fmtbuilder, "DecodeError(Input data format not supported by this parser)");
+    default:
+        return Formatter<StringView>::format(fmtbuilder, "DecodeError(Unknown)");
+    }
+}

+ 114 - 398
Userland/Libraries/LibCrypto/ASN1/DER.h

@@ -26,449 +26,165 @@
 
 #pragma once
 
+#include <AK/Bitmap.h>
+#include <AK/Result.h>
 #include <AK/Types.h>
 #include <LibCrypto/ASN1/ASN1.h>
 #include <LibCrypto/BigInt/UnsignedBigInteger.h>
 
-namespace Crypto {
-
-static bool der_decode_integer(const u8* in, size_t length, UnsignedBigInteger& number)
-{
-    if (length < 3) {
-        dbgln("invalid header size");
-        return false;
-    }
+namespace Crypto::ASN1 {
+
+enum class DecodeError {
+    NoInput,
+    NonConformingType,
+    EndOfStream,
+    NotEnoughData,
+    EnteringNonConstructedTag,
+    LeavingMainContext,
+    InvalidInputFormat,
+    Overflow,
+    UnsupportedFormat,
+};
 
-    size_t x { 0 };
-    // must start with 0x02
-    if ((in[x++] & 0x1f) != 0x02) {
-        dbgln("not an integer {} ({} follows)", in[x - 1], in[x]);
-        return false;
+class Decoder {
+public:
+    Decoder(ReadonlyBytes data)
+    {
+        m_stack.append(data);
     }
 
-    // decode length
-    size_t z = in[x++];
-    if ((x & 0x80) == 0) {
-        // overflow
-        if (x + z > length) {
-            dbgln("would overflow {} > {}", z + x, length);
-            return false;
-        }
-
-        number = UnsignedBigInteger::import_data(in + x, z);
-        return true;
-    } else {
-        // actual big integer
-        z &= 0x7f;
+    // Read a tag without consuming it (and its data).
+    Result<Tag, DecodeError> peek();
 
-        // overflow
-        if ((x + z) > length || z > 4 || z == 0) {
-            dbgln("would overflow {} > {}", z + x, length);
-            return false;
-        }
+    bool eof() const;
 
-        size_t y = 0;
-        while (z--) {
-            y = ((size_t)(in[x++])) | (y << 8);
-        }
+    template<typename ValueType>
+    struct TaggedValue {
+        Tag tag;
+        ValueType value;
+    };
 
-        // overflow
-        if (x + y > length) {
-            dbgln("would overflow {} > {}", y + x, length);
-            return false;
-        }
+    template<typename ValueType>
+    Result<ValueType, DecodeError> read()
+    {
+        if (m_stack.is_empty())
+            return DecodeError::NoInput;
 
-        number = UnsignedBigInteger::import_data(in + x, y);
-        return true;
-    }
+        if (eof())
+            return DecodeError::EndOfStream;
 
-    // see if it's negative
-    if (in[x] & 0x80) {
-        dbgln("negative bigint unsupported in der_decode_integer");
-        return false;
-    }
+        auto previous_position = m_stack;
 
-    return true;
-}
-static bool der_length_integer(UnsignedBigInteger* num, size_t* out_length)
-{
-    auto& bigint = *num;
-    size_t value_length = bigint.trimmed_length() * sizeof(u32);
-    auto length = value_length;
-    if (length == 0) {
-        ++length;
-    } else {
-        // the number comes with a 0 padding to make it positive in 2's comp
-        // add that zero if the msb is 1, but only if we haven't padded it
-        // ourselves
-        auto ms2b = (u16)(bigint.words()[bigint.trimmed_length() - 1] >> 16);
-
-        if ((ms2b & 0xff00) == 0) {
-            if (!(((u8)ms2b) & 0x80))
-                --length;
-        } else if (ms2b & 0x8000) {
-            ++length;
-        }
-    }
-    if (value_length < 128) {
-        ++length;
-    } else {
-        ++length;
-        while (value_length) {
-            ++length;
-            value_length >>= 8;
+        auto tag_or_error = peek();
+        if (tag_or_error.is_error()) {
+            m_stack = move(previous_position);
+            return tag_or_error.error();
         }
-    }
-    // kind
-    ++length;
-    *out_length = length;
-    return true;
-}
-constexpr static bool der_decode_object_identifier(const u8* in, size_t in_length, u8* words, u8* out_length)
-{
-    if (in_length < 3)
-        return false; // invalid header
-
-    if (*out_length < 2)
-        return false; // need at least two words
-
-    size_t x { 0 };
-    if ((in[x++] & 0x1f) != 0x06) {
-        return false; // invalid header value
-    }
 
-    size_t length { 0 };
-    if (in[x] < 128) {
-        length = in[x++];
-    } else {
-        if ((in[x] < 0x81) | (in[x] > 0x82))
-            return false; // invalid header
+        auto length_or_error = read_length();
+        if (length_or_error.is_error()) {
+            m_stack = move(previous_position);
+            return length_or_error.error();
+        }
 
-        size_t y = in[x++] & 0x7f;
-        while (y--)
-            length = (length << 8) | (size_t)in[x++];
-    }
+        auto tag = tag_or_error.value();
+        auto length = length_or_error.value();
 
-    if (length < 1 || length + x > in_length)
-        return false; // invalid length or overflow
-
-    size_t y { 0 }, t { 0 };
-    while (length--) {
-        t = (t << 7) | (in[x] & 0x7f);
-        if (!(in[x++] & 0x80)) {
-            if (y >= *out_length)
-                return false; // overflow
-
-            if (y == 0) {
-                words[0] = t / 40;
-                words[1] = t % 40;
-                y = 2;
-            } else {
-                words[y++] = t;
-            }
-            t = 0;
+        auto value_or_error = read_value<ValueType>(tag.class_, tag.kind, length);
+        if (value_or_error.is_error()) {
+            m_stack = move(previous_position);
+            return value_or_error.error();
         }
-    }
-    *out_length = y;
-    return true;
-}
 
-static constexpr size_t der_object_identifier_bits(size_t x)
-{
-    x &= 0xffffffff;
-    size_t c { 0 };
-    while (x) {
-        ++c;
-        x >>= 1;
-    }
-    return c;
-}
-
-constexpr static bool der_length_object_identifier(u8* words, size_t num_words, size_t* out_length)
-{
-    if (num_words < 2)
-        return false;
-
-    if (words[0] > 3 || (words[0] < 2 && words[1] > 39))
-        return false;
-
-    size_t z { 0 };
-    size_t wordbuf = words[0] * 40 + words[1];
-    for (size_t y = 0; y < num_words; ++y) {
-        auto t = der_object_identifier_bits(wordbuf);
-        z = t / 7 + (!!(t % 7)) + (!!(wordbuf == 0));
-        if (y < num_words - 1)
-            wordbuf = words[y + 1];
-    }
+        m_current_tag.clear();
 
-    if (z < 128) {
-        z += 2;
-    } else if (z < 256) {
-        z += 3;
-    } else {
-        z += 4;
+        return value_or_error.release_value();
     }
-    *out_length = z;
-    return true;
-}
 
-constexpr static bool der_length_sequence(ASN1::List* list, size_t in_length, size_t* out_length)
-{
-    size_t y { 0 }, x { 0 };
-    for (size_t i = 0; i < in_length; ++i) {
-        auto type = list[i].kind;
-        auto size = list[i].size;
-        auto data = list[i].data;
-
-        if (type == ASN1::Kind::Eol)
-            break;
-
-        switch (type) {
-        case ASN1::Kind::Integer:
-            if (!der_length_integer((UnsignedBigInteger*)data, &x)) {
-                return false;
-            }
-            y += x;
-            break;
-        case ASN1::Kind::ObjectIdentifier:
-            if (!der_length_object_identifier((u8*)data, size, &x)) {
-                return false;
-            }
-            y += x;
-            break;
-        case ASN1::Kind::Sequence:
-            if (!der_length_sequence((ASN1::List*)data, size, &x)) {
-                return false;
-            }
-            y += x;
-            break;
-        default:
-            dbgln("Unhandled Kind {}", ASN1::kind_name(type));
-            ASSERT_NOT_REACHED();
-            break;
-        }
-    }
+    Optional<DecodeError> enter();
+    Optional<DecodeError> leave();
 
-    if (y < 128) {
-        y += 2;
-    } else if (y < 256) {
-        y += 3;
-    } else if (y < 65536) {
-        y += 4;
-    } else if (y < 16777216ul) {
-        y += 5;
-    } else {
-        dbgln("invalid length {}", y);
-        return false;
-    }
-    *out_length = y;
-    return true;
-}
+private:
+    template<typename ValueType, typename DecodedType>
+    Result<ValueType, DecodeError> with_type_check(DecodedType&& value)
+    {
+        if constexpr (requires { ValueType { value }; })
+            return ValueType { value };
 
-static inline bool der_decode_sequence(const u8* in, size_t in_length, ASN1::List* list, size_t out_length, bool ordered = true)
-{
-    if (in_length < 2) {
-        dbgln("header too small");
-        return false; // invalid header
+        return DecodeError::NonConformingType;
     }
-    size_t x { 0 };
-    if (in[x++] != 0x30) {
-        dbgln("not a sequence: {}", in[x - 1]);
-        return false; // not a sequence
-    }
-    size_t block_size { 0 };
-    size_t y { 0 };
-    if (in[x] < 128) {
-        block_size = in[x++];
-    } else if (in[x] & 0x80) {
-        if ((in[x] < 0x81) || (in[x] > 0x83)) {
-            dbgln("invalid length element {}", in[x]);
-            return false;
-        }
 
-        y = in[x++] & 0x7f;
+    template<typename ValueType, typename DecodedType>
+    Result<ValueType, DecodeError> with_type_check(Result<DecodedType, DecodeError>&& value_or_error)
+    {
+        if (value_or_error.is_error())
+            return value_or_error.error();
 
-        if (x + y > in_length) {
-            dbgln("would overflow {} > {}", x + y, in_length);
-            return false; // overflow
-        }
-        block_size = 0;
-        while (y--)
-            block_size = (block_size << 8) | (size_t)in[x++];
-    }
+        auto&& value = value_or_error.value();
+        if constexpr (requires { ValueType { value }; })
+            return ValueType { value };
 
-    // overflow
-    if (x + block_size > in_length) {
-        dbgln("would overflow {} > {}", x + block_size, in_length);
-        return false;
+        return DecodeError::NonConformingType;
     }
 
-    for (size_t i = 0; i < out_length; ++i)
-        list[i].used = false;
+    template<typename ValueType>
+    Result<ValueType, DecodeError> read_value(Class klass, Kind kind, size_t length)
+    {
+        auto data_or_error = read_bytes(length);
+        if (data_or_error.is_error())
+            return data_or_error.error();
+        auto data = data_or_error.value();
 
-    in_length = block_size;
-    for (size_t i = 0; i < out_length; ++i) {
-        size_t z = 0;
-        auto kind = list[i].kind;
-        auto size = list[i].size;
-        auto data = list[i].data;
+        if (klass != Class::Universal)
+            return with_type_check<ValueType>(data);
 
-        if (!ordered && list[i].used) {
-            continue;
-        }
+        if (kind == Kind::Boolean)
+            return with_type_check<ValueType>(decode_boolean(data));
 
-        switch (kind) {
-        case ASN1::Kind::Integer:
-            z = in_length;
-            if (!der_decode_integer(in + x, z, *(UnsignedBigInteger*)data)) {
-                dbgln("could not decode an integer");
-                return false;
-            }
-            if (!der_length_integer((UnsignedBigInteger*)data, &z)) {
-                dbgln("could not figure out the length");
-                return false;
-            }
-            break;
-        case ASN1::Kind::ObjectIdentifier:
-            z = in_length;
-            if (!der_decode_object_identifier(in + x, z, (u8*)data, (u8*)&size)) {
-                if (!ordered)
-                    continue;
-                return false;
-            }
-            list[i].size = size;
-            if (!der_length_object_identifier((u8*)data, size, &z)) {
-                return false;
-            }
-            break;
-        case ASN1::Kind::Sequence:
-            if ((in[x] & 0x3f) != 0x30) {
-                dbgln("Not a sequence: {}", (in[x] & 0x3f));
-                return false;
-            }
-            z = in_length;
-            if (!der_decode_sequence(in + x, z, (ASN1::List*)data, size)) {
-                if (!ordered)
-                    continue;
-                return false;
-            }
-            if (!der_length_sequence((ASN1::List*)data, size, &z)) {
-                return false;
-            }
-            break;
-        default:
-            dbgln("Unhandled ASN1 kind {}", ASN1::kind_name(kind));
-            ASSERT_NOT_REACHED();
-            break;
-        }
-        x += z;
-        in_length -= z;
-        list[i].used = true;
-        if (!ordered)
-            i = -1;
-    }
-    for (size_t i = 0; i < out_length; ++i)
-        if (!list[i].used) {
-            dbgln("index {} was not read", i);
-            return false;
-        }
+        if (kind == Kind::Integer)
+            return with_type_check<ValueType>(decode_arbitrary_sized_integer(data));
 
-    return true;
-}
+        if (kind == Kind::OctetString)
+            return with_type_check<ValueType>(decode_octet_string(data));
 
-template<size_t element_count>
-struct der_decode_sequence_many_base {
-    constexpr void set(size_t index, ASN1::Kind kind, size_t size, void* data)
-    {
-        ASN1::set(m_list[index], kind, data, size);
-    }
+        if (kind == Kind::Null)
+            return with_type_check<ValueType>(decode_null(data));
 
-    constexpr der_decode_sequence_many_base(const u8* in, size_t in_length)
-        : m_in(in)
-        , m_in_length(in_length)
-    {
-    }
+        if (kind == Kind::ObjectIdentifier)
+            return with_type_check<ValueType>(decode_object_identifier(data));
 
-    ASN1::List* list() { return m_list; }
-    const u8* in() { return m_in; }
-    size_t in_length() { return m_in_length; }
+        if (kind == Kind::PrintableString || kind == Kind::IA5String || kind == Kind::UTCTime)
+            return with_type_check<ValueType>(decode_printable_string(data));
 
-protected:
-    ASN1::List m_list[element_count];
-    const u8* m_in;
-    size_t m_in_length;
-};
+        if (kind == Kind::Utf8String)
+            return with_type_check<ValueType>(StringView { data.data(), data.size() });
 
-template<size_t element_count>
-struct der_decode_sequence_many : public der_decode_sequence_many_base<element_count> {
+        if (kind == Kind::BitString)
+            return with_type_check<ValueType>(decode_bit_string(data));
 
-    template<typename ElementType, typename... Args>
-    constexpr void construct(size_t index, ASN1::Kind kind, size_t size, ElementType data, Args... args)
-    {
-        der_decode_sequence_many_base<element_count>::set(index, kind, size, (void*)data);
-        construct(index + 1, args...);
+        return with_type_check<ValueType>(data);
     }
 
-    constexpr void construct(size_t index)
-    {
-        ASSERT(index == element_count);
-    }
+    Result<Tag, DecodeError> read_tag();
+    Result<size_t, DecodeError> read_length();
+    Result<u8, DecodeError> read_byte();
+    Result<ReadonlyBytes, DecodeError> read_bytes(size_t length);
 
-    template<typename... Args>
-    constexpr der_decode_sequence_many(const u8* in, size_t in_length, Args... args)
-        : der_decode_sequence_many_base<element_count>(in, in_length)
-    {
-        construct(0, args...);
-    }
+    static Result<bool, DecodeError> decode_boolean(ReadonlyBytes);
+    static Result<UnsignedBigInteger, DecodeError> decode_arbitrary_sized_integer(ReadonlyBytes);
+    static Result<StringView, DecodeError> decode_octet_string(ReadonlyBytes);
+    static Result<std::nullptr_t, DecodeError> decode_null(ReadonlyBytes);
+    static Result<Vector<int>, DecodeError> decode_object_identifier(ReadonlyBytes);
+    static Result<StringView, DecodeError> decode_printable_string(ReadonlyBytes);
+    static Result<Bitmap, DecodeError> decode_bit_string(ReadonlyBytes);
 
-    constexpr operator bool()
-    {
-        return der_decode_sequence(this->m_in, this->m_in_length, this->m_list, element_count);
-    }
+    Vector<ReadonlyBytes> m_stack;
+    Optional<Tag> m_current_tag;
 };
 
-// FIXME: Move these terrible constructs into their own place
-constexpr static void decode_b64_block(const u8 in[4], u8 out[3])
-{
-    out[0] = (u8)(in[0] << 2 | in[1] >> 4);
-    out[1] = (u8)(in[1] << 4 | in[2] >> 2);
-    out[2] = (u8)(((in[2] << 6) & 0xc0) | in[3]);
 }
 
-constexpr static char base64_chars[] { "|$$$}rstuvwxyz{$$$$$$$>?@ABCDEFGHIJKLMNOPQRSTUVW$$$$$$XYZ[\\]^_`abcdefghijklmnopq" };
-constexpr static size_t decode_b64(const u8* in_buffer, size_t in_length, ByteBuffer& out_buffer)
-{
-    u8 in[4] { 0 }, out[3] { 0 }, v { 0 };
-    size_t i { 0 }, length { 0 };
-    size_t output_offset { 0 };
-
-    const u8* ptr = in_buffer;
-
-    while (ptr <= in_buffer + in_length) {
-        for (length = 0, i = 0; i < 4 && (ptr <= in_buffer + in_length); ++i) {
-            v = 0;
-            while ((ptr <= in_buffer + in_length) && !v) {
-                v = ptr[0];
-                ++ptr;
-                v = (u8)((v < 43 || v > 122) ? 0 : base64_chars[v - 43]);
-                if (v)
-                    v = (u8)(v == '$' ? 0 : v - 61);
-            }
-            if (ptr <= in_buffer + in_length) {
-                ++length;
-                if (v)
-                    in[i] = v - 1;
-
-            } else {
-                in[i] = 0;
-            }
-        }
-        if (length) {
-            decode_b64_block(in, out);
-            out_buffer.overwrite(output_offset, out, length - 1);
-            output_offset += length - 1;
-        }
-    }
-    return output_offset;
-}
-}
+template<>
+struct AK::Formatter<Crypto::ASN1::DecodeError> : Formatter<StringView> {
+    void format(FormatBuilder&, Crypto::ASN1::DecodeError);
+};

+ 72 - 0
Userland/Libraries/LibCrypto/ASN1/PEM.cpp

@@ -0,0 +1,72 @@
+/*
+ * Copyright (c) 2021, the SerenityOS developers.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ *    list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ *    this list of conditions and the following disclaimer in the documentation
+ *    and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#include <AK/Base64.h>
+#include <AK/GenericLexer.h>
+#include <LibCrypto/ASN1/PEM.h>
+
+namespace Crypto {
+
+ByteBuffer decode_pem(ReadonlyBytes data)
+{
+    GenericLexer lexer { data };
+    ByteBuffer decoded;
+
+    // FIXME: Parse multiple.
+    enum {
+        PreStartData,
+        Started,
+        Ended,
+    } state { PreStartData };
+    while (!lexer.is_eof()) {
+        switch (state) {
+        case PreStartData:
+            if (lexer.consume_specific("-----BEGIN"))
+                state = Started;
+            lexer.consume_line();
+            break;
+        case Started: {
+            if (lexer.consume_specific("-----END")) {
+                state = Ended;
+                lexer.consume_line();
+                break;
+            }
+            auto b64decoded = decode_base64(lexer.consume_line().trim_whitespace(TrimMode::Right));
+            decoded.append(b64decoded.data(), b64decoded.size());
+            break;
+        }
+        case Ended:
+            lexer.consume_all();
+            break;
+        default:
+            ASSERT_NOT_REACHED();
+        }
+    }
+
+    return decoded;
+}
+
+}

+ 1 - 39
Userland/Libraries/LibCrypto/ASN1/PEM.h

@@ -32,44 +32,6 @@
 
 namespace Crypto {
 
-static inline ByteBuffer decode_pem(ReadonlyBytes data_in, size_t cert_index = 0)
-{
-    size_t i { 0 };
-    size_t start_at { 0 };
-    size_t idx { 0 };
-    size_t input_length = data_in.size();
-    auto alloc_len = input_length / 4 * 3;
-    auto output = ByteBuffer::create_uninitialized(alloc_len);
-
-    for (i = 0; i < input_length; i++) {
-        if ((data_in[i] == '\n') || (data_in[i] == '\r'))
-            continue;
-
-        if (data_in[i] != '-') {
-            // Read entire line.
-            while ((i < input_length) && (data_in[i] != '\n'))
-                i++;
-            continue;
-        }
-
-        if (data_in[i] == '-') {
-            auto end_idx = i;
-            // Read until end of line.
-            while ((i < input_length) && (data_in[i] != '\n'))
-                i++;
-            if (start_at) {
-                if (cert_index > 0) {
-                    cert_index--;
-                    start_at = 0;
-                } else {
-                    idx = decode_b64(data_in.offset(start_at), end_idx - start_at, output);
-                    break;
-                }
-            } else
-                start_at = i + 1;
-        }
-    }
-    return output.slice(0, idx);
-}
+ByteBuffer decode_pem(ReadonlyBytes);
 
 }

+ 3 - 0
Userland/Libraries/LibCrypto/CMakeLists.txt

@@ -1,4 +1,7 @@
 set(SOURCES
+    ASN1/ASN1.cpp
+    ASN1/DER.cpp
+    ASN1/PEM.cpp
     Authentication/GHash.cpp
     BigInt/SignedBigInteger.cpp
     BigInt/UnsignedBigInteger.cpp

+ 176 - 65
Userland/Libraries/LibCrypto/PK/RSA.cpp

@@ -26,6 +26,7 @@
 
 #include <AK/Debug.h>
 #include <AK/Random.h>
+#include <AK/ScopeGuard.h>
 #include <LibCrypto/ASN1/ASN1.h>
 #include <LibCrypto/ASN1/DER.h>
 #include <LibCrypto/ASN1/PEM.h>
@@ -34,84 +35,194 @@
 namespace Crypto {
 namespace PK {
 
-RSA::KeyPairType RSA::parse_rsa_key(ReadonlyBytes in)
+static constexpr Array<int, 7> pkcs8_rsa_key_oid { 1, 2, 840, 113549, 1, 1, 1 };
+
+RSA::KeyPairType RSA::parse_rsa_key(ReadonlyBytes der)
 {
     // we are going to assign to at least one of these
     KeyPairType keypair;
-    // TODO: move ASN parsing logic out
-    u64 t, x, y, z, tmp_oid[16];
-    u8 tmp_buf[4096] { 0 };
-    UnsignedBigInteger n, e, d;
-    ASN1::List pubkey_hash_oid[2], pubkey[2];
-
-    ASN1::set(pubkey_hash_oid[0], ASN1::Kind::ObjectIdentifier, tmp_oid, sizeof(tmp_oid) / sizeof(tmp_oid[0]));
-    ASN1::set(pubkey_hash_oid[1], ASN1::Kind::Null, nullptr, 0);
-
-    // DER is weird in that it stores pubkeys as bitstrings
-    // we must first extract that crap
-    ASN1::set(pubkey[0], ASN1::Kind::Sequence, &pubkey_hash_oid, 2);
-    ASN1::set(pubkey[1], ASN1::Kind::Null, nullptr, 0);
-
-    dbgln("we were offered {} bytes of input", in.size());
-
-    if (der_decode_sequence(in.data(), in.size(), pubkey, 2)) {
-        // yay, now we have to reassemble the bitstring to a bytestring
-        t = 0;
-        y = 0;
-        z = 0;
-        x = 0;
-        for (; x < pubkey[1].size; ++x) {
-            y = (y << 1) | tmp_buf[x];
-            if (++z == 8) {
-                tmp_buf[t++] = (u8)y;
-                y = 0;
-                z = 0;
-            }
+
+    ASN1::Decoder decoder(der);
+    // There are four possible (supported) formats:
+    // PKCS#1 private key
+    // PKCS#1 public key
+    // PKCS#8 private key
+    // PKCS#8 public key
+
+    // They're all a single sequence, so let's check that first
+    {
+        auto result = decoder.peek();
+        if (result.is_error()) {
+            // Bad data.
+            dbgln_if(RSA_PARSE_DEBUG, "RSA key parse failed: {}", result.error());
+            return keypair;
         }
-        // now the buffer is correct (Sequence { Integer, Integer })
-        if (!der_decode_sequence_many<2>(tmp_buf, t,
-                ASN1::Kind::Integer, 1, &n,
-                ASN1::Kind::Integer, 1, &e)) {
-            // something was fucked up
-            dbgln("bad pubkey: e={} n={}", e, n);
+        auto tag = result.value();
+        if (tag.kind != ASN1::Kind::Sequence) {
+            dbgln_if(RSA_PARSE_DEBUG, "RSA key parse failed: Expected a Sequence but got {}", ASN1::kind_name(tag.kind));
             return keypair;
         }
-        // correct public key
-        keypair.public_key.set(n, e);
-        return keypair;
     }
 
-    // could be a private key
-    if (!der_decode_sequence_many<1>(in.data(), in.size(),
-            ASN1::Kind::Integer, 1, &n)) {
-        // that's no key
-        // that's a death star
-        dbgln("that's a death star");
-        return keypair;
+    // Then enter the sequence
+    {
+        auto error = decoder.enter();
+        if (error.has_value()) {
+            // Something was weird with the input.
+            dbgln_if(RSA_PARSE_DEBUG, "RSA key parse failed: {}", error.value());
+            return keypair;
+        }
     }
 
-    if (n == 0) {
-        // it is a private key
-        UnsignedBigInteger zero;
-        if (!der_decode_sequence_many<4>(in.data(), in.size(),
-                ASN1::Kind::Integer, 1, &zero,
-                ASN1::Kind::Integer, 1, &n,
-                ASN1::Kind::Integer, 1, &e,
-                ASN1::Kind::Integer, 1, &d)) {
-            dbgln("bad privkey n={} e={} d={}", n, e, d);
+    bool has_read_error = false;
+
+    const auto check_if_pkcs8_rsa_key = [&] {
+        // see if it's a sequence:
+        auto tag_result = decoder.peek();
+        if (tag_result.is_error()) {
+            // Decode error :shrug:
+            dbgln_if(RSA_PARSE_DEBUG, "RSA PKCS#8 public key parse failed: {}", tag_result.error());
+            return false;
+        }
+
+        auto tag = tag_result.value();
+        if (tag.kind != ASN1::Kind::Sequence) {
+            // We don't know what this is, but it sure isn't a PKCS#8 key.
+            dbgln_if(RSA_PARSE_DEBUG, "RSA PKCS#8 public key parse failed: Expected a Sequence but got {}", ASN1::kind_name(tag.kind));
+            return false;
+        }
+
+        // It's a sequence, now let's see if it's actually an RSA key.
+        auto error = decoder.enter();
+        if (error.has_value()) {
+            // Shenanigans!
+            dbgln_if(RSA_PARSE_DEBUG, "RSA PKCS#8 public key parse failed: {}", error.value());
+            return false;
+        }
+
+        ScopeGuard leave { [&] {
+            auto error = decoder.leave();
+            if (error.has_value()) {
+                dbgln_if(RSA_PARSE_DEBUG, "RSA key parse failed: {}", error.value());
+                has_read_error = true;
+            }
+        } };
+
+        // Now let's read the OID.
+        auto oid_result = decoder.read<Vector<int>>();
+        if (oid_result.is_error()) {
+            dbgln_if(RSA_PARSE_DEBUG, "RSA PKCS#8 public key parse failed: {}", oid_result.error());
+            return false;
+        }
+
+        auto oid = oid_result.release_value();
+        // Now let's check that the OID matches "RSA key"
+        if (oid != pkcs8_rsa_key_oid) {
+            // Oh well. not an RSA key at all.
+            dbgln_if(RSA_PARSE_DEBUG, "RSA PKCS#8 public key parse failed: Not an RSA key");
+            return false;
+        }
+
+        return true;
+    };
+
+    auto integer_result = decoder.read<UnsignedBigInteger>();
+
+    if (!integer_result.is_error()) {
+        auto first_integer = integer_result.release_value();
+
+        // It's either a PKCS#1 key, or a PKCS#8 private key.
+        // Check for the PKCS#8 private key right away.
+        if (check_if_pkcs8_rsa_key()) {
+            if (has_read_error)
+                return keypair;
+            // Now read the private key, which is actually an octet string containing the PKCS#1 encoded private key.
+            auto data_result = decoder.read<StringView>();
+            if (data_result.is_error()) {
+                dbgln_if(RSA_PARSE_DEBUG, "RSA PKCS#8 private key parse failed: {}", data_result.error());
+                return keypair;
+            }
+            return parse_rsa_key(data_result.value().bytes());
+        }
+
+        if (has_read_error)
+            return keypair;
+
+        // It's not a PKCS#8 key, so it's a PKCS#1 key (or something we don't support)
+        // if the first integer is zero or one, it's a private key.
+        if (first_integer == 0) {
+            // This is a private key, parse the rest.
+            auto modulus_result = decoder.read<UnsignedBigInteger>();
+            if (modulus_result.is_error()) {
+                dbgln_if(RSA_PARSE_DEBUG, "RSA PKCS#1 private key parse failed: {}", modulus_result.error());
+                return keypair;
+            }
+            auto modulus = modulus_result.release_value();
+
+            auto public_exponent_result = decoder.read<UnsignedBigInteger>();
+            if (public_exponent_result.is_error()) {
+                dbgln_if(RSA_PARSE_DEBUG, "RSA PKCS#1 private key parse failed: {}", public_exponent_result.error());
+                return keypair;
+            }
+            auto public_exponent = public_exponent_result.release_value();
+
+            auto private_exponent_result = decoder.read<UnsignedBigInteger>();
+            if (private_exponent_result.is_error()) {
+                dbgln_if(RSA_PARSE_DEBUG, "RSA PKCS#1 private key parse failed: {}", private_exponent_result.error());
+                return keypair;
+            }
+            auto private_exponent = private_exponent_result.release_value();
+
+            // Drop the rest of the fields on the floor, we don't use them.
+            // FIXME: Actually use them...
+            keypair.private_key = { modulus, move(private_exponent), public_exponent };
+            keypair.public_key = { move(modulus), move(public_exponent) };
+
+            return keypair;
+        } else if (first_integer == 1) {
+            // This is a multi-prime key, we don't support that.
+            dbgln_if(RSA_PARSE_DEBUG, "RSA PKCS#1 private key parse failed: Multi-prime key not supported");
+            return keypair;
+        } else {
+            auto&& modulus = move(first_integer);
+
+            // Try reading a public key, `first_integer` is the modulus.
+            auto public_exponent_result = decoder.read<UnsignedBigInteger>();
+            if (public_exponent_result.is_error()) {
+                // Bad public key.
+                dbgln_if(RSA_PARSE_DEBUG, "RSA PKCS#1 public key parse failed: {}", public_exponent_result.error());
+                return keypair;
+            }
+
+            auto public_exponent = public_exponent_result.release_value();
+            keypair.public_key.set(move(modulus), move(public_exponent));
+
             return keypair;
         }
-        keypair.private_key.set(n, d, e);
-        return keypair;
-    }
-    if (n == 1) {
-        // multiprime key, we don't know how to deal with this
-        dbgln("Unsupported key type");
-        return keypair;
+
+    } else {
+        // It wasn't a PKCS#1 key, let's try our luck with PKCS#8.
+        if (!check_if_pkcs8_rsa_key())
+            return keypair;
+
+        if (has_read_error)
+            return keypair;
+
+        // Now we have a bit string, which contains the PKCS#1 encoded public key.
+        auto data_result = decoder.read<Bitmap>();
+        if (data_result.is_error()) {
+            dbgln_if(RSA_PARSE_DEBUG, "RSA PKCS#8 public key parse failed: {}", data_result.error());
+            return keypair;
+        }
+
+        // Now just read it as a PKCS#1 DER.
+        auto data = data_result.release_value();
+        // FIXME: This is pretty awkward, maybe just generate a zero'd out ByteBuffer from the parser instead?
+        auto padded_data = ByteBuffer::create_zeroed(data.size_in_bytes());
+        padded_data.overwrite(0, data.data(), data.size_in_bytes());
+
+        return parse_rsa_key(padded_data.bytes());
     }
-    // it's a broken public key
-    keypair.public_key.set(n, 65537);
-    return keypair;
 }
 
 void RSA::encrypt(ReadonlyBytes in, Bytes& out)

+ 20 - 18
Userland/Libraries/LibCrypto/PK/RSA.h

@@ -35,12 +35,13 @@
 
 namespace Crypto {
 namespace PK {
-template<typename Integer = u64>
+template<typename Integer = UnsignedBigInteger>
 class RSAPublicKey {
 public:
-    RSAPublicKey(const Integer& n, const Integer& e)
-        : m_modulus(n)
-        , m_public_exponent(e)
+    RSAPublicKey(Integer n, Integer e)
+        : m_modulus(move(n))
+        , m_public_exponent(move(e))
+        , m_length(m_modulus.trimmed_length() * sizeof(u32))
     {
     }
 
@@ -57,11 +58,11 @@ public:
     size_t length() const { return m_length; }
     void set_length(size_t length) { m_length = length; }
 
-    void set(const Integer& n, const Integer& e)
+    void set(Integer n, Integer e)
     {
-        m_modulus = n;
-        m_public_exponent = e;
-        m_length = (n.trimmed_length() * sizeof(u32));
+        m_modulus = move(n);
+        m_public_exponent = move(e);
+        m_length = (m_modulus.trimmed_length() * sizeof(u32));
     }
 
 private:
@@ -73,10 +74,11 @@ private:
 template<typename Integer = UnsignedBigInteger>
 class RSAPrivateKey {
 public:
-    RSAPrivateKey(const Integer& n, const Integer& d, const Integer& e)
-        : m_modulus(n)
-        , m_private_exponent(d)
-        , m_public_exponent(e)
+    RSAPrivateKey(Integer n, Integer d, Integer e)
+        : m_modulus(move(n))
+        , m_private_exponent(move(d))
+        , m_public_exponent(move(e))
+        , m_length(m_modulus.trimmed_length() * sizeof(u32))
     {
     }
 
@@ -91,12 +93,12 @@ public:
     size_t length() const { return m_length; }
     void set_length(size_t length) { m_length = length; }
 
-    void set(const Integer& n, const Integer& d, const Integer& e)
+    void set(Integer n, Integer d, Integer e)
     {
-        m_modulus = n;
-        m_private_exponent = d;
-        m_public_exponent = e;
-        m_length = (n.length() * sizeof(u32));
+        m_modulus = move(n);
+        m_private_exponent = move(d);
+        m_public_exponent = move(e);
+        m_length = m_modulus.trimmed_length() * sizeof(u32);
     }
 
 private:
@@ -120,7 +122,7 @@ class RSA : public PKSystem<RSAPrivateKey<IntegerType>, RSAPublicKey<IntegerType
 public:
     using KeyPairType = RSAKeyPair<PublicKeyType, PrivateKeyType>;
 
-    static KeyPairType parse_rsa_key(ReadonlyBytes);
+    static KeyPairType parse_rsa_key(ReadonlyBytes der);
     static KeyPairType generate_key_pair(size_t bits = 256)
     {
         IntegerType e { 65537 }; // :P

+ 1 - 1
Userland/Libraries/LibTLS/TLSv12.cpp

@@ -839,7 +839,7 @@ bool TLSv12::add_client_key(ReadonlyBytes certificate_pem_buffer, ReadonlyBytes
     if (certificate_pem_buffer.is_empty() || rsa_key.is_empty()) {
         return true;
     }
-    auto decoded_certificate = Crypto::decode_pem(certificate_pem_buffer, 0);
+    auto decoded_certificate = Crypto::decode_pem(certificate_pem_buffer);
     if (decoded_certificate.is_empty()) {
         dbgln("Certificate not PEM");
         return false;

+ 54 - 10
Userland/Utilities/test-crypto.cpp

@@ -1970,8 +1970,9 @@ static void rsa_emsa_pss_test_create()
 
 static void rsa_test_der_parse()
 {
-    I_TEST((RSA | ASN1 DER / PEM encoded Key import));
-    auto privkey = R"(-----BEGIN RSA PRIVATE KEY-----
+    {
+        I_TEST((RSA | ASN1 PKCS1 DER / PEM encoded Key import));
+        auto privkey = R"(-----BEGIN RSA PRIVATE KEY-----
 MIIBOgIBAAJBAJsrIYHxs1YL9tpfodaWs1lJoMdF4kgFisUFSj6nvBhJUlmBh607AlgTaX0E
 DGPYycXYGZ2n6rqmms5lpDXBpUcCAwEAAQJAUNpPkmtEHDENxsoQBUXvXDYeXdePSiIBJhpU
 joNOYoR5R9z5oX2cpcyykQ58FC2vKKg+x8N6xczG7qO95tw5UQIhAN354CP/FA+uTeJ6KJ+i
@@ -1980,14 +1981,57 @@ IQCTjYI861Y+hjMnlORkGSdvWlTHUj6gjEOh4TlWeJzQoQIgAxMZOQKtxCZUuxFwzRq4xLRG
 nrDlBQpuxz7bwSyQO7UCIHrYMnDohgNbwtA5ZpW3H1cKKQQvueWm6sxW9P5sUrZ3
 -----END RSA PRIVATE KEY-----)";
 
-    Crypto::PK::RSA rsa(privkey);
-    if (rsa.public_key().public_exponent() == 65537) {
-        if (rsa.private_key().private_exponent() == "4234603516465654167360850580101327813936403862038934287300450163438938741499875303761385527882335478349599685406941909381269804396099893549838642251053393"_bigint) {
-            PASS;
-        } else
-            FAIL(Invalid private exponent);
-    } else {
-        FAIL(Invalid public exponent);
+        Crypto::PK::RSA rsa(privkey);
+        if (rsa.public_key().public_exponent() == 65537) {
+            if (rsa.private_key().private_exponent() == "4234603516465654167360850580101327813936403862038934287300450163438938741499875303761385527882335478349599685406941909381269804396099893549838642251053393"_bigint) {
+                PASS;
+            } else
+                FAIL(Invalid private exponent);
+        } else {
+            FAIL(Invalid public exponent);
+        }
+    }
+
+    {
+        I_TEST((RSA | ASN1 PKCS8 DER / PEM encoded Key import));
+        auto privkey = R"(-----BEGIN PRIVATE KEY-----
+MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC7ZBYaG9+CcJP7
+WVFJRI/uw3hljc7WpzeYs8MN82/g9CG1gnEF3P3ZSBdWVr8gnbh05EsSGHKghIce
+CB7DNrM5Ab0ru04CuODdPx56xCj+4MmzTc/aq79ntmOt131NGHgq9yVwfJqnSpyl
+OoVw7j/Wg4ciwPDQaeLmD1BsE/W9UsF1km7DWasBpW5br82DpudKgJq2Ixf52+rY
+TCkMgyWcetx4MfXll4y5ZVtJXCnHJfkCS64EaCqXmClP4ovOuHH4khJ3rW9j4yuL
+e5ck3PSXOrtOTR43HZkCXzseCkbW7qKSmk/9ZreImOzOgu8vvw7ewLAQR9qYVS6X
+PXY8IilDAgMBAAECggEBAIV3ld5mt90Z/exqA2Fh+fofMyNxyz5Lv2d9sZHAL5FT
+kKbND18TtaIKnMSb6Gl8rKJk76slyo7Vlb8oHXEBBsm1mV0KfVenAlHS4QyjpmdT
+B5Yz97VR2nQuDfUFpHNC2GQRv5LMzQIWPFfaxKxYpRNOfvOb5Gks4bTmd2tjFAYR
+MCbHgPw1liKA9dYKk4NB0301EY05e4Zz8RjqYHkkmOPD7DnjFbHqcFUjVKK5E3vD
+WjxNXUbiSudCCN7WLEOyeHZNd+l6kSAVxZuCAp0G3Da5ndXgIStcy4hYi/fL3XQQ
+bNpxjfhsjlD3tdHNr3NNYDAqxcxpsyO1NCpCIW3ZVrECgYEA7l6gTZ3e9AiSNlMd
+2O2vNnbQ6UZfsEfu2y7HmpCuNJkFkAnM/1h72Krejnn31rRuR6uCFn4YgQUN9Eq0
+E1PJCtTay2ucZw5rqtkewT9QzXvVD9eiGM+MF89UzSCC+dOW0/odkD+xP2evnPvG
+PbXztnuERC1pi0YWLj1YcsfsEX0CgYEAyUA2UtYjnvCcteIy+rURT0aoZ9tDMrG+
+Es42EURVv1sduVdUst5R+bXx1aDzpCkcdni3TyxeosvTGAZngI3O8ghh1GV7NPZR
+nkiPXjMnhL0Zf+X9gCA6TFANfPuWhMSGijYsCd46diKGDReGYUnmcN9XopeG1h6i
+3JiOuVPAIb8CgYBmIcUtfGb6yHFdNV+kgrJ/84ivaqe1MBz3bKO5ZiQ+BRKNFKXx
+AkiOHSgeg8PdCpH1w1aJrJ1zKmdANIHThiKtsWXNot3wig03tq+mvSox4Mz5bLrX
+RpYP3ZXIDhYQVMhbKt9f3upi8FoeOQJHjp5Nob6aN5rxQaZfSYmMJHzRQQKBgQCO
+ALwUGTtLNBYvlKtKEadkG8RKfAFfbOFkXZLy/hfPDRjdJY0DJTIMk+BPT+F6rPOD
+eMxHllQ0ZMPPiP1RTT5/s4BsISsdhMy0dhiLbGbvF4s9nugPly3rmPTbgp6DkjQo
+o+7RC7iOkO+rnzTXwxBSBpXMiUTAIx/hrdfPVxQT+wKBgCh7N3OLIOH6EWcW1fif
+UoENh8rkt/kzm89G1JLwBhuBIBPXUEZt2dS/xSUempqVqFGONpP87gvqxkMTtgCA
+73KXn/cxHWM2kmXyHA3kQlOYw6WHjpldQAxLE+TRHXO2JUtZ09Mu4rVXX7lmwbTm
+l3vmuDEF3/Bo1C1HTg0xRV/l
+-----END PRIVATE KEY-----)";
+
+        Crypto::PK::RSA rsa(privkey);
+        if (rsa.public_key().public_exponent() == 65537) {
+            if (rsa.private_key().private_exponent() == "16848664331299797559656678180469464902267415922431923391961407795209879741791261105581093539484181644099608161661780611501562625272630894063592208758992911105496755004417051031019663332258403844985328863382168329621318366311519850803972480500782200178279692319955495383119697563295214236936264406600739633470565823022975212999060908747002623721589308539473108154612454595201561671949550531384574873324370774408913092560971930541734744950937900805812300970883306404011323308000168926094053141613790857814489531436452649384151085451448183385611208320292948291211969430321231180227006521681776197974694030147965578466993"_bigint) {
+                PASS;
+            } else
+                FAIL(Invalid private exponent);
+        } else {
+            FAIL(Invalid public exponent);
+        }
     }
 }