On-Policy Distillation

Overview

On-policy distillation is a training approach where a student model learns from a teacher model’s logits distribution through knowledge distillation during the rollout phase. This technique enables efficient transfer of knowledge from larger or better-performing teacher models to smaller or faster student models on targeted datasets.

In Cosmos RL, on-policy distillation is integrated into the training pipeline where:

  1. The student model generates completions during rollout

  2. The teacher model receives all completions the student model generates to get the logits and probability distributions

  3. The student model receives the teacher’s logit probability distributions and is trained to match the teacher’s distribution using simple reserve KL or Jensen-Shannon Divergence (JSD) loss

  4. This happens within the standard training loop alongside other optimization objectives

Quick Start

Example Commands

DeepMath Dataset (Qwen3-8B)

To start an on-policy distillation job for the Qwen3-8B model on the DeepMath dataset:

cosmos-rl --config configs/qwen3/qwen3-8b-distill-deepmath.toml tools/dataset/deepmath_distill.py

Breaking down the command:

  • cosmos-rl: Main CLI entry point

  • --config: Path to the TOML configuration file

  • tools/dataset/deepmath_distill.py: Custom dataset and reward function script that handles data loading and evaluation

Countdown Dataset (Qwen2.5-1.5B)

To start an on-policy distillation job for the Qwen2.5-1.5B model on the Countdown dataset:

cosmos-rl --config configs/qwen2-5/qwen2-5-1.5b-distill-countdown.toml tools/dataset/countdown_distill.py

This example uses a smaller model (1.5B) with a different task (Countdown numbers game) compared to the DeepMath example.

Configuration File Structure

The configuration file (TOML format) contains multiple sections. Here’s the example section in the toml config for distillation:

[distillation]
enable = true
model_name_or_path = "Qwen/Qwen3-8B"
compile = true
mini_batch = 8
master_dtype = "float32"
param_dtype = "bfloat16"
logprob_dtype = "float32"
fsdp_reduce_dtype = "float32"
fsdp_offload = false
fsdp_reshard_after_forward = "default"
batch_size_per_replica = 16
top_k = 0
jsd_beta = 1
include_prompt = false
trainer_token_ids_from_teacher = true
rollout_top_k_recompute = false

[distillation.parallelism]
n_init_replicas = 1
tp_size = 1
cp_size = 1
dp_shard_size = 2
pp_size = 1
dp_replicate_size = 1

Distillation Parameters

Core Configuration

enable (boolean)

Enable/disable distillation during training. Set to true to activate the distillation loss.

model_name_or_path (string)

Path or HuggingFace model identifier for the teacher model. This is the model from which logits will be extracted during rollout.

Example: "Qwen/Qwen3-8B"

mini_batch (integer)

Batch size at each GPU for each forward execution when the teacher model generates logits and logprobs using the student-generated prompt and completion. Smaller values reduce peak memory usage during teacher inference.

Default: 1

batch_size_per_replica (integer)

Total number of samples processed per teacher model replica per distillation step. This may be first split to each GPU in the replica and then further split into multiple mini_batch forward passes for memory efficiency. For example, if batch_size_per_replica=32 and mini_batch=8 and there are 2 GPUs in the replica, the teacher will make 2 forward passes per GPU.

Formula: number_of_teacher_forward_passes = batch_size_per_replica // (number_of_gpus_in_replica * mini_batch)

Default: 8

Parallelism Configuration

The [distillation.parallelism] section configures GPU parallelism for distillation:

n_init_replicas (integer)

Number of parallel distillation workers. Keep at 1 unless running multi-node distillation.

tp_size (integer)

Tensor parallelism degree. Splits model parameters across GPUs. Use for very large models.

dp_shard_size (integer)

Data parallelism shard size. Number of GPUs for data parallelism.

Advanced Options

Teacher Model Setting Configuration

compile, master_dtype, param_dtype, logprob_dtype, fsdp_reduce_dtype, fsdp_offload, fsdp_reshard_after_forward

These parameters work the same as in the [train] section for normal model settings. They control compilation, mixed precision forward, FSDP behavior, and memory management for the teacher model. See the [train] section documentation for detailed explanations.

Loss & Sampling Algorithm Configuration

jsd_beta (float)

Interpolation coefficient between 0.0 and 1.0 of the Generalized Jensen-Shannon Divergence loss.

  • When beta is 0.0, the loss is the KL divergence

  • When beta is 1.0, the loss is the Inverse KL divergence

  • Values between 0 and 1 interpolate between these two extremes

Default: 0.5

top_k (integer)

Controls the distillation loss formulation:

  • When 0: Uses simple reverse KL for loss (as described in On-Policy Distillation)

  • When > 0: Uses the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div, restricting to top-k most likely tokens. See Eq. (1) of this paper for the definition

Default: 0

Model Training Framework Configuration

include_prompt (boolean)

Include the prompt tokens in the distillation loss computation. When false, only completion tokens are considered.

  • Set to true if you want the student to also match teacher’s prompt embeddings

  • Set to false (default) to focus on generation quality

Default: false

trainer_token_ids_from_teacher (boolean)

Whether the trainer gets all top_k token ids directly from its redis-connected teacher model during distillation rather than from the rollout structure. This can simplify the rollout payload when being transferred in the framework.

Note: When top_k <= 0, this parameter is automatically set to false.

Default: true

rollout_top_k_recompute (boolean)

Whether to recompute all top-k logprobs with top-k token ids after the full sequence is generated during rollout for distillation. This can ensure the completion generation process doesn’t keep large top-k values that would degrade generation efficiency.

Default: false

Launching on SLURM

You can launch an on-policy distillation task on SLURM using the Cosmos RL dispatch helper:

python $PATH_TO_COSMOS_RL_ROOT/tools/slurm/dispatch_job.py \
    --ngpu-per-node 8 \
    --config-path $PATH_TO_COSMOS_RL_ROOT/configs/qwen3/qwen3-8b-distill-deepmath.toml \
    --output-root-path $PATH_TO_COSMOS_RL_ROOT/s_output \
    --cosmos-container $SQSH_PATH \
    --slurm-partition batch \
    --slurm-account sw_aidot \
    --repo-root-path $PATH_TO_COSMOS_RL_ROOT \
    ./cosmos_rl/tools/dataset/deepmath_distill.py

Notes

  • Set $PATH_TO_COSMOS_RL_ROOT to your local Cosmos RL repository root.

  • Set $SQSH_PATH to the container image path (.sqsh).

  • You can swap the config and dataset entry script for other distillation tasks.

Launching on Lepton

For cloud-based training on NVIDIA Lepton, you can clone an existing reference job and customize it for your needs.

Reference Job

Start with the Qwen8B DeepMath distillation reference job:

Qwen8B DeepMath Distillation Job

Steps to Clone and Customize

  1. Clone the Job: - Click the “Clone” or “Create from Template” button in the job details view - This will create a copy of the job configuration with all settings pre-populated

  2. Customize Configuration: - Modify the job name and description for your experiment - Update the toml configuration file content and path (config.toml) - Update the launch entry script or module if needed (cosmos_rl.tools.dataset.deepmath_distill) - Adjust resource allocation (GPU count, memory) based on your model size and requirements

  3. Submit and Monitor: - Click “Submit Job” to launch the training - Monitor training progress through the Lepton dashboard and wandb logs

See Also

  • Single node example - Basic training setup

  • ../parallelism/index - Distributed training configuration

  • ../rollout/index - Rollout configuration details