Skip to content

Conversation

@vivekkhandelwal1
Copy link
Collaborator

@vivekkhandelwal1 vivekkhandelwal1 commented Dec 11, 2025

This commit adds the support to handle output tensor type consistency in case of dynamic dims for AtenTransposeInt op lowering by adding a cast operation to ensure the output tensor type matches the expected type.

Also, update tests in datamovement.mlir to reflect changes in tensor handling for dynamic dimensions in transpose operations.

Example IR:

  %int1 = torch.constant.int 1
  %int8 = torch.constant.int 8
  %int7 = torch.constant.int 7
  %int96 = torch.constant.int 96
  %int2 = torch.constant.int 2
  %int3 = torch.constant.int 3
  %0 = torch.prim.ListConstruct %int1, %int8, %int7, %int8, %int7, %int96 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1,56,56,96],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?,?,?,?],f32>
  %2 = torch.aten.transpose.int %1, %int2, %int3 : !torch.vtensor<[?,?,?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?,?,?],f32>

…oseInt op lowering

This commit adds the support to handle output tensor type consistency in case of dynamic dims for AtenTransposeInt op lowering by adding a cast operation to ensure the output tensor type matches the expected type.

Also, update tests in datamovement.mlir to reflect changes in tensor handling for dynamic dimensions in transpose operations.

Example IR:
"""
  %0 = torch.prim.ListConstruct %int1, %int8, %int7, %int8, %int7, %int96 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1,56,56,96],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?,?,?,?],f32>
  %2 = torch.aten.transpose.int %1, %int2, %int3 : !torch.vtensor<[?,?,?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?,?,?],f32>
"""
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%0 = torch.prim.ListConstruct %int1, %int8, %int7, %int8, %int7, %int96 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1,56,56,96],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?,?,?,?],f32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clarification: Is a similar change required for view op as well since the view output is dynamic shape in torch IR but fixed shape in the linalg IR as per the added test?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants