Skip to content

Commit 354011d

Browse files
WIP: create PISA dialect (+ emitter and passes)
1 parent 7ca72cb commit 354011d

38 files changed

+1112
-0
lines changed

.pre-commit-config.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ repos:
3838
rev: "v2.2.5"
3939
hooks:
4040
- id: codespell
41+
# The PISA dialect contains operation names that look like misspellings.
42+
exclude: >
43+
(?x)^(
44+
.*\/pisa\/.*\.mlir|
45+
.*\/PISA\/.*\.td|
46+
.*\/PISA\/.*\.cpp
47+
)$
4148
4249
# Changes tabs to spaces
4350
- repo: https://github.com/Lucas-C/pre-commit-hooks

lib/Dialect/PISA/IR/BUILD

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# PISA dialect implementation
2+
3+
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
4+
5+
package(
6+
default_applicable_licenses = ["@heir//:license"],
7+
default_visibility = ["//visibility:public"],
8+
)
9+
10+
cc_library(
11+
name = "Dialect",
12+
srcs = [
13+
"PISADialect.cpp",
14+
],
15+
hdrs = [
16+
"PISADialect.h",
17+
"PISAOps.h",
18+
],
19+
deps = [
20+
"dialect_inc_gen",
21+
"ops_inc_gen",
22+
":PISAOps",
23+
"@llvm-project//llvm:Support",
24+
"@llvm-project//mlir:IR",
25+
],
26+
)
27+
28+
cc_library(
29+
name = "PISAOps",
30+
srcs = [
31+
"PISAOps.cpp",
32+
],
33+
hdrs = [
34+
"PISADialect.h",
35+
"PISAOps.h",
36+
],
37+
deps = [
38+
":dialect_inc_gen",
39+
":ops_inc_gen",
40+
"@llvm-project//llvm:Support",
41+
"@llvm-project//mlir:ArithDialect",
42+
"@llvm-project//mlir:IR",
43+
"@llvm-project//mlir:InferTypeOpInterface",
44+
"@llvm-project//mlir:Support",
45+
],
46+
)
47+
48+
td_library(
49+
name = "td_files",
50+
srcs = [
51+
"PISADialect.td",
52+
"PISAOps.td",
53+
],
54+
# include from the heir - root to enable fully - qualified include - paths
55+
includes = ["../../../.."],
56+
deps = [
57+
"@heir//lib/Utils/DRR",
58+
"@llvm-project//mlir:BuiltinDialectTdFiles",
59+
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
60+
"@llvm-project//mlir:OpBaseTdFiles",
61+
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
62+
],
63+
)
64+
65+
gentbl_cc_library(
66+
name = "dialect_inc_gen",
67+
tbl_outs = [
68+
(
69+
[
70+
"-gen-dialect-decls",
71+
],
72+
"PISADialect.h.inc",
73+
),
74+
(
75+
[
76+
"-gen-dialect-defs",
77+
],
78+
"PISADialect.cpp.inc",
79+
),
80+
],
81+
tblgen = "@llvm-project//mlir:mlir-tblgen",
82+
td_file = "PISADialect.td",
83+
deps = [
84+
":td_files",
85+
],
86+
)
87+
88+
gentbl_cc_library(
89+
name = "ops_inc_gen",
90+
tbl_outs = [
91+
(
92+
["-gen-op-decls"],
93+
"PISAOps.h.inc",
94+
),
95+
(
96+
["-gen-op-defs"],
97+
"PISAOps.cpp.inc",
98+
),
99+
(
100+
["-gen-op-doc"],
101+
"PISAOps.md",
102+
),
103+
],
104+
tblgen = "@llvm-project//mlir:mlir-tblgen",
105+
td_file = "PISAOps.td",
106+
deps = [
107+
":dialect_inc_gen",
108+
":td_files",
109+
],
110+
)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#include "lib/Dialect/PISA/IR/PISADialect.h"
2+
3+
#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project
4+
5+
// NOLINTNEXTLINE(misc-include-cleaner): Required to define PISAOps
6+
7+
#include "lib/Dialect/PISA/IR/PISAOps.h"
8+
9+
// Generated definitions
10+
#include "lib/Dialect/PISA/IR/PISADialect.cpp.inc"
11+
12+
#define GET_OP_CLASSES
13+
#include "lib/Dialect/PISA/IR/PISAOps.cpp.inc"
14+
15+
namespace mlir {
16+
namespace heir {
17+
namespace pisa {
18+
19+
void PISADialect::initialize() {
20+
addOperations<
21+
#define GET_OP_LIST
22+
#include "lib/Dialect/PISA/IR/PISAOps.cpp.inc"
23+
>();
24+
}
25+
26+
} // namespace pisa
27+
} // namespace heir
28+
} // namespace mlir

lib/Dialect/PISA/IR/PISADialect.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#ifndef LIB_DIALECT_PISA_IR_PISADIALECT_H_
2+
#define LIB_DIALECT_PISA_IR_PISADIALECT_H_
3+
4+
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
5+
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project
6+
7+
// Generated headers (block clang-format from messing up order)
8+
#include "lib/Dialect/PISA/IR/PISADialect.h.inc"
9+
10+
#endif // LIB_DIALECT_PISA_IR_PISADIALECT_H_

lib/Dialect/PISA/IR/PISADialect.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#ifndef LIB_DIALECT_PISA_IR_PISADIALECT_TD_
2+
#define LIB_DIALECT_PISA_IR_PISADIALECT_TD_
3+
4+
include "mlir/IR/DialectBase.td"
5+
6+
def PISA_Dialect : Dialect {
7+
let name = "pisa";
8+
let description = [{
9+
// FIXME: add documentation
10+
The `pisa` dialect is ...
11+
}];
12+
13+
let cppNamespace = "::mlir::heir::pisa";
14+
}
15+
16+
#endif // LIB_DIALECT_PISA_IR_PISADIALECT_TD_

lib/Dialect/PISA/IR/PISAOps.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#include "lib/Dialect/PISA/IR/PISAOps.h"
2+
3+
namespace mlir {
4+
namespace heir {
5+
namespace pisa {} // namespace pisa
6+
} // namespace heir
7+
} // namespace mlir

lib/Dialect/PISA/IR/PISAOps.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#ifndef LIB_DIALECT_PISA_IR_PISAOPS_H_
2+
#define LIB_DIALECT_PISA_IR_PISAOPS_H_
3+
4+
#include "lib/Dialect/PISA/IR/PISADialect.h"
5+
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
6+
#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
7+
8+
#define GET_OP_CLASSES
9+
#include "lib/Dialect/PISA/IR/PISAOps.h.inc"
10+
11+
#endif // LIB_DIALECT_PISA_IR_PISAOPS_H_

lib/Dialect/PISA/IR/PISAOps.td

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#ifndef LIB_DIALECT_PISA_IR_PISAOPS_TD_
2+
#define LIB_DIALECT_PISA_IR_PISAOPS_TD_
3+
4+
include "lib/Dialect/PISA/IR/PISADialect.td"
5+
include "mlir/IR/BuiltinAttributes.td"
6+
include "mlir/IR/CommonTypeConstraints.td"
7+
include "mlir/IR/OpBase.td"
8+
include "mlir/Interfaces/InferTypeOpInterface.td"
9+
include "mlir/Interfaces/SideEffectInterfaces.td"
10+
11+
def Tensor8192I32 : TypeConstraint<CPred<[{
12+
mlir::isa<mlir::RankedTensorType>($_self) &&
13+
mlir::cast<mlir::RankedTensorType>($_self).getRank() == 1 &&
14+
mlir::cast<mlir::RankedTensorType>($_self).getDimSize(0) == 8192 &&
15+
mlir::cast<mlir::RankedTensorType>($_self).getElementType().isInteger(32)
16+
}]>, "tensor<8192xi32>">;
17+
18+
class PISA_Op<string mnemonic, list<Trait> traits = [Pure]> :
19+
Op<PISA_Dialect, mnemonic, traits> {
20+
let cppNamespace = "::mlir::heir::pisa";
21+
}
22+
23+
class PISA_BinaryOp<string mnemonic, list<Trait> traits = []> :
24+
PISA_Op<mnemonic, traits # [SameOperandsAndResultType]>,
25+
Arguments<(ins Tensor8192I32:$lhs, Tensor8192I32:$rhs, I32Attr:$q, I32Attr:$i)>,
26+
Results<(outs Tensor8192I32:$output)> {
27+
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` qualified(type($output))";
28+
}
29+
30+
def PISA_AddOp : PISA_BinaryOp<"add", [Commutative]> {
31+
let summary = "addition operation";
32+
let description = [{
33+
Computes addition of two polynomials (irrespective of ntt/coefficient representation).
34+
}];
35+
}
36+
37+
def PISA_SubOp : PISA_BinaryOp<"sub", []> {
38+
let summary = "subtraction operation";
39+
let description = [{
40+
Computes subtraction of two polynomials (irrespective of ntt/coefficient representation).
41+
}];
42+
}
43+
44+
def PISA_MulOp : PISA_BinaryOp<"mul", [Commutative]> {
45+
let summary = "multiplication operation";
46+
let description = [{
47+
Computes addition of two polynomials (in ntt representation).
48+
}];
49+
}
50+
51+
def PISA_MuliOp : PISA_Op<"muli", [SameOperandsAndResultType]> {
52+
let summary = "multiplication-with-immediate operation";
53+
let description = [{
54+
Computes multiplication of a polynomial (in ntt representation) with a constant.
55+
}];
56+
let arguments = (ins Tensor8192I32:$lhs, I32Attr:$q, I32Attr:$i, I32Attr:$imm);
57+
let results = (outs Tensor8192I32:$output);
58+
let assemblyFormat = "$lhs attr-dict `:` qualified(type($output))";
59+
}
60+
61+
def PISA_MacOp : PISA_Op<"mac", [SameOperandsAndResultType]> {
62+
let summary = "multiply-and-accumulate operation";
63+
let description = [{
64+
Computes multiplication of two polynomials (in ntt representation) and adds the result to a third polynomial.
65+
}];
66+
let arguments = (ins Tensor8192I32:$lhs, Tensor8192I32:$rhs, Tensor8192I32:$acc, I32Attr:$q, I32Attr:$i);
67+
let results = (outs Tensor8192I32:$output);
68+
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` qualified(type($output))";
69+
}
70+
71+
def PISA_MaciOp : PISA_Op<"maci", [SameOperandsAndResultType]> {
72+
let summary = "multiply-and-accumulate-with-immediate operation";
73+
let description = [{
74+
Computes multiplication of a polynomial (in ntt representation) with a constant and adds the result to a third polynomial.
75+
}];
76+
let arguments = (ins Tensor8192I32:$lhs, Tensor8192I32:$acc, I32Attr:$q, I32Attr:$i, I32Attr:$imm);
77+
let results = (outs Tensor8192I32:$output);
78+
let assemblyFormat = "$lhs `,` $acc attr-dict `:` qualified(type($output))";
79+
}
80+
81+
def PISA_NTTOp : PISA_Op<"ntt", [SameOperandsAndResultType]> {
82+
let summary = "number-theoretic-transform operation";
83+
let description = [{
84+
Computes number-theoretic-transform of a polynomial.
85+
}];
86+
let arguments = (ins Tensor8192I32:$poly, Tensor8192I32:$w, I32Attr:$q, I32Attr:$i);
87+
let results = (outs Tensor8192I32:$output);
88+
let assemblyFormat = "$poly `,` $w attr-dict `:` qualified(type($output))";
89+
}
90+
91+
def PISA_INTTOp : PISA_Op<"intt", [SameOperandsAndResultType]> {
92+
let summary = "inverse number-theoretic-transform operation";
93+
let description = [{
94+
Computes inverse number-theoretic-transform of a polynomial.
95+
}];
96+
let arguments = (ins Tensor8192I32:$poly, Tensor8192I32:$w, I32Attr:$q, I32Attr:$i);
97+
let results = (outs Tensor8192I32:$output);
98+
let assemblyFormat = "$poly `,` $w attr-dict `:` qualified(type($output))";
99+
}
100+
101+
102+
#endif // LIB_DIALECT_PISA_IR_PISAOPS_TD_
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
add_subdirectory(PolynomialToPISA)
12
add_subdirectory(PolynomialToStandard)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
2+
3+
package(
4+
default_applicable_licenses = ["@heir//:license"],
5+
default_visibility = ["//visibility:public"],
6+
)
7+
8+
cc_library(
9+
name = "PolynomialToPISA",
10+
srcs = ["PolynomialToPISA.cpp"],
11+
hdrs = ["PolynomialToPISA.h"],
12+
deps = [
13+
":pass_inc_gen",
14+
"@heir//lib/Dialect/ModArith/IR:Dialect",
15+
"@heir//lib/Dialect/PISA/IR:Dialect",
16+
"@heir//lib/Utils/ConversionUtils",
17+
"@llvm-project//mlir:IR",
18+
"@llvm-project//mlir:Pass",
19+
"@llvm-project//mlir:PolynomialDialect",
20+
"@llvm-project//mlir:Transforms",
21+
],
22+
)
23+
24+
gentbl_cc_library(
25+
name = "pass_inc_gen",
26+
tbl_outs = [
27+
(
28+
[
29+
"-gen-pass-decls",
30+
"-name=PolynomialToPISA",
31+
],
32+
"PolynomialToPISA.h.inc",
33+
),
34+
(
35+
["-gen-pass-doc"],
36+
"PolynomialToPISA.md",
37+
),
38+
],
39+
tblgen = "@llvm-project//mlir:mlir-tblgen",
40+
td_file = "PolynomialToPISA.td",
41+
deps = [
42+
"@llvm-project//mlir:OpBaseTdFiles",
43+
"@llvm-project//mlir:PassBaseTdFiles",
44+
],
45+
)

0 commit comments

Comments
 (0)