diff --git a/AK/Base64.cpp b/AK/Base64.cpp index c4155e12fcf..47303358547 100644 --- a/AK/Base64.cpp +++ b/AK/Base64.cpp @@ -17,21 +17,40 @@ size_t size_required_to_decode_base64(StringView input) return simdutf::maximal_binary_length_from_base64(input.characters_without_null_termination(), input.length()); } +static ErrorOr decode_base64_into_impl(StringView input, ByteBuffer& output, simdutf::base64_options options) +{ + size_t output_length = output.size(); + + auto result = simdutf::base64_to_binary_safe( + input.characters_without_null_termination(), + input.length(), + reinterpret_cast(output.data()), + output_length, + options); + + if (result.error != simdutf::SUCCESS && result.error != simdutf::OUTPUT_BUFFER_TOO_SMALL) { + output.resize((result.count / 4) * 3); + + return InvalidBase64 { + .error = Error::from_string_literal("Invalid base64-encoded data"), + .valid_input_bytes = result.count, + }; + } + + VERIFY(output_length <= output.size()); + output.resize(output_length); + + return result.error == simdutf::SUCCESS ? input.length() : result.count; +} + static ErrorOr decode_base64_impl(StringView input, simdutf::base64_options options) { ByteBuffer output; TRY(output.try_resize(size_required_to_decode_base64(input))); - auto result = simdutf::base64_to_binary( - input.characters_without_null_termination(), - input.length(), - reinterpret_cast(output.data()), - options); + if (auto result = decode_base64_into_impl(input, output, options); result.is_error()) + return result.release_error().error; - if (result.error != simdutf::SUCCESS) - return Error::from_string_literal("Invalid base64-encoded data"); - - output.resize(result.count); return output; } @@ -59,6 +78,16 @@ ErrorOr decode_base64url(StringView input) return decode_base64_impl(input, simdutf::base64_url); } +ErrorOr decode_base64_into(StringView input, ByteBuffer& output) +{ + return decode_base64_into_impl(input, output, simdutf::base64_default); +} + +ErrorOr decode_base64url_into(StringView input, ByteBuffer& output) +{ + return decode_base64_into_impl(input, output, simdutf::base64_url); +} + ErrorOr encode_base64(ReadonlyBytes input, OmitPadding omit_padding) { auto options = omit_padding == OmitPadding::Yes diff --git a/AK/Base64.h b/AK/Base64.h index f448a0f9657..0670a1d9df3 100644 --- a/AK/Base64.h +++ b/AK/Base64.h @@ -18,6 +18,16 @@ size_t size_required_to_decode_base64(StringView); ErrorOr decode_base64(StringView); ErrorOr decode_base64url(StringView); +struct InvalidBase64 { + Error error; + size_t valid_input_bytes { 0 }; +}; + +// On success, these return the number of input bytes that were decoded. This might be less than the +// string length if the output buffer was not large enough. +ErrorOr decode_base64_into(StringView, ByteBuffer&); +ErrorOr decode_base64url_into(StringView, ByteBuffer&); + enum class OmitPadding { No, Yes, diff --git a/Tests/AK/TestBase64.cpp b/Tests/AK/TestBase64.cpp index e8006f84755..383938d5ea2 100644 --- a/Tests/AK/TestBase64.cpp +++ b/Tests/AK/TestBase64.cpp @@ -29,6 +29,43 @@ TEST_CASE(test_decode) decode_equal("aGVsbG8/d29ybGQ="sv, "hello?world"sv); } +TEST_CASE(test_decode_into) +{ + ByteBuffer buffer; + + auto decode_equal = [&](StringView input, StringView expected, Optional buffer_size = {}) { + buffer.resize(buffer_size.value_or_lazy_evaluated([&]() { + return AK::size_required_to_decode_base64(input); + })); + + auto result = AK::decode_base64_into(input, buffer); + VERIFY(!result.is_error()); + + EXPECT_EQ(StringView { buffer }, expected); + }; + + decode_equal(""sv, ""sv); + + decode_equal("Zg=="sv, "f"sv); + decode_equal("Zm8="sv, "fo"sv); + decode_equal("Zm9v"sv, "foo"sv); + decode_equal("Zm9vYg=="sv, "foob"sv); + decode_equal("Zm9vYmE="sv, "fooba"sv); + decode_equal("Zm9vYmFy"sv, "foobar"sv); + decode_equal(" Zm9vYmFy "sv, "foobar"sv); + decode_equal(" \n\r \t Zm 9v \t YmFy \n"sv, "foobar"sv); + decode_equal("aGVsbG8/d29ybGQ="sv, "hello?world"sv); + + decode_equal("Zm9vYmFy"sv, ""sv, 0); + decode_equal("Zm9vYmFy"sv, ""sv, 1); + decode_equal("Zm9vYmFy"sv, ""sv, 2); + decode_equal("Zm9vYmFy"sv, "foo"sv, 3); + decode_equal("Zm9vYmFy"sv, "foo"sv, 4); + decode_equal("Zm9vYmFy"sv, "foo"sv, 5); + decode_equal("Zm9vYmFy"sv, "foobar"sv, 6); + decode_equal("Zm9vYmFy"sv, "foobar"sv, 7); +} + TEST_CASE(test_decode_invalid) { EXPECT(decode_base64(("asdf\xffqwe"sv)).is_error());