Skip to content

Commit a2e617a

Browse files
PISA: switch from i32 to mod_arith.int< .. : i32>
1 parent aa805b2 commit a2e617a

File tree

10 files changed

+93
-71
lines changed

10 files changed

+93
-71
lines changed

lib/Dialect/PISA/IR/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ cc_library(
3737
deps = [
3838
":dialect_inc_gen",
3939
":ops_inc_gen",
40+
"@heir//lib/Dialect/ModArith/IR:Dialect",
4041
"@llvm-project//llvm:Support",
41-
"@llvm-project//mlir:ArithDialect",
4242
"@llvm-project//mlir:IR",
4343
"@llvm-project//mlir:InferTypeOpInterface",
4444
"@llvm-project//mlir:Support",

lib/Dialect/PISA/IR/PISAOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef LIB_DIALECT_PISA_IR_PISAOPS_H_
22
#define LIB_DIALECT_PISA_IR_PISAOPS_H_
33

4+
#include "lib/Dialect/ModArith/IR/ModArithTypes.h" // required for the type predicate we use
45
#include "lib/Dialect/PISA/IR/PISADialect.h"
56
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
67
#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project

lib/Dialect/PISA/IR/PISAOps.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,16 @@ include "mlir/IR/OpBase.td"
88
include "mlir/Interfaces/InferTypeOpInterface.td"
99
include "mlir/Interfaces/SideEffectInterfaces.td"
1010

11+
// We only accept tensors of mod_arith with 32-bit typed moduli.
12+
// Note that we do NOT allow moduli that are concretely less than 32 bits but have a larger type (e.g., I64)
13+
// as those allow the compiler to emit code that relies on temporarily using up to 64 bits before mod-reducing.
1114
def Tensor8192I32 : TypeConstraint<CPred<[{
1215
mlir::isa<mlir::RankedTensorType>($_self) &&
1316
mlir::cast<mlir::RankedTensorType>($_self).getRank() == 1 &&
1417
mlir::cast<mlir::RankedTensorType>($_self).getDimSize(0) == 8192 &&
15-
mlir::cast<mlir::RankedTensorType>($_self).getElementType().isInteger(32)
16-
}]>, "tensor<8192xi32>">;
18+
llvm::isa<mlir::heir::mod_arith::ModArithType>(mlir::cast<mlir::RankedTensorType>($_self).getElementType()) &&
19+
mlir::cast<mlir::heir::mod_arith::ModArithType>(mlir::cast<mlir::RankedTensorType>($_self).getElementType()).getModulus().getType().isInteger(32)
20+
}]>, "tensor<8192xmod_arith.int< ... : i32>>">;
1721

1822
class PISA_Op<string mnemonic, list<Trait> traits = [Pure]> :
1923
Op<PISA_Dialect, mnemonic, traits> {

lib/Target/PISA/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ cc_library(
1313
],
1414
deps = [
1515
"@heir//lib/Analysis/SelectVariableNames",
16+
"@heir//lib/Dialect/ModArith/IR:Dialect",
1617
"@heir//lib/Dialect/PISA/IR:Dialect",
1718
"@heir//lib/Utils:TargetUtils",
1819
"@llvm-project//llvm:Support",

lib/Target/PISA/PISAEmitter.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "lib/Target/PISA/PISAEmitter.h"
22

33
#include "lib/Analysis/SelectVariableNames/SelectVariableNames.h"
4+
#include "lib/Dialect/ModArith/IR/ModArithDialect.h"
45
#include "lib/Dialect/PISA/IR/PISAOps.h"
56
#include "lib/Utils/TargetUtils.h" // from @llvm-project
67
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
@@ -23,7 +24,8 @@ void registerToPISATranslation() {
2324
},
2425
[](DialectRegistry &registry) {
2526
registry.insert<arith::ArithDialect, func::FuncDialect,
26-
tensor::TensorDialect, pisa::PISADialect>();
27+
tensor::TensorDialect, pisa::PISADialect,
28+
mod_arith::ModArithDialect>();
2729
});
2830
}
2931

tests/Dialect/PISA/IR/invalid.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: heir-opt --split-input-file --verify-diagnostics %s 2>&1
2+
3+
// -----
4+
// CHECK-NOT: test_invalid_tensor_length
5+
func.func @test_invalid_tensor_length(%arg0 : tensor<1024x!mod_arith.int<33538049:i32>>, %arg1 : tensor<1024x!mod_arith.int<33538049:i32>>) -> tensor<1024x!mod_arith.int<33538049:i32>> {
6+
// expected-error@below {{'pisa.add' op operand #0 must be tensor<8192xmod_arith.int< ... : i32>>, but got 'tensor<1024x!mod_arith.int<33538049 : i32>>'}}
7+
%0 = pisa.add %arg0, %arg1 {q = 2147483647 : i32, i = 0 : i32} : tensor<1024x!mod_arith.int<33538049:i32>>
8+
return %0 : tensor<1024x!mod_arith.int<33538049:i32>>
9+
}
10+
11+
// -----
12+
// CHECK-NOT: test_invalid_tensor_modulus_type
13+
func.func @test_invalid_tensor_modulus_type(%arg0 : tensor<8192x!mod_arith.int<33538049:i64>>, %arg1 : tensor<8192x!mod_arith.int<33538049:i64>>) -> tensor<8192x!mod_arith.int<33538049:i64>> {
14+
// expected-error@below {{'pisa.add' op operand #0 must be tensor<8192xmod_arith.int< ... : i32>>, but got 'tensor<8192x!mod_arith.int<33538049 : i64>>'}}
15+
%0 = pisa.add %arg0, %arg1 {q = 2147483647 : i32, i = 0 : i32} : tensor<8192x!mod_arith.int<33538049:i64>>
16+
return %0 : tensor<8192x!mod_arith.int<33538049:i64>>
17+
}
18+
19+
// -----
20+
// CHECK-NOT: test_invalid_modulus
21+
func.func @test_invalid_modulus(%arg0 : tensor<8192x!mod_arith.int<33538049:i32>>, %arg1 : tensor<8192x!mod_arith.int<33538049:i32>>) -> tensor<8192x!mod_arith.int<33538049:i32>> {
22+
// expected-error@below {{custom op 'pisa.add' 'pisa.add' op attribute 'q' failed to satisfy constraint: 32-bit signless integer attribute}}
23+
%0 = pisa.add %arg0, %arg1 {q = 18446744073709551557, i = 0 : i32} : tensor<8192x!mod_arith.int<33538049:i32>>
24+
return %0 : tensor<8192x!mod_arith.int<33538049:i32>>
25+
}

tests/Dialect/PISA/IR/ops.mlir

Lines changed: 39 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,58 @@
1-
// RUN: not heir-opt --verify-diagnostics --split-input-file %s 2>&1 | FileCheck %s
1+
// RUN: heir-opt %s | FileCheck %s
22

33
// This simply tests for syntax.
4+
!m32 = !mod_arith.int<33538049:i32>
45

56
// CHECK-LABEL: test_padd
6-
func.func @test_padd(%arg0 : tensor<8192xi32>, %arg1 : tensor<8192xi32>) -> tensor<8192xi32> {
7-
%0 = pisa.add %arg0, %arg1 {q = 2147483647 : i32, i = 0 : i32} : tensor<8192xi32>
8-
return %0 : tensor<8192xi32>
7+
func.func @test_padd(%arg0 : tensor<8192x!m32>, %arg1 : tensor<8192x!m32>) -> tensor<8192x!m32> {
8+
%0 = pisa.add %arg0, %arg1 {q = 2147483647 : i32, i = 0 : i32} : tensor<8192x!m32>
9+
return %0 : tensor<8192x!m32>
910
}
1011

1112
// CHECK-LABEL: test_psub
12-
func.func @test_psub(%arg0 : tensor<8192xi32>, %arg1 : tensor<8192xi32>) -> tensor<8192xi32> {
13-
%0 = pisa.sub %arg0, %arg1 {q = 2147483647 : i32, i = 0 : i32} : tensor<8192xi32>
14-
return %0 : tensor<8192xi32>
13+
func.func @test_psub(%arg0 : tensor<8192x!m32>, %arg1 : tensor<8192x!m32>) -> tensor<8192x!m32> {
14+
%0 = pisa.sub %arg0, %arg1 {q = 2147483647 : i32, i = 0 : i32} : tensor<8192x!m32>
15+
return %0 : tensor<8192x!m32>
1516
}
1617

1718
// CHECK-LABEL: test_pmul
18-
func.func @test_pmul(%arg0 : tensor<8192xi32>, %arg1 : tensor<8192xi32>) -> tensor<8192xi32> {
19-
%0 = pisa.mul %arg0, %arg1 {q = 2147483647 : i32, i = 0 : i32} : tensor<8192xi32>
20-
return %0 : tensor<8192xi32>
19+
func.func @test_pmul(%arg0 : tensor<8192x!m32>, %arg1 : tensor<8192x!m32>) -> tensor<8192x!m32> {
20+
%0 = pisa.mul %arg0, %arg1 {q = 2147483647 : i32, i = 0 : i32} : tensor<8192x!m32>
21+
return %0 : tensor<8192x!m32>
2122
}
2223

2324
// CHECK-LABEL: test_pmuli
24-
func.func @test_pmuli(%arg0 : tensor<8192xi32>) -> tensor<8192xi32> {
25-
%0 = pisa.muli %arg0 {q = 2147483647 : i32, i = 0 : i32, imm = 5 : i32} : tensor<8192xi32>
26-
return %0 : tensor<8192xi32>
25+
func.func @test_pmuli(%arg0 : tensor<8192x!m32>) -> tensor<8192x!m32> {
26+
%0 = pisa.muli %arg0 {q = 2147483647 : i32, i = 0 : i32, imm = 5 : i32} : tensor<8192x!m32>
27+
return %0 : tensor<8192x!m32>
2728
}
2829

2930
// CHECK-LABEL: test_pmac
30-
func.func @test_pmac(%arg0 : tensor<8192xi32>, %arg1 : tensor<8192xi32>, %arg2 : tensor<8192xi32>) -> tensor<8192xi32> {
31-
%0 = pisa.mac %arg0, %arg1, %arg2 {q = 2147483647 : i32, i = 0 : i32} : tensor<8192xi32>
32-
return %0 : tensor<8192xi32>
31+
func.func @test_pmac(%arg0 : tensor<8192x!m32>, %arg1 : tensor<8192x!m32>, %arg2 : tensor<8192x!m32>) -> tensor<8192x!m32> {
32+
%0 = pisa.mac %arg0, %arg1, %arg2 {q = 2147483647 : i32, i = 0 : i32} : tensor<8192x!m32>
33+
return %0 : tensor<8192x!m32>
3334
}
3435

3536
// CHECK-LABEL: test_pmaci
36-
func.func @test_pmaci(%arg0 : tensor<8192xi32>, %arg1 : tensor<8192xi32>) -> tensor<8192xi32> {
37-
%0 = pisa.maci %arg0, %arg1 {q = 2147483647 : i32, i = 0 : i32, imm = 5 : i32} : tensor<8192xi32>
38-
return %0 : tensor<8192xi32>
39-
}
40-
41-
// CHECK-LABEL: test_pntt
42-
func.func @test_pntt(%arg0 : tensor<8192xi32>) -> tensor<8192xi32> {
43-
//TODO: figure out how to best handle the twiddle factors here...
44-
%w = arith.constant dense<42> : tensor<8192xi32>
45-
%0 = pisa.ntt %arg0, %w {q = 2147483647 : i32, i = 0 : i32} : tensor<8192xi32>
46-
return %0 : tensor<8192xi32>
47-
}
48-
49-
// CHECK-LABEL: test_pintt
50-
func.func @test_pintt(%arg0 : tensor<8192xi32>) -> tensor<8192xi32> {
51-
//TODO: figure out how to best handle the twiddle factors here...
52-
%w = arith.constant dense<42> : tensor<8192xi32>
53-
%0 = pisa.intt %arg0, %w {q = 2147483647 : i32, i = 0 : i32} : tensor<8192xi32>
54-
return %0 : tensor<8192xi32>
55-
}
56-
57-
58-
// -----
59-
// CHECK-NOT: test_invalid_tensor
60-
func.func @test_invalid_tensor(%arg0 : tensor<1024xi32>, %arg1 : tensor<1024xi32>) -> tensor<1024xi32> {
61-
// expected-error@below {{custom op 'pisa.add' 'pisa.add' op operand #0 must be tensor<8192xi32>, but got 'tensor<1024xi32>'}}
62-
%0 = pisa.add %arg0, %arg1 {q = 2147483647 : i32, i = 0 : i32} : tensor<1024xi32>
63-
return %0 : tensor<1024xi32>
64-
}
65-
66-
// -----
67-
// CHECK-NOT: test_invalid_modulus
68-
func.func @test_invalid_modulus(%arg0 : tensor<8192xi32>, %arg1 : tensor<8192xi32>) -> tensor<8192xi32> {
69-
// expected-error@below {{custom op 'pisa.add' 'pisa.add' op attribute 'q' failed to satisfy constraint: 32-bit signless integer attribute}}
70-
%0 = pisa.add %arg0, %arg1 {q = 18446744073709551557, i = 0 : i32} : tensor<8192xi32>
71-
return %0 : tensor<8192xi32>
72-
}
37+
func.func @test_pmaci(%arg0 : tensor<8192x!m32>, %arg1 : tensor<8192x!m32>) -> tensor<8192x!m32> {
38+
%0 = pisa.maci %arg0, %arg1 {q = 2147483647 : i32, i = 0 : i32, imm = 5 : i32} : tensor<8192x!m32>
39+
return %0 : tensor<8192x!m32>
40+
}
41+
42+
// FIXME: re-enable check once mod_arith.constant works for tensors
43+
// func.func @test_pntt(%arg0 : tensor<8192x!m32>) -> tensor<8192x!m32> {
44+
// //TODO: figure out how to best handle the twiddle factors here...
45+
// // FIXME: cannot currently create a mod_arith.constant tensor? Below will silently fail and cause mlir-opt to produce no output?
46+
// %w = mod_arith.constant 42 : tensor<8192x!m32>
47+
// %0 = pisa.ntt %arg0, %w {q = 2147483647 : i32, i = 0 : i32} : tensor<8192x!m32>
48+
// return %0 : tensor<8192x!m32>
49+
// }
50+
51+
// FIXME: re-enable check once mod_arith.constant works for tensors
52+
// func.func @test_pintt(%arg0 : tensor<8192x!m32>) -> tensor<8192x!m32> {
53+
// //TODO: figure out how to best handle the twiddle factors here...
54+
// //FIXME: cannot currently create a mod_arith.constant tensor? Below will silently fail and cause mlir-opt to produce no output?
55+
// %w = mod_arith.constant 42 : tensor<8192x!m32>
56+
// %0 = pisa.intt %arg0, %w {q = 2147483647 : i32, i = 0 : i32} : tensor<8192x!m32>
57+
// return %0 : tensor<8192x!m32>
58+
// }

tests/Dialect/Polynomial/Conversions/polynomial_to_pisa/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ glob_lit_tests(
66
name = "all_tests",
77
data = ["@heir//tests:test_utilities"],
88
driver = "@heir//tests:run_lit.sh",
9+
exclude = ["end_to_end.mlir"], # TODO (#1199): re-enable after `--lwe-to-polynomial` is fixed
910
test_file_exts = ["mlir"],
1011
)

tests/Dialect/Polynomial/Conversions/polynomial_to_pisa/non_rns.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
!p = !polynomial.polynomial<ring=<coefficientType=!coeff_ty, polynomialModulus=#polynomial.int_polynomial<1 + x**8192>>>
55

66
//CHECK-LABEL: @test_add
7-
//CHECK: [[X:%.+]]: tensor<8192xi32>, [[Y:%.+]]: tensor<8192xi32>
7+
//CHECK: [[X:%.+]]: tensor<8192x!Z33538049_i32_>, [[Y:%.+]]: tensor<8192x!Z33538049_i32_>
88
func.func @test_add(%x : !p, %y : !p) -> !p {
9-
//CHECK: [[ADD:%.+]] = pisa.add [[X]], [[Y]] {i = 0 : i32, q = 33538049 : i32} : tensor<8192xi32>
9+
//CHECK: [[ADD:%.+]] = pisa.add [[X]], [[Y]] {i = 0 : i32, q = 33538049 : i32} : tensor<8192x!Z33538049_i32_>
1010
%0 = polynomial.add %x, %y : !p
11-
//CHECK: return [[ADD]] : tensor<8192xi32>
11+
//CHECK: return [[ADD]] : tensor<8192x!Z33538049_i32_>
1212
return %0 : !p
1313
}

tests/Emitter/PISA/emit.mlir

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// RUN: heir-translate --emit-pisa %s | FileCheck %s
22

3-
func.func @test_emit(%arg0 : tensor<8192xi32>, %arg1 : tensor<8192xi32>) -> tensor<8192xi32> {
3+
!m32 = !mod_arith.int<33538049:i32>
4+
5+
func.func @test_emit(%arg0 : tensor<8192x!m32>, %arg1 : tensor<8192x!m32>) -> tensor<8192x!m32> {
46
//CHECK: 13, add, [[ADD:.+]], [[INP0:.+]], [[INP1:.+]], 0
57
//CHECK: 13, sub, [[SUB:.+]], [[INP0]], [[INP1]], 0
68
//CHECK: 13, mul, [[MUL:.+]], [[INP0]], [[INP1]], 0
@@ -9,14 +11,14 @@ func.func @test_emit(%arg0 : tensor<8192xi32>, %arg1 : tensor<8192xi32>) -> ten
911
//CHECK: 13, mac, [[ACC1]], [[INP0]], [[INP1]], 0
1012
//CHECK: 13, copy, [[ACC2:.+]], [[INP1]]
1113
//CHECK: 13, mac, [[ACC2]], [[INP0]], [[ACC2]]_imm, 0
12-
%0 = pisa.add %arg0, %arg1 {q = 2147483647 : i32, i = 0 : i32} : tensor<8192xi32>
13-
%1 = pisa.sub %arg0, %arg1 {q = 2147483647 : i32, i = 0 : i32} : tensor<8192xi32>
14-
%2 = pisa.mul %arg0, %arg1 {q = 2147483647 : i32, i = 0 : i32} : tensor<8192xi32>
15-
%3 = pisa.muli %arg0 {q = 2147483647 : i32, i = 0 : i32, imm = 5 : i32} : tensor<8192xi32>
16-
%4 = pisa.mac %arg0, %arg1, %arg0 {q = 2147483647 : i32, i = 0 : i32} : tensor<8192xi32>
17-
%5 = pisa.maci %arg0, %arg1 {q = 2147483647 : i32, i = 0 : i32, imm = 5 : i32} : tensor<8192xi32>
18-
%w = arith.constant dense<42> : tensor<8192xi32>
19-
// %6 = pisa.ntt %arg0, %w {q = 2147483647 : i32, i = 0 : i32} : tensor<8192xi32>
20-
// %7 = pisa.intt %arg0, %w {q = 2147483647 : i32, i = 0 : i32} : tensor<8192xi32>
21-
return %0 : tensor<8192xi32>
14+
%0 = pisa.add %arg0, %arg1 {q = 2147483647 : i32, i = 0 : i32} : tensor<8192x!m32>
15+
%1 = pisa.sub %arg0, %arg1 {q = 2147483647 : i32, i = 0 : i32} : tensor<8192x!m32>
16+
%2 = pisa.mul %arg0, %arg1 {q = 2147483647 : i32, i = 0 : i32} : tensor<8192x!m32>
17+
%3 = pisa.muli %arg0 {q = 2147483647 : i32, i = 0 : i32, imm = 5 : i32} : tensor<8192x!m32>
18+
%4 = pisa.mac %arg0, %arg1, %arg0 {q = 2147483647 : i32, i = 0 : i32} : tensor<8192x!m32>
19+
%5 = pisa.maci %arg0, %arg1 {q = 2147483647 : i32, i = 0 : i32, imm = 5 : i32} : tensor<8192x!m32>
20+
// %w = mod_arith.constant 42 : tensor<8192x!m32> // FIXME: re-enable once mod_arith tensor constant generation is fixed
21+
// %6 = pisa.ntt %arg0, %w {q = 2147483647 : i32, i = 0 : i32} : tensor<8192x!m32>
22+
// %7 = pisa.intt %arg0, %w {q = 2147483647 : i32, i = 0 : i32} : tensor<8192x!m32>
23+
return %0 : tensor<8192x!m32>
2224
}

0 commit comments

Comments
 (0)