Skip to content

Commit 3ece662

Browse files
authored
[mlir][xegpu] Add support for vector.extract_strided_slice XeGPU SIMT distribution with partial offsets. (#171512)
`vector.extract_strided_slice` can have two forms when specifying offsets. Case 1: ``` %1 = vector.extract_strided_slice %0 { offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<24x16xf32> to vector<8x16xf32> ``` Case 2: ``` %1 = vector.extract_strided_slice %0 { offsets = [8], sizes = [8], strides = [1]} : vector<24x16xf32> to vector<8x16xf32> ``` These two ops means the same thing, but case 2 is syntactic sugar to avoid specifying offsets for fully extracted dims. Currently case 2 fails in XeGPU SIMT distribution. This PR fixes this issue.
1 parent 097ac33 commit 3ece662

File tree

2 files changed

+61
-2
lines changed

2 files changed

+61
-2
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1673,6 +1673,19 @@ struct VectorExtractStridedSliceDistribution
16731673
extractOp.getSizes(), [](Attribute attr) { return attr; });
16741674
SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
16751675
extractOp.getOffsets(), [](Attribute attr) { return attr; });
1676+
SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
1677+
extractOp.getStrides(), [](Attribute attr) { return attr; });
1678+
// If the provided sizes, offsets, strides are less than the rank, pad them
1679+
// with full sizes, zero offsets, and unit strides. This makes it easier to
1680+
// adjust them later.
1681+
int64_t sourceRank = extractOp.getSourceVectorType().getRank();
1682+
for (int64_t i = extractOp.getSizes().size(); i < sourceRank; ++i) {
1683+
updatedSizes.push_back(rewriter.getI64IntegerAttr(
1684+
extractOp.getSourceVectorType().getDimSize(i)));
1685+
updatedOffsets.push_back(rewriter.getI64IntegerAttr(0));
1686+
updatedStrides.push_back(
1687+
rewriter.getI64IntegerAttr(1)); // stride is always 1.
1688+
}
16761689
// If the result is distributed, it must be distributed in exactly one
16771690
// dimension. In this case, we adjust the sourceDistType, distributedSizes
16781691
// and distributedOffsets accordingly.
@@ -1708,7 +1721,7 @@ struct VectorExtractStridedSliceDistribution
17081721
// The offsets in the distributed dimention must be a multiple of subgroup
17091722
// size.
17101723
int64_t distrDimOffset =
1711-
cast<IntegerAttr>(extractOp.getOffsets()[distributedDim]).getInt();
1724+
cast<IntegerAttr>(updatedOffsets[distributedDim]).getInt();
17121725
if (distrDimOffset % subgroupSize != 0)
17131726
return rewriter.notifyMatchFailure(
17141727
warpOp, "Offset along distributed dimension "
@@ -1737,7 +1750,7 @@ struct VectorExtractStridedSliceDistribution
17371750
rewriter, extractOp.getLoc(), distributedType, source,
17381751
ArrayAttr::get(rewriter.getContext(), updatedOffsets),
17391752
ArrayAttr::get(rewriter.getContext(), updatedSizes),
1740-
extractOp.getStrides());
1753+
ArrayAttr::get(rewriter.getContext(), updatedStrides));
17411754
rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newExtractOp);
17421755
return success();
17431756
}

mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,27 @@ gpu.func @vector_extract_strided_slice_unsopported_source(%laneid: index) {
753753
gpu.return
754754
}
755755

756+
// CHECK-LABEL: gpu.func @vector_extract_strided_slice_partial_offsets
757+
// CHECK-NEXT: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<8x1xf32>, vector<24x1xf32>) {
758+
// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<24x16xf32>
759+
// CHECK: gpu.yield %{{.*}}, %[[S]] : vector<8x16xf32>, vector<24x16xf32>
760+
// CHECK-NEXT: }
761+
// CHECK-NEXT: %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1
762+
// CHECK-SAME: {offsets = [8, 0], sizes = [8, 1], strides = [1, 1]} : vector<24x1xf32> to vector<8x1xf32>
763+
// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<8x1xf32>) -> ()
764+
gpu.func @vector_extract_strided_slice_partial_offsets(%laneid: index) {
765+
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) {
766+
%0 = "some_def"() : () -> (vector<24x16xf32>)
767+
%1 = vector.extract_strided_slice %0 { offsets = [8], sizes = [8], strides = [1],
768+
layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
769+
layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
770+
}
771+
: vector<24x16xf32> to vector<8x16xf32>
772+
gpu.yield %1 : vector<8x16xf32>
773+
}
774+
"some_use"(%r) : (vector<8x1xf32>) -> ()
775+
gpu.return
776+
}
756777

757778
// CHECK-LABEL: gpu.func @vector_insert_strided_slice_distributed_dim_fully_inserted
758779
// CHECK-NEXT: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x1xf32>, vector<16x1xf32>, vector<64x1xf32>) {
@@ -880,6 +901,31 @@ gpu.func @vector_insert_strided_slice_1d(%laneid: index) {
880901
gpu.return
881902
}
882903

904+
// CHECK-LABEL: gpu.func @vector_insert_strided_slice_different_ranks
905+
// CHECK-NEXT: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x1xf32>, vector<1xf32>, vector<64x1xf32>) {
906+
// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<16xf32>
907+
// CHECK-NEXT: %[[D:.*]] = "some_def"() : () -> vector<64x16xf32>
908+
// CHECK: gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<64x16xf32>, vector<16xf32>, vector<64x16xf32>
909+
// CHECK-NEXT: }
910+
// CHECK-NEXT: %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2
911+
// CHECK-SAME: {offsets = [13, 0], strides = [1]} : vector<1xf32> into vector<64x1xf32>
912+
// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<64x1xf32>) -> ()
913+
gpu.func @vector_insert_strided_slice_different_ranks(%laneid: index) {
914+
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<64x1xf32>) {
915+
%0 = "some_def"() : () -> (vector<16xf32>)
916+
%1 = "some_def"() : () -> (vector<64x16xf32>)
917+
%2 = vector.insert_strided_slice %0, %1 { offsets = [13, 0], strides = [1],
918+
layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
919+
layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
920+
layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
921+
}
922+
: vector<16xf32> into vector<64x16xf32>
923+
gpu.yield %2 : vector<64x16xf32>
924+
}
925+
"some_use"(%r) : (vector<64x1xf32>) -> ()
926+
gpu.return
927+
}
928+
883929
// CHECK-LABEL: gpu.func @vector_insert_strided_slice_unsupported_source
884930
// CHECK: %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<3xf32>) {
885931
// CHECK: }

0 commit comments

Comments
 (0)