Training Small R1-like Reasoning Models
So DeepSeek just dropped their R1 model and it's kind of a big deal in open-source AI. It shows that we can train near-frontier reasoning models that can actually perform pretty well when trained using very limited hardware and small reasoning datasets. I'm going to walk through how we can recreate that magic "aha moment" by taking a smaller Llama model and fine-tuning it with their technique to produce reasoning traces that exhibit emergent reflection and backtracking behavior. The key ideas are actually pretty straightforward once you break them down.
At the heart of R1's training is Group Relative Policy Optimization (GRPO), which was released earlier this year in a previous paper. Unlike traditional Proximal Policy Optimization, GRPO eliminates the need for a value function model, instead estimating baselines from group scores. This is done by generating multiple outputs for a single prompt and then scoring each of these outputs. The average reward of the generated outputs is used as a baseline. The advantage of each solution within the group is computed relative to this baseline, and then the policy is optimized to maximize the GRPO objective, which includes the calculated advantages and a KL divergence term. This is the core innovation that lead to the breakthrough and shows promise for future models.
The full training of DeepSeek R1 involves a multi-stage approach. It begins with pure reinforcement learning, where the model learns to reason without any prior supervised fine-tuning. This stage produces a model the original papers calls DeepSeek-R1-Zero, which is trained exclusively with RL using rule-based accuracy and format rewards. This initial phase leads to the emergence of reasoning traces in the model's responses without any explicit training on this reasoning structure, which was surprising and impressive result.
Using only question and answer pairs, it is possible to generate Chain-of-Thought reasoning from scratch by leveraging GRPO. This approach allows us to bootstrap the reasoning process, which means creating the reasoning trace without needing a large dataset with explicit step-by-step solutions. The process begins with a model that might not initially exhibit reasoning, but through GRPO and rule-based reward functions, it learns to generate intermediate steps that reveal part of the thought process. Instead of relying on human-annotated CoT data, the model generates multiple outputs for a single question and is rewarded for both its accuracy and correct formatting. By iteratively refining its outputs based on these rewards, the model learns to self-generate reasoning traces and self-improve its problem-solving abilities. This is a big step towards the so-called "reasoning data flywheel" that many people have been hypothesizing about for several years now.
The model's responses follow a specific XML structure - first providing its reasoning steps within <reasoning>...</reasoning>
tags, followed by the final answer within <answer>...</answer>
tags. The goal of this training is to enable the model to develop self-verification and search abilities autonomously, with clear separation between its reasoning process and final answer.
The original DeepSeek R1 paper describes two rule-based reward functions used to train the model: a format reward and an accuracy reward.
- Format Reward: This reward function checks if the generated text has the correct format, specifically whether the reasoning steps are enclosed in
<reasoning>...</reasoning>
tags and the final answer in<answer>...</answer>
tags. The format reward assigns a score of 1.0 if the format is correct, and 0.0 otherwise. - Accuracy Reward: This reward function evaluates the mathematical correctness of the generated equation within the
<answer>
tags. It extracts the equation from the<answer>
tag, checks if all the provided numbers are used exactly once, and evaluates the equation. If the equation is correct and matches the target number, the reward is 1.0; otherwise, it is 0.0.
Following this, the model undergoes additional SFT and RL stages, using the R1-Zero output to generate "cold-start" SFT data for instruction fine-tuning and further RL training. After the additional RL stage, the model undergoes another round of SFT data collection. In this phase, the most recent model checkpoint is used to generate 600K CoT SFT examples which are then used to train the model on the reasoning traces.
We're only going to focus on fine-tuning a small langauge model (llama-3.1-8b) using GRPO which is step 1 of this multi-stage process. However ther are many good open source reasoning datasets out there that you can experiment with including:
- openai/gsm8k
- AI-MO/NuminaMath-CoT
- SkunkworksAI/reasoning-0.01
- open-thoughts/OpenThoughts-114k
- livebench/reasoning
- open-r1/OpenR1-Math-220k
- AI-MO/NuminaMath-1.5M
Environment Setup
To get started, we'll need to set up our Python environment with the required dependencies. The core of our setup relies on PyTorch 2.5.1 with CUDA 12.1 support to use the latest A100/H100 training hardware with flash attention.
For the training pipeline, we need the core Hugging Face libraries including datasets, transformers, and peft. For training optimization, you'll need either:
- unsloth for single/low GPU setups
- deepspeed 0.15.4 and accelerate 1.3.0 for multi-GPU clusters
Most importantly, we need a specific nightly build of trl (e95f9fb
) that contains the latest GRPOTrainer implementation with bug fixes and optimizations. We'll also need vllm 0.7.0 for efficient inference capabilities.
Here's the complete setup:
pip install "torch==2.5.1" --index-url https://download.pytorch.org/whl/cu121
pip install flash-attn
pip install datasets transformers peft
pip install "deepspeed==0.15.4"
pip install "accelerate==1.3.0"
pip install "vllm==0.7.0"
pip install git+https://github.com/huggingface/trl.git@e95f9fb74a3c3647b86f251b7e230ec51c64b72b
# Optional
# pip install unsloth
Data Preparation
We're going to use the data preperation step developed by @willccb to post-process the gsm8k dataset, but you could easily adapt this to the other reasoning datasets if you just adjust to their unique reasoning format structure.
The data pipeline transforms the GSM8K dataset from its original format into a structured XML-based Chain-of-Thought format. The source data contains mathematical word problems paired with step-by-step solutions that conclude with a final answer denoted by "####". The preprocessing converts this into a ChatML format where the reasoning steps are wrapped in XML tags.
We're using the Llama system prompt that enforces the target format, requiring responses to contain reasoning steps within <reasoning>
tags followed by the final answer in <answer>
tags. The extraction functions parse both the original format (using the #### delimiter) and the target XML format to maintain consistency and enable reward computation during training.
For example, a GSM8K problem about calculating wages gets transformed from:
Bob gets paid $5 an hour for regular hours and $6 for overtime...
...
Bob got paid $6/hour x 12 hours = $72 for overtime
In 2 weeks Bob earned $400 + $72 = $472 #### 472
Into the structured format:
<reasoning>
Bob has worked 40 hours/week x 2 weeks = 80 regular hours
Bob got paid 80 hours x $5/hour = $400 for his regular hours
...
In 2 weeks Bob earned $400 + $72 = $472
</reasoning>
<answer>
472
</answer>
Reward Functions
The training process uses five complementary reward functions which we define in reward.py
. These guide the model towards generating well-structured trajectories that end in the correct answer. Unlike a proceess reward model we don't need to define a step-wise value function that grades each step which is a lot simpler to work with.
-
Correctness Reward (2.0 points): The primary reward function that evaluates answer accuracy. It extracts the answer from the XML structure and compares it to the ground truth, awarding 2.0 points for exact matches. This is the most heavily weighted reward to emphasize the importance of mathematical correctness.
-
Integer Format Reward (0.5 points): This function ensures the answer contains only digits, awarding 0.5 points for proper numerical responses. This helps prevent the model from including units, explanatory text, or other non-numeric content in the answer section.
-
Strict Format Reward (0.5 points): Enforces precise XML formatting using regex pattern matching. It checks for exact newline placement and proper tag structure, requiring the format:
<reasoning>
[reasoning steps]
</reasoning>
<answer>
[answer]
</answer>
-
Soft Format Reward (0.5 points): A more lenient formatting check that allows for flexible whitespace while still ensuring the basic XML structure is maintained. This provides a smoother learning gradient when the model is still learning the proper format.
-
XML Component Reward (up to 0.5 points): Provides fine-grained scoring by examining individual XML components. It awards 0.125 points for each correctly placed tag (
<reasoning>
,</reasoning>
,<answer>
,</answer>
), with small penalties for trailing content after the closing tags. This granular approach helps guide the model toward proper XML structure during training.
These reward functions combine to create a comprehensive scoring system that can award up to 4.0 points per response. The correctness reward provides the primary learning signal, while the formatting rewards ensure the model maintains the structured output format that makes its reasoning process explicit and machine-readable.
It's likely that DeepSeek has a suite of much more sophisticated reward functions that grade output's answer beyond just strict string similarity. A core component is probably using an LLM as a judge to compare answers in cases where precise string matching isn't sufficient. This approach would allow the reward system to understand semantic equivalence between different phrasings and verify mathematical correctness even when the exact wording differs. While they hint at using more advanced reward mechanisms in their paper, the exact details remain unpublished. The reward functions likely include checks for mathematical consistency across reasoning steps, verification that intermediate calculations support the final answer, and rewards for showing complete work rather than just stating results. They may also employ learned reward models trained on human preferences about high-quality mathematical reasoning. However, without access to their internal training details, we can only speculate about the full scope of their reward engineering, though it's likely much more sophisticated than what we've implemented here.
Training
The actual training implementation uses the trl
library, which provides a convenient GRPOTrainer class that handles the complexities of Group Relative Policy Optimization. The training setup combines multiple reward functions that work together to guide the model toward generating well-structured reasoning traces.
The GRPOTrainer is initialized with our base model, tokenizer, and a list of reward functions that evaluate different aspects of the model's output:
trainer = GRPOTrainer(
model=model_name,
processing_class=tokenizer,
reward_funcs=[
correctness_reward_func,
int_reward_func,
strict_format_reward_func,
soft_format_reward_func,
xmlcount_reward_func,
],
args=training_args,
train_dataset=data,
)
During training, the GRPOTrainer generates multiple completions for each prompt (controlled by num_generations
in the config). These completions form a "group" that's used to compute the baseline reward. Each completion is evaluated by all reward functions, and their scores are combined to create a comprehensive reward signal that considers both answer correctness and formatting.
The training configuration can be customized through the GRPOConfig class, which allows you to set parameters like learning rate, batch size, and the number of generations per prompt:
training_args = GRPOConfig(
learning_rate=5e-6,
per_device_train_batch_size=1,
gradient_accumulation_steps=1,
num_generations=6, # Number of completions per prompt
max_prompt_length=256,
max_completion_length=200,
max_steps=250,
save_steps=250,
max_grad_norm=0.1,
)
Now if you want to run the training across your GPU cluster, you can use the following zero3.yaml
config file to distribute the training across multiple GPUs using deepspeed. You'll need at least one dedicated GPU for the vllm inference, and then 3xH100s for the training. It takes at least 50 steps for the Llama-3.1-8B model to start to even get the resoning format correctly but it will converge fairly quickly after that. After that
output_dir: outputs/Llama-3.1-8B-Reasoning
model_name_or_path: meta-llama/Llama-3.1-8B-Instruct
model_revision: main
# Hyperparameters
learning_rate: 5.0e-7
lr_scheduler_type: cosine
max_steps: 450
gradient_accumulation_steps: 8
per_device_train_batch_size: 1
warmup_ratio: 0.03
beta: 0.001
max_prompt_length: 256
max_completion_length: 1024
num_generations: 8
seed: 1337
# Inference Compute
use_vllm: true
vllm_gpu_memory_utilization: 0.5
vllm_device: "cuda:3"
# Memory Usage
torch_dtype: bfloat16
attn_implementation: flash_attention_2
bf16: true
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
# Logging
logging_strategy: steps
logging_steps: 2
report_to:
- wandb
# Saving
save_strategy: "steps"
save_steps: 25
Now you can run the training script with the following command:
#!/usr/bin/env bash
set -e
nohup accelerate launch \
--num_processes 3 \
--config_file configs/accelerate_configs/deepspeed_zero3.yaml train_trl.py \
--config zero3.yaml
I've also incldued a Unsloth implementation (which uses more efficent CUDA kernels and a PEFT LoRA) which can theoretically use a single 1xA100 to train the model. However, I'm not enitrely certain you will be able to run it successfully in conjunction with vllm without running into out of memory errors. But theoretically it should be possible.
We can however use the unsloth library to efficiently test the model, either load the model directly from your snapshot or load the LoRA adaptor if using the low-GPU setup.
from unsloth import FastLanguageModel
from vllm import SamplingParams
from prep_data import SYSTEM_PROMPT
LORA_NAME = "grpo_saved_lora"
SAMPLE_PROMPT = "Which is greater 9.10 or 9.9?"
model, tokenizer = FastLanguageModel.from_pretrained(
# model_name="meta-llama/meta-Llama-3.1-8B-Instruct",
model_name="./outputs/Llama-3.1-8B-Reasoning",
load_in_4bit=True,
fast_inference=True,
gpu_memory_utilization=0.6,
)
text = tokenizer.apply_chat_template(
[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": SAMPLE_PROMPT},
],
tokenize=False,
add_generation_prompt=True,
)
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=1024,
)
output = model.fast_generate(
text,
sampling_params=sampling_params,
# If using a LoRA adaptor
# lora_request=model.load_lora(LORA_NAME),
)
answer = output[0].outputs[0].text
print(answer)
This is only a simple first pass at building the training pipeline for a small reasoning model, but it's a good start. From here you would then want to implement the subsequent RL and SFT stages to produce a full R1-like model. The full source code for this experiment is available on Github.