Monte Carlo Tree Search

Monte Carlo Tree Search is a robust algorithm used for decision making in game playing and reinforcement learning. It is a heuristic search algorithm that combines the best aspects of tree search and Monte Carlo simulation. MCTS is particularly useful when dealing with large and complex search spaces, such as in neurosymbolic systems and reinforcement learning.

Today let's implement MCTS in Python. First we need to define a Node class that will be used to represent the state.

rom dataclasses import dataclass, field
from collections import defaultdict
import math
from typing import Dict, Set, List, Optional
from abc import ABC, abstractmethod
import random


class Node(ABC):
    """Abstract base class for game state nodes."""

    @abstractmethod
    def find_children(self) -> Set["Node"]:
        """Return all possible child nodes."""

    @abstractmethod
    def get_random_child(self) -> Optional["Node"]:
        """Return a random child node."""

    @abstractmethod
    def is_terminal(self) -> bool:
        """Check if the node is a terminal state."""

    @abstractmethod
    def reward(self) -> float:
        """Return the reward value for a terminal node."""

    @abstractmethod
    def __hash__(self) -> int:
        """Define a hash function for the node."""

    @abstractmethod
    def __eq__(self, other: object) -> bool:
        """Define equality comparison for nodes."""

Now we'll define the MCTS class that will hold the search state and perform the tree search.

@dataclass
class MCTS:
    """Monte Carlo Tree Search implementation."""

    exploration_weight: float = 1.0
    total_rewards: Dict[Node, float] = field(default_factory=lambda: defaultdict(float))
    visit_counts: Dict[Node, int] = field(default_factory=lambda: defaultdict(int))
    children: Dict[Node, Set[Node]] = field(default_factory=lambda: defaultdict(set))

Now we'll implement the select_best_move method that will select the best move from the current node.

    def select_best_move(self, node: Node) -> Node:
        """Select the best move from the current node."""
        if node.is_terminal():
            raise ValueError(f"Cannot select move from terminal node: {node}")

        if node not in self.children:
            return node.get_random_child()

In the algorithm we'll use the UCT (Upper Confidence Trees) formula to select the best move.

$$
UCT(v) = \frac{w_v}{n_v} + c \sqrt{\frac{\ln(N)}{n_v}}
$$

Where:

  • $w_v$ is the total reward of the node.
  • $n_v$ is the number of times the node has been visited.
  • $N$ is the total number of simulations.
  • $c$ is the exploration constant.
    def simulate(self, node: Node) -> None:
        """Perform one iteration of the MCTS algorithm."""
        path = self._traverse_tree(node)
        leaf = path[-1]
        self._expand_node(leaf)
        reward = self._simulate_random_playout(leaf)
        self._backpropagate(path, reward)

    def _traverse_tree(self, node: Node) -> List[Node]:
        """Traverse the tree to find an unexplored node."""
        path = []
        while True:
            path.append(node)
            if node not in self.children or not self.children[node]:
                return path
            unexplored = self.children[node] - set(self.children.keys())
            if unexplored:
                path.append(unexplored.pop())
                return path
            node = self._select_uct(node)

    def _expand_node(self, node: Node) -> None:
        """Expand the node by adding its children to the tree."""
        if node in self.children:
            return
        self.children[node] = node.find_children()

Now we'll implement the _simulate_random_playout method that will simulate a random playout from the given node.

    def _simulate_random_playout(self, node: Node) -> float:
        """Simulate a random playout from the given node."""
        invert_reward = True
        while True:
            if node.is_terminal():
                reward = node.reward()
                return 1 - reward if invert_reward else reward
            node = node.get_random_child()
            invert_reward = not invert_reward

Now we'll implement the _backpropagate method that will update the statistics for the nodes in the path.

    def _backpropagate(self, path: List[Node], reward: float) -> None:
        """Update the statistics for the nodes in the path."""
        for node in reversed(path):
            self.visit_counts[node] += 1
            self.total_rewards[node] += reward
            reward = 1 - reward

The select method will select a child node using the UCT metric.

    def _select_uct(self, node: Node) -> Node:
        """Select a child node using the UCT formula."""
        assert all(n in self.children for n in self.children[node])
        log_n_parent = math.log(self.visit_counts[node])

        def uct(n: Node) -> float:
            return self.total_rewards[n] / self.visit_counts[
                n
            ] + self.exploration_weight * math.sqrt(log_n_parent / self.visit_counts[n])

        return max(self.children[node], key=uct)

And then finally we'll implement the _calculate_node_score method that will calculate the score for a node based on its average reward.

    def _calculate_node_score(self, nodeNode) -> float:
        """Calculate the score for a node based on its average reward."""
        if self.visit_counts[node] == 0:
            return float("-inf")
        return self.total_rewards[node] / self.visit_counts[node]

Now to use this we need to define a TicTacToeNode class that will be used to represent the state of the game.

class TicTacToeNode(Node):
    def __init__(self, state: str, player: str):
        self.state = state
        self.player = player

    def find_children(self) -> Set["TicTacToeNode"]:
        if self.is_terminal():
            return set()
        return {
            TicTacToeNode(
                self.state[:i] + self.player + self.state[i + 1 :],
                "O" if self.player == "X" else "X",
            )
            for i, value in enumerate(self.state)
            if value == " "
        }

    def get_random_child(self) -> Optional["TicTacToeNode"]:
        if self.is_terminal():
            return None
        empty_spots = [i for i, value in enumerate(self.state) if value == " "]
        index = random.choice(empty_spots)
        return TicTacToeNode(
            self.state[:index] + self.player + self.state[index + 1 :],
            "O" if self.player == "X" else "X",
        )

    def is_terminal(self) -> bool:
        return self.winner() is not None or " " not in self.state

    def reward(self) -> float:
        winner = self.winner()
        if winner is None:
            return 0.5  # Draw
        return 1.0 if winner == self.player else 0.0

    def __hash__(self) -> int:
        return hash(self.state)

    def __eq__(self, other: object) -> bool:
        return isinstance(other, TicTacToeNode) and self.state == other.state

    def winner(self) -> Optional[str]:
        lines = [
            (0, 1, 2),
            (3, 4, 5),
            (6, 7, 8),  # Rows
            (0, 3, 6),
            (1, 4, 7),
            (2, 5, 8),  # Columns
            (0, 4, 8),
            (2, 4, 6),  # Diagonals
        ]
        for line in lines:
            if self.state[line[0]] == self.state[line[1]] == self.state[line[2]] != " ":
                return self.state[line[0]]
        return None

Now we can use the MCTS class to play a game of TicTacToe.

def play_game():
    state = " " * 9
    mcts = MCTS()
    board = TicTacToeNode(state, "X")

    print("Initial board:")
    print_board(board.state)

    while True:
        # Human player's turn (O)
        human_move = int(input("Enter your move (0-8): "))
        board = TicTacToeNode(
            board.state[:human_move] + "O" + board.state[human_move + 1 :], "X"
        )
        print("\nBoard after your move:")
        print_board(board.state)

        if board.is_terminal():
            break

        # AI player's turn (X)
        for _ in range(1000):  # Number of MCTS iterations
            mcts.simulate(board)
        board = mcts.select_best_move(board)
        print("\nBoard after AI move:")
        print_board(board.state)

        if board.is_terminal():
            break

    winner = board.winner()
    if winner:
        print(f"\nPlayer {winner} wins!")
    else:
        print("\nIt's a draw!")


def print_board(state: str):
    for i in range(0, 9, 3):
        print(" ".join(state[i : i + 3]))


if __name__ == "__main__":
    play_game()