Browse Source

LibVideo: Ensure that syntax element counts don't overflow

Integer overflow could sometimes occur due to counts going above 255,
where the values should instead be clamped at their maximum to avoid
wrapping to 0.
Zaggy1024 2 years ago
parent
commit
7d27273dc7
1 changed files with 26 additions and 23 deletions
  1. 26 23
      Userland/Libraries/LibVideo/VP9/TreeParser.cpp

+ 26 - 23
Userland/Libraries/LibVideo/VP9/TreeParser.cpp

@@ -668,80 +668,83 @@ u8 TreeParser::calculate_token_probability(u8 node)
 
 
 void TreeParser::count_syntax_element(SyntaxElementType type, int value)
 void TreeParser::count_syntax_element(SyntaxElementType type, int value)
 {
 {
+    auto increment = [](u8& count) {
+        count = min(static_cast<u32>(count) + 1, 255);
+    };
     switch (type) {
     switch (type) {
     case SyntaxElementType::Partition:
     case SyntaxElementType::Partition:
-        m_decoder.m_syntax_element_counter->m_counts_partition[m_ctx][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_partition[m_ctx][value]);
         return;
         return;
     case SyntaxElementType::IntraMode:
     case SyntaxElementType::IntraMode:
     case SyntaxElementType::SubIntraMode:
     case SyntaxElementType::SubIntraMode:
-        m_decoder.m_syntax_element_counter->m_counts_intra_mode[m_ctx][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_intra_mode[m_ctx][value]);
         return;
         return;
     case SyntaxElementType::UVMode:
     case SyntaxElementType::UVMode:
-        m_decoder.m_syntax_element_counter->m_counts_uv_mode[m_ctx][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_uv_mode[m_ctx][value]);
         return;
         return;
     case SyntaxElementType::Skip:
     case SyntaxElementType::Skip:
-        m_decoder.m_syntax_element_counter->m_counts_skip[m_ctx][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_skip[m_ctx][value]);
         return;
         return;
     case SyntaxElementType::IsInter:
     case SyntaxElementType::IsInter:
-        m_decoder.m_syntax_element_counter->m_counts_is_inter[m_ctx][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_is_inter[m_ctx][value]);
         return;
         return;
     case SyntaxElementType::CompMode:
     case SyntaxElementType::CompMode:
-        m_decoder.m_syntax_element_counter->m_counts_comp_mode[m_ctx][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_comp_mode[m_ctx][value]);
         return;
         return;
     case SyntaxElementType::CompRef:
     case SyntaxElementType::CompRef:
-        m_decoder.m_syntax_element_counter->m_counts_comp_ref[m_ctx][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_comp_ref[m_ctx][value]);
         return;
         return;
     case SyntaxElementType::SingleRefP1:
     case SyntaxElementType::SingleRefP1:
-        m_decoder.m_syntax_element_counter->m_counts_single_ref[m_ctx][0][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_single_ref[m_ctx][0][value]);
         return;
         return;
     case SyntaxElementType::SingleRefP2:
     case SyntaxElementType::SingleRefP2:
-        m_decoder.m_syntax_element_counter->m_counts_single_ref[m_ctx][1][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_single_ref[m_ctx][1][value]);
         return;
         return;
     case SyntaxElementType::MVSign:
     case SyntaxElementType::MVSign:
-        m_decoder.m_syntax_element_counter->m_counts_mv_sign[m_mv_component][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_mv_sign[m_mv_component][value]);
         return;
         return;
     case SyntaxElementType::MVClass0Bit:
     case SyntaxElementType::MVClass0Bit:
-        m_decoder.m_syntax_element_counter->m_counts_mv_class0_bit[m_mv_component][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_mv_class0_bit[m_mv_component][value]);
         return;
         return;
     case SyntaxElementType::MVBit:
     case SyntaxElementType::MVBit:
         VERIFY(m_mv_bit < MV_OFFSET_BITS);
         VERIFY(m_mv_bit < MV_OFFSET_BITS);
-        m_decoder.m_syntax_element_counter->m_counts_mv_bits[m_mv_component][m_mv_bit][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_mv_bits[m_mv_component][m_mv_bit][value]);
         m_mv_bit = 0xFF;
         m_mv_bit = 0xFF;
         return;
         return;
     case SyntaxElementType::TXSize:
     case SyntaxElementType::TXSize:
-        m_decoder.m_syntax_element_counter->m_counts_tx_size[m_decoder.m_max_tx_size][m_ctx][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_tx_size[m_decoder.m_max_tx_size][m_ctx][value]);
         return;
         return;
     case SyntaxElementType::InterMode:
     case SyntaxElementType::InterMode:
-        m_decoder.m_syntax_element_counter->m_counts_inter_mode[m_ctx][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_inter_mode[m_ctx][value]);
         return;
         return;
     case SyntaxElementType::InterpFilter:
     case SyntaxElementType::InterpFilter:
-        m_decoder.m_syntax_element_counter->m_counts_interp_filter[m_ctx][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_interp_filter[m_ctx][value]);
         return;
         return;
     case SyntaxElementType::MVJoint:
     case SyntaxElementType::MVJoint:
-        m_decoder.m_syntax_element_counter->m_counts_mv_joint[value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_mv_joint[value]);
         return;
         return;
     case SyntaxElementType::MVClass:
     case SyntaxElementType::MVClass:
-        m_decoder.m_syntax_element_counter->m_counts_mv_class[m_mv_component][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_mv_class[m_mv_component][value]);
         return;
         return;
     case SyntaxElementType::MVClass0FR:
     case SyntaxElementType::MVClass0FR:
         VERIFY(m_mv_class0_bit < CLASS0_SIZE);
         VERIFY(m_mv_class0_bit < CLASS0_SIZE);
-        m_decoder.m_syntax_element_counter->m_counts_mv_class0_fr[m_mv_component][m_mv_class0_bit][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_mv_class0_fr[m_mv_component][m_mv_class0_bit][value]);
         m_mv_class0_bit = 0xFF;
         m_mv_class0_bit = 0xFF;
         return;
         return;
     case SyntaxElementType::MVClass0HP:
     case SyntaxElementType::MVClass0HP:
-        m_decoder.m_syntax_element_counter->m_counts_mv_class0_hp[m_mv_component][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_mv_class0_hp[m_mv_component][value]);
         return;
         return;
     case SyntaxElementType::MVFR:
     case SyntaxElementType::MVFR:
-        m_decoder.m_syntax_element_counter->m_counts_mv_fr[m_mv_component][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_mv_fr[m_mv_component][value]);
         return;
         return;
     case SyntaxElementType::MVHP:
     case SyntaxElementType::MVHP:
-        m_decoder.m_syntax_element_counter->m_counts_mv_hp[m_mv_component][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_mv_hp[m_mv_component][value]);
         return;
         return;
     case SyntaxElementType::Token:
     case SyntaxElementType::Token:
-        m_decoder.m_syntax_element_counter->m_counts_token[m_tx_size][m_plane > 0][m_decoder.m_is_inter][m_band][m_ctx][min(2, value)]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_token[m_tx_size][m_plane > 0][m_decoder.m_is_inter][m_band][m_ctx][min(2, value)]);
         return;
         return;
     case SyntaxElementType::MoreCoefs:
     case SyntaxElementType::MoreCoefs:
-        m_decoder.m_syntax_element_counter->m_counts_more_coefs[m_tx_size][m_plane > 0][m_decoder.m_is_inter][m_band][m_ctx][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_more_coefs[m_tx_size][m_plane > 0][m_decoder.m_is_inter][m_band][m_ctx][value]);
         return;
         return;
     case SyntaxElementType::DefaultIntraMode:
     case SyntaxElementType::DefaultIntraMode:
     case SyntaxElementType::DefaultUVMode:
     case SyntaxElementType::DefaultUVMode: