A Monte Carlo Tree Search for Reward-Friendly Optimization Problems

Posted on || 13 minute read

Table of Contents

This post is the whitepaper for a modification I am making to the AlphaZero algorithm for a family of optimization problems I am interested in. We expose the topic and end with a technical breakdown of the key data structure used.

Background

Pick your favorite state space $\mathcal{S}$ and endow each $s\in\mathcal{S}$ with a set $\mathcal{A}(s)$ of actions. For each \(a\in\mathcal{A}(s)\), then $as\in \mathcal{S}$ is formed by taking $a$ from $s$. The value of $as$ may be deterministic or stochastic (e.g., in the context of single-player games) or policy-dependent (e.g., in the context of multiplayer games). If $\mathcal{A}(s) = \emptyset$, then $s$ is terminal. Additionally, for each terminal $s$, we have a valuation $v(s)$ which we seek to optimize.

Chess

In this classic example, we play as White against an opponent playing with some fixed policy. Here,

  • $\mathcal{S}$: All possible chess board positions with White to play.
    • This includes piece positions, the states of castling, and information about $3$-fold repetition and the $50$-move rules.
  • $\mathcal{A}$: All legal moves for White to play.
  • $as$: When White takes move $a$, a board $s'$ with Black to play is formed; Black plays according to their policy, producing a board $as$ with White to play.
    • When Black has no move, instead $as$ is the board $s'$ with a decoration that White has won or forced a stalemate.
    • This requires a small modification to the defintion of $\mathcal{S}$.
  • $v(s)$: If the game is over, a score of $+1,$ $0,$ or $-1$ indicates the winning player.

If White and Black play with respect to stochastic policies $\pi_W, \pi_B$, then we are encouraged to extend $v$ from terminal states to $$v_{\pi_W, \pi_B}(s) := \mathbb{E}[v(s') : \text{from }s, \text{terminal state }s'\text{ is reached when Black and White follow }\pi_W\text{ and }\pi_B,\text{respectively}]$$

Space Invaders

In this example:

  • $\mathcal{S}$: all possible single-frame game states in the game Space Invaders
    • this includes positions and velocities of the player and all asteriods
    • it also includes score and other internal game states
  • $\mathcal{A}(s)$: combinations of button inputs legal in Space Invaders
  • $as$: the result of playing a combination $a$ of buttons for a frame from state $s$
  • $v(s)$: if the game is over, the player's current score.

While it may be tempting to extend $v$ to $\mathcal{S}$ with the current score function, we follow the previous convention:

$$v_{\pi}(s) := \mathbb{E}[v(s') : \text{from }s, \text{terminal state }s'\text{ is reached when the player follows }\pi]$$

Policies

A policy maps each nonterminal state $s$ to a probability distribution over $\mathcal{A}(s)$. Some policies are better than others, as determined by component-wise comparison. That is, we may say that $\pi\leq \pi'$ if for all states $s$, $v_\pi(s)\leq v_{\pi'}(s)$. Finding optimal policies is intractible, even when $\mathcal{S}$ and $\mathcal{A}(s)$ are finite and state transitions are deterministic.

We are forced to make tradeoffs when defining our policy $\pi$:

  • implementation: how to store $\pi$ so it can quickly evaluate states, but also be fine-tuned with new data,
  • exploitation: how much to trust that $\pi$ will lead us to the best terminal states if we follow it greedily, and
  • exploration: how much to distrust $\pi$, taking less preferred actions in order to gather data counteracting our bias.

When building a policy $\pi$ from nothing, we conduct a tree search as following:

  • $s_{\text{root}}$: select a root in $S$ and build a tree
  • transitions: Explore from the root, recording the transitions $(a_1, s_1), \cdots, (a_t, s_t)$.
    • For all $i=1..t$, $a_i\in\mathcal{A}(s_{i-1})$ is the action taken from $s_{i-1}$, which produced $s_i$.
    • $v$: $s_t$ is terminal with value $v = v(s_t)$.
  • biases: When selecting an action $a\in\mathcal{A}(s)$, we bias ourselves to:
    • exploration: Infrequently visited state-action pairs are preferable.
      • Log the values $n(s)$, the number of actions taken from $s$ during the tree's lifetime.
      • Log also the values $n(s, a)$, the number of visits to $(s, a)$.
        • If practical, we may have access to $n(sa)$ and do not need to define $n(s, a)$.
    • score: Actions which previously led to a high score are preferable.
      • Log $w(s, a)$, the sum of all terminal scores over all explored paths with the transition $(s, a)$.
        • If practical, we may instead log $w(s)$ defined similarly and use $w(sa)$ to inform us about $a$.
        • This complicates how the average of $w$ is calculated.
      • We are biased to take $a$ from $s$ if $w(s, a) / n(s, a)$ is high.

This begs several imporatant questions:

  • prior distributions: How should pick actions from unvisited nodes?
    • When the tree is initialized, all values of $n(s), n(s, a),$ and $w(s, a)$ are $0$.
    • A priori, we have no bias to suggest what action to take.
    • In most games, uniform distrubtions are terrible guesses.
    • We encounter this problem each time a new (nonterminal) node is added to the tree.
  • bias terms: How should we pick actions from visited nodes?
    • Even when $s$ has been visited, what is biases us towards one action over another? For example, $$\dfrac{w(s, a)}{n(s, a)} + c\cdot\dfrac{\sqrt{n(s)}}{n(s, a)}.$$
      • This choice is extremely arbitrary. Many others may perform better.
    • How practically can we store information about every visited action from every visited state?
    • How fast will we converge to an optimal policy, if at all?
  • learning: Are we expected to trust our observations?
    • Is the policy identical to the entire tree?
    • How practically can we use it in future simulations?

AlphaZero

Formulation

In this paper, the DeepMind team made crucial modifications to the classical MCTS search. First, they use an artificial neural network to define a model $$f_\theta: \mathcal{S}\to\mathbb{R}^{|\mathcal{A}| + 1}.$$

  • By assumption, there exists a finite set $\mathcal{A}$ such that for all $s\in\mathcal{S}$, $\mathcal{A}(s)\subseteq\mathcal{A}$.
  • Here, $\theta$ is a high-dimensional real-valued vector of weights and biases for the underlying artificial neural network.

For each $s\in\mathcal{S}$, we interpret $f_\theta(s)$ as the direct sum of $p_\theta(s, \cdot)$ and $v_\theta(s)$. Here:

  • $p_\theta(s,\cdot)\in\mathbb{R}^{|\mathcal{A}|}$ may be interpreted as a probability distribution over $\mathcal{A}$.
    • The actions in $\mathcal{A}\setminus\mathcal{A}(s)$ are ignored.
    • We may choose to normalize with the softmax function or a simple average.
  • $v_\theta(s)\in\mathbb{R}$: a prediction for the terminal value of $v$ when following a policy defined as a function of $\theta$.

Note: In fact, $p_\theta(s, \cdot)$ is also a prediction. The model $f_\theta$ estimates the probability distribution of our agent, whose policy is defined as a function of $f_\theta$!

Learning loop

Rather than store large swaths of information in a Monte Carlo tree, our learning loop is as follows:

  • plant many trees: Create a batch of search trees to explore in parallel.
    • Select roots with a bias towards high values of $v_\theta$.
    • Avoid batches of roots that are "similar" to each other. Perhaps $\mathcal{S}$ has a distance metric.
  • trust the model: The model may be used to enhance the search algorithm.
  • $p_\theta(s, a)$: Use this value to adjust the bias towards taking action $a$ from $s$.
    • Dirichlet noise is added to encourage exploration.
  • $v_\theta(s)$: When we encounter $s$, and is not terminal and unvisited, add $s$ to the tree and use $v_\theta(s)$.
    • Since $s$ is not terminal, we may need to take many actions to reach a terminal state.
    • Trust the model's prediction will improve over time.
    • If $s$ is visited later, we will play moves from $a$.
  • chopped trees: Suppose a tree is rooted at $s_0$.
    • After a few thousands simulations, we have explored a few thousand paths.
    • Many of those pathes ended by adding a new node to the tree and using $v_\theta$ to estimate the ending value.
    • Many ended with a terminal state.
    • Inside the tree, we have recorded $n(s_0)$ and $n(s_0, a)$ for all $a\in\mathcal{A}(s_0)$.
    • We also have recorded $w(s_0)$ and $w(s_0, a)$ for all $a\in\mathcal{A}(s_0)$.
    • We emit $p_{\text{obs}} = (n(s_0, a)/n(s_0))_{a\in\mathcal{A}}$.
    • We also emit $v_{\text{obs}} = w(s_0)\cdot/n(s_0)$.
  • model updates: With a batch of observations, we adjust the parameters $\theta$ in our model to decrease a loss function.
    • Our batch consists of:
      • roots: $(s_i)_{i=1..B}$,
      • observed values: $(v_{\text{obs, i}})_{i=1..B}$, and
      • observed probabilities: $(p_{\text{obs, i}})_{i=1..B}$.
    • Loss includes terms such as:
      • mis-estimated values: $(v_\theta(s_i) - v_{\text{obs, i}})^2$
      • cross-entropy: $H(p, p) = -\sum_a p_{\text{obs, i}}(a)\cdot\log(p_\theta(s_i, a))$
      • overfitting: $|\theta|_2^2$. (a.k.a. regularization)

IRTree

This is my tentative name for a modification of AlphaZero to the context of mathematical optimization. Here:

  • I(nitial): The value of $c(s)$ is only calculated directly for root states. All remaining values are computed dynamic with the reward function.
  • R(eward): When a state $s$ is added to the tree, we immediately record the reward function $r(s, \cdot)$ into memory.

Assumptions

  • There exists a cost function $c:\mathcal{S}\to\mathbb{R}$ to minimize.
  • For all $s\in\mathcal{S}$, the reward function $r(s, \cdot): \mathcal{A}(s)\to\mathbb{R}$ is convenient to compute as a function of $s$.
    • Here, for all $a\in\mathcal{A}(s)$, $r(s, a) = c(s) - c(as)$.

Transitions

Since we have a cheap reward function, we instead update the tree with improvement in the cost function after visiting $(s, a)$. Suppose $s_0\in\mathcal{S}$ and $\ell\in\mathbb{N}$. Suppose furthermore that $t = (t_i)_{i=1..\ell}\in(\mathcal{S}\times\mathbb{R}\times\mathcal{A})^\ell$ such that for all $i=1..\ell$:

  • $t_i = (a_i, r_i, s_i)\in\mathcal{A}\times\mathbb{R}\times\mathcal{S}$,
  • $a_i\in\mathcal{A}(s_{i-1})$ and $r_i = r(s_{i-1}, a_i)$, and
  • $s_i = a_is_{i-1}$. Then we say that $t$ are the transitions (of a path of length $\ell$ corresponding to $c$) from $s_0$. Then we define the max gains $g^\ast = g_t^\ast = (g_i^\ast)_{i=0..t}$ (of $c$, from $s_0$, along $t$) as follows:
  • $g_t^\ast = 0$, and
  • $g_i^\ast = \max(0, r_{i+1} + g_{i+1}^\ast)$ for all $i = 0..t-1$. Note that for all $i = 0..t$, $$g_i^\ast = \max_{j = i..t}(r_{i+1} +\cdots+r_j).$$ Additionally, if $j^\ast$ attains the maximum with respect to $i$, then $j^\ast$ also attains $$\min_{j=i..t}c(s_j).$$ Finally, suppose $j$ attains the above maximum (or maxima) with respect to $i = 0$. Then we say that $t$ observes an improvement in $c$ of $g_0^\ast$ from $s_0$ (at $s_j$).

Finally, suppose we have a model (with parameters $\theta$) which defines a prediction function $g_\theta^\ast:\mathcal{S}\to\mathbb{R}$. Then as above, we define the predicted max gains $h_{\theta}^\ast = h_{\theta, t}^{\ast} = (h_{i}^{\ast})_{i=0..t}$ (of $c$, from $s_0$, along $t$, with respect to $\theta$) as follows:

  • $h_t^\ast = \max 0, g_\theta^\ast$, and
  • $h_i^\ast = \max(0, r_{i+1} + h_{i+1}^\ast)$ for all $i = 0..t-1$.

For convenience, we extend $r = (r_i)_{i=1..t}$ to

$$r' = (r_i')_{i=1..t+1}$$

by letting $r_i' := r_i$ for all $i = 1..t$ and $r_{t+1}' := g_\theta^\ast(s_t)$.

As before, we note that for all $i = 0..t$,

$$h_i^\ast = \max_{j=i..t+1}(r_{i+1}'+\cdots+r_j').$$

For a final observation, let

$$c' := (c_i')_{i=0..t+1}$$

where $c_i' := c(s_i)$ for all $i = 0..t$ and $c_{t+1}' := c(s_t) + g_\theta^\ast(s_t)$. Then the maxima of $h_i^\ast$ and the minima of $c_i'$ are identical.

Model

We model with prediction function $$f_\theta(s) = \left(\begin{array}{r} p_\theta(s, \cdot)\ g_\theta(s)\ g_\theta^\ast(s)\ \tilde{g}_\theta^\ast(s) \end{array}\right)$$

Here, we predict:

  • $g_\theta(s)$: the expected value of $c(s) - c(s')$ if a search starting with $s$ ends at terminal node $s'$.
  • $g_\theta^\ast(s)$: the maximum value of $c(s) - c(s')$ over all paths explored during the search.
    • The number of paths explored is fixed.
  • $\tilde{g}_\theta^\ast$: the expected maximum value of $c(s) - c(s')$ if a search starts with $s$ and passes through $s'$.

We iteratively search and modify a tree $m$ rooted at some $s_{\text{root}}\in\mathcal{S}$. Each search in $m$ begins at $s_0 = s_{\text{root}}$, corresponds to a sequence $t = (t_i)_{i=1..\ell}$ of transitions as defined above.

Inside $m$, we identify states by a canonical path from $s_{\text{root}}$. Given the decomposition of each reward function as $$r(s, \cdot) = f(s, \cdot) + k(s, \cdot),$$ we store for each visited $s\in\mathcal{S}$ the following unchanging information inside $m$, in some implicit form:

  • (mut) $n(s)\in\mathbb{N}$, the frequency of $s$
  • (mut) $\tilde{r}(s, \cdot): \mathcal{A}(s)\to\mathbb{R}$, our best estimate for $r(s, \cdot)$
    • initially $f(s, \cdot)$, and $\tilde{r}(s, a)$ is replaced by $f(s, a) + k(s, a)$ when $(s, a)$ is visited
  • $p_\theta(s, \cdot): \mathcal{A}(s)\to\mathbb{R}$ Additionally, for each $a\in\mathcal{A}(s)$ taken from $s$ during the any search through $m$, we store:
  • (mut) $n(s, a)\in\mathbb{N}$, the frequency of $(s, a)$
  • (mut) $\gamma(s, a)$ described below. Here, $\gamma(s, a)$ is the total predicted max gains from $sa$ over all searches where $a$ is taken from $s$. In other words, we increment $\gamma(s, a)$ by $\tilde{g}_j^\ast$ in the following circumstance:
  • a search in $m$ corresponds to a sequence $t = (t_i)_{i=1..\ell}$
  • $t_j = (a, r_j, s_j)$ and $s_{j-1} = s$.

Estimates

With tree $m$ rooted at $s_{\text{root}}$, we search $m$ for states which minimize $c$ . For our policy, we choose the upper estimate function defined by $$u(s, a) := \tilde{r}(s, a) + \overline{\gamma}(s, a) + C\cdot p(s, a)\cdot \dfrac{\sqrt{n(s)}}{1+n(s, a)}$$ where $\overline{\gamma}(s, a) = 0$ if $(s, a)$ is unvisited (i.e., $n(s, a) = 0$) and otherwise, $\overline{\gamma}(s, a) = \gamma(s, a) / n(s, a)$.