Skip to content
This repository has been archived by the owner on Dec 23, 2024. It is now read-only.

fix: add missing gas costs for loading ciphertexts #118

Merged
merged 1 commit into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions fhevm/ciphertext_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func loadCiphertext(env EVMEnvironment, handle common.Hash) (ct *tfhe.TfheCipher

metadataInt := newInt(env.GetState(ciphertextStorage, handle).Bytes())
if metadataInt.IsZero() {
return nil, 0
return nil, ColdSloadCostEIP2929
}
metadata := newCiphertextMetadata(metadataInt.Bytes32())
ctBytes := make([]byte, 0)
Expand All @@ -93,7 +93,7 @@ func loadCiphertext(env EVMEnvironment, handle common.Hash) (ct *tfhe.TfheCipher
err := ct.Deserialize(ctBytes, metadata.fheUintType)
if err != nil {
logger.Error("failed to deserialize ciphertext from storage", "err", err)
return nil, 0
return nil, ColdSloadCostEIP2929 + DeserializeCiphertextGas
}
env.FhevmData().loadedCiphertexts[handle] = ct
return ct, env.FhevmParams().GasCosts.FheStorageSloadGas[ct.Type()]
Expand Down
81 changes: 0 additions & 81 deletions fhevm/contracts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3474,23 +3474,6 @@ func FheArrayEqNoRhs(t *testing.T, fheUintType tfhe.FheUintType) {
}
}

func FheArrayEqNoRhsGas(t *testing.T, fheUintType tfhe.FheUintType) {
depth := 1
environment := newTestEVMEnvironment()
environment.depth = depth

lhs := make([]*big.Int, 3)
lhs[0] = loadCiphertextInTestMemory(environment, 1, depth, fheUintType).GetHash().Big()
lhs[1] = loadCiphertextInTestMemory(environment, 2, depth, fheUintType).GetHash().Big()
lhs[2] = loadCiphertextInTestMemory(environment, 3, depth, fheUintType).GetHash().Big()
input, _ := arrayEqMethod.Inputs.Pack(lhs)

gas := fheArrayEqRequiredGas(environment, input)
if gas != 0 {
t.Fatalf("fheArrayEq expected 0 gas value")
}
}

func TestFheArrayEqUnverifiedCtInLhs(t *testing.T) {
depth := 1
environment := newTestEVMEnvironment()
Expand Down Expand Up @@ -3519,28 +3502,6 @@ func TestFheArrayEqUnverifiedCtInLhs(t *testing.T) {
}
}

func TestFheArrayEqUnverifiedCtInLhsGas(t *testing.T) {
depth := 1
environment := newTestEVMEnvironment()
environment.depth = depth

lhs := make([]*big.Int, 3)
lhs[0] = loadCiphertextInTestMemory(environment, 1, depth, tfhe.FheUint32).GetHash().Big()
lhs[0].Add(lhs[0], big.NewInt(1))
lhs[1] = loadCiphertextInTestMemory(environment, 2, depth, tfhe.FheUint32).GetHash().Big()
lhs[2] = loadCiphertextInTestMemory(environment, 3, depth, tfhe.FheUint32).GetHash().Big()
rhs := make([]*big.Int, 3)
rhs[0] = loadCiphertextInTestMemory(environment, 1, depth, tfhe.FheUint32).GetHash().Big()
rhs[1] = loadCiphertextInTestMemory(environment, 2, depth, tfhe.FheUint32).GetHash().Big()
rhs[2] = loadCiphertextInTestMemory(environment, 3, depth, tfhe.FheUint32).GetHash().Big()
input, _ := arrayEqMethod.Inputs.Pack(lhs, rhs)

gas := fheArrayEqRequiredGas(environment, input)
if gas != 0 {
t.Fatalf("fheArrayEq expected 0 gas value")
}
}

func TestFheArrayEqUnverifiedCtInRhs(t *testing.T) {
depth := 1
environment := newTestEVMEnvironment()
Expand Down Expand Up @@ -3570,28 +3531,6 @@ func TestFheArrayEqUnverifiedCtInRhs(t *testing.T) {
}
}

func TestFheArrayEqUnverifiedCtInRhsGas(t *testing.T) {
depth := 1
environment := newTestEVMEnvironment()
environment.depth = depth

lhs := make([]*big.Int, 3)
lhs[0] = loadCiphertextInTestMemory(environment, 1, depth, tfhe.FheUint32).GetHash().Big()
lhs[1] = loadCiphertextInTestMemory(environment, 2, depth, tfhe.FheUint32).GetHash().Big()
lhs[2] = loadCiphertextInTestMemory(environment, 3, depth, tfhe.FheUint32).GetHash().Big()
rhs := make([]*big.Int, 3)
rhs[0] = loadCiphertextInTestMemory(environment, 1, depth, tfhe.FheUint32).GetHash().Big()
rhs[1] = loadCiphertextInTestMemory(environment, 2, depth, tfhe.FheUint32).GetHash().Big()
rhs[1].Add(lhs[0], big.NewInt(1))
rhs[2] = loadCiphertextInTestMemory(environment, 3, depth, tfhe.FheUint32).GetHash().Big()
input, _ := arrayEqMethod.Inputs.Pack(lhs, rhs)

gas := fheArrayEqRequiredGas(environment, input)
if gas != 0 {
t.Fatalf("fheArrayEq expected 0 gas value")
}
}

func TestVerifyCiphertextInvalidType(t *testing.T) {
depth := 1
environment := newTestEVMEnvironment()
Expand Down Expand Up @@ -5164,23 +5103,3 @@ func TesFheArrayEqNoRhs32(t *testing.T) {
func TestFheArrayEqNoRhs64(t *testing.T) {
FheArrayEqNoRhs(t, tfhe.FheUint64)
}

func TestFheArrayEqNoRhsGas4(t *testing.T) {
FheArrayEqNoRhsGas(t, tfhe.FheUint4)
}

func TestFheArrayEqNoRhsGas8(t *testing.T) {
FheArrayEqNoRhsGas(t, tfhe.FheUint8)
}

func TestFheArrayEqNoRhsGas16(t *testing.T) {
FheArrayEqNoRhsGas(t, tfhe.FheUint16)
}

func TesFheArrayEqNoRhsGas32(t *testing.T) {
FheArrayEqNoRhsGas(t, tfhe.FheUint32)
}

func TestFheArrayEqNoRhsGas64(t *testing.T) {
FheArrayEqNoRhsGas(t, tfhe.FheUint64)
}
12 changes: 6 additions & 6 deletions fhevm/fhelib.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,11 +309,11 @@ func load2Ciphertexts(environment EVMEnvironment, input []byte) (lhs *tfhe.TfheC
loadGasRhs := uint64(0)
lhs, loadGasLhs = loadCiphertext(environment, common.BytesToHash(input[0:32]))
if lhs == nil {
return nil, nil, 0, errors.New("unverified ciphertext handle")
return nil, nil, loadGasLhs, errors.New("unverified ciphertext handle")
}
rhs, loadGasRhs = loadCiphertext(environment, common.BytesToHash(input[32:64]))
if rhs == nil {
return nil, nil, 0, errors.New("unverified ciphertext handle")
return nil, nil, loadGasLhs + loadGasRhs, errors.New("unverified ciphertext handle")
}
err = nil
loadGas = loadGasLhs + loadGasRhs
Expand All @@ -337,15 +337,15 @@ func load3Ciphertexts(environment EVMEnvironment, input []byte) (first *tfhe.Tfh
loadGasThird := uint64(0)
first, loadGasFirst = loadCiphertext(environment, common.BytesToHash(input[0:32]))
if first == nil {
return nil, nil, nil, 0, errors.New("unverified ciphertext handle")
return nil, nil, nil, loadGasFirst, errors.New("unverified ciphertext handle")
}
second, loadGasSecond = loadCiphertext(environment, common.BytesToHash(input[32:64]))
if second == nil {
return nil, nil, nil, 0, errors.New("unverified ciphertext handle")
return nil, nil, nil, loadGasFirst + loadGasSecond, errors.New("unverified ciphertext handle")
}
third, loadGasThird = loadCiphertext(environment, common.BytesToHash(input[64:96]))
if third == nil {
return nil, nil, nil, 0, errors.New("unverified ciphertext handle")
return nil, nil, nil, loadGasFirst + loadGasSecond + loadGasThird, errors.New("unverified ciphertext handle")
}
err = nil
loadGas = loadGasFirst + loadGasSecond + loadGasThird
Expand All @@ -358,7 +358,7 @@ func getScalarOperands(environment EVMEnvironment, input []byte) (lhs *tfhe.Tfhe
}
lhs, loadGas = loadCiphertext(environment, common.BytesToHash(input[0:32]))
if lhs == nil {
return nil, nil, 0, errors.New("failed to load ciphertext")
return nil, nil, loadGas, errors.New("failed to load ciphertext")
}
rhs = &big.Int{}
rhs.SetBytes(input[32:64])
Expand Down
33 changes: 14 additions & 19 deletions fhevm/operators_arithmetic_gas.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@ func fheAddSubRequiredGas(environment EVMEnvironment, input []byte) uint64 {
lhs, rhs, loadGas, err = load2Ciphertexts(environment, input)
if err != nil {
logger.Error("fheAdd/Sub RequiredGas() ciphertext failed to load inputs", "err", err, "input", hex.EncodeToString(input))
return 0
return loadGas
}
if lhs.Type() != rhs.Type() {
logger.Error("fheAdd/Sub RequiredGas() operand type mismatch", "lhs", lhs.Type(), "rhs", rhs.Type())
return 0
return loadGas
}
} else {
lhs, _, loadGas, err = getScalarOperands(environment, input)
if err != nil {
logger.Error("fheAdd/Sub RequiredGas() scalar failed to load inputs", "err", err, "input", hex.EncodeToString(input))
return 0
return loadGas
}
}

Expand All @@ -47,24 +47,22 @@ func fheMulRequiredGas(environment EVMEnvironment, input []byte) uint64 {
logger.Error("fheMul RequiredGas() can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input))
return 0
}
loadGas := uint64(0)
var lhs, rhs *tfhe.TfheCiphertext
if !isScalar {
lhs, rhs, loadGas, err = load2Ciphertexts(environment, input)
lhs, rhs, loadGas, err := load2Ciphertexts(environment, input)
if err != nil {
logger.Error("fheMul RequiredGas() ciphertext failed to load inputs", "err", err, "input", hex.EncodeToString(input))
return 0
return loadGas
}
if lhs.Type() != rhs.Type() {
logger.Error("fheMul RequiredGas() operand type mismatch", "lhs", lhs.Type(), "rhs", rhs.Type())
return 0
return loadGas
}
return environment.FhevmParams().GasCosts.FheMul[lhs.Type()]
return environment.FhevmParams().GasCosts.FheMul[lhs.Type()] + loadGas
} else {
lhs, _, loadGas, err = getScalarOperands(environment, input)
lhs, _, loadGas, err := getScalarOperands(environment, input)
if err != nil {
logger.Error("fheMul RequiredGas() scalar failed to load inputs", "err", err, "input", hex.EncodeToString(input))
return 0
return loadGas
}
return environment.FhevmParams().GasCosts.FheScalarMul[lhs.Type()] + loadGas
}
Expand All @@ -79,16 +77,15 @@ func fheDivRequiredGas(environment EVMEnvironment, input []byte) uint64 {
logger.Error("fheDiv RequiredGas() cannot detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input))
return 0
}
loadGas := uint64(0)
var lhs *tfhe.TfheCiphertext

if !isScalar {
logger.Error("fheDiv RequiredGas() only scalar in division is supported, two ciphertexts received", "input", hex.EncodeToString(input))
return 0
} else {
lhs, _, loadGas, err = getScalarOperands(environment, input)
lhs, _, loadGas, err := getScalarOperands(environment, input)
if err != nil {
logger.Error("fheDiv RequiredGas() scalar failed to load inputs", "err", err, "input", hex.EncodeToString(input))
return 0
return loadGas
}
return environment.FhevmParams().GasCosts.FheScalarDiv[lhs.Type()] + loadGas
}
Expand All @@ -103,16 +100,14 @@ func fheRemRequiredGas(environment EVMEnvironment, input []byte) uint64 {
logger.Error("fheRem RequiredGas() cannot detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input))
return 0
}
var lhs *tfhe.TfheCiphertext
loadGas := uint64(0)
if !isScalar {
logger.Error("fheRem RequiredGas() only scalar in division is supported, two ciphertexts received", "input", hex.EncodeToString(input))
return 0
} else {
lhs, _, loadGas, err = getScalarOperands(environment, input)
lhs, _, loadGas, err := getScalarOperands(environment, input)
if err != nil {
logger.Error("fheRem RequiredGas() scalar failed to load inputs", "err", err, "input", hex.EncodeToString(input))
return 0
return loadGas
}
return environment.FhevmParams().GasCosts.FheScalarRem[lhs.Type()] + loadGas
}
Expand Down
23 changes: 10 additions & 13 deletions fhevm/operators_bit_gas.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"encoding/hex"

"github.com/ethereum/go-ethereum/common"
"github.com/zama-ai/fhevm-go/fhevm/tfhe"
)

func fheShlRequiredGas(environment EVMEnvironment, input []byte) uint64 {
Expand All @@ -16,24 +15,22 @@ func fheShlRequiredGas(environment EVMEnvironment, input []byte) uint64 {
logger.Error("fheShift RequiredGas() can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input))
return 0
}
var lhs, rhs *tfhe.TfheCiphertext
loadGas := uint64(0)
if !isScalar {
lhs, rhs, loadGas, err = load2Ciphertexts(environment, input)
lhs, rhs, loadGas, err := load2Ciphertexts(environment, input)
if err != nil {
logger.Error("fheShift RequiredGas() ciphertext failed to load inputs", "err", err, "input", hex.EncodeToString(input))
return 0
return loadGas
}
if lhs.Type() != rhs.Type() {
logger.Error("fheShift RequiredGas() operand type mismatch", "lhs", lhs.Type(), "rhs", rhs.Type())
return 0
return loadGas
}
return environment.FhevmParams().GasCosts.FheShift[lhs.Type()]
return environment.FhevmParams().GasCosts.FheShift[lhs.Type()] + loadGas
} else {
lhs, _, loadGas, err = getScalarOperands(environment, input)
lhs, _, loadGas, err := getScalarOperands(environment, input)
if err != nil {
logger.Error("fheShift RequiredGas() scalar failed to load inputs", "err", err, "input", hex.EncodeToString(input))
return 0
return loadGas
}
return environment.FhevmParams().GasCosts.FheScalarShift[lhs.Type()] + loadGas
}
Expand Down Expand Up @@ -65,7 +62,7 @@ func fheNegRequiredGas(environment EVMEnvironment, input []byte) uint64 {
ct, loadGas := loadCiphertext(environment, common.BytesToHash(input[0:32]))
if ct == nil {
logger.Error("fheNeg failed to load input", "input", hex.EncodeToString(input))
return 0
return loadGas
}
return environment.FhevmParams().GasCosts.FheNeg[ct.Type()] + loadGas
}
Expand All @@ -81,7 +78,7 @@ func fheNotRequiredGas(environment EVMEnvironment, input []byte) uint64 {
ct, loadGas := loadCiphertext(environment, common.BytesToHash(input[0:32]))
if ct == nil {
logger.Error("fheNot failed to load input", "input", hex.EncodeToString(input))
return 0
return loadGas
}
return environment.FhevmParams().GasCosts.FheNot[ct.Type()] + loadGas
}
Expand All @@ -106,11 +103,11 @@ func fheBitAndRequiredGas(environment EVMEnvironment, input []byte) uint64 {
lhs, rhs, loadGas, err := load2Ciphertexts(environment, input)
if err != nil {
logger.Error("Bitwise op RequiredGas() failed to load inputs", "err", err, "input", hex.EncodeToString(input))
return 0
return loadGas
}
if lhs.Type() != rhs.Type() {
logger.Error("Bitwise op RequiredGas() operand type mismatch", "lhs", lhs.Type(), "rhs", rhs.Type())
return 0
return loadGas
}
return environment.FhevmParams().GasCosts.FheBitwiseOp[lhs.Type()] + loadGas
}
Expand Down
16 changes: 9 additions & 7 deletions fhevm/operators_comparison.go
Original file line number Diff line number Diff line change
Expand Up @@ -617,20 +617,22 @@ func init() {
}
}

func getVerifiedCiphertexts(environment EVMEnvironment, unpacked interface{}) ([]*tfhe.TfheCiphertext, error) {
func getVerifiedCiphertexts(environment EVMEnvironment, unpacked interface{}) ([]*tfhe.TfheCiphertext, uint64, error) {
totalLoadGas := uint64(0)
big, ok := unpacked.([]*big.Int)
if !ok {
return nil, fmt.Errorf("fheArrayEq failed to cast to []*big.Int")
return nil, 0, fmt.Errorf("fheArrayEq failed to cast to []*big.Int")
}
ret := make([]*tfhe.TfheCiphertext, 0, len(big))
for _, b := range big {
ct, _ := loadCiphertext(environment, common.BigToHash(b))
ct, loadGas := loadCiphertext(environment, common.BigToHash(b))
if ct == nil {
return nil, fmt.Errorf("fheArrayEq unverified ciphertext")
return nil, totalLoadGas + loadGas, fmt.Errorf("fheArrayEq unverified ciphertext")
}
totalLoadGas += loadGas
ret = append(ret, ct)
}
return ret, nil
return ret, totalLoadGas, nil
}

func fheArrayEqRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool, runSpan trace.Span) ([]byte, error) {
Expand All @@ -649,14 +651,14 @@ func fheArrayEqRun(environment EVMEnvironment, caller common.Address, addr commo
return nil, err
}

lhs, err := getVerifiedCiphertexts(environment, unpacked[0])
lhs, _, err := getVerifiedCiphertexts(environment, unpacked[0])
if err != nil {
msg := "fheArrayEqRun failed to get lhs to verified ciphertexts"
logger.Error(msg, "err", err)
return nil, err
}

rhs, err := getVerifiedCiphertexts(environment, unpacked[1])
rhs, _, err := getVerifiedCiphertexts(environment, unpacked[1])
if err != nil {
msg := "fheArrayEqRun failed to get rhs to verified ciphertexts"
logger.Error(msg, "err", err)
Expand Down
Loading
Loading