Skip to content

Commit 98b0874

Browse files
refactor config matching
Signed-off-by: Youngeun Kwon <[email protected]>
1 parent 7124e44 commit 98b0874

File tree

1 file changed

+4
-19
lines changed

1 file changed

+4
-19
lines changed

nemo_rl/models/megatron/community_import.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -53,25 +53,10 @@ def import_model_from_hf_name(
5353
orig_pipeline_dtype = model_provider.pipeline_dtype
5454

5555
if megatron_config is not None:
56-
model_provider.tensor_model_parallel_size = megatron_config[
57-
"tensor_model_parallel_size"
58-
]
59-
model_provider.pipeline_model_parallel_size = megatron_config[
60-
"pipeline_model_parallel_size"
61-
]
62-
model_provider.expert_model_parallel_size = megatron_config[
63-
"expert_model_parallel_size"
64-
]
65-
model_provider.expert_tensor_parallel_size = megatron_config[
66-
"expert_tensor_parallel_size"
67-
]
68-
model_provider.num_layers_in_first_pipeline_stage = megatron_config[
69-
"num_layers_in_first_pipeline_stage"
70-
]
71-
model_provider.num_layers_in_last_pipeline_stage = megatron_config[
72-
"num_layers_in_last_pipeline_stage"
73-
]
74-
model_provider.pipeline_dtype = megatron_config["pipeline_dtype"]
56+
for k in megatron_config.keys():
57+
if hasattr(model_provider, k):
58+
setattr(model_provider, k, megatron_config[k]) # type: ignore
59+
7560
model_provider.finalize()
7661
model_provider.initialize_model_parallel(seed=0)
7762
megatron_model = model_provider.provide_distributed_model(wrap_with_ddp=False)

0 commit comments

Comments
 (0)