2222
2323namespace mlir ::heir::orion {
2424
25+ using ckks::AddOp;
26+ using ckks::AddPlainOp;
27+ using ckks::BootstrapOp;
28+ using ckks::LevelReduceOp;
29+ using ckks::MulOp;
30+ using ckks::MulPlainOp;
31+ using ckks::RescaleOp;
32+ using ckks::SubOp;
33+ using ckks::SubPlainOp;
2534using kernel::ArithmeticDagNode;
2635using kernel::implementHaleviShoup;
2736using kernel::SSAValue;
@@ -114,8 +123,7 @@ struct ConvertChebyshevOp : public OpRewritePattern<ChebyshevOp> {
114123 // for evaluation purposes.
115124 auto encodedSplatRescale = encodeSplattedCleartextUsingCtAndScalingFactor (
116125 b, ctTy, logDefaultScale, rescale);
117- xInput =
118- ckks::MulPlainOp::create (b, xInput, encodedSplatRescale).getResult ();
126+ xInput = MulPlainOp::create (b, xInput, encodedSplatRescale).getResult ();
119127 }
120128
121129 if (!shift.isZero ()) {
@@ -125,8 +133,7 @@ struct ConvertChebyshevOp : public OpRewritePattern<ChebyshevOp> {
125133 ctTy.getPlaintextSpace ().getEncoding ());
126134 auto encodedSplatShift = encodeSplattedCleartextUsingCtAndScalingFactor (
127135 b, ctTy, scalingFactor, shift);
128- xInput =
129- ckks::AddPlainOp::create (b, xInput, encodedSplatShift).getResult ();
136+ xInput = AddPlainOp::create (b, xInput, encodedSplatShift).getResult ();
130137 }
131138
132139 SSAValue xNode (xInput);
@@ -150,8 +157,8 @@ struct ConvertChebyshevOp : public OpRewritePattern<ChebyshevOp> {
150157 return failure ();
151158 }
152159 auto moddedDownTy = ctTypeResult.value ();
153- auto rescaleOp = ckks:: RescaleOp::create (
154- b, moddedDownTy, finalOutput, ctTy.getCiphertextSpace ().getRing ());
160+ auto rescaleOp = RescaleOp::create (b, moddedDownTy, finalOutput,
161+ ctTy.getCiphertextSpace ().getRing ());
155162 finalOutput = rescaleOp.getResult ();
156163 }
157164
@@ -192,8 +199,8 @@ struct ConvertLinearTransformOp : public OpRewritePattern<LinearTransformOp> {
192199
193200 // ct-pt muls in the kernel didn't rescale, so rescale at the very end
194201 if (!rescaleAfterCtPtMul) {
195- auto rescaleOp = ckks:: RescaleOp::create (
196- b, ctTy, finalOutput, ctTy.getCiphertextSpace ().getRing ());
202+ auto rescaleOp = RescaleOp::create (b, ctTy, finalOutput,
203+ ctTy.getCiphertextSpace ().getRing ());
197204 finalOutput = rescaleOp.getResult ();
198205 }
199206
@@ -205,6 +212,56 @@ struct ConvertLinearTransformOp : public OpRewritePattern<LinearTransformOp> {
205212 std::string libraryTarget;
206213};
207214
215+ struct FixTypesForRescale : public OpRewritePattern <RescaleOp> {
216+ using OpRewritePattern<RescaleOp>::OpRewritePattern;
217+
218+ LogicalResult matchAndRewrite (RescaleOp op,
219+ PatternRewriter& rewriter) const override {
220+ LLVM_DEBUG (llvm::dbgs () << " Handling RescaleOp\n " );
221+ lwe::LWECiphertextType inputType =
222+ cast<lwe::LWECiphertextType>(op.getInput ().getType ());
223+ FailureOr<lwe::LWECiphertextType> outputTypeResult =
224+ applyModReduce (inputType);
225+ if (failed (outputTypeResult)) {
226+ op.emitError ()
227+ << " Cannot drop one limb from ciphertext type, inserting bootstrap\n " ;
228+ int64_t maxLevel = inputType.getModulusChain ().getElements ().size () - 1 ;
229+
230+ // Now we cheat a little bit: normally bootstrap itself would consume
231+ // some levels, which depends on the chosen backend. In our case, we're
232+ // lowering to library backends that handle this opaquely.
233+ //
234+ // TODO(#1207): fix if this pass still matters when lowering to
235+ // polynomial.
236+ FailureOr<lwe::LWECiphertextType> bootstrapResultTypeResult =
237+ cloneAtLevel (inputType, maxLevel);
238+ if (failed (bootstrapResultTypeResult)) {
239+ op.emitError () << " Failed to insert bootstrap" ;
240+ return failure ();
241+ }
242+ ImplicitLocOpBuilder builder (op.getLoc (), op->getContext ());
243+ builder.setInsertionPoint (op);
244+ auto bootstrapOp = BootstrapOp::create (
245+ builder, bootstrapResultTypeResult.value (), op.getInput ());
246+ op.setOperand (bootstrapOp.getResult ());
247+
248+ outputTypeResult = applyModReduce (bootstrapResultTypeResult.value ());
249+ if (failed (outputTypeResult)) {
250+ return op.emitError ()
251+ << " Failed to rescale even after inserting bootstrap\n " ;
252+ }
253+ }
254+ // FIXME: convert to rewriter style
255+ debugLevelAndScale (outputTypeResult.value ());
256+ op.getResult ().setType (outputTypeResult.value ());
257+ op.setToRingAttr (outputTypeResult.value ().getCiphertextSpace ().getRing ());
258+ return success ();
259+ }
260+
261+ private:
262+ std::string libraryTarget;
263+ };
264+
208265WalkResult handleInferTypeOpInterface (InferTypeOpInterface op) {
209266 LLVM_DEBUG (llvm::dbgs () << " Handling InferTypeOpInterface: " << op->getName ()
210267 << " \n " );
@@ -229,7 +286,7 @@ WalkResult handleInferTypeOpInterface(InferTypeOpInterface op) {
229286 return WalkResult::advance ();
230287}
231288
232- WalkResult handleRescaleOp (ckks:: RescaleOp op) {
289+ WalkResult handleRescaleOp (RescaleOp op) {
233290 LLVM_DEBUG (llvm::dbgs () << " Handling RescaleOp\n " );
234291 lwe::LWECiphertextType inputType =
235292 cast<lwe::LWECiphertextType>(op.getInput ().getType ());
@@ -254,7 +311,7 @@ WalkResult handleRescaleOp(ckks::RescaleOp op) {
254311 }
255312 ImplicitLocOpBuilder builder (op.getLoc (), op->getContext ());
256313 builder.setInsertionPoint (op);
257- auto bootstrapOp = ckks:: BootstrapOp::create (
314+ auto bootstrapOp = BootstrapOp::create (
258315 builder, bootstrapResultTypeResult.value (), op.getInput ());
259316 op.setOperand (bootstrapOp.getResult ());
260317
@@ -270,7 +327,7 @@ WalkResult handleRescaleOp(ckks::RescaleOp op) {
270327 return WalkResult::advance ();
271328}
272329
273- WalkResult handleBootstrap (ckks:: BootstrapOp op) {
330+ WalkResult handleBootstrap (BootstrapOp op) {
274331 LLVM_DEBUG (llvm::dbgs () << " Handling BootstrapOp\n " );
275332 // First, we need to find the maximum level in the modulus chain from the
276333 // ciphertext type.
@@ -297,12 +354,12 @@ WalkResult handleNonMulCtPtOp(CtPtOp op) {
297354 lwe::LWECiphertextType ctType;
298355 LLVM_DEBUG (llvm::dbgs () << " Handling CtPt op: " << op->getName () << " \n " );
299356 Value plaintextOperand;
300- if (isa<lwe::LWECiphertextType>(op.getLhs ( ).getType ())) {
301- ctType = cast<lwe::LWECiphertextType>(op.getLhs ( ).getType ());
302- plaintextOperand = op.getRhs ( );
357+ if (isa<lwe::LWECiphertextType>(op.getOperand ( 0 ).getType ())) {
358+ ctType = cast<lwe::LWECiphertextType>(op.getOperand ( 0 ).getType ());
359+ plaintextOperand = op.getOperand ( 1 );
303360 } else {
304- ctType = cast<lwe::LWECiphertextType>(op.getRhs ( ).getType ());
305- plaintextOperand = op.getLhs ( );
361+ ctType = cast<lwe::LWECiphertextType>(op.getOperand ( 1 ).getType ());
362+ plaintextOperand = op.getOperand ( 0 );
306363 }
307364
308365 if (auto encodeOp =
@@ -328,16 +385,16 @@ WalkResult handleNonMulCtPtOp(CtPtOp op) {
328385 return WalkResult::advance ();
329386}
330387
331- WalkResult handleMulPlain (ckks:: MulPlainOp op) {
388+ WalkResult handleMulPlain (MulPlainOp op) {
332389 LLVM_DEBUG (llvm::dbgs () << " Handling MulPlain op\n " );
333390 lwe::LWECiphertextType ctType;
334391 lwe::LWEPlaintextType ptType;
335- if (isa<lwe::LWECiphertextType>(op.getLhs ( ).getType ())) {
336- ctType = cast<lwe::LWECiphertextType>(op.getLhs ( ).getType ());
337- ptType = cast<lwe::LWEPlaintextType>(op.getRhs ( ).getType ());
392+ if (isa<lwe::LWECiphertextType>(op.getOperand ( 0 ).getType ())) {
393+ ctType = cast<lwe::LWECiphertextType>(op.getOperand ( 0 ).getType ());
394+ ptType = cast<lwe::LWEPlaintextType>(op.getOperand ( 1 ).getType ());
338395 } else {
339- ctType = cast<lwe::LWECiphertextType>(op.getRhs ( ).getType ());
340- ptType = cast<lwe::LWEPlaintextType>(op.getLhs ( ).getType ());
396+ ctType = cast<lwe::LWECiphertextType>(op.getOperand ( 1 ).getType ());
397+ ptType = cast<lwe::LWEPlaintextType>(op.getOperand ( 0 ).getType ());
341398 }
342399 auto newCtType = lwe::LWECiphertextType::get (
343400 op.getContext (), ctType.getApplicationData (),
@@ -360,33 +417,32 @@ WalkResult handleMulPlain(ckks::MulPlainOp op) {
360417template <typename CtCtOp>
361418WalkResult handleCtCtOp (CtCtOp op) {
362419 LLVM_DEBUG (llvm::dbgs () << " Handling ct-ct op: " << op->getName () << " \n " );
363- lwe::LWECiphertextType lhsType =
364- cast<lwe::LWECiphertextType>(op.getLhs ().getType ());
365- lwe::LWECiphertextType rhsType =
366- cast<lwe::LWECiphertextType>(op.getRhs ().getType ());
367420 ImplicitLocOpBuilder b (op.getLoc (), op->getContext ());
368421 b.setInsertionPoint (op);
369422
370423 // Determine if we need to reduce the level of one operand to match the
371424 // other, or rescale, or do both simultaneously.
372- int64_t lhsLevel = lhsType.getModulusChain ().getCurrent ();
373- int64_t rhsLevel = rhsType.getModulusChain ().getCurrent ();
374- int64_t lhsScalingFactor = lwe::getScalingFactorFromEncodingAttr (
425+ auto lhsType = cast<lwe::LWECiphertextType>(op.getOperand (0 ).getType ());
426+ auto rhsType = cast<lwe::LWECiphertextType>(op.getOperand (1 ).getType ());
427+ auto lhsLevel = lhsType.getModulusChain ().getCurrent ();
428+ auto rhsLevel = rhsType.getModulusChain ().getCurrent ();
429+ auto lhsScalingFactor = lwe::getScalingFactorFromEncodingAttr (
375430 lhsType.getPlaintextSpace ().getEncoding ());
376- int64_t rhsScalingFactor = lwe::getScalingFactorFromEncodingAttr (
431+ auto rhsScalingFactor = lwe::getScalingFactorFromEncodingAttr (
377432 rhsType.getPlaintextSpace ().getEncoding ());
378433
379- if (lhsLevel != rhsLevel && lhsScalingFactor != rhsScalingFactor) {
434+ // Rescale to adjust scale and level simultaneously
435+ while (lhsLevel != rhsLevel && lhsScalingFactor != rhsScalingFactor) {
380436 LLVM_DEBUG (llvm::dbgs ()
381437 << " lhs level = " << lhsLevel << " != rhs level = " << rhsLevel
382438 << " , and " << " lhs scale = " << lhsScalingFactor
383439 << " != rhs scale = " << rhsScalingFactor
384440 << " , applying rescale...\n " );
385441 Value operandToRescale;
386- if (lhsScalingFactor >= rhsScalingFactor) {
387- operandToRescale = op.getLhs ( );
442+ if (lhsScalingFactor > rhsScalingFactor) {
443+ operandToRescale = op.getOperand ( 0 );
388444 } else {
389- operandToRescale = op.getRhs ( );
445+ operandToRescale = op.getOperand ( 1 );
390446 }
391447
392448 auto rescaleInputTy =
@@ -398,33 +454,71 @@ WalkResult handleCtCtOp(CtCtOp op) {
398454 return WalkResult::interrupt ();
399455 }
400456 auto ctType = ctTypeResult.value ();
401- auto rescaleOp = ckks:: RescaleOp::create (
402- b, ctType, operandToRescale, ctType.getCiphertextSpace ().getRing ());
457+ auto rescaleOp = RescaleOp::create (b, ctType, operandToRescale,
458+ ctType.getCiphertextSpace ().getRing ());
403459 int64_t operandIndex = (lhsScalingFactor > rhsScalingFactor) ? 1 : 0 ;
404460 debugLevelAndScale (rescaleOp.getResult ().getType (), " operand" );
405461 op->setOperand (operandIndex, rescaleOp.getResult ());
406- return handleInferTypeOpInterface (op);
462+ auto result = handleInferTypeOpInterface (op);
463+ if (result.wasInterrupted ()) {
464+ return result;
465+ }
466+
467+ if (lhsScalingFactor > rhsScalingFactor) {
468+ lhsType = cast<lwe::LWECiphertextType>(rescaleOp.getResult ().getType ());
469+ } else {
470+ rhsType = cast<lwe::LWECiphertextType>(rescaleOp.getResult ().getType ());
471+ }
472+ lhsLevel = lhsType.getModulusChain ().getCurrent ();
473+ rhsLevel = rhsType.getModulusChain ().getCurrent ();
474+ lhsScalingFactor = lwe::getScalingFactorFromEncodingAttr (
475+ lhsType.getPlaintextSpace ().getEncoding ());
476+ rhsScalingFactor = lwe::getScalingFactorFromEncodingAttr (
477+ rhsType.getPlaintextSpace ().getEncoding ());
407478 }
408479
480+ // After rescaling, now the levels may still mismatch and we need to
481+ // drop levels (without rescaling) to align.
482+ //
483+ // Example:
484+ //
485+ // lhs level = 10 != rhs level = 8
486+ // lhs scale = 80 != rhs scale = 120
487+ //
488+ // Above we rescale the rhs from 120 to 80, which drops the level to 7, then
489+ // we level reduce the lhs from 10 to 7.
409490 if (lhsLevel != rhsLevel) {
410491 LLVM_DEBUG (llvm::dbgs () << " lhs level = " << lhsLevel << " != rhs level = "
411492 << rhsLevel << " , applying level_reduce...\n " );
412493 Value operandToReduce;
413494 int64_t levelsToDrop;
414495 if (lhsLevel < rhsLevel) {
415- operandToReduce = op.getRhs ( );
496+ operandToReduce = op.getOperand ( 1 );
416497 levelsToDrop = rhsLevel - lhsLevel;
417498 } else {
418- operandToReduce = op.getLhs ( );
499+ operandToReduce = op.getOperand ( 0 );
419500 levelsToDrop = lhsLevel - rhsLevel;
420501 }
421502
422- auto levelReduceOp = ckks::LevelReduceOp::create (
503+ LLVM_DEBUG (llvm::dbgs () << " dropping " << levelsToDrop << " levels\n " );
504+ auto levelReduceOp = LevelReduceOp::create (
423505 b, operandToReduce, b.getI64IntegerAttr (levelsToDrop));
506+ levelReduceOp.dump ();
424507 int64_t operandIndex = (lhsLevel < rhsLevel) ? 1 : 0 ;
425508 debugLevelAndScale (levelReduceOp.getResult ().getType (), " operand" );
426509 op->setOperand (operandIndex, levelReduceOp.getResult ());
427- return handleInferTypeOpInterface (op);
510+ auto result = handleInferTypeOpInterface (op);
511+ if (result.wasInterrupted ()) {
512+ return result;
513+ }
514+ lhsType = cast<lwe::LWECiphertextType>(op.getOperand (0 ).getType ());
515+ rhsType = cast<lwe::LWECiphertextType>(op.getOperand (1 ).getType ());
516+ lhsLevel = lhsType.getModulusChain ().getCurrent ();
517+ rhsLevel = rhsType.getModulusChain ().getCurrent ();
518+ lhsScalingFactor = lwe::getScalingFactorFromEncodingAttr (
519+ lhsType.getPlaintextSpace ().getEncoding ());
520+ rhsScalingFactor = lwe::getScalingFactorFromEncodingAttr (
521+ rhsType.getPlaintextSpace ().getEncoding ());
428522 }
429523
430524 if (lhsScalingFactor != rhsScalingFactor) {
@@ -434,35 +528,33 @@ WalkResult handleCtCtOp(CtCtOp op) {
434528 Value operandToRescale;
435529 int64_t targetScalingFactor;
436530 if (lhsScalingFactor < rhsScalingFactor) {
437- operandToRescale = op.getLhs ( );
531+ operandToRescale = op.getOperand ( 0 );
438532 targetScalingFactor = rhsScalingFactor;
439533 } else {
440- operandToRescale = op.getRhs ( );
534+ operandToRescale = op.getOperand ( 1 );
441535 targetScalingFactor = lhsScalingFactor;
442536 }
443537
444538 auto encodedSplatOne = encodeSplattedCleartextUsingCtAndScalingFactor (
445539 b, cast<lwe::LWECiphertextType>(operandToRescale.getType ()),
446540 targetScalingFactor, APFloat (1.0 ));
447- auto mulPlainOp =
448- ckks::MulPlainOp::create (b, operandToRescale, encodedSplatOne);
541+ auto mulPlainOp = MulPlainOp::create (b, operandToRescale, encodedSplatOne);
449542 int64_t operandIndex = (lhsScalingFactor < rhsScalingFactor) ? 0 : 1 ;
450543 debugLevelAndScale (mulPlainOp.getResult ().getType (), " operand" );
451544 op->setOperand (operandIndex, mulPlainOp.getResult ());
452- return handleInferTypeOpInterface (op);
453545 }
454546
455547 return handleInferTypeOpInterface (op);
456548}
457549
458- WalkResult handleMul (ckks:: MulOp op) {
550+ WalkResult handleMul (MulOp op) {
459551 ImplicitLocOpBuilder b (op.getLoc (), op->getContext ());
460552 b.setInsertionPoint (op);
461553 LLVM_DEBUG (llvm::dbgs () << " Handling Mul op\n " );
462554 lwe::LWECiphertextType lhsType =
463- cast<lwe::LWECiphertextType>(op.getLhs ( ).getType ());
555+ cast<lwe::LWECiphertextType>(op.getOperand ( 0 ).getType ());
464556 lwe::LWECiphertextType rhsType =
465- cast<lwe::LWECiphertextType>(op.getRhs ( ).getType ());
557+ cast<lwe::LWECiphertextType>(op.getOperand ( 1 ).getType ());
466558
467559 // Mul ops may have differing scales, but not differing levels
468560 int64_t lhsLevel = lhsType.getModulusChain ().getCurrent ();
@@ -474,15 +566,15 @@ WalkResult handleMul(ckks::MulOp op) {
474566 Value operandToReduce;
475567 int64_t levelsToDrop;
476568 if (lhsLevel < rhsLevel) {
477- operandToReduce = op.getRhs ( );
569+ operandToReduce = op.getOperand ( 1 );
478570 levelsToDrop = rhsLevel - lhsLevel;
479571 } else {
480- operandToReduce = op.getLhs ( );
572+ operandToReduce = op.getOperand ( 0 );
481573 levelsToDrop = lhsLevel - rhsLevel;
482574 }
483575 LLVM_DEBUG (llvm::dbgs () << " dropping " << levelsToDrop << " levels\n " );
484576
485- auto levelReduceOp = ckks:: LevelReduceOp::create (
577+ auto levelReduceOp = LevelReduceOp::create (
486578 b, operandToReduce, b.getI64IntegerAttr (levelsToDrop));
487579 int64_t operandIndex = (lhsLevel < rhsLevel) ? 1 : 0 ;
488580 debugLevelAndScale (levelReduceOp.getResult ().getType (), " operand" );
@@ -533,14 +625,15 @@ struct OrionToCKKS : public impl::OrionToCKKSBase<OrionToCKKS> {
533625 }
534626
535627 return llvm::TypeSwitch<Operation*, WalkResult>(op)
536- .Case <ckks::RescaleOp>(handleRescaleOp)
537- .Case <ckks::AddPlainOp>(handleNonMulCtPtOp<ckks::AddPlainOp>)
538- .Case <ckks::SubPlainOp>(handleNonMulCtPtOp<ckks::SubPlainOp>)
539- .Case <ckks::MulPlainOp>(handleMulPlain)
540- .Case <ckks::AddOp>(handleCtCtOp<ckks::AddOp>)
541- .Case <ckks::SubOp>(handleCtCtOp<ckks::SubOp>)
542- .Case <ckks::MulOp>(handleMul)
543- .Case <ckks::BootstrapOp>(handleBootstrap)
628+ // FIXME: convert these to patterns and use walk driver
629+ .Case <RescaleOp>(handleRescaleOp)
630+ .Case <AddPlainOp>(handleNonMulCtPtOp<AddPlainOp>)
631+ .Case <SubPlainOp>(handleNonMulCtPtOp<SubPlainOp>)
632+ .Case <MulPlainOp>(handleMulPlain)
633+ .Case <AddOp>(handleCtCtOp<AddOp>)
634+ .Case <SubOp>(handleCtCtOp<SubOp>)
635+ .Case <MulOp>(handleMul)
636+ .Case <BootstrapOp>(handleBootstrap)
544637 // Some ops above implement InferTypeOpInterface, but need special
545638 // cases, so this must come after them.
546639 .Case <InferTypeOpInterface>(handleInferTypeOpInterface)
0 commit comments