1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import paddle
1615import paddle .distributed as dist
17- from paddle .distributed .auto_parallel .intermediate .tensor_parallel import (
18- PrepareLayerInput ,
19- PrepareLayerOutput ,
20- )
21-
22-
23- def layer_input_parallel_row_hook (process_mesh ):
24- def hook (layer , inputs , output = None ):
25- res_inputs = []
26- for input in inputs :
27- if not input .is_dist ():
28- x = dist .shard_tensor (input , process_mesh , [dist .Shard (0 ), dist .Replicate (), dist .Replicate ()])
29- res_inputs .append (dist .reshard (x , process_mesh , [dist .Shard (0 ), dist .Replicate (), dist .Replicate ()]))
30- else :
31- res_inputs .append (
32- dist .reshard (input , process_mesh , [dist .Shard (0 ), dist .Replicate (), dist .Replicate ()])
33- )
34- return tuple (res_inputs )
35-
36- return hook
37-
38-
39- def layer_input_parallel_row_and_col_hook (process_mesh ):
40- def hook (layer , inputs , output = None ):
41- res_inputs = []
42- for input in inputs :
43- if not input .is_dist ():
44- x = dist .shard_tensor (input , process_mesh , [dist .Shard (0 ), dist .Replicate (), dist .Shard (1 )])
45- res_inputs .append (dist .reshard (x , process_mesh , [dist .Shard (0 ), dist .Replicate (), dist .Shard (1 )]))
46- else :
47- res_inputs .append (dist .reshard (input , process_mesh , [dist .Shard (0 ), dist .Replicate (), dist .Shard (1 )]))
48- return tuple (res_inputs )
49-
50- return hook
51-
52-
53- def layer_input_replicate_hook (process_mesh ):
54- def hook (layer , inputs , output = None ):
55- res_inputs = []
56- for input in inputs :
57- if not input .is_dist ():
58- x = dist .shard_tensor (input , process_mesh , [dist .Replicate (), dist .Replicate (), dist .Replicate ()])
59- res_inputs .append (
60- dist .reshard (x , process_mesh , [dist .Replicate (), dist .Replicate (), dist .Replicate ()])
61- )
62- else :
63- res_inputs .append (dist .reshard (input , process_mesh , [dist .Replicate (), dist .Replicate ()]))
64- return tuple (res_inputs )
65-
66- return hook
67-
68-
69- def layer_input_rope_hook (process_mesh ):
70- def hook (layer , inputs , output = None ):
71- res_inputs = []
72- batch_size = None
73- seq_length = None
74- process_mesh = None
75- placements = None
76- for index in range (len (inputs )):
77- if index == 0 :
78- batch_size , seq_length , _ , _ = inputs [index ]._local_shape
79- process_mesh = inputs [index ].process_mesh
80- placements = inputs [index ].placements
81- # process position_ids
82- if index == len (inputs ) - 1 :
83- mesh = dist .auto_parallel .get_mesh ()
84- assert "sep" in mesh .dim_names , f"mesh.dim_names:{ mesh .dim_names } must contain sep"
85- group = mesh ._get_group ("sep" )
86- chunk_size = seq_length // 2
87- chunk_num = group .nranks * 2
88- rank = group .rank
89- first_chunk_ids = paddle .arange (rank * chunk_size , (rank + 1 ) * chunk_size , dtype = "int64" )
90- second_chunk_ids = paddle .arange (
91- (chunk_num - rank - 1 ) * chunk_size , (chunk_num - rank ) * chunk_size , dtype = "int64"
92- )
93- position_ids = paddle .concat ([first_chunk_ids , second_chunk_ids ]).expand ((batch_size , seq_length ))
94- mp_axis = process_mesh .dim_names .index ("mp" )
95- placements [mp_axis ] = dist .Replicate () # mp placament shard(2) -> replicate
96- position_ids = dist .auto_parallel .api .dtensor_from_local (position_ids , process_mesh , placements )
97- res_inputs .append (position_ids )
98- else :
99- res_inputs .append (inputs [index ])
100- return tuple (res_inputs )
101-
102- return hook
103-
104-
105- def layer_output_rope_hook (process_mesh ):
106- def hook (layer , inputs , outputs ):
107- res_outputs = []
108- for output in outputs :
109- process_mesh = output .process_mesh
110- placements = output .placements
111- cp_index = process_mesh .dim_names .index ("sep" ) # get the axis for the split
112- cp_degree = process_mesh .shape [cp_index ]
113- assert cp_degree > 1 , f"cp_degree:{ cp_degree } must > 1"
114- placements [cp_index ] = dist .Shard (1 ) # seq_dim:1
115- output = dist .reshard (output , process_mesh , placements )
116- res_outputs .append (output )
117- return tuple (res_outputs )
118-
119- return hook
12016
12117
12218def get_dist_config (model , prefix = "" ):
@@ -125,36 +21,9 @@ def get_dist_config(model, prefix=""):
12521 assert prefix .endswith ("." )
12622
12723 config = {
128- "sp_config" : {
129- "parallelize_plan" : {
130- f"{ prefix } llama.embed_tokens" : [
131- dist .ColWiseParallel (),
132- dist .SequenceParallelBegin (),
133- ],
134- f"{ prefix } llama.reshard_row" : PrepareLayerInput (layer_input_parallel_row_hook ),
135- f"{ prefix } llama.reshard_row_and_col" : PrepareLayerInput (layer_input_parallel_row_and_col_hook ),
136- f"{ prefix } llama.global_layer.reshard_replicate" : PrepareLayerInput (layer_input_replicate_hook ),
137- f"{ prefix } llama.layers.*.self_attn.qkv_proj" : dist .ColWiseParallel (),
138- f"{ prefix } llama.layers.*.self_attn.q_proj" : dist .ColWiseParallel (),
139- f"{ prefix } llama.layers.*.self_attn.k_proj" : dist .ColWiseParallel (),
140- f"{ prefix } llama.layers.*.self_attn.v_proj" : dist .ColWiseParallel (),
141- f"{ prefix } llama.layers.*.self_attn.o_proj" : dist .RowWiseParallel (),
142- f"{ prefix } llama.layers.*.self_attn" : dist .SequenceParallelDisable (),
143- f"{ prefix } llama.layers.*.mlp.gate_proj" : dist .ColWiseParallel (),
144- f"{ prefix } llama.layers.*.mlp.up_proj" : dist .ColWiseParallel (),
145- f"{ prefix } llama.layers.*.mlp.gate_up_fused_proj" : dist .ColWiseParallel (),
146- f"{ prefix } llama.layers.*.mlp.down_proj" : dist .RowWiseParallel (),
147- f"{ prefix } llama.layers.*.mlp" : dist .SequenceParallelDisable (need_transpose = False ),
148- f"{ prefix } lm_head.weight" : dist .ColWiseParallel (),
149- f"{ prefix } lm_head" : dist .SequenceParallelEnd (),
150- }
151- },
15224 "mp_config" : {
15325 "parallelize_plan" : {
15426 f"{ prefix } llama.embed_tokens" : dist .ColWiseParallel (gather_output = True ),
155- f"{ prefix } llama.reshard_row" : PrepareLayerInput (layer_input_parallel_row_hook ),
156- f"{ prefix } llama.reshard_row_and_col" : PrepareLayerInput (layer_input_parallel_row_and_col_hook ),
157- f"{ prefix } llama.global_layer.reshard_replicate" : PrepareLayerInput (layer_input_replicate_hook ),
15827 f"{ prefix } llama.layers.*.self_attn.qkv_proj" : dist .ColWiseParallel (),
15928 f"{ prefix } llama.layers.*.self_attn.q_proj" : dist .ColWiseParallel (),
16029 f"{ prefix } llama.layers.*.self_attn.k_proj" : dist .ColWiseParallel (),
@@ -167,31 +36,5 @@ def get_dist_config(model, prefix=""):
16736 f"{ prefix } lm_head.weight" : dist .ColWiseParallel (),
16837 }
16938 },
170- "pp_config" : {"split_spec" : f"{ prefix } llama.layers" , "global_spec" : f"{ prefix } llama.global_layer" },
171- "cp_config" : {
172- "parallelize_plan" : {
173- f"{ prefix } llama.layers.*.self_attn.sdpa" : dist .ContextParallel (
174- backend = "p2p" if model .config .context_parallel_degree > 1 else "all2all"
175- ),
176- }
177- },
17839 }
179-
180- if model .config .context_parallel_degree > 1 :
181- config ["cp_config" ]["parallelize_plan" ].update (
182- {
183- f"{ prefix } llama.layers.*.self_attn.rope_func" : [
184- PrepareLayerInput (layer_input_rope_hook ),
185- PrepareLayerOutput (layer_output_rope_hook ),
186- ]
187- }
188- )
189- elif model .config .sep_parallel_degree > 1 :
190- # fuse_rope is not support dtensor spmd yet,thus need to extraly reshard sequence dim
191- config ["cp_config" ]["parallelize_plan" ].update (
192- {
193- f"{ prefix } llama.layers.*.self_attn.rope_func" : PrepareLayerOutput (layer_output_rope_hook ),
194- }
195- )
196-
19740 return config
0 commit comments