Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions lib/Dialect/Secret/Conversions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ cc_library(
"@heir//lib/Dialect/LWE/IR:Dialect",
"@heir//lib/Dialect/Polynomial/IR:Dialect",
"@heir//lib/Dialect/Secret/IR:Dialect",
"@heir//lib/Utils:ContextAwareDialectConversion",
"@heir//lib/Utils:ContextAwareTypeConversion",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
Expand Down
12 changes: 6 additions & 6 deletions lib/Dialect/Secret/Conversions/Patterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "lib/Dialect/LWE/IR/LWETypes.h"
#include "lib/Dialect/ModuleAttributes.h"
#include "lib/Dialect/Secret/IR/SecretOps.h"
#include "lib/Utils/ContextAwareDialectConversion.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
Expand All @@ -18,12 +17,13 @@
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project

namespace mlir {
namespace heir {

Value insertKeyArgument(func::FuncOp parentFunc, Type encryptionKeyType,
ContextAwareConversionPatternRewriter& rewriter) {
ConversionPatternRewriter& rewriter) {
// The new key type is inserted as the last argument of the parent function.
auto oldFunctionType = parentFunc.getFunctionType();
SmallVector<Type, 4> newInputTypes;
Expand All @@ -45,7 +45,7 @@ Value insertKeyArgument(func::FuncOp parentFunc, Type encryptionKeyType,

LogicalResult ConvertClientConceal::matchAndRewrite(
secret::ConcealOp op, OpAdaptor adaptor,
ContextAwareConversionPatternRewriter& rewriter) const {
ConversionPatternRewriter& rewriter) const {
func::FuncOp parentFunc = op->getParentOfType<func::FuncOp>();
if (!parentFunc || !parentFunc->hasAttr(kClientEncFuncAttrName)) {
return op->emitError() << "expected to be inside a function with attribute "
Expand All @@ -54,7 +54,7 @@ LogicalResult ConvertClientConceal::matchAndRewrite(

// The encryption func encrypts a single value, so it must have a single
// return type. This return type may be split over multiple ciphertexts. This
// relies on the ContextAwareFuncConversion to have already run, so that the
// relies on the FuncConversion to have already run, so that the
// result type is type converted in-place.
auto resultCtTy = dyn_cast<lwe::LWECiphertextType>(
getElementTypeOrSelf(parentFunc.getResultTypes()[0]));
Expand Down Expand Up @@ -140,7 +140,7 @@ LogicalResult ConvertClientConceal::matchAndRewrite(

LogicalResult ConvertClientReveal::matchAndRewrite(
secret::RevealOp op, OpAdaptor adaptor,
ContextAwareConversionPatternRewriter& rewriter) const {
ConversionPatternRewriter& rewriter) const {
func::FuncOp parentFunc = op->getParentOfType<func::FuncOp>();
if (!parentFunc || !parentFunc->hasAttr(kClientDecFuncAttrName)) {
return op->emitError() << "expected to be inside a function with attribute "
Expand All @@ -149,7 +149,7 @@ LogicalResult ConvertClientReveal::matchAndRewrite(

// The decryption func decrypts a single value, so it must have a single
// argument that may be split over multiple ciphertexts. This relies on the
// ContextAwareFuncConversion to have already run, so that the argument type
// FuncConversion to have already run, so that the argument type
// is type converted in-place.
auto argCtTy = dyn_cast<lwe::LWECiphertextType>(
getElementTypeOrSelf(parentFunc.getArgumentTypes()[0]));
Expand Down
22 changes: 8 additions & 14 deletions lib/Dialect/Secret/Conversions/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

#include "lib/Dialect/Polynomial/IR/PolynomialAttributes.h"
#include "lib/Dialect/Secret/IR/SecretOps.h"
#include "lib/Utils/ContextAwareDialectConversion.h"
#include "lib/Utils/ContextAwareTypeConversion.h"
#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
Expand All @@ -17,19 +15,17 @@ namespace heir {
// lwe.rlwe_encrypt. Modifies the containing function to add new secret key
// material args.
// TODO(#1875): support trivial encryptions
struct ConvertClientConceal
: public ContextAwareOpConversionPattern<secret::ConcealOp> {
ConvertClientConceal(const ContextAwareTypeConverter& typeConverter,
struct ConvertClientConceal : public OpConversionPattern<secret::ConcealOp> {
ConvertClientConceal(const TypeConverter& typeConverter_,
mlir::MLIRContext* context, bool usePublicKey,
polynomial::RingAttr ring)
: ContextAwareOpConversionPattern<secret::ConcealOp>(typeConverter,
context),
: OpConversionPattern<secret::ConcealOp>(typeConverter_, context),
usePublicKey(usePublicKey),
ring(ring) {}

LogicalResult matchAndRewrite(
secret::ConcealOp op, OpAdaptor adaptor,
ContextAwareConversionPatternRewriter& rewriter) const override;
ConversionPatternRewriter& rewriter) const override;

private:
bool usePublicKey;
Expand All @@ -39,17 +35,15 @@ struct ConvertClientConceal
// Lower a client decryption function's secret.reveal op to lwe.rlwe_decrypt +
// lwe.rlwe_decode. Modifies the containing function to add new secret key
// material args.
struct ConvertClientReveal
: public ContextAwareOpConversionPattern<secret::RevealOp> {
ConvertClientReveal(const ContextAwareTypeConverter& typeConverter,
struct ConvertClientReveal : public OpConversionPattern<secret::RevealOp> {
ConvertClientReveal(const TypeConverter& typeConverter_,
mlir::MLIRContext* context, polynomial::RingAttr ring)
: ContextAwareOpConversionPattern<secret::RevealOp>(typeConverter,
context),
: OpConversionPattern<secret::RevealOp>(typeConverter_, context),
ring(ring) {}

LogicalResult matchAndRewrite(
secret::RevealOp op, OpAdaptor adaptor,
ContextAwareConversionPatternRewriter& rewriter) const override;
ConversionPatternRewriter& rewriter) const override;

private:
polynomial::RingAttr ring;
Expand Down
3 changes: 1 addition & 2 deletions lib/Dialect/Secret/Conversions/SecretToBGV/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ cc_library(
"@heir//lib/Utils",
"@heir//lib/Utils:AttributeUtils",
"@heir//lib/Utils:ContextAwareConversionUtils",
"@heir//lib/Utils:ContextAwareDialectConversion",
"@heir//lib/Utils:ContextAwareTypeConversion",
"@heir//lib/Utils:ConversionUtils",
"@heir//lib/Utils/Polynomial",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
Expand Down
54 changes: 38 additions & 16 deletions lib/Dialect/Secret/Conversions/SecretToBGV/SecretToBGV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <cstdint>
#include <optional>
#include <string>
#include <utility>
#include <vector>

Expand All @@ -25,11 +26,11 @@
#include "lib/Dialect/Secret/IR/SecretTypes.h"
#include "lib/Utils/AttributeUtils.h"
#include "lib/Utils/ContextAwareConversionUtils.h"
#include "lib/Utils/ContextAwareDialectConversion.h"
#include "lib/Utils/ContextAwareTypeConversion.h"
#include "lib/Utils/ConversionUtils.h"
#include "lib/Utils/Polynomial/Polynomial.h"
#include "lib/Utils/Utils.h"
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
Expand All @@ -44,6 +45,8 @@
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project

#define DEBUG_TYPE "secret-to-bgv"

namespace mlir::heir {

#define GEN_PASS_DEF_SECRETTOBGV
Expand Down Expand Up @@ -90,22 +93,34 @@ polynomial::RingAttr getRlweRNSRingWithLevel(polynomial::RingAttr ringAttr,

} // namespace

class SecretToBGVTypeConverter
: public UniquelyNamedAttributeAwareTypeConverter {
class SecretToBGVTypeConverter : public TypeConverter {
public:
SecretToBGVTypeConverter(MLIRContext* ctx, polynomial::RingAttr rlweRing,
int64_t ptm, bool isBFV)
: UniquelyNamedAttributeAwareTypeConverter(
mgmt::MgmtDialect::kArgMgmtAttrName),
ring(rlweRing),
plaintextModulus(ptm),
isBFV(isBFV) {
addConversion([](Type type, Attribute attr) { return type; });
addConversion([this](secret::SecretType type, mgmt::MgmtAttr mgmtAttr) {
return convertSecretTypeWithMgmtAttr(type, mgmtAttr);
: ring(rlweRing), plaintextModulus(ptm), isBFV(isBFV) {
addConversion([this](Value value) -> std::optional<Type> {
LLVM_DEBUG(llvm::dbgs() << "Converting type for value " << value << "\n");
FailureOr<Attribute> attr = findAttributeAssociatedWith(
value, mgmt::MgmtDialect::kArgMgmtAttrName);
if (failed(attr)) {
LLVM_DEBUG(llvm::dbgs()
<< "Unable to find context attribute for " << value);
return std::nullopt;
}
LLVM_DEBUG(llvm::dbgs() << "found attribute " << attr.value() << "\n");
return convertTypeWithAttr(value.getType(), attr.value());
});
}

std::optional<Type> convertTypeWithAttr(Type type, Attribute attr) const {
auto secretType = dyn_cast<secret::SecretType>(type);
auto mgmtAttr = dyn_cast<mgmt::MgmtAttr>(attr);
if (secretType && mgmtAttr)
return convertSecretTypeWithMgmtAttr(secretType, mgmtAttr);
LLVM_DEBUG(llvm::dbgs() << "Only supported secret types with mgmt attr");
return std::nullopt;
}

Type convertSecretTypeWithMgmtAttr(secret::SecretType type,
mgmt::MgmtAttr mgmtAttr) const {
auto level = mgmtAttr.getLevel();
Expand Down Expand Up @@ -262,10 +277,17 @@ struct SecretToBGV : public impl::SecretToBGVBase<SecretToBGV> {
rlweRing.value());
patterns.add<ConvertClientReveal>(typeConverter, context, rlweRing.value());

addStructuralConversionPatterns(typeConverter, patterns, target);

if (failed(applyContextAwarePartialConversion(module, target,
std::move(patterns)))) {
addContextAwareStructuralConversionPatterns(
typeConverter, patterns, target,
std::string(mgmt::MgmtDialect::kArgMgmtAttrName),
[&](Type type, Attribute attr) {
return typeConverter.convertTypeWithAttr(type, attr);
});

ConversionConfig config;
config.allowPatternRollback = false;
if (failed(applyPartialConversion(module, target, std::move(patterns),
config))) {
return signalPassFailure();
}

Expand Down
3 changes: 1 addition & 2 deletions lib/Dialect/Secret/Conversions/SecretToCGGI/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ cc_library(
"@heir//lib/Dialect/Secret/IR:Dialect",
"@heir//lib/Transforms/MemrefToArith:Utils",
"@heir//lib/Utils:ContextAwareConversionUtils",
"@heir//lib/Utils:ContextAwareDialectConversion",
"@heir//lib/Utils:ContextAwareTypeConversion",
"@heir//lib/Utils:ConversionUtils",
"@heir//lib/Utils/Polynomial",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
Expand Down
Loading
Loading