Skip to content

Commit 15716ff

Browse files
ananthsubNeMo Bot
authored andcommitted
Qwen3 MoE finetune recipes (#1265)
Signed-off-by: Ananth Subramaniam <[email protected]> Signed-off-by: NeMo Bot <[email protected]>
1 parent f604f70 commit 15716ff

File tree

3 files changed

+514
-18
lines changed

3 files changed

+514
-18
lines changed

src/megatron/bridge/recipes/qwen/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@
4545

4646
# Qwen3 MoE models
4747
from .qwen3_moe import (
48+
qwen3_30b_a3b_finetune_config,
4849
qwen3_30b_a3b_pretrain_config,
50+
qwen3_235b_a22b_finetune_config,
4951
qwen3_235b_a22b_pretrain_config,
5052
)
5153

@@ -83,7 +85,9 @@
8385
"qwen3_32b_finetune_config",
8486
# Qwen3 MoE models
8587
"qwen3_30b_a3b_pretrain_config",
88+
"qwen3_30b_a3b_finetune_config",
8689
"qwen3_235b_a22b_pretrain_config",
90+
"qwen3_235b_a22b_finetune_config",
8791
# Qwen3-Next models
8892
"qwen3_next_80b_a3b_pretrain_config",
8993
]

src/megatron/bridge/recipes/qwen/qwen3_moe.py

Lines changed: 294 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
from typing_extensions import TypedDict, Unpack
2121

2222
from megatron.bridge import AutoBridge
23+
from megatron.bridge.peft.base import PEFT
2324
from megatron.bridge.recipes.utils.dataset_utils import get_blend_fields_from_data_paths
25+
from megatron.bridge.recipes.utils.finetune_utils import default_peft_config, default_squad_config
2426
from megatron.bridge.recipes.utils.optimizer_utils import distributed_fused_adam_with_cosine_annealing
2527
from megatron.bridge.recipes.utils.tokenizer_utils import DEFAULT_NULL_TOKENIZER_VOCAB_SIZE
2628
from megatron.bridge.training.comm_overlap import CommOverlapConfig
@@ -33,7 +35,7 @@
3335
TokenizerConfig,
3436
TrainingConfig,
3537
)
36-
from megatron.bridge.training.mixed_precision import MixedPrecisionConfig, bf16_mixed
38+
from megatron.bridge.training.mixed_precision import MixedPrecisionConfig, bf16_mixed, get_mixed_precision_config
3739

3840

3941
class Qwen3MoeCommonKwargs(TypedDict, total=False):
@@ -81,6 +83,45 @@ class Qwen3MoeCommonKwargs(TypedDict, total=False):
8183
comm_overlap_config: Optional[CommOverlapConfig]
8284

8385

86+
class Qwen3MoeFinetuneKwargs(TypedDict, total=False):
87+
"""Typed options accepted by Qwen3 MoE finetuning recipe helper functions.
88+
89+
This is separate from Qwen3MoeCommonKwargs to avoid confusion - finetuning
90+
uses SQuAD dataset by default, not the data path fields.
91+
"""
92+
93+
# Core identifiers
94+
dir: Optional[str]
95+
name: str
96+
97+
# Finetuning-specific
98+
pretrained_checkpoint: Optional[str]
99+
peft: Union[str, PEFT, None]
100+
packed_sequence: bool
101+
102+
# Training hyperparameters
103+
train_iters: int
104+
global_batch_size: Optional[int]
105+
micro_batch_size: int
106+
seq_length: Optional[int]
107+
eval_interval: int
108+
save_interval: int
109+
110+
# Optimizer
111+
finetune_lr: Optional[float]
112+
min_lr: float
113+
lr_warmup_iters: int
114+
lr_decay_iters: Optional[int]
115+
116+
# W&B logging
117+
wandb_project: Optional[str]
118+
wandb_entity: Optional[str]
119+
wandb_exp_name: Optional[str]
120+
121+
# Precision
122+
precision_config: Optional[Union[MixedPrecisionConfig, str]]
123+
124+
84125
def qwen3_30b_a3b_pretrain_config(**user_kwargs: Unpack[Qwen3MoeCommonKwargs]) -> ConfigContainer:
85126
"""Return a pre-training config for Qwen3-30B-A3B MoE.
86127
@@ -310,3 +351,255 @@ def _qwen3_moe_common(
310351
)
311352

312353
return cfg
354+
355+
356+
def qwen3_30b_a3b_finetune_config(**user_kwargs: Unpack[Qwen3MoeFinetuneKwargs]) -> ConfigContainer:
357+
"""Return a finetuning config for Qwen3-30B-A3B MoE.
358+
359+
Default configuration: 1 node, 8 GPUs, LoRA
360+
- LoRA (default): TP=4, PP=1, EP=4, LR=1e-4, dim=8, alpha=16, target_modules=['linear_qkv', 'linear_proj']
361+
- DoRA: TP=4, PP=1, EP=4, LR=1e-4, dim=8, alpha=16, target_modules=['linear_qkv', 'linear_proj']
362+
- Full SFT (peft=None): TP=4, PP=2, EP=4, LR=5e-6, SP=True
363+
364+
Matches NeMo2 recipe at nemo/collections/llm/recipes/qwen3_30b_a3b.py
365+
"""
366+
peft = user_kwargs.pop("peft", "lora")
367+
is_full_sft = peft is None or (isinstance(peft, str) and peft.lower() == "none")
368+
369+
# Auto-select LR if not specified
370+
finetune_lr = user_kwargs.get("finetune_lr")
371+
if finetune_lr is None:
372+
finetune_lr = 5e-6 if is_full_sft else 1e-4
373+
user_kwargs["finetune_lr"] = finetune_lr
374+
375+
# Build base config
376+
config = _qwen3_moe_finetune_common(hf_path="Qwen/Qwen3-30B-A3B", **user_kwargs)
377+
378+
# Model-specific parallelism settings (match NeMo pattern)
379+
if is_full_sft:
380+
config.model.tensor_model_parallel_size = 4
381+
config.model.expert_model_parallel_size = 4
382+
config.model.pipeline_model_parallel_size = 2
383+
config.model.expert_tensor_parallel_size = 1
384+
config.model.sequence_parallel = True
385+
config.peft = None
386+
else:
387+
# PEFT (LoRA, DoRA, or custom)
388+
config.model.tensor_model_parallel_size = 4
389+
config.model.expert_model_parallel_size = 4
390+
config.model.pipeline_model_parallel_size = 1
391+
config.model.expert_tensor_parallel_size = 1
392+
config.model.sequence_parallel = True
393+
394+
if isinstance(peft, str) and peft.lower() in ["lora", "dora"]:
395+
config.peft = default_peft_config(peft)
396+
config.peft.dim = 8
397+
config.peft.alpha = 16
398+
config.peft.target_modules = ["linear_qkv", "linear_proj"]
399+
else:
400+
config.peft = peft
401+
402+
return config
403+
404+
405+
def qwen3_235b_a22b_finetune_config(**user_kwargs: Unpack[Qwen3MoeFinetuneKwargs]) -> ConfigContainer:
406+
"""Return a finetuning config for Qwen3-235B-A22B MoE.
407+
408+
Default configuration: 8 nodes (LoRA) or 16 nodes (Full SFT), 8 GPUs per node
409+
- LoRA (default): TP=4, PP=4, EP=4, LR=1e-4, dim=8, alpha=16, target_modules=['linear_qkv', 'linear_proj']
410+
Total: 64 GPUs (8 nodes)
411+
- DoRA: TP=4, PP=4, EP=4, LR=1e-4, dim=8, alpha=16, target_modules=['linear_qkv', 'linear_proj']
412+
Total: 64 GPUs (8 nodes)
413+
- Full SFT (peft=None): TP=4, PP=16, EP=4, LR=5e-6, SP=True
414+
Total: 64 GPUs (8 nodes)
415+
416+
Matches NeMo2 recipe at nemo/collections/llm/recipes/qwen3_235b_a22b.py
417+
418+
Note: Uses account_for_embedding_in_pipeline_split and account_for_loss_in_pipeline_split
419+
for proper layer distribution in pipeline parallelism.
420+
"""
421+
peft = user_kwargs.pop("peft", "lora")
422+
is_full_sft = peft is None or (isinstance(peft, str) and peft.lower() == "none")
423+
424+
# Auto-select LR if not specified
425+
finetune_lr = user_kwargs.get("finetune_lr")
426+
if finetune_lr is None:
427+
finetune_lr = 5e-6 if is_full_sft else 1e-4
428+
user_kwargs["finetune_lr"] = finetune_lr
429+
430+
# Build base config
431+
config = _qwen3_moe_finetune_common(hf_path="Qwen/Qwen3-235B-A22B", **user_kwargs)
432+
433+
# Enable pipeline split accounting (required for 235B model)
434+
config.model.account_for_embedding_in_pipeline_split = True
435+
config.model.account_for_loss_in_pipeline_split = True
436+
437+
# Model-specific parallelism settings (match NeMo pattern)
438+
if is_full_sft:
439+
config.model.tensor_model_parallel_size = 4
440+
config.model.pipeline_model_parallel_size = 16
441+
config.model.expert_model_parallel_size = 4
442+
config.model.expert_tensor_parallel_size = 1
443+
config.model.sequence_parallel = True
444+
config.peft = None
445+
else:
446+
# PEFT (LoRA, DoRA, or custom)
447+
config.model.tensor_model_parallel_size = 4
448+
config.model.pipeline_model_parallel_size = 4
449+
config.model.expert_model_parallel_size = 4
450+
config.model.expert_tensor_parallel_size = 1
451+
config.model.sequence_parallel = True
452+
453+
if isinstance(peft, str) and peft.lower() in ["lora", "dora"]:
454+
config.peft = default_peft_config(peft)
455+
config.peft.dim = 8
456+
config.peft.alpha = 16
457+
config.peft.target_modules = ["linear_qkv", "linear_proj"]
458+
else:
459+
config.peft = peft
460+
461+
return config
462+
463+
464+
def _qwen3_moe_finetune_common(
465+
hf_path: str,
466+
dir: Optional[str] = None,
467+
name: str = "default",
468+
# Finetuning-specific
469+
pretrained_checkpoint: Optional[str] = None,
470+
packed_sequence: bool = False,
471+
# Training hyperparameters
472+
train_iters: int = 100,
473+
global_batch_size: Optional[int] = None,
474+
micro_batch_size: int = 1,
475+
seq_length: Optional[int] = None,
476+
eval_interval: int = 50,
477+
save_interval: int = 100,
478+
# Optimizer
479+
finetune_lr: Optional[float] = None,
480+
min_lr: float = 0.0,
481+
lr_warmup_iters: int = 10,
482+
lr_decay_iters: Optional[int] = None,
483+
# W&B logging
484+
wandb_project: Optional[str] = None,
485+
wandb_entity: Optional[str] = None,
486+
wandb_exp_name: Optional[str] = None,
487+
# Precision
488+
precision_config: Optional[Union[MixedPrecisionConfig, str]] = None,
489+
) -> ConfigContainer:
490+
"""
491+
Create a finetuning configuration for Qwen3 MoE models using a given HuggingFace path.
492+
493+
Args:
494+
hf_path (str): HuggingFace model path (e.g., "Qwen/Qwen3-30B-A3B", "Qwen/Qwen3-235B-A22B").
495+
dir (Optional[str]): Base directory for saving logs and checkpoints.
496+
name (str): Name of the finetuning run.
497+
pretrained_checkpoint (Optional[str]): Path to pretrained checkpoint to load.
498+
packed_sequence (bool): Whether to use packed sequences for training efficiency.
499+
train_iters (int): Total number of training iterations.
500+
global_batch_size (Optional[int]): Global batch size for training.
501+
micro_batch_size (int): Micro batch size for training.
502+
seq_length (Optional[int]): Sequence length for training data.
503+
eval_interval (int): Evaluation interval.
504+
save_interval (int): Checkpoint save interval.
505+
finetune_lr (Optional[float]): Learning rate for finetuning.
506+
min_lr (float): Minimum learning rate for cosine decay.
507+
lr_warmup_iters (int): Number of warmup iterations for the learning rate.
508+
lr_decay_iters (Optional[int]): Number of iterations over which to decay the LR.
509+
wandb_project (Optional[str]): Weights & Biases project name.
510+
wandb_entity (Optional[str]): Weights & Biases entity name.
511+
wandb_exp_name (Optional[str]): Weights & Biases experiment name.
512+
precision_config (Optional[Union[MixedPrecisionConfig, str]]): Precision configuration for the model.
513+
514+
Returns:
515+
ConfigContainer: Configuration for finetuning.
516+
"""
517+
# Default sequence length for finetuning
518+
if seq_length is None:
519+
seq_length = 2048 if packed_sequence else 4096
520+
521+
# Default global batch size
522+
if global_batch_size is None:
523+
global_batch_size = 32
524+
525+
base_output_dir = dir if dir is not None else os.path.join(os.getcwd(), "nemo_experiments")
526+
run_output_dir = os.path.join(base_output_dir, name)
527+
checkpoint_dir = os.path.join(run_output_dir, "checkpoints")
528+
tensorboard_dir = os.path.join(run_output_dir, "tb_logs")
529+
530+
bridge = AutoBridge.from_hf_pretrained(hf_path)
531+
model_cfg = bridge.to_megatron_provider(load_weights=False)
532+
533+
# Precision configuration
534+
if precision_config is None:
535+
precision_config = bf16_mixed()
536+
elif isinstance(precision_config, str):
537+
precision_config = get_mixed_precision_config(precision_config)
538+
539+
# Sequence length
540+
model_cfg.seq_length = seq_length
541+
model_cfg.cross_entropy_fusion_impl = "te"
542+
543+
# Optimizer and scheduler
544+
opt_config, scheduler = distributed_fused_adam_with_cosine_annealing(
545+
lr_warmup_iters=lr_warmup_iters,
546+
lr_decay_iters=lr_decay_iters if lr_decay_iters is not None else train_iters,
547+
max_lr=finetune_lr if finetune_lr is not None else 1e-4,
548+
min_lr=min_lr,
549+
)
550+
551+
# Dataset configuration (SQuAD by default)
552+
dataset_config = default_squad_config(seq_length=seq_length, packed_sequence=packed_sequence)
553+
554+
# W&B logger configuration
555+
logger_config = LoggerConfig(
556+
log_interval=10,
557+
tensorboard_dir=tensorboard_dir,
558+
log_timers_to_tensorboard=True,
559+
wandb_project=wandb_project,
560+
wandb_entity=wandb_entity,
561+
wandb_exp_name=wandb_exp_name,
562+
)
563+
564+
# Config Container
565+
cfg = ConfigContainer(
566+
model=model_cfg,
567+
train=TrainingConfig(
568+
train_iters=train_iters,
569+
eval_interval=eval_interval,
570+
eval_iters=10,
571+
global_batch_size=global_batch_size,
572+
micro_batch_size=micro_batch_size,
573+
manual_gc=True,
574+
manual_gc_interval=100,
575+
manual_gc_eval=100,
576+
),
577+
optimizer=opt_config,
578+
scheduler=scheduler,
579+
ddp=DistributedDataParallelConfig(
580+
check_for_nan_in_grad=True,
581+
grad_reduce_in_fp32=True,
582+
overlap_grad_reduce=True,
583+
overlap_param_gather=True,
584+
average_in_collective=True,
585+
use_distributed_optimizer=True,
586+
),
587+
dataset=dataset_config,
588+
logger=logger_config,
589+
tokenizer=TokenizerConfig(
590+
tokenizer_type="HuggingFaceTokenizer",
591+
tokenizer_model=hf_path,
592+
),
593+
checkpoint=CheckpointConfig(
594+
save_interval=save_interval,
595+
save=checkpoint_dir,
596+
load=checkpoint_dir,
597+
pretrained_checkpoint=pretrained_checkpoint,
598+
ckpt_format="torch_dist",
599+
fully_parallel_save=True,
600+
),
601+
rng=RNGConfig(seed=5678), # Different seed for finetuning
602+
mixed_precision=precision_config,
603+
)
604+
605+
return cfg

0 commit comments

Comments
 (0)