43 #ifndef DUMBO_CORE_MCTS_H
44 #define DUMBO_CORE_MCTS_H
46 #include <dumbo/core/game_state.h>
47 #include <dumbo/core/move.h>
48 #include <dumbo/core/solver.h>
50 #include <glog/logging.h>
55 #include <unordered_map>
61 template <
typename M,
typename G>
68 M Run(
const G& state);
73 std::unordered_map<M, std::shared_ptr<Node>,
typename M::Hasher> children;
74 std::shared_ptr<Node> parent =
nullptr;
83 void Update(
double win) {
91 if (parent) parent->Update(win);
98 template <
typename M,
typename G>
100 CHECK(state.IsMyTurn());
103 const auto start_time = std::chrono::high_resolution_clock::now();
106 std::shared_ptr<Node> root(
new Node);
110 std::vector<std::shared_ptr<Node>> registry = {root};
112 while (std::chrono::duration<double>(
113 std::chrono::high_resolution_clock::now() - start_time)
114 .count() < this->max_time_per_move_) {
116 auto compare_ucbs = [](
const std::pair<M, std::shared_ptr<Node>>& entry1,
117 const std::pair<M, std::shared_ptr<Node>>& entry2) {
118 const std::shared_ptr<Node>& n1 = entry1.second;
119 const std::shared_ptr<Node>& n2 = entry2.second;
120 constexpr
double kNumStddevs = 1.414;
122 CHECK_GT(n1->total, 0.0);
123 CHECK_GT(n2->total, 0.0);
128 const double parent_total1 = (n1->parent) ? n1->parent->total : n1->total;
130 (n1->wins / n1->total) +
131 kNumStddevs * std::sqrt(std::log(parent_total1) / n1->total);
133 const double parent_total2 = (n2->parent) ? n2->parent->total : n2->total;
135 (n2->wins / n2->total) +
136 kNumStddevs * std::sqrt(std::log(parent_total2) / n2->total);
146 std::shared_ptr<Node> node = root;
147 while (!node->children.empty() &&
148 node->children.size() == node->state.LegalMoves().size()) {
149 auto iter = std::max_element(node->children.begin(), node->children.end(),
156 if (node->state.IsTerminal(&win)) {
157 VLOG(1) <<
"Hit a terminal leaf node. Updating and continuing.";
163 M move = node->state.RandomMove();
164 while (node->children.count(move)) {
165 move = node->state.RandomMove();
168 std::shared_ptr<Node> expansion(
new Node);
169 registry.emplace_back(expansion);
171 CHECK(node->state.NextState(move, &expansion->state));
172 expansion->parent = node;
173 node->children.emplace(move, expansion);
176 G current_state = expansion->state;
177 while (!current_state.IsTerminal(&win)) {
178 const M move = current_state.RandomMove();
181 CHECK(current_state.NextState(move, &next_state));
183 current_state = next_state;
187 expansion->Update(win);
192 std::max_element(root->children.begin(), root->children.end(),
193 [](
const std::pair<M, std::shared_ptr<Node>>& entry1,
194 const std::pair<M, std::shared_ptr<Node>>& entry2) {
195 return entry1.second->wins / entry1.second->total <
196 entry2.second->wins / entry2.second->total;