Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 46 additions & 15 deletions paddleformers/cli/train/auto_parallel/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@
from paddleformers.trainer.trainer import Trainer
from paddleformers.trainer.trainer_utils import set_seed
from paddleformers.transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForCausalLMPipe,
AutoTokenizer,
CosineAnnealingWithWarmupDecay,
LinearAnnealingWithWarmupDecay,
LlamaConfig,
LlamaForCausalLMNet,
LlamaPretrainingCriterionNet,
)
from paddleformers.transformers.configuration_utils import LlmMetaConfig
from paddleformers.utils.log import logger
Expand Down Expand Up @@ -145,7 +145,6 @@ def __init__(self, *args, **kwargs):


def run_auto_parallel(model_args, data_args, generating_args, training_args):

do_enable_linear_fused_grad_add = training_args.enable_linear_fused_grad_add
# do_enable_mp_async_allreduce = (
# training_args.enable_auto_parallel
Expand Down Expand Up @@ -203,14 +202,8 @@ def run_auto_parallel(model_args, data_args, generating_args, training_args):
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)

# TODO: only support llama model now
config_class = LlamaConfig
model_class = LlamaForCausalLMNet
criterion_class = LlamaPretrainingCriterionNet

config = config_class.from_pretrained(model_args.model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
# config = AutoConfig.from_pretrained(model_args.model_name_or_path)
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
LlmMetaConfig.set_llm_config(config, training_args)
config.use_fast_layer_norm = model_args.use_fast_layer_norm

Expand Down Expand Up @@ -276,6 +269,13 @@ def run_auto_parallel(model_args, data_args, generating_args, training_args):
if training_args.no_recompute_layers is not None:
training_args.no_recompute_layers.sort()

if training_args.use_intermediate_api:
config.run_single_model = True
config.tensor_parallel_degree = 1
config.sharding_parallel_degree = 1
config.sep_parallel_degree = 1
config.context_parallel_degree = 1

print("Final pre-training config:", config)

# Set the dtype for loading model
Expand All @@ -286,9 +286,41 @@ def run_auto_parallel(model_args, data_args, generating_args, training_args):
if training_args.bf16:
dtype = "bfloat16"

with paddle.LazyGuard():
model = model_class.from_config(config, dtype=dtype)
criterion = criterion_class(config)
model_class = AutoModelForCausalLM

if not training_args.enable_auto_parallel and training_args.pipeline_parallel_degree > 1:
model_class = AutoModelForCausalLMPipe
if "LLama" in str(config.architectures):
try:
from utils.register_reshard import register_pp_reshard_information

register_pp_reshard_information(config.num_hidden_layers)
except:
print("Not register llama pp reshard information.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

什么情况下会执行失败?不注册这个reshard会造成什么影响


architectures_to_check = {"Qwen2Moe", "DeepseekV2", "DeepseekV3"}
if (
any(architecture in str(config.architectures) for architecture in architectures_to_check)
and training_args.data_parallel_degree > 1
):
training_args.use_expert_parallel = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单卡模式下允许EP吗?


if model_args.continue_training:
# NOTE(gongenlei): new add
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note去掉

if training_args.autotuner_benchmark:
model = model_class.from_config(config, dtype=dtype)
else:
model = model_class.from_pretrained(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

热启时不需要参考下面 paddle.lazyGuard写法吗?

model_args.model_name_or_path,
config=config,
dtype=dtype,
)
else:
if training_args.enable_auto_parallel:
with paddle.LazyGuard():
model = model_class.from_config(config, dtype=dtype)
else:
model = model_class.from_config(config, dtype=dtype)

if training_args.recompute:

Expand Down Expand Up @@ -344,7 +376,6 @@ def fn(layer):

trainer = PretrainingTrainer(
model=model,
criterion=criterion,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset if training_args.do_train else None,
Expand Down
35 changes: 0 additions & 35 deletions paddleformers/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,41 +214,6 @@
"LlamaPretrainingCriterion",
"LlamaNTKScalingRotaryEmbedding",
],
"llama.modeling_auto": [
"enable_fuse_ffn_qkv_pass",
"LlamaDecoderLayerAuto",
"LlamaAttentionAuto",
"LlamaPretrainedModelAuto",
"LlamaLMHeadAuto",
"LlamaModelAuto",
"LlamaForCausalLM3DAuto",
"LlamaMLPAuto",
"get_mesh",
"LlamaRMSNormAuto",
"is_pp_enable",
"LlamaPretrainingCriterion3DAuto",
"global_mesh_starts_with_pp",
"scaled_dot_product_attention",
],
"llama.modeling_network": [
"LlamaPretrainedModelNet",
"layer_input_parallel_row_and_col_hook",
"LlamaModelNet",
"LlamaPretrainingCriterionNet",
"layer_input_replicate_hook",
"LlamaLMHeadNet",
"LlamaForCausalLMNetDPO",
"GlobalOutputNet",
"layer_input_parallel_row_hook",
"LlamaRMSNormNet",
"LlamaAttentionNet",
"scaled_dot_product_attention",
"ReshardLayer",
"LlamaForCausalLMNet",
"enable_fuse_ffn_qkv_pass",
"LlamaMLPNet",
"LlamaDecoderLayerNet",
],
"llama.modeling_pp": ["LlamaForCausalLMPipe"],
"llama.tokenizer": ["LlamaTokenizer", "Llama3Tokenizer"],
"llama.tokenizer_fast": ["LlamaTokenizerFast"],
Expand Down
10 changes: 10 additions & 0 deletions paddleformers/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,9 @@ class PretrainedConfig:
Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
model has a output word embedding layer.

run_single_model (`bool`, *optional*, defaults to `False`):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果这个是想表达非并行模式的话,名字并不表意,建议替换下,开发者可以更好理解,例如:run_without_parallelismrun_in_non_parallel_mode,如果这种模式下还允许dp和sharding的话,看看是否有更合适的名字

Whether to run the model in single card mode. When enabled, all parallel degree configurations will be disabled.

dtype (`str`, *optional*):
The `dtype` of the weights. This attribute can be used to initialize the model to a non-default `dtype`
(which is normally `float32`) and thus allow for optimal storage allocation. For example, if the saved
Expand Down Expand Up @@ -601,6 +604,13 @@ def __init__(self, **kwargs):
self.use_cache = kwargs.pop("use_cache", False)
self.tie_word_embeddings = kwargs.pop("tie_word_embeddings", True)

# for run model in single card mode
self.run_single_model = kwargs.pop("run_single_model", False)
if self.run_single_model:
self.tensor_parallel_degree = 1
self.sep_parallel_degree = 1
self.context_parallel_degree = 1

# for transformers fuse
self.fuse_linear = kwargs.pop("fuse_linear", False)
self.fuse_attention_qkv = kwargs.pop("fuse_attention_qkv", False)
Expand Down
35 changes: 0 additions & 35 deletions paddleformers/transformers/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,41 +50,6 @@
"LlamaPretrainingCriterion",
"LlamaNTKScalingRotaryEmbedding",
],
"modeling_auto": [
"enable_fuse_ffn_qkv_pass",
"LlamaDecoderLayerAuto",
"LlamaAttentionAuto",
"LlamaPretrainedModelAuto",
"LlamaLMHeadAuto",
"LlamaModelAuto",
"LlamaForCausalLM3DAuto",
"LlamaMLPAuto",
"get_mesh",
"LlamaRMSNormAuto",
"is_pp_enable",
"LlamaPretrainingCriterion3DAuto",
"global_mesh_starts_with_pp",
"scaled_dot_product_attention",
],
"modeling_network": [
"LlamaPretrainedModelNet",
"layer_input_parallel_row_and_col_hook",
"LlamaModelNet",
"LlamaPretrainingCriterionNet",
"layer_input_replicate_hook",
"LlamaLMHeadNet",
"LlamaForCausalLMNetDPO",
"GlobalOutputNet",
"layer_input_parallel_row_hook",
"LlamaRMSNormNet",
"LlamaAttentionNet",
"scaled_dot_product_attention",
"ReshardLayer",
"LlamaForCausalLMNet",
"enable_fuse_ffn_qkv_pass",
"LlamaMLPNet",
"LlamaDecoderLayerNet",
],
"modeling_pp": ["LlamaForCausalLMPipe"],
"tokenizer": ["LlamaTokenizer", "Llama3Tokenizer"],
"tokenizer_fast": ["LlamaTokenizerFast"],
Expand Down
40 changes: 40 additions & 0 deletions paddleformers/transformers/llama/auto_dist_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle.distributed as dist


def get_dist_config(model, prefix=""):
"""Generate distributed configuration for Llama model"""
if prefix != "":
assert prefix.endswith(".")

config = {
"mp_config": {
"parallelize_plan": {
f"{prefix}llama.embed_tokens": dist.ColWiseParallel(gather_output=True),
f"{prefix}llama.layers.*.self_attn.qkv_proj": dist.ColWiseParallel(),
f"{prefix}llama.layers.*.self_attn.q_proj": dist.ColWiseParallel(),
f"{prefix}llama.layers.*.self_attn.k_proj": dist.ColWiseParallel(),
f"{prefix}llama.layers.*.self_attn.v_proj": dist.ColWiseParallel(),
f"{prefix}llama.layers.*.self_attn.o_proj": dist.RowWiseParallel(),
f"{prefix}llama.layers.*.mlp.gate_proj": dist.ColWiseParallel(),
f"{prefix}llama.layers.*.mlp.up_proj": dist.ColWiseParallel(),
f"{prefix}llama.layers.*.mlp.gate_up_fused_proj": dist.ColWiseParallel(),
f"{prefix}llama.layers.*.mlp.down_proj": dist.RowWiseParallel(),
f"{prefix}lm_head.weight": dist.ColWiseParallel(),
}
},
}
return config
32 changes: 19 additions & 13 deletions paddleformers/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
get_skip_recompute_ops,
)
from ..refined_recompute import recompute as rr_recompute
from .auto_dist_config import get_dist_config

try:
from paddle.incubate.nn.functional import fused_rotary_position_embedding
Expand Down Expand Up @@ -178,22 +179,16 @@ def assign_kv_heads(num_kv_heads: int, num_gpus: int):
return assignment_list


def parallel_matmul(x: Tensor, y: Tensor, transpose_y=False, tensor_parallel_output=True):
is_fleet_init = True
tensor_parallel_degree = 1
try:
hcg = fleet.get_hybrid_communicate_group()
model_parallel_group = hcg.get_model_parallel_group()
tensor_parallel_degree = hcg.get_model_parallel_world_size()
except:
is_fleet_init = False

def parallel_matmul(
x: Tensor, y: Tensor, transpose_y=False, tensor_parallel_degree=1, tensor_parallel_output=True, args=None
):
if paddle.in_dynamic_mode():
y_is_distributed = y.is_distributed
else:
y_is_distributed = tensor_parallel_degree > 1

if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed:
if tensor_parallel_degree > 1 and y_is_distributed:
hcg = fleet.get_hybrid_communicate_group()
model_parallel_group = hcg.get_model_parallel_group()
# if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg'
input_parallel = paddle.distributed.collective._c_identity(x, group=model_parallel_group)
logits = paddle.matmul(input_parallel, y, transpose_y=transpose_y)
Expand Down Expand Up @@ -1974,17 +1969,24 @@ def forward(self, hidden_states, tensor_parallel_output=None):
if tensor_parallel_output is None:
tensor_parallel_output = self.config.tensor_parallel_output and self.config.tensor_parallel_degree > 1

tensor_parallel_degree = self.config.tensor_parallel_degree
if get_env_device() == "xpu" and self.xpu_parallel_matmul is not None:
logits = self.xpu_parallel_matmul(
hidden_states,
self.weight,
transpose_y=self.transpose_y,
tensor_parallel_degree=tensor_parallel_degree,
tensor_parallel_output=tensor_parallel_output,
training=self.training,
)
else:
logits = parallel_matmul(
hidden_states, self.weight, transpose_y=self.transpose_y, tensor_parallel_output=tensor_parallel_output
hidden_states,
self.weight,
transpose_y=self.transpose_y,
tensor_parallel_degree=tensor_parallel_degree,
tensor_parallel_output=tensor_parallel_output,
args=self.config,
)
return logits

Expand Down Expand Up @@ -2156,3 +2158,7 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

def auto_dist_config(self, prefix=""):
assert self.config.run_single_model, "Use `get_dist_config` only in single card mode."
return get_dist_config(self, prefix)
Loading