diff --git a/options.go b/options.go index c4eb422..884d559 100644 --- a/options.go +++ b/options.go @@ -1,9 +1,5 @@ package smt -import ( - "hash" -) - // Option is a function that configures SparseMerkleTrie. type Option func(*TrieSpec) @@ -16,17 +12,3 @@ func WithPathHasher(ph PathHasher) Option { func WithValueHasher(vh ValueHasher) Option { return func(ts *TrieSpec) { ts.vh = vh } } - -// NoPrehashSpec returns a new TrieSpec that has a nil Value Hasher and a nil -// Path Hasher -// NOTE: This should only be used when values are already hashed and a path is -// used instead of a key during proof verification, otherwise these will be -// double hashed and produce an incorrect leaf digest invalidating the proof. -func NoPrehashSpec(hasher hash.Hash, sumTrie bool) *TrieSpec { - spec := newTrieSpec(hasher, sumTrie) - opt := WithPathHasher(newNilPathHasher(hasher.Size())) - opt(&spec) - opt = WithValueHasher(nil) - opt(&spec) - return &spec -} diff --git a/proofs.go b/proofs.go index 64ce171..3a2b690 100644 --- a/proofs.go +++ b/proofs.go @@ -61,7 +61,11 @@ func (proof *SparseMerkleProof) validateBasic(spec *TrieSpec) error { // Check that leaf data for non-membership proofs is a valid size. lps := len(leafPrefix) + spec.ph.PathSize() if proof.NonMembershipLeafData != nil && len(proof.NonMembershipLeafData) < lps { - return fmt.Errorf("invalid non-membership leaf data size: got %d but min is %d", len(proof.NonMembershipLeafData), lps) + return fmt.Errorf( + "invalid non-membership leaf data size: got %d but min is %d", + len(proof.NonMembershipLeafData), + lps, + ) } // Check that all supplied sidenodes are the correct size. @@ -133,7 +137,11 @@ func (proof *SparseCompactMerkleProof) validateBasic(spec *TrieSpec) error { // Compact proofs: check that NumSideNodes is within the right range. if proof.NumSideNodes < 0 || proof.NumSideNodes > spec.ph.PathSize()*8 { - return fmt.Errorf("invalid number of side nodes: got %d, min is 0 and max is %d", len(proof.SideNodes), spec.ph.PathSize()*8) + return fmt.Errorf( + "invalid number of side nodes: got %d, min is 0 and max is %d", + len(proof.SideNodes), + spec.ph.PathSize()*8, + ) } // Compact proofs: check that the length of the bit mask is as expected @@ -185,6 +193,17 @@ func (proof *SparseMerkleClosestProof) Unmarshal(bz []byte) error { return dec.Decode(proof) } +// GetValueHash returns the value hash of the closest proof. +func (proof *SparseMerkleClosestProof) GetValueHash(spec *TrieSpec) []byte { + if proof.ClosestValueHash == nil { + return nil + } + if spec.sumTrie { + return proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSize] + } + return proof.ClosestValueHash +} + func (proof *SparseMerkleClosestProof) validateBasic(spec *TrieSpec) error { // ensure the depth of the leaf node being proven is within the path size if proof.Depth < 0 || proof.Depth > spec.ph.PathSize()*8 { @@ -246,7 +265,12 @@ func (proof *SparseCompactMerkleClosestProof) validateBasic(spec *TrieSpec) erro } for i, b := range proof.FlippedBits { if len(b) > maxSliceLen { - return fmt.Errorf("invalid compressed flipped bit index %d: got length %d, max is %d]", i, bytesToInt(b), maxSliceLen) + return fmt.Errorf( + "invalid compressed flipped bit index %d: got length %d, max is %d]", + i, + bytesToInt(b), + maxSliceLen, + ) } } // perform a sanity check on the closest proof @@ -301,26 +325,38 @@ func VerifySumProof(proof *SparseMerkleProof, root, key, value []byte, sum uint6 // VerifyClosestProof verifies a Merkle proof for a proof of inclusion for a leaf // found to have the closest path to the one provided to the proof structure -// -// TO_AUDITOR: This is akin to an inclusion proof with N (num flipped bits) exclusion -// proof wrapped into one and needs to be reviewed from an algorithm POV. func VerifyClosestProof(proof *SparseMerkleClosestProof, root []byte, spec *TrieSpec) (bool, error) { if err := proof.validateBasic(spec); err != nil { return false, errors.Join(ErrBadProof, err) } - if !spec.sumTrie { - return VerifyProof(proof.ClosestProof, root, proof.ClosestPath, proof.ClosestValueHash, spec) + // Create a new TrieSpec with a nil path hasher - as the ClosestProof + // already contains a hashed path, double hashing it will invalidate the + // proof. + nilSpec := &TrieSpec{ + th: spec.th, + ph: newNilPathHasher(spec.ph.PathSize()), + vh: spec.vh, + sumTrie: spec.sumTrie, + } + if !nilSpec.sumTrie { + return VerifyProof(proof.ClosestProof, root, proof.ClosestPath, proof.ClosestValueHash, nilSpec) } if proof.ClosestValueHash == nil { - return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, nil, 0, spec) + return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, nil, 0, nilSpec) } sumBz := proof.ClosestValueHash[len(proof.ClosestValueHash)-sumSize:] sum := binary.BigEndian.Uint64(sumBz) valueHash := proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSize] - return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, valueHash, sum, spec) + return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, valueHash, sum, nilSpec) } -func verifyProofWithUpdates(proof *SparseMerkleProof, root []byte, key []byte, value []byte, spec *TrieSpec) (bool, [][][]byte, error) { +func verifyProofWithUpdates( + proof *SparseMerkleProof, + root []byte, + key []byte, + value []byte, + spec *TrieSpec, +) (bool, [][][]byte, error) { path := spec.ph.Path(key) if err := proof.validateBasic(spec); err != nil { @@ -384,7 +420,13 @@ func VerifyCompactProof(proof *SparseCompactMerkleProof, root []byte, key, value } // VerifyCompactSumProof is similar to VerifySumProof but for a compacted Merkle proof. -func VerifyCompactSumProof(proof *SparseCompactMerkleProof, root []byte, key, value []byte, sum uint64, spec *TrieSpec) (bool, error) { +func VerifyCompactSumProof( + proof *SparseCompactMerkleProof, + root []byte, + key, value []byte, + sum uint64, + spec *TrieSpec, +) (bool, error) { decompactedProof, err := DecompactProof(proof, spec) if err != nil { return false, errors.Join(ErrBadProof, err) diff --git a/smst_proofs_test.go b/smst_proofs_test.go index 21fd454..c909d23 100644 --- a/smst_proofs_test.go +++ b/smst_proofs_test.go @@ -79,7 +79,14 @@ func TestSMST_Proof_Operations(t *testing.T) { result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("badValue"), 10, base) // wrong value and sum require.NoError(t, err) require.False(t, result) - result, err = VerifySumProof(randomiseSumProof(proof), root, []byte("testKey"), []byte("testValue"), 5, base) // invalid proof + result, err = VerifySumProof( + randomiseSumProof(proof), + root, + []byte("testKey"), + []byte("testValue"), + 5, + base, + ) // invalid proof require.NoError(t, err) require.False(t, result) @@ -98,7 +105,14 @@ func TestSMST_Proof_Operations(t *testing.T) { result, err = VerifySumProof(proof, root, []byte("testKey2"), []byte("badValue"), 10, base) // wrong value and sum require.NoError(t, err) require.False(t, result) - result, err = VerifySumProof(randomiseSumProof(proof), root, []byte("testKey2"), []byte("testValue"), 5, base) // invalid proof + result, err = VerifySumProof( + randomiseSumProof(proof), + root, + []byte("testKey2"), + []byte("testValue"), + 5, + base, + ) // invalid proof require.NoError(t, err) require.False(t, result) @@ -129,7 +143,14 @@ func TestSMST_Proof_Operations(t *testing.T) { result, err = VerifySumProof(proof, root, []byte("testKey3"), defaultValue, 5, base) // wrong sum require.NoError(t, err) require.False(t, result) - result, err = VerifySumProof(randomiseSumProof(proof), root, []byte("testKey3"), defaultValue, 0, base) // invalid proof + result, err = VerifySumProof( + randomiseSumProof(proof), + root, + []byte("testKey3"), + defaultValue, + 0, + base, + ) // invalid proof require.NoError(t, err) require.False(t, result) } @@ -204,7 +225,6 @@ func TestSMST_Proof_ValidateBasic(t *testing.T) { func TestSMST_ClosestProof_ValidateBasic(t *testing.T) { smn := simplemap.NewSimpleMap() smst := NewSparseMerkleSumTrie(smn, sha256.New()) - np := NoPrehashSpec(sha256.New(), true) base := smst.Spec() path := sha256.Sum256([]byte("testKey2")) flipPathBit(path[:], 3) @@ -227,14 +247,14 @@ func TestSMST_ClosestProof_ValidateBasic(t *testing.T) { require.NoError(t, err) proof.Depth = -1 require.EqualError(t, proof.validateBasic(base), "invalid depth: got -1, outside of [0, 256]") - result, err := VerifyClosestProof(proof, root, np) + result, err := VerifyClosestProof(proof, root, smst.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) require.Error(t, err) proof.Depth = 257 require.EqualError(t, proof.validateBasic(base), "invalid depth: got 257, outside of [0, 256]") - result, err = VerifyClosestProof(proof, root, np) + result, err = VerifyClosestProof(proof, root, smst.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) @@ -244,14 +264,14 @@ func TestSMST_ClosestProof_ValidateBasic(t *testing.T) { require.NoError(t, err) proof.FlippedBits[0] = -1 require.EqualError(t, proof.validateBasic(base), "invalid flipped bit index 0: got -1, outside of [0, 8]") - result, err = VerifyClosestProof(proof, root, np) + result, err = VerifyClosestProof(proof, root, smst.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) require.Error(t, err) proof.FlippedBits[0] = 9 require.EqualError(t, proof.validateBasic(base), "invalid flipped bit index 0: got 9, outside of [0, 8]") - result, err = VerifyClosestProof(proof, root, np) + result, err = VerifyClosestProof(proof, root, smst.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) @@ -265,7 +285,7 @@ func TestSMST_ClosestProof_ValidateBasic(t *testing.T) { proof.validateBasic(base), "invalid closest path: 8d13809f932d0296b88c1913231ab4b403f05c88363575476204fef6930f22ae (not equal at bit: 3)", ) - result, err = VerifyClosestProof(proof, root, np) + result, err = VerifyClosestProof(proof, root, smst.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) @@ -326,7 +346,7 @@ func TestSMST_ProveClosest(t *testing.T) { ClosestProof: proof.ClosestProof, // copy of proof as we are checking equality of other fields }) - result, err = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), true)) + result, err = VerifyClosestProof(proof, root, smst.Spec()) require.NoError(t, err) require.True(t, result) @@ -352,7 +372,7 @@ func TestSMST_ProveClosest(t *testing.T) { ClosestProof: proof.ClosestProof, // copy of proof as we are checking equality of other fields }) - result, err = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), true)) + result, err = VerifyClosestProof(proof, root, smst.Spec()) require.NoError(t, err) require.True(t, result) } @@ -381,7 +401,7 @@ func TestSMST_ProveClosest_Empty(t *testing.T) { ClosestProof: &SparseMerkleProof{}, }) - result, err := VerifyClosestProof(proof, smst.Root(), NoPrehashSpec(sha256.New(), true)) + result, err := VerifyClosestProof(proof, smst.Root(), smst.Spec()) require.NoError(t, err) require.True(t, result) } @@ -419,7 +439,7 @@ func TestSMST_ProveClosest_OneNode(t *testing.T) { ClosestProof: &SparseMerkleProof{}, }) - result, err := VerifyClosestProof(proof, smst.Root(), NoPrehashSpec(sha256.New(), true)) + result, err := VerifyClosestProof(proof, smst.Root(), smst.Spec()) require.NoError(t, err) require.True(t, result) } diff --git a/smt_proofs_test.go b/smt_proofs_test.go index 2cf70c8..1c353df 100644 --- a/smt_proofs_test.go +++ b/smt_proofs_test.go @@ -178,7 +178,6 @@ func TestSMT_Proof_ValidateBasic(t *testing.T) { func TestSMT_ClosestProof_ValidateBasic(t *testing.T) { smn := simplemap.NewSimpleMap() smt := NewSparseMerkleTrie(smn, sha256.New()) - np := NoPrehashSpec(sha256.New(), false) base := smt.Spec() path := sha256.Sum256([]byte("testKey2")) flipPathBit(path[:], 3) @@ -201,14 +200,14 @@ func TestSMT_ClosestProof_ValidateBasic(t *testing.T) { require.NoError(t, err) proof.Depth = -1 require.EqualError(t, proof.validateBasic(base), "invalid depth: got -1, outside of [0, 256]") - result, err := VerifyClosestProof(proof, root, np) + result, err := VerifyClosestProof(proof, root, smt.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) require.Error(t, err) proof.Depth = 257 require.EqualError(t, proof.validateBasic(base), "invalid depth: got 257, outside of [0, 256]") - result, err = VerifyClosestProof(proof, root, np) + result, err = VerifyClosestProof(proof, root, smt.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) @@ -218,14 +217,14 @@ func TestSMT_ClosestProof_ValidateBasic(t *testing.T) { require.NoError(t, err) proof.FlippedBits[0] = -1 require.EqualError(t, proof.validateBasic(base), "invalid flipped bit index 0: got -1, outside of [0, 8]") - result, err = VerifyClosestProof(proof, root, np) + result, err = VerifyClosestProof(proof, root, smt.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) require.Error(t, err) proof.FlippedBits[0] = 9 require.EqualError(t, proof.validateBasic(base), "invalid flipped bit index 0: got 9, outside of [0, 8]") - result, err = VerifyClosestProof(proof, root, np) + result, err = VerifyClosestProof(proof, root, smt.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) @@ -239,7 +238,7 @@ func TestSMT_ClosestProof_ValidateBasic(t *testing.T) { proof.validateBasic(base), "invalid closest path: 8d13809f932d0296b88c1913231ab4b403f05c88363575476204fef6930f22ae (not equal at bit: 3)", ) - result, err = VerifyClosestProof(proof, root, np) + result, err = VerifyClosestProof(proof, root, smt.Spec()) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactClosestProof(proof, base) @@ -287,7 +286,7 @@ func TestSMT_ProveClosest(t *testing.T) { checkClosestCompactEquivalence(t, proof, smt.Spec()) require.NotEqual(t, proof, &SparseMerkleClosestProof{}) - result, err = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), false)) + result, err = VerifyClosestProof(proof, root, smt.Spec()) require.NoError(t, err) require.True(t, result) closestPath := sha256.Sum256([]byte("testKey2")) @@ -304,7 +303,7 @@ func TestSMT_ProveClosest(t *testing.T) { checkClosestCompactEquivalence(t, proof, smt.Spec()) require.NotEqual(t, proof, &SparseMerkleClosestProof{}) - result, err = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), false)) + result, err = VerifyClosestProof(proof, root, smt.Spec()) require.NoError(t, err) require.True(t, result) closestPath = sha256.Sum256([]byte("testKey4")) @@ -336,7 +335,7 @@ func TestSMT_ProveClosest_Empty(t *testing.T) { ClosestProof: &SparseMerkleProof{}, }) - result, err := VerifyClosestProof(proof, smt.Root(), NoPrehashSpec(sha256.New(), false)) + result, err := VerifyClosestProof(proof, smt.Root(), smt.Spec()) require.NoError(t, err) require.True(t, result) } @@ -368,7 +367,7 @@ func TestSMT_ProveClosest_OneNode(t *testing.T) { ClosestProof: &SparseMerkleProof{}, }) - result, err := VerifyClosestProof(proof, smt.Root(), NoPrehashSpec(sha256.New(), false)) + result, err := VerifyClosestProof(proof, smt.Root(), smt.Spec()) require.NoError(t, err) require.True(t, result) }