Skip to content

Commit ac3913e

Browse files
tissue3facebook-github-bot
authored andcommitted
Add deduce_jagged_tensor_with_graph_analysis flag for batch dim distinguish
Summary: For vdd, it seems that the jagged tensor batch dim is identical to dense tensor batch dim, which caused issue in bmm kernel, that it cannot handle batch size as large as 2^16. This fix adds a flag `deduce_jagged_tensor_with_graph_analysis` so that when it is turnt on, we depend on graph analysis, i.e. `try_getting_jagged_tensor_map`, to deduce batch dim for jagged tensor. This can be more reliable than deducing based on value. Differential Revision: D49262422
1 parent 4ca3435 commit ac3913e

File tree

1 file changed

+4
-17
lines changed

1 file changed

+4
-17
lines changed

fx2ait/fx2ait/tensor_spec.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,9 @@ def _get_max_seq_lens_from_offsets(
306306

307307
return max_seq_lens
308308

309-
@staticmethod
310-
def _try_getting_jagged_tensor_map(
309+
@classmethod
310+
def try_getting_jagged_tensor_map(
311+
cls,
311312
inputs: List[torch.Tensor],
312313
jagged_tensor_batch_dims: Set[int],
313314
fx_inputs: Optional[List[torch.fx.Node]] = None,
@@ -371,6 +372,7 @@ def from_input_list_with_batch_size_jagged_tensor(
371372
additional_inputs: List[torch.Tensor] = None,
372373
infer_max_seq_lens_from_offsets: bool = False,
373374
fx_inputs: List[torch.fx.Node] = None,
375+
jagged_tensor_map: Optional[Dict[int, int]] = None,
374376
) -> List["TensorSpec"]:
375377
"""
376378
Most of the recommendation models will work fine using this function.
@@ -385,21 +387,6 @@ def from_input_list_with_batch_size_jagged_tensor(
385387
jagged_offsets_batch_dims=jagged_offsets_batch_dims,
386388
)
387389

388-
jagged_tensor_map = cls._try_getting_jagged_tensor_map(
389-
inputs=inputs,
390-
jagged_tensor_batch_dims=jagged_tensor_batch_dims,
391-
fx_inputs=fx_inputs,
392-
)
393-
if jagged_tensor_map:
394-
logger.info("Successfully detected a jagged_tensor_map:")
395-
for input_id, jagged_tensor_id in jagged_tensor_map.items():
396-
logger.info(f"{input_id=}, {jagged_tensor_id=}")
397-
else:
398-
logger.info(
399-
"Unable to detect a jagged_tensor_map: falling back "
400-
"to the batch dim-based jagged tensor detection."
401-
)
402-
403390
result: List = []
404391
result_unsorted: List = []
405392
left_inputs: List = []

0 commit comments

Comments
 (0)