瀏覽代碼

AK: Implement RedBlackTree container

This container is based on a balanced binary search tree, and as such
allows for O(logn) worst-case insertion, removal, and search, as well
as O(n) sorted iteration.
Idan Horowitz 4 年之前
父節點
當前提交
e962254eb2
共有 3 個文件被更改,包括 662 次插入0 次删除
  1. 551 0
      AK/RedBlackTree.h
  2. 1 0
      AK/Tests/CMakeLists.txt
  3. 110 0
      AK/Tests/TestRedBlackTree.cpp

+ 551 - 0
AK/RedBlackTree.h

@@ -0,0 +1,551 @@
+/*
+ * Copyright (c) 2021, Idan Horowitz <idan.horowitz@gmail.com>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ *    list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ *    this list of conditions and the following disclaimer in the documentation
+ *    and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#pragma once
+
+#include <AK/Concepts.h>
+
+namespace AK {
+
+template<Integral K>
+class BaseRedBlackTree {
+public:
+    [[nodiscard]] size_t size() const { return m_size; }
+    [[nodiscard]] bool is_empty() const { return m_size == 0; }
+
+    enum class Color : bool {
+        Red,
+        Black
+    };
+    struct Node {
+        Node* left_child { nullptr };
+        Node* right_child { nullptr };
+        Node* parent { nullptr };
+
+        Color color { Color::Red };
+
+        K key;
+
+        Node(K key)
+            : key(key)
+        {
+        }
+        virtual ~Node() {};
+    };
+
+protected:
+    BaseRedBlackTree() = default; // These are protected to ensure no one instantiates the leaky base red black tree directly
+    virtual ~BaseRedBlackTree() {};
+
+    void rotate_left(Node* subtree_root)
+    {
+        VERIFY(subtree_root);
+        auto* pivot = subtree_root->right_child;
+        VERIFY(pivot);
+        auto* parent = subtree_root->parent;
+
+        // stage 1 - subtree_root's right child is now pivot's left child
+        subtree_root->right_child = pivot->left_child;
+        if (subtree_root->right_child)
+            subtree_root->right_child->parent = subtree_root;
+
+        // stage 2 - pivot's left child is now subtree_root
+        pivot->left_child = subtree_root;
+        subtree_root->parent = pivot;
+
+        // stage 3 - update pivot's parent
+        pivot->parent = parent;
+        if (!parent) { // new root
+            m_root = pivot;
+        } else if (parent->left_child == subtree_root) { // we are the left child
+            parent->left_child = pivot;
+        } else { // we are the right child
+            parent->right_child = pivot;
+        }
+    }
+
+    void rotate_right(Node* subtree_root)
+    {
+        VERIFY(subtree_root);
+        auto* pivot = subtree_root->left_child;
+        VERIFY(pivot);
+        auto* parent = subtree_root->parent;
+
+        // stage 1 - subtree_root's left child is now pivot's right child
+        subtree_root->left_child = pivot->right_child;
+        if (subtree_root->left_child)
+            subtree_root->left_child->parent = subtree_root;
+
+        // stage 2 - pivot's right child is now subtree_root
+        pivot->right_child = subtree_root;
+        subtree_root->parent = pivot;
+
+        // stage 3 - update pivot's parent
+        pivot->parent = parent;
+        if (!parent) { // new root
+            m_root = pivot;
+        } else if (parent->left_child == subtree_root) { // we are the left child
+            parent->left_child = pivot;
+        } else { // we are the right child
+            parent->right_child = pivot;
+        }
+    }
+
+    static Node* find(Node* node, K key)
+    {
+        while (node && node->key != key) {
+            if (key < node->key) {
+                node = node->left_child;
+            } else {
+                node = node->right_child;
+            }
+        }
+        return node;
+    }
+
+    static Node* find_largest_not_above(Node* node, K key)
+    {
+        Node* candidate = nullptr;
+        while (node) {
+            if (key == node->key) {
+                return node;
+            } else if (key < node->key) {
+                node = node->left_child;
+            } else {
+                candidate = node;
+                node = node->right_child;
+            }
+        }
+        return candidate;
+    }
+
+    void insert(Node* node)
+    {
+        VERIFY(node);
+        Node* parent = nullptr;
+        Node* temp = m_root;
+        while (temp) {
+            parent = temp;
+            if (node->key < temp->key) {
+                temp = temp->left_child;
+            } else {
+                temp = temp->right_child;
+            }
+        }
+        if (!parent) { // new root
+            node->color = Color::Black;
+            m_root = node;
+            m_size = 1;
+            m_minimum = node;
+            return;
+        } else if (node->key < parent->key) { // we are the left child
+            parent->left_child = node;
+        } else { // we are the right child
+            parent->right_child = node;
+        }
+        node->parent = parent;
+
+        if (node->parent->parent) // no fixups to be done for a height <= 2 tree
+            insert_fixups(node);
+
+        m_size++;
+        if (m_minimum->left_child == node)
+            m_minimum = node;
+    }
+
+    void insert_fixups(Node* node)
+    {
+        VERIFY(node && node->color == Color::Red);
+        while (node->parent && node->parent->color == Color::Red) {
+            auto* grand_parent = node->parent->parent;
+            if (grand_parent->right_child == node->parent) {
+                auto* uncle = grand_parent->left_child;
+                if (uncle && uncle->color == Color::Red) {
+                    node->parent->color = Color::Black;
+                    uncle->color = Color::Black;
+                    grand_parent->color = Color::Red;
+                    node = grand_parent;
+                } else {
+                    if (node->parent->left_child == node) {
+                        node = node->parent;
+                        rotate_right(node);
+                    }
+                    node->parent->color = Color::Black;
+                    grand_parent->color = Color::Red;
+                    rotate_left(grand_parent);
+                }
+            } else {
+                auto* uncle = grand_parent->right_child;
+                if (uncle && uncle->color == Color::Red) {
+                    node->parent->color = Color::Black;
+                    uncle->color = Color::Black;
+                    grand_parent->color = Color::Red;
+                    node = grand_parent;
+                } else {
+                    if (node->parent->right_child == node) {
+                        node = node->parent;
+                        rotate_left(node);
+                    }
+                    node->parent->color = Color::Black;
+                    grand_parent->color = Color::Red;
+                    rotate_right(grand_parent);
+                }
+            }
+        }
+        m_root->color = Color::Black; // the root should always be black
+    }
+
+    void remove(Node* node)
+    {
+        VERIFY(node);
+
+        // special case: deleting the only node
+        if (m_size == 1) {
+            m_root = nullptr;
+            m_size = 0;
+            return;
+        }
+
+        if (m_minimum == node)
+            m_minimum = successor(node);
+
+        // removal assumes the node has 0 or 1 child, so if we have 2, relink with the successor first (by definition the successor has no left child)
+        // FIXME: since we dont know how a value is represented in the node, we cant simply swap the values and keys, and instead we relink the nodes
+        //  in place, this is quite a bit more expensive, as well as much less readable, is there a better way?
+        if (node->left_child && node->right_child) {
+            auto* successor_node = successor(node); // this is always non-null as all nodes besides the maximum node have a successor, and the maximum node has no right child
+            auto neighbour_swap = successor_node->parent == node;
+            node->left_child->parent = successor_node;
+            if (!neighbour_swap)
+                node->right_child->parent = successor_node;
+            if (node->parent) {
+                if (node->parent->left_child == node) {
+                    node->parent->left_child = successor_node;
+                } else {
+                    node->parent->right_child = successor_node;
+                }
+            } else {
+                m_root = successor_node;
+            }
+            if (successor_node->right_child)
+                successor_node->right_child->parent = node;
+            if (neighbour_swap) {
+                successor_node->parent = node->parent;
+                node->parent = successor_node;
+            } else {
+                if (successor_node->parent) {
+                    if (successor_node->parent->left_child == successor_node) {
+                        successor_node->parent->left_child = node;
+                    } else {
+                        successor_node->parent->right_child = node;
+                    }
+                } else {
+                    m_root = node;
+                }
+                swap(node->parent, successor_node->parent);
+            }
+            swap(node->left_child, successor_node->left_child);
+            if (neighbour_swap) {
+                node->right_child = successor_node->right_child;
+                successor_node->right_child = node;
+            } else {
+                swap(node->right_child, successor_node->right_child);
+            }
+            swap(node->color, successor_node->color);
+        }
+
+        auto* child = node->left_child ?: node->right_child;
+
+        if (child)
+            child->parent = node->parent;
+        if (node->parent) {
+            if (node->parent->left_child == node)
+                node->parent->left_child = child;
+            else
+                node->parent->right_child = child;
+        } else {
+            m_root = child;
+        }
+
+        // if the node is red then child must be black, and just replacing the node with its child should result in a valid tree (no change to black height)
+        if (node->color != Color::Red)
+            remove_fixups(child, node->parent);
+
+        m_size--;
+    }
+
+    // We maintain parent as a separate argument since node might be null
+    void remove_fixups(Node* node, Node* parent)
+    {
+        while (node != m_root && (!node || node->color == Color::Black)) {
+            if (parent->left_child == node) {
+                auto* sibling = parent->right_child;
+                if (sibling->color == Color::Red) {
+                    sibling->color = Color::Black;
+                    parent->color = Color::Red;
+                    rotate_left(parent);
+                    sibling = parent->right_child;
+                }
+                if ((!sibling->left_child || sibling->left_child->color == Color::Black) && (!sibling->right_child || sibling->right_child->color == Color::Black)) {
+                    sibling->color = Color::Red;
+                    node = parent;
+                } else {
+                    if (!sibling->right_child || sibling->right_child->color == Color::Black) {
+                        sibling->left_child->color = Color::Black; // null check?
+                        sibling->color = Color::Red;
+                        rotate_right(sibling);
+                        sibling = parent->right_child;
+                    }
+                    sibling->color = parent->color;
+                    parent->color = Color::Black;
+                    sibling->right_child->color = Color::Black; // null check?
+                    rotate_left(parent);
+                    node = m_root; // fixed
+                }
+            } else {
+                auto* sibling = parent->left_child;
+                if (sibling->color == Color::Red) {
+                    sibling->color = Color::Black;
+                    parent->color = Color::Red;
+                    rotate_right(parent);
+                    sibling = parent->left_child;
+                }
+                if ((!sibling->left_child || sibling->left_child->color == Color::Black) && (!sibling->right_child || sibling->right_child->color == Color::Black)) {
+                    sibling->color = Color::Red;
+                    node = parent;
+                } else {
+                    if (!sibling->left_child || sibling->left_child->color == Color::Black) {
+                        sibling->right_child->color = Color::Black; // null check?
+                        sibling->color = Color::Red;
+                        rotate_left(sibling);
+                        sibling = parent->left_child;
+                    }
+                    sibling->color = parent->color;
+                    parent->color = Color::Black;
+                    sibling->left_child->color = Color::Black; // null check?
+                    rotate_right(parent);
+                    node = m_root; // fixed
+                }
+            }
+            parent = node->parent;
+        }
+        node->color = Color::Black; // by this point node cant be null
+    }
+
+    static Node* successor(Node* node)
+    {
+        VERIFY(node);
+        if (node->right_child) {
+            node = node->right_child;
+            while (node->left_child)
+                node = node->left_child;
+            return node;
+        } else {
+            auto temp = node->parent;
+            while (temp && node == temp->right_child) {
+                node = temp;
+                temp = temp->parent;
+            }
+            return temp;
+        }
+    }
+
+    static Node* predecessor(Node* node)
+    {
+        VERIFY(node);
+        if (node->left_child) {
+            node = node->left_child;
+            while (node->right_child)
+                node = node->right_child;
+            return node;
+        } else {
+            auto temp = node->parent;
+            while (temp && node == temp->left_child) {
+                node = temp;
+                temp = temp->parent;
+            }
+            return temp;
+        }
+    }
+
+    Node* m_root { nullptr };
+    size_t m_size { 0 };
+    Node* m_minimum { nullptr }; // maintained for O(1) begin()
+};
+
+template<typename TreeType, typename ElementType>
+class RedBlackTreeIterator {
+public:
+    RedBlackTreeIterator() = default;
+    bool operator!=(const RedBlackTreeIterator& other) const { return m_node != other.m_node; }
+    RedBlackTreeIterator& operator++()
+    {
+        if (!m_node)
+            return *this;
+        m_prev = m_node;
+        // the complexity is O(logn) for each successor call, but the total complexity for all elements comes out to O(n), meaning the amortized cost for a single call is O(1)
+        m_node = static_cast<typename TreeType::Node*>(TreeType::successor(m_node));
+        return *this;
+    }
+    RedBlackTreeIterator& operator--()
+    {
+        if (!m_prev)
+            return *this;
+        m_node = m_prev;
+        m_prev = static_cast<typename TreeType::Node*>(TreeType::predecessor(m_prev));
+        return *this;
+    }
+    ElementType& operator*() { return m_node->value; }
+    ElementType* operator->() { return &m_node->value; }
+    [[nodiscard]] bool is_end() const { return !m_node; }
+    [[nodiscard]] bool is_begin() const { return !m_prev; }
+
+private:
+    friend TreeType;
+    explicit RedBlackTreeIterator(typename TreeType::Node* node, typename TreeType::Node* prev = nullptr)
+        : m_node(node)
+        , m_prev(prev)
+    {
+    }
+    typename TreeType::Node* m_node { nullptr };
+    typename TreeType::Node* m_prev { nullptr };
+};
+
+template<Integral K, typename V>
+class RedBlackTree : public BaseRedBlackTree<K> {
+public:
+    RedBlackTree() = default;
+    virtual ~RedBlackTree() override
+    {
+        clear();
+    }
+
+    using BaseTree = BaseRedBlackTree<K>;
+
+    V* find(K key)
+    {
+        auto* node = static_cast<Node*>(BaseTree::find(this->m_root, key));
+        if (!node)
+            return nullptr;
+        return &node->value;
+    }
+
+    V* find_largest_not_above(K key)
+    {
+        auto* node = static_cast<Node*>(BaseTree::find_largest_not_above(this->m_root, key));
+        if (!node)
+            return nullptr;
+        return &node->value;
+    }
+
+    void insert(K key, const V& value)
+    {
+        insert(key, V(value));
+    }
+
+    void insert(K key, V&& value)
+    {
+        auto* node = new Node(key, move(value));
+        BaseTree::insert(node);
+    }
+
+    using Iterator = RedBlackTreeIterator<RedBlackTree, V>;
+    friend Iterator;
+    Iterator begin() { return Iterator(static_cast<Node*>(this->m_minimum)); }
+    Iterator end() { return {}; }
+    Iterator begin_from(K key) { return Iterator(static_cast<Node*>(BaseTree::find(this->m_root, key))); }
+
+    using ConstIterator = RedBlackTreeIterator<const RedBlackTree, const V>;
+    friend ConstIterator;
+    ConstIterator begin() const { return ConstIterator(static_cast<Node*>(this->m_minimum)); }
+    ConstIterator end() const { return {}; }
+    ConstIterator begin_from(K key) const { return ConstIterator(static_cast<Node*>(BaseTree::find(this->m_root, key))); }
+
+    V unsafe_remove(K key)
+    {
+        auto* node = BaseTree::find(this->m_root, key);
+        VERIFY(node);
+
+        BaseTree::remove(node);
+
+        V temp = move(static_cast<Node*>(node)->value);
+
+        node->right_child = nullptr;
+        node->left_child = nullptr;
+        delete node;
+
+        return temp;
+    }
+
+    bool remove(K key)
+    {
+        auto* node = BaseTree::find(this->m_root, key);
+        if (!node)
+            return false;
+
+        BaseTree::remove(node);
+
+        node->right_child = nullptr;
+        node->left_child = nullptr;
+        delete node;
+
+        return true;
+    }
+
+    void clear()
+    {
+        if (this->m_root) {
+            delete this->m_root;
+            this->m_root = nullptr;
+        }
+        this->m_minimum = nullptr;
+        this->m_size = 0;
+    }
+
+private:
+    struct Node : BaseRedBlackTree<K>::Node {
+
+        V value;
+
+        Node(K key, V value)
+            : BaseRedBlackTree<K>::Node(key)
+            , value(move(value))
+        {
+        }
+
+        ~Node()
+        {
+            if (this->left_child)
+                delete this->left_child;
+            if (this->right_child)
+                delete this->right_child;
+        }
+    };
+};
+
+}
+
+using AK::RedBlackTree;

+ 1 - 0
AK/Tests/CMakeLists.txt

@@ -39,6 +39,7 @@ set(AK_TEST_SOURCES
     TestOptional.cpp
     TestOptional.cpp
     TestQueue.cpp
     TestQueue.cpp
     TestQuickSort.cpp
     TestQuickSort.cpp
+    TestRedBlackTree.cpp
     TestRefPtr.cpp
     TestRefPtr.cpp
     TestSinglyLinkedList.cpp
     TestSinglyLinkedList.cpp
     TestSourceGenerator.cpp
     TestSourceGenerator.cpp

+ 110 - 0
AK/Tests/TestRedBlackTree.cpp

@@ -0,0 +1,110 @@
+/*
+ * Copyright (c) 2021, Idan Horowitz <idan.horowitz@gmail.com>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ *    list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ *    this list of conditions and the following disclaimer in the documentation
+ *    and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#include <AK/TestSuite.h>
+
+#include <AK/Random.h>
+#include <AK/RedBlackTree.h>
+
+TEST_CASE(construct)
+{
+    RedBlackTree<int, int> empty;
+    EXPECT(empty.is_empty());
+    EXPECT(empty.size() == 0);
+}
+
+TEST_CASE(ints)
+{
+    RedBlackTree<int, int> ints;
+    ints.insert(1, 10);
+    ints.insert(3, 20);
+    ints.insert(2, 30);
+    EXPECT_EQ(ints.size(), 3u);
+    EXPECT_EQ(*ints.find(3), 20);
+    EXPECT_EQ(*ints.find(2), 30);
+    EXPECT_EQ(*ints.find(1), 10);
+    EXPECT(!ints.remove(4));
+    EXPECT(ints.remove(2));
+    EXPECT(ints.remove(1));
+    EXPECT(ints.remove(3));
+    EXPECT_EQ(ints.size(), 0u);
+}
+
+TEST_CASE(largest_smaller_than)
+{
+    RedBlackTree<int, int> ints;
+    ints.insert(1, 10);
+    ints.insert(11, 20);
+    ints.insert(21, 30);
+    EXPECT_EQ(ints.size(), 3u);
+    EXPECT_EQ(*ints.find_largest_not_above(3), 10);
+    EXPECT_EQ(*ints.find_largest_not_above(17), 20);
+    EXPECT_EQ(*ints.find_largest_not_above(22), 30);
+    EXPECT_EQ(ints.find_largest_not_above(-5), nullptr);
+}
+
+TEST_CASE(key_ordered_iteration)
+{
+    constexpr auto amount = 10000;
+    RedBlackTree<int, size_t> test;
+    Array<int, amount> keys {};
+
+    // generate random key order
+    for (int i = 0; i < amount; i++) {
+        keys[i] = i;
+    }
+    for (size_t i = 0; i < amount; i++) {
+        swap(keys[i], keys[get_random<size_t>() % amount]);
+    }
+
+    // insert random keys
+    for (size_t i = 0; i < amount; i++) {
+        test.insert(keys[i], keys[i]);
+    }
+
+    // check key-ordered iteration
+    size_t index = 0;
+    for (auto& value : test) {
+        EXPECT(value == index++);
+    }
+
+    // ensure we can remove all of them (aka, tree structure is not destroyed somehow)
+    for (size_t i = 0; i < amount; i++) {
+        EXPECT(test.remove(i));
+    }
+}
+
+TEST_CASE(clear)
+{
+    RedBlackTree<size_t, size_t> test;
+    for (size_t i = 0; i < 1000; i++) {
+        test.insert(i, i);
+    }
+    test.clear();
+    EXPECT_EQ(test.size(), 0u);
+}
+
+TEST_MAIN(RedBlackTree)