Fine-tuning LLMs via policy gradient algorithms
Talk by Shengtong Zhang (Stanford; Cursor).
Pretraining - Midtraining - Reinforcement learning
Problem Statement:
- $x$: prompt
- $\pi_ \theta$: LLM
- Rollout: start from $x$, generate $o_ 1\sim\pi_ \theta(\cdot|x)$, and $o_ i\sim\pi_ \theta(\cdot | x, o_ 1,\dots, o_ {i-1})$, and
<eos>= $o_ k\sim\pi_ \theta(\cdot | x, o_ 1,\dots, o_ {k-1})$, we obtain a trajectory $\tau = (o_ 1,\dots, o_ k)$. - $R(\tau)$ the reward of $\tau$.
Example 1.
For $R(\tau)$ is to look for $\langle \textrm{answer}\rangle ???? \langle \textrm{answer}\rangle$, and give $+1$ if the final answer is correct, or $0$ if the format is wrong or the answer is wrong.
Goal: $\textrm{max} \mathcal{J}(\theta) = \mathbb{E}_ \tau R(\tau) = \mathbb{E}_ {x\sim D}\mathbb{E}_ {\tau\sim \pi_ \theta(\cdot | x)}R(\tau)$.
Policy Gradient Algorithm
Goal: find $\Delta_ \theta \mathcal{J}(\theta)$.
By chain rule and log trick, one can easily see that $$\Delta_ \theta \mathcal{J}(\theta) = \sum_ {\tau=(o_ 1,\dots, o_ k)}\Delta_ \theta\pi_ \theta(o_ 1,\dots,o_ k)(\sum_ {i=1}^k \Delta \log \pi_ \theta (o_ 1| x, o_1, \dots, o_ {i-1}))R(\tau),$$ i.e.
$$\Delta_ \theta \mathcal{J}(\theta) =\mathbb{E}_ \tau R(\tau)(\sum_ {i=1}^k \Delta \log \pi_ \theta (o_ 1| x, o_1, \dots, o_ {i-1})),$$
In practice, one uses Monte-Carlo estimation method.