Skip to content

Conversation

@charithaintc
Copy link
Contributor

vector.extract_strided_slice can have two forms when specifying offsets.

Case 1:

%1 = vector.extract_strided_slice %0 { offsets = [8, 0], sizes = [8, 16], strides = [1, 1]}
      : vector<24x16xf32> to vector<8x16xf32>

Case 2:

%1 = vector.extract_strided_slice %0 { offsets = [8], sizes = [8], strides = [1]}
      : vector<24x16xf32> to vector<8x16xf32>

These two ops means the same thing, but case 2 is syntactic sugar to avoid specifying offsets for fully extracted dims. Currently case 2 fails in XeGPU SIMT distribution. This PR fixes this issue.

@llvmbot
Copy link
Member

llvmbot commented Dec 9, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Charitha Saumya (charithaintc)

Changes

vector.extract_strided_slice can have two forms when specifying offsets.

Case 1:

%1 = vector.extract_strided_slice %0 { offsets = [8, 0], sizes = [8, 16], strides = [1, 1]}
      : vector&lt;24x16xf32&gt; to vector&lt;8x16xf32&gt;

Case 2:

%1 = vector.extract_strided_slice %0 { offsets = [8], sizes = [8], strides = [1]}
      : vector&lt;24x16xf32&gt; to vector&lt;8x16xf32&gt;

These two ops means the same thing, but case 2 is syntactic sugar to avoid specifying offsets for fully extracted dims. Currently case 2 fails in XeGPU SIMT distribution. This PR fixes this issue.


Full diff: https://github.com/llvm/llvm-project/pull/171512.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp (+16-2)
  • (modified) mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir (+46)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index ca81c3cd7be42..bbea93101c54e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -99,6 +99,7 @@ getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
   for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
     if (i < distributionStart)
       continue;
+
     // Check if the dimension can be distributed evenly.
     if (dim % effectiveLaneLayout[i - distributionStart] != 0)
       return failure();
@@ -1673,6 +1674,19 @@ struct VectorExtractStridedSliceDistribution
         extractOp.getSizes(), [](Attribute attr) { return attr; });
     SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
         extractOp.getOffsets(), [](Attribute attr) { return attr; });
+    SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
+        extractOp.getStrides(), [](Attribute attr) { return attr; });
+    // If the provided sizes, offsets, strides are less than the rank, pad them
+    // with full sizes, zero offsets, and unit strides. This makes it easier to
+    // adjust them later.
+    int64_t sourceRank = extractOp.getSourceVectorType().getRank();
+    for (int64_t i = extractOp.getSizes().size(); i < sourceRank; ++i) {
+      updatedSizes.push_back(rewriter.getI64IntegerAttr(
+          extractOp.getSourceVectorType().getDimSize(i)));
+      updatedOffsets.push_back(rewriter.getI64IntegerAttr(0));
+      updatedStrides.push_back(
+          rewriter.getI64IntegerAttr(1)); // stride is always 1.
+    }
     // If the result is distributed, it must be distributed in exactly one
     // dimension. In this case, we adjust the sourceDistType, distributedSizes
     // and distributedOffsets accordingly.
@@ -1708,7 +1722,7 @@ struct VectorExtractStridedSliceDistribution
       // The offsets in the distributed dimention must be a multiple of subgroup
       // size.
       int64_t distrDimOffset =
-          cast<IntegerAttr>(extractOp.getOffsets()[distributedDim]).getInt();
+          cast<IntegerAttr>(updatedOffsets[distributedDim]).getInt();
       if (distrDimOffset % subgroupSize != 0)
         return rewriter.notifyMatchFailure(
             warpOp, "Offset along distributed dimension "
@@ -1737,7 +1751,7 @@ struct VectorExtractStridedSliceDistribution
         rewriter, extractOp.getLoc(), distributedType, source,
         ArrayAttr::get(rewriter.getContext(), updatedOffsets),
         ArrayAttr::get(rewriter.getContext(), updatedSizes),
-        extractOp.getStrides());
+        ArrayAttr::get(rewriter.getContext(), updatedStrides));
     rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newExtractOp);
     return success();
   }
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index 216f3d19cff94..7819a438057c4 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -753,6 +753,27 @@ gpu.func @vector_extract_strided_slice_unsopported_source(%laneid: index) {
   gpu.return
 }
 
+// CHECK-LABEL:  gpu.func @vector_extract_strided_slice_partial_offsets
+// CHECK-NEXT:      %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<8x1xf32>, vector<24x1xf32>) {
+// CHECK-NEXT:        %[[S:.*]] = "some_def"() : () -> vector<24x16xf32>
+// CHECK:             gpu.yield %{{.*}}, %[[S]] : vector<8x16xf32>, vector<24x16xf32>
+// CHECK-NEXT:      }
+// CHECK-NEXT:      %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1
+// CHECK-SAME:        {offsets = [8, 0], sizes = [8, 1], strides = [1, 1]} : vector<24x1xf32> to vector<8x1xf32>
+// CHECK-NEXT:      "some_use"(%[[T1]]) : (vector<8x1xf32>) -> ()
+gpu.func @vector_extract_strided_slice_partial_offsets(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) {
+    %0 = "some_def"() : () -> (vector<24x16xf32>)
+    %1 = vector.extract_strided_slice %0 { offsets = [8], sizes = [8], strides = [1],
+        layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+        layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+      }
+      : vector<24x16xf32> to vector<8x16xf32>
+    gpu.yield %1 : vector<8x16xf32>
+  }
+  "some_use"(%r) : (vector<8x1xf32>) -> ()
+  gpu.return
+}
 
 // CHECK-LABEL:  gpu.func @vector_insert_strided_slice_distributed_dim_fully_inserted
 // CHECK-NEXT:      %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x1xf32>, vector<16x1xf32>, vector<64x1xf32>) {
@@ -880,6 +901,31 @@ gpu.func @vector_insert_strided_slice_1d(%laneid: index) {
   gpu.return
 }
 
+// CHECK-LABEL: gpu.func @vector_insert_strided_slice_different_ranks
+// CHECK-NEXT:      %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x1xf32>, vector<1xf32>, vector<64x1xf32>) {
+// CHECK-NEXT:        %[[S:.*]] = "some_def"() : () -> vector<16xf32>
+// CHECK-NEXT:        %[[D:.*]] = "some_def"() : () -> vector<64x16xf32>
+// CHECK:             gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<64x16xf32>, vector<16xf32>, vector<64x16xf32>
+// CHECK-NEXT:      }
+// CHECK-NEXT:      %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2
+// CHECK-SAME:        {offsets = [13, 0], strides = [1]} : vector<1xf32> into vector<64x1xf32>
+// CHECK-NEXT:      "some_use"(%[[T1]]) : (vector<64x1xf32>) -> ()
+gpu.func @vector_insert_strided_slice_different_ranks(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<64x1xf32>) {
+    %0 = "some_def"() : () -> (vector<16xf32>)
+    %1 = "some_def"() : () -> (vector<64x16xf32>)
+    %2 = vector.insert_strided_slice %0, %1 { offsets = [13, 0],  strides = [1],
+      layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }
+      : vector<16xf32> into vector<64x16xf32>
+    gpu.yield %2 : vector<64x16xf32>
+  }
+  "some_use"(%r) : (vector<64x1xf32>) -> ()
+  gpu.return
+}
+
 // CHECK-LABEL:  gpu.func @vector_insert_strided_slice_unsupported_source
 // CHECK:          %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<3xf32>) {
 // CHECK:          }

@charithaintc charithaintc requested a review from nbpatel December 9, 2025 21:49
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<64x1xf32>) {
%0 = "some_def"() : () -> (vector<16xf32>)
%1 = "some_def"() : () -> (vector<64x16xf32>)
%2 = vector.insert_strided_slice %0, %1 { offsets = [13, 0], strides = [1],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: how you make this work by only changing the extract_strided_slice?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a missed corner case. it is unrelated to the code changes in PR.

Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@charithaintc charithaintc merged commit 3ece662 into llvm:main Dec 10, 2025
8 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants