dumbo
A fun little game engine.
 All Classes
mcts.h
1 /*
2  * Copyright (c) 2019. David Fridovich-Keil.
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are
7  * met:
8  *
9  * 1. Redistributions of source code must retain the above copyright
10  * notice, this list of conditions and the following disclaimer.
11  *
12  * 2. Redistributions in binary form must reproduce the above
13  * copyright notice, this list of conditions and the following
14  * disclaimer in the documentation and/or other materials provided
15  * with the distribution.
16  *
17  * 3. Neither the name of the copyright holder nor the names of its
18  * contributors may be used to endorse or promote products derived
19  * from this software without specific prior written permission.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS AS IS
22  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
24  * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
25  * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
26  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
27  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
28  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
29  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
30  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31  * POSSIBILITY OF SUCH DAMAGE.
32  *
33  * Please contact the author(s) of this library if you have any questions.
34  * Author: David Fridovich-Keil ( dfk@eecs.berkeley.edu )
35  */
36 
37 ///////////////////////////////////////////////////////////////////////////////
38 //
39 // Generic implementation of Monte Carlo Tree Search.
40 //
41 ///////////////////////////////////////////////////////////////////////////////
42 
43 #ifndef DUMBO_CORE_MCTS_H
44 #define DUMBO_CORE_MCTS_H
45 
46 #include <dumbo/core/game_state.h>
47 #include <dumbo/core/move.h>
48 #include <dumbo/core/solver.h>
49 
50 #include <glog/logging.h>
51 #include <algorithm>
52 #include <chrono>
53 #include <iterator>
54 #include <memory>
55 #include <unordered_map>
56 #include <vector>
57 
58 namespace dumbo {
59 namespace core {
60 
61 template <typename M, typename G>
62 class MCTS : public Solver<M, G> {
63  public:
64  ~MCTS() {}
65  MCTS(double max_time_per_move = 1.0) : Solver<M, G>(max_time_per_move) {}
66 
67  // Run the solver on the specified game state. Returns a move.
68  M Run(const G& state);
69 
70  private:
71  struct Node {
72  G state;
73  std::unordered_map<M, std::shared_ptr<Node>, typename M::Hasher> children;
74  std::shared_ptr<Node> parent = nullptr;
75  double wins = 0.0;
76  double total = 0.0;
77 
78  // Update this node and all of its parents with a win/loss/draw result.
79  // Winning is encoded as 1.0, loss as 0.0, and draw as 0.5.
80  // The input is whether the computer won the rollout; within each node, the
81  // 'wins' variable stores the number of wins for the player who played the
82  // move immediately leading to that node.
83  void Update(double win) {
84  total += 1.0;
85 
86  if (state.IsMyTurn())
87  wins += 1.0 - win;
88  else
89  wins += win;
90 
91  if (parent) parent->Update(win);
92  }
93  }; //\struct Node
94 }; //\class MCTS
95 
96 // ---------------------------- IMPLEMENTATION ----------------------------- //
97 
98 template <typename M, typename G>
99 M MCTS<M, G>::Run(const G& state) {
100  CHECK(state.IsMyTurn());
101 
102  // Start a clock.
103  const auto start_time = std::chrono::high_resolution_clock::now();
104 
105  // Root an empty tree at the initial state.
106  std::shared_ptr<Node> root(new Node);
107  root->state = state;
108 
109  // Keep track of all nodes we've seen so far.
110  std::vector<std::shared_ptr<Node>> registry = {root};
111 
112  while (std::chrono::duration<double>(
113  std::chrono::high_resolution_clock::now() - start_time)
114  .count() < this->max_time_per_move_) {
115  // (1) Pick a node to expand.
116  auto compare_ucbs = [](const std::pair<M, std::shared_ptr<Node>>& entry1,
117  const std::pair<M, std::shared_ptr<Node>>& entry2) {
118  const std::shared_ptr<Node>& n1 = entry1.second;
119  const std::shared_ptr<Node>& n2 = entry2.second;
120  constexpr double kNumStddevs = 1.414; // std::sqrt(2.0);
121 
122  CHECK_GT(n1->total, 0.0);
123  CHECK_GT(n2->total, 0.0);
124 
125  // Compute UCBs for both nodes.
126  // NOTE: this is the UCT rule which may be found at
127  // https://en.wikipedia.org/wiki/Monte_Carlo_tree_search.
128  const double parent_total1 = (n1->parent) ? n1->parent->total : n1->total;
129  const double ucb1 =
130  (n1->wins / n1->total) +
131  kNumStddevs * std::sqrt(std::log(parent_total1) / n1->total);
132 
133  const double parent_total2 = (n2->parent) ? n2->parent->total : n2->total;
134  const double ucb2 =
135  (n2->wins / n2->total) +
136  kNumStddevs * std::sqrt(std::log(parent_total2) / n2->total);
137 
138  // Compare UCBs.
139  return ucb1 < ucb2;
140  }; // compare_ucbs
141 
142  // Start at the root and walk down to a leaf node, using the above
143  // comparitor to choose moves for both players.
144  // NOTE: always use max UCB since each node stores the wins for the
145  // player whose move it is at that node.
146  std::shared_ptr<Node> node = root;
147  while (!node->children.empty() &&
148  node->children.size() == node->state.LegalMoves().size()) {
149  auto iter = std::max_element(node->children.begin(), node->children.end(),
150  compare_ucbs);
151  node = iter->second;
152  }
153 
154  // If this is a terminal node, then just try again.
155  double win = 0.0;
156  if (node->state.IsTerminal(&win)) {
157  VLOG(1) << "Hit a terminal leaf node. Updating and continuing.";
158  node->Update(win);
159  continue;
160  }
161 
162  // (2) Expand the node by choosing a random move not already tried yet.
163  M move = node->state.RandomMove();
164  while (node->children.count(move)) {
165  move = node->state.RandomMove();
166  }
167 
168  std::shared_ptr<Node> expansion(new Node);
169  registry.emplace_back(expansion);
170 
171  CHECK(node->state.NextState(move, &expansion->state));
172  expansion->parent = node;
173  node->children.emplace(move, expansion);
174 
175  // (3) Sample random game trajectory from the expanded node.
176  G current_state = expansion->state;
177  while (!current_state.IsTerminal(&win)) {
178  const M move = current_state.RandomMove();
179 
180  G next_state;
181  CHECK(current_state.NextState(move, &next_state));
182 
183  current_state = next_state;
184  }
185 
186  // (3) Update all ancestors of leaf node.
187  expansion->Update(win);
188  }
189 
190  // Ran out of time. Pick the best move from the root.
191  const auto iter =
192  std::max_element(root->children.begin(), root->children.end(),
193  [](const std::pair<M, std::shared_ptr<Node>>& entry1,
194  const std::pair<M, std::shared_ptr<Node>>& entry2) {
195  return entry1.second->wins / entry1.second->total <
196  entry2.second->wins / entry2.second->total;
197  });
198  return iter->first;
199 }
200 
201 } // namespace core
202 } // namespace dumbo
203 
204 #endif