Decoder.cpp 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. /*
  2. * Copyright (c) 2021, Hunter Salyer <thefalsehonesty@gmail.com>
  3. * Copyright (c) 2022, Gregory Bertilson <zaggy1024@gmail.com>
  4. *
  5. * SPDX-License-Identifier: BSD-2-Clause
  6. */
  7. #include "Decoder.h"
  8. #include "Utilities.h"
  9. namespace Video::VP9 {
  10. Decoder::Decoder()
  11. : m_parser(make<Parser>(*this))
  12. {
  13. }
  14. ErrorOr<void> Decoder::decode_frame(ByteBuffer const& frame_data)
  15. {
  16. TRY(m_parser->parse_frame(frame_data));
  17. // TODO:
  18. // - #2
  19. // - #3
  20. // - #4
  21. TRY(update_reference_frames());
  22. return {};
  23. }
  24. void Decoder::dump_frame_info()
  25. {
  26. m_parser->dump_info();
  27. }
  28. u8 Decoder::merge_prob(u8 pre_prob, u8 count_0, u8 count_1, u8 count_sat, u8 max_update_factor)
  29. {
  30. auto total_decode_count = count_0 + count_1;
  31. auto prob = (total_decode_count == 0) ? 128 : clip_3(1, 255, (count_0 * 256 + (total_decode_count >> 1)) / total_decode_count);
  32. auto count = min(total_decode_count, count_sat);
  33. auto factor = (max_update_factor * count) / count_sat;
  34. return round_2(pre_prob * (256 - factor) + (prob * factor), 8);
  35. }
  36. u8 Decoder::merge_probs(int const* tree, int index, u8* probs, u8* counts, u8 count_sat, u8 max_update_factor)
  37. {
  38. auto s = tree[index];
  39. auto left_count = (s <= 0) ? counts[-s] : merge_probs(tree, s, probs, counts, count_sat, max_update_factor);
  40. auto r = tree[index + 1];
  41. auto right_count = (r <= 0) ? counts[-r] : merge_probs(tree, r, probs, counts, count_sat, max_update_factor);
  42. probs[index >> 1] = merge_prob(probs[index >> 1], left_count, right_count, count_sat, max_update_factor);
  43. return left_count + right_count;
  44. }
  45. ErrorOr<void> Decoder::adapt_coef_probs()
  46. {
  47. u8 update_factor;
  48. if (m_parser->m_frame_is_intra || m_parser->m_last_frame_type != KeyFrame)
  49. update_factor = 112;
  50. else
  51. update_factor = 128;
  52. for (size_t t = 0; t < 4; t++) {
  53. for (size_t i = 0; i < 2; i++) {
  54. for (size_t j = 0; j < 2; j++) {
  55. for (size_t k = 0; k < 6; k++) {
  56. size_t max_l = (k == 0) ? 3 : 6;
  57. for (size_t l = 0; l < max_l; l++) {
  58. auto& coef_probs = m_parser->m_probability_tables->coef_probs()[t][i][j][k][l];
  59. merge_probs(small_token_tree, 2, coef_probs,
  60. m_parser->m_syntax_element_counter->m_counts_token[t][i][j][k][l],
  61. 24, update_factor);
  62. merge_probs(binary_tree, 0, coef_probs,
  63. m_parser->m_syntax_element_counter->m_counts_more_coefs[t][i][j][k][l],
  64. 24, update_factor);
  65. }
  66. }
  67. }
  68. }
  69. }
  70. return {};
  71. }
  72. #define ADAPT_PROB_TABLE(name, size) \
  73. do { \
  74. for (size_t i = 0; i < (size); i++) { \
  75. auto table = probs.name##_prob(); \
  76. table[i] = adapt_prob(table[i], counter.m_counts_##name[i]); \
  77. } \
  78. } while (0)
  79. #define ADAPT_TREE(tree_name, prob_name, count_name, size) \
  80. do { \
  81. for (size_t i = 0; i < (size); i++) { \
  82. adapt_probs(tree_name##_tree, probs.prob_name##_probs()[i], counter.m_counts_##count_name[i]); \
  83. } \
  84. } while (0)
  85. ErrorOr<void> Decoder::adapt_non_coef_probs()
  86. {
  87. auto& probs = *m_parser->m_probability_tables;
  88. auto& counter = *m_parser->m_syntax_element_counter;
  89. ADAPT_PROB_TABLE(is_inter, IS_INTER_CONTEXTS);
  90. ADAPT_PROB_TABLE(comp_mode, COMP_MODE_CONTEXTS);
  91. ADAPT_PROB_TABLE(comp_ref, REF_CONTEXTS);
  92. for (size_t i = 0; i < REF_CONTEXTS; i++) {
  93. for (size_t j = 0; j < 2; j++)
  94. probs.single_ref_prob()[i][j] = adapt_prob(probs.single_ref_prob()[i][j], counter.m_counts_single_ref[i][j]);
  95. }
  96. ADAPT_TREE(inter_mode, inter_mode, inter_mode, INTER_MODE_CONTEXTS);
  97. ADAPT_TREE(intra_mode, y_mode, intra_mode, INTER_MODE_CONTEXTS);
  98. ADAPT_TREE(intra_mode, uv_mode, uv_mode, INTER_MODE_CONTEXTS);
  99. ADAPT_TREE(partition, partition, partition, INTER_MODE_CONTEXTS);
  100. ADAPT_PROB_TABLE(skip, SKIP_CONTEXTS);
  101. if (m_parser->m_interpolation_filter == Switchable) {
  102. ADAPT_TREE(interp_filter, interp_filter, interp_filter, INTERP_FILTER_CONTEXTS);
  103. }
  104. if (m_parser->m_tx_mode == TXModeSelect) {
  105. for (size_t i = 0; i < TX_SIZE_CONTEXTS; i++) {
  106. auto& tx_probs = probs.tx_probs();
  107. auto& tx_counts = counter.m_counts_tx_size;
  108. adapt_probs(tx_size_8_tree, tx_probs[TX_8x8][i], tx_counts[TX_8x8][i]);
  109. adapt_probs(tx_size_16_tree, tx_probs[TX_16x16][i], tx_counts[TX_16x16][i]);
  110. adapt_probs(tx_size_32_tree, tx_probs[TX_32x32][i], tx_counts[TX_32x32][i]);
  111. }
  112. }
  113. adapt_probs(mv_joint_tree, probs.mv_joint_probs(), counter.m_counts_mv_joint);
  114. for (size_t i = 0; i < 2; i++) {
  115. probs.mv_sign_prob()[i] = adapt_prob(probs.mv_sign_prob()[i], counter.m_counts_mv_sign[i]);
  116. adapt_probs(mv_class_tree, probs.mv_class_probs()[i], counter.m_counts_mv_class[i]);
  117. probs.mv_class0_bit_prob()[i] = adapt_prob(probs.mv_class0_bit_prob()[i], counter.m_counts_mv_class0_bit[i]);
  118. for (size_t j = 0; j < MV_OFFSET_BITS; j++)
  119. probs.mv_bits_prob()[i][j] = adapt_prob(probs.mv_bits_prob()[i][j], counter.m_counts_mv_bits[i][j]);
  120. for (size_t j = 0; j < CLASS0_SIZE; j++)
  121. adapt_probs(mv_fr_tree, probs.mv_class0_fr_probs()[i][j], counter.m_counts_mv_class0_fr[i][j]);
  122. adapt_probs(mv_fr_tree, probs.mv_fr_probs()[i], counter.m_counts_mv_fr[i]);
  123. if (m_parser->m_allow_high_precision_mv) {
  124. probs.mv_class0_hp_prob()[i] = adapt_prob(probs.mv_class0_hp_prob()[i], counter.m_counts_mv_class0_hp[i]);
  125. probs.mv_hp_prob()[i] = adapt_prob(probs.mv_hp_prob()[i], counter.m_counts_mv_hp[i]);
  126. }
  127. }
  128. return {};
  129. }
  130. void Decoder::adapt_probs(int const* tree, u8* probs, u8* counts)
  131. {
  132. merge_probs(tree, 0, probs, counts, COUNT_SAT, MAX_UPDATE_FACTOR);
  133. }
  134. u8 Decoder::adapt_prob(u8 prob, u8 counts[2])
  135. {
  136. return merge_prob(prob, counts[0], counts[1], COUNT_SAT, MAX_UPDATE_FACTOR);
  137. }
  138. ErrorOr<void> Decoder::predict_intra(size_t, u32, u32, bool, bool, bool, TXSize, u32)
  139. {
  140. // TODO: Implement
  141. return Error::from_string_literal("predict_intra not implemented");
  142. }
  143. ErrorOr<void> Decoder::predict_inter(size_t, u32, u32, u32, u32, u32)
  144. {
  145. // TODO: Implement
  146. return Error::from_string_literal("predict_inter not implemented");
  147. }
  148. ErrorOr<void> Decoder::reconstruct(size_t, u32, u32, TXSize)
  149. {
  150. // TODO: Implement
  151. return Error::from_string_literal("reconstruct not implemented");
  152. }
  153. ErrorOr<void> Decoder::update_reference_frames()
  154. {
  155. for (auto i = 0; i < NUM_REF_FRAMES; i++) {
  156. dbgln("updating frame {}? {}", i, (m_parser->m_refresh_frame_flags & (1 << i)) == 1);
  157. if ((m_parser->m_refresh_frame_flags & (1 << i)) != 1)
  158. continue;
  159. m_parser->m_ref_frame_width[i] = m_parser->m_frame_width;
  160. m_parser->m_ref_frame_height[i] = m_parser->m_frame_height;
  161. // TODO: 1.3-1.7
  162. }
  163. // TODO: 2.1-2.2
  164. return {};
  165. }
  166. }