|
20 | 20 | from typing_extensions import TypedDict, Unpack |
21 | 21 |
|
22 | 22 | from megatron.bridge import AutoBridge |
| 23 | +from megatron.bridge.peft.base import PEFT |
23 | 24 | from megatron.bridge.recipes.utils.dataset_utils import get_blend_fields_from_data_paths |
| 25 | +from megatron.bridge.recipes.utils.finetune_utils import default_peft_config, default_squad_config |
24 | 26 | from megatron.bridge.recipes.utils.optimizer_utils import distributed_fused_adam_with_cosine_annealing |
25 | 27 | from megatron.bridge.recipes.utils.tokenizer_utils import DEFAULT_NULL_TOKENIZER_VOCAB_SIZE |
26 | 28 | from megatron.bridge.training.comm_overlap import CommOverlapConfig |
|
33 | 35 | TokenizerConfig, |
34 | 36 | TrainingConfig, |
35 | 37 | ) |
36 | | -from megatron.bridge.training.mixed_precision import MixedPrecisionConfig, bf16_mixed |
| 38 | +from megatron.bridge.training.mixed_precision import MixedPrecisionConfig, bf16_mixed, get_mixed_precision_config |
37 | 39 |
|
38 | 40 |
|
39 | 41 | class Qwen3MoeCommonKwargs(TypedDict, total=False): |
@@ -81,6 +83,45 @@ class Qwen3MoeCommonKwargs(TypedDict, total=False): |
81 | 83 | comm_overlap_config: Optional[CommOverlapConfig] |
82 | 84 |
|
83 | 85 |
|
| 86 | +class Qwen3MoeFinetuneKwargs(TypedDict, total=False): |
| 87 | + """Typed options accepted by Qwen3 MoE finetuning recipe helper functions. |
| 88 | +
|
| 89 | + This is separate from Qwen3MoeCommonKwargs to avoid confusion - finetuning |
| 90 | + uses SQuAD dataset by default, not the data path fields. |
| 91 | + """ |
| 92 | + |
| 93 | + # Core identifiers |
| 94 | + dir: Optional[str] |
| 95 | + name: str |
| 96 | + |
| 97 | + # Finetuning-specific |
| 98 | + pretrained_checkpoint: Optional[str] |
| 99 | + peft: Union[str, PEFT, None] |
| 100 | + packed_sequence: bool |
| 101 | + |
| 102 | + # Training hyperparameters |
| 103 | + train_iters: int |
| 104 | + global_batch_size: Optional[int] |
| 105 | + micro_batch_size: int |
| 106 | + seq_length: Optional[int] |
| 107 | + eval_interval: int |
| 108 | + save_interval: int |
| 109 | + |
| 110 | + # Optimizer |
| 111 | + finetune_lr: Optional[float] |
| 112 | + min_lr: float |
| 113 | + lr_warmup_iters: int |
| 114 | + lr_decay_iters: Optional[int] |
| 115 | + |
| 116 | + # W&B logging |
| 117 | + wandb_project: Optional[str] |
| 118 | + wandb_entity: Optional[str] |
| 119 | + wandb_exp_name: Optional[str] |
| 120 | + |
| 121 | + # Precision |
| 122 | + precision_config: Optional[Union[MixedPrecisionConfig, str]] |
| 123 | + |
| 124 | + |
84 | 125 | def qwen3_30b_a3b_pretrain_config(**user_kwargs: Unpack[Qwen3MoeCommonKwargs]) -> ConfigContainer: |
85 | 126 | """Return a pre-training config for Qwen3-30B-A3B MoE. |
86 | 127 |
|
@@ -310,3 +351,255 @@ def _qwen3_moe_common( |
310 | 351 | ) |
311 | 352 |
|
312 | 353 | return cfg |
| 354 | + |
| 355 | + |
| 356 | +def qwen3_30b_a3b_finetune_config(**user_kwargs: Unpack[Qwen3MoeFinetuneKwargs]) -> ConfigContainer: |
| 357 | + """Return a finetuning config for Qwen3-30B-A3B MoE. |
| 358 | +
|
| 359 | + Default configuration: 1 node, 8 GPUs, LoRA |
| 360 | + - LoRA (default): TP=4, PP=1, EP=4, LR=1e-4, dim=8, alpha=16, target_modules=['linear_qkv', 'linear_proj'] |
| 361 | + - DoRA: TP=4, PP=1, EP=4, LR=1e-4, dim=8, alpha=16, target_modules=['linear_qkv', 'linear_proj'] |
| 362 | + - Full SFT (peft=None): TP=4, PP=2, EP=4, LR=5e-6, SP=True |
| 363 | +
|
| 364 | + Matches NeMo2 recipe at nemo/collections/llm/recipes/qwen3_30b_a3b.py |
| 365 | + """ |
| 366 | + peft = user_kwargs.pop("peft", "lora") |
| 367 | + is_full_sft = peft is None or (isinstance(peft, str) and peft.lower() == "none") |
| 368 | + |
| 369 | + # Auto-select LR if not specified |
| 370 | + finetune_lr = user_kwargs.get("finetune_lr") |
| 371 | + if finetune_lr is None: |
| 372 | + finetune_lr = 5e-6 if is_full_sft else 1e-4 |
| 373 | + user_kwargs["finetune_lr"] = finetune_lr |
| 374 | + |
| 375 | + # Build base config |
| 376 | + config = _qwen3_moe_finetune_common(hf_path="Qwen/Qwen3-30B-A3B", **user_kwargs) |
| 377 | + |
| 378 | + # Model-specific parallelism settings (match NeMo pattern) |
| 379 | + if is_full_sft: |
| 380 | + config.model.tensor_model_parallel_size = 4 |
| 381 | + config.model.expert_model_parallel_size = 4 |
| 382 | + config.model.pipeline_model_parallel_size = 2 |
| 383 | + config.model.expert_tensor_parallel_size = 1 |
| 384 | + config.model.sequence_parallel = True |
| 385 | + config.peft = None |
| 386 | + else: |
| 387 | + # PEFT (LoRA, DoRA, or custom) |
| 388 | + config.model.tensor_model_parallel_size = 4 |
| 389 | + config.model.expert_model_parallel_size = 4 |
| 390 | + config.model.pipeline_model_parallel_size = 1 |
| 391 | + config.model.expert_tensor_parallel_size = 1 |
| 392 | + config.model.sequence_parallel = True |
| 393 | + |
| 394 | + if isinstance(peft, str) and peft.lower() in ["lora", "dora"]: |
| 395 | + config.peft = default_peft_config(peft) |
| 396 | + config.peft.dim = 8 |
| 397 | + config.peft.alpha = 16 |
| 398 | + config.peft.target_modules = ["linear_qkv", "linear_proj"] |
| 399 | + else: |
| 400 | + config.peft = peft |
| 401 | + |
| 402 | + return config |
| 403 | + |
| 404 | + |
| 405 | +def qwen3_235b_a22b_finetune_config(**user_kwargs: Unpack[Qwen3MoeFinetuneKwargs]) -> ConfigContainer: |
| 406 | + """Return a finetuning config for Qwen3-235B-A22B MoE. |
| 407 | +
|
| 408 | + Default configuration: 8 nodes (LoRA) or 16 nodes (Full SFT), 8 GPUs per node |
| 409 | + - LoRA (default): TP=4, PP=4, EP=4, LR=1e-4, dim=8, alpha=16, target_modules=['linear_qkv', 'linear_proj'] |
| 410 | + Total: 64 GPUs (8 nodes) |
| 411 | + - DoRA: TP=4, PP=4, EP=4, LR=1e-4, dim=8, alpha=16, target_modules=['linear_qkv', 'linear_proj'] |
| 412 | + Total: 64 GPUs (8 nodes) |
| 413 | + - Full SFT (peft=None): TP=4, PP=16, EP=4, LR=5e-6, SP=True |
| 414 | + Total: 64 GPUs (8 nodes) |
| 415 | +
|
| 416 | + Matches NeMo2 recipe at nemo/collections/llm/recipes/qwen3_235b_a22b.py |
| 417 | +
|
| 418 | + Note: Uses account_for_embedding_in_pipeline_split and account_for_loss_in_pipeline_split |
| 419 | + for proper layer distribution in pipeline parallelism. |
| 420 | + """ |
| 421 | + peft = user_kwargs.pop("peft", "lora") |
| 422 | + is_full_sft = peft is None or (isinstance(peft, str) and peft.lower() == "none") |
| 423 | + |
| 424 | + # Auto-select LR if not specified |
| 425 | + finetune_lr = user_kwargs.get("finetune_lr") |
| 426 | + if finetune_lr is None: |
| 427 | + finetune_lr = 5e-6 if is_full_sft else 1e-4 |
| 428 | + user_kwargs["finetune_lr"] = finetune_lr |
| 429 | + |
| 430 | + # Build base config |
| 431 | + config = _qwen3_moe_finetune_common(hf_path="Qwen/Qwen3-235B-A22B", **user_kwargs) |
| 432 | + |
| 433 | + # Enable pipeline split accounting (required for 235B model) |
| 434 | + config.model.account_for_embedding_in_pipeline_split = True |
| 435 | + config.model.account_for_loss_in_pipeline_split = True |
| 436 | + |
| 437 | + # Model-specific parallelism settings (match NeMo pattern) |
| 438 | + if is_full_sft: |
| 439 | + config.model.tensor_model_parallel_size = 4 |
| 440 | + config.model.pipeline_model_parallel_size = 16 |
| 441 | + config.model.expert_model_parallel_size = 4 |
| 442 | + config.model.expert_tensor_parallel_size = 1 |
| 443 | + config.model.sequence_parallel = True |
| 444 | + config.peft = None |
| 445 | + else: |
| 446 | + # PEFT (LoRA, DoRA, or custom) |
| 447 | + config.model.tensor_model_parallel_size = 4 |
| 448 | + config.model.pipeline_model_parallel_size = 4 |
| 449 | + config.model.expert_model_parallel_size = 4 |
| 450 | + config.model.expert_tensor_parallel_size = 1 |
| 451 | + config.model.sequence_parallel = True |
| 452 | + |
| 453 | + if isinstance(peft, str) and peft.lower() in ["lora", "dora"]: |
| 454 | + config.peft = default_peft_config(peft) |
| 455 | + config.peft.dim = 8 |
| 456 | + config.peft.alpha = 16 |
| 457 | + config.peft.target_modules = ["linear_qkv", "linear_proj"] |
| 458 | + else: |
| 459 | + config.peft = peft |
| 460 | + |
| 461 | + return config |
| 462 | + |
| 463 | + |
| 464 | +def _qwen3_moe_finetune_common( |
| 465 | + hf_path: str, |
| 466 | + dir: Optional[str] = None, |
| 467 | + name: str = "default", |
| 468 | + # Finetuning-specific |
| 469 | + pretrained_checkpoint: Optional[str] = None, |
| 470 | + packed_sequence: bool = False, |
| 471 | + # Training hyperparameters |
| 472 | + train_iters: int = 100, |
| 473 | + global_batch_size: Optional[int] = None, |
| 474 | + micro_batch_size: int = 1, |
| 475 | + seq_length: Optional[int] = None, |
| 476 | + eval_interval: int = 50, |
| 477 | + save_interval: int = 100, |
| 478 | + # Optimizer |
| 479 | + finetune_lr: Optional[float] = None, |
| 480 | + min_lr: float = 0.0, |
| 481 | + lr_warmup_iters: int = 10, |
| 482 | + lr_decay_iters: Optional[int] = None, |
| 483 | + # W&B logging |
| 484 | + wandb_project: Optional[str] = None, |
| 485 | + wandb_entity: Optional[str] = None, |
| 486 | + wandb_exp_name: Optional[str] = None, |
| 487 | + # Precision |
| 488 | + precision_config: Optional[Union[MixedPrecisionConfig, str]] = None, |
| 489 | +) -> ConfigContainer: |
| 490 | + """ |
| 491 | + Create a finetuning configuration for Qwen3 MoE models using a given HuggingFace path. |
| 492 | +
|
| 493 | + Args: |
| 494 | + hf_path (str): HuggingFace model path (e.g., "Qwen/Qwen3-30B-A3B", "Qwen/Qwen3-235B-A22B"). |
| 495 | + dir (Optional[str]): Base directory for saving logs and checkpoints. |
| 496 | + name (str): Name of the finetuning run. |
| 497 | + pretrained_checkpoint (Optional[str]): Path to pretrained checkpoint to load. |
| 498 | + packed_sequence (bool): Whether to use packed sequences for training efficiency. |
| 499 | + train_iters (int): Total number of training iterations. |
| 500 | + global_batch_size (Optional[int]): Global batch size for training. |
| 501 | + micro_batch_size (int): Micro batch size for training. |
| 502 | + seq_length (Optional[int]): Sequence length for training data. |
| 503 | + eval_interval (int): Evaluation interval. |
| 504 | + save_interval (int): Checkpoint save interval. |
| 505 | + finetune_lr (Optional[float]): Learning rate for finetuning. |
| 506 | + min_lr (float): Minimum learning rate for cosine decay. |
| 507 | + lr_warmup_iters (int): Number of warmup iterations for the learning rate. |
| 508 | + lr_decay_iters (Optional[int]): Number of iterations over which to decay the LR. |
| 509 | + wandb_project (Optional[str]): Weights & Biases project name. |
| 510 | + wandb_entity (Optional[str]): Weights & Biases entity name. |
| 511 | + wandb_exp_name (Optional[str]): Weights & Biases experiment name. |
| 512 | + precision_config (Optional[Union[MixedPrecisionConfig, str]]): Precision configuration for the model. |
| 513 | +
|
| 514 | + Returns: |
| 515 | + ConfigContainer: Configuration for finetuning. |
| 516 | + """ |
| 517 | + # Default sequence length for finetuning |
| 518 | + if seq_length is None: |
| 519 | + seq_length = 2048 if packed_sequence else 4096 |
| 520 | + |
| 521 | + # Default global batch size |
| 522 | + if global_batch_size is None: |
| 523 | + global_batch_size = 32 |
| 524 | + |
| 525 | + base_output_dir = dir if dir is not None else os.path.join(os.getcwd(), "nemo_experiments") |
| 526 | + run_output_dir = os.path.join(base_output_dir, name) |
| 527 | + checkpoint_dir = os.path.join(run_output_dir, "checkpoints") |
| 528 | + tensorboard_dir = os.path.join(run_output_dir, "tb_logs") |
| 529 | + |
| 530 | + bridge = AutoBridge.from_hf_pretrained(hf_path) |
| 531 | + model_cfg = bridge.to_megatron_provider(load_weights=False) |
| 532 | + |
| 533 | + # Precision configuration |
| 534 | + if precision_config is None: |
| 535 | + precision_config = bf16_mixed() |
| 536 | + elif isinstance(precision_config, str): |
| 537 | + precision_config = get_mixed_precision_config(precision_config) |
| 538 | + |
| 539 | + # Sequence length |
| 540 | + model_cfg.seq_length = seq_length |
| 541 | + model_cfg.cross_entropy_fusion_impl = "te" |
| 542 | + |
| 543 | + # Optimizer and scheduler |
| 544 | + opt_config, scheduler = distributed_fused_adam_with_cosine_annealing( |
| 545 | + lr_warmup_iters=lr_warmup_iters, |
| 546 | + lr_decay_iters=lr_decay_iters if lr_decay_iters is not None else train_iters, |
| 547 | + max_lr=finetune_lr if finetune_lr is not None else 1e-4, |
| 548 | + min_lr=min_lr, |
| 549 | + ) |
| 550 | + |
| 551 | + # Dataset configuration (SQuAD by default) |
| 552 | + dataset_config = default_squad_config(seq_length=seq_length, packed_sequence=packed_sequence) |
| 553 | + |
| 554 | + # W&B logger configuration |
| 555 | + logger_config = LoggerConfig( |
| 556 | + log_interval=10, |
| 557 | + tensorboard_dir=tensorboard_dir, |
| 558 | + log_timers_to_tensorboard=True, |
| 559 | + wandb_project=wandb_project, |
| 560 | + wandb_entity=wandb_entity, |
| 561 | + wandb_exp_name=wandb_exp_name, |
| 562 | + ) |
| 563 | + |
| 564 | + # Config Container |
| 565 | + cfg = ConfigContainer( |
| 566 | + model=model_cfg, |
| 567 | + train=TrainingConfig( |
| 568 | + train_iters=train_iters, |
| 569 | + eval_interval=eval_interval, |
| 570 | + eval_iters=10, |
| 571 | + global_batch_size=global_batch_size, |
| 572 | + micro_batch_size=micro_batch_size, |
| 573 | + manual_gc=True, |
| 574 | + manual_gc_interval=100, |
| 575 | + manual_gc_eval=100, |
| 576 | + ), |
| 577 | + optimizer=opt_config, |
| 578 | + scheduler=scheduler, |
| 579 | + ddp=DistributedDataParallelConfig( |
| 580 | + check_for_nan_in_grad=True, |
| 581 | + grad_reduce_in_fp32=True, |
| 582 | + overlap_grad_reduce=True, |
| 583 | + overlap_param_gather=True, |
| 584 | + average_in_collective=True, |
| 585 | + use_distributed_optimizer=True, |
| 586 | + ), |
| 587 | + dataset=dataset_config, |
| 588 | + logger=logger_config, |
| 589 | + tokenizer=TokenizerConfig( |
| 590 | + tokenizer_type="HuggingFaceTokenizer", |
| 591 | + tokenizer_model=hf_path, |
| 592 | + ), |
| 593 | + checkpoint=CheckpointConfig( |
| 594 | + save_interval=save_interval, |
| 595 | + save=checkpoint_dir, |
| 596 | + load=checkpoint_dir, |
| 597 | + pretrained_checkpoint=pretrained_checkpoint, |
| 598 | + ckpt_format="torch_dist", |
| 599 | + fully_parallel_save=True, |
| 600 | + ), |
| 601 | + rng=RNGConfig(seed=5678), # Different seed for finetuning |
| 602 | + mixed_precision=precision_config, |
| 603 | + ) |
| 604 | + |
| 605 | + return cfg |
0 commit comments