Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions examples/biencoder/finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from nemo_automodel.components.config._arg_parser import parse_args_and_load_config
from nemo_automodel.recipes.biencoder import TrainBiencoderRecipe


def main(default_config_path="examples/biencoder/llama3_2_1b_biencoder.yaml"):
"""Main entry point for the biencoder fine-tuning recipe.

Loads the configuration, sets up the recipe, and initiates the training loop.

Args:
default_config_path: Path to the default configuration file
"""
cfg = parse_args_and_load_config(default_config_path)
recipe = TrainBiencoderRecipe(cfg)
recipe.setup()
recipe.run_train_validation_loop()


if __name__ == "__main__":
main()
132 changes: 132 additions & 0 deletions examples/biencoder/llama3_2_1b_biencoder.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# To run this recipe, please use the following command:
# python examples/biencoder/finetune.py --config examples/biencoder/llama3_2_1b_biencoder.yaml
# Or with torchrun for multi-GPU:
# torchrun --nproc-per-node=8 examples/biencoder/finetune.py --config examples/biencoder/llama3_2_1b_biencoder.yaml

seed: 42

step_scheduler:
global_batch_size: 128
local_batch_size: 4
ckpt_every_steps: 500
val_every_steps: 500
num_epochs: 1

dist_env:
backend: nccl
timeout_minutes: 1

model:
_target_: nemo_automodel.components.models.biencoder.NeMoAutoModelBiencoder.from_pretrained
pretrained_model_name_or_path: meta-llama/Llama-3.2-1B
share_encoder: true
add_linear_pooler: false
out_dimension: 768
do_gradient_checkpointing: false
train_n_passages: 5
eval_negative_size: 4
pooling: avg
l2_normalize: true
t: 0.02
use_liger_kernel: true
use_sdpa_patching: true
torch_dtype: bfloat16

tokenizer:
_target_: transformers.AutoTokenizer.from_pretrained
pretrained_model_name_or_path: meta-llama/Llama-3.2-1B

dataloader:
_target_: torchdata.stateful_dataloader.StatefulDataLoader
dataset:
_target_: nemo_automodel.components.datasets.llm.make_retrieval_dataset
data_dir_list:
- training_datasets/nqsh_shuffled_50k.json
- training_datasets/miracl_train_es_llama3_1b_4m_512len.json
- training_datasets/mldr_en_perc95_small.json
data_type: train
train_n_passages: 5
eval_negative_size: 4
seed: 42
do_shuffle: true
collate_fn:
_target_: nemo_automodel.components.datasets.llm.RetrievalBiencoderCollator
q_max_len: 512
p_max_len: 512
query_prefix: "query:"
passage_prefix: "passage:"
pad_to_multiple_of: 8
shuffle: true
num_workers: 0

# Optional: Uncomment to enable validation
# validation_dataloader:
# _target_: torchdata.stateful_dataloader.StatefulDataLoader
# dataset:
# _target_: nemo_automodel.components.datasets.llm.make_retrieval_dataset
# data_dir_list: training_datasets/validation.json
# data_type: eval
# train_n_passages: 5
# eval_negative_size: 4
# seed: 42
# do_shuffle: false
# max_train_samples: 1000
# train_data_select_offset: 0
# collate_fn:
# _target_: nemo_automodel.components.datasets.llm.RetrievalBiencoderCollator
# q_max_len: 512
# p_max_len: 512
# query_prefix: "query:"
# passage_prefix: "passage:"
# padding: longest
# pad_to_multiple_of: 8
# batch_size: 2
# shuffle: false
# num_workers: 0

optimizer:
_target_: transformer_engine.pytorch.optimizers.fused_adam.FusedAdam
lr: 5.0e-6
weight_decay: 0.01
adam_w_mode: true
bias_correction: true
master_weights: true

# Learning rate scheduler
lr_scheduler:
lr_warmup_steps: 100

checkpoint:
enabled: true
checkpoint_dir: ./output/llama3_2_1b_biencoder/checkpoints
model_save_format: torch_save
save_consolidated: false

distributed:
_target_: nemo_automodel.components.distributed.fsdp2.FSDP2Manager
dp_size: none
dp_replicate_size: 1
tp_size: 1
cp_size: 1
sequence_parallel: false

# Uncomment and configure for W&B logging
# wandb:
# project: biencoder-finetuning
# entity: your_entity
# name: llama3_2_1b_biencoder

3 changes: 2 additions & 1 deletion nemo_automodel/components/checkpoint/stateful_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def load_state_dict(self, state_dict: dict[str, Any], strict: bool = True) -> No
if self.is_tied_lm_head and not self.is_peft:
# PP models don't have tied embeddings. Safe to pass in model[0] here.
lm_head_weight, lm_head_param_name = _get_lm_head_weight_and_name(self.model[0])
if lm_head_param_name not in state_dict:
# Skip for Biencoder models as it doesn't have a lm_head at the top level
if lm_head_weight is not None and lm_head_param_name not in state_dict:
# weight tying guarantees this is identical to the embedding weight
state_dict[lm_head_param_name] = lm_head_weight.detach()

Expand Down
4 changes: 4 additions & 0 deletions nemo_automodel/components/datasets/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@

from .column_mapped_text_instruction_dataset import ColumnMappedTextInstructionDataset # noqa: F401
from .nanogpt_dataset import NanogptDataset # noqa: F401
from .retrieval_collator import RetrievalBiencoderCollator # noqa: F401
from .retrieval_dataset import make_retrieval_dataset # noqa: F401
from .squad import make_squad_dataset # noqa: F401
from .tool_calling_chat_dataset import ToolCallingChatDataset # noqa: F401

__all__ = [
"NanogptDataset",
"make_squad_dataset",
"make_retrieval_dataset",
"RetrievalBiencoderCollator",
"ColumnMappedTextInstructionDataset",
"ToolCallingChatDataset",
]
Loading
Loading