Skip to content

sdan/vlm-gym

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

32 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Status: Experimental

rewardsgeolocation.mov

vlm-gym

A simple reinforcement learning gym for vision-language models, written in JAX. Drop in any environment, any model, and train with PPO.

Core components:

  • envs/ — Pluggable vision environments (GeoGuessr, NLVR2, captioning)
  • models/ — VLM implementations (Qwen3-VL-4B-Instruct reference)
  • core/train.py — Trainer runs PPO on the environment
  • core/rollout.py — Inference engine runs the VLM on the environment
  • core/eval.py — Evaluation harness runs the VLM on the environment and compares it to the Hugging Face baseline

Install and convert HF(Qwen3-VL-4B-Instruct default model) → JAX

uv sync 
uv run python -m utils.hf_to_jax --model_dir checkpoints/qwen3vl_4b

Run a VLM to play GeoGuessr

uv run python -m core.rollout 
  --model_dir checkpoints/qwen3vl_4b \
  --env_name geospot
vlm-gym-rollout.mp4

Train a VLM to play GeoGuessr

Training uses a hierarchical curriculum that progressively sharpens geolocation accuracy:

  • Stage 1 (0-100 episodes): Country-level coarse matching (wide tolerance)
  • Stage 2 (100-300): Country refinement (tighter kernels)
  • Stage 3 (300-600): Add region signal
  • Stage 4 (600-1000): Introduce city-level precision
  • Stage 5 (1000+): Full hierarchical task (country + region + city + coords)

Each field (country/region/city/coords) uses geodesic distance with exponential decay kernels. Weights blend progressively to guide learning from coarse → fine localization.

# Train on OpenStreetView-5M dataset
uv run python core/train.py \
  --model_dir checkpoints/qwen3vl_4b \
  --env_name geospot \
  --lr 5e-7 \
  --total_steps 10000
Screenshot 2025-10-18 at 5 47 31 PM

Sample

uv run python -m core.rollout 
  --model_dir checkpoints/qwen3vl_4b \
  --env_name geospot \
  --episodes 1 \
  --batch_size 1 \
  --temperature 0.7 \
  --top_p 0.9 \
  --top_k 5 \
  --max_new_tokens 128 \
  --seed 0

Train

# Train on any environment
uv run python core/train.py \
  --model_dir=checkpoints/qwen3vl_4b \
  --env_name=geospot \
  --groups_per_batch=8 \
  --group_size=1 \
  --lr=5e-7 \
  --total_steps=10000

Evaluate

uv run python core/eval.py \
  --model_dir checkpoints/qwen3vl_4b \
  --compare_hf \
  --hf_model_name Qwen/Qwen3-VL-4B-Instruct \
  --benchmark_runs 2 \
  --max_new_tokens=128 \
  --prompt "Give me a short introduction to large language models." \

Currently the JAX compiler takes a while to run its initial compile, yet the rest of the inference is slightly behind the Hugging Face baseline. TODO: optimize the JAX sampler to improve throughput.

Preliminary Benchmarks (Qwen3-VL-4B, A100 80GB, Oct 2024):

Metric JAX Sampler HF Baseline
Mean tokens/sec 12.73 ± 0.06 16.79 ± 3.03
First-token latency (s) 28.19 0.16
Steady-state throughput (tok/s) 12.79 19.82

First-token latency includes XLA compile time. Reproduce with core/eval.py --compare_hf --benchmark_runs 2.


Environments

Creating a custom environment is simple - just extend envs.base.BaseEnv:

class MyEnv(BaseEnv):
    def reset(self, idx):
        # Return state and observation
        return state, obs
    
    def step(self, state, action_tokens):
        # Calculate reward based on VLM output
        return state, [], reward, done, info

Built-in environments:

  • geospotGeoGuessr: Street-view geolocation with hierarchical rewards (country→region→city→coords)
  • nlvr2 — Two-image True/False reasoning

Requirements

  • Python 3.10+
  • Linux, CUDA 12, NVIDIA GPU (80GB+ recommended for training; inference requires ~10GB for 4B model)
  • JAX 0.6.1 (CUDA 12 build)

References


Citation

If you use vlm-gym in your research, please cite:

@software{dantuluri2025vlmgym,
  author = {Dantuluri, Surya},
  title = {vlm-gym: Reinforcement Learning Gym for Vision-Language Models},
  year = {2024},
  url = {https://github.com/sdan/vlm-gym}
}

License

See LICENSE and NOTICE.

About

RL gym for vision language models written from scratch

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages