Browse Source

LibCompress: Handle arbitrarily long FF-chains in the LZMA encoder

Tim Schumacher 2 years ago
parent
commit
a01968ee6d

+ 20 - 0
Tests/LibCompress/TestLzma.cpp

@@ -75,6 +75,26 @@ TEST_CASE(compress_decompress_roundtrip_with_unknown_size)
     EXPECT_EQ(uncompressed, result.span());
 }
 
+TEST_CASE(compress_long_overflow_chain)
+{
+    // Encoding 0xFF followed by the end-of-stream marker results in a chain of bytes that doesn't fit into 64 bits,
+    // which breaks naive implementations of "hold back the byte until it no longer changes".
+
+    Array<u8, 1> const uncompressed {
+        0xFF
+    };
+
+    auto stream = MUST(try_make<AllocatingMemoryStream>());
+    auto compressor = TRY_OR_FAIL(Compress::LzmaCompressor::create_container(MaybeOwned<Stream> { *stream }, {}));
+    TRY_OR_FAIL(compressor->write_until_depleted(uncompressed));
+    TRY_OR_FAIL(compressor->flush());
+
+    auto decompressor = TRY_OR_FAIL(Compress::LzmaDecompressor::create_from_container(MaybeOwned<Stream> { *stream }));
+    auto result = TRY_OR_FAIL(decompressor->read_until_eof());
+
+    EXPECT_EQ(uncompressed, result.span());
+}
+
 // The following tests are based on test files from the LZMA specification, which has been placed in the public domain.
 // LZMA Specification Draft (2015): https://www.7-zip.org/a/lzma-specification.7z
 

+ 49 - 28
Userland/Libraries/LibCompress/Lzma.cpp

@@ -249,33 +249,53 @@ ErrorOr<void> LzmaDecompressor::normalize_range_decoder()
     return {};
 }
 
+ErrorOr<void> LzmaCompressor::shift_range_encoder()
+{
+    if ((m_range_encoder_code >> 32) == 0x01) {
+        // If there is an overflow, we can finalize the chain we were previously building.
+        // This includes incrementing both the cached byte and all the 0xFF bytes that we generate.
+        VERIFY(m_range_encoder_cached_byte != 0xFF);
+        TRY(m_stream->write_value<u8>(m_range_encoder_cached_byte + 1));
+        for (size_t i = 0; i < m_range_encoder_ff_chain_length; i++)
+            TRY(m_stream->write_value<u8>(0x00));
+        m_range_encoder_ff_chain_length = 0;
+        m_range_encoder_cached_byte = (m_range_encoder_code >> 24);
+    } else if ((m_range_encoder_code >> 24) == 0xFF) {
+        // If the byte to flush is 0xFF, it can potentially propagate an overflow and needs to be added to the chain.
+        m_range_encoder_ff_chain_length++;
+    } else {
+        // If the byte to flush isn't 0xFF, any future overflows will not be propagated beyond this point,
+        // so we can be sure that the built chain doesn't change anymore.
+        TRY(m_stream->write_value<u8>(m_range_encoder_cached_byte));
+        for (size_t i = 0; i < m_range_encoder_ff_chain_length; i++)
+            TRY(m_stream->write_value<u8>(0xFF));
+        m_range_encoder_ff_chain_length = 0;
+        m_range_encoder_cached_byte = (m_range_encoder_code >> 24);
+    }
+
+    // In all three cases we now recorded the highest byte in some way, so we can shift it away and shift in a null byte as the lowest byte.
+    m_range_encoder_range <<= 8;
+    m_range_encoder_code <<= 8;
+
+    // Since we are working with a 64-bit code, we need to limit it to 32 bits artificially.
+    m_range_encoder_code &= 0xFFFFFFFF;
+
+    return {};
+}
+
 ErrorOr<void> LzmaCompressor::normalize_range_encoder()
 {
     u64 const maximum_range_value = m_range_encoder_code + m_range_encoder_range;
 
-    // If we hit this, we have the potential to overflow into a byte that we already flushed.
-    VERIFY((maximum_range_value & ((1ull << m_range_encoder_code_used_bits) - 1)) == maximum_range_value);
+    // Logically, we should only ever build up an overflow that is smaller than or equal to 0x01.
+    VERIFY((maximum_range_value >> 32) <= 0x01);
 
     constexpr u32 minimum_range_value = 1 << 24;
 
     if (m_range_encoder_range >= minimum_range_value)
         return {};
 
-    u64 const flipped_bits = maximum_range_value ^ m_range_encoder_code;
-    u64 const size_of_flipped_bits = count_required_bits(flipped_bits);
-
-    // If we can flush a full byte without impacting future bits, do so.
-    while (m_range_encoder_code_used_bits - 8 >= size_of_flipped_bits) {
-        u8 const next_byte = (m_range_encoder_code >> (m_range_encoder_code_used_bits - 8));
-        m_range_encoder_code -= static_cast<u64>(next_byte) << (m_range_encoder_code_used_bits - 8);
-        m_range_encoder_code_used_bits -= 8;
-        TRY(m_stream->write_value(next_byte));
-    }
-
-    // Now, shift in a fresh null byte from the bottom.
-    m_range_encoder_range <<= 8;
-    m_range_encoder_code <<= 8;
-    m_range_encoder_code_used_bits += 8;
+    TRY(shift_range_encoder());
 
     VERIFY(m_range_encoder_range >= minimum_range_value);
 
@@ -1212,10 +1232,6 @@ ErrorOr<NonnullOwnPtr<LzmaCompressor>> LzmaCompressor::create_container(MaybeOwn
     auto header = TRY(LzmaHeader::from_compressor_options(options));
     TRY(stream->write_value(header));
 
-    // Note: The reference LZMA implementation has a starting null byte due to how their overflow reservoir is implemented and subsequently wrote it into the specification.
-    //       Therefore, we just have to add it manually.
-    TRY(stream->write_value<u8>(0x00));
-
     auto compressor = TRY(adopt_nonnull_own_or_enomem(new (nothrow) LzmaCompressor(move(stream), options, move(dictionary), move(literal_probabilities))));
 
     return compressor;
@@ -1276,13 +1292,18 @@ ErrorOr<void> LzmaCompressor::flush()
     if (!m_options.uncompressed_size.has_value())
         TRY(encode_normalized_simple_match(end_of_stream_marker, 0));
 
-    while (m_range_encoder_code_used_bits > 0) {
-        VERIFY(m_range_encoder_code_used_bits >= 8);
-        u8 const next_byte = (m_range_encoder_code >> (m_range_encoder_code_used_bits - 8));
-        m_range_encoder_code -= static_cast<u64>(next_byte) << (m_range_encoder_code_used_bits - 8);
-        m_range_encoder_code_used_bits -= 8;
-        TRY(m_stream->write_value(next_byte));
-    }
+    // Shifting the range encoder using the normal operation handles any pending overflows.
+    TRY(shift_range_encoder());
+
+    // Now, the remaining bytes are the cached byte, the chain of 0xFF, and the upper 3 bytes of the current `code`.
+    // Incrementing the values does not have to be considered as no overflows are pending. The fourth byte is the
+    // null byte that we just shifted in, which should not be flushed as it would be extraneous junk data.
+    TRY(m_stream->write_value<u8>(m_range_encoder_cached_byte));
+    for (size_t i = 0; i < m_range_encoder_ff_chain_length; i++)
+        TRY(m_stream->write_value<u8>(0xFF));
+    TRY(m_stream->write_value<u8>(m_range_encoder_code >> 24));
+    TRY(m_stream->write_value<u8>(m_range_encoder_code >> 16));
+    TRY(m_stream->write_value<u8>(m_range_encoder_code >> 8));
 
     m_has_flushed_data = true;
     return {};

+ 7 - 1
Userland/Libraries/LibCompress/Lzma.h

@@ -225,6 +225,7 @@ public:
 private:
     LzmaCompressor(MaybeOwned<Stream>, LzmaCompressorOptions, MaybeOwned<CircularBuffer>, FixedArray<Probability> literal_probabilities);
 
+    ErrorOr<void> shift_range_encoder();
     ErrorOr<void> normalize_range_encoder();
     ErrorOr<void> encode_direct_bit(u8 value);
     ErrorOr<void> encode_bit_with_probability(Probability&, u8 value);
@@ -253,7 +254,12 @@ private:
     // Range encoder state.
     u32 m_range_encoder_range { 0xFFFFFFFF };
     u64 m_range_encoder_code { 0 };
-    size_t m_range_encoder_code_used_bits { 32 };
+
+    // Since the range is only 32-bits, we can overflow at most +1 into the next byte beyond the usual 32-bit code.
+    // Therefore, it is sufficient to store the highest byte (which may still change due to that +1 overflow) and
+    // the length of the chain of 0xFF bytes that may end up propagating that change.
+    u8 m_range_encoder_cached_byte { 0x00 };
+    size_t m_range_encoder_ff_chain_length { 0 };
 };
 
 }