Skip to content

Commit d4857c9

Browse files
XXX: Implement histogram as showcase for AccumulateOp.
1 parent fca969e commit d4857c9

File tree

3 files changed

+64
-6
lines changed

3 files changed

+64
-6
lines changed

experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,9 +365,10 @@ buildNextBody(AccumulateOp op, OpBuilder &builder, Value initialState,
365365
[&](OpBuilder &builder, Location loc) {
366366
ImplicitLocOpBuilder b(loc, builder);
367367

368-
// Don't modify state; return undef element.
369-
Value nextElement = b.create<UndefOp>(elementType);
370-
b.create<scf::YieldOp>(ValueRange{initialUpstreamState, nextElement});
368+
// Don't modify state; return init element.
369+
FuncOp initFunc = op.getInitFunc();
370+
Value initValue = b.create<func::CallOp>(initFunc)->getResult(0);
371+
b.create<scf::YieldOp>(ValueRange{initialUpstreamState, initValue});
371372
},
372373
/*elseBuilder=*/
373374
[&](OpBuilder &builder, Location loc) {

experimental/iterators/test/Conversion/IteratorsToLLVM/accumulate.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ func.func private @sum_struct(%lhs : !element_type, %rhs : !element_type) -> !el
2222
// CHECK-NEXT: %[[V1:.*]] = iterators.extractvalue %[[arg0:.*]][0] : !iterators.state<!iterators.state<i32>, i1>
2323
// CHECK-NEXT: %[[V2:.*]] = iterators.extractvalue %[[arg0]][1] : !iterators.state<!iterators.state<i32>, i1>
2424
// CHECK-NEXT: %[[V3:.*]]:2 = scf.if %[[V2]] -> (!iterators.state<i32>, !llvm.struct<(i32)>) {
25-
// CHECK-NEXT: %[[V4:.*]] = llvm.mlir.undef : !llvm.struct<(i32)>
25+
// CHECK-NEXT: %[[V4:.*]] = func.call @zero_struct() : () -> !llvm.struct<(i32)>
2626
// CHECK-NEXT: scf.yield %[[V1]], %[[V4]] : !iterators.state<i32>, !llvm.struct<(i32)>
2727
// CHECK-NEXT: } else {
2828
// CHECK-NEXT: %[[V4:.*]] = func.call @zero_struct() : () -> !llvm.struct<(i32)>

experimental/iterators/test/Integration/Dialect/Iterators/CPU/accumulate.mlir

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
// RUN: mlir-proto-opt %s \
22
// RUN: -convert-iterators-to-llvm \
33
// RUN: -convert-states-to-llvm \
4-
// RUN: -convert-func-to-llvm \
5-
// RUN: -convert-scf-to-cf -convert-cf-to-llvm \
4+
// RUN: -convert-scf-to-cf \
5+
// RUN: -arith-bufferize -func-bufferize -tensor-bufferize \
6+
// RUN: -convert-func-to-llvm -convert-memref-to-llvm \
7+
// RUN: -cse -reconcile-unrealized-casts \
68
// RUN: | mlir-cpu-runner -e main -entry-point-result=void \
79
// RUN: | FileCheck %s
810

911
!struct_i32 = !llvm.struct<(i32)>
1012
!struct_i32i32 = !llvm.struct<(i32, i32)>
13+
!struct_i32i32i32i32 = !llvm.struct<(i32, i32, i32, i32)>
1114
!struct_f32 = !llvm.struct<(f32)>
1215

1316
func.func private @init_sum_struct() -> !struct_i32 {
@@ -84,8 +87,62 @@ func.func @test_accumulate_avg_struct() {
8487
return
8588
}
8689

90+
func.func private @unpack_i32(%input : !struct_i32) -> i32 {
91+
%i = llvm.extractvalue %input[0 : index] : !struct_i32
92+
return %i : i32
93+
}
94+
95+
func.func private @init_histogram() -> tensor<4xi32> {
96+
%init = arith.constant dense<[0, 0, 0, 0]> : tensor<4xi32>
97+
return %init : tensor<4xi32>
98+
}
99+
100+
func.func private @accumulate_histogram(
101+
%hist : tensor<4xi32>, %val : i32) -> tensor<4xi32> {
102+
%idx = arith.index_cast %val : i32 to index
103+
%oldCount = tensor.extract %hist[%idx] : tensor<4xi32>
104+
%one = arith.constant 1 : i32
105+
%newCount = arith.addi %oldCount, %one : i32
106+
%newHist = tensor.insert %newCount into %hist[%idx] : tensor<4xi32>
107+
return %newHist : tensor<4xi32>
108+
}
109+
110+
func.func private @tensor_to_struct(%input : tensor<4xi32>) -> !struct_i32i32i32i32 {
111+
%idx0 = arith.constant 0 : index
112+
%idx1 = arith.constant 1 : index
113+
%idx2 = arith.constant 2 : index
114+
%idx3 = arith.constant 3 : index
115+
%i0 = tensor.extract %input[%idx0] : tensor<4xi32>
116+
%i1 = tensor.extract %input[%idx1] : tensor<4xi32>
117+
%i2 = tensor.extract %input[%idx2] : tensor<4xi32>
118+
%i3 = tensor.extract %input[%idx3] : tensor<4xi32>
119+
%structu = llvm.mlir.undef : !struct_i32i32i32i32
120+
%struct0 = llvm.insertvalue %i0, %structu[0 : index] : !struct_i32i32i32i32
121+
%struct1 = llvm.insertvalue %i1, %struct0[1 : index] : !struct_i32i32i32i32
122+
%struct2 = llvm.insertvalue %i2, %struct1[2 : index] : !struct_i32i32i32i32
123+
%struct3 = llvm.insertvalue %i3, %struct2[3 : index] : !struct_i32i32i32i32
124+
return %struct3 : !struct_i32i32i32i32
125+
}
126+
127+
func.func @test_accumulate_histogram() {
128+
%input = "iterators.constantstream"()
129+
{ value = [[0 : i32], [1 : i32], [1 : i32], [2 : i32]] }
130+
: () -> (!iterators.stream<!struct_i32>)
131+
%unpacked = "iterators.map"(%input) {mapFuncRef = @unpack_i32}
132+
: (!iterators.stream<!struct_i32>) -> (!iterators.stream<i32>)
133+
%accumulated = iterators.accumulate(%unpacked, @init_histogram,
134+
@accumulate_histogram)
135+
: (!iterators.stream<i32>) -> !iterators.stream<tensor<4xi32>>
136+
%transposed = "iterators.map"(%accumulated) {mapFuncRef = @tensor_to_struct}
137+
: (!iterators.stream<tensor<4xi32>>) -> (!iterators.stream<!struct_i32i32i32i32>)
138+
"iterators.sink"(%transposed) : (!iterators.stream<!struct_i32i32i32i32>) -> ()
139+
// CHECK: (1, 2, 1, 0)
140+
return
141+
}
142+
87143
func.func @main() {
88144
call @test_accumulate_sum_struct() : () -> ()
89145
call @test_accumulate_avg_struct() : () -> ()
146+
call @test_accumulate_histogram() : () -> ()
90147
return
91148
}

0 commit comments

Comments
 (0)