Skip to content

Commit 8e83c57

Browse files
committed
adapt workflow in auto parallel
1 parent e603c11 commit 8e83c57

File tree

3 files changed

+47
-172
lines changed

3 files changed

+47
-172
lines changed

paddleformers/cli/train/auto_parallel/workflow.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@
2727
from paddleformers.trainer.trainer import Trainer
2828
from paddleformers.trainer.trainer_utils import set_seed
2929
from paddleformers.transformers import (
30+
AutoConfig,
31+
AutoModelForCausalLM,
32+
AutoModelForCausalLMPipe,
3033
AutoTokenizer,
3134
CosineAnnealingWithWarmupDecay,
3235
LinearAnnealingWithWarmupDecay,
33-
LlamaConfig,
34-
LlamaForCausalLMNet,
35-
LlamaPretrainingCriterionNet,
3636
)
3737
from paddleformers.transformers.configuration_utils import LlmMetaConfig
3838
from paddleformers.utils.log import logger
@@ -145,7 +145,6 @@ def __init__(self, *args, **kwargs):
145145

146146

147147
def run_auto_parallel(model_args, data_args, generating_args, training_args):
148-
149148
do_enable_linear_fused_grad_add = training_args.enable_linear_fused_grad_add
150149
# do_enable_mp_async_allreduce = (
151150
# training_args.enable_auto_parallel
@@ -203,14 +202,8 @@ def run_auto_parallel(model_args, data_args, generating_args, training_args):
203202
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
204203
)
205204

206-
# TODO: only support llama model now
207-
config_class = LlamaConfig
208-
model_class = LlamaForCausalLMNet
209-
criterion_class = LlamaPretrainingCriterionNet
210-
211-
config = config_class.from_pretrained(model_args.model_name_or_path)
212205
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
213-
# config = AutoConfig.from_pretrained(model_args.model_name_or_path)
206+
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
214207
LlmMetaConfig.set_llm_config(config, training_args)
215208
config.use_fast_layer_norm = model_args.use_fast_layer_norm
216209

@@ -276,6 +269,13 @@ def run_auto_parallel(model_args, data_args, generating_args, training_args):
276269
if training_args.no_recompute_layers is not None:
277270
training_args.no_recompute_layers.sort()
278271

272+
if training_args.use_intermediate_api:
273+
config.run_single_model = True
274+
config.tensor_parallel_degree = 1
275+
config.sharding_parallel_degree = 1
276+
config.sep_parallel_degree = 1
277+
config.context_parallel_degree = 1
278+
279279
print("Final pre-training config:", config)
280280

281281
# Set the dtype for loading model
@@ -286,9 +286,41 @@ def run_auto_parallel(model_args, data_args, generating_args, training_args):
286286
if training_args.bf16:
287287
dtype = "bfloat16"
288288

289-
with paddle.LazyGuard():
290-
model = model_class.from_config(config, dtype=dtype)
291-
criterion = criterion_class(config)
289+
model_class = AutoModelForCausalLM
290+
291+
if not training_args.enable_auto_parallel and training_args.pipeline_parallel_degree > 1:
292+
model_class = AutoModelForCausalLMPipe
293+
if "LLama" in str(config.architectures):
294+
try:
295+
from utils.register_reshard import register_pp_reshard_information
296+
297+
register_pp_reshard_information(config.num_hidden_layers)
298+
except:
299+
print("Not register llama pp reshard information.")
300+
301+
architectures_to_check = {"Qwen2Moe", "DeepseekV2", "DeepseekV3"}
302+
if (
303+
any(architecture in str(config.architectures) for architecture in architectures_to_check)
304+
and training_args.data_parallel_degree > 1
305+
):
306+
training_args.use_expert_parallel = True
307+
308+
if model_args.continue_training:
309+
# NOTE(gongenlei): new add
310+
if training_args.autotuner_benchmark:
311+
model = model_class.from_config(config, dtype=dtype)
312+
else:
313+
model = model_class.from_pretrained(
314+
model_args.model_name_or_path,
315+
config=config,
316+
dtype=dtype,
317+
)
318+
else:
319+
if training_args.enable_auto_parallel:
320+
with paddle.LazyGuard():
321+
model = model_class.from_config(config, dtype=dtype)
322+
else:
323+
model = model_class.from_config(config, dtype=dtype)
292324

293325
if training_args.recompute:
294326

@@ -344,7 +376,6 @@ def fn(layer):
344376

345377
trainer = PretrainingTrainer(
346378
model=model,
347-
criterion=criterion,
348379
args=training_args,
349380
data_collator=data_collator,
350381
train_dataset=train_dataset if training_args.do_train else None,

paddleformers/transformers/llama/auto_dist_config.py

Lines changed: 0 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -12,111 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import paddle
1615
import paddle.distributed as dist
17-
from paddle.distributed.auto_parallel.intermediate.tensor_parallel import (
18-
PrepareLayerInput,
19-
PrepareLayerOutput,
20-
)
21-
22-
23-
def layer_input_parallel_row_hook(process_mesh):
24-
def hook(layer, inputs, output=None):
25-
res_inputs = []
26-
for input in inputs:
27-
if not input.is_dist():
28-
x = dist.shard_tensor(input, process_mesh, [dist.Shard(0), dist.Replicate(), dist.Replicate()])
29-
res_inputs.append(dist.reshard(x, process_mesh, [dist.Shard(0), dist.Replicate(), dist.Replicate()]))
30-
else:
31-
res_inputs.append(
32-
dist.reshard(input, process_mesh, [dist.Shard(0), dist.Replicate(), dist.Replicate()])
33-
)
34-
return tuple(res_inputs)
35-
36-
return hook
37-
38-
39-
def layer_input_parallel_row_and_col_hook(process_mesh):
40-
def hook(layer, inputs, output=None):
41-
res_inputs = []
42-
for input in inputs:
43-
if not input.is_dist():
44-
x = dist.shard_tensor(input, process_mesh, [dist.Shard(0), dist.Replicate(), dist.Shard(1)])
45-
res_inputs.append(dist.reshard(x, process_mesh, [dist.Shard(0), dist.Replicate(), dist.Shard(1)]))
46-
else:
47-
res_inputs.append(dist.reshard(input, process_mesh, [dist.Shard(0), dist.Replicate(), dist.Shard(1)]))
48-
return tuple(res_inputs)
49-
50-
return hook
51-
52-
53-
def layer_input_replicate_hook(process_mesh):
54-
def hook(layer, inputs, output=None):
55-
res_inputs = []
56-
for input in inputs:
57-
if not input.is_dist():
58-
x = dist.shard_tensor(input, process_mesh, [dist.Replicate(), dist.Replicate(), dist.Replicate()])
59-
res_inputs.append(
60-
dist.reshard(x, process_mesh, [dist.Replicate(), dist.Replicate(), dist.Replicate()])
61-
)
62-
else:
63-
res_inputs.append(dist.reshard(input, process_mesh, [dist.Replicate(), dist.Replicate()]))
64-
return tuple(res_inputs)
65-
66-
return hook
67-
68-
69-
def layer_input_rope_hook(process_mesh):
70-
def hook(layer, inputs, output=None):
71-
res_inputs = []
72-
batch_size = None
73-
seq_length = None
74-
process_mesh = None
75-
placements = None
76-
for index in range(len(inputs)):
77-
if index == 0:
78-
batch_size, seq_length, _, _ = inputs[index]._local_shape
79-
process_mesh = inputs[index].process_mesh
80-
placements = inputs[index].placements
81-
# process position_ids
82-
if index == len(inputs) - 1:
83-
mesh = dist.auto_parallel.get_mesh()
84-
assert "sep" in mesh.dim_names, f"mesh.dim_names:{mesh.dim_names} must contain sep"
85-
group = mesh._get_group("sep")
86-
chunk_size = seq_length // 2
87-
chunk_num = group.nranks * 2
88-
rank = group.rank
89-
first_chunk_ids = paddle.arange(rank * chunk_size, (rank + 1) * chunk_size, dtype="int64")
90-
second_chunk_ids = paddle.arange(
91-
(chunk_num - rank - 1) * chunk_size, (chunk_num - rank) * chunk_size, dtype="int64"
92-
)
93-
position_ids = paddle.concat([first_chunk_ids, second_chunk_ids]).expand((batch_size, seq_length))
94-
mp_axis = process_mesh.dim_names.index("mp")
95-
placements[mp_axis] = dist.Replicate() # mp placament shard(2) -> replicate
96-
position_ids = dist.auto_parallel.api.dtensor_from_local(position_ids, process_mesh, placements)
97-
res_inputs.append(position_ids)
98-
else:
99-
res_inputs.append(inputs[index])
100-
return tuple(res_inputs)
101-
102-
return hook
103-
104-
105-
def layer_output_rope_hook(process_mesh):
106-
def hook(layer, inputs, outputs):
107-
res_outputs = []
108-
for output in outputs:
109-
process_mesh = output.process_mesh
110-
placements = output.placements
111-
cp_index = process_mesh.dim_names.index("sep") # get the axis for the split
112-
cp_degree = process_mesh.shape[cp_index]
113-
assert cp_degree > 1, f"cp_degree:{cp_degree} must > 1"
114-
placements[cp_index] = dist.Shard(1) # seq_dim:1
115-
output = dist.reshard(output, process_mesh, placements)
116-
res_outputs.append(output)
117-
return tuple(res_outputs)
118-
119-
return hook
12016

12117

12218
def get_dist_config(model, prefix=""):
@@ -125,36 +21,9 @@ def get_dist_config(model, prefix=""):
12521
assert prefix.endswith(".")
12622

12723
config = {
128-
"sp_config": {
129-
"parallelize_plan": {
130-
f"{prefix}llama.embed_tokens": [
131-
dist.ColWiseParallel(),
132-
dist.SequenceParallelBegin(),
133-
],
134-
f"{prefix}llama.reshard_row": PrepareLayerInput(layer_input_parallel_row_hook),
135-
f"{prefix}llama.reshard_row_and_col": PrepareLayerInput(layer_input_parallel_row_and_col_hook),
136-
f"{prefix}llama.global_layer.reshard_replicate": PrepareLayerInput(layer_input_replicate_hook),
137-
f"{prefix}llama.layers.*.self_attn.qkv_proj": dist.ColWiseParallel(),
138-
f"{prefix}llama.layers.*.self_attn.q_proj": dist.ColWiseParallel(),
139-
f"{prefix}llama.layers.*.self_attn.k_proj": dist.ColWiseParallel(),
140-
f"{prefix}llama.layers.*.self_attn.v_proj": dist.ColWiseParallel(),
141-
f"{prefix}llama.layers.*.self_attn.o_proj": dist.RowWiseParallel(),
142-
f"{prefix}llama.layers.*.self_attn": dist.SequenceParallelDisable(),
143-
f"{prefix}llama.layers.*.mlp.gate_proj": dist.ColWiseParallel(),
144-
f"{prefix}llama.layers.*.mlp.up_proj": dist.ColWiseParallel(),
145-
f"{prefix}llama.layers.*.mlp.gate_up_fused_proj": dist.ColWiseParallel(),
146-
f"{prefix}llama.layers.*.mlp.down_proj": dist.RowWiseParallel(),
147-
f"{prefix}llama.layers.*.mlp": dist.SequenceParallelDisable(need_transpose=False),
148-
f"{prefix}lm_head.weight": dist.ColWiseParallel(),
149-
f"{prefix}lm_head": dist.SequenceParallelEnd(),
150-
}
151-
},
15224
"mp_config": {
15325
"parallelize_plan": {
15426
f"{prefix}llama.embed_tokens": dist.ColWiseParallel(gather_output=True),
155-
f"{prefix}llama.reshard_row": PrepareLayerInput(layer_input_parallel_row_hook),
156-
f"{prefix}llama.reshard_row_and_col": PrepareLayerInput(layer_input_parallel_row_and_col_hook),
157-
f"{prefix}llama.global_layer.reshard_replicate": PrepareLayerInput(layer_input_replicate_hook),
15827
f"{prefix}llama.layers.*.self_attn.qkv_proj": dist.ColWiseParallel(),
15928
f"{prefix}llama.layers.*.self_attn.q_proj": dist.ColWiseParallel(),
16029
f"{prefix}llama.layers.*.self_attn.k_proj": dist.ColWiseParallel(),
@@ -167,31 +36,5 @@ def get_dist_config(model, prefix=""):
16736
f"{prefix}lm_head.weight": dist.ColWiseParallel(),
16837
}
16938
},
170-
"pp_config": {"split_spec": f"{prefix}llama.layers", "global_spec": f"{prefix}llama.global_layer"},
171-
"cp_config": {
172-
"parallelize_plan": {
173-
f"{prefix}llama.layers.*.self_attn.sdpa": dist.ContextParallel(
174-
backend="p2p" if model.config.context_parallel_degree > 1 else "all2all"
175-
),
176-
}
177-
},
17839
}
179-
180-
if model.config.context_parallel_degree > 1:
181-
config["cp_config"]["parallelize_plan"].update(
182-
{
183-
f"{prefix}llama.layers.*.self_attn.rope_func": [
184-
PrepareLayerInput(layer_input_rope_hook),
185-
PrepareLayerOutput(layer_output_rope_hook),
186-
]
187-
}
188-
)
189-
elif model.config.sep_parallel_degree > 1:
190-
# fuse_rope is not support dtensor spmd yet,thus need to extraly reshard sequence dim
191-
config["cp_config"]["parallelize_plan"].update(
192-
{
193-
f"{prefix}llama.layers.*.self_attn.rope_func": PrepareLayerOutput(layer_output_rope_hook),
194-
}
195-
)
196-
19740
return config

paddleformers/transformers/llama/modeling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,6 +1355,7 @@ def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]:
13551355

13561356
@classmethod
13571357
def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True):
1358+
13581359
from ..conversion_utils import split_or_merge_func
13591360

13601361
fn = split_or_merge_func(

0 commit comments

Comments
 (0)