소스 검색

LibCompress: Convert DeflateDecompressor from recursive to iterative

This way a deflate blob that contains a large amount of small blocks
wont cause a stack overflow.
Idan Horowitz 4 년 전
부모
커밋
974a981ded
1개의 변경된 파일82개의 추가작업 그리고 74개의 파일을 삭제
  1. 82 74
      Userland/Libraries/LibCompress/Deflate.cpp

+ 82 - 74
Userland/Libraries/LibCompress/Deflate.cpp

@@ -223,113 +223,121 @@ DeflateDecompressor::~DeflateDecompressor()
 
 size_t DeflateDecompressor::read(Bytes bytes)
 {
-    if (has_any_error())
-        return 0;
+    size_t total_read = 0;
+    while (total_read < bytes.size()) {
+        if (has_any_error())
+            break;
 
-    if (m_state == State::Idle) {
-        if (m_read_final_bock)
-            return 0;
-
-        m_read_final_bock = m_input_stream.read_bit();
-        const auto block_type = m_input_stream.read_bits(2);
-
-        if (m_input_stream.has_any_error()) {
-            set_fatal_error();
-            return 0;
-        }
+        auto slice = bytes.slice(total_read);
 
-        if (block_type == 0b00) {
-            m_input_stream.align_to_byte_boundary();
+        if (m_state == State::Idle) {
+            if (m_read_final_bock)
+                break;
 
-            LittleEndian<u16> length, negated_length;
-            m_input_stream >> length >> negated_length;
+            m_read_final_bock = m_input_stream.read_bit();
+            const auto block_type = m_input_stream.read_bits(2);
 
             if (m_input_stream.has_any_error()) {
                 set_fatal_error();
-                return 0;
+                break;
             }
 
-            if ((length ^ 0xffff) != negated_length) {
-                set_fatal_error();
-                return 0;
-            }
-
-            m_state = State::ReadingUncompressedBlock;
-            new (&m_uncompressed_block) UncompressedBlock(*this, length);
+            if (block_type == 0b00) {
+                m_input_stream.align_to_byte_boundary();
 
-            return read(bytes);
-        }
+                LittleEndian<u16> length, negated_length;
+                m_input_stream >> length >> negated_length;
 
-        if (block_type == 0b01) {
-            m_state = State::ReadingCompressedBlock;
-            new (&m_compressed_block) CompressedBlock(*this, CanonicalCode::fixed_literal_codes(), CanonicalCode::fixed_distance_codes());
+                if (m_input_stream.has_any_error()) {
+                    set_fatal_error();
+                    break;
+                }
 
-            return read(bytes);
-        }
+                if ((length ^ 0xffff) != negated_length) {
+                    set_fatal_error();
+                    break;
+                }
 
-        if (block_type == 0b10) {
-            CanonicalCode literal_codes;
-            Optional<CanonicalCode> distance_codes;
-            decode_codes(literal_codes, distance_codes);
+                m_state = State::ReadingUncompressedBlock;
+                new (&m_uncompressed_block) UncompressedBlock(*this, length);
 
-            if (m_input_stream.has_any_error()) {
-                set_fatal_error();
-                return 0;
+                continue;
             }
 
-            m_state = State::ReadingCompressedBlock;
-            new (&m_compressed_block) CompressedBlock(*this, literal_codes, distance_codes);
+            if (block_type == 0b01) {
+                m_state = State::ReadingCompressedBlock;
+                new (&m_compressed_block) CompressedBlock(*this, CanonicalCode::fixed_literal_codes(), CanonicalCode::fixed_distance_codes());
 
-            return read(bytes);
-        }
+                continue;
+            }
 
-        set_fatal_error();
-        return 0;
-    }
+            if (block_type == 0b10) {
+                CanonicalCode literal_codes;
+                Optional<CanonicalCode> distance_codes;
+                decode_codes(literal_codes, distance_codes);
 
-    if (m_state == State::ReadingCompressedBlock) {
-        auto nread = m_output_stream.read(bytes);
+                if (m_input_stream.has_any_error()) {
+                    set_fatal_error();
+                    break;
+                }
 
-        while (nread < bytes.size() && m_compressed_block.try_read_more()) {
-            nread += m_output_stream.read(bytes.slice(nread));
-        }
+                m_state = State::ReadingCompressedBlock;
+                new (&m_compressed_block) CompressedBlock(*this, literal_codes, distance_codes);
+
+                continue;
+            }
 
-        if (m_input_stream.has_any_error()) {
             set_fatal_error();
-            return 0;
+            break;
         }
 
-        if (nread == bytes.size())
-            return nread;
+        if (m_state == State::ReadingCompressedBlock) {
+            auto nread = m_output_stream.read(slice);
 
-        m_compressed_block.~CompressedBlock();
-        m_state = State::Idle;
+            while (nread < slice.size() && m_compressed_block.try_read_more()) {
+                nread += m_output_stream.read(slice.slice(nread));
+            }
 
-        return nread + read(bytes.slice(nread));
-    }
+            if (m_input_stream.has_any_error()) {
+                set_fatal_error();
+                break;
+            }
 
-    if (m_state == State::ReadingUncompressedBlock) {
-        auto nread = m_output_stream.read(bytes);
+            total_read += nread;
+            if (nread == slice.size())
+                break;
 
-        while (nread < bytes.size() && m_uncompressed_block.try_read_more()) {
-            nread += m_output_stream.read(bytes.slice(nread));
-        }
+            m_compressed_block.~CompressedBlock();
+            m_state = State::Idle;
 
-        if (m_input_stream.has_any_error()) {
-            set_fatal_error();
-            return 0;
+            continue;
         }
 
-        if (nread == bytes.size())
-            return nread;
+        if (m_state == State::ReadingUncompressedBlock) {
+            auto nread = m_output_stream.read(slice);
 
-        m_uncompressed_block.~UncompressedBlock();
-        m_state = State::Idle;
+            while (nread < slice.size() && m_uncompressed_block.try_read_more()) {
+                nread += m_output_stream.read(slice.slice(nread));
+            }
 
-        return nread + read(bytes.slice(nread));
-    }
+            if (m_input_stream.has_any_error()) {
+                set_fatal_error();
+                break;
+            }
 
-    VERIFY_NOT_REACHED();
+            total_read += nread;
+            if (nread == slice.size())
+                break;
+
+            m_uncompressed_block.~UncompressedBlock();
+            m_state = State::Idle;
+
+            continue;
+        }
+
+        VERIFY_NOT_REACHED();
+    }
+    return total_read;
 }
 
 bool DeflateDecompressor::read_or_error(Bytes bytes)