Skip to content

Commit 7277489

Browse files
committed
Apply isort and black reformatting
Signed-off-by: genquan9 <[email protected]> Signed-off-by: genquan9 <[email protected]>
1 parent 2bf98a6 commit 7277489

File tree

7 files changed

+134
-142
lines changed

7 files changed

+134
-142
lines changed

nemo/collections/llm/gpt/model/gemma3.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from pathlib import Path
2222
from typing import TYPE_CHECKING, Annotated, Callable, Optional, Tuple, Union
2323

24+
import torch
2425
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
2526
from megatron.core.inference.contexts import BaseInferenceContext
2627
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
@@ -30,8 +31,6 @@
3031
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
3132
from megatron.core.transformer.enums import AttnBackend, AttnMaskType
3233
from megatron.core.transformer.mlp import MLP, MLPSubmodules
33-
34-
import torch
3534
from torch import Tensor, nn
3635

3736
from nemo.collections.llm.fn.activation import openai_gelu

nemo/collections/vlm/gemma3vl/data/task_encoder.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import logging
1615
import json
16+
import logging
1717
from dataclasses import dataclass, field
1818
from typing import Optional
1919

@@ -104,7 +104,6 @@ def encode_batch(self, batch_data: DataBatch) -> dict:
104104
batch_data["media"] = batch_data["media"].reshape(-1, *batch_data["media"].shape[2:])
105105
return batch_data
106106

107-
108107
def encode_vqa_sample_multi_turns(self, input_sample: VQASample):
109108
images = input_sample.image if isinstance(input_sample.image, list) else [input_sample.image]
110109

@@ -116,10 +115,12 @@ def encode_vqa_sample_multi_turns(self, input_sample: VQASample):
116115
messages.append(context)
117116

118117
# Apply chat template and process with HF processor
119-
#`add_generation_prompt=False` because we're providing the full ground truth sequence
118+
# `add_generation_prompt=False` because we're providing the full ground truth sequence
120119
# We remove the <bos> token using removeprefix('<bos>') since we're finetuning.
121120
# The Processor will add this token before training and the model expects only one.
122-
converted_messages = self.hf_processor.apply_chat_template(messages, add_generation_prompt=False, tokenize=False).removeprefix('<bos>')
121+
converted_messages = self.hf_processor.apply_chat_template(
122+
messages, add_generation_prompt=False, tokenize=False
123+
).removeprefix('<bos>')
123124
outputs = self.hf_processor(
124125
images=images,
125126
text=converted_messages,
@@ -140,7 +141,9 @@ def encode_vqa_sample_multi_turns(self, input_sample: VQASample):
140141
if context['role'] != 'assistant':
141142
continue
142143
# Tokenize the answer, including the stop string if provided
143-
answer_with_stop = context['content'][0]['text'].rstrip().lstrip() + "<end_of_turn>" + (self.config.stop_string or "")
144+
answer_with_stop = (
145+
context['content'][0]['text'].rstrip().lstrip() + "<end_of_turn>" + (self.config.stop_string or "")
146+
)
144147
answer_with_stop = answer_with_stop.rstrip().lstrip()
145148
answer_tokens = self.tokenizer.tokenizer(answer_with_stop, add_special_tokens=False)["input_ids"]
146149
answer_tokens_tensor = torch.tensor(answer_tokens, device=tokens.device) # Ensure same device
@@ -171,7 +174,6 @@ def encode_vqa_sample_multi_turns(self, input_sample: VQASample):
171174
break
172175
return tokens, labels, images
173176

174-
175177
def encode_vqa_sample(self, input_sample: VQASample) -> DataSample:
176178
"""Encode a VQA sample into a DataSample format.
177179
@@ -228,4 +230,3 @@ def encode_vqa_sample(self, input_sample: VQASample) -> DataSample:
228230
)
229231

230232
return sample
231-

nemo/collections/vlm/gemma3vl/model/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from megatron.core.inference_params import InferenceParams
2727
from megatron.core.packed_seq_params import PackedSeqParams
2828
from megatron.core.parallel_state import get_context_parallel_group
29-
from nemo.collections.vlm.qwen2vl.data.multimodal_tokens import IGNORE_INDEX
3029
from megatron.core.tensor_parallel import scatter_to_sequence_parallel_region
3130
from megatron.core.transformer import MegatronModule
3231
from megatron.core.transformer.transformer_config import TransformerConfig
@@ -38,11 +37,11 @@
3837
from nemo.collections.llm.gpt.model.gemma3 import Gemma3Config
3938
from nemo.collections.vlm.gemma3vl.model.vision import Gemma3VLMultimodalProjectorConfig, Gemma3VLVisionConfig
4039
from nemo.collections.vlm.neva.model.base import MODEL_CONFIG_ATTR, NevaModel, restore_model_weights
40+
from nemo.collections.vlm.qwen2vl.data.multimodal_tokens import IGNORE_INDEX
4141
from nemo.lightning import io
4242
from nemo.lightning.pytorch.optim import OptimizerModule
4343
from nemo.utils.import_utils import safe_import_from
4444

45-
4645
TENorm, _ = safe_import_from("megatron.core.extensions.transformer_engine", "TENorm")
4746

4847
HAVE_TEX = True

scripts/vlm/gemma3vl_export.py

Lines changed: 58 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,73 @@
11
"""Export Gemma3VL NeMo checkpoints to Hugging Face format."""
22

33
import argparse
4-
from huggingface_hub import hf_hub_download
54
import importlib
65
import os
7-
from pathlib import Path
86
import sys
7+
from pathlib import Path
8+
9+
from huggingface_hub import hf_hub_download
10+
911
from nemo.collections import llm
1012

1113

1214
def main():
13-
parser = argparse.ArgumentParser(
14-
description=(
15-
"Export NeMo vision language model checkpoint to Hugging Face format."
16-
)
17-
)
18-
parser.add_argument(
19-
"--nemo_ckpt_path",
20-
type=str,
21-
required=True,
22-
default=None,
23-
help="Path to the NeMo checkpoint directory.",
24-
)
25-
parser.add_argument(
26-
"--output_hf_path",
27-
type=str,
28-
required=True,
29-
default=None,
30-
help="Path to save the converted Hugging Face checkpoint.",
31-
)
32-
parser.add_argument(
33-
"--model_name",
34-
type=str,
35-
required=False,
36-
default=None,
37-
help="Name of the model on Hugging Face.",
38-
)
15+
parser = argparse.ArgumentParser(
16+
description=("Export NeMo vision language model checkpoint to Hugging Face format.")
17+
)
18+
parser.add_argument(
19+
"--nemo_ckpt_path",
20+
type=str,
21+
required=True,
22+
default=None,
23+
help="Path to the NeMo checkpoint directory.",
24+
)
25+
parser.add_argument(
26+
"--output_hf_path",
27+
type=str,
28+
required=True,
29+
default=None,
30+
help="Path to save the converted Hugging Face checkpoint.",
31+
)
32+
parser.add_argument(
33+
"--model_name",
34+
type=str,
35+
required=False,
36+
default=None,
37+
help="Name of the model on Hugging Face.",
38+
)
3939

40-
args = parser.parse_args()
40+
args = parser.parse_args()
4141

42-
llm.export_ckpt(
43-
path=Path(args.nemo_ckpt_path),
44-
target="hf",
45-
output_path=Path(args.output_hf_path),
46-
overwrite=True,
47-
)
48-
if args.model_name:
49-
# Copy necessary files if exist from HuggingFace for Gemma3VL model export.
50-
copy_file_list = [
51-
"preprocessor_config.json",
52-
"chat_template.json",
53-
"config.json",
54-
"generation_config.json",
55-
"merges.txt",
56-
"tokenizer.json",
57-
"tokenizer_config.json",
58-
"vocab.json",
59-
]
60-
for file_name in copy_file_list:
61-
try:
62-
downloaded_path = hf_hub_download(
63-
repo_id=args.model_name,
64-
filename=file_name,
65-
local_dir=args.output_hf_path,
66-
)
67-
print(f"Downloaded {downloaded_path} during export gamma3vl models.")
68-
except:
69-
print(f"Ignore {file_name} during export gamma3vl models.")
42+
llm.export_ckpt(
43+
path=Path(args.nemo_ckpt_path),
44+
target="hf",
45+
output_path=Path(args.output_hf_path),
46+
overwrite=True,
47+
)
48+
if args.model_name:
49+
# Copy necessary files if exist from HuggingFace for Gemma3VL model export.
50+
copy_file_list = [
51+
"preprocessor_config.json",
52+
"chat_template.json",
53+
"config.json",
54+
"generation_config.json",
55+
"merges.txt",
56+
"tokenizer.json",
57+
"tokenizer_config.json",
58+
"vocab.json",
59+
]
60+
for file_name in copy_file_list:
61+
try:
62+
downloaded_path = hf_hub_download(
63+
repo_id=args.model_name,
64+
filename=file_name,
65+
local_dir=args.output_hf_path,
66+
)
67+
print(f"Downloaded {downloaded_path} during export gamma3vl models.")
68+
except:
69+
print(f"Ignore {file_name} during export gamma3vl models.")
7070

7171

7272
if __name__ == "__main__":
73-
main()
73+
main()

scripts/vlm/gemma3vl_finetune.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,34 +20,35 @@
2020
--data_dir=<YOUR DATA DIR>
2121
"""
2222
from scripts.vlm import gemma3vl_utils as train_utils
23+
2324
# Need to run these filters before importing nemo.
2425
train_utils.filter_warnings()
2526
train_utils.filter_grad_bucket_logs()
2627

2728
import argparse
2829
import time
30+
2931
import torch
32+
3033
torch.autograd.set_detect_anomaly(True)
3134
import os
32-
from lightning.pytorch.loggers import WandbLogger
33-
from lightning.pytorch.loggers import TensorBoardLogger
35+
36+
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
3437
from megatron.core.distributed import DistributedDataParallelConfig
3538
from megatron.core.optimizer import OptimizerConfig
39+
from transformers import Gemma3ImageProcessor, Gemma3Processor
40+
3641
from nemo import lightning as nl
3742
from nemo.collections import llm, vlm
38-
3943
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
4044
from nemo.collections.multimodal.data.energon import EnergonMultiModalDataModule
4145
from nemo.collections.vlm.gemma3vl.data.mock import Gemma3VLMockDataModule
46+
from nemo.collections.vlm.gemma3vl.data.task_encoder import TaskEncoder as Gemma3VLTaskEncoder
47+
from nemo.collections.vlm.gemma3vl.data.task_encoder import TaskEncoderConfig as Gemma3VLTaskEncoderConfig
4248
from nemo.lightning.pytorch.optim import CosineAnnealingScheduler
4349
from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule
44-
from nemo.utils.exp_manager import TimingCallback
4550
from nemo.utils import logging
46-
from nemo.collections.vlm.gemma3vl.data.task_encoder import (
47-
TaskEncoder as Gemma3VLTaskEncoder,
48-
TaskEncoderConfig as Gemma3VLTaskEncoderConfig,
49-
)
50-
from transformers import Gemma3ImageProcessor, Gemma3Processor
51+
from nemo.utils.exp_manager import TimingCallback
5152

5253

5354
def main(args):
@@ -149,18 +150,14 @@ def main(args):
149150
name=args.exp_name,
150151
ckpt=checkpoint_callback,
151152
tensorboard=TensorBoardLogger(save_dir="tensorboard", name=""),
152-
wandb=WandbLogger(project=args.wandb_project, name=args.exp_name)
153-
if args.wandb_project is not None
154-
else None,
153+
wandb=WandbLogger(project=args.wandb_project, name=args.exp_name) if args.wandb_project is not None else None,
155154
)
156155

157156
# Auto resume setup
158157
resume = nl.AutoResume(
159158
resume_if_exists=False,
160159
resume_ignore_no_checkpoint=True,
161-
restore_config=nl.RestoreConfig(path=args.resume_from_ckpt)
162-
if args.resume_from_ckpt is not None
163-
else None,
160+
restore_config=nl.RestoreConfig(path=args.resume_from_ckpt) if args.resume_from_ckpt is not None else None,
164161
)
165162

166163
# Optimizer and scheduler setup
@@ -205,7 +202,7 @@ def main(args):
205202
parser.add_argument(
206203
"--restore_path", type=str, required=False, default=None, help="Path to restore model from checkpoint"
207204
)
208-
parser.add_argument("--log_dir", type=str, required=False, default="/logs", help="Path to the log folder")
205+
parser.add_argument("--log_dir", type=str, required=False, default="/logs", help="Path to the log folder")
209206
parser.add_argument("--tp_size", type=int, required=False, default=1)
210207
parser.add_argument("--pp_size", type=int, required=False, default=1)
211208
parser.add_argument("--num_nodes", type=int, required=False, default=1)
@@ -216,14 +213,20 @@ def main(args):
216213
parser.add_argument("--val_check_interval", type=int, required=False, default=10)
217214
parser.add_argument("--limit_val_batches", type=float, required=False, default=1.0)
218215
parser.add_argument("--lr", type=float, required=False, default=2.0e-06, help="Learning rate")
219-
parser.add_argument("--hf_model_id", type=str, required=False, default="google/gemma-3-4b-it", help="HuggingFace Gemma3VL model ids")
216+
parser.add_argument(
217+
"--hf_model_id",
218+
type=str,
219+
required=False,
220+
default="google/gemma-3-4b-it",
221+
help="HuggingFace Gemma3VL model ids",
222+
)
220223
parser.add_argument("--gbs", type=int, required=False, default=32, help="Global batch size")
221224
parser.add_argument("--mbs", type=int, required=False, default=1, help="Micro batch size")
222225
parser.add_argument("--save_top_k", type=int, required=False, default=1, help="Save top k")
223-
parser.add_argument("--num_workers", type=int, required=False, default=2, help="The num of workers for data loader")
224226
parser.add_argument(
225-
"--max_sequence_length", type=int, required=False, default=512, help="Maximum sequence length"
227+
"--num_workers", type=int, required=False, default=2, help="The num of workers for data loader"
226228
)
229+
parser.add_argument("--max_sequence_length", type=int, required=False, default=512, help="Maximum sequence length")
227230
parser.add_argument(
228231
"--resume_from_ckpt",
229232
type=str,
@@ -232,9 +235,7 @@ def main(args):
232235
help="Path to restore model from checkpoint",
233236
)
234237
parser.add_argument("--wandb_project", type=str, required=False, default=None)
235-
parser.add_argument(
236-
"--exp_name", type=str, required=False, default="gemma3vl_finetune"
237-
)
238+
parser.add_argument("--exp_name", type=str, required=False, default="gemma3vl_finetune")
238239

239240
args = parser.parse_args()
240241
main(args)

0 commit comments

Comments
 (0)