Skip to content

Commit 57ba0cc

Browse files
committed
start refactoring orion-to-ckks to walk driver
1 parent 2ba78b8 commit 57ba0cc

File tree

2 files changed

+183
-60
lines changed

2 files changed

+183
-60
lines changed

lib/Dialect/Orion/Conversions/OrionToCKKS/OrionToCKKS.cpp

Lines changed: 153 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,15 @@
2222

2323
namespace mlir::heir::orion {
2424

25+
using ckks::AddOp;
26+
using ckks::AddPlainOp;
27+
using ckks::BootstrapOp;
28+
using ckks::LevelReduceOp;
29+
using ckks::MulOp;
30+
using ckks::MulPlainOp;
31+
using ckks::RescaleOp;
32+
using ckks::SubOp;
33+
using ckks::SubPlainOp;
2534
using kernel::ArithmeticDagNode;
2635
using kernel::implementHaleviShoup;
2736
using kernel::SSAValue;
@@ -114,8 +123,7 @@ struct ConvertChebyshevOp : public OpRewritePattern<ChebyshevOp> {
114123
// for evaluation purposes.
115124
auto encodedSplatRescale = encodeSplattedCleartextUsingCtAndScalingFactor(
116125
b, ctTy, logDefaultScale, rescale);
117-
xInput =
118-
ckks::MulPlainOp::create(b, xInput, encodedSplatRescale).getResult();
126+
xInput = MulPlainOp::create(b, xInput, encodedSplatRescale).getResult();
119127
}
120128

121129
if (!shift.isZero()) {
@@ -125,8 +133,7 @@ struct ConvertChebyshevOp : public OpRewritePattern<ChebyshevOp> {
125133
ctTy.getPlaintextSpace().getEncoding());
126134
auto encodedSplatShift = encodeSplattedCleartextUsingCtAndScalingFactor(
127135
b, ctTy, scalingFactor, shift);
128-
xInput =
129-
ckks::AddPlainOp::create(b, xInput, encodedSplatShift).getResult();
136+
xInput = AddPlainOp::create(b, xInput, encodedSplatShift).getResult();
130137
}
131138

132139
SSAValue xNode(xInput);
@@ -150,8 +157,8 @@ struct ConvertChebyshevOp : public OpRewritePattern<ChebyshevOp> {
150157
return failure();
151158
}
152159
auto moddedDownTy = ctTypeResult.value();
153-
auto rescaleOp = ckks::RescaleOp::create(
154-
b, moddedDownTy, finalOutput, ctTy.getCiphertextSpace().getRing());
160+
auto rescaleOp = RescaleOp::create(b, moddedDownTy, finalOutput,
161+
ctTy.getCiphertextSpace().getRing());
155162
finalOutput = rescaleOp.getResult();
156163
}
157164

@@ -192,8 +199,8 @@ struct ConvertLinearTransformOp : public OpRewritePattern<LinearTransformOp> {
192199

193200
// ct-pt muls in the kernel didn't rescale, so rescale at the very end
194201
if (!rescaleAfterCtPtMul) {
195-
auto rescaleOp = ckks::RescaleOp::create(
196-
b, ctTy, finalOutput, ctTy.getCiphertextSpace().getRing());
202+
auto rescaleOp = RescaleOp::create(b, ctTy, finalOutput,
203+
ctTy.getCiphertextSpace().getRing());
197204
finalOutput = rescaleOp.getResult();
198205
}
199206

@@ -205,6 +212,56 @@ struct ConvertLinearTransformOp : public OpRewritePattern<LinearTransformOp> {
205212
std::string libraryTarget;
206213
};
207214

215+
struct FixTypesForRescale : public OpRewritePattern<RescaleOp> {
216+
using OpRewritePattern<RescaleOp>::OpRewritePattern;
217+
218+
LogicalResult matchAndRewrite(RescaleOp op,
219+
PatternRewriter& rewriter) const override {
220+
LLVM_DEBUG(llvm::dbgs() << "Handling RescaleOp\n");
221+
lwe::LWECiphertextType inputType =
222+
cast<lwe::LWECiphertextType>(op.getInput().getType());
223+
FailureOr<lwe::LWECiphertextType> outputTypeResult =
224+
applyModReduce(inputType);
225+
if (failed(outputTypeResult)) {
226+
op.emitError()
227+
<< "Cannot drop one limb from ciphertext type, inserting bootstrap\n";
228+
int64_t maxLevel = inputType.getModulusChain().getElements().size() - 1;
229+
230+
// Now we cheat a little bit: normally bootstrap itself would consume
231+
// some levels, which depends on the chosen backend. In our case, we're
232+
// lowering to library backends that handle this opaquely.
233+
//
234+
// TODO(#1207): fix if this pass still matters when lowering to
235+
// polynomial.
236+
FailureOr<lwe::LWECiphertextType> bootstrapResultTypeResult =
237+
cloneAtLevel(inputType, maxLevel);
238+
if (failed(bootstrapResultTypeResult)) {
239+
op.emitError() << "Failed to insert bootstrap";
240+
return failure();
241+
}
242+
ImplicitLocOpBuilder builder(op.getLoc(), op->getContext());
243+
builder.setInsertionPoint(op);
244+
auto bootstrapOp = BootstrapOp::create(
245+
builder, bootstrapResultTypeResult.value(), op.getInput());
246+
op.setOperand(bootstrapOp.getResult());
247+
248+
outputTypeResult = applyModReduce(bootstrapResultTypeResult.value());
249+
if (failed(outputTypeResult)) {
250+
return op.emitError()
251+
<< "Failed to rescale even after inserting bootstrap\n";
252+
}
253+
}
254+
// FIXME: convert to rewriter style
255+
debugLevelAndScale(outputTypeResult.value());
256+
op.getResult().setType(outputTypeResult.value());
257+
op.setToRingAttr(outputTypeResult.value().getCiphertextSpace().getRing());
258+
return success();
259+
}
260+
261+
private:
262+
std::string libraryTarget;
263+
};
264+
208265
WalkResult handleInferTypeOpInterface(InferTypeOpInterface op) {
209266
LLVM_DEBUG(llvm::dbgs() << "Handling InferTypeOpInterface: " << op->getName()
210267
<< "\n");
@@ -229,7 +286,7 @@ WalkResult handleInferTypeOpInterface(InferTypeOpInterface op) {
229286
return WalkResult::advance();
230287
}
231288

232-
WalkResult handleRescaleOp(ckks::RescaleOp op) {
289+
WalkResult handleRescaleOp(RescaleOp op) {
233290
LLVM_DEBUG(llvm::dbgs() << "Handling RescaleOp\n");
234291
lwe::LWECiphertextType inputType =
235292
cast<lwe::LWECiphertextType>(op.getInput().getType());
@@ -254,7 +311,7 @@ WalkResult handleRescaleOp(ckks::RescaleOp op) {
254311
}
255312
ImplicitLocOpBuilder builder(op.getLoc(), op->getContext());
256313
builder.setInsertionPoint(op);
257-
auto bootstrapOp = ckks::BootstrapOp::create(
314+
auto bootstrapOp = BootstrapOp::create(
258315
builder, bootstrapResultTypeResult.value(), op.getInput());
259316
op.setOperand(bootstrapOp.getResult());
260317

@@ -270,7 +327,7 @@ WalkResult handleRescaleOp(ckks::RescaleOp op) {
270327
return WalkResult::advance();
271328
}
272329

273-
WalkResult handleBootstrap(ckks::BootstrapOp op) {
330+
WalkResult handleBootstrap(BootstrapOp op) {
274331
LLVM_DEBUG(llvm::dbgs() << "Handling BootstrapOp\n");
275332
// First, we need to find the maximum level in the modulus chain from the
276333
// ciphertext type.
@@ -297,12 +354,12 @@ WalkResult handleNonMulCtPtOp(CtPtOp op) {
297354
lwe::LWECiphertextType ctType;
298355
LLVM_DEBUG(llvm::dbgs() << "Handling CtPt op: " << op->getName() << "\n");
299356
Value plaintextOperand;
300-
if (isa<lwe::LWECiphertextType>(op.getLhs().getType())) {
301-
ctType = cast<lwe::LWECiphertextType>(op.getLhs().getType());
302-
plaintextOperand = op.getRhs();
357+
if (isa<lwe::LWECiphertextType>(op.getOperand(0).getType())) {
358+
ctType = cast<lwe::LWECiphertextType>(op.getOperand(0).getType());
359+
plaintextOperand = op.getOperand(1);
303360
} else {
304-
ctType = cast<lwe::LWECiphertextType>(op.getRhs().getType());
305-
plaintextOperand = op.getLhs();
361+
ctType = cast<lwe::LWECiphertextType>(op.getOperand(1).getType());
362+
plaintextOperand = op.getOperand(0);
306363
}
307364

308365
if (auto encodeOp =
@@ -328,16 +385,16 @@ WalkResult handleNonMulCtPtOp(CtPtOp op) {
328385
return WalkResult::advance();
329386
}
330387

331-
WalkResult handleMulPlain(ckks::MulPlainOp op) {
388+
WalkResult handleMulPlain(MulPlainOp op) {
332389
LLVM_DEBUG(llvm::dbgs() << "Handling MulPlain op\n");
333390
lwe::LWECiphertextType ctType;
334391
lwe::LWEPlaintextType ptType;
335-
if (isa<lwe::LWECiphertextType>(op.getLhs().getType())) {
336-
ctType = cast<lwe::LWECiphertextType>(op.getLhs().getType());
337-
ptType = cast<lwe::LWEPlaintextType>(op.getRhs().getType());
392+
if (isa<lwe::LWECiphertextType>(op.getOperand(0).getType())) {
393+
ctType = cast<lwe::LWECiphertextType>(op.getOperand(0).getType());
394+
ptType = cast<lwe::LWEPlaintextType>(op.getOperand(1).getType());
338395
} else {
339-
ctType = cast<lwe::LWECiphertextType>(op.getRhs().getType());
340-
ptType = cast<lwe::LWEPlaintextType>(op.getLhs().getType());
396+
ctType = cast<lwe::LWECiphertextType>(op.getOperand(1).getType());
397+
ptType = cast<lwe::LWEPlaintextType>(op.getOperand(0).getType());
341398
}
342399
auto newCtType = lwe::LWECiphertextType::get(
343400
op.getContext(), ctType.getApplicationData(),
@@ -360,33 +417,32 @@ WalkResult handleMulPlain(ckks::MulPlainOp op) {
360417
template <typename CtCtOp>
361418
WalkResult handleCtCtOp(CtCtOp op) {
362419
LLVM_DEBUG(llvm::dbgs() << "Handling ct-ct op: " << op->getName() << "\n");
363-
lwe::LWECiphertextType lhsType =
364-
cast<lwe::LWECiphertextType>(op.getLhs().getType());
365-
lwe::LWECiphertextType rhsType =
366-
cast<lwe::LWECiphertextType>(op.getRhs().getType());
367420
ImplicitLocOpBuilder b(op.getLoc(), op->getContext());
368421
b.setInsertionPoint(op);
369422

370423
// Determine if we need to reduce the level of one operand to match the
371424
// other, or rescale, or do both simultaneously.
372-
int64_t lhsLevel = lhsType.getModulusChain().getCurrent();
373-
int64_t rhsLevel = rhsType.getModulusChain().getCurrent();
374-
int64_t lhsScalingFactor = lwe::getScalingFactorFromEncodingAttr(
425+
auto lhsType = cast<lwe::LWECiphertextType>(op.getOperand(0).getType());
426+
auto rhsType = cast<lwe::LWECiphertextType>(op.getOperand(1).getType());
427+
auto lhsLevel = lhsType.getModulusChain().getCurrent();
428+
auto rhsLevel = rhsType.getModulusChain().getCurrent();
429+
auto lhsScalingFactor = lwe::getScalingFactorFromEncodingAttr(
375430
lhsType.getPlaintextSpace().getEncoding());
376-
int64_t rhsScalingFactor = lwe::getScalingFactorFromEncodingAttr(
431+
auto rhsScalingFactor = lwe::getScalingFactorFromEncodingAttr(
377432
rhsType.getPlaintextSpace().getEncoding());
378433

379-
if (lhsLevel != rhsLevel && lhsScalingFactor != rhsScalingFactor) {
434+
// Rescale to adjust scale and level simultaneously
435+
while (lhsLevel != rhsLevel && lhsScalingFactor != rhsScalingFactor) {
380436
LLVM_DEBUG(llvm::dbgs()
381437
<< "lhs level = " << lhsLevel << " != rhs level = " << rhsLevel
382438
<< ", and " << "lhs scale = " << lhsScalingFactor
383439
<< " != rhs scale = " << rhsScalingFactor
384440
<< ", applying rescale...\n");
385441
Value operandToRescale;
386-
if (lhsScalingFactor >= rhsScalingFactor) {
387-
operandToRescale = op.getLhs();
442+
if (lhsScalingFactor > rhsScalingFactor) {
443+
operandToRescale = op.getOperand(0);
388444
} else {
389-
operandToRescale = op.getRhs();
445+
operandToRescale = op.getOperand(1);
390446
}
391447

392448
auto rescaleInputTy =
@@ -398,33 +454,71 @@ WalkResult handleCtCtOp(CtCtOp op) {
398454
return WalkResult::interrupt();
399455
}
400456
auto ctType = ctTypeResult.value();
401-
auto rescaleOp = ckks::RescaleOp::create(
402-
b, ctType, operandToRescale, ctType.getCiphertextSpace().getRing());
457+
auto rescaleOp = RescaleOp::create(b, ctType, operandToRescale,
458+
ctType.getCiphertextSpace().getRing());
403459
int64_t operandIndex = (lhsScalingFactor > rhsScalingFactor) ? 1 : 0;
404460
debugLevelAndScale(rescaleOp.getResult().getType(), "operand");
405461
op->setOperand(operandIndex, rescaleOp.getResult());
406-
return handleInferTypeOpInterface(op);
462+
auto result = handleInferTypeOpInterface(op);
463+
if (result.wasInterrupted()) {
464+
return result;
465+
}
466+
467+
if (lhsScalingFactor > rhsScalingFactor) {
468+
lhsType = cast<lwe::LWECiphertextType>(rescaleOp.getResult().getType());
469+
} else {
470+
rhsType = cast<lwe::LWECiphertextType>(rescaleOp.getResult().getType());
471+
}
472+
lhsLevel = lhsType.getModulusChain().getCurrent();
473+
rhsLevel = rhsType.getModulusChain().getCurrent();
474+
lhsScalingFactor = lwe::getScalingFactorFromEncodingAttr(
475+
lhsType.getPlaintextSpace().getEncoding());
476+
rhsScalingFactor = lwe::getScalingFactorFromEncodingAttr(
477+
rhsType.getPlaintextSpace().getEncoding());
407478
}
408479

480+
// After rescaling, now the levels may still mismatch and we need to
481+
// drop levels (without rescaling) to align.
482+
//
483+
// Example:
484+
//
485+
// lhs level = 10 != rhs level = 8
486+
// lhs scale = 80 != rhs scale = 120
487+
//
488+
// Above we rescale the rhs from 120 to 80, which drops the level to 7, then
489+
// we level reduce the lhs from 10 to 7.
409490
if (lhsLevel != rhsLevel) {
410491
LLVM_DEBUG(llvm::dbgs() << "lhs level = " << lhsLevel << " != rhs level = "
411492
<< rhsLevel << ", applying level_reduce...\n");
412493
Value operandToReduce;
413494
int64_t levelsToDrop;
414495
if (lhsLevel < rhsLevel) {
415-
operandToReduce = op.getRhs();
496+
operandToReduce = op.getOperand(1);
416497
levelsToDrop = rhsLevel - lhsLevel;
417498
} else {
418-
operandToReduce = op.getLhs();
499+
operandToReduce = op.getOperand(0);
419500
levelsToDrop = lhsLevel - rhsLevel;
420501
}
421502

422-
auto levelReduceOp = ckks::LevelReduceOp::create(
503+
LLVM_DEBUG(llvm::dbgs() << "dropping " << levelsToDrop << " levels\n");
504+
auto levelReduceOp = LevelReduceOp::create(
423505
b, operandToReduce, b.getI64IntegerAttr(levelsToDrop));
506+
levelReduceOp.dump();
424507
int64_t operandIndex = (lhsLevel < rhsLevel) ? 1 : 0;
425508
debugLevelAndScale(levelReduceOp.getResult().getType(), "operand");
426509
op->setOperand(operandIndex, levelReduceOp.getResult());
427-
return handleInferTypeOpInterface(op);
510+
auto result = handleInferTypeOpInterface(op);
511+
if (result.wasInterrupted()) {
512+
return result;
513+
}
514+
lhsType = cast<lwe::LWECiphertextType>(op.getOperand(0).getType());
515+
rhsType = cast<lwe::LWECiphertextType>(op.getOperand(1).getType());
516+
lhsLevel = lhsType.getModulusChain().getCurrent();
517+
rhsLevel = rhsType.getModulusChain().getCurrent();
518+
lhsScalingFactor = lwe::getScalingFactorFromEncodingAttr(
519+
lhsType.getPlaintextSpace().getEncoding());
520+
rhsScalingFactor = lwe::getScalingFactorFromEncodingAttr(
521+
rhsType.getPlaintextSpace().getEncoding());
428522
}
429523

430524
if (lhsScalingFactor != rhsScalingFactor) {
@@ -434,35 +528,33 @@ WalkResult handleCtCtOp(CtCtOp op) {
434528
Value operandToRescale;
435529
int64_t targetScalingFactor;
436530
if (lhsScalingFactor < rhsScalingFactor) {
437-
operandToRescale = op.getLhs();
531+
operandToRescale = op.getOperand(0);
438532
targetScalingFactor = rhsScalingFactor;
439533
} else {
440-
operandToRescale = op.getRhs();
534+
operandToRescale = op.getOperand(1);
441535
targetScalingFactor = lhsScalingFactor;
442536
}
443537

444538
auto encodedSplatOne = encodeSplattedCleartextUsingCtAndScalingFactor(
445539
b, cast<lwe::LWECiphertextType>(operandToRescale.getType()),
446540
targetScalingFactor, APFloat(1.0));
447-
auto mulPlainOp =
448-
ckks::MulPlainOp::create(b, operandToRescale, encodedSplatOne);
541+
auto mulPlainOp = MulPlainOp::create(b, operandToRescale, encodedSplatOne);
449542
int64_t operandIndex = (lhsScalingFactor < rhsScalingFactor) ? 0 : 1;
450543
debugLevelAndScale(mulPlainOp.getResult().getType(), "operand");
451544
op->setOperand(operandIndex, mulPlainOp.getResult());
452-
return handleInferTypeOpInterface(op);
453545
}
454546

455547
return handleInferTypeOpInterface(op);
456548
}
457549

458-
WalkResult handleMul(ckks::MulOp op) {
550+
WalkResult handleMul(MulOp op) {
459551
ImplicitLocOpBuilder b(op.getLoc(), op->getContext());
460552
b.setInsertionPoint(op);
461553
LLVM_DEBUG(llvm::dbgs() << "Handling Mul op\n");
462554
lwe::LWECiphertextType lhsType =
463-
cast<lwe::LWECiphertextType>(op.getLhs().getType());
555+
cast<lwe::LWECiphertextType>(op.getOperand(0).getType());
464556
lwe::LWECiphertextType rhsType =
465-
cast<lwe::LWECiphertextType>(op.getRhs().getType());
557+
cast<lwe::LWECiphertextType>(op.getOperand(1).getType());
466558

467559
// Mul ops may have differing scales, but not differing levels
468560
int64_t lhsLevel = lhsType.getModulusChain().getCurrent();
@@ -474,15 +566,15 @@ WalkResult handleMul(ckks::MulOp op) {
474566
Value operandToReduce;
475567
int64_t levelsToDrop;
476568
if (lhsLevel < rhsLevel) {
477-
operandToReduce = op.getRhs();
569+
operandToReduce = op.getOperand(1);
478570
levelsToDrop = rhsLevel - lhsLevel;
479571
} else {
480-
operandToReduce = op.getLhs();
572+
operandToReduce = op.getOperand(0);
481573
levelsToDrop = lhsLevel - rhsLevel;
482574
}
483575
LLVM_DEBUG(llvm::dbgs() << "dropping " << levelsToDrop << " levels\n");
484576

485-
auto levelReduceOp = ckks::LevelReduceOp::create(
577+
auto levelReduceOp = LevelReduceOp::create(
486578
b, operandToReduce, b.getI64IntegerAttr(levelsToDrop));
487579
int64_t operandIndex = (lhsLevel < rhsLevel) ? 1 : 0;
488580
debugLevelAndScale(levelReduceOp.getResult().getType(), "operand");
@@ -533,14 +625,15 @@ struct OrionToCKKS : public impl::OrionToCKKSBase<OrionToCKKS> {
533625
}
534626

535627
return llvm::TypeSwitch<Operation*, WalkResult>(op)
536-
.Case<ckks::RescaleOp>(handleRescaleOp)
537-
.Case<ckks::AddPlainOp>(handleNonMulCtPtOp<ckks::AddPlainOp>)
538-
.Case<ckks::SubPlainOp>(handleNonMulCtPtOp<ckks::SubPlainOp>)
539-
.Case<ckks::MulPlainOp>(handleMulPlain)
540-
.Case<ckks::AddOp>(handleCtCtOp<ckks::AddOp>)
541-
.Case<ckks::SubOp>(handleCtCtOp<ckks::SubOp>)
542-
.Case<ckks::MulOp>(handleMul)
543-
.Case<ckks::BootstrapOp>(handleBootstrap)
628+
// FIXME: convert these to patterns and use walk driver
629+
.Case<RescaleOp>(handleRescaleOp)
630+
.Case<AddPlainOp>(handleNonMulCtPtOp<AddPlainOp>)
631+
.Case<SubPlainOp>(handleNonMulCtPtOp<SubPlainOp>)
632+
.Case<MulPlainOp>(handleMulPlain)
633+
.Case<AddOp>(handleCtCtOp<AddOp>)
634+
.Case<SubOp>(handleCtCtOp<SubOp>)
635+
.Case<MulOp>(handleMul)
636+
.Case<BootstrapOp>(handleBootstrap)
544637
// Some ops above implement InferTypeOpInterface, but need special
545638
// cases, so this must come after them.
546639
.Case<InferTypeOpInterface>(handleInferTypeOpInterface)

0 commit comments

Comments
 (0)