浏览代码

ChessEngine: Don't throw away useful branches from last tree

Computation from last turn might have produced some nodes that are still
accurate. Keeping them should make the engine a bit smarter.
Lucas CHOLLET 2 年之前
父节点
当前提交
d5979516b4

+ 12 - 2
Userland/Services/ChessEngine/ChessEngine.cpp

@@ -38,7 +38,14 @@ void ChessEngine::handle_go(GoCommand const& command)
 
     auto elapsed_time = Core::ElapsedTimer::start_new();
 
-    MCTSTree mcts(m_board);
+    auto mcts = [this]() -> MCTSTree {
+        if (!m_last_tree.has_value())
+            return { m_board };
+        auto x = m_last_tree.value().child_with_move(m_board.last_move().value());
+        if (x.has_value())
+            return move(x.value());
+        return { m_board };
+    }();
 
     int rounds = 0;
     while (elapsed_time.elapsed() <= command.movetime.value()) {
@@ -47,7 +54,10 @@ void ChessEngine::handle_go(GoCommand const& command)
     }
     dbgln("MCTS finished {} rounds.", rounds);
     dbgln("MCTS evaluation {}", mcts.expected_value());
-    auto best_move = mcts.best_move();
+    auto& best_node = mcts.best_node();
+    auto const& best_move = best_node.last_move();
     dbgln("MCTS best move {}", best_move.to_long_algebraic());
     send_command(BestMoveCommand(best_move));
+
+    m_last_tree = move(best_node);
 }

+ 2 - 0
Userland/Services/ChessEngine/ChessEngine.h

@@ -6,6 +6,7 @@
 
 #pragma once
 
+#include "MCTSTree.h"
 #include <LibChess/Chess.h>
 #include <LibChess/UCIEndpoint.h>
 
@@ -26,4 +27,5 @@ private:
     }
 
     Chess::Board m_board;
+    Optional<MCTSTree> m_last_tree;
 };

+ 32 - 4
Userland/Services/ChessEngine/MCTSTree.cpp

@@ -16,6 +16,19 @@ MCTSTree::MCTSTree(Chess::Board const& board, MCTSTree* parent)
 {
 }
 
+MCTSTree::MCTSTree(MCTSTree&& other)
+    : m_children(move(other.m_children))
+    , m_parent(other.m_parent)
+    , m_white_points(other.m_white_points)
+    , m_simulations(other.m_simulations)
+    , m_board(move(other.m_board))
+    , m_last_move(move(other.m_last_move))
+    , m_turn(other.m_turn)
+    , m_moves_generated(other.m_moves_generated)
+{
+    other.m_parent = nullptr;
+}
+
 MCTSTree& MCTSTree::select_leaf()
 {
     if (!expanded() || m_children.size() == 0)
@@ -117,22 +130,37 @@ void MCTSTree::do_round()
     node.apply_result(result);
 }
 
-Chess::Move MCTSTree::best_move() const
+Optional<MCTSTree&> MCTSTree::child_with_move(Chess::Move chess_move)
+{
+    for (auto& node : m_children) {
+        if (node.last_move() == chess_move)
+            return node;
+    }
+    return {};
+}
+
+MCTSTree& MCTSTree::best_node()
 {
     int score_multiplier = (m_turn == Chess::Color::White) ? 1 : -1;
 
-    Chess::Move best_move = { { 0, 0 }, { 0, 0 } };
+    MCTSTree* best_node_ptr = nullptr;
     double best_score = -double(INFINITY);
     VERIFY(m_children.size());
     for (auto& node : m_children) {
         double node_score = node.expected_value() * score_multiplier;
         if (node_score >= best_score) {
-            best_move = node.m_last_move.value();
+            best_node_ptr = &node;
             best_score = node_score;
         }
     }
+    VERIFY(best_node_ptr);
 
-    return best_move;
+    return *best_node_ptr;
+}
+
+Chess::Move MCTSTree::last_move() const
+{
+    return m_last_move.value();
 }
 
 double MCTSTree::expected_value() const

+ 6 - 1
Userland/Services/ChessEngine/MCTSTree.h

@@ -20,6 +20,7 @@ public:
     };
 
     MCTSTree(Chess::Board const& board, MCTSTree* parent = nullptr);
+    MCTSTree(MCTSTree&&);
 
     MCTSTree& select_leaf();
     MCTSTree& expand();
@@ -28,7 +29,11 @@ public:
     void apply_result(int game_score);
     void do_round();
 
-    Chess::Move best_move() const;
+    Optional<MCTSTree&> child_with_move(Chess::Move);
+
+    MCTSTree& best_node();
+
+    Chess::Move last_move() const;
     double expected_value() const;
     double uct(Chess::Color color) const;
     bool expanded() const;