Monte Carlo Tree Search for Theorem Proving
Talk by Fred Rajasekaran.
Plan:
- AlphaZero Setup
- Modified MCTS
- Application to Theorem Proving
AlphaZero: Think of Chess/Go for now.
NeurakNetwork:
- $p_ \theta(s)$ policy.
- $V_ \theta(s)$ value.
- $S$: state of the game,
- $p_ \theta(s) = (p(a_ 1|s),\dots, p(a_ n|s) )$ probability distribution over actions from $s$,
- action $a$: move in the game,
- $V_ \theta(s)$: model's estimate of how good $s$ is.
Chess: $n=8\times8\times 73$.
Training examples: $(s_ t, \pi_ t, z_ t)_ {t=0}^T$, where
- $\pi_ t$: good policy from MCTS
- $s_ t$: state at time $t$
- $z_ t$: $\{-1,0,+1\}$
The loss function is $$\ell = \sum_ {t}((V_ \theta(s_ t) - Z_ t)^2 -\pi_t \cdot \textrm{log}p_ \theta(s_ t)).$$
The second term here is basically KL divergence, $$D_ {KL}(\pi(\cdot | s_ t) || p_ \theta(\cdot| s_ t)) -\sum_ a\pi_ t(a|s_ t)\textrm{log} p_ \theta(a|s_ t) + \sum_ a \pi_ t (a|s_ t)\log \pi_ t(a|s_ t).$$
So minimizing the loss function is to make $p_ \theta$ close to $\pi_ t$ and also make value function estimate the outcome more correctly.
MCTS
Goal: generate a good policy $\pi$:
Maintain
- $W(s,a)$: reward generated by state $s$, action $a$
- $N(s,a)$: the number of times taking action $a$ from $s$,
- $Q(s,a):= W(s,a)/N(s,a)$
- $P(s,a)$: $p_ \theta(a|s)$ initial estimate of goodness of taking action $a$ from $s$
Upper Confidence Bound (UCB): $$U(s,a) = Q(s,a) + C_ {PUCT}\cdot P(s,a)\frac{\sqrt{\sum_ b N(s,b)}}{1+N(s,a)}.$$
Algorithm: start at $s$,
- Compute $a^\ast = \mathrm{argmax}_ a U(s,a)$, do $a^\ast$, end up in $s'$.
2a. $s'$ is already in the tree, recurse.
2b. If not, add the new state to tree, initialize $$P(s',\cdot) = p_ \theta (s'),$$ $$V(s') = V_ \theta(s),$$ $$Q(s',a) = N(s',a) = 0, \forall a.$$
- Propagate up the path in tree and update $Q$-values: $s_ 0\xrightarrow{a_ 0}s_ 1\xrightarrow{a_ 1}\cdots\xrightarrow{a_ L}s_ L= s'$. If this is a 2-player game, $s_t$ is the $j$ steps above $s'$, and $$V_ t = (-1)^k V_ \theta(s').$$ For any $(s_ t, a_ t)$, $t=0,\dots,L-1$: $$N(s_ t,a_ t)\leftarrow N(s_ t,a_ t)+1,$$ $$W(s_ t,a_ t)\leftarrow W(s_ t,a_ t)+V_ t,$$ $$Q(s_ t,a_ t)\leftarrow \frac{W(s_ t, a_ t)}{N(s_ t, a_t)}.$$
Do this for $K$ iterations, and we get a tree with $K+1$ vertices.
Then $\pi(s,a) = N(s,a)^{\frac{1}{\tau}}/\sum_ b N(s,b)^{\frac{1}{\tau}}.$
Sample from $\pi$, play this action, come to state $s'$, and do MCTS, take action, iterate until completion.
AlphaZero used $K=800$,
- $\tau=1$: policy proportional to the number of visits.
- $\tau\rightarrow 0$: greedily selecting the most visited.
Store trajectory $(s_ t, \pi_ t, z_ t),$ $t=0,\dots, T$. Update $NN$ via the loss function $\ell$.
From MCTC to Theorem Proving
This part is slides talk.
How to deal with "infinity moves" issue?
Aristotle they append all proof histroy together with prompt.
How to make the game end? Train from easy examples and then increase the difficulty.