Skip to content

Commit

Permalink
feat: consolidate ClosestProof verification and remove the NilPathHas…
Browse files Browse the repository at this point in the history
…her method
  • Loading branch information
h5law committed Mar 19, 2024
1 parent 9e9d9b3 commit 4831b52
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 53 deletions.
18 changes: 0 additions & 18 deletions options.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
package smt

import (
"hash"
)

// Option is a function that configures SparseMerkleTrie.
type Option func(*TrieSpec)

Expand All @@ -16,17 +12,3 @@ func WithPathHasher(ph PathHasher) Option {
func WithValueHasher(vh ValueHasher) Option {
return func(ts *TrieSpec) { ts.vh = vh }
}

// NoPrehashSpec returns a new TrieSpec that has a nil Value Hasher and a nil
// Path Hasher
// NOTE: This should only be used when values are already hashed and a path is
// used instead of a key during proof verification, otherwise these will be
// double hashed and produce an incorrect leaf digest invalidating the proof.
func NoPrehashSpec(hasher hash.Hash, sumTrie bool) *TrieSpec {
spec := newTrieSpec(hasher, sumTrie)
opt := WithPathHasher(newNilPathHasher(hasher.Size()))
opt(&spec)
opt = WithValueHasher(nil)
opt(&spec)
return &spec
}
66 changes: 54 additions & 12 deletions proofs.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ func (proof *SparseMerkleProof) validateBasic(spec *TrieSpec) error {
// Check that leaf data for non-membership proofs is a valid size.
lps := len(leafPrefix) + spec.ph.PathSize()
if proof.NonMembershipLeafData != nil && len(proof.NonMembershipLeafData) < lps {
return fmt.Errorf("invalid non-membership leaf data size: got %d but min is %d", len(proof.NonMembershipLeafData), lps)
return fmt.Errorf(
"invalid non-membership leaf data size: got %d but min is %d",
len(proof.NonMembershipLeafData),
lps,
)
}

// Check that all supplied sidenodes are the correct size.
Expand Down Expand Up @@ -133,7 +137,11 @@ func (proof *SparseCompactMerkleProof) validateBasic(spec *TrieSpec) error {

// Compact proofs: check that NumSideNodes is within the right range.
if proof.NumSideNodes < 0 || proof.NumSideNodes > spec.ph.PathSize()*8 {
return fmt.Errorf("invalid number of side nodes: got %d, min is 0 and max is %d", len(proof.SideNodes), spec.ph.PathSize()*8)
return fmt.Errorf(
"invalid number of side nodes: got %d, min is 0 and max is %d",
len(proof.SideNodes),
spec.ph.PathSize()*8,
)
}

// Compact proofs: check that the length of the bit mask is as expected
Expand Down Expand Up @@ -185,6 +193,17 @@ func (proof *SparseMerkleClosestProof) Unmarshal(bz []byte) error {
return dec.Decode(proof)
}

// GetValueHash returns the value hash of the closest proof.
func (proof *SparseMerkleClosestProof) GetValueHash(spec *TrieSpec) []byte {
if proof.ClosestValueHash == nil {
return nil
}
if spec.sumTrie {
return proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSize]
}
return proof.ClosestValueHash
}

func (proof *SparseMerkleClosestProof) validateBasic(spec *TrieSpec) error {
// ensure the depth of the leaf node being proven is within the path size
if proof.Depth < 0 || proof.Depth > spec.ph.PathSize()*8 {
Expand Down Expand Up @@ -246,7 +265,12 @@ func (proof *SparseCompactMerkleClosestProof) validateBasic(spec *TrieSpec) erro
}
for i, b := range proof.FlippedBits {
if len(b) > maxSliceLen {
return fmt.Errorf("invalid compressed flipped bit index %d: got length %d, max is %d]", i, bytesToInt(b), maxSliceLen)
return fmt.Errorf(
"invalid compressed flipped bit index %d: got length %d, max is %d]",
i,
bytesToInt(b),
maxSliceLen,
)
}
}
// perform a sanity check on the closest proof
Expand Down Expand Up @@ -301,26 +325,38 @@ func VerifySumProof(proof *SparseMerkleProof, root, key, value []byte, sum uint6

// VerifyClosestProof verifies a Merkle proof for a proof of inclusion for a leaf
// found to have the closest path to the one provided to the proof structure
//
// TO_AUDITOR: This is akin to an inclusion proof with N (num flipped bits) exclusion
// proof wrapped into one and needs to be reviewed from an algorithm POV.
func VerifyClosestProof(proof *SparseMerkleClosestProof, root []byte, spec *TrieSpec) (bool, error) {
if err := proof.validateBasic(spec); err != nil {
return false, errors.Join(ErrBadProof, err)
}
if !spec.sumTrie {
return VerifyProof(proof.ClosestProof, root, proof.ClosestPath, proof.ClosestValueHash, spec)
// Create a new TrieSpec with a nil path hasher - as the ClosestProof
// already contains a hashed path, double hashing it will invalidate the
// proof.
nilSpec := &TrieSpec{
th: spec.th,
ph: newNilPathHasher(spec.ph.PathSize()),
vh: spec.vh,
sumTrie: spec.sumTrie,
}
if !nilSpec.sumTrie {
return VerifyProof(proof.ClosestProof, root, proof.ClosestPath, proof.ClosestValueHash, nilSpec)
}
if proof.ClosestValueHash == nil {
return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, nil, 0, spec)
return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, nil, 0, nilSpec)
}
sumBz := proof.ClosestValueHash[len(proof.ClosestValueHash)-sumSize:]
sum := binary.BigEndian.Uint64(sumBz)
valueHash := proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSize]
return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, valueHash, sum, spec)
return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, valueHash, sum, nilSpec)
}

func verifyProofWithUpdates(proof *SparseMerkleProof, root []byte, key []byte, value []byte, spec *TrieSpec) (bool, [][][]byte, error) {
func verifyProofWithUpdates(
proof *SparseMerkleProof,
root []byte,
key []byte,
value []byte,
spec *TrieSpec,
) (bool, [][][]byte, error) {
path := spec.ph.Path(key)

if err := proof.validateBasic(spec); err != nil {
Expand Down Expand Up @@ -384,7 +420,13 @@ func VerifyCompactProof(proof *SparseCompactMerkleProof, root []byte, key, value
}

// VerifyCompactSumProof is similar to VerifySumProof but for a compacted Merkle proof.
func VerifyCompactSumProof(proof *SparseCompactMerkleProof, root []byte, key, value []byte, sum uint64, spec *TrieSpec) (bool, error) {
func VerifyCompactSumProof(
proof *SparseCompactMerkleProof,
root []byte,
key, value []byte,
sum uint64,
spec *TrieSpec,
) (bool, error) {
decompactedProof, err := DecompactProof(proof, spec)
if err != nil {
return false, errors.Join(ErrBadProof, err)
Expand Down
46 changes: 33 additions & 13 deletions smst_proofs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,14 @@ func TestSMST_Proof_Operations(t *testing.T) {
result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("badValue"), 10, base) // wrong value and sum
require.NoError(t, err)
require.False(t, result)
result, err = VerifySumProof(randomiseSumProof(proof), root, []byte("testKey"), []byte("testValue"), 5, base) // invalid proof
result, err = VerifySumProof(
randomiseSumProof(proof),
root,
[]byte("testKey"),
[]byte("testValue"),
5,
base,
) // invalid proof
require.NoError(t, err)
require.False(t, result)

Expand All @@ -98,7 +105,14 @@ func TestSMST_Proof_Operations(t *testing.T) {
result, err = VerifySumProof(proof, root, []byte("testKey2"), []byte("badValue"), 10, base) // wrong value and sum
require.NoError(t, err)
require.False(t, result)
result, err = VerifySumProof(randomiseSumProof(proof), root, []byte("testKey2"), []byte("testValue"), 5, base) // invalid proof
result, err = VerifySumProof(
randomiseSumProof(proof),
root,
[]byte("testKey2"),
[]byte("testValue"),
5,
base,
) // invalid proof
require.NoError(t, err)
require.False(t, result)

Expand Down Expand Up @@ -129,7 +143,14 @@ func TestSMST_Proof_Operations(t *testing.T) {
result, err = VerifySumProof(proof, root, []byte("testKey3"), defaultValue, 5, base) // wrong sum
require.NoError(t, err)
require.False(t, result)
result, err = VerifySumProof(randomiseSumProof(proof), root, []byte("testKey3"), defaultValue, 0, base) // invalid proof
result, err = VerifySumProof(
randomiseSumProof(proof),
root,
[]byte("testKey3"),
defaultValue,
0,
base,
) // invalid proof
require.NoError(t, err)
require.False(t, result)
}
Expand Down Expand Up @@ -204,7 +225,6 @@ func TestSMST_Proof_ValidateBasic(t *testing.T) {
func TestSMST_ClosestProof_ValidateBasic(t *testing.T) {
smn := simplemap.NewSimpleMap()
smst := NewSparseMerkleSumTrie(smn, sha256.New())
np := NoPrehashSpec(sha256.New(), true)
base := smst.Spec()
path := sha256.Sum256([]byte("testKey2"))
flipPathBit(path[:], 3)
Expand All @@ -227,14 +247,14 @@ func TestSMST_ClosestProof_ValidateBasic(t *testing.T) {
require.NoError(t, err)
proof.Depth = -1
require.EqualError(t, proof.validateBasic(base), "invalid depth: got -1, outside of [0, 256]")
result, err := VerifyClosestProof(proof, root, np)
result, err := VerifyClosestProof(proof, root, smst.Spec())
require.ErrorIs(t, err, ErrBadProof)
require.False(t, result)
_, err = CompactClosestProof(proof, base)
require.Error(t, err)
proof.Depth = 257
require.EqualError(t, proof.validateBasic(base), "invalid depth: got 257, outside of [0, 256]")
result, err = VerifyClosestProof(proof, root, np)
result, err = VerifyClosestProof(proof, root, smst.Spec())
require.ErrorIs(t, err, ErrBadProof)
require.False(t, result)
_, err = CompactClosestProof(proof, base)
Expand All @@ -244,14 +264,14 @@ func TestSMST_ClosestProof_ValidateBasic(t *testing.T) {
require.NoError(t, err)
proof.FlippedBits[0] = -1
require.EqualError(t, proof.validateBasic(base), "invalid flipped bit index 0: got -1, outside of [0, 8]")
result, err = VerifyClosestProof(proof, root, np)
result, err = VerifyClosestProof(proof, root, smst.Spec())
require.ErrorIs(t, err, ErrBadProof)
require.False(t, result)
_, err = CompactClosestProof(proof, base)
require.Error(t, err)
proof.FlippedBits[0] = 9
require.EqualError(t, proof.validateBasic(base), "invalid flipped bit index 0: got 9, outside of [0, 8]")
result, err = VerifyClosestProof(proof, root, np)
result, err = VerifyClosestProof(proof, root, smst.Spec())
require.ErrorIs(t, err, ErrBadProof)
require.False(t, result)
_, err = CompactClosestProof(proof, base)
Expand All @@ -265,7 +285,7 @@ func TestSMST_ClosestProof_ValidateBasic(t *testing.T) {
proof.validateBasic(base),
"invalid closest path: 8d13809f932d0296b88c1913231ab4b403f05c88363575476204fef6930f22ae (not equal at bit: 3)",
)
result, err = VerifyClosestProof(proof, root, np)
result, err = VerifyClosestProof(proof, root, smst.Spec())
require.ErrorIs(t, err, ErrBadProof)
require.False(t, result)
_, err = CompactClosestProof(proof, base)
Expand Down Expand Up @@ -326,7 +346,7 @@ func TestSMST_ProveClosest(t *testing.T) {
ClosestProof: proof.ClosestProof, // copy of proof as we are checking equality of other fields
})

result, err = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), true))
result, err = VerifyClosestProof(proof, root, smst.Spec())
require.NoError(t, err)
require.True(t, result)

Expand All @@ -352,7 +372,7 @@ func TestSMST_ProveClosest(t *testing.T) {
ClosestProof: proof.ClosestProof, // copy of proof as we are checking equality of other fields
})

result, err = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), true))
result, err = VerifyClosestProof(proof, root, smst.Spec())
require.NoError(t, err)
require.True(t, result)
}
Expand Down Expand Up @@ -381,7 +401,7 @@ func TestSMST_ProveClosest_Empty(t *testing.T) {
ClosestProof: &SparseMerkleProof{},
})

result, err := VerifyClosestProof(proof, smst.Root(), NoPrehashSpec(sha256.New(), true))
result, err := VerifyClosestProof(proof, smst.Root(), smst.Spec())
require.NoError(t, err)
require.True(t, result)
}
Expand Down Expand Up @@ -419,7 +439,7 @@ func TestSMST_ProveClosest_OneNode(t *testing.T) {
ClosestProof: &SparseMerkleProof{},
})

result, err := VerifyClosestProof(proof, smst.Root(), NoPrehashSpec(sha256.New(), true))
result, err := VerifyClosestProof(proof, smst.Root(), smst.Spec())
require.NoError(t, err)
require.True(t, result)
}
Expand Down
19 changes: 9 additions & 10 deletions smt_proofs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ func TestSMT_Proof_ValidateBasic(t *testing.T) {
func TestSMT_ClosestProof_ValidateBasic(t *testing.T) {
smn := simplemap.NewSimpleMap()
smt := NewSparseMerkleTrie(smn, sha256.New())
np := NoPrehashSpec(sha256.New(), false)
base := smt.Spec()
path := sha256.Sum256([]byte("testKey2"))
flipPathBit(path[:], 3)
Expand All @@ -201,14 +200,14 @@ func TestSMT_ClosestProof_ValidateBasic(t *testing.T) {
require.NoError(t, err)
proof.Depth = -1
require.EqualError(t, proof.validateBasic(base), "invalid depth: got -1, outside of [0, 256]")
result, err := VerifyClosestProof(proof, root, np)
result, err := VerifyClosestProof(proof, root, smt.Spec())
require.ErrorIs(t, err, ErrBadProof)
require.False(t, result)
_, err = CompactClosestProof(proof, base)
require.Error(t, err)
proof.Depth = 257
require.EqualError(t, proof.validateBasic(base), "invalid depth: got 257, outside of [0, 256]")
result, err = VerifyClosestProof(proof, root, np)
result, err = VerifyClosestProof(proof, root, smt.Spec())
require.ErrorIs(t, err, ErrBadProof)
require.False(t, result)
_, err = CompactClosestProof(proof, base)
Expand All @@ -218,14 +217,14 @@ func TestSMT_ClosestProof_ValidateBasic(t *testing.T) {
require.NoError(t, err)
proof.FlippedBits[0] = -1
require.EqualError(t, proof.validateBasic(base), "invalid flipped bit index 0: got -1, outside of [0, 8]")
result, err = VerifyClosestProof(proof, root, np)
result, err = VerifyClosestProof(proof, root, smt.Spec())
require.ErrorIs(t, err, ErrBadProof)
require.False(t, result)
_, err = CompactClosestProof(proof, base)
require.Error(t, err)
proof.FlippedBits[0] = 9
require.EqualError(t, proof.validateBasic(base), "invalid flipped bit index 0: got 9, outside of [0, 8]")
result, err = VerifyClosestProof(proof, root, np)
result, err = VerifyClosestProof(proof, root, smt.Spec())
require.ErrorIs(t, err, ErrBadProof)
require.False(t, result)
_, err = CompactClosestProof(proof, base)
Expand All @@ -239,7 +238,7 @@ func TestSMT_ClosestProof_ValidateBasic(t *testing.T) {
proof.validateBasic(base),
"invalid closest path: 8d13809f932d0296b88c1913231ab4b403f05c88363575476204fef6930f22ae (not equal at bit: 3)",
)
result, err = VerifyClosestProof(proof, root, np)
result, err = VerifyClosestProof(proof, root, smt.Spec())
require.ErrorIs(t, err, ErrBadProof)
require.False(t, result)
_, err = CompactClosestProof(proof, base)
Expand Down Expand Up @@ -287,7 +286,7 @@ func TestSMT_ProveClosest(t *testing.T) {
checkClosestCompactEquivalence(t, proof, smt.Spec())
require.NotEqual(t, proof, &SparseMerkleClosestProof{})

result, err = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), false))
result, err = VerifyClosestProof(proof, root, smt.Spec())
require.NoError(t, err)
require.True(t, result)
closestPath := sha256.Sum256([]byte("testKey2"))
Expand All @@ -304,7 +303,7 @@ func TestSMT_ProveClosest(t *testing.T) {
checkClosestCompactEquivalence(t, proof, smt.Spec())
require.NotEqual(t, proof, &SparseMerkleClosestProof{})

result, err = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), false))
result, err = VerifyClosestProof(proof, root, smt.Spec())
require.NoError(t, err)
require.True(t, result)
closestPath = sha256.Sum256([]byte("testKey4"))
Expand Down Expand Up @@ -336,7 +335,7 @@ func TestSMT_ProveClosest_Empty(t *testing.T) {
ClosestProof: &SparseMerkleProof{},
})

result, err := VerifyClosestProof(proof, smt.Root(), NoPrehashSpec(sha256.New(), false))
result, err := VerifyClosestProof(proof, smt.Root(), smt.Spec())
require.NoError(t, err)
require.True(t, result)
}
Expand Down Expand Up @@ -368,7 +367,7 @@ func TestSMT_ProveClosest_OneNode(t *testing.T) {
ClosestProof: &SparseMerkleProof{},
})

result, err := VerifyClosestProof(proof, smt.Root(), NoPrehashSpec(sha256.New(), false))
result, err := VerifyClosestProof(proof, smt.Root(), smt.Spec())
require.NoError(t, err)
require.True(t, result)
}
Expand Down

0 comments on commit 4831b52

Please sign in to comment.