TreeParser.cpp 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737
  1. /*
  2. * Copyright (c) 2021, Hunter Salyer <thefalsehonesty@gmail.com>
  3. * Copyright (c) 2022, Gregory Bertilson <zaggy1024@gmail.com>
  4. *
  5. * SPDX-License-Identifier: BSD-2-Clause
  6. */
  7. #include <AK/Function.h>
  8. #include "Context.h"
  9. #include "Enums.h"
  10. #include "LookupTables.h"
  11. #include "Parser.h"
  12. #include "TreeParser.h"
  13. #include "Utilities.h"
  14. namespace Video::VP9 {
  15. // Parsing of binary trees is handled here, as defined in sections 9.3.
  16. // Each syntax element is defined in its own section for each overarching section listed here:
  17. // - 9.3.1: Selection of the binary tree to be used.
  18. // - 9.3.2: Probability selection based on context and often the node of the tree.
  19. // - 9.3.4: Counting each syntax element when it is read.
  20. class TreeSelection {
  21. public:
  22. union TreeSelectionValue {
  23. int const* m_tree;
  24. int m_value;
  25. };
  26. constexpr TreeSelection(int const* values)
  27. : m_is_single_value(false)
  28. , m_value { .m_tree = values }
  29. {
  30. }
  31. constexpr TreeSelection(int value)
  32. : m_is_single_value(true)
  33. , m_value { .m_value = value }
  34. {
  35. }
  36. bool is_single_value() const { return m_is_single_value; }
  37. int single_value() const { return m_value.m_value; }
  38. int const* tree() const { return m_value.m_tree; }
  39. private:
  40. bool m_is_single_value;
  41. TreeSelectionValue m_value;
  42. };
  43. template<typename OutputType>
  44. inline ErrorOr<OutputType> parse_tree(BitStream& bit_stream, TreeSelection tree_selection, Function<u8(u8)> const& probability_getter)
  45. {
  46. // 9.3.3: The tree decoding function.
  47. if (tree_selection.is_single_value())
  48. return static_cast<OutputType>(tree_selection.single_value());
  49. int const* tree = tree_selection.tree();
  50. int n = 0;
  51. do {
  52. u8 node = n >> 1;
  53. n = tree[n + TRY(bit_stream.read_bool(probability_getter(node)))];
  54. } while (n > 0);
  55. return static_cast<OutputType>(-n);
  56. }
  57. ErrorOr<Partition> TreeParser::parse_partition(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, bool has_rows, bool has_columns, BlockSubsize block_subsize, u8 num_8x8, PartitionContextView above_partition_context, PartitionContextView left_partition_context, u32 row, u32 column, bool frame_is_intra)
  58. {
  59. // Tree array
  60. TreeSelection tree = { PartitionSplit };
  61. if (has_rows && has_columns)
  62. tree = { partition_tree };
  63. else if (has_rows)
  64. tree = { rows_partition_tree };
  65. else if (has_columns)
  66. tree = { cols_partition_tree };
  67. // Probability array
  68. u32 above = 0;
  69. u32 left = 0;
  70. auto bsl = mi_width_log2_lookup[block_subsize];
  71. auto block_offset = mi_width_log2_lookup[Block_64x64] - bsl;
  72. for (auto i = 0; i < num_8x8; i++) {
  73. if (column + i >= above_partition_context.size())
  74. dbgln("column={}, i={}, size={}", column, i, above_partition_context.size());
  75. above |= above_partition_context[column + i];
  76. if (row + i >= left_partition_context.size())
  77. dbgln("row={}, i={}, size={}", row, i, left_partition_context.size());
  78. left |= left_partition_context[row + i];
  79. }
  80. above = (above & (1 << block_offset)) > 0;
  81. left = (left & (1 << block_offset)) > 0;
  82. auto context = bsl * 4 + left * 2 + above;
  83. u8 const* probabilities = frame_is_intra ? probability_table.kf_partition_probs()[context] : probability_table.partition_probs()[context];
  84. Function<u8(u8)> probability_getter = [&](u8 node) {
  85. if (has_rows && has_columns)
  86. return probabilities[node];
  87. if (has_columns)
  88. return probabilities[1];
  89. return probabilities[2];
  90. };
  91. auto value = TRY(parse_tree<Partition>(bit_stream, tree, probability_getter));
  92. counter.m_counts_partition[context][value]++;
  93. return value;
  94. }
  95. ErrorOr<PredictionMode> TreeParser::parse_default_intra_mode(BitStream& bit_stream, ProbabilityTables const& probability_table, BlockSubsize mi_size, FrameBlockContext above, FrameBlockContext left, Array<PredictionMode, 4> const& block_sub_modes, u8 index_x, u8 index_y)
  96. {
  97. // FIXME: This should use a struct for the above and left contexts.
  98. // Tree
  99. TreeSelection tree = { intra_mode_tree };
  100. // Probabilities
  101. PredictionMode above_mode, left_mode;
  102. if (mi_size >= Block_8x8) {
  103. above_mode = above.sub_modes[2];
  104. left_mode = left.sub_modes[1];
  105. } else {
  106. if (index_y > 0)
  107. above_mode = block_sub_modes[index_x];
  108. else
  109. above_mode = above.sub_modes[2 + index_x];
  110. if (index_x > 0)
  111. left_mode = block_sub_modes[index_y << 1];
  112. else
  113. left_mode = left.sub_modes[1 + (index_y << 1)];
  114. }
  115. u8 const* probabilities = probability_table.kf_y_mode_probs()[to_underlying(above_mode)][to_underlying(left_mode)];
  116. auto value = TRY(parse_tree<PredictionMode>(bit_stream, tree, [&](u8 node) { return probabilities[node]; }));
  117. // Default intra mode is not counted.
  118. return value;
  119. }
  120. ErrorOr<PredictionMode> TreeParser::parse_default_uv_mode(BitStream& bit_stream, ProbabilityTables const& probability_table, PredictionMode y_mode)
  121. {
  122. // Tree
  123. TreeSelection tree = { intra_mode_tree };
  124. // Probabilities
  125. u8 const* probabilities = probability_table.kf_uv_mode_prob()[to_underlying(y_mode)];
  126. auto value = TRY(parse_tree<PredictionMode>(bit_stream, tree, [&](u8 node) { return probabilities[node]; }));
  127. // Default UV mode is not counted.
  128. return value;
  129. }
  130. ErrorOr<PredictionMode> TreeParser::parse_intra_mode(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, BlockSubsize mi_size)
  131. {
  132. // Tree
  133. TreeSelection tree = { intra_mode_tree };
  134. // Probabilities
  135. auto context = size_group_lookup[mi_size];
  136. u8 const* probabilities = probability_table.y_mode_probs()[context];
  137. auto value = TRY(parse_tree<PredictionMode>(bit_stream, tree, [&](u8 node) { return probabilities[node]; }));
  138. counter.m_counts_intra_mode[context][to_underlying(value)]++;
  139. return value;
  140. }
  141. ErrorOr<PredictionMode> TreeParser::parse_sub_intra_mode(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter)
  142. {
  143. // Tree
  144. TreeSelection tree = { intra_mode_tree };
  145. // Probabilities
  146. u8 const* probabilities = probability_table.y_mode_probs()[0];
  147. auto value = TRY(parse_tree<PredictionMode>(bit_stream, tree, [&](u8 node) { return probabilities[node]; }));
  148. counter.m_counts_intra_mode[0][to_underlying(value)]++;
  149. return value;
  150. }
  151. ErrorOr<PredictionMode> TreeParser::parse_uv_mode(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, PredictionMode y_mode)
  152. {
  153. // Tree
  154. TreeSelection tree = { intra_mode_tree };
  155. // Probabilities
  156. u8 const* probabilities = probability_table.uv_mode_probs()[to_underlying(y_mode)];
  157. auto value = TRY(parse_tree<PredictionMode>(bit_stream, tree, [&](u8 node) { return probabilities[node]; }));
  158. counter.m_counts_uv_mode[to_underlying(y_mode)][to_underlying(value)]++;
  159. return value;
  160. }
  161. ErrorOr<u8> TreeParser::parse_segment_id(BitStream& bit_stream, Array<u8, 7> const& probabilities)
  162. {
  163. auto value = TRY(parse_tree<u8>(bit_stream, { segment_tree }, [&](u8 node) { return probabilities[node]; }));
  164. // Segment ID is not counted.
  165. return value;
  166. }
  167. ErrorOr<bool> TreeParser::parse_segment_id_predicted(BitStream& bit_stream, Array<u8, 3> const& probabilities, u8 above_seg_pred_context, u8 left_seg_pred_context)
  168. {
  169. auto context = left_seg_pred_context + above_seg_pred_context;
  170. auto value = TRY(parse_tree<bool>(bit_stream, { binary_tree }, [&](u8) { return probabilities[context]; }));
  171. // Segment ID prediction is not counted.
  172. return value;
  173. }
  174. ErrorOr<PredictionMode> TreeParser::parse_inter_mode(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, u8 mode_context_for_ref_frame_0)
  175. {
  176. // Tree
  177. TreeSelection tree = { inter_mode_tree };
  178. // Probabilities
  179. u8 const* probabilities = probability_table.inter_mode_probs()[mode_context_for_ref_frame_0];
  180. auto value = TRY(parse_tree<u8>(bit_stream, tree, [&](u8 node) { return probabilities[node]; }));
  181. counter.m_counts_inter_mode[mode_context_for_ref_frame_0][value]++;
  182. return static_cast<PredictionMode>(value + to_underlying(PredictionMode::NearestMv));
  183. }
  184. ErrorOr<InterpolationFilter> TreeParser::parse_interpolation_filter(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, FrameBlockContext above, FrameBlockContext left)
  185. {
  186. // FIXME: Above and left context should be provided by a struct.
  187. // Tree
  188. TreeSelection tree = { interp_filter_tree };
  189. // Probabilities
  190. // NOTE: SWITCHABLE_FILTERS is not used in the spec for this function. Therefore, the number
  191. // was demystified by referencing the reference codec libvpx:
  192. // https://github.com/webmproject/libvpx/blob/705bf9de8c96cfe5301451f1d7e5c90a41c64e5f/vp9/common/vp9_pred_common.h#L69
  193. u8 left_interp = !left.is_intra_predicted() ? left.interpolation_filter : SWITCHABLE_FILTERS;
  194. u8 above_interp = !above.is_intra_predicted() ? above.interpolation_filter : SWITCHABLE_FILTERS;
  195. u8 context = SWITCHABLE_FILTERS;
  196. if (above_interp == left_interp || above_interp == SWITCHABLE_FILTERS)
  197. context = left_interp;
  198. else if (left_interp == SWITCHABLE_FILTERS)
  199. context = above_interp;
  200. u8 const* probabilities = probability_table.interp_filter_probs()[context];
  201. auto value = TRY(parse_tree<InterpolationFilter>(bit_stream, tree, [&](u8 node) { return probabilities[node]; }));
  202. counter.m_counts_interp_filter[context][to_underlying(value)]++;
  203. return value;
  204. }
  205. ErrorOr<bool> TreeParser::parse_skip(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, FrameBlockContext above, FrameBlockContext left)
  206. {
  207. // Probabilities
  208. u8 context = 0;
  209. context += static_cast<u8>(above.skip_coefficients);
  210. context += static_cast<u8>(left.skip_coefficients);
  211. u8 probability = probability_table.skip_prob()[context];
  212. auto value = TRY(parse_tree<bool>(bit_stream, { binary_tree }, [&](u8) { return probability; }));
  213. counter.m_counts_skip[context][value]++;
  214. return value;
  215. }
  216. ErrorOr<TransformSize> TreeParser::parse_tx_size(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, TransformSize max_tx_size, FrameBlockContext above, FrameBlockContext left)
  217. {
  218. // FIXME: Above and left contexts should be in structs.
  219. // Tree
  220. TreeSelection tree { tx_size_8_tree };
  221. if (max_tx_size == Transform_16x16)
  222. tree = { tx_size_16_tree };
  223. if (max_tx_size == Transform_32x32)
  224. tree = { tx_size_32_tree };
  225. // Probabilities
  226. auto above_context = max_tx_size;
  227. auto left_context = max_tx_size;
  228. if (above.is_available && !above.skip_coefficients)
  229. above_context = above.transform_size;
  230. if (left.is_available && !left.skip_coefficients)
  231. left_context = left.transform_size;
  232. if (!left.is_available)
  233. left_context = above_context;
  234. if (!above.is_available)
  235. above_context = left_context;
  236. auto context = (above_context + left_context) > max_tx_size;
  237. u8 const* probabilities = probability_table.tx_probs()[max_tx_size][context];
  238. auto value = TRY(parse_tree<TransformSize>(bit_stream, tree, [&](u8 node) { return probabilities[node]; }));
  239. counter.m_counts_tx_size[max_tx_size][context][value]++;
  240. return value;
  241. }
  242. ErrorOr<bool> TreeParser::parse_block_is_inter_predicted(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, FrameBlockContext above, FrameBlockContext left)
  243. {
  244. // FIXME: Above and left contexts should be in structs.
  245. // Probabilities
  246. u8 context = 0;
  247. if (above.is_available && left.is_available)
  248. context = (left.is_intra_predicted() && above.is_intra_predicted()) ? 3 : static_cast<u8>(above.is_intra_predicted() || left.is_intra_predicted());
  249. else if (above.is_available || left.is_available)
  250. context = 2 * static_cast<u8>(above.is_available ? above.is_intra_predicted() : left.is_intra_predicted());
  251. u8 probability = probability_table.is_inter_prob()[context];
  252. auto value = TRY(parse_tree<bool>(bit_stream, { binary_tree }, [&](u8) { return probability; }));
  253. counter.m_counts_is_inter[context][value]++;
  254. return value;
  255. }
  256. ErrorOr<ReferenceMode> TreeParser::parse_comp_mode(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, ReferenceFrameType comp_fixed_ref, FrameBlockContext above, FrameBlockContext left)
  257. {
  258. // FIXME: Above and left contexts should be in structs.
  259. // Probabilities
  260. u8 context;
  261. if (above.is_available && left.is_available) {
  262. if (above.is_single_reference() && left.is_single_reference()) {
  263. auto is_above_fixed = above.ref_frames.primary == comp_fixed_ref;
  264. auto is_left_fixed = left.ref_frames.primary == comp_fixed_ref;
  265. context = is_above_fixed ^ is_left_fixed;
  266. } else if (above.is_single_reference()) {
  267. auto is_above_fixed = above.ref_frames.primary == comp_fixed_ref;
  268. context = 2 + static_cast<u8>(is_above_fixed || above.is_intra_predicted());
  269. } else if (left.is_single_reference()) {
  270. auto is_left_fixed = left.ref_frames.primary == comp_fixed_ref;
  271. context = 2 + static_cast<u8>(is_left_fixed || left.is_intra_predicted());
  272. } else {
  273. context = 4;
  274. }
  275. } else if (above.is_available) {
  276. if (above.is_single_reference())
  277. context = above.ref_frames.primary == comp_fixed_ref;
  278. else
  279. context = 3;
  280. } else if (left.is_available) {
  281. if (left.is_single_reference())
  282. context = static_cast<u8>(left.ref_frames.primary == comp_fixed_ref);
  283. else
  284. context = 3;
  285. } else {
  286. context = 1;
  287. }
  288. u8 probability = probability_table.comp_mode_prob()[context];
  289. auto value = TRY(parse_tree<ReferenceMode>(bit_stream, { binary_tree }, [&](u8) { return probability; }));
  290. counter.m_counts_comp_mode[context][value]++;
  291. return value;
  292. }
  293. ErrorOr<ReferenceIndex> TreeParser::parse_comp_ref(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, ReferenceFrameType comp_fixed_ref, ReferenceFramePair comp_var_ref, ReferenceIndex variable_reference_index, FrameBlockContext above, FrameBlockContext left)
  294. {
  295. // FIXME: Above and left contexts should be in structs.
  296. // Probabilities
  297. u8 context;
  298. if (above.is_available && left.is_available) {
  299. if (above.is_intra_predicted() && left.is_intra_predicted()) {
  300. context = 2;
  301. } else if (left.is_intra_predicted()) {
  302. if (above.is_single_reference()) {
  303. context = 1 + 2 * (above.ref_frames.primary != comp_var_ref.secondary);
  304. } else {
  305. context = 1 + 2 * (above.ref_frames[variable_reference_index] != comp_var_ref.secondary);
  306. }
  307. } else if (above.is_intra_predicted()) {
  308. if (left.is_single_reference()) {
  309. context = 1 + 2 * (left.ref_frames.primary != comp_var_ref.secondary);
  310. } else {
  311. context = 1 + 2 * (left.ref_frames[variable_reference_index] != comp_var_ref.secondary);
  312. }
  313. } else {
  314. auto var_ref_above = above.is_single_reference() ? above.ref_frames.primary : above.ref_frames[variable_reference_index];
  315. auto var_ref_left = left.is_single_reference() ? left.ref_frames.primary : left.ref_frames[variable_reference_index];
  316. if (var_ref_above == var_ref_left && comp_var_ref.secondary == var_ref_above) {
  317. context = 0;
  318. } else if (left.is_single_reference() && above.is_single_reference()) {
  319. if ((var_ref_above == comp_fixed_ref && var_ref_left == comp_var_ref.primary)
  320. || (var_ref_left == comp_fixed_ref && var_ref_above == comp_var_ref.primary)) {
  321. context = 4;
  322. } else if (var_ref_above == var_ref_left) {
  323. context = 3;
  324. } else {
  325. context = 1;
  326. }
  327. } else if (left.is_single_reference() || above.is_single_reference()) {
  328. auto vrfc = left.is_single_reference() ? var_ref_above : var_ref_left;
  329. auto rfs = above.is_single_reference() ? var_ref_above : var_ref_left;
  330. if (vrfc == comp_var_ref.secondary && rfs != comp_var_ref.secondary) {
  331. context = 1;
  332. } else if (rfs == comp_var_ref.secondary && vrfc != comp_var_ref.secondary) {
  333. context = 2;
  334. } else {
  335. context = 4;
  336. }
  337. } else if (var_ref_above == var_ref_left) {
  338. context = 4;
  339. } else {
  340. context = 2;
  341. }
  342. }
  343. } else if (above.is_available) {
  344. if (above.is_intra_predicted()) {
  345. context = 2;
  346. } else {
  347. if (above.is_single_reference()) {
  348. context = 3 * static_cast<u8>(above.ref_frames.primary != comp_var_ref.secondary);
  349. } else {
  350. context = 4 * static_cast<u8>(above.ref_frames[variable_reference_index] != comp_var_ref.secondary);
  351. }
  352. }
  353. } else if (left.is_available) {
  354. if (left.is_intra_predicted()) {
  355. context = 2;
  356. } else {
  357. if (left.is_single_reference()) {
  358. context = 3 * static_cast<u8>(left.ref_frames.primary != comp_var_ref.secondary);
  359. } else {
  360. context = 4 * static_cast<u8>(left.ref_frames[variable_reference_index] != comp_var_ref.secondary);
  361. }
  362. }
  363. } else {
  364. context = 2;
  365. }
  366. u8 probability = probability_table.comp_ref_prob()[context];
  367. auto value = TRY(parse_tree<ReferenceIndex>(bit_stream, { binary_tree }, [&](u8) { return probability; }));
  368. counter.m_counts_comp_ref[context][to_underlying(value)]++;
  369. return value;
  370. }
  371. ErrorOr<bool> TreeParser::parse_single_ref_part_1(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, FrameBlockContext above, FrameBlockContext left)
  372. {
  373. // FIXME: Above and left contexts should be in structs.
  374. // Probabilities
  375. u8 context;
  376. if (above.is_available && left.is_available) {
  377. if (above.is_intra_predicted() && left.is_intra_predicted()) {
  378. context = 2;
  379. } else if (left.is_intra_predicted()) {
  380. if (above.is_single_reference()) {
  381. context = 4 * (above.ref_frames.primary == ReferenceFrameType::LastFrame);
  382. } else {
  383. context = 1 + (above.ref_frames.primary == ReferenceFrameType::LastFrame || above.ref_frames.secondary == ReferenceFrameType::LastFrame);
  384. }
  385. } else if (above.is_intra_predicted()) {
  386. if (left.is_single_reference()) {
  387. context = 4 * (left.ref_frames.primary == ReferenceFrameType::LastFrame);
  388. } else {
  389. context = 1 + (left.ref_frames.primary == ReferenceFrameType::LastFrame || left.ref_frames.secondary == ReferenceFrameType::LastFrame);
  390. }
  391. } else {
  392. if (left.is_single_reference() && above.is_single_reference()) {
  393. context = 2 * (above.ref_frames.primary == ReferenceFrameType::LastFrame) + 2 * (left.ref_frames.primary == ReferenceFrameType::LastFrame);
  394. } else if (!left.is_single_reference() && !above.is_single_reference()) {
  395. auto above_used_last_frame = above.ref_frames.primary == ReferenceFrameType::LastFrame || above.ref_frames.secondary == ReferenceFrameType::LastFrame;
  396. auto left_used_last_frame = left.ref_frames.primary == ReferenceFrameType::LastFrame || left.ref_frames.secondary == ReferenceFrameType::LastFrame;
  397. context = 1 + (above_used_last_frame || left_used_last_frame);
  398. } else {
  399. auto single_reference_type = above.is_single_reference() ? above.ref_frames.primary : left.ref_frames.primary;
  400. auto compound_reference_a_type = above.is_single_reference() ? left.ref_frames.primary : above.ref_frames.primary;
  401. auto compound_reference_b_type = above.is_single_reference() ? left.ref_frames.secondary : above.ref_frames.secondary;
  402. context = compound_reference_a_type == ReferenceFrameType::LastFrame || compound_reference_b_type == ReferenceFrameType::LastFrame;
  403. if (single_reference_type == ReferenceFrameType::LastFrame)
  404. context += 3;
  405. }
  406. }
  407. } else if (above.is_available) {
  408. if (above.is_intra_predicted()) {
  409. context = 2;
  410. } else {
  411. if (above.is_single_reference()) {
  412. context = 4 * (above.ref_frames.primary == ReferenceFrameType::LastFrame);
  413. } else {
  414. context = 1 + (above.ref_frames.primary == ReferenceFrameType::LastFrame || above.ref_frames.secondary == ReferenceFrameType::LastFrame);
  415. }
  416. }
  417. } else if (left.is_available) {
  418. if (left.is_intra_predicted()) {
  419. context = 2;
  420. } else {
  421. if (left.is_single_reference()) {
  422. context = 4 * (left.ref_frames.primary == ReferenceFrameType::LastFrame);
  423. } else {
  424. context = 1 + (left.ref_frames.primary == ReferenceFrameType::LastFrame || left.ref_frames.secondary == ReferenceFrameType::LastFrame);
  425. }
  426. }
  427. } else {
  428. context = 2;
  429. }
  430. u8 probability = probability_table.single_ref_prob()[context][0];
  431. auto value = TRY(parse_tree<bool>(bit_stream, { binary_tree }, [&](u8) { return probability; }));
  432. counter.m_counts_single_ref[context][0][value]++;
  433. return value;
  434. }
  435. ErrorOr<bool> TreeParser::parse_single_ref_part_2(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, FrameBlockContext above, FrameBlockContext left)
  436. {
  437. // FIXME: Above and left contexts should be in structs.
  438. // Probabilities
  439. u8 context;
  440. if (above.is_available && left.is_available) {
  441. if (above.is_intra_predicted() && left.is_intra_predicted()) {
  442. context = 2;
  443. } else if (left.is_intra_predicted()) {
  444. if (above.is_single_reference()) {
  445. if (above.ref_frames.primary == ReferenceFrameType::LastFrame) {
  446. context = 3;
  447. } else {
  448. context = 4 * (above.ref_frames.primary == ReferenceFrameType::GoldenFrame);
  449. }
  450. } else {
  451. context = 1 + 2 * (above.ref_frames.primary == ReferenceFrameType::GoldenFrame || above.ref_frames.secondary == ReferenceFrameType::GoldenFrame);
  452. }
  453. } else if (above.is_intra_predicted()) {
  454. if (left.is_single_reference()) {
  455. if (left.ref_frames.primary == ReferenceFrameType::LastFrame) {
  456. context = 3;
  457. } else {
  458. context = 4 * (left.ref_frames.primary == ReferenceFrameType::GoldenFrame);
  459. }
  460. } else {
  461. context = 1 + 2 * (left.ref_frames.primary == ReferenceFrameType::GoldenFrame || left.ref_frames.secondary == ReferenceFrameType::GoldenFrame);
  462. }
  463. } else {
  464. if (left.is_single_reference() && above.is_single_reference()) {
  465. auto above_last = above.ref_frames.primary == ReferenceFrameType::LastFrame;
  466. auto left_last = left.ref_frames.primary == ReferenceFrameType::LastFrame;
  467. if (above_last && left_last) {
  468. context = 3;
  469. } else if (above_last) {
  470. context = 4 * (left.ref_frames.primary == ReferenceFrameType::GoldenFrame);
  471. } else if (left_last) {
  472. context = 4 * (above.ref_frames.primary == ReferenceFrameType::GoldenFrame);
  473. } else {
  474. context = 2 * (above.ref_frames.primary == ReferenceFrameType::GoldenFrame) + 2 * (left.ref_frames.primary == ReferenceFrameType::GoldenFrame);
  475. }
  476. } else if (!left.is_single_reference() && !above.is_single_reference()) {
  477. if (above.ref_frames.primary == left.ref_frames.primary && above.ref_frames.secondary == left.ref_frames.secondary) {
  478. context = 3 * (above.ref_frames.primary == ReferenceFrameType::GoldenFrame || above.ref_frames.secondary == ReferenceFrameType::GoldenFrame);
  479. } else {
  480. context = 2;
  481. }
  482. } else {
  483. auto single_reference_type = above.is_single_reference() ? above.ref_frames.primary : left.ref_frames.primary;
  484. auto compound_reference_a_type = above.is_single_reference() ? left.ref_frames.primary : above.ref_frames.primary;
  485. auto compound_reference_b_type = above.is_single_reference() ? left.ref_frames.secondary : above.ref_frames.secondary;
  486. context = compound_reference_a_type == ReferenceFrameType::GoldenFrame || compound_reference_b_type == ReferenceFrameType::GoldenFrame;
  487. if (single_reference_type == ReferenceFrameType::GoldenFrame) {
  488. context += 3;
  489. } else if (single_reference_type != ReferenceFrameType::AltRefFrame) {
  490. context = 1 + (2 * context);
  491. }
  492. }
  493. }
  494. } else if (above.is_available) {
  495. if (above.is_intra_predicted() || (above.ref_frames.primary == ReferenceFrameType::LastFrame && above.is_single_reference())) {
  496. context = 2;
  497. } else if (above.is_single_reference()) {
  498. context = 4 * (above.ref_frames.primary == ReferenceFrameType::GoldenFrame);
  499. } else {
  500. context = 3 * (above.ref_frames.primary == ReferenceFrameType::GoldenFrame || above.ref_frames.secondary == ReferenceFrameType::GoldenFrame);
  501. }
  502. } else if (left.is_available) {
  503. if (left.is_intra_predicted() || (left.ref_frames.primary == ReferenceFrameType::LastFrame && left.is_single_reference())) {
  504. context = 2;
  505. } else if (left.is_single_reference()) {
  506. context = 4 * (left.ref_frames.primary == ReferenceFrameType::GoldenFrame);
  507. } else {
  508. context = 3 * (left.ref_frames.primary == ReferenceFrameType::GoldenFrame || left.ref_frames.secondary == ReferenceFrameType::GoldenFrame);
  509. }
  510. } else {
  511. context = 2;
  512. }
  513. u8 probability = probability_table.single_ref_prob()[context][1];
  514. auto value = TRY(parse_tree<bool>(bit_stream, { binary_tree }, [&](u8) { return probability; }));
  515. counter.m_counts_single_ref[context][1][value]++;
  516. return value;
  517. }
  518. ErrorOr<MvJoint> TreeParser::parse_motion_vector_joint(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter)
  519. {
  520. auto value = TRY(parse_tree<MvJoint>(bit_stream, { mv_joint_tree }, [&](u8 node) { return probability_table.mv_joint_probs()[node]; }));
  521. counter.m_counts_mv_joint[value]++;
  522. return value;
  523. }
  524. ErrorOr<bool> TreeParser::parse_motion_vector_sign(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, u8 component)
  525. {
  526. auto value = TRY(parse_tree<bool>(bit_stream, { binary_tree }, [&](u8) { return probability_table.mv_sign_prob()[component]; }));
  527. counter.m_counts_mv_sign[component][value]++;
  528. return value;
  529. }
  530. ErrorOr<MvClass> TreeParser::parse_motion_vector_class(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, u8 component)
  531. {
  532. // Spec doesn't mention node, but the probabilities table has an extra dimension
  533. // so we will use node for that.
  534. auto value = TRY(parse_tree<MvClass>(bit_stream, { mv_class_tree }, [&](u8 node) { return probability_table.mv_class_probs()[component][node]; }));
  535. counter.m_counts_mv_class[component][value]++;
  536. return value;
  537. }
  538. ErrorOr<bool> TreeParser::parse_motion_vector_class0_bit(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, u8 component)
  539. {
  540. auto value = TRY(parse_tree<bool>(bit_stream, { binary_tree }, [&](u8) { return probability_table.mv_class0_bit_prob()[component]; }));
  541. counter.m_counts_mv_class0_bit[component][value]++;
  542. return value;
  543. }
  544. ErrorOr<u8> TreeParser::parse_motion_vector_class0_fr(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, u8 component, bool class_0_bit)
  545. {
  546. auto value = TRY(parse_tree<u8>(bit_stream, { mv_fr_tree }, [&](u8 node) { return probability_table.mv_class0_fr_probs()[component][class_0_bit][node]; }));
  547. counter.m_counts_mv_class0_fr[component][class_0_bit][value]++;
  548. return value;
  549. }
  550. ErrorOr<bool> TreeParser::parse_motion_vector_class0_hp(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, u8 component, bool use_hp)
  551. {
  552. TreeSelection tree { 1 };
  553. if (use_hp)
  554. tree = { binary_tree };
  555. auto value = TRY(parse_tree<bool>(bit_stream, tree, [&](u8) { return probability_table.mv_class0_hp_prob()[component]; }));
  556. counter.m_counts_mv_class0_hp[component][value]++;
  557. return value;
  558. }
  559. ErrorOr<bool> TreeParser::parse_motion_vector_bit(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, u8 component, u8 bit_index)
  560. {
  561. auto value = TRY(parse_tree<bool>(bit_stream, { binary_tree }, [&](u8) { return probability_table.mv_bits_prob()[component][bit_index]; }));
  562. counter.m_counts_mv_bits[component][bit_index][value]++;
  563. return value;
  564. }
  565. ErrorOr<u8> TreeParser::parse_motion_vector_fr(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, u8 component)
  566. {
  567. auto value = TRY(parse_tree<u8>(bit_stream, { mv_fr_tree }, [&](u8 node) { return probability_table.mv_fr_probs()[component][node]; }));
  568. counter.m_counts_mv_fr[component][value]++;
  569. return value;
  570. }
  571. ErrorOr<bool> TreeParser::parse_motion_vector_hp(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, u8 component, bool use_hp)
  572. {
  573. TreeSelection tree { 1 };
  574. if (use_hp)
  575. tree = { binary_tree };
  576. auto value = TRY(parse_tree<u8>(bit_stream, tree, [&](u8) { return probability_table.mv_hp_prob()[component]; }));
  577. counter.m_counts_mv_hp[component][value]++;
  578. return value;
  579. }
  580. TokensContext TreeParser::get_context_for_first_token(NonZeroTokensView above_non_zero_tokens, NonZeroTokensView left_non_zero_tokens_in_block, TransformSize transform_size, u8 plane, u32 sub_block_column, u32 sub_block_row, bool is_inter, u8 band)
  581. {
  582. u8 transform_size_in_sub_blocks = transform_size_to_sub_blocks(transform_size);
  583. bool above_has_non_zero_tokens = false;
  584. for (u8 x = 0; x < transform_size_in_sub_blocks && x < above_non_zero_tokens[plane].size() - sub_block_column; x++) {
  585. if (above_non_zero_tokens[plane][sub_block_column + x]) {
  586. above_has_non_zero_tokens = true;
  587. break;
  588. }
  589. }
  590. bool left_has_non_zero_tokens = false;
  591. for (u8 y = 0; y < transform_size_in_sub_blocks && y < left_non_zero_tokens_in_block[plane].size() - sub_block_row; y++) {
  592. if (left_non_zero_tokens_in_block[plane][sub_block_row + y]) {
  593. left_has_non_zero_tokens = true;
  594. break;
  595. }
  596. }
  597. u8 context = above_has_non_zero_tokens + left_has_non_zero_tokens;
  598. return TokensContext { transform_size, plane > 0, is_inter, band, context };
  599. }
  600. TokensContext TreeParser::get_context_for_other_tokens(Array<u8, 1024> token_cache, TransformSize transform_size, TransformSet transform_set, u8 plane, u16 token_position, bool is_inter, u8 band)
  601. {
  602. auto transform_size_in_pixels = sub_blocks_to_pixels(transform_size_to_sub_blocks(transform_size));
  603. auto log2_of_transform_size = transform_size + 2;
  604. auto pixel_y = token_position >> log2_of_transform_size;
  605. auto pixel_x = token_position - (pixel_y << log2_of_transform_size);
  606. auto above_token_energy = pixel_y > 0 ? (pixel_y - 1) * transform_size_in_pixels + pixel_x : 0;
  607. auto left_token_energy = pixel_y * transform_size_in_pixels + pixel_x - 1;
  608. u32 neighbor_a, neighbor_b;
  609. if (pixel_y > 0 && pixel_x > 0) {
  610. if (transform_set == TransformSet { TransformType::DCT, TransformType::ADST }) {
  611. neighbor_a = above_token_energy;
  612. neighbor_b = above_token_energy;
  613. } else if (transform_set == TransformSet { TransformType::ADST, TransformType::DCT }) {
  614. neighbor_a = left_token_energy;
  615. neighbor_b = left_token_energy;
  616. } else {
  617. neighbor_a = above_token_energy;
  618. neighbor_b = left_token_energy;
  619. }
  620. } else if (pixel_y > 0) {
  621. neighbor_a = above_token_energy;
  622. neighbor_b = above_token_energy;
  623. } else {
  624. neighbor_a = left_token_energy;
  625. neighbor_b = left_token_energy;
  626. }
  627. u8 context = (1 + token_cache[neighbor_a] + token_cache[neighbor_b]) >> 1;
  628. return TokensContext { transform_size, plane > 0, is_inter, band, context };
  629. }
  630. ErrorOr<bool> TreeParser::parse_more_coefficients(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, TokensContext const& context)
  631. {
  632. auto probability = probability_table.coef_probs()[context.m_tx_size][context.m_is_uv_plane][context.m_is_inter][context.m_band][context.m_context_index][0];
  633. auto value = TRY(parse_tree<u8>(bit_stream, { binary_tree }, [&](u8) { return probability; }));
  634. counter.m_counts_more_coefs[context.m_tx_size][context.m_is_uv_plane][context.m_is_inter][context.m_band][context.m_context_index][value]++;
  635. return value;
  636. }
  637. ErrorOr<Token> TreeParser::parse_token(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, TokensContext const& context)
  638. {
  639. Function<u8(u8)> probability_getter = [&](u8 node) -> u8 {
  640. auto prob = probability_table.coef_probs()[context.m_tx_size][context.m_is_uv_plane][context.m_is_inter][context.m_band][context.m_context_index][min(2, 1 + node)];
  641. if (node < 2)
  642. return prob;
  643. auto x = (prob - 1) / 2;
  644. auto const& pareto_table = probability_table.pareto_table();
  645. if ((prob & 1) != 0)
  646. return pareto_table[x][node - 2];
  647. return (pareto_table[x][node - 2] + pareto_table[x + 1][node - 2]) >> 1;
  648. };
  649. auto value = TRY(parse_tree<Token>(bit_stream, { token_tree }, probability_getter));
  650. counter.m_counts_token[context.m_tx_size][context.m_is_uv_plane][context.m_is_inter][context.m_band][context.m_context_index][min(2, value)]++;
  651. return value;
  652. }
  653. }