Skip to content

Commit e603c11

Browse files
committed
fix parallel_matmul
1 parent 654a0a1 commit e603c11

File tree

1 file changed

+9
-19
lines changed

1 file changed

+9
-19
lines changed

paddleformers/transformers/llama/modeling.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -179,23 +179,16 @@ def assign_kv_heads(num_kv_heads: int, num_gpus: int):
179179
return assignment_list
180180

181181

182-
def parallel_matmul(x: Tensor, y: Tensor, transpose_y=False, tensor_parallel_output=True, args=None):
183-
is_fleet_init = True
184-
tensor_parallel_degree = 1
185-
if args is None or not args.run_single_model:
186-
try:
187-
hcg = fleet.get_hybrid_communicate_group()
188-
model_parallel_group = hcg.get_model_parallel_group()
189-
tensor_parallel_degree = hcg.get_model_parallel_world_size()
190-
except:
191-
is_fleet_init = False
192-
182+
def parallel_matmul(
183+
x: Tensor, y: Tensor, transpose_y=False, tensor_parallel_degree=1, tensor_parallel_output=True, args=None
184+
):
193185
if paddle.in_dynamic_mode():
194186
y_is_distributed = y.is_distributed
195187
else:
196188
y_is_distributed = tensor_parallel_degree > 1
197-
198-
if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed:
189+
if tensor_parallel_degree > 1 and y_is_distributed:
190+
hcg = fleet.get_hybrid_communicate_group()
191+
model_parallel_group = hcg.get_model_parallel_group()
199192
# if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg'
200193
input_parallel = paddle.distributed.collective._c_identity(x, group=model_parallel_group)
201194
logits = paddle.matmul(input_parallel, y, transpose_y=transpose_y)
@@ -1328,8 +1321,6 @@ def _get_hardware_flops(self):
13281321

13291322
@classmethod
13301323
def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]:
1331-
if config.run_single_model:
1332-
return cls._get_name_mappings()
13331324
mappings: list[StateDictNameMapping] = []
13341325
model_mappings = [
13351326
["embed_tokens.weight"],
@@ -1364,8 +1355,6 @@ def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]:
13641355

13651356
@classmethod
13661357
def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True):
1367-
if config.run_single_model:
1368-
return {}
13691358
from ..conversion_utils import split_or_merge_func
13701359

13711360
fn = split_or_merge_func(
@@ -1425,8 +1414,6 @@ def get_tensor_parallel_split_mappings(num_layers):
14251414

14261415
@classmethod
14271416
def _get_fuse_or_split_param_mappings(cls, config: LlamaConfig, is_fuse=False):
1428-
if config.run_single_model:
1429-
return cls._get_fuse_or_split_param_mappings()
14301417
# return parameter fuse utils
14311418
from ..conversion_utils import split_or_fuse_func
14321419

@@ -1981,11 +1968,13 @@ def forward(self, hidden_states, tensor_parallel_output=None):
19811968
if tensor_parallel_output is None:
19821969
tensor_parallel_output = self.config.tensor_parallel_output and self.config.tensor_parallel_degree > 1
19831970

1971+
tensor_parallel_degree = self.config.tensor_parallel_degree
19841972
if get_env_device() == "xpu" and self.xpu_parallel_matmul is not None:
19851973
logits = self.xpu_parallel_matmul(
19861974
hidden_states,
19871975
self.weight,
19881976
transpose_y=self.transpose_y,
1977+
tensor_parallel_degree=tensor_parallel_degree,
19891978
tensor_parallel_output=tensor_parallel_output,
19901979
training=self.training,
19911980
)
@@ -1994,6 +1983,7 @@ def forward(self, hidden_states, tensor_parallel_output=None):
19941983
hidden_states,
19951984
self.weight,
19961985
transpose_y=self.transpose_y,
1986+
tensor_parallel_degree=tensor_parallel_degree,
19971987
tensor_parallel_output=tensor_parallel_output,
19981988
args=self.config,
19991989
)

0 commit comments

Comments
 (0)