2020 --data_dir=<YOUR DATA DIR>
2121"""
2222from scripts .vlm import gemma3vl_utils as train_utils
23+
2324# Need to run these filters before importing nemo.
2425train_utils .filter_warnings ()
2526train_utils .filter_grad_bucket_logs ()
2627
2728import argparse
2829import time
30+
2931import torch
32+
3033torch .autograd .set_detect_anomaly (True )
3134import os
32- from lightning . pytorch . loggers import WandbLogger
33- from lightning .pytorch .loggers import TensorBoardLogger
35+
36+ from lightning .pytorch .loggers import TensorBoardLogger , WandbLogger
3437from megatron .core .distributed import DistributedDataParallelConfig
3538from megatron .core .optimizer import OptimizerConfig
39+ from transformers import Gemma3ImageProcessor , Gemma3Processor
40+
3641from nemo import lightning as nl
3742from nemo .collections import llm , vlm
38-
3943from nemo .collections .common .tokenizers .huggingface .auto_tokenizer import AutoTokenizer
4044from nemo .collections .multimodal .data .energon import EnergonMultiModalDataModule
4145from nemo .collections .vlm .gemma3vl .data .mock import Gemma3VLMockDataModule
46+ from nemo .collections .vlm .gemma3vl .data .task_encoder import TaskEncoder as Gemma3VLTaskEncoder
47+ from nemo .collections .vlm .gemma3vl .data .task_encoder import TaskEncoderConfig as Gemma3VLTaskEncoderConfig
4248from nemo .lightning .pytorch .optim import CosineAnnealingScheduler
4349from nemo .lightning .pytorch .optim .megatron import MegatronOptimizerModule
44- from nemo .utils .exp_manager import TimingCallback
4550from nemo .utils import logging
46- from nemo .collections .vlm .gemma3vl .data .task_encoder import (
47- TaskEncoder as Gemma3VLTaskEncoder ,
48- TaskEncoderConfig as Gemma3VLTaskEncoderConfig ,
49- )
50- from transformers import Gemma3ImageProcessor , Gemma3Processor
51+ from nemo .utils .exp_manager import TimingCallback
5152
5253
5354def main (args ):
@@ -149,18 +150,14 @@ def main(args):
149150 name = args .exp_name ,
150151 ckpt = checkpoint_callback ,
151152 tensorboard = TensorBoardLogger (save_dir = "tensorboard" , name = "" ),
152- wandb = WandbLogger (project = args .wandb_project , name = args .exp_name )
153- if args .wandb_project is not None
154- else None ,
153+ wandb = WandbLogger (project = args .wandb_project , name = args .exp_name ) if args .wandb_project is not None else None ,
155154 )
156155
157156 # Auto resume setup
158157 resume = nl .AutoResume (
159158 resume_if_exists = False ,
160159 resume_ignore_no_checkpoint = True ,
161- restore_config = nl .RestoreConfig (path = args .resume_from_ckpt )
162- if args .resume_from_ckpt is not None
163- else None ,
160+ restore_config = nl .RestoreConfig (path = args .resume_from_ckpt ) if args .resume_from_ckpt is not None else None ,
164161 )
165162
166163 # Optimizer and scheduler setup
@@ -205,7 +202,7 @@ def main(args):
205202 parser .add_argument (
206203 "--restore_path" , type = str , required = False , default = None , help = "Path to restore model from checkpoint"
207204 )
208- parser .add_argument ("--log_dir" , type = str , required = False , default = "/logs" , help = "Path to the log folder" )
205+ parser .add_argument ("--log_dir" , type = str , required = False , default = "/logs" , help = "Path to the log folder" )
209206 parser .add_argument ("--tp_size" , type = int , required = False , default = 1 )
210207 parser .add_argument ("--pp_size" , type = int , required = False , default = 1 )
211208 parser .add_argument ("--num_nodes" , type = int , required = False , default = 1 )
@@ -216,14 +213,20 @@ def main(args):
216213 parser .add_argument ("--val_check_interval" , type = int , required = False , default = 10 )
217214 parser .add_argument ("--limit_val_batches" , type = float , required = False , default = 1.0 )
218215 parser .add_argument ("--lr" , type = float , required = False , default = 2.0e-06 , help = "Learning rate" )
219- parser .add_argument ("--hf_model_id" , type = str , required = False , default = "google/gemma-3-4b-it" , help = "HuggingFace Gemma3VL model ids" )
216+ parser .add_argument (
217+ "--hf_model_id" ,
218+ type = str ,
219+ required = False ,
220+ default = "google/gemma-3-4b-it" ,
221+ help = "HuggingFace Gemma3VL model ids" ,
222+ )
220223 parser .add_argument ("--gbs" , type = int , required = False , default = 32 , help = "Global batch size" )
221224 parser .add_argument ("--mbs" , type = int , required = False , default = 1 , help = "Micro batch size" )
222225 parser .add_argument ("--save_top_k" , type = int , required = False , default = 1 , help = "Save top k" )
223- parser .add_argument ("--num_workers" , type = int , required = False , default = 2 , help = "The num of workers for data loader" )
224226 parser .add_argument (
225- "--max_sequence_length " , type = int , required = False , default = 512 , help = "Maximum sequence length "
227+ "--num_workers " , type = int , required = False , default = 2 , help = "The num of workers for data loader "
226228 )
229+ parser .add_argument ("--max_sequence_length" , type = int , required = False , default = 512 , help = "Maximum sequence length" )
227230 parser .add_argument (
228231 "--resume_from_ckpt" ,
229232 type = str ,
@@ -232,9 +235,7 @@ def main(args):
232235 help = "Path to restore model from checkpoint" ,
233236 )
234237 parser .add_argument ("--wandb_project" , type = str , required = False , default = None )
235- parser .add_argument (
236- "--exp_name" , type = str , required = False , default = "gemma3vl_finetune"
237- )
238+ parser .add_argument ("--exp_name" , type = str , required = False , default = "gemma3vl_finetune" )
238239
239240 args = parser .parse_args ()
240241 main (args )
0 commit comments