@@ -179,23 +179,16 @@ def assign_kv_heads(num_kv_heads: int, num_gpus: int):
179179 return assignment_list
180180
181181
182- def parallel_matmul (x : Tensor , y : Tensor , transpose_y = False , tensor_parallel_output = True , args = None ):
183- is_fleet_init = True
184- tensor_parallel_degree = 1
185- if args is None or not args .run_single_model :
186- try :
187- hcg = fleet .get_hybrid_communicate_group ()
188- model_parallel_group = hcg .get_model_parallel_group ()
189- tensor_parallel_degree = hcg .get_model_parallel_world_size ()
190- except :
191- is_fleet_init = False
192-
182+ def parallel_matmul (
183+ x : Tensor , y : Tensor , transpose_y = False , tensor_parallel_degree = 1 , tensor_parallel_output = True , args = None
184+ ):
193185 if paddle .in_dynamic_mode ():
194186 y_is_distributed = y .is_distributed
195187 else :
196188 y_is_distributed = tensor_parallel_degree > 1
197-
198- if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed :
189+ if tensor_parallel_degree > 1 and y_is_distributed :
190+ hcg = fleet .get_hybrid_communicate_group ()
191+ model_parallel_group = hcg .get_model_parallel_group ()
199192 # if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg'
200193 input_parallel = paddle .distributed .collective ._c_identity (x , group = model_parallel_group )
201194 logits = paddle .matmul (input_parallel , y , transpose_y = transpose_y )
@@ -1328,8 +1321,6 @@ def _get_hardware_flops(self):
13281321
13291322 @classmethod
13301323 def _get_name_mappings (cls , config : LlamaConfig ) -> list [StateDictNameMapping ]:
1331- if config .run_single_model :
1332- return cls ._get_name_mappings ()
13331324 mappings : list [StateDictNameMapping ] = []
13341325 model_mappings = [
13351326 ["embed_tokens.weight" ],
@@ -1364,8 +1355,6 @@ def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]:
13641355
13651356 @classmethod
13661357 def _get_tensor_parallel_mappings (cls , config : LlamaConfig , is_split = True ):
1367- if config .run_single_model :
1368- return {}
13691358 from ..conversion_utils import split_or_merge_func
13701359
13711360 fn = split_or_merge_func (
@@ -1425,8 +1414,6 @@ def get_tensor_parallel_split_mappings(num_layers):
14251414
14261415 @classmethod
14271416 def _get_fuse_or_split_param_mappings (cls , config : LlamaConfig , is_fuse = False ):
1428- if config .run_single_model :
1429- return cls ._get_fuse_or_split_param_mappings ()
14301417 # return parameter fuse utils
14311418 from ..conversion_utils import split_or_fuse_func
14321419
@@ -1981,11 +1968,13 @@ def forward(self, hidden_states, tensor_parallel_output=None):
19811968 if tensor_parallel_output is None :
19821969 tensor_parallel_output = self .config .tensor_parallel_output and self .config .tensor_parallel_degree > 1
19831970
1971+ tensor_parallel_degree = self .config .tensor_parallel_degree
19841972 if get_env_device () == "xpu" and self .xpu_parallel_matmul is not None :
19851973 logits = self .xpu_parallel_matmul (
19861974 hidden_states ,
19871975 self .weight ,
19881976 transpose_y = self .transpose_y ,
1977+ tensor_parallel_degree = tensor_parallel_degree ,
19891978 tensor_parallel_output = tensor_parallel_output ,
19901979 training = self .training ,
19911980 )
@@ -1994,6 +1983,7 @@ def forward(self, hidden_states, tensor_parallel_output=None):
19941983 hidden_states ,
19951984 self .weight ,
19961985 transpose_y = self .transpose_y ,
1986+ tensor_parallel_degree = tensor_parallel_degree ,
19971987 tensor_parallel_output = tensor_parallel_output ,
19981988 args = self .config ,
19991989 )
0 commit comments