Skip to content

Commit 26fdad4

Browse files
meatybobbyakoumpa
andauthored
feat: Add NeMo Biencoder (#745)
Signed-off-by: Alexandros Koumparoulis <[email protected]> Co-authored-by: Alexandros Koumparoulis <[email protected]>
1 parent 152e632 commit 26fdad4

File tree

15 files changed

+3455
-1
lines changed

15 files changed

+3455
-1
lines changed

examples/biencoder/finetune.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from nemo_automodel.components.config._arg_parser import parse_args_and_load_config
18+
from nemo_automodel.recipes.biencoder import TrainBiencoderRecipe
19+
20+
21+
def main(default_config_path="examples/biencoder/llama3_2_1b_biencoder.yaml"):
22+
"""Main entry point for the biencoder fine-tuning recipe.
23+
24+
Loads the configuration, sets up the recipe, and initiates the training loop.
25+
26+
Args:
27+
default_config_path: Path to the default configuration file
28+
"""
29+
cfg = parse_args_and_load_config(default_config_path)
30+
recipe = TrainBiencoderRecipe(cfg)
31+
recipe.setup()
32+
recipe.run_train_validation_loop()
33+
34+
35+
if __name__ == "__main__":
36+
main()
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# To run this recipe, please use the following command:
16+
# python examples/biencoder/finetune.py --config examples/biencoder/llama3_2_1b_biencoder.yaml
17+
# Or with torchrun for multi-GPU:
18+
# torchrun --nproc-per-node=8 examples/biencoder/finetune.py --config examples/biencoder/llama3_2_1b_biencoder.yaml
19+
20+
seed: 42
21+
22+
step_scheduler:
23+
global_batch_size: 128
24+
local_batch_size: 4
25+
ckpt_every_steps: 500
26+
val_every_steps: 500
27+
num_epochs: 1
28+
29+
dist_env:
30+
backend: nccl
31+
timeout_minutes: 1
32+
33+
model:
34+
_target_: nemo_automodel.components.models.biencoder.NeMoAutoModelBiencoder.from_pretrained
35+
pretrained_model_name_or_path: meta-llama/Llama-3.2-1B
36+
share_encoder: true
37+
add_linear_pooler: false
38+
out_dimension: 768
39+
do_gradient_checkpointing: false
40+
train_n_passages: 5
41+
eval_negative_size: 4
42+
pooling: avg
43+
l2_normalize: true
44+
t: 0.02
45+
use_liger_kernel: true
46+
use_sdpa_patching: true
47+
torch_dtype: bfloat16
48+
49+
tokenizer:
50+
_target_: transformers.AutoTokenizer.from_pretrained
51+
pretrained_model_name_or_path: meta-llama/Llama-3.2-1B
52+
53+
dataloader:
54+
_target_: torchdata.stateful_dataloader.StatefulDataLoader
55+
dataset:
56+
_target_: nemo_automodel.components.datasets.llm.make_retrieval_dataset
57+
data_dir_list:
58+
- training_datasets/nqsh_shuffled_50k.json
59+
- training_datasets/miracl_train_es_llama3_1b_4m_512len.json
60+
- training_datasets/mldr_en_perc95_small.json
61+
data_type: train
62+
train_n_passages: 5
63+
eval_negative_size: 4
64+
seed: 42
65+
do_shuffle: true
66+
collate_fn:
67+
_target_: nemo_automodel.components.datasets.llm.RetrievalBiencoderCollator
68+
q_max_len: 512
69+
p_max_len: 512
70+
query_prefix: "query:"
71+
passage_prefix: "passage:"
72+
pad_to_multiple_of: 8
73+
shuffle: true
74+
num_workers: 0
75+
76+
# Optional: Uncomment to enable validation
77+
# validation_dataloader:
78+
# _target_: torchdata.stateful_dataloader.StatefulDataLoader
79+
# dataset:
80+
# _target_: nemo_automodel.components.datasets.llm.make_retrieval_dataset
81+
# data_dir_list: training_datasets/validation.json
82+
# data_type: eval
83+
# train_n_passages: 5
84+
# eval_negative_size: 4
85+
# seed: 42
86+
# do_shuffle: false
87+
# max_train_samples: 1000
88+
# train_data_select_offset: 0
89+
# collate_fn:
90+
# _target_: nemo_automodel.components.datasets.llm.RetrievalBiencoderCollator
91+
# q_max_len: 512
92+
# p_max_len: 512
93+
# query_prefix: "query:"
94+
# passage_prefix: "passage:"
95+
# padding: longest
96+
# pad_to_multiple_of: 8
97+
# batch_size: 2
98+
# shuffle: false
99+
# num_workers: 0
100+
101+
optimizer:
102+
_target_: transformer_engine.pytorch.optimizers.fused_adam.FusedAdam
103+
lr: 5.0e-6
104+
weight_decay: 0.01
105+
adam_w_mode: true
106+
bias_correction: true
107+
master_weights: true
108+
109+
# Learning rate scheduler
110+
lr_scheduler:
111+
lr_warmup_steps: 100
112+
113+
checkpoint:
114+
enabled: true
115+
checkpoint_dir: ./output/llama3_2_1b_biencoder/checkpoints
116+
model_save_format: torch_save
117+
save_consolidated: false
118+
119+
distributed:
120+
_target_: nemo_automodel.components.distributed.fsdp2.FSDP2Manager
121+
dp_size: none
122+
dp_replicate_size: 1
123+
tp_size: 1
124+
cp_size: 1
125+
sequence_parallel: false
126+
127+
# Uncomment and configure for W&B logging
128+
# wandb:
129+
# project: biencoder-finetuning
130+
# entity: your_entity
131+
# name: llama3_2_1b_biencoder
132+

nemo_automodel/components/checkpoint/stateful_wrappers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ def load_state_dict(self, state_dict: dict[str, Any], strict: bool = True) -> No
145145
if self.is_tied_lm_head and not self.is_peft:
146146
# PP models don't have tied embeddings. Safe to pass in model[0] here.
147147
lm_head_weight, lm_head_param_name = _get_lm_head_weight_and_name(self.model[0])
148-
if lm_head_param_name not in state_dict:
148+
# Skip for Biencoder models as it doesn't have a lm_head at the top level
149+
if lm_head_weight is not None and lm_head_param_name not in state_dict:
149150
# weight tying guarantees this is identical to the embedding weight
150151
state_dict[lm_head_param_name] = lm_head_weight.detach()
151152

nemo_automodel/components/datasets/llm/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,15 @@
1515
from .chat_dataset import ChatDataset # noqa: F401
1616
from .column_mapped_text_instruction_dataset import ColumnMappedTextInstructionDataset # noqa: F401
1717
from .nanogpt_dataset import NanogptDataset # noqa: F401
18+
from .retrieval_collator import RetrievalBiencoderCollator # noqa: F401
19+
from .retrieval_dataset import make_retrieval_dataset # noqa: F401
1820
from .squad import make_squad_dataset # noqa: F401
1921

2022
__all__ = [
2123
"NanogptDataset",
2224
"make_squad_dataset",
25+
"make_retrieval_dataset",
26+
"RetrievalBiencoderCollator",
2327
"ColumnMappedTextInstructionDataset",
2428
"ChatDataset",
2529
]

0 commit comments

Comments
 (0)