From 792e5cf26cb2d2d9a5c8e80eb9d9e6076ee87738 Mon Sep 17 00:00:00 2001 From: youben11 Date: Wed, 24 Apr 2024 16:17:38 +0100 Subject: [PATCH] feat(compiler/simu): support signed integers --- .../include/concretelang/Runtime/simulation.h | 10 +- .../FHEToTFHEScalar/FHEToTFHEScalar.cpp | 96 ++++++++++++--- .../Conversion/SimulateTFHE/SimulateTFHE.cpp | 53 ++++++--- .../compiler/lib/Runtime/simulation.cpp | 16 ++- .../compiler/lib/Support/Pipeline.cpp | 4 +- .../compiler/tests/python/test_simulation.py | 109 ++++++++++++++++++ 6 files changed, 250 insertions(+), 38 deletions(-) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h index 53afd32e13..ca661ef0b2 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h @@ -31,9 +31,10 @@ uint64_t sim_neg_lwe_u64(uint64_t plaintext); /// /// \param lhs left operand /// \param rhs right operand -/// \param loc +/// \param loc location of the operation +/// \param is_signed tell if operands are known to be signed /// \return uint64_t -uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc); +uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc, bool is_signed); /// \brief simulate the multiplication of a noisy plaintext with an integer /// @@ -41,9 +42,10 @@ uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc); /// /// \param lhs left operand /// \param rhs right operand -/// \param loc +/// \param loc location of the operation +/// \param is_signed tell if operands are known to be signed /// \return uint64_t -uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc); +uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc, bool is_signed); /// \brief simulate a keyswitch on a noisy plaintext /// diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp index ae88c00c32..cd0f98918d 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp @@ -180,10 +180,47 @@ struct AddEintIntOpPattern : public ScalarOpPattern { op.getLoc(), adaptor.getB(), op.getType().cast().getWidth(), rewriter); + auto isSigned = op.getType().cast().isSigned(); + std::vector attrs; + if (isSigned) { + auto signedAttr = + rewriter.getNamedAttr("signed", rewriter.getBoolAttr(isSigned)); + attrs.push_back(signedAttr); + } + // Write the new op auto newOp = rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), adaptor.getA(), - encodedInt); + op, getTypeConverter()->convertType(op.getType()), + mlir::ValueRange({adaptor.getA(), encodedInt}), attrs); + forwardOptimizerID(op, newOp); + + return mlir::success(); + } +}; + +/// Rewriter for the `FHE::add_eint` operation. +struct AddEintOpPattern : public mlir::OpConversionPattern { + AddEintOpPattern(mlir::TypeConverter &converter, mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : mlir::OpConversionPattern(converter, context, benefit) { + } + + mlir::LogicalResult + matchAndRewrite(FHE::AddEintOp op, FHE::AddEintOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + + auto isSigned = op.getType().cast().isSigned(); + std::vector attrs; + if (isSigned) { + auto signedAttr = + rewriter.getNamedAttr("signed", rewriter.getBoolAttr(isSigned)); + attrs.push_back(signedAttr); + } + + // Write the new op + auto newOp = rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), + adaptor.getOperands(), attrs); forwardOptimizerID(op, newOp); return mlir::success(); @@ -220,10 +257,18 @@ struct SubEintIntOpPattern : public ScalarOpPattern { eintOperand.getType().cast().getWidth(), rewriter); + auto isSigned = op.getType().cast().isSigned(); + std::vector attrs; + if (isSigned) { + auto signedAttr = + rewriter.getNamedAttr("signed", rewriter.getBoolAttr(isSigned)); + attrs.push_back(signedAttr); + } + // Write the new op auto newOp = rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), adaptor.getA(), - encodedInt); + op, getTypeConverter()->convertType(op.getType()), + mlir::ValueRange({adaptor.getA(), encodedInt}), attrs); forwardOptimizerID(op, newOp); return mlir::success(); @@ -247,10 +292,18 @@ struct SubIntEintOpPattern : public ScalarOpPattern { op.getB().getType().cast().getWidth(), rewriter); + auto isSigned = op.getType().cast().isSigned(); + std::vector attrs; + if (isSigned) { + auto signedAttr = + rewriter.getNamedAttr("signed", rewriter.getBoolAttr(isSigned)); + attrs.push_back(signedAttr); + } + // Write the new op auto newOp = rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), encodedInt, - adaptor.getB()); + op, getTypeConverter()->convertType(op.getType()), + mlir::ValueRange({encodedInt, adaptor.getB()}), attrs); forwardOptimizerID(op, newOp); return mlir::success(); @@ -276,10 +329,18 @@ struct SubEintOpPattern : public ScalarOpPattern { location, rhsOperand.getType(), rhsOperand); forwardOptimizerID(op, negative); + auto isSigned = op.getType().cast().isSigned(); + std::vector attrs; + if (isSigned) { + auto signedAttr = + rewriter.getNamedAttr("signed", rewriter.getBoolAttr(isSigned)); + attrs.push_back(signedAttr); + } + // Write new op. auto newOp = rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), lhsOperand, - negative.getResult()); + op, getTypeConverter()->convertType(op.getType()), + mlir::ValueRange({lhsOperand, negative.getResult()}), attrs); forwardOptimizerID(op, newOp); return mlir::success(); @@ -305,10 +366,18 @@ struct MulEintIntOpPattern : public ScalarOpPattern { mlir::Value castedCleartext = rewriter.create( location, rewriter.getIntegerType(64), intOperand); + auto isSigned = op.getType().cast().isSigned(); + std::vector attrs; + if (isSigned) { + auto signedAttr = + rewriter.getNamedAttr("signed", rewriter.getBoolAttr(isSigned)); + attrs.push_back(signedAttr); + } + // Write the new op. auto newOp = rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), eintOperand, - castedCleartext); + op, getTypeConverter()->convertType(op.getType()), + mlir::ValueRange({eintOperand, castedCleartext}), attrs); forwardOptimizerID(op, newOp); return mlir::success(); @@ -804,12 +873,11 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase { FHE::NegEintOp, TFHE::NegGLWEOp, true>, // |_ `FHE::not` mlir::concretelang::GenericOneToOneOpConversionPattern< - FHE::BoolNotOp, TFHE::NegGLWEOp, true>, - // |_ `FHE::add_eint` - mlir::concretelang::GenericOneToOneOpConversionPattern< - FHE::AddEintOp, TFHE::AddGLWEOp, true>>(&getContext(), converter); + FHE::BoolNotOp, TFHE::NegGLWEOp, true>>(&getContext(), converter); // |_ `FHE::add_eint_int` patterns.add { } }; +int locationStringCtr = 0; mlir::Value globalStringValueFromLoc(mlir::ConversionPatternRewriter &rewriter, mlir::Location loc) { std::string locString; auto ros = llvm::raw_string_ostream(locString); loc.print(ros); - - std::string msgName; - std::stringstream stream; - stream << "loc_" << rand(); - stream >> msgName; - return mlir::LLVM::createGlobalString(loc, rewriter, msgName, locString, - mlir::LLVM::linkage::Linkage::Linkonce, - false); + locString.append("\0"); + auto locStrWithNullByte = + llvm::StringRef(locString.c_str(), locString.size() + 1); + + std::stringstream msgName; + msgName << "str_loc_" << locationStringCtr++; + return mlir::LLVM::createGlobalString( + loc, rewriter, msgName.str(), locStrWithNullByte, + mlir::LLVM::linkage::Linkage::Linkonce, false); } template @@ -122,12 +124,21 @@ struct AddOpPattern : public mlir::OpConversionPattern { const std::string funcName = "sim_add_lwe_u64"; auto locString = globalStringValueFromLoc(rewriter, addOp.getLoc()); + // check if operation has been tagged as signed + auto isSigned = false; + mlir::Attribute signedAttr = adaptor.getAttributes().get("signed"); + if (signedAttr && signedAttr.cast().getValue()) { + isSigned = true; + } + mlir::Value isSignedCst = rewriter.create( + addOp.getLoc(), isSigned, 1); if (insertForwardDeclaration( addOp, rewriter, funcName, rewriter.getFunctionType( {rewriter.getIntegerType(64), rewriter.getIntegerType(64), - mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type())}, + mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()), + rewriter.getIntegerType(1)}, {rewriter.getIntegerType(64)})) .failed()) { return mlir::failure(); @@ -135,7 +146,8 @@ struct AddOpPattern : public mlir::OpConversionPattern { rewriter.replaceOpWithNewOp( addOp, funcName, mlir::TypeRange{rewriter.getIntegerType(64)}, - mlir::ValueRange({adaptor.getA(), adaptor.getB(), locString})); + mlir::ValueRange( + {adaptor.getA(), adaptor.getB(), locString, isSignedCst})); return mlir::success(); } @@ -156,11 +168,21 @@ struct MulOpPattern : public mlir::OpConversionPattern { auto locString = globalStringValueFromLoc(rewriter, mulOp.getLoc()); + // check if operation has been tagged as signed + auto isSigned = false; + mlir::Attribute signedAttr = adaptor.getAttributes().get("signed"); + if (signedAttr && signedAttr.cast().getValue()) { + isSigned = true; + } + mlir::Value isSignedCst = rewriter.create( + mulOp.getLoc(), isSigned, 1); + if (insertForwardDeclaration( mulOp, rewriter, funcName, rewriter.getFunctionType( {rewriter.getIntegerType(64), rewriter.getIntegerType(64), - mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type())}, + mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()), + rewriter.getIntegerType(1)}, {rewriter.getIntegerType(64)})) .failed()) { return mlir::failure(); @@ -168,7 +190,8 @@ struct MulOpPattern : public mlir::OpConversionPattern { rewriter.replaceOpWithNewOp( mulOp, funcName, mlir::TypeRange{rewriter.getIntegerType(64)}, - mlir::ValueRange({adaptor.getA(), adaptor.getB(), locString})); + mlir::ValueRange( + {adaptor.getA(), adaptor.getB(), locString, isSignedCst})); return mlir::success(); } @@ -186,8 +209,10 @@ struct SubIntGLWEOpPattern : public mlir::OpRewritePattern { mlir::Value negated = rewriter.create( subOp.getLoc(), subOp.getB().getType(), subOp.getB()); - rewriter.replaceOpWithNewOp(subOp, subOp.getType(), - negated, subOp.getA()); + rewriter.replaceOpWithNewOp( + subOp, subOp.getType(), mlir::ValueRange({negated, subOp.getA()}), + // to forward the signed attr if set + subOp.getOperation()->getAttrs()); return mlir::success(); } diff --git a/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp b/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp index f537483141..d42d7b6c46 100644 --- a/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp +++ b/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp @@ -189,16 +189,22 @@ void sim_wop_pbs_crt( uint64_t sim_neg_lwe_u64(uint64_t plaintext) { return ~plaintext + 1; } -uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc) { - if (lhs > UINT63_MAX - rhs) { +uint64_t sim_add_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc, + bool is_signed) { + uint64_t result = lhs + rhs; + if (!is_signed && (lhs > UINT63_MAX - rhs || result > UINT63_MAX)) { printf("WARNING at %s: overflow happened during addition in simulation\n", loc); } - return lhs + rhs; + return result; } -uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc) { - if (rhs != 0 && lhs > UINT63_MAX / rhs) { +uint64_t sim_mul_lwe_u64(uint64_t lhs, uint64_t rhs, char *loc, + bool is_signed) { + + uint64_t result = lhs * rhs; + if (!is_signed && rhs != 0 && + (lhs > UINT63_MAX / rhs || result > UINT63_MAX)) { printf("WARNING at %s: overflow happened during multiplication in " "simulation\n", loc); diff --git a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp index 22a1bb08b3..45bd663f28 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp @@ -440,8 +440,10 @@ mlir::LogicalResult simulateTFHE(mlir::MLIRContext &context, if (fheContext) { auto solution = fheContext.value().solution; auto optCrt = getCrtDecompositionFromSolution(solution); - if (optCrt) + if (optCrt) { enableOverflowDetection = false; + log_verbose() << "WARNING: overflow detection disabled since using CRT"; + } } pipelinePrinting("TFHESimulation", pm, context); diff --git a/compilers/concrete-compiler/compiler/tests/python/test_simulation.py b/compilers/concrete-compiler/compiler/tests/python/test_simulation.py index c083d94d2d..d31bab76c9 100644 --- a/compilers/concrete-compiler/compiler/tests/python/test_simulation.py +++ b/compilers/concrete-compiler/compiler/tests/python/test_simulation.py @@ -258,6 +258,90 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', id="add_eint_int", ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: i8) -> !FHE.esint<7> { + %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.esint<7>, i8) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (-1, -2), + -3, + b"", + id="add_eint_int_signed", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> { + %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>) + return %1: !FHE.eint<7> + } + """, + (81, 73), + 154, + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', + id="add_eint", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: !FHE.esint<7>) -> !FHE.esint<7> { + %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.esint<7>, !FHE.esint<7>) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (-81, 73), + -8, + b"", + id="add_eint_signed", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { + %1 = "FHE.sub_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) + return %1: !FHE.eint<7> + } + """, + (4, 7), + 256 - 3, + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', + id="sub_eint_int", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: i8) -> !FHE.esint<7> { + %1 = "FHE.sub_eint_int"(%arg0, %arg1): (!FHE.esint<7>, i8) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (4, 7), + -3, + b"", + id="sub_eint_int_signed", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> { + %1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>) + return %1: !FHE.eint<7> + } + """, + (11, 18), + 256 - 7, + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\n', + id="sub_eint", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: !FHE.esint<7>) -> !FHE.esint<7> { + %1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.esint<7>, !FHE.esint<7>) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (11, 18), + -7, + b"", + id="sub_eint_signed", + ), pytest.param( """ func.func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { @@ -270,6 +354,31 @@ def test_lib_compile_and_run_simulation(mlir_input, args, expected_result): b'WARNING at loc("-":3:22): overflow happened during multiplication in simulation\n', id="mul_eint_int", ), + pytest.param( + """ + func.func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { + %1 = "FHE.sub_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) + %2 = "FHE.mul_eint_int"(%1, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) + return %2: !FHE.eint<7> + } + """, + (5, 10), + 256 - 50, + b'WARNING at loc("-":3:22): overflow happened during addition in simulation\nWARNING at loc("-":4:22): overflow happened during multiplication in simulation\n', + id="sub_mul_eint_int", + ), + pytest.param( + """ + func.func @main(%arg0: !FHE.esint<7>, %arg1: i8) -> !FHE.esint<7> { + %1 = "FHE.mul_eint_int"(%arg0, %arg1): (!FHE.esint<7>, i8) -> (!FHE.esint<7>) + return %1: !FHE.esint<7> + } + """, + (5, -2), + -10, + b"", + id="mul_eint_int_signed", + ), pytest.param( """ func.func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {