Ver código fonte

LibVideo: Ensure that syntax element counts don't overflow

Integer overflow could sometimes occur due to counts going above 255,
where the values should instead be clamped at their maximum to avoid
wrapping to 0.
Zaggy1024 2 anos atrás
pai
commit
7d27273dc7
1 arquivos alterados com 26 adições e 23 exclusões
  1. 26 23
      Userland/Libraries/LibVideo/VP9/TreeParser.cpp

+ 26 - 23
Userland/Libraries/LibVideo/VP9/TreeParser.cpp

@@ -668,80 +668,83 @@ u8 TreeParser::calculate_token_probability(u8 node)
 
 void TreeParser::count_syntax_element(SyntaxElementType type, int value)
 {
+    auto increment = [](u8& count) {
+        count = min(static_cast<u32>(count) + 1, 255);
+    };
     switch (type) {
     case SyntaxElementType::Partition:
-        m_decoder.m_syntax_element_counter->m_counts_partition[m_ctx][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_partition[m_ctx][value]);
         return;
     case SyntaxElementType::IntraMode:
     case SyntaxElementType::SubIntraMode:
-        m_decoder.m_syntax_element_counter->m_counts_intra_mode[m_ctx][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_intra_mode[m_ctx][value]);
         return;
     case SyntaxElementType::UVMode:
-        m_decoder.m_syntax_element_counter->m_counts_uv_mode[m_ctx][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_uv_mode[m_ctx][value]);
         return;
     case SyntaxElementType::Skip:
-        m_decoder.m_syntax_element_counter->m_counts_skip[m_ctx][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_skip[m_ctx][value]);
         return;
     case SyntaxElementType::IsInter:
-        m_decoder.m_syntax_element_counter->m_counts_is_inter[m_ctx][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_is_inter[m_ctx][value]);
         return;
     case SyntaxElementType::CompMode:
-        m_decoder.m_syntax_element_counter->m_counts_comp_mode[m_ctx][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_comp_mode[m_ctx][value]);
         return;
     case SyntaxElementType::CompRef:
-        m_decoder.m_syntax_element_counter->m_counts_comp_ref[m_ctx][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_comp_ref[m_ctx][value]);
         return;
     case SyntaxElementType::SingleRefP1:
-        m_decoder.m_syntax_element_counter->m_counts_single_ref[m_ctx][0][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_single_ref[m_ctx][0][value]);
         return;
     case SyntaxElementType::SingleRefP2:
-        m_decoder.m_syntax_element_counter->m_counts_single_ref[m_ctx][1][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_single_ref[m_ctx][1][value]);
         return;
     case SyntaxElementType::MVSign:
-        m_decoder.m_syntax_element_counter->m_counts_mv_sign[m_mv_component][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_mv_sign[m_mv_component][value]);
         return;
     case SyntaxElementType::MVClass0Bit:
-        m_decoder.m_syntax_element_counter->m_counts_mv_class0_bit[m_mv_component][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_mv_class0_bit[m_mv_component][value]);
         return;
     case SyntaxElementType::MVBit:
         VERIFY(m_mv_bit < MV_OFFSET_BITS);
-        m_decoder.m_syntax_element_counter->m_counts_mv_bits[m_mv_component][m_mv_bit][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_mv_bits[m_mv_component][m_mv_bit][value]);
         m_mv_bit = 0xFF;
         return;
     case SyntaxElementType::TXSize:
-        m_decoder.m_syntax_element_counter->m_counts_tx_size[m_decoder.m_max_tx_size][m_ctx][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_tx_size[m_decoder.m_max_tx_size][m_ctx][value]);
         return;
     case SyntaxElementType::InterMode:
-        m_decoder.m_syntax_element_counter->m_counts_inter_mode[m_ctx][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_inter_mode[m_ctx][value]);
         return;
     case SyntaxElementType::InterpFilter:
-        m_decoder.m_syntax_element_counter->m_counts_interp_filter[m_ctx][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_interp_filter[m_ctx][value]);
         return;
     case SyntaxElementType::MVJoint:
-        m_decoder.m_syntax_element_counter->m_counts_mv_joint[value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_mv_joint[value]);
         return;
     case SyntaxElementType::MVClass:
-        m_decoder.m_syntax_element_counter->m_counts_mv_class[m_mv_component][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_mv_class[m_mv_component][value]);
         return;
     case SyntaxElementType::MVClass0FR:
         VERIFY(m_mv_class0_bit < CLASS0_SIZE);
-        m_decoder.m_syntax_element_counter->m_counts_mv_class0_fr[m_mv_component][m_mv_class0_bit][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_mv_class0_fr[m_mv_component][m_mv_class0_bit][value]);
         m_mv_class0_bit = 0xFF;
         return;
     case SyntaxElementType::MVClass0HP:
-        m_decoder.m_syntax_element_counter->m_counts_mv_class0_hp[m_mv_component][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_mv_class0_hp[m_mv_component][value]);
         return;
     case SyntaxElementType::MVFR:
-        m_decoder.m_syntax_element_counter->m_counts_mv_fr[m_mv_component][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_mv_fr[m_mv_component][value]);
         return;
     case SyntaxElementType::MVHP:
-        m_decoder.m_syntax_element_counter->m_counts_mv_hp[m_mv_component][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_mv_hp[m_mv_component][value]);
         return;
     case SyntaxElementType::Token:
-        m_decoder.m_syntax_element_counter->m_counts_token[m_tx_size][m_plane > 0][m_decoder.m_is_inter][m_band][m_ctx][min(2, value)]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_token[m_tx_size][m_plane > 0][m_decoder.m_is_inter][m_band][m_ctx][min(2, value)]);
         return;
     case SyntaxElementType::MoreCoefs:
-        m_decoder.m_syntax_element_counter->m_counts_more_coefs[m_tx_size][m_plane > 0][m_decoder.m_is_inter][m_band][m_ctx][value]++;
+        increment(m_decoder.m_syntax_element_counter->m_counts_more_coefs[m_tx_size][m_plane > 0][m_decoder.m_is_inter][m_band][m_ctx][value]);
         return;
     case SyntaxElementType::DefaultIntraMode:
     case SyntaxElementType::DefaultUVMode: