LoRA Post-training for Sports Video Generation
Author: Arslan Ali Organization: NVIDIA
| Model | Workload | Use Case |
|---|---|---|
| Cosmos Predict 2.5 | Post-training | Sports Video Generation |
This guide provides instructions on running LoRA (Low-Rank Adaptation) post-training with the Cosmos Predict 2.5 models for sports video generation tasks, supporting Text2World, Image2World, and Video2World generation modes.
Motivation
While the base Cosmos Predict 2.5 model excels at general video generation, sports content demands specialized understanding of athletic dynamics and game rules. Post-training addresses critical gaps in player kinematic realism and physics, ensuring natural body movements and accurate ball trajectories. The adapted model achieves higher rule-coherence scores by respecting sport-specific constraints like offside lines, field boundaries, and valid player positions. Additionally, post-training significantly improves identity consistency, maintaining stable player appearances, jersey numbers, and team colors throughout generated sequences—essential for realistic sports simulation and analysis applications.
Table of Contents
- Prerequisites
- What is LoRA?
- Preparing Data
- LoRA Post-training
- Configuration
- Training
- Inference with LoRA Post-trained checkpoint
- Converting DCP Checkpoint to Consolidated PyTorch Format
- Running Inference
Prerequisites
1. Environment Setup
Follow the Setup guide for general environment setup instructions, including installing dependencies.
2. Hugging Face Configuration
Model checkpoints are automatically downloaded during post-training if they are not present. Configure Hugging Face as follows:
# Login with your Hugging Face token (required for downloading models)
hf auth login
# Set custom cache directory for HF models
# Default: ~/.cache/huggingface
export HF_HOME=/path/to/your/hf/cache
💡 Tip: Ensure you have sufficient disk space in
HF_HOME.
3. Training Output Directory
Configure where training checkpoints and artifacts will be saved:
# Set output directory for training checkpoints and artifacts
# Default: /tmp/imaginaire4-output
export IMAGINAIRE_OUTPUT_ROOT=/path/to/your/output/directory
💡 Tip: Ensure you have sufficient disk space in
IMAGINAIRE_OUTPUT_ROOT.
Weights & Biases (W&B) Logging
By default, training will attempt to log metrics to Weights & Biases. You have several options:
Option 1: Enable W&B
To enable full experiment tracking with W&B:
- Create a free account at wandb.ai
- Get your API key from https://wandb.ai/authorize
-
Set the environment variable:
⚠️ Security Warning: Store API keys in environment variables or secure vaults. Never commit API keys to source control.
Option 2: Disable W&B
Add job.wandb_mode=disabled to your training command to disable wandb logging.
What is LoRA?
LoRA (Low-Rank Adaptation) is a parameter-efficient fine-tuning technique that allows you to adapt large pre-trained models to specific domains or tasks by training only a small number of additional parameters.
Key Benefits of LoRA Post-Training
- Memory Efficiency: Only trains ~1-2% of total model parameters
- Faster Training: Significantly reduced training time per iteration
- Storage Efficiency: LoRA checkpoints are much smaller than full model checkpoints
- Flexibility: Can maintain multiple LoRA adapters for different domains
- Preserved Base Capabilities: Retains the original model's capabilities while adding domain-specific improvements
When to Use LoRA vs Full Fine-tuning
Use LoRA when:
- You have limited compute resources
- You want to create domain-specific adapters
- You need to preserve the base model's general capabilities
- You're working with smaller datasets
Use full fine-tuning when:
- You need maximum model adaptation
- You have sufficient compute and storage
- You're making fundamental changes to model behavior
1. Preparing Data
1.1 Understanding Training Data Requirements
The training approach uses the same video dataset to train all three generation modes:
- Text2World (0 frames): Uses only text prompts, videos serve as ground truth for reconstruction
- Image2World (1 frame): Uses first frame as condition, generates remaining frames
- Video2World (2+ frames): Uses initial frames as condition, continues the video generation
1.2 Dataset Location
The sports dataset should be organized in a directory structure that you'll specify in the configuration. Set your dataset path to point to your video collection:
Replace this path with the actual location of your sports video dataset. The dataset should contain sports video clips in MP4 format at 720p resolution (704x1280). For the current training, we prepared approximately 4,350 training clips and 50 clips for validation/inference to support both Image2World and Video2World generation tasks.
1.3 Dataset Formats
The system supports two caption formats:
Text Format (.txt files)
- Simple text files containing one caption per file
- Files should be placed in a
metas/directory - Filename should match the video filename (e.g.,
video1.mp4→video1.txt)
JSON Format (.json files)
- More flexible format supporting multiple prompt variations
- Files should be placed in a
captions/directory - Supports long, short, and medium prompt types
JSON Caption File Format:
{
"model_name": {
"long": "Detailed description of the sports action...",
"short": "Brief summary of the sports play...",
"medium": "Moderate length description of the sports scene..."
}
}
2. LoRA Post-training
2.1 Configuration
Two configurations for sports are provided:
predict2_lora_training_2b_cosmos_sports_assets_txt- For text caption formatpredict2_lora_training_2b_cosmos_sports_assets_json_rank32- For JSON caption format with long prompts
The configurations can be found in the comsmos predict-2.5 github.
Complete Sports Configuration
from imaginaire.lazy_config import LazyCall as L
from projects.cosmos.predict2.datasets.local_datasets.dataset_video import (
VideoDataset,
get_generic_dataloader,
get_sampler,
)
# Dataset configuration for sports videos (Text format)
example_dataset_cosmos_sports_assets_lora_txt = L(VideoDataset)(
dataset_dir="/path/to/sports/videos",
num_frames=93,
video_size=(704, 1280),
)
# Dataset configuration for sports videos (JSON format with long prompts)
example_dataset_cosmos_sports_assets_lora_json = L(VideoDataset)(
dataset_dir="/path/to/sports/videos",
num_frames=93,
video_size=(704, 1280),
caption_format="json",
prompt_type="long",
)
# Dataloader configuration
dataloader_train_cosmos_sports_assets_lora_txt = L(get_generic_dataloader)(
dataset=example_dataset_cosmos_sports_assets_lora_txt,
sampler=L(get_sampler)(dataset=example_dataset_cosmos_sports_assets_lora_txt),
batch_size=1,
drop_last=True,
num_workers=4,
pin_memory=True,
)
dataloader_train_cosmos_sports_assets_lora_json = L(get_generic_dataloader)(
dataset=example_dataset_cosmos_sports_assets_lora_json,
sampler=L(get_sampler)(dataset=example_dataset_cosmos_sports_assets_lora_json),
batch_size=1,
drop_last=True,
num_workers=4,
pin_memory=True,
)
# Model configuration with LoRA
_lora_model_config = dict(
config=dict(
# Enable LoRA training
use_lora=True,
# LoRA configuration parameters
lora_rank=32, # Rank of LoRA adaptation matrices (higher for sports complexity)
lora_alpha=32, # LoRA scaling parameter
lora_target_modules="q_proj,k_proj,v_proj,output_proj,mlp.layer1,mlp.layer2",
init_lora_weights=True, # Properly initialize LoRA weights
# Training configuration for all three modes
# The model will randomly sample between 0, 1, and 2 conditional frames during training
min_num_conditional_frames=0, # Allow text2world (0 frames)
max_num_conditional_frames=2, # Allow up to video2world (2 frames)
# Probability distribution for sampling number of conditional frames
# This controls how often each mode is trained:
# - 0 frames: text2world (33.3%)
# - 1 frame: image2world (33.3%)
# - 2 frames: video2world (33.4%)
conditional_frames_probs={0: 0.333, 1: 0.333, 2: 0.334},
# Optional: set conditional_frame_timestep for better control
conditional_frame_timestep=-1.0, # Default -1 means not effective
# Keep the default conditioning strategy
conditioning_strategy="frame_replace",
denoise_replace_gt_frames=True,
),
)
# Training configuration
_lora_trainer = dict(
logging_iter=100,
max_iter=10000,
callbacks=dict(
heart_beat=dict(save_s3=False),
iter_speed=dict(hit_thres=1000, save_s3=False),
device_monitor=dict(save_s3=False),
every_n_sample_reg=dict(every_n=1000, save_s3=False),
every_n_sample_ema=dict(every_n=1000, save_s3=False),
wandb=dict(save_s3=False),
wandb_10x=dict(save_s3=False),
dataloader_speed=dict(save_s3=False),
),
)
# Optimizer configuration
_lora_optimizer = dict(
lr=0.0001,
weight_decay=0.001,
)
# Scheduler configuration
_lora_scheduler = dict(
f_max=[0.5],
f_min=[0.2],
warm_up_steps=[2_000],
cycle_lengths=[100000],
)
# Checkpoint configuration
_lora_checkpoint_base = dict(
load_path=get_checkpoint_path(DEFAULT_CHECKPOINT.s3.uri),
load_from_object_store=dict(enabled=False),
save_to_object_store=dict(enabled=False),
save_iter=1000, # Save checkpoint every 1000 iterations
)
# Model parallel configuration
_lora_model_parallel = dict(
context_parallel_size=1,
)
# Complete experiment configurations
predict2_lora_training_2b_cosmos_sports_assets_txt = dict(
defaults=_lora_defaults,
job=dict(
project="cosmos_predict_v2p5",
group="lora",
name="2b_cosmos_sports_assets_lora",
),
dataloader_train=dataloader_train_cosmos_sports_assets_lora_txt,
checkpoint=_lora_checkpoint_base,
optimizer=_lora_optimizer,
scheduler=_lora_scheduler,
trainer=_lora_trainer,
model=_lora_model_config,
model_parallel=_lora_model_parallel,
)
predict2_lora_training_2b_cosmos_sports_assets_json_rank32 = dict(
defaults=_lora_defaults,
job=dict(
project="cosmos_predict_v2p5",
group="lora",
name="2b_cosmos_sports_assets_json_lora",
),
dataloader_train=dataloader_train_cosmos_sports_assets_lora_json,
checkpoint=_lora_checkpoint_base,
optimizer=_lora_optimizer,
scheduler=_lora_scheduler,
trainer=_lora_trainer,
model=_lora_model_config,
model_parallel=_lora_model_parallel,
)
2.2 Training
Run the LoRA post-training using one of the following configurations:
Using Text Caption Format
torchrun --nproc_per_node=8 --master_port=12341 -m scripts.train \
--config=projects/cosmos/predict2/configs/video2world/config.py -- \
experiment=predict2_lora_training_2b_cosmos_sports_assets_txt
Using JSON Caption Format with Long Prompts
torchrun --nproc_per_node=8 --master_port=12341 -m scripts.train \
--config=projects/cosmos/predict2/configs/video2world/config.py -- \
experiment=predict2_lora_training_2b_cosmos_sports_assets_json_rank32
Disabling W&B Logging
Add job.wandb_mode=disabled to disable wandb:
torchrun --nproc_per_node=8 --master_port=12341 -m scripts.train \
--config=projects/cosmos/predict2/configs/video2world/config.py -- \
experiment=predict2_lora_training_2b_cosmos_sports_assets_txt \
job.wandb_mode=disabled
Checkpoints are saved to:
- Text format:
${IMAGINAIRE_OUTPUT_ROOT}/cosmos_predict_v2p5/lora/2b_cosmos_sports_assets_lora/checkpoints - JSON format:
${IMAGINAIRE_OUTPUT_ROOT}/cosmos_predict_v2p5/lora/2b_cosmos_sports_assets_json_lora/checkpoints
Note: By default, IMAGINAIRE_OUTPUT_ROOT is /tmp/imaginaire4-output. We strongly recommend setting IMAGINAIRE_OUTPUT_ROOT to a location with sufficient storage space for your checkpoints.
Checkpointing
Training uses two checkpoint formats:
1. Distributed Checkpoint (DCP) Format
Primary format for training checkpoints.
- Structure: Multi-file directory with sharded model weights
- Used for: Saving checkpoints during training, resuming training
- Advantages:
- Efficient parallel I/O for multi-GPU training
- Supports FSDP (Fully Sharded Data Parallel)
- Optimized for distributed workloads
Example directory structure:
checkpoints/
├── iter_{NUMBER}/
│ ├── model/
│ │ ├── .metadata
│ │ └── __0_0.distcp
│ ├── optim/
│ ├── scheduler/
│ └── trainer/
└── latest_checkpoint.txt
2. Consolidated PyTorch (.pt) Format
Single-file format for inference and distribution.
- Structure: Single
.ptfile containing the complete model state - Used for: Inference, model sharing, initial post-training
- Advantages:
- Easy to distribute and version control
- Standard PyTorch format
- Simpler for single-GPU workflows
3. Inference with LoRA Post-trained checkpoint
3.1 Converting DCP Checkpoint to Consolidated PyTorch Format
Since the checkpoints are saved in DCP format during training, you need to convert them to consolidated PyTorch format (.pt) for inference. Use the convert_distcp_to_pt.py script:
# For text format checkpoint
CHECKPOINTS_DIR=${IMAGINAIRE_OUTPUT_ROOT:-/tmp/imaginaire4-output}/cosmos_predict_v2p5/lora/2b_cosmos_sports_assets_lora/checkpoints
CHECKPOINT_ITER=$(cat $CHECKPOINTS_DIR/latest_checkpoint.txt)
CHECKPOINT_DIR=$CHECKPOINTS_DIR/$CHECKPOINT_ITER
# Convert DCP checkpoint to PyTorch format
python scripts/convert_distcp_to_pt.py $CHECKPOINT_DIR/model $CHECKPOINT_DIR
This conversion will create three files:
model.pt: Full checkpoint containing both regular and EMA weightsmodel_ema_fp32.pt: EMA weights only in float32 precisionmodel_ema_bf16.pt: EMA weights only in bfloat16 precision (recommended for inference)
3.2 Running Inference
After converting the checkpoint, you can run inference with your post-trained model using the command-line interface.
The model can be used for any generation mode. Simply use the appropriate JSON configuration with the corresponding experiment:
Text2World Generation
# Using Text format checkpoint
torchrun --nproc_per_node=8 examples/inference.py \
-i assets/sports_text2world_prompts.json \
-o outputs/sports_text2world \
--checkpoint-path $CHECKPOINT_DIR/model_ema_bf16.pt \
--experiment predict2_lora_training_2b_cosmos_sports_assets_txt
# Using JSON format checkpoint
torchrun --nproc_per_node=8 examples/inference.py \
-i assets/sports_text2world_prompts.json \
-o outputs/sports_text2world \
--checkpoint-path $JSON_CHECKPOINT_DIR/model_ema_bf16.pt \
--experiment predict2_lora_training_2b_cosmos_sports_assets_json_rank32
Image2World Generation
torchrun --nproc_per_node=8 examples/inference.py \
-i assets/sports_image2world_inputs.json \
-o outputs/sports_image2world \
--checkpoint-path $CHECKPOINT_DIR/model_ema_bf16.pt \
--experiment predict2_lora_training_2b_cosmos_sports_assets_txt
Video2World Generation
torchrun --nproc_per_node=8 examples/inference.py \
-i assets/sports_video2world_inputs.json \
-o outputs/sports_video2world \
--checkpoint-path $CHECKPOINT_DIR/model_ema_bf16.pt \
--experiment predict2_lora_training_2b_cosmos_sports_assets_txt
The model automatically detects the generation mode based on the input:
- Provide text only → Text2World generation
- Provide 1 image frame → Image2World generation
- Provide 2+ video frames → Video2World generation
Generated videos will be saved to the output directory.
Example Prompts for Soccer Video Generation
Image2World/Video2World Generation Example
{
"inference_type": "image2world",
"name": "soccer_action_sequence",
"prompt": "A soccer player in a red and black uniform dribbles the ball past an opponent in an orange and white uniform. The player in red sprints towards the goal, evading the defender. As he approaches the goalpost, the goalkeeper dives to make a save but fails to stop the ball from entering the net. The camera follows the ball as it flies into the goal, capturing the excitement of the moment.",
"input_path": "first_frame.mp4",
"seed": 0,
"guidance": 3,
"num_output_frames": 93
}
Evaluation and Results
Comparison: Base Model vs Post-Trained Model
The LoRA post-training significantly improves the quality and realism of generated soccer videos. Below is a comparison of videos generated by the base model versus the post-trained model:
| Sample | Base Model | Post-Trained Model |
|---|---|---|
| Sample 1 | ||
| Sample 2 | ||
| Sample 3 |
Key Improvements After Post-Training
The post-training experiment demonstrated model improvements with limited data, despite not solving all physics artifacts.
The post-trained model demonstrates substantial enhancements in several critical areas:
1. Rule Coherence and Field Semantics
The post-trained model shows significantly better understanding of soccer rules and field layout:
- Players stay within field boundaries and respect offside lines
- Goal areas and penalty boxes are rendered with proper dimensions
- Ball physics follow realistic trajectories during passes and shots
2. Identity and Team Preservation
Team identities and player characteristics are better maintained throughout the generated sequences:
- Consistent jersey colors and numbers across all frames
- Individual player features remain stable during complex movements
- Team formations and tactical positioning are more realistic
3. Reduced Broadcast Artifacts
The post-trained model produces cleaner, broadcast-quality videos:
- Minimized motion blur during fast-paced action
- Reduced ghosting effects around players and the ball
- Cleaner rendering of stadium elements and crowd backgrounds
- Improved temporal consistency across frame sequences
4. Sport-Specific Motion Dynamics
The model better captures soccer-specific movements:
- Realistic dribbling patterns and ball control
- Natural goalkeeper diving and saving motions
- Accurate representation of tackles, passes, and shots
- Proper player acceleration and deceleration patterns
These improvements make the post-trained model particularly suitable for:
- Training computer vision systems for sports analytics
- Generating synthetic data for referee training
- Creating realistic game simulations for tactical analysis
- Producing content for sports broadcasting and entertainment
For more inference options and advanced usage, see the Cosmos Predict 2 inference documentation.