@@ -21,10 +21,11 @@ polynomial::RingAttr getRlweRNSRingWithLevel(polynomial::RingAttr ringAttr,
2121 int level);
2222
2323// Walks the arithmetic DAG and generates MLIR for it. This materializer is
24- // weird because some SSA values are tensors of ciphertexts while others are
25- // ciphertext-semantic tensors of cleartexts. The former requires simple ops
26- // like tensor.extract and ckks.rotate, while the latter requires
27- // tensor.extract_slice and tensor_ext.rotate.
24+ // meant to handle quirks of the Orion import process. For example it
25+ // handles scale encoding of cleartexts, and special cases for tensors of
26+ // ciphertexts vs ciphertext-semantic tensors of cleartexts. The former
27+ // requires simple ops like tensor.extract and ckks.rotate, while the latter
28+ // requires tensor.extract_slice and tensor_ext.rotate.
2829class IRMaterializingVisitor
2930 : public kernel::CachingVisitor<kernel::SSAValue, Value> {
3031 public:
@@ -39,16 +40,48 @@ class IRMaterializingVisitor
3940 builder(builder),
4041 plaintextType(ptTy) {}
4142
42- Value maybeRescale (Value value, lwe::LWECiphertextType resultType,
43- bool rescale) {
43+ // Relinearize and rescale the ciphertext if relinAndRescale is true.
44+ Value relinAndRescale (Value value, lwe::LWECiphertextType resultType,
45+ bool relinearize, bool rescale) {
4446 Value result = value;
47+ if (relinearize) {
48+ auto inputDimension = cast<lwe::LWECiphertextType>(value.getType ())
49+ .getCiphertextSpace ()
50+ .getSize ();
51+ SmallVector<int32_t > fromBasis;
52+ for (int i = 0 ; i < inputDimension; ++i) {
53+ fromBasis.push_back (i);
54+ }
55+ SmallVector<int32_t > toBasis = {0 , 1 };
56+ auto relinOp = ckks::RelinearizeOp::create (
57+ builder, result, builder.getDenseI32ArrayAttr (fromBasis),
58+ builder.getDenseI32ArrayAttr (toBasis));
59+ result = relinOp.getResult ();
60+ }
4561 if (rescale) {
4662 FailureOr<lwe::LWECiphertextType> ctTypeResult =
4763 applyModReduce (resultType);
4864 if (failed (ctTypeResult)) {
4965 emitError (result.getLoc ())
50- << " Cannot rescale ciphertext type: " << resultType;
51- return Value ();
66+ << " Cannot rescale ciphertext type, inserting extra bootstrap op" ;
67+ // sub 1 because the max level is the last index in the chain.
68+ int64_t maxLevel =
69+ resultType.getModulusChain ().getElements ().size () - 1 ;
70+
71+ // Now we cheat a little bit: normally bootstrap itself would consume
72+ // some levels, which depends on the chosen backend. In our case, we're
73+ // lowering to library backends that handle this opaquely.
74+ //
75+ // TODO(#1207): fix if this pass still matters when lowering to
76+ // polynomial.
77+ FailureOr<lwe::LWECiphertextType> outputTypeResult =
78+ cloneAtLevel (resultType, maxLevel);
79+ if (failed (outputTypeResult)) {
80+ emitError (result.getLoc ()) << " Failed to insert bootstrap" ;
81+ return Value ();
82+ }
83+ result = ckks::BootstrapOp::create (builder, outputTypeResult.value (),
84+ result);
5285 }
5386 auto ctType = ctTypeResult.value ();
5487 result = ckks::RescaleOp::create (builder, ctType, result,
@@ -68,26 +101,33 @@ class IRMaterializingVisitor
68101 }
69102
70103 template <typename T, typename CtCtOp, typename CtPtOp, typename CleartextOp>
71- Value binop (const T& node, bool rescale = false ) {
104+ Value binop (const T& node) {
72105 Value lhs = this ->process (node.left );
73106 Value rhs = this ->process (node.right );
74107
108+ bool relinearize =
109+ static_cast <bool >(std::is_same<CtCtOp, ckks::MulOp>::value);
110+ bool rescale =
111+ static_cast <bool >(std::is_same<CtCtOp, ckks::MulOp>::value) ||
112+ static_cast <bool >(std::is_same<CtPtOp, ckks::MulPlainOp>::value);
113+
75114 return TypeSwitch<Type, Value>(lhs.getType ())
76115 .template Case <lwe::LWECiphertextType>([&](auto ty) {
77116 if (isa<RankedTensorType>(rhs.getType ())) {
78- return maybeRescale (
117+ return relinAndRescale (
79118 CtPtOp::create (builder, lhs, encodeCleartextOperand (ty, rhs))
80119 .getResult (),
81- ty, rescale);
120+ ty, relinearize, rescale);
82121 }
83122
84123 if (isa<lwe::LWEPlaintextType>(rhs.getType ())) {
85- return maybeRescale (CtPtOp::create (builder, lhs, rhs).getResult (),
86- ty, rescale);
124+ return relinAndRescale (
125+ CtPtOp::create (builder, lhs, rhs).getResult (), ty, relinearize,
126+ rescale);
87127 }
88128
89- return maybeRescale (CtCtOp::create (builder, lhs, rhs).getResult (), ty ,
90- rescale);
129+ return relinAndRescale (CtCtOp::create (builder, lhs, rhs).getResult (),
130+ ty, relinearize, rescale);
91131 })
92132 .template Case <lwe::LWEPlaintextType>([&](auto ty) {
93133 if (isa<RankedTensorType>(rhs.getType ())) {
@@ -111,15 +151,15 @@ class IRMaterializingVisitor
111151 << " \n\n rhs=" << rhs << " \n " ;
112152 return Value ();
113153 }
114- return maybeRescale (CtPtOp::create (builder, lhs, rhs).getResult (),
115- ctTy, rescale);
154+ return relinAndRescale (CtPtOp::create (builder, lhs, rhs).getResult (),
155+ ctTy, relinearize , rescale);
116156 })
117157 .template Case <RankedTensorType>([&](auto ty) {
118158 auto ctTy = cast<lwe::LWECiphertextType>(rhs.getType ());
119- return maybeRescale (
159+ return relinAndRescale (
120160 CtPtOp::create (builder, encodeCleartextOperand (ctTy, lhs), rhs)
121161 .getResult (),
122- ctTy, rescale);
162+ ctTy, relinearize, rescale);
123163 })
124164 .Default ([&](Type) {
125165 emitError (lhs.getLoc ())
0 commit comments