-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][xegpu] Add support for vector.extract_strided_slice XeGPU SIMT distribution with partial offsets.
#171512
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][xegpu] Add support for vector.extract_strided_slice XeGPU SIMT distribution with partial offsets.
#171512
Conversation
|
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Charitha Saumya (charithaintc) Changes
Case 1: Case 2: These two ops means the same thing, but case 2 is syntactic sugar to avoid specifying offsets for fully extracted dims. Currently case 2 fails in XeGPU SIMT distribution. This PR fixes this issue. Full diff: https://github.com/llvm/llvm-project/pull/171512.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index ca81c3cd7be42..bbea93101c54e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -99,6 +99,7 @@ getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
if (i < distributionStart)
continue;
+
// Check if the dimension can be distributed evenly.
if (dim % effectiveLaneLayout[i - distributionStart] != 0)
return failure();
@@ -1673,6 +1674,19 @@ struct VectorExtractStridedSliceDistribution
extractOp.getSizes(), [](Attribute attr) { return attr; });
SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
extractOp.getOffsets(), [](Attribute attr) { return attr; });
+ SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
+ extractOp.getStrides(), [](Attribute attr) { return attr; });
+ // If the provided sizes, offsets, strides are less than the rank, pad them
+ // with full sizes, zero offsets, and unit strides. This makes it easier to
+ // adjust them later.
+ int64_t sourceRank = extractOp.getSourceVectorType().getRank();
+ for (int64_t i = extractOp.getSizes().size(); i < sourceRank; ++i) {
+ updatedSizes.push_back(rewriter.getI64IntegerAttr(
+ extractOp.getSourceVectorType().getDimSize(i)));
+ updatedOffsets.push_back(rewriter.getI64IntegerAttr(0));
+ updatedStrides.push_back(
+ rewriter.getI64IntegerAttr(1)); // stride is always 1.
+ }
// If the result is distributed, it must be distributed in exactly one
// dimension. In this case, we adjust the sourceDistType, distributedSizes
// and distributedOffsets accordingly.
@@ -1708,7 +1722,7 @@ struct VectorExtractStridedSliceDistribution
// The offsets in the distributed dimention must be a multiple of subgroup
// size.
int64_t distrDimOffset =
- cast<IntegerAttr>(extractOp.getOffsets()[distributedDim]).getInt();
+ cast<IntegerAttr>(updatedOffsets[distributedDim]).getInt();
if (distrDimOffset % subgroupSize != 0)
return rewriter.notifyMatchFailure(
warpOp, "Offset along distributed dimension "
@@ -1737,7 +1751,7 @@ struct VectorExtractStridedSliceDistribution
rewriter, extractOp.getLoc(), distributedType, source,
ArrayAttr::get(rewriter.getContext(), updatedOffsets),
ArrayAttr::get(rewriter.getContext(), updatedSizes),
- extractOp.getStrides());
+ ArrayAttr::get(rewriter.getContext(), updatedStrides));
rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newExtractOp);
return success();
}
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index 216f3d19cff94..7819a438057c4 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -753,6 +753,27 @@ gpu.func @vector_extract_strided_slice_unsopported_source(%laneid: index) {
gpu.return
}
+// CHECK-LABEL: gpu.func @vector_extract_strided_slice_partial_offsets
+// CHECK-NEXT: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<8x1xf32>, vector<24x1xf32>) {
+// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<24x16xf32>
+// CHECK: gpu.yield %{{.*}}, %[[S]] : vector<8x16xf32>, vector<24x16xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1
+// CHECK-SAME: {offsets = [8, 0], sizes = [8, 1], strides = [1, 1]} : vector<24x1xf32> to vector<8x1xf32>
+// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<8x1xf32>) -> ()
+gpu.func @vector_extract_strided_slice_partial_offsets(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) {
+ %0 = "some_def"() : () -> (vector<24x16xf32>)
+ %1 = vector.extract_strided_slice %0 { offsets = [8], sizes = [8], strides = [1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<24x16xf32> to vector<8x16xf32>
+ gpu.yield %1 : vector<8x16xf32>
+ }
+ "some_use"(%r) : (vector<8x1xf32>) -> ()
+ gpu.return
+}
// CHECK-LABEL: gpu.func @vector_insert_strided_slice_distributed_dim_fully_inserted
// CHECK-NEXT: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x1xf32>, vector<16x1xf32>, vector<64x1xf32>) {
@@ -880,6 +901,31 @@ gpu.func @vector_insert_strided_slice_1d(%laneid: index) {
gpu.return
}
+// CHECK-LABEL: gpu.func @vector_insert_strided_slice_different_ranks
+// CHECK-NEXT: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x1xf32>, vector<1xf32>, vector<64x1xf32>) {
+// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<16xf32>
+// CHECK-NEXT: %[[D:.*]] = "some_def"() : () -> vector<64x16xf32>
+// CHECK: gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<64x16xf32>, vector<16xf32>, vector<64x16xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2
+// CHECK-SAME: {offsets = [13, 0], strides = [1]} : vector<1xf32> into vector<64x1xf32>
+// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<64x1xf32>) -> ()
+gpu.func @vector_insert_strided_slice_different_ranks(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<64x1xf32>) {
+ %0 = "some_def"() : () -> (vector<16xf32>)
+ %1 = "some_def"() : () -> (vector<64x16xf32>)
+ %2 = vector.insert_strided_slice %0, %1 { offsets = [13, 0], strides = [1],
+ layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<16xf32> into vector<64x16xf32>
+ gpu.yield %2 : vector<64x16xf32>
+ }
+ "some_use"(%r) : (vector<64x1xf32>) -> ()
+ gpu.return
+}
+
// CHECK-LABEL: gpu.func @vector_insert_strided_slice_unsupported_source
// CHECK: %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<3xf32>) {
// CHECK: }
|
| %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<64x1xf32>) { | ||
| %0 = "some_def"() : () -> (vector<16xf32>) | ||
| %1 = "some_def"() : () -> (vector<64x16xf32>) | ||
| %2 = vector.insert_strided_slice %0, %1 { offsets = [13, 0], strides = [1], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: how you make this work by only changing the extract_strided_slice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a missed corner case. it is unrelated to the code changes in PR.
Jianhui-Li
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
vector.extract_strided_slicecan have two forms when specifying offsets.Case 1:
Case 2:
These two ops means the same thing, but case 2 is syntactic sugar to avoid specifying offsets for fully extracted dims. Currently case 2 fails in XeGPU SIMT distribution. This PR fixes this issue.