瀏覽代碼

LibVideo: Parameterize all tree parsing for motion vectors in VP9

Zaggy1024 2 年之前
父節點
當前提交
e906bcc696

+ 16 - 17
Userland/Libraries/LibVideo/VP9/Parser.cpp

@@ -1338,7 +1338,7 @@ DecoderErrorOr<void> Parser::read_mv(u8 ref)
 {
     m_use_hp = m_allow_high_precision_mv && use_mv_hp(m_best_mv[ref]);
     MotionVector diff_mv;
-    auto mv_joint = TRY_READ(m_tree_parser->parse_tree<MvJoint>(SyntaxElementType::MVJoint));
+    auto mv_joint = TRY_READ(TreeParser::parse_motion_vector_joint(*m_bit_stream, *m_probability_tables, *m_syntax_element_counter));
     if (mv_joint == MvJointHzvnz || mv_joint == MvJointHnzvnz)
         diff_mv.set_row(TRY(read_mv_component(0)));
     if (mv_joint == MvJointHnzvz || mv_joint == MvJointHnzvnz)
@@ -1352,27 +1352,26 @@ DecoderErrorOr<void> Parser::read_mv(u8 ref)
 
 DecoderErrorOr<i32> Parser::read_mv_component(u8 component)
 {
-    m_tree_parser->set_mv_component(component);
-    auto mv_sign = TRY_READ(m_tree_parser->parse_tree<bool>(SyntaxElementType::MVSign));
-    auto mv_class = TRY_READ(m_tree_parser->parse_tree<MvClass>(SyntaxElementType::MVClass));
-    u32 mag;
+    auto mv_sign = TRY_READ(TreeParser::parse_motion_vector_sign(*m_bit_stream, *m_probability_tables, *m_syntax_element_counter, component));
+    auto mv_class = TRY_READ(TreeParser::parse_motion_vector_class(*m_bit_stream, *m_probability_tables, *m_syntax_element_counter, component));
+    u32 magnitude;
     if (mv_class == MvClass0) {
-        u32 mv_class0_bit = TRY_READ(m_tree_parser->parse_tree<bool>(SyntaxElementType::MVClass0Bit));
-        u32 mv_class0_fr = TRY_READ(m_tree_parser->parse_mv_class0_fr(mv_class0_bit));
-        u32 mv_class0_hp = TRY_READ(m_tree_parser->parse_tree<bool>(SyntaxElementType::MVClass0HP));
-        mag = ((mv_class0_bit << 3) | (mv_class0_fr << 1) | mv_class0_hp) + 1;
+        auto mv_class0_bit = TRY_READ(TreeParser::parse_motion_vector_class0_bit(*m_bit_stream, *m_probability_tables, *m_syntax_element_counter, component));
+        auto mv_class0_fr = TRY_READ(TreeParser::parse_motion_vector_class0_fr(*m_bit_stream, *m_probability_tables, *m_syntax_element_counter, component, mv_class0_bit));
+        auto mv_class0_hp = TRY_READ(TreeParser::parse_motion_vector_class0_hp(*m_bit_stream, *m_probability_tables, *m_syntax_element_counter, component, m_use_hp));
+        magnitude = ((mv_class0_bit << 3) | (mv_class0_fr << 1) | mv_class0_hp) + 1;
     } else {
-        u32 d = 0;
+        u32 bits = 0;
         for (u8 i = 0; i < mv_class; i++) {
-            u32 mv_bit = TRY_READ(m_tree_parser->parse_mv_bit(i));
-            d |= mv_bit << i;
+            auto mv_bit = TRY_READ(TreeParser::parse_motion_vector_bit(*m_bit_stream, *m_probability_tables, *m_syntax_element_counter, component, i));
+            bits |= mv_bit << i;
         }
-        mag = CLASS0_SIZE << (mv_class + 2);
-        u32 mv_fr = TRY_READ(m_tree_parser->parse_tree<u8>(SyntaxElementType::MVFR));
-        u32 mv_hp = TRY_READ(m_tree_parser->parse_tree<bool>(SyntaxElementType::MVHP));
-        mag += ((d << 3) | (mv_fr << 1) | mv_hp) + 1;
+        magnitude = CLASS0_SIZE << (mv_class + 2);
+        auto mv_fr = TRY_READ(TreeParser::parse_motion_vector_fr(*m_bit_stream, *m_probability_tables, *m_syntax_element_counter, component));
+        auto mv_hp = TRY_READ(TreeParser::parse_motion_vector_hp(*m_bit_stream, *m_probability_tables, *m_syntax_element_counter, component, m_use_hp));
+        magnitude += ((bits << 3) | (mv_fr << 1) | mv_hp) + 1;
     }
-    return (mv_sign ? -1 : 1) * static_cast<i32>(mag);
+    return (mv_sign ? -1 : 1) * static_cast<i32>(magnitude);
 }
 
 Gfx::Point<size_t> Parser::get_decoded_point_for_plane(u32 column, u32 row, u8 plane)

+ 71 - 68
Userland/Libraries/LibVideo/VP9/TreeParser.cpp

@@ -581,29 +581,85 @@ ErrorOr<bool> TreeParser::parse_single_ref_part_2(BitStream& bit_stream, Probabi
     return value;
 }
 
+ErrorOr<MvJoint> TreeParser::parse_motion_vector_joint(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter)
+{
+    auto value = TRY(parse_tree_new<MvJoint>(bit_stream, { mv_joint_tree }, [&](u8 node) { return probability_table.mv_joint_probs()[node]; }));
+    increment_counter(counter.m_counts_mv_joint[value]);
+    return value;
+}
+
+ErrorOr<bool> TreeParser::parse_motion_vector_sign(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, u8 component)
+{
+    auto value = TRY(parse_tree_new<bool>(bit_stream, { binary_tree }, [&](u8) { return probability_table.mv_sign_prob()[component]; }));
+    increment_counter(counter.m_counts_mv_sign[component][value]);
+    return value;
+}
+
+ErrorOr<MvClass> TreeParser::parse_motion_vector_class(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, u8 component)
+{
+    // Spec doesn't mention node, but the probabilities table has an extra dimension
+    // so we will use node for that.
+    auto value = TRY(parse_tree_new<MvClass>(bit_stream, { mv_class_tree }, [&](u8 node) { return probability_table.mv_class_probs()[component][node]; }));
+    increment_counter(counter.m_counts_mv_class[component][value]);
+    return value;
+}
+
+ErrorOr<bool> TreeParser::parse_motion_vector_class0_bit(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, u8 component)
+{
+    auto value = TRY(parse_tree_new<bool>(bit_stream, { binary_tree }, [&](u8) { return probability_table.mv_class0_bit_prob()[component]; }));
+    increment_counter(counter.m_counts_mv_class0_bit[component][value]);
+    return value;
+}
+
+ErrorOr<u8> TreeParser::parse_motion_vector_class0_fr(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, u8 component, bool class_0_bit)
+{
+    auto value = TRY(parse_tree_new<u8>(bit_stream, { mv_fr_tree }, [&](u8 node) { return probability_table.mv_class0_fr_probs()[component][class_0_bit][node]; }));
+    increment_counter(counter.m_counts_mv_class0_fr[component][class_0_bit][value]);
+    return value;
+}
+
+ErrorOr<bool> TreeParser::parse_motion_vector_class0_hp(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, u8 component, bool use_hp)
+{
+    TreeParser::TreeSelection tree { 1 };
+    if (use_hp)
+        tree = { binary_tree };
+    auto value = TRY(parse_tree_new<bool>(bit_stream, tree, [&](u8) { return probability_table.mv_class0_hp_prob()[component]; }));
+    increment_counter(counter.m_counts_mv_class0_hp[component][value]);
+    return value;
+}
+
+ErrorOr<bool> TreeParser::parse_motion_vector_bit(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, u8 component, u8 bit_index)
+{
+    auto value = TRY(parse_tree_new<bool>(bit_stream, { binary_tree }, [&](u8) { return probability_table.mv_bits_prob()[component][bit_index]; }));
+    increment_counter(counter.m_counts_mv_bits[component][bit_index][value]);
+    return value;
+}
+
+ErrorOr<u8> TreeParser::parse_motion_vector_fr(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, u8 component)
+{
+    auto value = TRY(parse_tree_new<u8>(bit_stream, { mv_fr_tree }, [&](u8 node) { return probability_table.mv_fr_probs()[component][node]; }));
+    increment_counter(counter.m_counts_mv_fr[component][value]);
+    return value;
+}
+
+ErrorOr<bool> TreeParser::parse_motion_vector_hp(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, u8 component, bool use_hp)
+{
+    TreeParser::TreeSelection tree { 1 };
+    if (use_hp)
+        tree = { binary_tree };
+    auto value = TRY(parse_tree_new<u8>(bit_stream, tree, [&](u8) { return probability_table.mv_hp_prob()[component]; }));
+    increment_counter(counter.m_counts_mv_hp[component][value]);
+    return value;
+}
+
 /*
  * Select a tree value based on the type of syntax element being parsed, as well as some parser state, as specified in section 9.3.1
  */
 TreeParser::TreeSelection TreeParser::select_tree(SyntaxElementType type)
 {
     switch (type) {
-    case SyntaxElementType::MVSign:
-    case SyntaxElementType::MVClass0Bit:
-    case SyntaxElementType::MVBit:
     case SyntaxElementType::MoreCoefs:
         return { binary_tree };
-    case SyntaxElementType::MVJoint:
-        return { mv_joint_tree };
-    case SyntaxElementType::MVClass:
-        return { mv_class_tree };
-    case SyntaxElementType::MVClass0FR:
-    case SyntaxElementType::MVFR:
-        return { mv_fr_tree };
-    case SyntaxElementType::MVClass0HP:
-    case SyntaxElementType::MVHP:
-        if (m_decoder.m_use_hp)
-            return { binary_tree };
-        return { 1 };
     case SyntaxElementType::Token:
         return { token_tree };
     default:
@@ -618,28 +674,6 @@ TreeParser::TreeSelection TreeParser::select_tree(SyntaxElementType type)
 u8 TreeParser::select_tree_probability(SyntaxElementType type, u8 node)
 {
     switch (type) {
-    case SyntaxElementType::MVSign:
-        return m_decoder.m_probability_tables->mv_sign_prob()[m_mv_component];
-    case SyntaxElementType::MVClass0Bit:
-        return m_decoder.m_probability_tables->mv_class0_bit_prob()[m_mv_component];
-    case SyntaxElementType::MVBit:
-        VERIFY(m_mv_bit < MV_OFFSET_BITS);
-        return m_decoder.m_probability_tables->mv_bits_prob()[m_mv_component][m_mv_bit];
-    case SyntaxElementType::MVJoint:
-        return m_decoder.m_probability_tables->mv_joint_probs()[node];
-    case SyntaxElementType::MVClass:
-        // Spec doesn't mention node, but the probabilities table has an extra dimension
-        // so we will use node for that.
-        return m_decoder.m_probability_tables->mv_class_probs()[m_mv_component][node];
-    case SyntaxElementType::MVClass0FR:
-        VERIFY(m_mv_class0_bit < CLASS0_SIZE);
-        return m_decoder.m_probability_tables->mv_class0_fr_probs()[m_mv_component][m_mv_class0_bit][node];
-    case SyntaxElementType::MVClass0HP:
-        return m_decoder.m_probability_tables->mv_class0_hp_prob()[m_mv_component];
-    case SyntaxElementType::MVFR:
-        return m_decoder.m_probability_tables->mv_fr_probs()[m_mv_component][node];
-    case SyntaxElementType::MVHP:
-        return m_decoder.m_probability_tables->mv_hp_prob()[m_mv_component];
     case SyntaxElementType::Token:
         return calculate_token_probability(node);
     case SyntaxElementType::MoreCoefs:
@@ -738,37 +772,6 @@ void TreeParser::count_syntax_element(SyntaxElementType type, int value)
         increment_counter(count);
     };
     switch (type) {
-    case SyntaxElementType::MVSign:
-        increment(m_decoder.m_syntax_element_counter->m_counts_mv_sign[m_mv_component][value]);
-        return;
-    case SyntaxElementType::MVClass0Bit:
-        increment(m_decoder.m_syntax_element_counter->m_counts_mv_class0_bit[m_mv_component][value]);
-        return;
-    case SyntaxElementType::MVBit:
-        VERIFY(m_mv_bit < MV_OFFSET_BITS);
-        increment(m_decoder.m_syntax_element_counter->m_counts_mv_bits[m_mv_component][m_mv_bit][value]);
-        m_mv_bit = 0xFF;
-        return;
-    case SyntaxElementType::MVJoint:
-        increment(m_decoder.m_syntax_element_counter->m_counts_mv_joint[value]);
-        return;
-    case SyntaxElementType::MVClass:
-        increment(m_decoder.m_syntax_element_counter->m_counts_mv_class[m_mv_component][value]);
-        return;
-    case SyntaxElementType::MVClass0FR:
-        VERIFY(m_mv_class0_bit < CLASS0_SIZE);
-        increment(m_decoder.m_syntax_element_counter->m_counts_mv_class0_fr[m_mv_component][m_mv_class0_bit][value]);
-        m_mv_class0_bit = 0xFF;
-        return;
-    case SyntaxElementType::MVClass0HP:
-        increment(m_decoder.m_syntax_element_counter->m_counts_mv_class0_hp[m_mv_component][value]);
-        return;
-    case SyntaxElementType::MVFR:
-        increment(m_decoder.m_syntax_element_counter->m_counts_mv_fr[m_mv_component][value]);
-        return;
-    case SyntaxElementType::MVHP:
-        increment(m_decoder.m_syntax_element_counter->m_counts_mv_hp[m_mv_component][value]);
-        return;
     case SyntaxElementType::Token:
         increment(m_decoder.m_syntax_element_counter->m_counts_token[m_tx_size][m_plane > 0][m_decoder.m_is_inter][m_band][m_ctx][min(2, value)]);
         return;

+ 10 - 17
Userland/Libraries/LibVideo/VP9/TreeParser.h

@@ -81,6 +81,16 @@ public:
     static ErrorOr<bool> parse_single_ref_part_1(BitStream&, ProbabilityTables const&, SyntaxElementCounter&, Optional<bool> above_single, Optional<bool> left_single, Optional<bool> above_intra, Optional<bool> left_intra, Optional<ReferenceFramePair> above_ref_frame, Optional<ReferenceFramePair> left_ref_frame);
     static ErrorOr<bool> parse_single_ref_part_2(BitStream&, ProbabilityTables const&, SyntaxElementCounter&, Optional<bool> above_single, Optional<bool> left_single, Optional<bool> above_intra, Optional<bool> left_intra, Optional<ReferenceFramePair> above_ref_frame, Optional<ReferenceFramePair> left_ref_frame);
 
+    static ErrorOr<MvJoint> parse_motion_vector_joint(BitStream&, ProbabilityTables const&, SyntaxElementCounter&);
+    static ErrorOr<bool> parse_motion_vector_sign(BitStream&, ProbabilityTables const&, SyntaxElementCounter&, u8 component);
+    static ErrorOr<MvClass> parse_motion_vector_class(BitStream&, ProbabilityTables const&, SyntaxElementCounter&, u8 component);
+    static ErrorOr<bool> parse_motion_vector_class0_bit(BitStream&, ProbabilityTables const&, SyntaxElementCounter&, u8 component);
+    static ErrorOr<u8> parse_motion_vector_class0_fr(BitStream&, ProbabilityTables const&, SyntaxElementCounter&, u8 component, bool class_0_bit);
+    static ErrorOr<bool> parse_motion_vector_class0_hp(BitStream&, ProbabilityTables const&, SyntaxElementCounter&, u8 component, bool use_hp);
+    static ErrorOr<bool> parse_motion_vector_bit(BitStream&, ProbabilityTables const&, SyntaxElementCounter&, u8 component, u8 bit_index);
+    static ErrorOr<u8> parse_motion_vector_fr(BitStream&, ProbabilityTables const&, SyntaxElementCounter&, u8 component);
+    static ErrorOr<bool> parse_motion_vector_hp(BitStream&, ProbabilityTables const&, SyntaxElementCounter&, u8 component, bool use_hp);
+
     void set_default_intra_mode_variables(u8 idx, u8 idy)
     {
         m_idx = idx;
@@ -95,23 +105,6 @@ public:
         m_start_y = start_y;
     }
 
-    void set_mv_component(u8 component)
-    {
-        m_mv_component = component;
-    }
-
-    ErrorOr<bool> parse_mv_bit(u8 bit)
-    {
-        m_mv_bit = bit;
-        return parse_tree<bool>(SyntaxElementType::MVBit);
-    }
-
-    ErrorOr<u8> parse_mv_class0_fr(bool mv_class0_bit)
-    {
-        m_mv_class0_bit = mv_class0_bit;
-        return parse_tree<u8>(SyntaxElementType::MVClass0FR);
-    }
-
 private:
     u8 calculate_token_probability(u8 node);
     u8 calculate_more_coefs_probability();