Skip to content

Commit 997af5f

Browse files
committed
keep going
1 parent 1e3f6e5 commit 997af5f

File tree

12 files changed

+279
-48
lines changed

12 files changed

+279
-48
lines changed

lib/Dialect/CKKS/IR/CKKSOps.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ LogicalResult RelinearizeOp::inferReturnTypes(
7979
return lwe::inferRelinearizeOpReturnTypes(ctx, adaptor, inferredReturnTypes);
8080
}
8181

82+
LogicalResult LevelReduceOp::inferReturnTypes(
83+
MLIRContext* ctx, std::optional<Location>, LevelReduceOp::Adaptor adaptor,
84+
SmallVectorImpl<Type>& inferredReturnTypes) {
85+
return lwe::inferLevelReduceOpReturnTypes(ctx, adaptor, inferredReturnTypes);
86+
}
87+
8288
void MulPlainOp::getCanonicalizationPatterns(RewritePatternSet& results,
8389
MLIRContext* context) {
8490
results.add<lwe::PutCiphertextInFirstOperand<MulPlainOp>>(context);

lib/Dialect/CKKS/IR/CKKSOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def CKKS_RescaleOp : CKKS_Op<"rescale", [ElementwiseMappable]> {
188188
let assemblyFormat = "operands attr-dict `:` qualified(type($input)) `->` qualified(type($output))" ;
189189
}
190190

191-
def CKKS_LevelReduceOp : CKKS_Op<"level_reduce", [ElementwiseMappable, SameOperandsAndResultPlaintextTypes]> {
191+
def CKKS_LevelReduceOp : CKKS_Op<"level_reduce", [ElementwiseMappable, SameOperandsAndResultPlaintextTypes, InferTypeOpAdaptor]> {
192192
let summary = "Lower the modulus level of the ciphertext via dropping RNS limbs.";
193193

194194
let arguments = (ins

lib/Dialect/LWE/IR/LWEAttributes.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "lib/Dialect/LWE/IR/LWEAttributes.h"
22

33
#include <cstdint>
4+
#include <utility>
45

56
#include "lib/Dialect/ModArith/IR/ModArithTypes.h"
67
#include "lib/Dialect/Polynomial/IR/PolynomialAttributes.h"
@@ -114,7 +115,7 @@ PlaintextSpaceAttr inferModulusSwitchOrRescaleOpPlaintextSpaceAttr(
114115
}
115116

116117
auto newScale = inferModulusSwitchOrRescaleOpScalingFactor(
117-
xEncoding, dividedModulus, plaintextModulus);
118+
xEncoding, std::move(dividedModulus), plaintextModulus);
118119
return PlaintextSpaceAttr::get(
119120
ctx, xRing, getEncodingAttrWithNewScalingFactor(xEncoding, newScale));
120121
}

lib/Dialect/LWE/IR/LWEOps.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,37 @@ LogicalResult inferRelinearizeOpReturnTypes(
366366
return success();
367367
}
368368

369+
template <typename Adaptor>
370+
LogicalResult inferLevelReduceOpReturnTypes(
371+
MLIRContext* ctx, Adaptor adaptor,
372+
SmallVectorImpl<Type>& inferredReturnTypes) {
373+
auto x = getCtTy(adaptor.getInput());
374+
auto levelToDrop = adaptor.getLevelToDrop();
375+
376+
ModulusChainAttr newModulusChain = lwe::ModulusChainAttr::get(
377+
ctx, x.getModulusChain().getElements(),
378+
x.getModulusChain().getCurrent() - levelToDrop);
379+
polynomial::RingAttr newRing = getRingFromModulusChain(
380+
newModulusChain, x.getCiphertextSpace().getRing().getPolynomialModulus());
381+
382+
auto newCtTy = lwe::LWECiphertextType::get(
383+
ctx, x.getApplicationData(), x.getPlaintextSpace(),
384+
lwe::CiphertextSpaceAttr::get(ctx, newRing,
385+
x.getCiphertextSpace().getEncryptionType(),
386+
x.getCiphertextSpace().getSize()),
387+
x.getKey(), newModulusChain);
388+
389+
if (auto tensorTy =
390+
dyn_cast<RankedTensorType>(adaptor.getInput().getType())) {
391+
inferredReturnTypes.push_back(
392+
RankedTensorType::get(tensorTy.getShape(), newCtTy));
393+
return success();
394+
}
395+
396+
inferredReturnTypes.push_back(newCtTy);
397+
return success();
398+
}
399+
369400
} // namespace lwe
370401
} // namespace heir
371402
} // namespace mlir

lib/Dialect/LWE/IR/LWETraits.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ class SameOperandsAndResultPlaintextTypes
6767
}
6868
if (plaintextTypes != ps) {
6969
op->emitOpError() << "requires all operands and results to have "
70-
"the same plaintextTypes";
70+
"the same plaintextTypes, but found "
71+
<< plaintextTypes << " and " << ps;
7172
return failure();
7273
}
7374
return success();

lib/Dialect/LWE/IR/LWETypes.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
#include "lib/Dialect/RNS/IR/RNSTypes.h"
66
#include "lib/Utils/Polynomial/Polynomial.h"
77
#include "llvm/include/llvm/ADT/STLFunctionalExtras.h" // from @llvm-project
8+
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
89
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
910
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
1011
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
1112

13+
#define DEBUG_TYPE "lwe-types"
14+
1215
namespace mlir {
1316
namespace heir {
1417
namespace lwe {

lib/Dialect/Orion/Conversions/OrionToCKKS/IRMaterializingVisitor.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,7 @@ Value IRMaterializingVisitor::operator()(const SubtractNode<SSAValue>& node) {
6767

6868
Value IRMaterializingVisitor::operator()(const MultiplyNode<SSAValue>& node) {
6969
return binop<MultiplyNode<SSAValue>, ckks::MulOp, ckks::MulPlainOp,
70-
arith::MulFOp>(node,
71-
/*rescale=*/true);
70+
arith::MulFOp>(node);
7271
}
7372

7473
Value IRMaterializingVisitor::operator()(const LeftRotateNode<SSAValue>& node) {

lib/Dialect/Orion/Conversions/OrionToCKKS/IRMaterializingVisitor.h

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@ polynomial::RingAttr getRlweRNSRingWithLevel(polynomial::RingAttr ringAttr,
2121
int level);
2222

2323
// Walks the arithmetic DAG and generates MLIR for it. This materializer is
24-
// weird because some SSA values are tensors of ciphertexts while others are
25-
// ciphertext-semantic tensors of cleartexts. The former requires simple ops
26-
// like tensor.extract and ckks.rotate, while the latter requires
27-
// tensor.extract_slice and tensor_ext.rotate.
24+
// meant to handle quirks of the Orion import process. For example it
25+
// handles scale encoding of cleartexts, and special cases for tensors of
26+
// ciphertexts vs ciphertext-semantic tensors of cleartexts. The former
27+
// requires simple ops like tensor.extract and ckks.rotate, while the latter
28+
// requires tensor.extract_slice and tensor_ext.rotate.
2829
class IRMaterializingVisitor
2930
: public kernel::CachingVisitor<kernel::SSAValue, Value> {
3031
public:
@@ -39,16 +40,48 @@ class IRMaterializingVisitor
3940
builder(builder),
4041
plaintextType(ptTy) {}
4142

42-
Value maybeRescale(Value value, lwe::LWECiphertextType resultType,
43-
bool rescale) {
43+
// Relinearize and rescale the ciphertext if relinAndRescale is true.
44+
Value relinAndRescale(Value value, lwe::LWECiphertextType resultType,
45+
bool relinearize, bool rescale) {
4446
Value result = value;
47+
if (relinearize) {
48+
auto inputDimension = cast<lwe::LWECiphertextType>(value.getType())
49+
.getCiphertextSpace()
50+
.getSize();
51+
SmallVector<int32_t> fromBasis;
52+
for (int i = 0; i < inputDimension; ++i) {
53+
fromBasis.push_back(i);
54+
}
55+
SmallVector<int32_t> toBasis = {0, 1};
56+
auto relinOp = ckks::RelinearizeOp::create(
57+
builder, result, builder.getDenseI32ArrayAttr(fromBasis),
58+
builder.getDenseI32ArrayAttr(toBasis));
59+
result = relinOp.getResult();
60+
}
4561
if (rescale) {
4662
FailureOr<lwe::LWECiphertextType> ctTypeResult =
4763
applyModReduce(resultType);
4864
if (failed(ctTypeResult)) {
4965
emitError(result.getLoc())
50-
<< "Cannot rescale ciphertext type: " << resultType;
51-
return Value();
66+
<< "Cannot rescale ciphertext type, inserting extra bootstrap op";
67+
// sub 1 because the max level is the last index in the chain.
68+
int64_t maxLevel =
69+
resultType.getModulusChain().getElements().size() - 1;
70+
71+
// Now we cheat a little bit: normally bootstrap itself would consume
72+
// some levels, which depends on the chosen backend. In our case, we're
73+
// lowering to library backends that handle this opaquely.
74+
//
75+
// TODO(#1207): fix if this pass still matters when lowering to
76+
// polynomial.
77+
FailureOr<lwe::LWECiphertextType> outputTypeResult =
78+
cloneAtLevel(resultType, maxLevel);
79+
if (failed(outputTypeResult)) {
80+
emitError(result.getLoc()) << "Failed to insert bootstrap";
81+
return Value();
82+
}
83+
result = ckks::BootstrapOp::create(builder, outputTypeResult.value(),
84+
result);
5285
}
5386
auto ctType = ctTypeResult.value();
5487
result = ckks::RescaleOp::create(builder, ctType, result,
@@ -68,26 +101,33 @@ class IRMaterializingVisitor
68101
}
69102

70103
template <typename T, typename CtCtOp, typename CtPtOp, typename CleartextOp>
71-
Value binop(const T& node, bool rescale = false) {
104+
Value binop(const T& node) {
72105
Value lhs = this->process(node.left);
73106
Value rhs = this->process(node.right);
74107

108+
bool relinearize =
109+
static_cast<bool>(std::is_same<CtCtOp, ckks::MulOp>::value);
110+
bool rescale =
111+
static_cast<bool>(std::is_same<CtCtOp, ckks::MulOp>::value) ||
112+
static_cast<bool>(std::is_same<CtPtOp, ckks::MulPlainOp>::value);
113+
75114
return TypeSwitch<Type, Value>(lhs.getType())
76115
.template Case<lwe::LWECiphertextType>([&](auto ty) {
77116
if (isa<RankedTensorType>(rhs.getType())) {
78-
return maybeRescale(
117+
return relinAndRescale(
79118
CtPtOp::create(builder, lhs, encodeCleartextOperand(ty, rhs))
80119
.getResult(),
81-
ty, rescale);
120+
ty, relinearize, rescale);
82121
}
83122

84123
if (isa<lwe::LWEPlaintextType>(rhs.getType())) {
85-
return maybeRescale(CtPtOp::create(builder, lhs, rhs).getResult(),
86-
ty, rescale);
124+
return relinAndRescale(
125+
CtPtOp::create(builder, lhs, rhs).getResult(), ty, relinearize,
126+
rescale);
87127
}
88128

89-
return maybeRescale(CtCtOp::create(builder, lhs, rhs).getResult(), ty,
90-
rescale);
129+
return relinAndRescale(CtCtOp::create(builder, lhs, rhs).getResult(),
130+
ty, relinearize, rescale);
91131
})
92132
.template Case<lwe::LWEPlaintextType>([&](auto ty) {
93133
if (isa<RankedTensorType>(rhs.getType())) {
@@ -111,15 +151,15 @@ class IRMaterializingVisitor
111151
<< "\n\nrhs=" << rhs << "\n";
112152
return Value();
113153
}
114-
return maybeRescale(CtPtOp::create(builder, lhs, rhs).getResult(),
115-
ctTy, rescale);
154+
return relinAndRescale(CtPtOp::create(builder, lhs, rhs).getResult(),
155+
ctTy, relinearize, rescale);
116156
})
117157
.template Case<RankedTensorType>([&](auto ty) {
118158
auto ctTy = cast<lwe::LWECiphertextType>(rhs.getType());
119-
return maybeRescale(
159+
return relinAndRescale(
120160
CtPtOp::create(builder, encodeCleartextOperand(ctTy, lhs), rhs)
121161
.getResult(),
122-
ctTy, rescale);
162+
ctTy, relinearize, rescale);
123163
})
124164
.Default([&](Type) {
125165
emitError(lhs.getLoc())

0 commit comments

Comments
 (0)