## A Monte Carlo Tree Search for Reward-Friendly Optimization Problems

# Table of Contents

This post is the whitepaper for a modification I am making to the AlphaZero algorithm for a family of optimization problems I am interested in. We expose the topic and end with a technical breakdown of the key data structure used.

## Background

Pick your favorite state space $\mathcal{S}$ and endow each $s\in\mathcal{S}$ with a set $\mathcal{A}(s)$ of actions. For each \(a\in\mathcal{A}(s)\), then $as\in \mathcal{S}$ is formed by taking $a$ from $s$. The value of $as$ may be deterministic or stochastic (e.g., in the context of single-player games) or policy-dependent (e.g., in the context of multiplayer games). If $\mathcal{A}(s) = \emptyset$, then $s$ is *terminal*. Additionally, for each terminal $s$, we have a valuation $v(s)$ which we seek to optimize.

### Chess

In this classic example, we play as White against an opponent playing with some fixed policy. Here,

- $\mathcal{S}$: All possible chess board positions with White to play.
- This includes piece positions, the states of castling, and information about $3$-fold repetition and the $50$-move rules.

- $\mathcal{A}$: All legal moves for White to play.
- $as$: When White takes move $a$, a board $s'$ with Black to play is formed; Black plays according to their policy, producing a board $as$ with White to play.
- When Black has no move, instead $as$ is the board $s'$ with a decoration that White has won or forced a stalemate.
- This requires a small modification to the defintion of $\mathcal{S}$.

- $v(s)$: If the game is over, a score of $+1,$ $0,$ or $-1$ indicates the winning player.

If White and Black play with respect to stochastic policies $\pi_W, \pi_B$, then we are encouraged to extend $v$ from terminal states to $$v_{\pi_W, \pi_B}(s) := \mathbb{E}[v(s') : \text{from }s, \text{terminal state }s'\text{ is reached when Black and White follow }\pi_W\text{ and }\pi_B,\text{respectively}]$$

### Space Invaders

In this example:

- $\mathcal{S}$: all possible single-frame game states in the game
*Space Invaders*- this includes positions and velocities of the player and all asteriods
- it also includes score and other internal game states

- $\mathcal{A}(s)$: combinations of button inputs legal in
*Space Invaders* - $as$: the result of playing a combination $a$ of buttons for a frame from state $s$
- $v(s)$: if the game is over, the player's current score.

While it may be tempting to extend $v$ to $\mathcal{S}$ with the *current score* function, we follow the previous convention:

$$v_{\pi}(s) := \mathbb{E}[v(s') : \text{from }s, \text{terminal state }s'\text{ is reached when the player follows }\pi]$$

## Policies

A policy maps each nonterminal state $s$ to a probability distribution over $\mathcal{A}(s)$. Some policies are better than others, as determined by component-wise comparison. That is, we may say that $\pi\leq \pi'$ if for all states $s$, $v_\pi(s)\leq v_{\pi'}(s)$. Finding optimal policies is intractible, even when $\mathcal{S}$ and $\mathcal{A}(s)$ are finite and state transitions are deterministic.

We are forced to make tradeoffs when defining our policy $\pi$:

**implementation**: how to store $\pi$ so it can quickly evaluate states, but also be fine-tuned with new data,**exploitation**: how much to trust that $\pi$ will lead us to the best terminal states if we follow it greedily, and**exploration**: how much to distrust $\pi$, taking less preferred actions in order to gather data counteracting our bias.

## Monte Carlo Tree Search

When building a policy $\pi$ from nothing, we conduct a tree search as following:

- $s_{\text{root}}$: select a root in $S$ and build a tree
**transitions**: Explore from the root, recording the transitions $(a_1, s_1), \cdots, (a_t, s_t)$.- For all $i=1..t$, $a_i\in\mathcal{A}(s_{i-1})$ is the action taken from $s_{i-1}$, which produced $s_i$.
- $v$: $s_t$ is terminal with value $v = v(s_t)$.

**biases**: When selecting an action $a\in\mathcal{A}(s)$, we bias ourselves to:**exploration**: Infrequently visited state-action pairs are preferable.- Log the values $n(s)$, the number of actions taken from $s$ during the tree's lifetime.
- Log also the values $n(s, a)$, the number of visits to $(s, a)$.
- If practical, we may have access to $n(sa)$ and do not need to define $n(s, a)$.

**score**: Actions which previously led to a high score are preferable.- Log $w(s, a)$, the sum of all terminal scores over all explored paths with the transition $(s, a)$.
- If practical, we may instead log $w(s)$ defined similarly and use $w(sa)$ to inform us about $a$.
- This complicates how the average of $w$ is calculated.

- We are biased to take $a$ from $s$ if $w(s, a) / n(s, a)$ is high.

- Log $w(s, a)$, the sum of all terminal scores over all explored paths with the transition $(s, a)$.

This begs several imporatant questions:

**prior distributions**: How should pick actions from unvisited nodes?- When the tree is initialized, all values of $n(s), n(s, a),$ and $w(s, a)$ are $0$.
*A priori*, we have no bias to suggest what action to take.- In most games, uniform distrubtions are terrible guesses.
- We encounter this problem each time a new (nonterminal) node is added to the tree.

**bias terms**: How should we pick actions from visited nodes?- Even when $s$ has been visited, what is biases us towards one action over another? For example, $$\dfrac{w(s, a)}{n(s, a)} + c\cdot\dfrac{\sqrt{n(s)}}{n(s, a)}.$$
- This choice is extremely arbitrary. Many others may perform better.

- How practically can we store information about every visited action from every visited state?
- How fast will we converge to an optimal policy, if at all?

- Even when $s$ has been visited, what is biases us towards one action over another? For example, $$\dfrac{w(s, a)}{n(s, a)} + c\cdot\dfrac{\sqrt{n(s)}}{n(s, a)}.$$
**learning**: Are we expected to trust our observations?- Is the policy identical to the entire tree?
- How practically can we use it in future simulations?

## AlphaZero

### Formulation

In this paper, the DeepMind team made crucial modifications to the classical MCTS search. First, they use an artificial neural network to define a model $$f_\theta: \mathcal{S}\to\mathbb{R}^{|\mathcal{A}| + 1}.$$

- By assumption, there exists a finite set $\mathcal{A}$ such that for all $s\in\mathcal{S}$, $\mathcal{A}(s)\subseteq\mathcal{A}$.
- Here, $\theta$ is a high-dimensional real-valued vector of weights and biases for the underlying artificial neural network.

For each $s\in\mathcal{S}$, we interpret $f_\theta(s)$ as the direct sum of $p_\theta(s, \cdot)$ and $v_\theta(s)$. Here:

- $p_\theta(s,\cdot)\in\mathbb{R}^{|\mathcal{A}|}$ may be interpreted as a probability distribution over $\mathcal{A}$.
- The actions in $\mathcal{A}\setminus\mathcal{A}(s)$ are ignored.
- We may choose to normalize with the softmax function or a simple average.

- $v_\theta(s)\in\mathbb{R}$: a prediction for the terminal value of $v$ when following a policy defined as a function of $\theta$.

**Note**: In fact, $p_\theta(s, \cdot)$ is also a prediction. The model $f_\theta$ estimates the probability distribution of our agent, whose policy is defined as a function of $f_\theta$!

### Learning loop

Rather than store large swaths of information in a Monte Carlo tree, our learning loop is as follows:

**plant many trees**: Create a batch of search trees to explore in parallel.- Select roots with a bias towards high values of $v_\theta$.
- Avoid batches of roots that are "similar" to each other. Perhaps $\mathcal{S}$ has a distance metric.

**trust the model**: The model may be used to enhance the search algorithm.- $p_\theta(s, a)$: Use this value to adjust the bias towards taking action $a$ from $s$.
- Dirichlet noise is added to encourage exploration.

- $v_\theta(s)$: When we encounter $s$, and is not terminal and unvisited, add $s$ to the tree and use $v_\theta(s)$.
- Since $s$ is not terminal, we may need to take many actions to reach a terminal state.
- Trust the model's prediction will improve over time.
- If $s$ is visited later, we will play moves from $a$.

**chopped trees**: Suppose a tree is rooted at $s_0$.- After a few thousands simulations, we have explored a few thousand paths.
- Many of those pathes ended by adding a new node to the tree and using $v_\theta$ to estimate the ending value.
- Many ended with a terminal state.
- Inside the tree, we have recorded $n(s_0)$ and $n(s_0, a)$ for all $a\in\mathcal{A}(s_0)$.
- We also have recorded $w(s_0)$ and $w(s_0, a)$ for all $a\in\mathcal{A}(s_0)$.
- We emit $p_{\text{obs}} = (n(s_0, a)/n(s_0))_{a\in\mathcal{A}}$.
- We also emit $v_{\text{obs}} = w(s_0)\cdot/n(s_0)$.

**model updates**: With a batch of observations, we adjust the parameters $\theta$ in our model to decrease a loss function.- Our batch consists of:
**roots**: $(s_i)_{i=1..B}$,**observed values**: $(v_{\text{obs, i}})_{i=1..B}$, and**observed probabilities**: $(p_{\text{obs, i}})_{i=1..B}$.

- Loss includes terms such as:
**mis-estimated values**: $(v_\theta(s_i) - v_{\text{obs, i}})^2$**cross-entropy**: $H(p, p) = -\sum_a p_{\text{obs, i}}(a)\cdot\log(p_\theta(s_i, a))$**overfitting**: $|\theta|_2^2$. (a.k.a.*regularization*)

- Our batch consists of:

`IRTree`

This is my tentative name for a modification of AlphaZero to the context of mathematical optimization. Here:

`I(nitial)`

: The value of $c(s)$ is only calculated directly for root states. All remaining values are computed dynamic with the reward function.`R(eward)`

: When a state $s$ is added to the tree, we immediately record the reward function $r(s, \cdot)$ into memory.

### Assumptions

- There exists a cost function $c:\mathcal{S}\to\mathbb{R}$ to minimize.
- For all $s\in\mathcal{S}$, the reward function $r(s, \cdot): \mathcal{A}(s)\to\mathbb{R}$ is convenient to compute as a function of $s$.
- Here, for all $a\in\mathcal{A}(s)$, $r(s, a) = c(s) - c(as)$.

### Transitions

Since we have a cheap reward function, we instead update the tree with improvement in the cost function *after* visiting $(s, a)$. Suppose $s_0\in\mathcal{S}$ and $\ell\in\mathbb{N}$. Suppose furthermore that $t = (t_i)_{i=1..\ell}\in(\mathcal{S}\times\mathbb{R}\times\mathcal{A})^\ell$ such that for all $i=1..\ell$:

- $t_i = (a_i, r_i, s_i)\in\mathcal{A}\times\mathbb{R}\times\mathcal{S}$,
- $a_i\in\mathcal{A}(s_{i-1})$ and $r_i = r(s_{i-1}, a_i)$, and
- $s_i = a_is_{i-1}$. Then we say that $t$ are the
*transitions (of a path of length $\ell$ corresponding to $c$) from $s_0$*. Then we define the*max gains*$g^\ast = g_t^\ast = (g_i^\ast)_{i=0..t}$ (of $c$, from $s_0$, along $t$) as follows: - $g_t^\ast = 0$, and
- $g_i^\ast = \max(0, r_{i+1} + g_{i+1}^\ast)$ for all $i = 0..t-1$. Note that for all $i = 0..t$, $$g_i^\ast = \max_{j = i..t}(r_{i+1} +\cdots+r_j).$$ Additionally, if $j^\ast$ attains the maximum with respect to $i$, then $j^\ast$ also attains $$\min_{j=i..t}c(s_j).$$ Finally, suppose $j$ attains the above maximum (or maxima) with respect to $i = 0$. Then we say that
*$t$ observes an improvement in $c$ of*.*$g_0^\ast$*from $s_0$ (at $s_j$)

Finally, suppose we have a model (with parameters $\theta$) which defines a prediction function $g_\theta^\ast:\mathcal{S}\to\mathbb{R}$. Then as above, we define the *predicted max gains* $h_{\theta}^\ast = h_{\theta, t}^{\ast} = (h_{i}^{\ast})_{i=0..t}$ (of $c$, from $s_0$, along $t$, with respect to $\theta$) as follows:

- $h_t^\ast = \max 0, g_\theta^\ast$, and
- $h_i^\ast = \max(0, r_{i+1} + h_{i+1}^\ast)$ for all $i = 0..t-1$.

For convenience, we extend $r = (r_i)_{i=1..t}$ to

$$r' = (r_i')_{i=1..t+1}$$

by letting $r_i' := r_i$ for all $i = 1..t$ and $r_{t+1}' := g_\theta^\ast(s_t)$.

As before, we note that for all $i = 0..t$,

$$h_i^\ast = \max_{j=i..t+1}(r_{i+1}'+\cdots+r_j').$$

For a final observation, let

$$c' := (c_i')_{i=0..t+1}$$

where $c_i' := c(s_i)$ for all $i = 0..t$ and $c_{t+1}' := c(s_t) + g_\theta^\ast(s_t)$. Then the maxima of $h_i^\ast$ and the minima of $c_i'$ are identical.

### Model

We model with prediction function $$f_\theta(s) = \left(\begin{array}{r} p_\theta(s, \cdot)\ g_\theta(s)\ g_\theta^\ast(s)\ \tilde{g}_\theta^\ast(s) \end{array}\right)$$

Here, we predict:

- $g_\theta(s)$: the expected value of $c(s) - c(s')$ if a search starting with $s$ ends at terminal node $s'$.
- $g_\theta^\ast(s)$: the maximum value of $c(s) - c(s')$ over all paths explored during the search.
- The number of paths explored is fixed.

- $\tilde{g}_\theta^\ast$: the expected maximum value of $c(s) - c(s')$ if a search starts with $s$ and passes through $s'$.

We iteratively search and modify a tree $m$ rooted at some $s_{\text{root}}\in\mathcal{S}$. Each search in $m$ begins at $s_0 = s_{\text{root}}$, corresponds to a sequence $t = (t_i)_{i=1..\ell}$ of transitions as defined above.

Inside $m$, we identify states by a canonical path from $s_{\text{root}}$. Given the decomposition of each reward function as $$r(s, \cdot) = f(s, \cdot) + k(s, \cdot),$$ we store for each visited $s\in\mathcal{S}$ the following unchanging information inside $m$, in some implicit form:

`(mut)`

$n(s)\in\mathbb{N}$, the frequency of $s$`(mut)`

$\tilde{r}(s, \cdot): \mathcal{A}(s)\to\mathbb{R}$, our best estimate for $r(s, \cdot)$- initially $f(s, \cdot)$, and $\tilde{r}(s, a)$ is replaced by $f(s, a) + k(s, a)$ when $(s, a)$ is visited

- $p_\theta(s, \cdot): \mathcal{A}(s)\to\mathbb{R}$ Additionally, for each $a\in\mathcal{A}(s)$ taken from $s$ during the any search through $m$, we store:
`(mut)`

$n(s, a)\in\mathbb{N}$, the frequency of $(s, a)$`(mut)`

$\gamma(s, a)$ described below. Here, $\gamma(s, a)$ is the total predicted max gains from $sa$ over all searches where $a$ is taken from $s$. In other words, we increment $\gamma(s, a)$ by $\tilde{g}_j^\ast$ in the following circumstance:- a search in $m$ corresponds to a sequence $t = (t_i)_{i=1..\ell}$
- $t_j = (a, r_j, s_j)$ and $s_{j-1} = s$.

### Estimates

With tree $m$ rooted at $s_{\text{root}}$, we search $m$ for states which minimize $c$ . For our policy, we choose the upper estimate function defined by $$u(s, a) := \tilde{r}(s, a) + \overline{\gamma}(s, a) + C\cdot p(s, a)\cdot \dfrac{\sqrt{n(s)}}{1+n(s, a)}$$ where $\overline{\gamma}(s, a) = 0$ if $(s, a)$ is unvisited (i.e., $n(s, a) = 0$) and otherwise, $\overline{\gamma}(s, a) = \gamma(s, a) / n(s, a)$.