AK+LibWeb: Replace our home-grown base64 encoder/decoders with simdutf

We currently have 2 base64 coders: one in AK, another in LibWeb for a
"forgiving" implementation. ECMA-262 has an upcoming proposal which will
require a third implementation.

Instead, let's use the base64 implementation that is used by Node.js and
recommended by the upcoming proposal. It handles forgiving decoding as
well.

Our users of AK's implementation should be fine with the forgiving
implementation. The AK impl originally had naive forgiving behavior, but
that was removed solely for performance reasons.

Using http://mattmahoney.net/dc/enwik8.zip (100MB unzipped) as a test,
performance of our old home-grown implementations vs. the simdutf
implementation (on Linux x64):

                Encode    Decode
AK base64       0.226s    0.169s
LibWeb base64   N/A       1.244s
simdutf         0.161s    0.047s
This commit is contained in:
Timothy Flynn 2024-07-15 15:25:08 -04:00 committed by Andreas Kling
parent 58dfe5424f
commit bfc9dc447f
Notes: sideshowbarker 2024-07-16 23:34:49 +09:00
11 changed files with 60 additions and 310 deletions

View file

@ -4,116 +4,51 @@
* SPDX-License-Identifier: BSD-2-Clause
*/
#include <AK/Assertions.h>
#define AK_DONT_REPLACE_STD
#include <AK/Base64.h>
#include <AK/Error.h>
#include <AK/StringBuilder.h>
#include <AK/Types.h>
#include <AK/Vector.h>
#include <simdutf.h>
namespace AK {
size_t calculate_base64_decoded_length(StringView input)
static ErrorOr<ByteBuffer> decode_base64_impl(StringView input, simdutf::base64_options options)
{
auto length = input.length() * 3 / 4;
if (input.ends_with("="sv))
--length;
if (input.ends_with("=="sv))
--length;
return length;
}
size_t calculate_base64_encoded_length(ReadonlyBytes input)
{
return ((4 * input.size() / 3) + 3) & ~3;
}
static ErrorOr<ByteBuffer> decode_base64_impl(StringView input, ReadonlySpan<i16> alphabet_lookup_table)
{
input = input.trim_whitespace();
if (input.length() % 4 != 0)
return Error::from_string_literal("Invalid length of Base64 encoded string");
auto get = [&](size_t offset, bool* is_padding) -> ErrorOr<u8> {
if (offset >= input.length())
return 0;
auto ch = static_cast<unsigned char>(input[offset]);
if (ch == '=') {
if (!is_padding)
return Error::from_string_literal("Invalid '=' character outside of padding in base64 data");
*is_padding = true;
return 0;
}
i16 result = alphabet_lookup_table[ch];
if (result < 0)
return Error::from_string_literal("Invalid character in base64 data");
VERIFY(result < 256);
return { result };
};
ByteBuffer output;
TRY(output.try_resize(calculate_base64_decoded_length(input)));
TRY(output.try_resize(simdutf::maximal_binary_length_from_base64(input.characters_without_null_termination(), input.length())));
size_t input_offset = 0;
size_t output_offset = 0;
auto result = simdutf::base64_to_binary(
input.characters_without_null_termination(),
input.length(),
reinterpret_cast<char*>(output.data()),
options);
while (input_offset < input.length()) {
bool in2_is_padding = false;
bool in3_is_padding = false;
u8 const in0 = TRY(get(input_offset++, nullptr));
u8 const in1 = TRY(get(input_offset++, nullptr));
u8 const in2 = TRY(get(input_offset++, &in2_is_padding));
u8 const in3 = TRY(get(input_offset++, &in3_is_padding));
output[output_offset++] = (in0 << 2) | ((in1 >> 4) & 3);
if (!in2_is_padding)
output[output_offset++] = ((in1 & 0xf) << 4) | ((in2 >> 2) & 0xf);
if (!in3_is_padding)
output[output_offset++] = ((in2 & 0x3) << 6) | in3;
}
if (result.error != simdutf::SUCCESS)
return Error::from_string_literal("Invalid base64-encoded data");
output.resize(result.count);
return output;
}
static ErrorOr<String> encode_base64_impl(ReadonlyBytes input, ReadonlySpan<char> alphabet)
static ErrorOr<String> encode_base64_impl(StringView input, simdutf::base64_options options)
{
Vector<u8> output;
TRY(output.try_ensure_capacity(calculate_base64_encoded_length(input)));
auto get = [&](size_t const offset, bool* need_padding = nullptr) -> u8 {
if (offset >= input.size()) {
if (need_padding)
*need_padding = true;
return 0;
}
return input[offset];
};
// simdutf does not append padding to base64url encodings. We use the default encoding option here to allocate room
// for the padding characters that we will later append ourselves if necessary.
TRY(output.try_resize(simdutf::base64_length_from_binary(input.length(), simdutf::base64_default)));
for (size_t i = 0; i < input.size(); i += 3) {
bool is_8bit = false;
bool is_16bit = false;
auto size_written = simdutf::binary_to_base64(
input.characters_without_null_termination(),
input.length(),
reinterpret_cast<char*>(output.data()),
options);
u8 const in0 = get(i);
u8 const in1 = get(i + 1, &is_16bit);
u8 const in2 = get(i + 2, &is_8bit);
u8 const index0 = (in0 >> 2) & 0x3f;
u8 const index1 = ((in0 << 4) | (in1 >> 4)) & 0x3f;
u8 const index2 = ((in1 << 2) | (in2 >> 6)) & 0x3f;
u8 const index3 = in2 & 0x3f;
output.unchecked_append(alphabet[index0]);
output.unchecked_append(alphabet[index1]);
output.unchecked_append(is_16bit ? '=' : alphabet[index2]);
output.unchecked_append(is_8bit ? '=' : alphabet[index3]);
if (options == simdutf::base64_url) {
for (size_t i = size_written; i < output.size(); ++i)
output[i] = '=';
}
return String::from_utf8_without_validation(output);
@ -121,23 +56,22 @@ static ErrorOr<String> encode_base64_impl(ReadonlyBytes input, ReadonlySpan<char
ErrorOr<ByteBuffer> decode_base64(StringView input)
{
static constexpr auto lookup_table = base64_lookup_table();
return decode_base64_impl(input, lookup_table);
return decode_base64_impl(input, simdutf::base64_default);
}
ErrorOr<ByteBuffer> decode_base64url(StringView input)
{
static constexpr auto lookup_table = base64url_lookup_table();
return decode_base64_impl(input, lookup_table);
return decode_base64_impl(input, simdutf::base64_url);
}
ErrorOr<String> encode_base64(ReadonlyBytes input)
{
return encode_base64_impl(input, base64_alphabet);
return encode_base64_impl(input, simdutf::base64_default);
}
ErrorOr<String> encode_base64url(ReadonlyBytes input)
{
return encode_base64_impl(input, base64url_alphabet);
return encode_base64_impl(input, simdutf::base64_url);
}
}

View file

@ -6,7 +6,6 @@
#pragma once
#include <AK/Array.h>
#include <AK/ByteBuffer.h>
#include <AK/Error.h>
#include <AK/String.h>
@ -14,59 +13,12 @@
namespace AK {
// https://datatracker.ietf.org/doc/html/rfc4648#section-4
constexpr Array base64_alphabet = {
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H',
'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X',
'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f',
'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
'o', 'p', 'q', 'r', 's', 't', 'u', 'v',
'w', 'x', 'y', 'z', '0', '1', '2', '3',
'4', '5', '6', '7', '8', '9', '+', '/'
};
ErrorOr<ByteBuffer> decode_base64(StringView);
ErrorOr<ByteBuffer> decode_base64url(StringView);
// https://datatracker.ietf.org/doc/html/rfc4648#section-5
constexpr Array base64url_alphabet = {
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H',
'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X',
'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f',
'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
'o', 'p', 'q', 'r', 's', 't', 'u', 'v',
'w', 'x', 'y', 'z', '0', '1', '2', '3',
'4', '5', '6', '7', '8', '9', '-', '_'
};
ErrorOr<String> encode_base64(ReadonlyBytes);
ErrorOr<String> encode_base64url(ReadonlyBytes);
consteval auto base64_lookup_table()
{
Array<i16, 256> table;
table.fill(-1);
for (size_t i = 0; i < base64_alphabet.size(); ++i) {
table[base64_alphabet[i]] = static_cast<i16>(i);
}
return table;
}
consteval auto base64url_lookup_table()
{
Array<i16, 256> table;
table.fill(-1);
for (size_t i = 0; i < base64url_alphabet.size(); ++i) {
table[base64url_alphabet[i]] = static_cast<i16>(i);
}
return table;
}
[[nodiscard]] size_t calculate_base64_decoded_length(StringView);
[[nodiscard]] size_t calculate_base64_encoded_length(ReadonlyBytes);
[[nodiscard]] ErrorOr<ByteBuffer> decode_base64(StringView);
[[nodiscard]] ErrorOr<ByteBuffer> decode_base64url(StringView);
[[nodiscard]] ErrorOr<String> encode_base64(ReadonlyBytes);
[[nodiscard]] ErrorOr<String> encode_base64url(ReadonlyBytes);
}
#if USING_AK_GLOBALLY

View file

@ -1,6 +1,5 @@
set(SOURCES
Assertions.cpp
Base64.cpp
CircularBuffer.cpp
ConstrainedStream.cpp
CountingStream.cpp
@ -38,6 +37,10 @@ set(SOURCES
kmalloc.cpp
)
if (NOT LAGOM_TOOLS_ONLY)
list(APPEND SOURCES Base64.cpp)
endif()
serenity_lib(AK ak)
serenity_install_headers(AK)
@ -56,3 +59,8 @@ if (Backtrace_FOUND)
else()
message(WARNING "Backtrace not found, stack traces will be unavailable")
endif()
if (NOT LAGOM_TOOLS_ONLY)
find_package(simdutf REQUIRED)
target_link_libraries(AK PRIVATE simdutf::simdutf)
endif()

View file

@ -7,15 +7,13 @@
#include <LibTest/TestCase.h>
#include <AK/Base64.h>
#include <AK/ByteString.h>
#include <string.h>
TEST_CASE(test_decode)
{
auto decode_equal = [&](StringView input, StringView expected) {
auto decoded = TRY_OR_FAIL(decode_base64(input));
EXPECT(ByteString::copy(decoded) == expected);
EXPECT(expected.length() <= calculate_base64_decoded_length(input.bytes()));
EXPECT_EQ(StringView { decoded }, expected);
};
decode_equal(""sv, ""sv);
@ -26,7 +24,7 @@ TEST_CASE(test_decode)
decode_equal("Zm9vYmE="sv, "fooba"sv);
decode_equal("Zm9vYmFy"sv, "foobar"sv);
decode_equal(" Zm9vYmFy "sv, "foobar"sv);
decode_equal(" \n\r \t Zm9vYmFy \n"sv, "foobar"sv);
decode_equal(" \n\r \t Zm 9v \t YmFy \n"sv, "foobar"sv);
decode_equal("aGVsbG8/d29ybGQ="sv, "hello?world"sv);
}
@ -42,9 +40,7 @@ TEST_CASE(test_decode_invalid)
EXPECT(decode_base64url("aGVsbG8/d29ybGQ="sv).is_error());
EXPECT(decode_base64("Y"sv).is_error());
EXPECT(decode_base64("YQ"sv).is_error());
EXPECT(decode_base64("YQ="sv).is_error());
EXPECT(decode_base64("PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHdpZHRoPSIxMC42MDUiIGhlaWdodD0iMTUuNTU1Ij48cGF0aCBmaWxsPSIjODg5IiBkPSJtMi44MjggMTUuNTU1IDcuNzc3LTcuNzc5TDIuODI4IDAgMCAyLjgyOGw0Ljk0OSA0Ljk0OEwwIDEyLjcyN2wyLjgyOCAyLjgyOHoiLz48L3N2Zz4"sv).is_error());
}
TEST_CASE(test_decode_only_padding)
@ -65,8 +61,7 @@ TEST_CASE(test_encode)
{
auto encode_equal = [&](StringView input, StringView expected) {
auto encoded = MUST(encode_base64(input.bytes()));
EXPECT(encoded == expected);
EXPECT_EQ(expected.length(), calculate_base64_encoded_length(input.bytes()));
EXPECT_EQ(encoded, expected);
};
encode_equal(""sv, ""sv);
@ -82,8 +77,7 @@ TEST_CASE(test_urldecode)
{
auto decode_equal = [&](StringView input, StringView expected) {
auto decoded = TRY_OR_FAIL(decode_base64url(input));
EXPECT(ByteString::copy(decoded) == expected);
EXPECT(expected.length() <= calculate_base64_decoded_length(input.bytes()));
EXPECT_EQ(StringView { decoded }, expected);
};
decode_equal(""sv, ""sv);
@ -104,8 +98,7 @@ TEST_CASE(test_urlencode)
{
auto encode_equal = [&](StringView input, StringView expected) {
auto encoded = MUST(encode_base64url(input.bytes()));
EXPECT(encoded == expected);
EXPECT_EQ(expected.length(), calculate_base64_encoded_length(input.bytes()));
EXPECT_EQ(encoded, expected);
};
encode_equal(""sv, ""sv);

View file

@ -460,7 +460,6 @@ set(SOURCES
HTML/ValidityState.cpp
HighResolutionTime/Performance.cpp
HighResolutionTime/TimeOrigin.cpp
Infra/Base64.cpp
Infra/ByteSequences.cpp
Infra/JSON.cpp
Infra/Strings.cpp

View file

@ -6,10 +6,10 @@
* SPDX-License-Identifier: BSD-2-Clause
*/
#include <AK/Base64.h>
#include <LibURL/URL.h>
#include <LibWeb/Fetch/Infrastructure/URL.h>
#include <LibWeb/MimeSniff/MimeType.h>
#include <Userland/Libraries/LibWeb/Infra/Base64.h>
namespace Web::Fetch::Infrastructure {
@ -79,7 +79,7 @@ ErrorOr<DataURL> process_data_url(URL::URL const& data_url)
// 2. Set body to the forgiving-base64 decode of stringBody.
// 3. If body is failure, then return failure.
body = TRY(Infra::decode_forgiving_base64(string_body));
body = TRY(decode_base64(string_body));
// 4. Remove the last 6 code points from mimeType.
// 5. Remove trailing U+0020 SPACE code points from mimeType, if any.

View file

@ -6,7 +6,6 @@
* SPDX-License-Identifier: BSD-2-Clause
*/
#include <AK/Base64.h>
#include <AK/Utf8View.h>
#include <LibIPC/File.h>
#include <LibJS/Runtime/AbstractOperations.h>
@ -59,7 +58,6 @@
#include <LibWeb/HTML/Window.h>
#include <LibWeb/HTML/WindowProxy.h>
#include <LibWeb/HighResolutionTime/TimeOrigin.h>
#include <LibWeb/Infra/Base64.h>
#include <LibWeb/Infra/CharacterTypes.h>
#include <LibWeb/Internals/Inspector.h>
#include <LibWeb/Internals/Internals.h>

View file

@ -32,7 +32,6 @@
#include <LibWeb/HighResolutionTime/Performance.h>
#include <LibWeb/HighResolutionTime/SupportedPerformanceTypes.h>
#include <LibWeb/IndexedDB/IDBFactory.h>
#include <LibWeb/Infra/Base64.h>
#include <LibWeb/PerformanceTimeline/EntryTypes.h>
#include <LibWeb/PerformanceTimeline/PerformanceObserver.h>
#include <LibWeb/PerformanceTimeline/PerformanceObserverEntryList.h>
@ -130,14 +129,14 @@ WebIDL::ExceptionOr<String> WindowOrWorkerGlobalScopeMixin::atob(String const& d
auto& realm = *vm.current_realm();
// 1. Let decodedData be the result of running forgiving-base64 decode on data.
auto decoded_data = Infra::decode_forgiving_base64(data.bytes_as_string_view());
auto decoded_data = decode_base64(data);
// 2. If decodedData is failure, then throw an "InvalidCharacterError" DOMException.
if (decoded_data.is_error())
return WebIDL::InvalidCharacterError::create(realm, "Input string is not valid base64 data"_fly_string);
// 3. Return decodedData.
// decode_base64() returns a byte string. LibJS uses UTF-8 for strings. Use Latin1Decoder to convert bytes 128-255 to UTF-8.
// decode_base64() returns a byte buffer. LibJS uses UTF-8 for strings. Use Latin1Decoder to convert bytes 128-255 to UTF-8.
auto decoder = TextCodec::decoder_for_exact_name("ISO-8859-1"sv);
VERIFY(decoder.has_value());
return TRY_OR_THROW_OOM(vm, decoder->to_utf8(decoded_data.value()));

View file

@ -1,123 +0,0 @@
/*
* Copyright (c) 2022-2023, the SerenityOS developers.
*
* SPDX-License-Identifier: BSD-2-Clause
*/
#include <AK/Base64.h>
#include <AK/ByteBuffer.h>
#include <AK/CharacterTypes.h>
#include <AK/Error.h>
#include <AK/StringBuilder.h>
#include <AK/StringView.h>
#include <AK/Vector.h>
#include <LibWeb/Infra/Base64.h>
#include <LibWeb/Infra/CharacterTypes.h>
namespace Web::Infra {
// https://infra.spec.whatwg.org/#forgiving-base64
ErrorOr<ByteBuffer> decode_forgiving_base64(StringView input)
{
// 1. Remove all ASCII whitespace from data.
// FIXME: It is possible to avoid copying input here, it's just a bit tricky to remove the equal signs
StringBuilder builder;
for (auto character : input) {
if (!is_ascii_whitespace(character))
TRY(builder.try_append(character));
}
auto data = builder.string_view();
// 2. If datas code point length divides by 4 leaving no remainder, then:
if (data.length() % 4 == 0) {
// If data ends with one or two U+003D (=) code points, then remove them from data.
if (data.ends_with("=="sv))
data = data.substring_view(0, data.length() - 2);
else if (data.ends_with('='))
data = data.substring_view(0, data.length() - 1);
}
// 3. If datas code point length divides by 4 leaving a remainder of 1, then return failure.
if (data.length() % 4 == 1)
return Error::from_string_literal("Invalid input length in forgiving base64 decode");
// 4. If data contains a code point that is not one of
// U+002B (+), U+002F (/), ASCII alphanumeric
// then return failure.
for (auto point : data) {
if (point != '+' && point != '/' && !is_ascii_alphanumeric(point))
return Error::from_string_literal("Invalid character in forgiving base64 decode");
}
// 5. Let output be an empty byte sequence.
// 6. Let buffer be an empty buffer that can have bits appended to it.
Vector<u8> output;
u32 buffer = 0;
auto accumulated_bits = 0;
auto add_to_buffer = [&](u8 number) {
VERIFY(number < 64);
u32 buffer_mask = number;
if (accumulated_bits == 0)
buffer_mask <<= 18;
else if (accumulated_bits == 6)
buffer_mask <<= 12;
else if (accumulated_bits == 12)
buffer_mask <<= 6;
else if (accumulated_bits == 18)
buffer_mask <<= 0;
buffer |= buffer_mask;
accumulated_bits += 6;
};
auto append_bytes = [&]() {
output.append(static_cast<u8>((buffer & 0xff0000) >> 16));
output.append(static_cast<u8>((buffer & 0xff00) >> 8));
output.append(static_cast<u8>(buffer & 0xff));
buffer = 0;
accumulated_bits = 0;
};
auto alphabet_lookup_table = AK::base64_lookup_table();
// 7. Let position be a position variable for data, initially pointing at the start of data.
// 8. While position does not point past the end of data:
for (auto point : data) {
// 1. Find the code point pointed to by position in the second column of Table 1: The Base 64 Alphabet of RFC 4648.
// Let n be the number given in the first cell of the same row. [RFC4648]
auto n = alphabet_lookup_table[point];
VERIFY(n >= 0);
// 2. Append the six bits corresponding to n, most significant bit first, to buffer.
add_to_buffer(static_cast<u8>(n));
// 3. buffer has accumulated 24 bits,
if (accumulated_bits == 24) {
// interpret them as three 8-bit big-endian numbers.
// Append three bytes with values equal to those numbers to output, in the same order, and then empty buffer
append_bytes();
}
}
// 9. If buffer is not empty, it contains either 12 or 18 bits.
VERIFY(accumulated_bits == 0 || accumulated_bits == 12 || accumulated_bits == 18);
// If it contains 12 bits, then discard the last four and interpret the remaining eight as an 8-bit big-endian number.
if (accumulated_bits == 12)
output.append(static_cast<u8>((buffer & 0xff0000) >> 16));
// If it contains 18 bits, then discard the last two and interpret the remaining 16 as two 8-bit big-endian numbers.
// Append the one or two bytes with values equal to those one or two numbers to output, in the same order.
if (accumulated_bits == 18) {
output.append(static_cast<u8>((buffer & 0xff0000) >> 16));
output.append(static_cast<u8>((buffer & 0xff00) >> 8));
}
return ByteBuffer::copy(output);
}
}

View file

@ -1,15 +0,0 @@
/*
* Copyright (c) 2022-2023, the SerenityOS developers.
*
* SPDX-License-Identifier: BSD-2-Clause
*/
#pragma once
#include <AK/Forward.h>
namespace Web::Infra {
[[nodiscard]] ErrorOr<ByteBuffer> decode_forgiving_base64(StringView);
}

View file

@ -19,6 +19,7 @@
"dav1d"
]
},
"simdutf",
{
"name": "skia",
"platform": "osx",
@ -38,7 +39,6 @@
"platform": "android"
},
"sqlite3",
"woff2",
{
"name": "vulkan",
"platform": "!android"
@ -46,7 +46,8 @@
{
"name": "vulkan-headers",
"platform": "!android"
}
},
"woff2"
],
"overrides": [
{
@ -69,6 +70,10 @@
"name": "libavif",
"version": "1.0.4#1"
},
{
"name": "simdutf",
"version": "5.2.5#0"
},
{
"name": "skia",
"version": "124#0"