Browse Source

LibCompress: Implement DEFLATE properly.

Now we have an actual stream implementation that can read arbitrary
(dynamic codes aren't supported yet) deflate encoded data. Even if
the blocks are really large.

And all of that happens with a single buffer of 32KiB. DEFLATE is
amazing!
asynts 5 năm trước cách đây
mục cha
commit
7c53f14bbc

+ 225 - 297
Libraries/LibCompress/Deflate.cpp

@@ -25,404 +25,332 @@
  */
 
 #include <AK/Assertions.h>
+#include <AK/BinarySearch.h>
+#include <AK/FixedArray.h>
 #include <AK/LogStream.h>
-#include <AK/Span.h>
-#include <AK/Types.h>
-#include <AK/Vector.h>
+
 #include <LibCompress/Deflate.h>
 
 namespace Compress {
 
-bool DeflateStream::read_next_block() const
+// FIXME: This logic needs to go into the deflate decoder somehow, we don't want
+//        to assert that the input is valid. Instead we need to set m_error on the
+//        stream.
+DeflateDecompressor::CanonicalCode::CanonicalCode(ReadonlyBytes codes)
 {
-    if (m_read_last_block)
-        return false;
+    // FIXME: I can't quite follow the algorithm here, but it seems to work.
 
-    m_read_last_block = m_reader.read_bits(1);
-    auto block_type = m_reader.read_bits(2);
-
-    switch (block_type) {
-    case 0:
-        decompress_uncompressed_block();
-        break;
-    case 1:
-        decompress_static_block();
-        break;
-    case 2:
-        decompress_dynamic_block();
-        break;
-    case 3:
-        dbg() << "Block contains reserved block type...";
-        ASSERT_NOT_REACHED();
-        break;
-    default:
-        dbg() << "Invalid block type was read...";
+    m_symbol_codes.resize(codes.size());
+    m_symbol_values.resize(codes.size());
+
+    auto allocated_symbols_count = 0;
+    auto next_code = 0;
+
+    for (size_t code_length = 1; code_length <= 15; ++code_length) {
+        next_code <<= 1;
+        auto start_bit = 1 << code_length;
+
+        for (size_t symbol = 0; symbol < codes.size(); ++symbol) {
+            if (codes[symbol] != code_length)
+                continue;
+
+            if (next_code > start_bit) {
+                dbg() << "Canonical code overflows the huffman tree";
+                ASSERT_NOT_REACHED();
+            }
+
+            m_symbol_codes[allocated_symbols_count] = start_bit | next_code;
+            m_symbol_values[allocated_symbols_count] = symbol;
+
+            allocated_symbols_count++;
+            next_code++;
+        }
+    }
+
+    if (next_code != (1 << 15)) {
+        dbg() << "Canonical code underflows the huffman tree " << next_code;
         ASSERT_NOT_REACHED();
-        break;
     }
+}
 
-    return true;
+const DeflateDecompressor::CanonicalCode& DeflateDecompressor::CanonicalCode::fixed_literal_codes()
+{
+    static CanonicalCode* code = nullptr;
+
+    if (code)
+        return *code;
+
+    FixedArray<u8> data { 288 };
+    data.bytes().slice(0, 144 - 0).fill(8);
+    data.bytes().slice(144, 256 - 144).fill(9);
+    data.bytes().slice(256, 280 - 256).fill(7);
+    data.bytes().slice(280, 288 - 280).fill(8);
+
+    code = new CanonicalCode(data);
+    return *code;
 }
 
-void DeflateStream::decompress_uncompressed_block() const
+const DeflateDecompressor::CanonicalCode& DeflateDecompressor::CanonicalCode::fixed_distance_codes()
 {
-    // Align to the next byte boundary.
-    while (m_reader.get_bit_byte_offset() != 0) {
-        m_reader.read();
-    }
+    static CanonicalCode* code = nullptr;
 
-    auto length = m_reader.read_bits(16) & 0xFFFF;
-    auto negated_length = m_reader.read_bits(16) & 0xFFFF;
+    if (code)
+        return *code;
 
-    if ((length ^ 0xFFFF) != negated_length) {
-        dbg() << "Block length is invalid...";
-        ASSERT_NOT_REACHED();
-    }
+    FixedArray<u8> data { 32 };
+    data.bytes().fill(5);
 
-    for (size_t i = 0; i < length; i++) {
-        auto byte = m_reader.read_byte();
-        if (byte < 0) {
-            dbg() << "Ran out of bytes while reading uncompressed block...";
-            ASSERT_NOT_REACHED();
-        }
+    code = new CanonicalCode(data);
+    return *code;
+}
+
+u32 DeflateDecompressor::CanonicalCode::read_symbol(InputBitStream& stream) const
+{
+    u32 code_bits = 1;
+
+    for (;;) {
+        code_bits = code_bits << 1 | stream.read_bits(1);
 
-        m_intermediate_stream << byte;
+        size_t index;
+        if (AK::binary_search(m_symbol_codes.span(), code_bits, AK::integral_compare<u32>, &index))
+            return m_symbol_values[index];
     }
 }
 
-void DeflateStream::decompress_static_block() const
+DeflateDecompressor::CompressedBlock::CompressedBlock(DeflateDecompressor& decompressor, CanonicalCode literal_codes, Optional<CanonicalCode> distance_codes)
+    : m_decompressor(decompressor)
+    , m_literal_codes(literal_codes)
+    , m_distance_codes(distance_codes)
 {
-    decompress_huffman_block(m_literal_length_codes, &m_fixed_distance_codes);
 }
 
-void DeflateStream::decompress_dynamic_block() const
+bool DeflateDecompressor::CompressedBlock::try_read_more()
 {
-    auto codes = decode_huffman_codes();
-    if (codes.size() == 2) {
-        decompress_huffman_block(codes[0], &codes[1]);
+    if (m_eof == true)
+        return false;
+
+    const auto symbol = m_literal_codes.read_symbol(m_decompressor.m_input_stream);
+
+    if (symbol < 256) {
+        m_decompressor.m_output_stream << static_cast<u8>(symbol);
+        return true;
+    } else if (symbol == 256) {
+        m_eof = true;
+        return false;
     } else {
-        decompress_huffman_block(codes[0], nullptr);
+        ASSERT(m_distance_codes.has_value());
+
+        const auto run_length = m_decompressor.decode_run_length(symbol);
+        const auto distance = m_decompressor.decode_distance(m_distance_codes.value().read_symbol(m_decompressor.m_input_stream));
+
+        auto bytes = m_decompressor.m_output_stream.reserve_contigous_space(run_length);
+        m_decompressor.m_output_stream.read(bytes, distance + bytes.size());
+
+        return true;
     }
 }
 
-void DeflateStream::decompress_huffman_block(CanonicalCode& length_codes, CanonicalCode* distance_codes) const
+DeflateDecompressor::UncompressedBlock::UncompressedBlock(DeflateDecompressor& decompressor, size_t length)
+    : m_decompressor(decompressor)
+    , m_bytes_remaining(length)
 {
-    for (;;) {
-        u32 symbol = length_codes.next_symbol(m_reader);
-
-        // End of block.
-        if (symbol == 256) {
-            break;
-        }
+}
 
-        // literal byte.
-        if (symbol < 256) {
-            m_intermediate_stream << static_cast<u8>(symbol);
-            continue;
-        }
+bool DeflateDecompressor::UncompressedBlock::try_read_more()
+{
+    if (m_bytes_remaining == 0)
+        return false;
 
-        // Length and distance for copying.
-        ASSERT(distance_codes);
+    const auto nread = min(m_bytes_remaining, m_decompressor.m_output_stream.remaining_contigous_space());
+    m_bytes_remaining -= nread;
 
-        auto run = decode_run_length(symbol);
-        if (run < 3 || run > 258) {
-            dbg() << "Invalid run length";
-            ASSERT_NOT_REACHED();
-        }
+    m_decompressor.m_input_stream >> m_decompressor.m_output_stream.reserve_contigous_space(nread);
 
-        auto distance_symbol = distance_codes->next_symbol(m_reader);
-        auto distance = decode_distance(distance_symbol);
-        if (distance < 1 || distance > 32768) {
-            dbg() << "Invalid distance";
-            ASSERT_NOT_REACHED();
-        }
+    return true;
+}
 
-        copy_from_history(distance, run);
-    }
+DeflateDecompressor::DeflateDecompressor(InputStream& stream)
+    : m_input_stream(stream)
+{
 }
 
-Vector<CanonicalCode> DeflateStream::decode_huffman_codes() const
+DeflateDecompressor::~DeflateDecompressor()
 {
-    // FIXME: This path is not tested.
-    Vector<CanonicalCode> result;
-
-    auto length_code_count = m_reader.read_bits(5) + 257;
-    auto distance_code_count = m_reader.read_bits(5) + 1;
-
-    size_t length_code_code_length = m_reader.read_bits(4) + 4;
-
-    Vector<u8> code_length_code_length;
-    code_length_code_length.resize(19);
-    code_length_code_length[16] = m_reader.read_bits(3);
-    code_length_code_length[17] = m_reader.read_bits(3);
-    code_length_code_length[18] = m_reader.read_bits(3);
-    code_length_code_length[0] = m_reader.read_bits(3);
-    for (size_t i = 0; i < length_code_code_length; i++) {
-        auto index = (i % 2 == 0) ? (8 + (i / 2)) : (7 - (i / 2));
-        code_length_code_length[index] = m_reader.read_bits(3);
-    }
+    if (m_state == State::ReadingCompressedBlock)
+        m_compressed_block.~CompressedBlock();
+    if (m_state == State::ReadingUncompressedBlock)
+        m_uncompressed_block.~UncompressedBlock();
+}
 
-    auto code_length_code = CanonicalCode(code_length_code_length);
+size_t DeflateDecompressor::read(Bytes bytes)
+{
+    // FIXME: There are surely a ton of bugs because we don't check for read errors
+    //        very often.
 
-    Vector<u32> code_lens;
-    code_lens.resize(length_code_count + distance_code_count);
+    if (m_state == State::Idle) {
+        if (m_read_final_bock)
+            return 0;
 
-    for (size_t index = 0; index < code_lens.capacity();) {
-        auto symbol = code_length_code.next_symbol(m_reader);
+        m_read_final_bock = m_input_stream.read_bit();
+        const auto block_type = m_input_stream.read_bits(2);
 
-        if (symbol <= 15) {
-            code_lens[index] = symbol;
-            index++;
-            continue;
-        }
+        if (block_type == 0b00) {
+            m_input_stream.align_to_byte_boundary();
 
-        u32 run_length;
-        u32 run_value = 0;
+            LittleEndian<u16> length, negated_length;
+            m_input_stream >> length >> negated_length;
 
-        if (symbol == 16) {
-            if (index == 0) {
-                dbg() << "No code length value avaliable";
-                ASSERT_NOT_REACHED();
+            if ((length ^ 0xffff) != negated_length) {
+                m_error = true;
+                return 0;
             }
 
-            run_length = m_reader.read_bits(2) + 3;
-            run_value = code_lens[index - 1];
-        } else if (symbol == 17) {
-            run_length = m_reader.read_bits(3) + 3;
-        } else if (symbol == 18) {
-            run_length = m_reader.read_bits(7) + 11;
-        } else {
-            dbg() << "Code symbol is out of range!";
-            ASSERT_NOT_REACHED();
-        }
+            m_state = State::ReadingUncompressedBlock;
+            new (&m_uncompressed_block) UncompressedBlock(*this, length);
 
-        u32 end = index + run_length;
-        if (end > code_lens.capacity()) {
-            dbg() << "Code run is out of range!";
-            ASSERT_NOT_REACHED();
+            return read(bytes);
         }
 
-        memset(code_lens.data() + index, run_value, run_length);
-        index = end;
-    }
+        if (block_type == 0b01) {
+            m_state = State::ReadingCompressedBlock;
+            new (&m_compressed_block) CompressedBlock(*this, CanonicalCode::fixed_literal_codes(), CanonicalCode::fixed_distance_codes());
 
-    Vector<u8> literal_codes;
-    literal_codes.resize(length_code_count);
-    memcpy(literal_codes.data(), code_lens.data(), literal_codes.capacity());
-    result.append(CanonicalCode(literal_codes));
+            return read(bytes);
+        }
 
-    Vector<u8> distance_codes;
-    distance_codes.resize(distance_code_count);
-    memcpy(distance_codes.data(), code_lens.data() + length_code_count, distance_codes.capacity());
+        if (block_type == 0b10) {
+            CanonicalCode literal_codes, distance_codes;
+            decode_codes(literal_codes, distance_codes);
+            new (&m_compressed_block) CompressedBlock(*this, literal_codes, distance_codes);
 
-    if (distance_code_count == 1 && distance_codes[0] == 0) {
-        return result;
-    }
+            return read(bytes);
+        }
 
-    u8 one_count = 0;
-    u8 other_count = 0;
+        ASSERT_NOT_REACHED();
+    }
 
-    for (size_t i = 0; i < distance_codes.capacity(); i++) {
-        u8 value = distance_codes.at(i);
+    if (m_state == State::ReadingCompressedBlock) {
+        auto nread = m_output_stream.read(bytes);
 
-        if (value == 1) {
-            one_count++;
-        } else if (value > 1) {
-            other_count++;
+        while (nread < bytes.size() && m_compressed_block.try_read_more()) {
+            nread += m_output_stream.read(bytes.slice(nread));
         }
-    }
 
-    if (one_count == 1 && other_count == 0) {
-        distance_codes.resize(32);
-        distance_codes[31] = 1;
-    }
+        if (nread == bytes.size())
+            return nread;
 
-    result.append(CanonicalCode(distance_codes));
-    return result;
-}
+        m_compressed_block.~CompressedBlock();
+        m_state = State::Idle;
 
-u32 DeflateStream::decode_run_length(u32 symbol) const
-{
-    if (symbol <= 264) {
-        return symbol - 254;
+        return nread + read(bytes.slice(nread));
     }
 
-    if (symbol <= 284) {
-        auto extra_bits = (symbol - 261) / 4;
-        return ((((symbol - 265) % 4) + 4) << extra_bits) + 3 + m_reader.read_bits(extra_bits);
-    }
+    if (m_state == State::ReadingUncompressedBlock) {
+        auto nread = m_output_stream.read(bytes);
 
-    if (symbol == 285) {
-        return 258;
-    }
+        while (nread < bytes.size() && m_uncompressed_block.try_read_more()) {
+            nread += m_output_stream.read(bytes.slice(nread));
+        }
 
-    dbg() << "Found invalid symbol in run length " << symbol;
-    ASSERT_NOT_REACHED();
-}
+        if (nread == bytes.size())
+            return nread;
 
-u32 DeflateStream::decode_distance(u32 symbol) const
-{
-    if (symbol <= 3) {
-        return symbol + 1;
-    }
+        m_uncompressed_block.~UncompressedBlock();
+        m_state = State::Idle;
 
-    if (symbol <= 29) {
-        auto extra_bits = (symbol / 2) - 1;
-        return (((symbol % 2) + 2) << extra_bits) + 1 + m_reader.read_bits(extra_bits);
+        return nread + read(bytes.slice(nread));
     }
 
-    dbg() << "Found invalid symbol in distance" << symbol;
     ASSERT_NOT_REACHED();
 }
 
-void DeflateStream::copy_from_history(u32 distance, u32 run) const
+bool DeflateDecompressor::read_or_error(Bytes bytes)
 {
-    for (size_t i = 0; i < run; i++) {
-        u8 byte;
-
-        // FIXME: In many cases we can read more than one byte at a time, this should
-        //        be refactored into a while loop. Beware, edge case:
-        //
-        //            // The first four bytes are on the stream already, the other four
-        //            // are written by copy_from_history() itself.
-        //            copy_from_history(4, 8);
-        m_intermediate_stream.read({ &byte, sizeof(byte) }, m_intermediate_stream.woffset() - distance);
-        m_intermediate_stream << byte;
+    if (read(bytes) < bytes.size()) {
+        m_error = true;
+        return false;
     }
+
+    return true;
 }
 
-i8 BitStreamReader::read()
+bool DeflateDecompressor::discard_or_error(size_t count)
 {
-    if (m_current_byte == -1) {
-        return -1;
-    }
+    u8 buffer[4096];
 
-    if (m_remaining_bits == 0) {
-        if (m_data_index + 1 > m_data.size())
-            return -1;
+    size_t ndiscarded = 0;
+    while (ndiscarded < count) {
+        if (eof()) {
+            m_error = true;
+            return false;
+        }
 
-        m_current_byte = m_data.at(m_data_index++);
-        m_remaining_bits = 8;
+        ndiscarded += read({ buffer, min<size_t>(count - ndiscarded, 4096) });
     }
 
-    m_remaining_bits--;
-    return (m_current_byte >> (7 - m_remaining_bits)) & 1;
+    return true;
 }
 
-i8 BitStreamReader::read_byte()
-{
-    m_current_byte = 0;
-    m_remaining_bits = 0;
-
-    if (m_data_index + 1 > m_data.size())
-        return -1;
-
-    return m_data.at(m_data_index++);
-}
+bool DeflateDecompressor::eof() const { return m_state == State::Idle && m_read_final_bock; }
 
-u8 BitStreamReader::get_bit_byte_offset()
+ByteBuffer DeflateDecompressor::decompress_all(ReadonlyBytes bytes)
 {
-    return (8 - m_remaining_bits) % 8;
-}
+    InputMemoryStream memory_stream { bytes };
+    InputBitStream bit_stream { memory_stream };
+    DeflateDecompressor deflate_stream { bit_stream };
 
-u32 BitStreamReader::read_bits(u8 count)
-{
-    ASSERT(count > 0 && count < 32);
+    auto buffer = ByteBuffer::create_uninitialized(4096);
 
-    u32 result = 0;
-    for (size_t i = 0; i < count; i++) {
-        result |= read() << i;
+    size_t nread = 0;
+    while (!deflate_stream.eof()) {
+        nread += deflate_stream.read(buffer.bytes().slice(nread));
+        if (buffer.size() - nread < 4096)
+            buffer.grow(buffer.size() + 4096);
     }
-    return result;
-}
 
-Vector<u8> DeflateStream::generate_literal_length_codes() const
-{
-    Vector<u8> ll_codes;
-    ll_codes.resize(288);
-    memset(ll_codes.data() + 0, 8, 144 - 0);
-    memset(ll_codes.data() + 144, 9, 256 - 144);
-    memset(ll_codes.data() + 256, 7, 280 - 256);
-    memset(ll_codes.data() + 280, 8, 288 - 280);
-    return ll_codes;
-}
-
-Vector<u8> DeflateStream::generate_fixed_distance_codes() const
-{
-    Vector<u8> fd_codes;
-    fd_codes.resize(32);
-    memset(fd_codes.data(), 5, 32);
-    return fd_codes;
+    buffer.trim(nread);
+    return buffer;
 }
 
-CanonicalCode::CanonicalCode(Vector<u8> codes)
+u32 DeflateDecompressor::decode_run_length(u32 symbol)
 {
-    m_symbol_codes.resize(codes.size());
-    m_symbol_values.resize(codes.size());
-
-    auto allocated_symbols_count = 0;
-    auto next_code = 0;
-
-    for (size_t code_length = 1; code_length <= 15; code_length++) {
-        next_code <<= 1;
-        auto start_bit = 1 << code_length;
-
-        for (size_t symbol = 0; symbol < codes.size(); symbol++) {
-            if (codes.at(symbol) != code_length) {
-                continue;
-            }
+    // FIXME: I can't quite follow the algorithm here, but it seems to work.
 
-            if (next_code > start_bit) {
-                dbg() << "Canonical code overflows the huffman tree";
-                ASSERT_NOT_REACHED();
-            }
-
-            m_symbol_codes[allocated_symbols_count] = start_bit | next_code;
-            m_symbol_values[allocated_symbols_count] = symbol;
+    if (symbol <= 264)
+        return symbol - 254;
 
-            allocated_symbols_count++;
-            next_code++;
-        }
+    if (symbol <= 284) {
+        auto extra_bits = (symbol - 261) / 4;
+        return (((symbol - 265) % 4 + 4) << extra_bits) + 3 + m_input_stream.read_bits(extra_bits);
     }
 
-    if (next_code != (1 << 15)) {
-        dbg() << "Canonical code underflows the huffman tree " << next_code;
-        ASSERT_NOT_REACHED();
-    }
+    if (symbol == 285)
+        return 258;
+
+    ASSERT_NOT_REACHED();
 }
 
-static i32 binary_search(Vector<u32>& heystack, u32 needle)
+u32 DeflateDecompressor::decode_distance(u32 symbol)
 {
-    i32 low = 0;
-    i32 high = heystack.size();
-
-    while (low <= high) {
-        u32 mid = (low + high) >> 1;
-        u32 value = heystack.at(mid);
-
-        if (value < needle) {
-            low = mid + 1;
-        } else if (value > needle) {
-            high = mid - 1;
-        } else {
-            return mid;
-        }
+    // FIXME: I can't quite follow the algorithm here, but it seems to work.
+
+    if (symbol <= 3)
+        return symbol + 1;
+
+    if (symbol <= 29) {
+        auto extra_bits = (symbol / 2) - 1;
+        return ((symbol % 2 + 2) << extra_bits) + 1 + m_input_stream.read_bits(extra_bits);
     }
 
-    return -1;
+    ASSERT_NOT_REACHED();
 }
 
-u32 CanonicalCode::next_symbol(BitStreamReader& reader)
+void DeflateDecompressor::decode_codes(CanonicalCode&, CanonicalCode&)
 {
-    auto code_bits = 1;
-
-    for (;;) {
-        code_bits = code_bits << 1 | reader.read();
-        i32 index = binary_search(m_symbol_codes, code_bits);
-        if (index >= 0) {
-            return m_symbol_values.at(index);
-        }
-    }
+    // FIXME: This was already implemented but I removed it because it was quite chaotic and untested.
+    //        I am planning to come back to this. @asynts
+    //        https://github.com/SerenityOS/serenity/blob/208cb995babb13e0af07bb9d3219f0a9fe7bca7d/Libraries/LibCompress/Deflate.cpp#L144-L242
+    TODO();
 }
 
 }

+ 66 - 136
Libraries/LibCompress/Deflate.h

@@ -26,160 +26,90 @@
 
 #pragma once
 
-#include <AK/CircularQueue.h>
-#include <AK/Span.h>
-#include <AK/Stream.h>
-#include <AK/Types.h>
+#include <AK/BitStream.h>
+#include <AK/ByteBuffer.h>
+#include <AK/CircularDuplexStream.h>
+#include <AK/Endian.h>
 #include <AK/Vector.h>
-#include <cstring>
 
 namespace Compress {
 
-// Reads one bit at a time starting with the rightmost bit
-class BitStreamReader {
-public:
-    BitStreamReader(ReadonlyBytes data)
-        : m_data(data)
-    {
-    }
+class DeflateDecompressor final : public InputStream {
+private:
+    class CanonicalCode {
+    public:
+        CanonicalCode() = default;
+        CanonicalCode(ReadonlyBytes);
+        u32 read_symbol(InputBitStream&) const;
 
-    i8 read();
-    i8 read_byte();
-    u32 read_bits(u8);
-    u8 get_bit_byte_offset();
+        static const CanonicalCode& fixed_literal_codes();
+        static const CanonicalCode& fixed_distance_codes();
 
-private:
-    ReadonlyBytes m_data;
-    size_t m_data_index { 0 };
+    private:
+        Vector<u32> m_symbol_codes;
+        Vector<u32> m_symbol_values;
+    };
 
-    i8 m_current_byte { 0 };
-    u8 m_remaining_bits { 0 };
-};
+    class CompressedBlock {
+    public:
+        CompressedBlock(DeflateDecompressor&, CanonicalCode literal_codes, Optional<CanonicalCode> distance_codes);
 
-class CanonicalCode {
-public:
-    CanonicalCode(Vector<u8>);
-    u32 next_symbol(BitStreamReader&);
+        bool try_read_more();
 
-private:
-    Vector<u32> m_symbol_codes;
-    Vector<u32> m_symbol_values;
-};
+    private:
+        bool m_eof { false };
 
-// Implements a DEFLATE decompressor according to RFC 1951.
-class DeflateStream final : public InputStream {
-public:
-    // FIXME: This should really return a ByteBuffer.
-    static Vector<u8> decompress_all(ReadonlyBytes bytes)
-    {
-        DeflateStream stream { bytes };
-        while (stream.read_next_block()) {
-        }
-
-        Vector<u8> vector;
-        vector.resize(stream.m_intermediate_stream.remaining());
-        stream >> vector;
-
-        return vector;
-    }
-
-    DeflateStream(ReadonlyBytes data)
-        : m_reader(data)
-        , m_literal_length_codes(generate_literal_length_codes())
-        , m_fixed_distance_codes(generate_fixed_distance_codes())
-    {
-    }
-
-    // FIXME: Accept an InputStream.
-
-    size_t read(Bytes bytes) override
-    {
-        if (m_intermediate_stream.remaining() >= bytes.size())
-            return m_intermediate_stream.read_or_error(bytes);
-
-        while (read_next_block()) {
-            if (m_intermediate_stream.remaining() >= bytes.size())
-                return m_intermediate_stream.read_or_error(bytes);
-        }
-
-        return m_intermediate_stream.read(bytes);
-    }
-
-    bool read_or_error(Bytes bytes) override
-    {
-        if (m_intermediate_stream.remaining() >= bytes.size()) {
-            m_intermediate_stream.read_or_error(bytes);
-            return true;
-        }
-
-        while (read_next_block()) {
-            if (m_intermediate_stream.remaining() >= bytes.size()) {
-                m_intermediate_stream.read_or_error(bytes);
-                return true;
-            }
-        }
-
-        m_error = true;
-        return false;
-    }
-
-    bool eof() const override
-    {
-        if (!m_intermediate_stream.eof())
-            return false;
-
-        while (read_next_block()) {
-            if (!m_intermediate_stream.eof())
-                return false;
-        }
-
-        return true;
-    }
-
-    bool discard_or_error(size_t count) override
-    {
-        if (m_intermediate_stream.remaining() >= count) {
-            m_intermediate_stream.discard_or_error(count);
-            return true;
-        }
-
-        while (read_next_block()) {
-            if (m_intermediate_stream.remaining() >= count) {
-                m_intermediate_stream.discard_or_error(count);
-                return true;
-            }
-        }
-
-        m_error = true;
-        return false;
-    }
+        DeflateDecompressor& m_decompressor;
+        CanonicalCode m_literal_codes;
+        Optional<CanonicalCode> m_distance_codes;
+    };
 
-private:
-    void decompress_uncompressed_block() const;
-    void decompress_static_block() const;
-    void decompress_dynamic_block() const;
-    void decompress_huffman_block(CanonicalCode&, CanonicalCode*) const;
+    class UncompressedBlock {
+    public:
+        UncompressedBlock(DeflateDecompressor&, size_t);
+
+        bool try_read_more();
 
-    Vector<CanonicalCode> decode_huffman_codes() const;
-    u32 decode_run_length(u32) const;
-    u32 decode_distance(u32) const;
+    private:
+        DeflateDecompressor& m_decompressor;
+        size_t m_bytes_remaining;
+    };
 
-    void copy_from_history(u32, u32) const;
+    enum class State {
+        Idle,
+        ReadingCompressedBlock,
+        ReadingUncompressedBlock
+    };
+
+public:
+    friend CompressedBlock;
+    friend UncompressedBlock;
 
-    Vector<u8> generate_literal_length_codes() const;
-    Vector<u8> generate_fixed_distance_codes() const;
+    DeflateDecompressor(InputStream&);
+    ~DeflateDecompressor();
 
-    mutable BitStreamReader m_reader;
+    size_t read(Bytes) override;
+    bool read_or_error(Bytes) override;
+    bool discard_or_error(size_t) override;
+    bool eof() const override;
+
+    static ByteBuffer decompress_all(ReadonlyBytes);
+
+private:
+    u32 decode_run_length(u32);
+    u32 decode_distance(u32);
+    void decode_codes(CanonicalCode&, CanonicalCode&);
 
-    mutable CanonicalCode m_literal_length_codes;
-    mutable CanonicalCode m_fixed_distance_codes;
+    bool m_read_final_bock { false };
 
-    // FIXME: Theoretically, blocks can be extremly large, reading a single block could
-    //        exhaust memory. Maybe wait for C++20 coroutines?
-    bool read_next_block() const;
+    State m_state { State::Idle };
+    union {
+        CompressedBlock m_compressed_block;
+        UncompressedBlock m_uncompressed_block;
+    };
 
-    mutable bool m_read_last_block { false };
-    mutable DuplexMemoryStream m_intermediate_stream;
+    InputBitStream m_input_stream;
+    CircularDuplexStream<32 * 1024> m_output_stream;
 };
 
 }

+ 2 - 2
Libraries/LibCompress/Zlib.cpp

@@ -55,9 +55,9 @@ Zlib::Zlib(ReadonlyBytes data)
     m_data_bytes = data.slice(2, data.size() - 2 - 4);
 }
 
-Vector<u8> Zlib::decompress()
+ByteBuffer Zlib::decompress()
 {
-    return DeflateStream::decompress_all(m_data_bytes);
+    return DeflateDecompressor::decompress_all(m_data_bytes);
 }
 
 u32 Zlib::checksum()

+ 9 - 2
Libraries/LibCompress/Zlib.h

@@ -26,18 +26,25 @@
 
 #pragma once
 
+#include <AK/ByteBuffer.h>
 #include <AK/Span.h>
 #include <AK/Types.h>
-#include <AK/Vector.h>
 
 namespace Compress {
+
 class Zlib {
 public:
     Zlib(ReadonlyBytes data);
 
-    Vector<u8> decompress();
+    ByteBuffer decompress();
     u32 checksum();
 
+    static ByteBuffer decompress_all(ReadonlyBytes bytes)
+    {
+        Zlib zlib { bytes };
+        return zlib.decompress();
+    }
+
 private:
     u8 m_compression_method;
     u8 m_compression_info;

+ 53 - 5
Userland/test-compress.cpp

@@ -52,11 +52,59 @@ TEST_CASE(deflate_decompress_compressed_block)
 
     const u8 uncompressed[] = "This is a simple text file :)";
 
-    const auto decompressed = Compress::DeflateStream::decompress_all({ compressed, sizeof(compressed) });
-    EXPECT(compare({ uncompressed, sizeof(uncompressed) - 1 }, decompressed.span()));
+    const auto decompressed = Compress::DeflateDecompressor::decompress_all({ compressed, sizeof(compressed) });
+    EXPECT(compare({ uncompressed, sizeof(uncompressed) - 1 }, decompressed.bytes()));
 }
 
-TEST_CASE(zlib_simple_decompress)
+TEST_CASE(deflate_decompress_uncompressed_block)
+{
+    const u8 compressed[] = {
+        0x01, 0x0d, 0x00, 0xf2, 0xff, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20,
+        0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21
+    };
+
+    const u8 uncompressed[] = "Hello, World!";
+
+    const auto decompressed = Compress::DeflateDecompressor::decompress_all({ compressed, sizeof(compressed) });
+    EXPECT(compare({ uncompressed, sizeof(uncompressed) - 1 }, decompressed.bytes()));
+}
+
+TEST_CASE(deflate_decompress_multiple_blocks)
+{
+    const u8 compressed[] = {
+        0x00, 0x1f, 0x00, 0xe0, 0xff, 0x54, 0x68, 0x65, 0x20, 0x66, 0x69, 0x72,
+        0x73, 0x74, 0x20, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x20, 0x69, 0x73, 0x20,
+        0x75, 0x6e, 0x63, 0x6f, 0x6d, 0x70, 0x72, 0x65, 0x73, 0x73, 0x65, 0x64,
+        0x53, 0x48, 0xcc, 0x4b, 0x51, 0x28, 0xc9, 0x48, 0x55, 0x28, 0x4e, 0x4d,
+        0xce, 0x07, 0x32, 0x93, 0x72, 0xf2, 0x93, 0xb3, 0x15, 0x32, 0x8b, 0x15,
+        0x92, 0xf3, 0x73, 0x0b, 0x8a, 0x52, 0x8b, 0x8b, 0x53, 0x53, 0xf4, 0x00
+    };
+
+    const u8 uncompressed[] = "The first block is uncompressed and the second block is compressed.";
+
+    const auto decompressed = Compress::DeflateDecompressor::decompress_all({ compressed, sizeof(compressed) });
+    EXPECT(compare({ uncompressed, sizeof(uncompressed) - 1 }, decompressed.bytes()));
+}
+
+// FIXME: The following test uses a dynamic encoding which isn't supported by DeflateDecompressor yet.
+
+/*
+TEST_CASE(deflate_decompress_zeroes)
+{
+    const u8 compressed[] = {
+        0xed, 0xc1, 0x01, 0x0d, 0x00, 0x00, 0x00, 0xc2, 0xa0, 0xf7, 0x4f, 0x6d,
+        0x0f, 0x07, 0x14, 0x00, 0x00, 0x00, 0xf0, 0x6e
+    };
+
+    u8 uncompressed[4096];
+    Bytes { uncompressed, sizeof(uncompressed) }.fill(0);
+
+    const auto decompressed = Compress::DeflateDecompressor::decompress_all({ compressed, sizeof(compressed) });
+    EXPECT(compare({ uncompressed, sizeof(uncompressed) }, decompressed.bytes()));
+}
+*/
+
+TEST_CASE(zlib_decompress_simple)
 {
     const u8 compressed[] = {
         0x78, 0x01, 0x01, 0x1D, 0x00, 0xE2, 0xFF, 0x54, 0x68, 0x69, 0x73, 0x20,
@@ -67,8 +115,8 @@ TEST_CASE(zlib_simple_decompress)
 
     const u8 uncompressed[] = "This is a simple text file :)";
 
-    const auto decompressed = Compress::Zlib { { compressed, sizeof(compressed) } }.decompress();
-    EXPECT(compare({ uncompressed, sizeof(uncompressed) - 1 }, decompressed.span()));
+    const auto decompressed = Compress::Zlib::decompress_all({ compressed, sizeof(compressed) });
+    EXPECT(compare({ uncompressed, sizeof(uncompressed) - 1 }, decompressed.bytes()));
 }
 
 TEST_MAIN(Compress)