Deflate.cpp 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032
  1. /*
  2. * Copyright (c) 2020, the SerenityOS developers.
  3. * Copyright (c) 2021, Idan Horowitz <idan.horowitz@serenityos.org>
  4. *
  5. * SPDX-License-Identifier: BSD-2-Clause
  6. */
  7. #include <AK/Array.h>
  8. #include <AK/Assertions.h>
  9. #include <AK/BinaryHeap.h>
  10. #include <AK/BinarySearch.h>
  11. #include <AK/BitStream.h>
  12. #include <LibCore/MemoryStream.h>
  13. #include <string.h>
  14. #include <LibCompress/Deflate.h>
  15. namespace Compress {
  16. static constexpr u8 deflate_special_code_length_copy = 16;
  17. static constexpr u8 deflate_special_code_length_zeros = 17;
  18. static constexpr u8 deflate_special_code_length_long_zeros = 18;
  19. CanonicalCode const& CanonicalCode::fixed_literal_codes()
  20. {
  21. static CanonicalCode code;
  22. static bool initialized = false;
  23. if (initialized)
  24. return code;
  25. code = CanonicalCode::from_bytes(fixed_literal_bit_lengths).value();
  26. initialized = true;
  27. return code;
  28. }
  29. CanonicalCode const& CanonicalCode::fixed_distance_codes()
  30. {
  31. static CanonicalCode code;
  32. static bool initialized = false;
  33. if (initialized)
  34. return code;
  35. code = CanonicalCode::from_bytes(fixed_distance_bit_lengths).value();
  36. initialized = true;
  37. return code;
  38. }
  39. Optional<CanonicalCode> CanonicalCode::from_bytes(ReadonlyBytes bytes)
  40. {
  41. // FIXME: I can't quite follow the algorithm here, but it seems to work.
  42. CanonicalCode code;
  43. auto non_zero_symbols = 0;
  44. auto last_non_zero = -1;
  45. for (size_t i = 0; i < bytes.size(); i++) {
  46. if (bytes[i] != 0) {
  47. non_zero_symbols++;
  48. last_non_zero = i;
  49. }
  50. }
  51. if (non_zero_symbols == 1) { // special case - only 1 symbol
  52. code.m_symbol_codes.append(0b10);
  53. code.m_symbol_values.append(last_non_zero);
  54. code.m_bit_codes[last_non_zero] = 0;
  55. code.m_bit_code_lengths[last_non_zero] = 1;
  56. return code;
  57. }
  58. auto next_code = 0;
  59. for (size_t code_length = 1; code_length <= 15; ++code_length) {
  60. next_code <<= 1;
  61. auto start_bit = 1 << code_length;
  62. for (size_t symbol = 0; symbol < bytes.size(); ++symbol) {
  63. if (bytes[symbol] != code_length)
  64. continue;
  65. if (next_code > start_bit)
  66. return {};
  67. code.m_symbol_codes.append(start_bit | next_code);
  68. code.m_symbol_values.append(symbol);
  69. code.m_bit_codes[symbol] = fast_reverse16(start_bit | next_code, code_length); // DEFLATE writes huffman encoded symbols as lsb-first
  70. code.m_bit_code_lengths[symbol] = code_length;
  71. next_code++;
  72. }
  73. }
  74. if (next_code != (1 << 15)) {
  75. return {};
  76. }
  77. return code;
  78. }
  79. ErrorOr<u32> CanonicalCode::read_symbol(LittleEndianInputBitStream& stream) const
  80. {
  81. u32 code_bits = 1;
  82. for (;;) {
  83. code_bits = code_bits << 1 | TRY(stream.read_bits(1));
  84. if (code_bits >= (1 << 16))
  85. return Error::from_string_literal("Symbol exceeds maximum symbol number");
  86. // FIXME: This is very inefficient and could greatly be improved by implementing this
  87. // algorithm: https://www.hanshq.net/zip.html#huffdec
  88. size_t index;
  89. if (binary_search(m_symbol_codes.span(), code_bits, &index))
  90. return m_symbol_values[index];
  91. }
  92. }
  93. ErrorOr<void> CanonicalCode::write_symbol(LittleEndianOutputBitStream& stream, u32 symbol) const
  94. {
  95. TRY(stream.write_bits(m_bit_codes[symbol], m_bit_code_lengths[symbol]));
  96. return {};
  97. }
  98. DeflateDecompressor::CompressedBlock::CompressedBlock(DeflateDecompressor& decompressor, CanonicalCode literal_codes, Optional<CanonicalCode> distance_codes)
  99. : m_decompressor(decompressor)
  100. , m_literal_codes(literal_codes)
  101. , m_distance_codes(distance_codes)
  102. {
  103. }
  104. ErrorOr<bool> DeflateDecompressor::CompressedBlock::try_read_more()
  105. {
  106. if (m_eof == true)
  107. return false;
  108. auto const symbol = TRY(m_literal_codes.read_symbol(*m_decompressor.m_input_stream));
  109. if (symbol >= 286)
  110. return Error::from_string_literal("Invalid deflate literal/length symbol");
  111. if (symbol < 256) {
  112. u8 byte_symbol = symbol;
  113. m_decompressor.m_output_buffer.write({ &byte_symbol, sizeof(byte_symbol) });
  114. return true;
  115. } else if (symbol == 256) {
  116. m_eof = true;
  117. return false;
  118. } else {
  119. if (!m_distance_codes.has_value())
  120. return Error::from_string_literal("Distance codes have not been initialized");
  121. auto const length = TRY(m_decompressor.decode_length(symbol));
  122. auto const distance_symbol = TRY(m_distance_codes.value().read_symbol(*m_decompressor.m_input_stream));
  123. if (distance_symbol >= 30)
  124. return Error::from_string_literal("Invalid deflate distance symbol");
  125. auto const distance = TRY(m_decompressor.decode_distance(distance_symbol));
  126. for (size_t idx = 0; idx < length; ++idx) {
  127. u8 byte = 0;
  128. TRY(m_decompressor.m_output_buffer.read_with_seekback({ &byte, sizeof(byte) }, distance));
  129. m_decompressor.m_output_buffer.write({ &byte, sizeof(byte) });
  130. }
  131. return true;
  132. }
  133. }
  134. DeflateDecompressor::UncompressedBlock::UncompressedBlock(DeflateDecompressor& decompressor, size_t length)
  135. : m_decompressor(decompressor)
  136. , m_bytes_remaining(length)
  137. {
  138. }
  139. ErrorOr<bool> DeflateDecompressor::UncompressedBlock::try_read_more()
  140. {
  141. if (m_bytes_remaining == 0)
  142. return false;
  143. Array<u8, 4096> temporary_buffer;
  144. auto readable_bytes = temporary_buffer.span().trim(min(m_bytes_remaining, m_decompressor.m_output_buffer.empty_space()));
  145. auto read_bytes = TRY(m_decompressor.m_input_stream->read(readable_bytes));
  146. auto written_bytes = m_decompressor.m_output_buffer.write(read_bytes);
  147. VERIFY(read_bytes.size() == written_bytes);
  148. m_bytes_remaining -= written_bytes;
  149. return true;
  150. }
  151. ErrorOr<NonnullOwnPtr<DeflateDecompressor>> DeflateDecompressor::construct(MaybeOwned<AK::Stream> stream)
  152. {
  153. auto output_buffer = TRY(CircularBuffer::create_empty(32 * KiB));
  154. return TRY(adopt_nonnull_own_or_enomem(new (nothrow) DeflateDecompressor(move(stream), move(output_buffer))));
  155. }
  156. DeflateDecompressor::DeflateDecompressor(MaybeOwned<AK::Stream> stream, CircularBuffer output_buffer)
  157. : m_input_stream(make<LittleEndianInputBitStream>(move(stream)))
  158. , m_output_buffer(move(output_buffer))
  159. {
  160. }
  161. DeflateDecompressor::~DeflateDecompressor()
  162. {
  163. if (m_state == State::ReadingCompressedBlock)
  164. m_compressed_block.~CompressedBlock();
  165. if (m_state == State::ReadingUncompressedBlock)
  166. m_uncompressed_block.~UncompressedBlock();
  167. }
  168. ErrorOr<Bytes> DeflateDecompressor::read(Bytes bytes)
  169. {
  170. size_t total_read = 0;
  171. while (total_read < bytes.size()) {
  172. auto slice = bytes.slice(total_read);
  173. if (m_state == State::Idle) {
  174. if (m_read_final_bock)
  175. break;
  176. m_read_final_bock = TRY(m_input_stream->read_bit());
  177. auto const block_type = TRY(m_input_stream->read_bits(2));
  178. if (block_type == 0b00) {
  179. m_input_stream->align_to_byte_boundary();
  180. LittleEndian<u16> length, negated_length;
  181. TRY(m_input_stream->read(length.bytes()));
  182. TRY(m_input_stream->read(negated_length.bytes()));
  183. if ((length ^ 0xffff) != negated_length)
  184. return Error::from_string_literal("Calculated negated length does not equal stored negated length");
  185. m_state = State::ReadingUncompressedBlock;
  186. new (&m_uncompressed_block) UncompressedBlock(*this, length);
  187. continue;
  188. }
  189. if (block_type == 0b01) {
  190. m_state = State::ReadingCompressedBlock;
  191. new (&m_compressed_block) CompressedBlock(*this, CanonicalCode::fixed_literal_codes(), CanonicalCode::fixed_distance_codes());
  192. continue;
  193. }
  194. if (block_type == 0b10) {
  195. CanonicalCode literal_codes;
  196. Optional<CanonicalCode> distance_codes;
  197. TRY(decode_codes(literal_codes, distance_codes));
  198. m_state = State::ReadingCompressedBlock;
  199. new (&m_compressed_block) CompressedBlock(*this, literal_codes, distance_codes);
  200. continue;
  201. }
  202. return Error::from_string_literal("Unhandled block type for Idle state");
  203. }
  204. if (m_state == State::ReadingCompressedBlock) {
  205. auto nread = m_output_buffer.read(slice).size();
  206. while (nread < slice.size() && TRY(m_compressed_block.try_read_more())) {
  207. nread += m_output_buffer.read(slice.slice(nread)).size();
  208. }
  209. total_read += nread;
  210. if (nread == slice.size())
  211. break;
  212. m_compressed_block.~CompressedBlock();
  213. m_state = State::Idle;
  214. continue;
  215. }
  216. if (m_state == State::ReadingUncompressedBlock) {
  217. auto nread = m_output_buffer.read(slice).size();
  218. while (nread < slice.size() && TRY(m_uncompressed_block.try_read_more())) {
  219. nread += m_output_buffer.read(slice.slice(nread)).size();
  220. }
  221. total_read += nread;
  222. if (nread == slice.size())
  223. break;
  224. m_uncompressed_block.~UncompressedBlock();
  225. m_state = State::Idle;
  226. continue;
  227. }
  228. VERIFY_NOT_REACHED();
  229. }
  230. return bytes.slice(0, total_read);
  231. }
  232. bool DeflateDecompressor::is_eof() const { return m_state == State::Idle && m_read_final_bock; }
  233. ErrorOr<size_t> DeflateDecompressor::write(ReadonlyBytes)
  234. {
  235. return Error::from_errno(EBADF);
  236. }
  237. bool DeflateDecompressor::is_open() const
  238. {
  239. return true;
  240. }
  241. void DeflateDecompressor::close()
  242. {
  243. }
  244. ErrorOr<ByteBuffer> DeflateDecompressor::decompress_all(ReadonlyBytes bytes)
  245. {
  246. auto memory_stream = TRY(Core::Stream::FixedMemoryStream::construct(bytes));
  247. auto deflate_stream = TRY(DeflateDecompressor::construct(move(memory_stream)));
  248. Core::Stream::AllocatingMemoryStream output_stream;
  249. auto buffer = TRY(ByteBuffer::create_uninitialized(4096));
  250. while (!deflate_stream->is_eof()) {
  251. auto const slice = TRY(deflate_stream->read(buffer));
  252. TRY(output_stream.write_entire_buffer(slice));
  253. }
  254. auto output_buffer = TRY(ByteBuffer::create_uninitialized(output_stream.used_buffer_size()));
  255. TRY(output_stream.read_entire_buffer(output_buffer));
  256. return output_buffer;
  257. }
  258. ErrorOr<u32> DeflateDecompressor::decode_length(u32 symbol)
  259. {
  260. // FIXME: I can't quite follow the algorithm here, but it seems to work.
  261. if (symbol <= 264)
  262. return symbol - 254;
  263. if (symbol <= 284) {
  264. auto extra_bits = (symbol - 261) / 4;
  265. return (((symbol - 265) % 4 + 4) << extra_bits) + 3 + TRY(m_input_stream->read_bits(extra_bits));
  266. }
  267. if (symbol == 285)
  268. return 258;
  269. VERIFY_NOT_REACHED();
  270. }
  271. ErrorOr<u32> DeflateDecompressor::decode_distance(u32 symbol)
  272. {
  273. // FIXME: I can't quite follow the algorithm here, but it seems to work.
  274. if (symbol <= 3)
  275. return symbol + 1;
  276. if (symbol <= 29) {
  277. auto extra_bits = (symbol / 2) - 1;
  278. return ((symbol % 2 + 2) << extra_bits) + 1 + TRY(m_input_stream->read_bits(extra_bits));
  279. }
  280. VERIFY_NOT_REACHED();
  281. }
  282. ErrorOr<void> DeflateDecompressor::decode_codes(CanonicalCode& literal_code, Optional<CanonicalCode>& distance_code)
  283. {
  284. auto literal_code_count = TRY(m_input_stream->read_bits(5)) + 257;
  285. auto distance_code_count = TRY(m_input_stream->read_bits(5)) + 1;
  286. auto code_length_count = TRY(m_input_stream->read_bits(4)) + 4;
  287. // First we have to extract the code lengths of the code that was used to encode the code lengths of
  288. // the code that was used to encode the block.
  289. u8 code_lengths_code_lengths[19] = { 0 };
  290. for (size_t i = 0; i < code_length_count; ++i) {
  291. code_lengths_code_lengths[code_lengths_code_lengths_order[i]] = TRY(m_input_stream->read_bits(3));
  292. }
  293. // Now we can extract the code that was used to encode the code lengths of the code that was used to
  294. // encode the block.
  295. auto code_length_code_result = CanonicalCode::from_bytes({ code_lengths_code_lengths, sizeof(code_lengths_code_lengths) });
  296. if (!code_length_code_result.has_value())
  297. return Error::from_string_literal("Failed to decode code length code");
  298. auto const code_length_code = code_length_code_result.value();
  299. // Next we extract the code lengths of the code that was used to encode the block.
  300. Vector<u8> code_lengths;
  301. while (code_lengths.size() < literal_code_count + distance_code_count) {
  302. auto symbol = TRY(code_length_code.read_symbol(*m_input_stream));
  303. if (symbol < deflate_special_code_length_copy) {
  304. code_lengths.append(static_cast<u8>(symbol));
  305. continue;
  306. } else if (symbol == deflate_special_code_length_zeros) {
  307. auto nrepeat = 3 + TRY(m_input_stream->read_bits(3));
  308. for (size_t j = 0; j < nrepeat; ++j)
  309. code_lengths.append(0);
  310. continue;
  311. } else if (symbol == deflate_special_code_length_long_zeros) {
  312. auto nrepeat = 11 + TRY(m_input_stream->read_bits(7));
  313. for (size_t j = 0; j < nrepeat; ++j)
  314. code_lengths.append(0);
  315. continue;
  316. } else {
  317. VERIFY(symbol == deflate_special_code_length_copy);
  318. if (code_lengths.is_empty())
  319. return Error::from_string_literal("Found no codes to copy before a copy block");
  320. auto nrepeat = 3 + TRY(m_input_stream->read_bits(2));
  321. for (size_t j = 0; j < nrepeat; ++j)
  322. code_lengths.append(code_lengths.last());
  323. }
  324. }
  325. if (code_lengths.size() != literal_code_count + distance_code_count)
  326. return Error::from_string_literal("Number of code lengths does not match the sum of codes");
  327. // Now we extract the code that was used to encode literals and lengths in the block.
  328. auto literal_code_result = CanonicalCode::from_bytes(code_lengths.span().trim(literal_code_count));
  329. if (!literal_code_result.has_value())
  330. return Error::from_string_literal("Failed to decode the literal code");
  331. literal_code = literal_code_result.value();
  332. // Now we extract the code that was used to encode distances in the block.
  333. if (distance_code_count == 1) {
  334. auto length = code_lengths[literal_code_count];
  335. if (length == 0)
  336. return {};
  337. else if (length != 1)
  338. return Error::from_string_literal("Length for a single distance code is longer than 1");
  339. }
  340. auto distance_code_result = CanonicalCode::from_bytes(code_lengths.span().slice(literal_code_count));
  341. if (!distance_code_result.has_value())
  342. return Error::from_string_literal("Failed to decode the distance code");
  343. distance_code = distance_code_result.value();
  344. return {};
  345. }
  346. ErrorOr<NonnullOwnPtr<DeflateCompressor>> DeflateCompressor::construct(MaybeOwned<AK::Stream> stream, CompressionLevel compression_level)
  347. {
  348. auto bit_stream = TRY(LittleEndianOutputBitStream::construct(move(stream)));
  349. auto deflate_compressor = TRY(adopt_nonnull_own_or_enomem(new (nothrow) DeflateCompressor(move(bit_stream), compression_level)));
  350. return deflate_compressor;
  351. }
  352. DeflateCompressor::DeflateCompressor(NonnullOwnPtr<LittleEndianOutputBitStream> stream, CompressionLevel compression_level)
  353. : m_compression_level(compression_level)
  354. , m_compression_constants(compression_constants[static_cast<int>(m_compression_level)])
  355. , m_output_stream(move(stream))
  356. {
  357. m_symbol_frequencies.fill(0);
  358. m_distance_frequencies.fill(0);
  359. }
  360. DeflateCompressor::~DeflateCompressor()
  361. {
  362. VERIFY(m_finished);
  363. }
  364. ErrorOr<Bytes> DeflateCompressor::read(Bytes)
  365. {
  366. return Error::from_errno(EBADF);
  367. }
  368. ErrorOr<size_t> DeflateCompressor::write(ReadonlyBytes bytes)
  369. {
  370. VERIFY(!m_finished);
  371. if (bytes.size() == 0)
  372. return 0; // recursion base case
  373. auto n_written = bytes.copy_trimmed_to(pending_block().slice(m_pending_block_size));
  374. m_pending_block_size += n_written;
  375. if (m_pending_block_size == block_size)
  376. TRY(flush());
  377. return n_written + TRY(write(bytes.slice(n_written)));
  378. }
  379. bool DeflateCompressor::is_eof() const
  380. {
  381. return true;
  382. }
  383. bool DeflateCompressor::is_open() const
  384. {
  385. return m_output_stream->is_open();
  386. }
  387. void DeflateCompressor::close()
  388. {
  389. }
  390. // Knuth's multiplicative hash on 4 bytes
  391. u16 DeflateCompressor::hash_sequence(u8 const* bytes)
  392. {
  393. constexpr const u32 knuth_constant = 2654435761; // shares no common factors with 2^32
  394. return ((bytes[0] | bytes[1] << 8 | bytes[2] << 16 | bytes[3] << 24) * knuth_constant) >> (32 - hash_bits);
  395. }
  396. size_t DeflateCompressor::compare_match_candidate(size_t start, size_t candidate, size_t previous_match_length, size_t maximum_match_length)
  397. {
  398. VERIFY(previous_match_length < maximum_match_length);
  399. // We firstly check that the match is at least (prev_match_length + 1) long, we check backwards as there's a higher chance the end mismatches
  400. for (ssize_t i = previous_match_length; i >= 0; i--) {
  401. if (m_rolling_window[start + i] != m_rolling_window[candidate + i])
  402. return 0;
  403. }
  404. // Find the actual length
  405. auto match_length = previous_match_length + 1;
  406. while (match_length < maximum_match_length && m_rolling_window[start + match_length] == m_rolling_window[candidate + match_length]) {
  407. match_length++;
  408. }
  409. VERIFY(match_length > previous_match_length);
  410. VERIFY(match_length <= maximum_match_length);
  411. return match_length;
  412. }
  413. size_t DeflateCompressor::find_back_match(size_t start, u16 hash, size_t previous_match_length, size_t maximum_match_length, size_t& match_position)
  414. {
  415. auto max_chain_length = m_compression_constants.max_chain;
  416. if (previous_match_length == 0)
  417. previous_match_length = min_match_length - 1; // we only care about matches that are at least min_match_length long
  418. if (previous_match_length >= maximum_match_length)
  419. return 0; // we can't improve a maximum length match
  420. if (previous_match_length >= m_compression_constants.max_lazy_length)
  421. return 0; // the previous match is already pretty, we shouldn't waste another full search
  422. if (previous_match_length >= m_compression_constants.good_match_length)
  423. max_chain_length /= 4; // we already have a pretty good much, so do a shorter search
  424. auto candidate = m_hash_head[hash];
  425. auto match_found = false;
  426. while (max_chain_length--) {
  427. if (candidate == empty_slot)
  428. break; // no remaining candidates
  429. VERIFY(candidate < start);
  430. if (start - candidate > window_size)
  431. break; // outside the window
  432. auto match_length = compare_match_candidate(start, candidate, previous_match_length, maximum_match_length);
  433. if (match_length != 0) {
  434. match_found = true;
  435. match_position = candidate;
  436. previous_match_length = match_length;
  437. if (match_length == maximum_match_length)
  438. return match_length; // bail if we got the maximum possible length
  439. }
  440. candidate = m_hash_prev[candidate % window_size];
  441. }
  442. if (!match_found)
  443. return 0; // we didn't find any matches
  444. return previous_match_length; // we found matches, but they were at most previous_match_length long
  445. }
  446. ALWAYS_INLINE u8 DeflateCompressor::distance_to_base(u16 distance)
  447. {
  448. return (distance <= 256) ? distance_to_base_lo[distance - 1] : distance_to_base_hi[(distance - 1) >> 7];
  449. }
  450. template<size_t Size>
  451. void DeflateCompressor::generate_huffman_lengths(Array<u8, Size>& lengths, Array<u16, Size> const& frequencies, size_t max_bit_length, u16 frequency_cap)
  452. {
  453. VERIFY((1u << max_bit_length) >= Size);
  454. u16 heap_keys[Size]; // Used for O(n) heap construction
  455. u16 heap_values[Size];
  456. u16 huffman_links[Size * 2 + 1] = { 0 };
  457. size_t non_zero_freqs = 0;
  458. for (size_t i = 0; i < Size; i++) {
  459. auto frequency = frequencies[i];
  460. if (frequency == 0)
  461. continue;
  462. if (frequency > frequency_cap) {
  463. frequency = frequency_cap;
  464. }
  465. heap_keys[non_zero_freqs] = frequency; // sort symbols by frequency
  466. heap_values[non_zero_freqs] = Size + non_zero_freqs; // huffman_links "links"
  467. non_zero_freqs++;
  468. }
  469. // special case for only 1 used symbol
  470. if (non_zero_freqs < 2) {
  471. for (size_t i = 0; i < Size; i++)
  472. lengths[i] = (frequencies[i] == 0) ? 0 : 1;
  473. return;
  474. }
  475. BinaryHeap<u16, u16, Size> heap { heap_keys, heap_values, non_zero_freqs };
  476. // build the huffman tree - binary heap is used for efficient frequency comparisons
  477. while (heap.size() > 1) {
  478. u16 lowest_frequency = heap.peek_min_key();
  479. u16 lowest_link = heap.pop_min();
  480. u16 second_lowest_frequency = heap.peek_min_key();
  481. u16 second_lowest_link = heap.pop_min();
  482. u16 new_link = heap.size() + 2;
  483. heap.insert(lowest_frequency + second_lowest_frequency, new_link);
  484. huffman_links[lowest_link] = new_link;
  485. huffman_links[second_lowest_link] = new_link;
  486. }
  487. non_zero_freqs = 0;
  488. for (size_t i = 0; i < Size; i++) {
  489. if (frequencies[i] == 0) {
  490. lengths[i] = 0;
  491. continue;
  492. }
  493. u16 link = huffman_links[Size + non_zero_freqs];
  494. non_zero_freqs++;
  495. size_t bit_length = 1;
  496. while (link != 2) {
  497. bit_length++;
  498. link = huffman_links[link];
  499. }
  500. if (bit_length > max_bit_length) {
  501. VERIFY(frequency_cap != 1);
  502. return generate_huffman_lengths(lengths, frequencies, max_bit_length, frequency_cap / 2);
  503. }
  504. lengths[i] = bit_length;
  505. }
  506. }
  507. void DeflateCompressor::lz77_compress_block()
  508. {
  509. for (auto& slot : m_hash_head) { // initialize chained hash table
  510. slot = empty_slot;
  511. }
  512. auto insert_hash = [&](auto pos, auto hash) {
  513. auto window_pos = pos % window_size;
  514. m_hash_prev[window_pos] = m_hash_head[hash];
  515. m_hash_head[hash] = window_pos;
  516. };
  517. auto emit_literal = [&](auto literal) {
  518. VERIFY(m_pending_symbol_size <= block_size + 1);
  519. auto index = m_pending_symbol_size++;
  520. m_symbol_buffer[index].distance = 0;
  521. m_symbol_buffer[index].literal = literal;
  522. m_symbol_frequencies[literal]++;
  523. };
  524. auto emit_back_reference = [&](auto distance, auto length) {
  525. VERIFY(m_pending_symbol_size <= block_size + 1);
  526. auto index = m_pending_symbol_size++;
  527. m_symbol_buffer[index].distance = distance;
  528. m_symbol_buffer[index].length = length;
  529. m_symbol_frequencies[length_to_symbol[length]]++;
  530. m_distance_frequencies[distance_to_base(distance)]++;
  531. };
  532. size_t previous_match_length = 0;
  533. size_t previous_match_position = 0;
  534. VERIFY(m_compression_constants.great_match_length <= max_match_length);
  535. // our block starts at block_size and is m_pending_block_size in length
  536. auto block_end = block_size + m_pending_block_size;
  537. size_t current_position;
  538. for (current_position = block_size; current_position < block_end - min_match_length + 1; current_position++) {
  539. auto hash = hash_sequence(&m_rolling_window[current_position]);
  540. size_t match_position;
  541. auto match_length = find_back_match(current_position, hash, previous_match_length,
  542. min(m_compression_constants.great_match_length, block_end - current_position), match_position);
  543. insert_hash(current_position, hash);
  544. // if the previous match is as good as the new match, just use it
  545. if (previous_match_length != 0 && previous_match_length >= match_length) {
  546. emit_back_reference((current_position - 1) - previous_match_position, previous_match_length);
  547. // skip all the bytes that are included in this match
  548. for (size_t j = current_position + 1; j < min(current_position - 1 + previous_match_length, block_end - min_match_length + 1); j++) {
  549. insert_hash(j, hash_sequence(&m_rolling_window[j]));
  550. }
  551. current_position = (current_position - 1) + previous_match_length - 1;
  552. previous_match_length = 0;
  553. continue;
  554. }
  555. if (match_length == 0) {
  556. VERIFY(previous_match_length == 0);
  557. emit_literal(m_rolling_window[current_position]);
  558. continue;
  559. }
  560. // if this is a lazy match, and the new match is better than the old one, output previous as literal
  561. if (previous_match_length != 0) {
  562. emit_literal(m_rolling_window[current_position - 1]);
  563. }
  564. previous_match_length = match_length;
  565. previous_match_position = match_position;
  566. }
  567. // clean up leftover lazy match
  568. if (previous_match_length != 0) {
  569. emit_back_reference((current_position - 1) - previous_match_position, previous_match_length);
  570. current_position = (current_position - 1) + previous_match_length;
  571. }
  572. // output remaining literals
  573. while (current_position < block_end) {
  574. emit_literal(m_rolling_window[current_position++]);
  575. }
  576. }
  577. size_t DeflateCompressor::huffman_block_length(Array<u8, max_huffman_literals> const& literal_bit_lengths, Array<u8, max_huffman_distances> const& distance_bit_lengths)
  578. {
  579. size_t length = 0;
  580. for (size_t i = 0; i < 286; i++) {
  581. auto frequency = m_symbol_frequencies[i];
  582. length += literal_bit_lengths[i] * frequency;
  583. if (i >= 257) // back reference length symbols
  584. length += packed_length_symbols[i - 257].extra_bits * frequency;
  585. }
  586. for (size_t i = 0; i < 30; i++) {
  587. auto frequency = m_distance_frequencies[i];
  588. length += distance_bit_lengths[i] * frequency;
  589. length += packed_distances[i].extra_bits * frequency;
  590. }
  591. return length;
  592. }
  593. size_t DeflateCompressor::uncompressed_block_length()
  594. {
  595. auto padding = 8 - ((m_output_stream->bit_offset() + 3) % 8);
  596. // 3 bit block header + align to byte + 2 * 16 bit length fields + block contents
  597. return 3 + padding + (2 * 16) + m_pending_block_size * 8;
  598. }
  599. size_t DeflateCompressor::fixed_block_length()
  600. {
  601. // block header + fixed huffman encoded block contents
  602. return 3 + huffman_block_length(fixed_literal_bit_lengths, fixed_distance_bit_lengths);
  603. }
  604. size_t DeflateCompressor::dynamic_block_length(Array<u8, max_huffman_literals> const& literal_bit_lengths, Array<u8, max_huffman_distances> const& distance_bit_lengths, Array<u8, 19> const& code_lengths_bit_lengths, Array<u16, 19> const& code_lengths_frequencies, size_t code_lengths_count)
  605. {
  606. // block header + literal code count + distance code count + code length count
  607. auto length = 3 + 5 + 5 + 4;
  608. // 3 bits per code_length
  609. length += 3 * code_lengths_count;
  610. for (size_t i = 0; i < code_lengths_frequencies.size(); i++) {
  611. auto frequency = code_lengths_frequencies[i];
  612. length += code_lengths_bit_lengths[i] * frequency;
  613. if (i == deflate_special_code_length_copy) {
  614. length += 2 * frequency;
  615. } else if (i == deflate_special_code_length_zeros) {
  616. length += 3 * frequency;
  617. } else if (i == deflate_special_code_length_long_zeros) {
  618. length += 7 * frequency;
  619. }
  620. }
  621. return length + huffman_block_length(literal_bit_lengths, distance_bit_lengths);
  622. }
  623. ErrorOr<void> DeflateCompressor::write_huffman(CanonicalCode const& literal_code, Optional<CanonicalCode> const& distance_code)
  624. {
  625. auto has_distances = distance_code.has_value();
  626. for (size_t i = 0; i < m_pending_symbol_size; i++) {
  627. if (m_symbol_buffer[i].distance == 0) {
  628. TRY(literal_code.write_symbol(*m_output_stream, m_symbol_buffer[i].literal));
  629. continue;
  630. }
  631. VERIFY(has_distances);
  632. auto symbol = length_to_symbol[m_symbol_buffer[i].length];
  633. TRY(literal_code.write_symbol(*m_output_stream, symbol));
  634. // Emit extra bits if needed
  635. TRY(m_output_stream->write_bits<u16>(m_symbol_buffer[i].length - packed_length_symbols[symbol - 257].base_length, packed_length_symbols[symbol - 257].extra_bits));
  636. auto base_distance = distance_to_base(m_symbol_buffer[i].distance);
  637. TRY(distance_code.value().write_symbol(*m_output_stream, base_distance));
  638. // Emit extra bits if needed
  639. TRY(m_output_stream->write_bits<u16>(m_symbol_buffer[i].distance - packed_distances[base_distance].base_distance, packed_distances[base_distance].extra_bits));
  640. }
  641. return {};
  642. }
  643. size_t DeflateCompressor::encode_huffman_lengths(Array<u8, max_huffman_literals + max_huffman_distances> const& lengths, size_t lengths_count, Array<code_length_symbol, max_huffman_literals + max_huffman_distances>& encoded_lengths)
  644. {
  645. size_t encoded_count = 0;
  646. size_t i = 0;
  647. while (i < lengths_count) {
  648. if (lengths[i] == 0) {
  649. auto zero_count = 0;
  650. for (size_t j = i; j < min(lengths_count, i + 138) && lengths[j] == 0; j++)
  651. zero_count++;
  652. if (zero_count < 3) { // below minimum repeated zero count
  653. encoded_lengths[encoded_count++].symbol = 0;
  654. i++;
  655. continue;
  656. }
  657. if (zero_count <= 10) {
  658. encoded_lengths[encoded_count].symbol = deflate_special_code_length_zeros;
  659. encoded_lengths[encoded_count++].count = zero_count;
  660. } else {
  661. encoded_lengths[encoded_count].symbol = deflate_special_code_length_long_zeros;
  662. encoded_lengths[encoded_count++].count = zero_count;
  663. }
  664. i += zero_count;
  665. continue;
  666. }
  667. encoded_lengths[encoded_count++].symbol = lengths[i++];
  668. auto copy_count = 0;
  669. for (size_t j = i; j < min(lengths_count, i + 6) && lengths[j] == lengths[i - 1]; j++)
  670. copy_count++;
  671. if (copy_count >= 3) {
  672. encoded_lengths[encoded_count].symbol = deflate_special_code_length_copy;
  673. encoded_lengths[encoded_count++].count = copy_count;
  674. i += copy_count;
  675. continue;
  676. }
  677. }
  678. return encoded_count;
  679. }
  680. size_t DeflateCompressor::encode_block_lengths(Array<u8, max_huffman_literals> const& literal_bit_lengths, Array<u8, max_huffman_distances> const& distance_bit_lengths, Array<code_length_symbol, max_huffman_literals + max_huffman_distances>& encoded_lengths, size_t& literal_code_count, size_t& distance_code_count)
  681. {
  682. literal_code_count = max_huffman_literals;
  683. distance_code_count = max_huffman_distances;
  684. VERIFY(literal_bit_lengths[256] != 0); // Make sure at least the EndOfBlock marker is present
  685. while (literal_bit_lengths[literal_code_count - 1] == 0)
  686. literal_code_count--;
  687. // Drop trailing zero lengths, keeping at least one
  688. while (distance_bit_lengths[distance_code_count - 1] == 0 && distance_code_count > 1)
  689. distance_code_count--;
  690. Array<u8, max_huffman_literals + max_huffman_distances> all_lengths {};
  691. size_t lengths_count = 0;
  692. for (size_t i = 0; i < literal_code_count; i++) {
  693. all_lengths[lengths_count++] = literal_bit_lengths[i];
  694. }
  695. for (size_t i = 0; i < distance_code_count; i++) {
  696. all_lengths[lengths_count++] = distance_bit_lengths[i];
  697. }
  698. return encode_huffman_lengths(all_lengths, lengths_count, encoded_lengths);
  699. }
  700. ErrorOr<void> DeflateCompressor::write_dynamic_huffman(CanonicalCode const& literal_code, size_t literal_code_count, Optional<CanonicalCode> const& distance_code, size_t distance_code_count, Array<u8, 19> const& code_lengths_bit_lengths, size_t code_length_count, Array<code_length_symbol, max_huffman_literals + max_huffman_distances> const& encoded_lengths, size_t encoded_lengths_count)
  701. {
  702. TRY(m_output_stream->write_bits(literal_code_count - 257, 5));
  703. TRY(m_output_stream->write_bits(distance_code_count - 1, 5));
  704. TRY(m_output_stream->write_bits(code_length_count - 4, 4));
  705. for (size_t i = 0; i < code_length_count; i++) {
  706. TRY(m_output_stream->write_bits(code_lengths_bit_lengths[code_lengths_code_lengths_order[i]], 3));
  707. }
  708. auto code_lengths_code = CanonicalCode::from_bytes(code_lengths_bit_lengths);
  709. VERIFY(code_lengths_code.has_value());
  710. for (size_t i = 0; i < encoded_lengths_count; i++) {
  711. auto encoded_length = encoded_lengths[i];
  712. TRY(code_lengths_code->write_symbol(*m_output_stream, encoded_length.symbol));
  713. if (encoded_length.symbol == deflate_special_code_length_copy) {
  714. TRY(m_output_stream->write_bits<u8>(encoded_length.count - 3, 2));
  715. } else if (encoded_length.symbol == deflate_special_code_length_zeros) {
  716. TRY(m_output_stream->write_bits<u8>(encoded_length.count - 3, 3));
  717. } else if (encoded_length.symbol == deflate_special_code_length_long_zeros) {
  718. TRY(m_output_stream->write_bits<u8>(encoded_length.count - 11, 7));
  719. }
  720. }
  721. TRY(write_huffman(literal_code, distance_code));
  722. return {};
  723. }
  724. ErrorOr<void> DeflateCompressor::flush()
  725. {
  726. TRY(m_output_stream->write_bits(m_finished, 1));
  727. // if this is just an empty block to signify the end of the deflate stream use the smallest block possible (10 bits total)
  728. if (m_pending_block_size == 0) {
  729. VERIFY(m_finished); // we shouldn't be writing empty blocks unless this is the final one
  730. TRY(m_output_stream->write_bits(0b01u, 2)); // fixed huffman codes
  731. TRY(m_output_stream->write_bits(0b0000000u, 7)); // end of block symbol
  732. TRY(m_output_stream->align_to_byte_boundary());
  733. return {};
  734. }
  735. auto write_uncompressed = [&]() -> ErrorOr<void> {
  736. TRY(m_output_stream->write_bits(0b00u, 2)); // no compression
  737. TRY(m_output_stream->align_to_byte_boundary());
  738. LittleEndian<u16> len = m_pending_block_size;
  739. TRY(m_output_stream->write_entire_buffer(len.bytes()));
  740. LittleEndian<u16> nlen = ~m_pending_block_size;
  741. TRY(m_output_stream->write_entire_buffer(nlen.bytes()));
  742. TRY(m_output_stream->write_entire_buffer(pending_block().slice(0, m_pending_block_size)));
  743. return {};
  744. };
  745. if (m_compression_level == CompressionLevel::STORE) { // disabled compression fast path
  746. TRY(write_uncompressed());
  747. m_pending_block_size = 0;
  748. return {};
  749. }
  750. // The following implementation of lz77 compression and huffman encoding is based on the reference implementation by Hans Wennborg https://www.hanshq.net/zip.html
  751. // this reads from the pending block and writes to m_symbol_buffer
  752. lz77_compress_block();
  753. // insert EndOfBlock marker to the symbol buffer
  754. m_symbol_buffer[m_pending_symbol_size].distance = 0;
  755. m_symbol_buffer[m_pending_symbol_size++].literal = 256;
  756. m_symbol_frequencies[256]++;
  757. // generate optimal dynamic huffman code lengths
  758. Array<u8, max_huffman_literals> dynamic_literal_bit_lengths {};
  759. Array<u8, max_huffman_distances> dynamic_distance_bit_lengths {};
  760. generate_huffman_lengths(dynamic_literal_bit_lengths, m_symbol_frequencies, 15); // deflate data huffman can use up to 15 bits per symbol
  761. generate_huffman_lengths(dynamic_distance_bit_lengths, m_distance_frequencies, 15);
  762. // encode literal and distance lengths together in deflate format
  763. Array<code_length_symbol, max_huffman_literals + max_huffman_distances> encoded_lengths {};
  764. size_t literal_code_count;
  765. size_t distance_code_count;
  766. auto encoded_lengths_count = encode_block_lengths(dynamic_literal_bit_lengths, dynamic_distance_bit_lengths, encoded_lengths, literal_code_count, distance_code_count);
  767. // count code length frequencies
  768. Array<u16, 19> code_lengths_frequencies { 0 };
  769. for (size_t i = 0; i < encoded_lengths_count; i++) {
  770. code_lengths_frequencies[encoded_lengths[i].symbol]++;
  771. }
  772. // generate optimal huffman code lengths code lengths
  773. Array<u8, 19> code_lengths_bit_lengths {};
  774. generate_huffman_lengths(code_lengths_bit_lengths, code_lengths_frequencies, 7); // deflate code length huffman can use up to 7 bits per symbol
  775. // calculate actual code length code lengths count (without trailing zeros)
  776. auto code_lengths_count = code_lengths_bit_lengths.size();
  777. while (code_lengths_bit_lengths[code_lengths_code_lengths_order[code_lengths_count - 1]] == 0)
  778. code_lengths_count--;
  779. auto uncompressed_size = uncompressed_block_length();
  780. auto fixed_huffman_size = fixed_block_length();
  781. auto dynamic_huffman_size = dynamic_block_length(dynamic_literal_bit_lengths, dynamic_distance_bit_lengths, code_lengths_bit_lengths, code_lengths_frequencies, code_lengths_count);
  782. // If the compression somehow didn't reduce the size enough, just write out the block uncompressed as it allows for much faster decompression
  783. if (uncompressed_size <= min(fixed_huffman_size, dynamic_huffman_size)) {
  784. TRY(write_uncompressed());
  785. } else if (fixed_huffman_size <= dynamic_huffman_size) {
  786. // If the fixed and dynamic huffman codes come out the same size, prefer the fixed version, as it takes less time to decode fixed huffman codes.
  787. TRY(m_output_stream->write_bits(0b01u, 2));
  788. TRY(write_huffman(CanonicalCode::fixed_literal_codes(), CanonicalCode::fixed_distance_codes()));
  789. } else {
  790. // dynamic huffman codes
  791. TRY(m_output_stream->write_bits(0b10u, 2));
  792. auto literal_code = CanonicalCode::from_bytes(dynamic_literal_bit_lengths);
  793. VERIFY(literal_code.has_value());
  794. auto distance_code = CanonicalCode::from_bytes(dynamic_distance_bit_lengths);
  795. TRY(write_dynamic_huffman(literal_code.value(), literal_code_count, distance_code, distance_code_count, code_lengths_bit_lengths, code_lengths_count, encoded_lengths, encoded_lengths_count));
  796. }
  797. if (m_finished)
  798. TRY(m_output_stream->align_to_byte_boundary());
  799. // reset all block specific members
  800. m_pending_block_size = 0;
  801. m_pending_symbol_size = 0;
  802. m_symbol_frequencies.fill(0);
  803. m_distance_frequencies.fill(0);
  804. // On the final block this copy will potentially produce an invalid search window, but since its the final block we dont care
  805. pending_block().copy_trimmed_to({ m_rolling_window, block_size });
  806. return {};
  807. }
  808. ErrorOr<void> DeflateCompressor::final_flush()
  809. {
  810. VERIFY(!m_finished);
  811. m_finished = true;
  812. TRY(flush());
  813. return {};
  814. }
  815. ErrorOr<ByteBuffer> DeflateCompressor::compress_all(ReadonlyBytes bytes, CompressionLevel compression_level)
  816. {
  817. auto output_stream = TRY(try_make<Core::Stream::AllocatingMemoryStream>());
  818. auto deflate_stream = TRY(DeflateCompressor::construct(MaybeOwned<AK::Stream>(*output_stream), compression_level));
  819. TRY(deflate_stream->write_entire_buffer(bytes));
  820. TRY(deflate_stream->final_flush());
  821. auto buffer = TRY(ByteBuffer::create_uninitialized(output_stream->used_buffer_size()));
  822. TRY(output_stream->read_entire_buffer(buffer));
  823. return buffer;
  824. }
  825. }