Skip to content

Commit

Permalink
feat(compiler/simu): support signed integers
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Apr 24, 2024
1 parent 8da333a commit 792e5cf
Show file tree
Hide file tree
Showing 6 changed files with 250 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,21 @@ uint64_t sim_neg_lwe_u64(uint64_t plaintext);
///
/// \param lhs left operand
/// \param rhs right operand
/// \param loc
/// \param loc location of the operation
/// \param is_signed tell if operands are known to be signed
/// \return uint64_t
uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc);
uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc, bool is_signed);

/// \brief simulate the multiplication of a noisy plaintext with an integer
///
/// The function also checks for overflow and print a warning when it happens
///
/// \param lhs left operand
/// \param rhs right operand
/// \param loc
/// \param loc location of the operation
/// \param is_signed tell if operands are known to be signed
/// \return uint64_t
uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc);
uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc, bool is_signed);

/// \brief simulate a keyswitch on a noisy plaintext
///
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,47 @@ struct AddEintIntOpPattern : public ScalarOpPattern<FHE::AddEintIntOp> {
op.getLoc(), adaptor.getB(),
op.getType().cast<FHE::FheIntegerInterface>().getWidth(), rewriter);

auto isSigned = op.getType().cast<FHE::FheIntegerInterface>().isSigned();
std::vector<mlir::NamedAttribute> attrs;
if (isSigned) {
auto signedAttr =
rewriter.getNamedAttr("signed", rewriter.getBoolAttr(isSigned));
attrs.push_back(signedAttr);
}

// Write the new op
auto newOp = rewriter.replaceOpWithNewOp<TFHE::AddGLWEIntOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.getA(),
encodedInt);
op, getTypeConverter()->convertType(op.getType()),
mlir::ValueRange({adaptor.getA(), encodedInt}), attrs);
forwardOptimizerID(op, newOp);

return mlir::success();
}
};

/// Rewriter for the `FHE::add_eint` operation.
struct AddEintOpPattern : public mlir::OpConversionPattern<FHE::AddEintOp> {
AddEintOpPattern(mlir::TypeConverter &converter, mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpConversionPattern<FHE::AddEintOp>(converter, context, benefit) {
}

mlir::LogicalResult
matchAndRewrite(FHE::AddEintOp op, FHE::AddEintOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {

auto isSigned = op.getType().cast<FHE::FheIntegerInterface>().isSigned();
std::vector<mlir::NamedAttribute> attrs;
if (isSigned) {
auto signedAttr =
rewriter.getNamedAttr("signed", rewriter.getBoolAttr(isSigned));
attrs.push_back(signedAttr);
}

// Write the new op
auto newOp = rewriter.replaceOpWithNewOp<TFHE::AddGLWEOp>(
op, getTypeConverter()->convertType(op.getType()),
adaptor.getOperands(), attrs);
forwardOptimizerID(op, newOp);

return mlir::success();
Expand Down Expand Up @@ -220,10 +257,18 @@ struct SubEintIntOpPattern : public ScalarOpPattern<FHE::SubEintIntOp> {
eintOperand.getType().cast<FHE::FheIntegerInterface>().getWidth(),
rewriter);

auto isSigned = op.getType().cast<FHE::FheIntegerInterface>().isSigned();
std::vector<mlir::NamedAttribute> attrs;
if (isSigned) {
auto signedAttr =
rewriter.getNamedAttr("signed", rewriter.getBoolAttr(isSigned));
attrs.push_back(signedAttr);
}

// Write the new op
auto newOp = rewriter.replaceOpWithNewOp<TFHE::AddGLWEIntOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.getA(),
encodedInt);
op, getTypeConverter()->convertType(op.getType()),
mlir::ValueRange({adaptor.getA(), encodedInt}), attrs);
forwardOptimizerID(op, newOp);

return mlir::success();
Expand All @@ -247,10 +292,18 @@ struct SubIntEintOpPattern : public ScalarOpPattern<FHE::SubIntEintOp> {
op.getB().getType().cast<FHE::FheIntegerInterface>().getWidth(),
rewriter);

auto isSigned = op.getType().cast<FHE::FheIntegerInterface>().isSigned();
std::vector<mlir::NamedAttribute> attrs;
if (isSigned) {
auto signedAttr =
rewriter.getNamedAttr("signed", rewriter.getBoolAttr(isSigned));
attrs.push_back(signedAttr);
}

// Write the new op
auto newOp = rewriter.replaceOpWithNewOp<TFHE::SubGLWEIntOp>(
op, getTypeConverter()->convertType(op.getType()), encodedInt,
adaptor.getB());
op, getTypeConverter()->convertType(op.getType()),
mlir::ValueRange({encodedInt, adaptor.getB()}), attrs);
forwardOptimizerID(op, newOp);

return mlir::success();
Expand All @@ -276,10 +329,18 @@ struct SubEintOpPattern : public ScalarOpPattern<FHE::SubEintOp> {
location, rhsOperand.getType(), rhsOperand);
forwardOptimizerID(op, negative);

auto isSigned = op.getType().cast<FHE::FheIntegerInterface>().isSigned();
std::vector<mlir::NamedAttribute> attrs;
if (isSigned) {
auto signedAttr =
rewriter.getNamedAttr("signed", rewriter.getBoolAttr(isSigned));
attrs.push_back(signedAttr);
}

// Write new op.
auto newOp = rewriter.replaceOpWithNewOp<TFHE::AddGLWEOp>(
op, getTypeConverter()->convertType(op.getType()), lhsOperand,
negative.getResult());
op, getTypeConverter()->convertType(op.getType()),
mlir::ValueRange({lhsOperand, negative.getResult()}), attrs);
forwardOptimizerID(op, newOp);

return mlir::success();
Expand All @@ -305,10 +366,18 @@ struct MulEintIntOpPattern : public ScalarOpPattern<FHE::MulEintIntOp> {
mlir::Value castedCleartext = rewriter.create<mlir::arith::ExtSIOp>(
location, rewriter.getIntegerType(64), intOperand);

auto isSigned = op.getType().cast<FHE::FheIntegerInterface>().isSigned();
std::vector<mlir::NamedAttribute> attrs;
if (isSigned) {
auto signedAttr =
rewriter.getNamedAttr("signed", rewriter.getBoolAttr(isSigned));
attrs.push_back(signedAttr);
}

// Write the new op.
auto newOp = rewriter.replaceOpWithNewOp<TFHE::MulGLWEIntOp>(
op, getTypeConverter()->convertType(op.getType()), eintOperand,
castedCleartext);
op, getTypeConverter()->convertType(op.getType()),
mlir::ValueRange({eintOperand, castedCleartext}), attrs);
forwardOptimizerID(op, newOp);

return mlir::success();
Expand Down Expand Up @@ -804,12 +873,11 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
FHE::NegEintOp, TFHE::NegGLWEOp, true>,
// |_ `FHE::not`
mlir::concretelang::GenericOneToOneOpConversionPattern<
FHE::BoolNotOp, TFHE::NegGLWEOp, true>,
// |_ `FHE::add_eint`
mlir::concretelang::GenericOneToOneOpConversionPattern<
FHE::AddEintOp, TFHE::AddGLWEOp, true>>(&getContext(), converter);
FHE::BoolNotOp, TFHE::NegGLWEOp, true>>(&getContext(), converter);
// |_ `FHE::add_eint_int`
patterns.add<lowering::AddEintIntOpPattern,
// |_ `FHE::add_eint`
lowering::AddEintOpPattern,
// |_ `FHE::sub_int_eint`
lowering::SubIntEintOpPattern,
// |_ `FHE::sub_eint_int`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,21 @@ struct NegOpPattern : public mlir::OpConversionPattern<TFHE::NegGLWEOp> {
}
};

int locationStringCtr = 0;
mlir::Value globalStringValueFromLoc(mlir::ConversionPatternRewriter &rewriter,
mlir::Location loc) {
std::string locString;
auto ros = llvm::raw_string_ostream(locString);
loc.print(ros);

std::string msgName;
std::stringstream stream;
stream << "loc_" << rand();
stream >> msgName;
return mlir::LLVM::createGlobalString(loc, rewriter, msgName, locString,
mlir::LLVM::linkage::Linkage::Linkonce,
false);
locString.append("\0");
auto locStrWithNullByte =
llvm::StringRef(locString.c_str(), locString.size() + 1);

std::stringstream msgName;
msgName << "str_loc_" << locationStringCtr++;
return mlir::LLVM::createGlobalString(
loc, rewriter, msgName.str(), locStrWithNullByte,
mlir::LLVM::linkage::Linkage::Linkonce, false);
}

template <typename AddOp, typename AddOpAdaptor>
Expand All @@ -122,20 +124,30 @@ struct AddOpPattern : public mlir::OpConversionPattern<AddOp> {
const std::string funcName = "sim_add_lwe_u64";

auto locString = globalStringValueFromLoc(rewriter, addOp.getLoc());
// check if operation has been tagged as signed
auto isSigned = false;
mlir::Attribute signedAttr = adaptor.getAttributes().get("signed");
if (signedAttr && signedAttr.cast<mlir::BoolAttr>().getValue()) {
isSigned = true;
}
mlir::Value isSignedCst = rewriter.create<mlir::arith::ConstantIntOp>(
addOp.getLoc(), isSigned, 1);

if (insertForwardDeclaration(
addOp, rewriter, funcName,
rewriter.getFunctionType(
{rewriter.getIntegerType(64), rewriter.getIntegerType(64),
mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type())},
mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()),
rewriter.getIntegerType(1)},
{rewriter.getIntegerType(64)}))
.failed()) {
return mlir::failure();
}

rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
addOp, funcName, mlir::TypeRange{rewriter.getIntegerType(64)},
mlir::ValueRange({adaptor.getA(), adaptor.getB(), locString}));
mlir::ValueRange(
{adaptor.getA(), adaptor.getB(), locString, isSignedCst}));

return mlir::success();
}
Expand All @@ -156,19 +168,30 @@ struct MulOpPattern : public mlir::OpConversionPattern<TFHE::MulGLWEIntOp> {

auto locString = globalStringValueFromLoc(rewriter, mulOp.getLoc());

// check if operation has been tagged as signed
auto isSigned = false;
mlir::Attribute signedAttr = adaptor.getAttributes().get("signed");
if (signedAttr && signedAttr.cast<mlir::BoolAttr>().getValue()) {
isSigned = true;
}
mlir::Value isSignedCst = rewriter.create<mlir::arith::ConstantIntOp>(
mulOp.getLoc(), isSigned, 1);

if (insertForwardDeclaration(
mulOp, rewriter, funcName,
rewriter.getFunctionType(
{rewriter.getIntegerType(64), rewriter.getIntegerType(64),
mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type())},
mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()),
rewriter.getIntegerType(1)},
{rewriter.getIntegerType(64)}))
.failed()) {
return mlir::failure();
}

rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
mulOp, funcName, mlir::TypeRange{rewriter.getIntegerType(64)},
mlir::ValueRange({adaptor.getA(), adaptor.getB(), locString}));
mlir::ValueRange(
{adaptor.getA(), adaptor.getB(), locString, isSignedCst}));

return mlir::success();
}
Expand All @@ -186,8 +209,10 @@ struct SubIntGLWEOpPattern : public mlir::OpRewritePattern<TFHE::SubGLWEIntOp> {
mlir::Value negated = rewriter.create<TFHE::NegGLWEOp>(
subOp.getLoc(), subOp.getB().getType(), subOp.getB());

rewriter.replaceOpWithNewOp<TFHE::AddGLWEIntOp>(subOp, subOp.getType(),
negated, subOp.getA());
rewriter.replaceOpWithNewOp<TFHE::AddGLWEIntOp>(
subOp, subOp.getType(), mlir::ValueRange({negated, subOp.getA()}),
// to forward the signed attr if set
subOp.getOperation()->getAttrs());

return mlir::success();
}
Expand Down
16 changes: 11 additions & 5 deletions compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,16 +189,22 @@ void sim_wop_pbs_crt(

uint64_t sim_neg_lwe_u64(uint64_t plaintext) { return ~plaintext + 1; }

uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc) {
if (lhs > UINT63_MAX - rhs) {
uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc,
bool is_signed) {
uint64_t result = lhs + rhs;
if (!is_signed && (lhs > UINT63_MAX - rhs || result > UINT63_MAX)) {
printf("WARNING at %s: overflow happened during addition in simulation\n",
loc);
}
return lhs + rhs;
return result;
}

uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc) {
if (rhs != 0 && lhs > UINT63_MAX / rhs) {
uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc,
bool is_signed) {

uint64_t result = lhs * rhs;
if (!is_signed && rhs != 0 &&
(lhs > UINT63_MAX / rhs || result > UINT63_MAX)) {
printf("WARNING at %s: overflow happened during multiplication in "
"simulation\n",
loc);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,8 +440,10 @@ mlir::LogicalResult simulateTFHE(mlir::MLIRContext &context,
if (fheContext) {
auto solution = fheContext.value().solution;
auto optCrt = getCrtDecompositionFromSolution(solution);
if (optCrt)
if (optCrt) {
enableOverflowDetection = false;
log_verbose() << "WARNING: overflow detection disabled since using CRT";
}
}

pipelinePrinting("TFHESimulation", pm, context);
Expand Down
Loading

0 comments on commit 792e5cf

Please sign in to comment.