Fine-tuning LLMs via policy gradient algorithms

Shurui Liu

Talk by Shengtong Zhang (Stanford; Cursor).

Pretraining - Midtraining - Reinforcement learning

Problem Statement:

  • $x$: prompt
  • $\pi_ \theta$: LLM
  • Rollout: start from $x$, generate $o_ 1\sim\pi_ \theta(\cdot|x)$, and $o_ i\sim\pi_ \theta(\cdot | x, o_ 1,\dots, o_ {i-1})$, and <eos> = $o_ k\sim\pi_ \theta(\cdot | x, o_ 1,\dots, o_ {k-1})$, we obtain a trajectory $\tau = (o_ 1,\dots, o_ k)$.
  • $R(\tau)$ the reward of $\tau$.

Example 1.

For $R(\tau)$ is to look for $\langle \textrm{answer}\rangle ???? \langle \textrm{answer}\rangle$, and give $+1$ if the final answer is correct, or $0$ if the format is wrong or the answer is wrong.

Goal: $\textrm{max} \mathcal{J}(\theta) = \mathbb{E}_ \tau R(\tau) = \mathbb{E}_ {x\sim D}\mathbb{E}_ {\tau\sim \pi_ \theta(\cdot | x)}R(\tau)$.

Policy Gradient Algorithm

Goal: find $\Delta_ \theta \mathcal{J}(\theta)$.

By chain rule and log trick, one can easily see that $$\Delta_ \theta \mathcal{J}(\theta) = \sum_ {\tau=(o_ 1,\dots, o_ k)}\Delta_ \theta\pi_ \theta(o_ 1,\dots,o_ k)(\sum_ {i=1}^k \Delta \log \pi_ \theta (o_ 1| x, o_1, \dots, o_ {i-1}))R(\tau),$$ i.e.

$$\Delta_ \theta \mathcal{J}(\theta) =\mathbb{E}_ \tau R(\tau)(\sum_ {i=1}^k \Delta \log \pi_ \theta (o_ 1| x, o_1, \dots, o_ {i-1})),$$

In practice, one uses Monte-Carlo estimation method.