Pārlūkot izejas kodu

LibCrypto: Add UnsignedBigInteger multiplication

Also added documentation for the runtime complexity of some operations.
Itamar 5 gadi atpakaļ
vecāks
revīzija
2959c4a5e9

+ 94 - 2
Libraries/LibCrypto/BigInt/UnsignedBigInteger.cpp

@@ -28,7 +28,10 @@
 
 namespace Crypto {
 
-UnsignedBigInteger UnsignedBigInteger::add(const UnsignedBigInteger& other)
+/**
+ * Complexity: O(N) where N is the number of words in the larger number
+ */
+UnsignedBigInteger UnsignedBigInteger::add(const UnsignedBigInteger& other) const
 {
     const UnsignedBigInteger* const longer = (length() > other.length()) ? this : &other;
     const UnsignedBigInteger* const shorter = (longer == &other) ? this : &other;
@@ -64,7 +67,10 @@ UnsignedBigInteger UnsignedBigInteger::add(const UnsignedBigInteger& other)
     return result;
 }
 
-UnsignedBigInteger UnsignedBigInteger::sub(const UnsignedBigInteger& other)
+/**
+ * Complexity: O(N) where N is the number of words in the larger number
+ */
+UnsignedBigInteger UnsignedBigInteger::sub(const UnsignedBigInteger& other) const
 {
     UnsignedBigInteger result;
 
@@ -96,6 +102,92 @@ UnsignedBigInteger UnsignedBigInteger::sub(const UnsignedBigInteger& other)
     return result;
 }
 
+/**
+ * Complexity: O(N^2) where N is the number of words in the larger number
+ * Multiplcation method:
+ * An integer is equal to the sum of the powers of two
+ * according to the indexes of its 'on' bits.
+ * So to multiple x*y, we go over each '1' bit in x (say the i'th bit), 
+ * and add y<<i to the result.
+ */
+UnsignedBigInteger UnsignedBigInteger::multiply(const UnsignedBigInteger& other) const
+{
+    UnsignedBigInteger result;
+    // iterate all bits
+    for (size_t word_index = 0; word_index < length(); ++word_index) {
+        for (size_t bit_index = 0; bit_index < UnsignedBigInteger::BITS_IN_WORD; ++bit_index) {
+            // If the bit is off - skip over it
+            if (!(m_words[word_index] & (1 << bit_index)))
+                continue;
+
+            const size_t shift_amount = word_index * UnsignedBigInteger::BITS_IN_WORD + bit_index;
+            auto shift_result = other.shift_left(shift_amount);
+            result = result.add(shift_result);
+        }
+    }
+    return result;
+}
+
+UnsignedBigInteger UnsignedBigInteger::shift_left(size_t num_bits) const
+{
+    // We can only do shift operations on individual words
+    // where the shift amount is <= size of word (32).
+    // But we do know how to shift by a multiple of word size (e.g 64=32*2)
+    // So we first shift the result by how many whole words fit in 'num_bits'
+    UnsignedBigInteger temp_result = shift_left_by_n_words(num_bits / UnsignedBigInteger::BITS_IN_WORD);
+
+    // And now we shift by the leftover amount of bits
+    num_bits %= UnsignedBigInteger::BITS_IN_WORD;
+
+    UnsignedBigInteger result(temp_result);
+
+    for (size_t i = 0; i < temp_result.length(); ++i) {
+        u32 current_word_of_temp_result = temp_result.shift_left_get_one_word(num_bits, i);
+        result.m_words[i] = current_word_of_temp_result;
+    }
+
+    // Shifting the last word can produce a carry
+    u32 carry_word = temp_result.shift_left_get_one_word(num_bits, temp_result.length());
+    if (carry_word != 0) {
+        result = result.add(UnsignedBigInteger(carry_word).shift_left_by_n_words(temp_result.length()));
+    }
+    return result;
+}
+
+UnsignedBigInteger UnsignedBigInteger::shift_left_by_n_words(const size_t number_of_words) const
+{
+    // shifting left by N words means just inserting N zeroes to the beginning of the words vector
+    UnsignedBigInteger result;
+    for (size_t i = 0; i < number_of_words; ++i) {
+        result.m_words.append(0);
+    }
+    for (size_t i = 0; i < length(); ++i) {
+        result.m_words.append(m_words[i]);
+    }
+    return result;
+}
+
+/**
+ * Returns the word at a requested index in the result of a shift operation
+ */
+u32 UnsignedBigInteger::shift_left_get_one_word(const size_t num_bits, const size_t result_word_index) const
+{
+    // "<= length()" (rather than length() - 1) is intentional,
+    // The result inedx of length() is used when calculating the carry word
+    ASSERT(result_word_index <= length());
+    ASSERT(num_bits <= UnsignedBigInteger::BITS_IN_WORD);
+    u32 result = 0;
+
+    // we need to check for "num_bits != 0" since shifting right by 32 is apparently undefined behaviour!
+    if (result_word_index > 0 && num_bits != 0) {
+        result += m_words[result_word_index - 1] >> (UnsignedBigInteger::BITS_IN_WORD - num_bits);
+    }
+    if (result_word_index < length() && num_bits < 32) {
+        result += m_words[result_word_index] << num_bits;
+    }
+    return result;
+}
+
 bool UnsignedBigInteger::operator==(const UnsignedBigInteger& other) const
 {
     if (trimmed_length() != other.trimmed_length()) {

+ 14 - 2
Libraries/LibCrypto/BigInt/UnsignedBigInteger.h

@@ -33,14 +33,23 @@ namespace Crypto {
 class UnsignedBigInteger {
 public:
     UnsignedBigInteger(u32 x) { m_words.append(x); }
+
+    UnsignedBigInteger(AK::Vector<u32>&& words)
+        : m_words(words)
+    {
+    }
+
     UnsignedBigInteger() {}
 
     static UnsignedBigInteger create_invalid();
 
     const AK::Vector<u32>& words() const { return m_words; }
 
-    UnsignedBigInteger add(const UnsignedBigInteger& other);
-    UnsignedBigInteger sub(const UnsignedBigInteger& other);
+    UnsignedBigInteger add(const UnsignedBigInteger& other) const;
+    UnsignedBigInteger sub(const UnsignedBigInteger& other) const;
+    UnsignedBigInteger multiply(const UnsignedBigInteger& other) const;
+    UnsignedBigInteger shift_left(size_t num_bits) const;
+    UnsignedBigInteger shift_left_by_n_words(const size_t number_of_words) const;
 
     size_t length() const { return m_words.size(); }
 
@@ -54,6 +63,9 @@ public:
     bool is_invalid() const { return m_is_invalid; }
 
 private:
+    u32 shift_left_get_one_word(const size_t num_bits, const size_t result_word_index) const;
+
+    static constexpr size_t BITS_IN_WORD = 32;
     AK::Vector<u32> m_words;
 
     // Used to indicate a negative result, or a result of an invalid operation

+ 42 - 0
Userland/test-crypto.cpp

@@ -305,6 +305,7 @@ void hmac_sha512_test_process();
 void bigint_test_fibo500();
 void bigint_addition_edgecases();
 void bigint_subtraction();
+void bigint_multiplication();
 
 int aes_cbc_tests()
 {
@@ -799,6 +800,7 @@ int bigint_tests()
     bigint_test_fibo500();
     bigint_addition_edgecases();
     bigint_subtraction();
+    bigint_multiplication();
     return 0;
 }
 
@@ -851,6 +853,8 @@ void bigint_addition_edgecases()
             PASS;
         } else {
             FAIL(Incorrect Result);
+        }
+    }
 }
 
 void bigint_subtraction()
@@ -902,3 +906,41 @@ void bigint_subtraction()
         }
     }
 }
+
+void bigint_multiplication()
+{
+    {
+        I_TEST((BigInteger | Simple Multipliction));
+        Crypto::UnsignedBigInteger num1(8);
+        Crypto::UnsignedBigInteger num2(251);
+        Crypto::UnsignedBigInteger result = num1.multiply(num2);
+        dbg() << "result: " << result;
+        if (result.words() == Vector<u32> { 2008 }) {
+            PASS;
+        } else {
+            FAIL(Incorrect Result);
+        }
+    }
+    {
+        I_TEST((BigInteger | Multiplications with big numbers 1));
+        Crypto::UnsignedBigInteger num1 = bigint_fibonacci(200);
+        Crypto::UnsignedBigInteger num2(12345678);
+        Crypto::UnsignedBigInteger result = num1.multiply(num2);
+        if (result.words() == Vector<u32> { 669961318, 143970113, 4028714974, 3164551305, 1589380278, 2 }) {
+            PASS;
+        } else {
+            FAIL(Incorrect Result);
+        }
+    }
+    {
+        I_TEST((BigInteger | Multiplications with big numbers 2));
+        Crypto::UnsignedBigInteger num1 = bigint_fibonacci(200);
+        Crypto::UnsignedBigInteger num2 = bigint_fibonacci(341);
+        Crypto::UnsignedBigInteger result = num1.multiply(num2);
+        if (result.words() == Vector<u32> { 3017415433, 2741793511, 1957755698, 3731653885, 3154681877, 785762127, 3200178098, 4260616581, 529754471, 3632684436, 1073347813, 2516430 }) {
+            PASS;
+        } else {
+            FAIL(Incorrect Result);
+        }
+    }
+}