Training with GRPOTrainer

Have you heard about DeepSeek's impressive new reasoning model? One of the key innovations behind its success is a new training technique called Group Relative Policy Optimization. This clever variant of PPO not only improves the model's training convergence but also uses less memory than traditional approaches - a core idea that helped DeepSeek's new R1 model achieve state-of-the-art performance with incredibly low training costs. Let's go over the basics of how to use GRPOTrainer, why it works and how it can be used in new model training.

GRPO is an online learning algorithm that improves iteratively using data generated by the model during training. The key innovation is that it eliminates the need for a critic model by estimating the baseline from group scores, significantly reducing training resources compared to PPO.

The algorithm works in four main steps:

  1. Generating completions: For each prompt in a batch, generate G different completions using the old policy
  2. Computing rewards: Score each completion using rule-based reward functions
  3. Normalizing advantages: Calculate the relative advantage by normalizing rewards using the group's mean and standard deviation
  4. Optimizing the policy: Update the model using a clipped PPO-like objective that keeps the new policy close to the old one

This approach is particularly efficient because it eliminates the need for a separate critic network, uses group statistics to establish baselines, maintains policy stability through clipping, and reduces memory usage compared to traditional PPO.

"Aha Moment"

During DeepSeek-R1's training, researchers observed what they called an "aha moment" - the emergence of sophisticated multi-step reasoning behavior through GRPO training. The model demonstrated a form of emergent chain-of-thought reasoning where it would:

  1. Generate an initial solution approach
  2. Automatically verify the correctness of its reasoning
  3. Identify potential flaws or gaps in logic
  4. Generate alternative approaches when necessary
  5. Compare multiple solution paths before settling on a final answer

This behavior manifested particularly in mathematical and logical reasoning tasks. For example, when solving complex word problems. What made this particularly interesting was that this verification behavior emerged purely through the GRPO reward optimization process, without any explicit chain-of-thought prompting or human feedback. The group-relative advantage computation naturally encouraged the model to develop more thorough reasoning strategies, as solutions that included verification steps consistently achieved higher rewards within their comparison groups.

Four key emergent behaviors were observed:

  • Self-verification: The model learned to automatically check its work and validate its answers.
  • Extended chain-of-thought: Solutions became more detailed with explicit step-by-step reasoning.
  • Exploratory reasoning: The model would try multiple approaches before settling on the best solution.
  • Reflection: The model developed the ability to question and revise its own reasoning process.

This emergent behavior aligns with theoretical predictions about the benefits of group-relative policy optimization: by comparing multiple solution attempts within the same context group, the training process naturally selects for more robust and self-correcting reasoning patterns. The KL divergence constraint ensures these patterns develop gradually and stably, rather than through sudden policy shifts.

In simpler terms, it's like the model learned to "check its work" all on its own - similar to how a careful student might solve a math problem, then go back to make sure their answer makes sense, and try a different approach if something seems off. The training method essentially rewarded this careful, thorough approach by comparing different solution attempts side by side, like a teacher grading multiple versions of the same homework assignment and giving better scores to students who show their work and verify their answers. The model wasn't explicitly taught to do this - it just naturally developed these good study habits because they consistently led to better outcomes!

Theory

GRPO is an online learning algorithm that iteratively improves by using data generated by the model during training. The key innovation of GRPO is its approach to estimating advantages through group-based normalization and its efficient handling of KL divergence constraints.

At its core, GRPO operates by first generating multiple completions for each input prompt. For each prompt in a batch, the model generates \(G\) different completions using the current policy \(\pi_{\theta}\). These completions form a "group" that will be used for relative comparison, which is fundamental to how GRPO estimates advantages.

The advantage estimation process is what gives GRPO its name. For each completion, a reward \(r_i\) is computed using a reward model. Instead of using a critic network to estimate advantages, GRPO normalizes rewards within each group:

$$\hat{A}_{i,t} = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})}$$

This group-relative normalization approach is particularly effective as it naturally compares rewards relative to other completions in the same group, providing a built-in baseline for advantage estimation.

To ensure stable training, GRPO employs KL divergence estimation using Schulman's approximator. This term helps ensure the policy doesn't deviate too far from the reference policy during training. The complete GRPO loss function combines advantage maximization with this KL divergence penalty.

By eliminating the need for a separate critic model, it achieves better memory efficiency. The group-relative advantage computation helps reduce variance in training, while the KL divergence term prevents policy collapse. The normalization within groups makes the method robust to reward scaling, and the single update per generation (compared to multiple updates in PPO) simplifies implementation. This design makes GRPO particularly well-suited for integration with existing language model architectures.

To put it in simpler terms, imagine GRPO as a teacher grading essays in a classroom. Instead of having a separate teaching assistant (critic model) help grade papers, the teacher looks at small groups of essays together. For each group, they rank the essays relative to each other - which ones are better or worse compared to the average of that specific group. This is like normalizing the rewards within groups.

The teacher also makes sure not to completely change their grading criteria between assignments (KL divergence constraint) - they want to stay somewhat consistent with how they've graded before. This helps students understand what's expected and improve gradually rather than facing wildly different standards each time.

This approach is more efficient than traditional methods because you don't need that extra teaching assistant, and comparing essays within small groups is more manageable than trying to compare every essay to every other essay. It's like the difference between grading on a curve for each small discussion section versus trying to curve grades across the entire university!

Usage

Here's how to use GRPO via the TRL framework:

from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer

# Load your dataset
dataset = load_dataset("trl-lib/tldr", split="train")

# Configure training arguments
training_args = GRPOConfig(
    output_dir="output",
    num_train_epochs=3,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=1e-6,
    logging_steps=10,
    num_generations=8,
    max_prompt_length=512,
    max_completion_length=256,
    temperature=0.9,
    beta=0.04,  # KL penalty coefficient
)

# Initialize trainer
trainer = GRPOTrainer(
    model="your/base/model",
    args=training_args,
    train_dataset=dataset,
    reward_funcs=reward_func,
)

# Start training
trainer.train()

The key training parameters are:

  • num_generations: Number of completions to generate per prompt (default: 8)
  • beta: KL penalty coefficient to control how far the model can deviate from reference policy (default: 0.04)
  • temperature: Sampling temperature for generation (default: 0.9)
  • max_prompt_length: Maximum length for input prompts (default: 512)
  • max_completion_length: Maximum length for generated completions (default: 256)

GRPO logs several important metrics during training:

  • completion_length: Average completion length
  • reward/{reward_func_name}: Reward computed by each reward function
  • reward: Average reward
  • reward_std: Average standard deviation within reward groups
  • kl: Average KL divergence between model and reference model

To accelerate generation, which is often the main bottleneck in online methods, you can use vLLM:

from trl import GRPOConfig

training_args = GRPOConfig(
    # ... other args ...
    use_vllm=True,
    vllm_gpu_memory_utilization=0.9,
)

You can also define custom reward functions to guide the training. The function must:

  1. Accept prompts and completions as keyword arguments
  2. Return a list of floats representing rewards for each completion

Example reward function that rewards longer completions:

def reward_func(completions, **kwargs):
    """Reward function that evaluates natural text quality based on multiple factors:
    - Length (with diminishing returns after target length)
    - Readability metrics
    - Sentence structure variety
    - Proper punctuation
    """
    rewards = []
    target_length = 200  # Ideal completion length
    
    for completion in completions:
        # Base score
        score = 0.0
        
        # Length reward with diminishing returns
        length = len(completion)
        if length < target_length:
            length_score = length / target_length
        else:
            length_score = 1.0 + 0.1 * math.log(length / target_length)
        score += length_score
        
        # Reward sentence variety (look for mix of short and long sentences)
        sentences = [s.strip() for s in completion.split('.') if s.strip()]
        if sentences:
            lengths = [len(s) for s in sentences]
            variance = statistics.variance(lengths) if len(lengths) > 1 else 0
            score += min(0.5, variance / 500)  # Cap variance reward
            
        # Reward proper punctuation
        if any(p in completion for p in ['.', '!', '?']):
            score += 0.3
            
        # Penalize very short sentences or fragments
        if any(len(s) < 10 for s in sentences):
            score -= 0.2
            
        rewards.append(float(score))
        
    return rewards

GRPO supports multiple reward functions - just pass them as a list:

trainer = GRPOTrainer(
    reward_funcs=[reward_func1, reward_func2],
    # rest of arguments
)

The final reward will be computed as the sum of rewards from each function.