diff --git a/README.md b/README.md index 5d91879..02fe2c1 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ [![Tests](https://github.com/pokt-network/smt/actions/workflows/test.yml/badge.svg)](https://github.com/pokt-network/smt/actions/workflows/test.yml) [![codecov](https://codecov.io/gh/pokt-network/smt/branch/main/graph/badge.svg)](https://codecov.io/gh/pokt-network/smt) -Note: **Requires Go 1.19+** +Note: **Requires Go 1.20+** - [Overview](#overview) - [Documentation](#documentation) diff --git a/bulk_test.go b/bulk_test.go index 47ab8ca..e0e4574 100644 --- a/bulk_test.go +++ b/bulk_test.go @@ -111,14 +111,16 @@ func bulkCheckAll(t *testing.T, smt *SMTWithStorage, kv []bulkop) { if err != nil { t.Errorf("error: %v", err) } - if !VerifyProof(proof, smt.Root(), []byte(k), []byte(v), smt.Spec()) { + valid, err := VerifyProof(proof, smt.Root(), []byte(k), []byte(v), smt.Spec()) + if !valid || err != nil { t.Fatalf("Merkle proof failed to verify (i=%d): %v", ki, []byte(k)) } compactProof, err := ProveCompact([]byte(k), smt) if err != nil { t.Errorf("error: %v", err) } - if !VerifyCompactProof(compactProof, smt.Root(), []byte(k), []byte(v), smt.Spec()) { + valid, err = VerifyCompactProof(compactProof, smt.Root(), []byte(k), []byte(v), smt.Spec()) + if !valid || err != nil { t.Fatalf("Compact Merkle proof failed to verify (i=%d): %v", ki, []byte(k)) } @@ -135,7 +137,7 @@ func bulkCheckAll(t *testing.T, smt *SMTWithStorage, kv []bulkop) { } ph := smt.Spec().ph - commonPrefix := countCommonPrefix(ph.Path([]byte(k)), ph.Path([]byte(k2)), 0) + commonPrefix := countCommonPrefixBits(ph.Path([]byte(k)), ph.Path([]byte(k2)), 0) if commonPrefix != smt.Spec().depth() && commonPrefix > largestCommonPrefix { largestCommonPrefix = commonPrefix } diff --git a/go.mod b/go.mod index 4ffaade..aa402d0 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/pokt-network/smt -go 1.19 +go 1.20 require ( github.com/dgraph-io/badger/v4 v4.2.0 diff --git a/proofs.go b/proofs.go index 8e00fef..90ac01a 100644 --- a/proofs.go +++ b/proofs.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "encoding/gob" "errors" + "fmt" "math" ) @@ -51,33 +52,37 @@ func (proof *SparseMerkleProof) Unmarshal(bz []byte) error { return dec.Decode(proof) } -func (proof *SparseMerkleProof) sanityCheck(spec *TreeSpec) bool { +func (proof *SparseMerkleProof) validateBasic(spec *TreeSpec) error { // Do a basic sanity check on the proof, so that a malicious proof cannot // cause the verifier to fatally exit (e.g. due to an index out-of-range // error) or cause a CPU DoS attack. // Check that the number of supplied sidenodes does not exceed the maximum possible. - if len(proof.SideNodes) > spec.ph.PathSize()*8 || - - // Check that leaf data for non-membership proofs is a valid size. - (proof.NonMembershipLeafData != nil && len(proof.NonMembershipLeafData) < len(leafPrefix)+spec.ph.PathSize()) { - return false + if len(proof.SideNodes) > spec.ph.PathSize()*8 { + return fmt.Errorf("too many side nodes: %d", len(proof.SideNodes)) + } + // Check that leaf data for non-membership proofs is a valid size. + if proof.NonMembershipLeafData != nil && len(proof.NonMembershipLeafData) < len(leafPrefix)+spec.ph.PathSize() { + return fmt.Errorf("invalid non-membership leaf data size: %d", len(proof.NonMembershipLeafData)) } // Check that all supplied sidenodes are the correct size. for _, v := range proof.SideNodes { if len(v) != hashSize(spec) { - return false + return fmt.Errorf("invalid side node size: %d", len(v)) } } // Check that the sibling data hashes to the first side node if not nil if proof.SiblingData == nil || len(proof.SideNodes) == 0 { - return true + return nil } - siblingHash := hashPreimage(spec, proof.SiblingData) - return bytes.Equal(proof.SideNodes[0], siblingHash) + if eq := bytes.Equal(proof.SideNodes[0], siblingHash); !eq { + return fmt.Errorf("invalid sibling data hash: %x", siblingHash) + } + + return nil } // SparseCompactMerkleProof is a compact Merkle proof for an element in a SparseMerkleTree. @@ -121,7 +126,7 @@ func (proof *SparseCompactMerkleProof) Unmarshal(bz []byte) error { return dec.Decode(proof) } -func (proof *SparseCompactMerkleProof) sanityCheck(spec *TreeSpec) bool { +func (proof *SparseCompactMerkleProof) validateBasic(spec *TreeSpec) error { // Do a basic sanity check on the proof on the fields of the proof specific to // the compact proof only. // @@ -129,19 +134,23 @@ func (proof *SparseCompactMerkleProof) sanityCheck(spec *TreeSpec) bool { // de-compacted proof should be executed. // Compact proofs: check that NumSideNodes is within the right range. - if proof.NumSideNodes < 0 || proof.NumSideNodes > spec.ph.PathSize()*8 || + if proof.NumSideNodes < 0 || proof.NumSideNodes > spec.ph.PathSize()*8 { + return fmt.Errorf("invalid number of side nodes: %d", proof.NumSideNodes) + } - // Compact proofs: check that the length of the bit mask is as expected - // according to NumSideNodes. - len(proof.BitMask) != int(math.Ceil(float64(proof.NumSideNodes)/float64(8))) || + // Compact proofs: check that the length of the bit mask is as expected + // according to NumSideNodes. + if len(proof.BitMask) != int(math.Ceil(float64(proof.NumSideNodes)/float64(8))) { + return fmt.Errorf("invalid bit mask length: %d", len(proof.BitMask)) + } - // Compact proofs: check that the correct number of sidenodes have been - // supplied according to the bit mask. - (proof.NumSideNodes > 0 && len(proof.SideNodes) != proof.NumSideNodes-countSetBits(proof.BitMask)) { - return false + // Compact proofs: check that the correct number of sidenodes have been + // supplied according to the bit mask. + if proof.NumSideNodes > 0 && len(proof.SideNodes) != proof.NumSideNodes-countSetBits(proof.BitMask) { + return fmt.Errorf("invalid number of side nodes: %d", len(proof.SideNodes)) } - return true + return nil } // SparseMerkleClosestProof is a wrapper around a SparseMerkleProof that @@ -172,33 +181,38 @@ func (proof *SparseMerkleClosestProof) Unmarshal(bz []byte) error { return dec.Decode(proof) } -func (proof *SparseMerkleClosestProof) sanityCheck(spec *TreeSpec) bool { - if proof.Depth > spec.ph.PathSize()*8 { - return false +func (proof *SparseMerkleClosestProof) validateBasic(spec *TreeSpec) error { + // ensure the depth of the leaf node being proven is within the path size + if proof.Depth > spec.ph.PathSize()*8 || proof.Depth > 255 { + return fmt.Errorf("invalid depth: %d", proof.Depth) } + // for each of the bits flipped ensure that they are within the path size + // and that they are not greater than the depth of the leaf node being proven for _, i := range proof.FlippedBits { + // as proof.Depth <= spec.ph.PathSize()*8, i <= proof.Depth if i > proof.Depth { - return false - } - if i > spec.ph.PathSize()*8 { - return false + return fmt.Errorf("invalid flipped bit index: %d", i) } } + // create the path of the leaf node using the flipped bits metadata workingPath := proof.Path for _, i := range proof.FlippedBits { flipPathBit(workingPath, i) } - if prefix := countCommonPrefix( + // ensure that the path of the leaf node being proven has a prefix + // of length depth as the path provided (with bits flipped) + if prefix := countCommonPrefixBits( workingPath[:proof.Depth/8], proof.ClosestPath[:proof.Depth/8], 0, ); prefix != proof.Depth { - return false + return fmt.Errorf("invalid closest path: %x", proof.ClosestPath) } - if !proof.ClosestProof.sanityCheck(spec) { - return false + // validate the proof itself + if err := proof.ClosestProof.validateBasic(spec); err != nil { + return fmt.Errorf("invalid closest proof: %w", err) } - return true + return nil } // SparseCompactMerkleClosestProof is a compressed representation of the SparseMerkleClosestProof @@ -213,33 +227,38 @@ type SparseCompactMerkleClosestProof struct { ClosestProof *SparseCompactMerkleProof // the proof of the leaf closest to the path provided } -func (proof *SparseCompactMerkleClosestProof) sanityCheck(spec *TreeSpec) bool { - if int(proof.Depth) > spec.ph.PathSize()*8 { - return false +func (proof *SparseCompactMerkleClosestProof) validateBasic(spec *TreeSpec) error { + // ensure the depth of the leaf node being proven is within the path size + if int(proof.Depth) > spec.ph.PathSize()*8 || proof.Depth > 255 { + return fmt.Errorf("invalid depth: %d", proof.Depth) } + // for each of the bits flipped ensure that they are within the path size + // and that they are not greater than the depth of the leaf node being proven for _, i := range proof.FlippedBits { + // as proof.Depth <= spec.ph.PathSize()*8, i <= proof.Depth if i > proof.Depth { - return false - } - if int(i) > spec.ph.PathSize()*8 { - return false + return fmt.Errorf("invalid flipped bit index: %d", i) } } + // create the path of the leaf node using the flipped bits metadata workingPath := proof.Path for _, i := range proof.FlippedBits { flipPathBit(workingPath, int(i)) } - if prefix := countCommonPrefix( + // ensure that the path of the leaf node being proven has a prefix + // of length depth as the path provided (with bits flipped) + if prefix := countCommonPrefixBits( workingPath[:proof.Depth/8], proof.ClosestPath[:proof.Depth/8], 0, ); prefix != int(proof.Depth) { - return false + return fmt.Errorf("invalid closest path: %x", proof.ClosestPath) } - if !proof.ClosestProof.sanityCheck(spec) { - return false + // validate the proof itself + if err := proof.ClosestProof.validateBasic(spec); err != nil { + return fmt.Errorf("invalid closest proof: %w", err) } - return true + return nil } // Marshal serialises the SparseCompactMerkleClosestProof to bytes @@ -260,13 +279,13 @@ func (proof *SparseCompactMerkleClosestProof) Unmarshal(bz []byte) error { } // VerifyProof verifies a Merkle proof. -func VerifyProof(proof *SparseMerkleProof, root, key, value []byte, spec *TreeSpec) bool { - result, _ := verifyProofWithUpdates(proof, root, key, value, spec) - return result +func VerifyProof(proof *SparseMerkleProof, root, key, value []byte, spec *TreeSpec) (bool, error) { + result, _, err := verifyProofWithUpdates(proof, root, key, value, spec) + return result, err } // VerifySumProof verifies a Merkle proof for a sum tree. -func VerifySumProof(proof *SparseMerkleProof, root, key, value []byte, sum uint64, spec *TreeSpec) bool { +func VerifySumProof(proof *SparseMerkleProof, root, key, value []byte, sum uint64, spec *TreeSpec) (bool, error) { var sumBz [sumSize]byte binary.BigEndian.PutUint64(sumBz[:], sum) valueHash := spec.digestValue(value) @@ -285,11 +304,11 @@ func VerifySumProof(proof *SparseMerkleProof, root, key, value []byte, sum uint6 return VerifyProof(proof, root, key, valueHash, smtSpec) } -// VerifyClosestProof verifies a Merkle proof for a proof of a leaf found to -// have the closest path to the one provided to the proof function -func VerifyClosestProof(proof *SparseMerkleClosestProof, root []byte, spec *TreeSpec) bool { - if !proof.sanityCheck(spec) { - return false +// VerifyClosestProof verifies a Merkle proof for a proof of inclusion for a leaf +// found to have the closest path to the one provided to the proof structure +func VerifyClosestProof(proof *SparseMerkleClosestProof, root []byte, spec *TreeSpec) (bool, error) { + if err := proof.validateBasic(spec); err != nil { + return false, errors.Join(ErrBadProof, err) } if !spec.sumTree { return VerifyProof(proof.ClosestProof, root, proof.ClosestPath, proof.ClosestValueHash, spec) @@ -299,14 +318,15 @@ func VerifyClosestProof(proof *SparseMerkleClosestProof, root []byte, spec *Tree } sumBz := proof.ClosestValueHash[len(proof.ClosestValueHash)-sumSize:] sum := binary.BigEndian.Uint64(sumBz) - return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSize], sum, spec) + valueHash := proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSize] + return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, valueHash, sum, spec) } -func verifyProofWithUpdates(proof *SparseMerkleProof, root []byte, key []byte, value []byte, spec *TreeSpec) (bool, [][][]byte) { +func verifyProofWithUpdates(proof *SparseMerkleProof, root []byte, key []byte, value []byte, spec *TreeSpec) (bool, [][][]byte, error) { path := spec.ph.Path(key) - if !proof.sanityCheck(spec) { - return false, nil + if err := proof.validateBasic(spec); err != nil { + return false, nil, errors.Join(ErrBadProof, err) } var updates [][][]byte @@ -321,7 +341,7 @@ func verifyProofWithUpdates(proof *SparseMerkleProof, root []byte, key []byte, v actualPath, valueHash = parseLeaf(proof.NonMembershipLeafData, spec.ph) if bytes.Equal(actualPath, path) { // This is not an unrelated leaf; non-membership proof failed. - return false, nil + return false, nil, errors.Join(ErrBadProof, errors.New("non-membership proof on related leaf")) } currentHash, currentData = digestLeaf(spec, actualPath, valueHash) @@ -353,40 +373,40 @@ func verifyProofWithUpdates(proof *SparseMerkleProof, root []byte, key []byte, v updates = append(updates, update) } - return bytes.Equal(currentHash, root), updates + return bytes.Equal(currentHash, root), updates, nil } -// VerifyCompactProof verifies a compacted Merkle proof. -func VerifyCompactProof(proof *SparseCompactMerkleProof, root []byte, key, value []byte, spec *TreeSpec) bool { +// VerifyCompactProof is similar to VerifyProof but for a compacted Merkle proof. +func VerifyCompactProof(proof *SparseCompactMerkleProof, root []byte, key, value []byte, spec *TreeSpec) (bool, error) { decompactedProof, err := DecompactProof(proof, spec) if err != nil { - return false + return false, errors.Join(ErrBadProof, err) } return VerifyProof(decompactedProof, root, key, value, spec) } -// VerifyCompactSumProof verifies a compacted Merkle proof. -func VerifyCompactSumProof(proof *SparseCompactMerkleProof, root []byte, key, value []byte, sum uint64, spec *TreeSpec) bool { +// VerifyCompactSumProof is similar to VerifySumProof but for a compacted Merkle proof. +func VerifyCompactSumProof(proof *SparseCompactMerkleProof, root []byte, key, value []byte, sum uint64, spec *TreeSpec) (bool, error) { decompactedProof, err := DecompactProof(proof, spec) if err != nil { - return false + return false, errors.Join(ErrBadProof, err) } return VerifySumProof(decompactedProof, root, key, value, sum, spec) } -// VerifyCompactClosestProof verifies a compacted Merkle proof -func VerifyCompactClosestProof(proof *SparseCompactMerkleClosestProof, root []byte, spec *TreeSpec) bool { +// VerifyCompactClosestProof is similar to VerifyClosestProof but for a compacted merkle proof +func VerifyCompactClosestProof(proof *SparseCompactMerkleClosestProof, root []byte, spec *TreeSpec) (bool, error) { decompactedProof, err := DecompactClosestProof(proof, spec) if err != nil { - return false + return false, errors.Join(ErrBadProof, err) } return VerifyClosestProof(decompactedProof, root, spec) } // CompactProof compacts a proof, to reduce its size. func CompactProof(proof *SparseMerkleProof, spec *TreeSpec) (*SparseCompactMerkleProof, error) { - if !proof.sanityCheck(spec) { - return nil, ErrBadProof + if err := proof.validateBasic(spec); err != nil { + return nil, errors.Join(ErrBadProof, err) } bitMask := make([]byte, int(math.Ceil(float64(len(proof.SideNodes))/float64(8)))) @@ -412,8 +432,8 @@ func CompactProof(proof *SparseMerkleProof, spec *TreeSpec) (*SparseCompactMerkl // DecompactProof decompacts a proof, so that it can be used for VerifyProof. func DecompactProof(proof *SparseCompactMerkleProof, spec *TreeSpec) (*SparseMerkleProof, error) { - if !proof.sanityCheck(spec) { - return nil, ErrBadProof + if err := proof.validateBasic(spec); err != nil { + return nil, errors.Join(ErrBadProof, err) } decompactedSideNodes := make([][]byte, proof.NumSideNodes) @@ -436,8 +456,8 @@ func DecompactProof(proof *SparseCompactMerkleProof, spec *TreeSpec) (*SparseMer // CompactClosestProof compacts a proof, to reduce its size. func CompactClosestProof(proof *SparseMerkleClosestProof, spec *TreeSpec) (*SparseCompactMerkleClosestProof, error) { - if !proof.sanityCheck(spec) { - return nil, ErrBadProof + if err := proof.validateBasic(spec); err != nil { + return nil, errors.Join(ErrBadProof, err) } compactedProof, err := CompactProof(proof.ClosestProof, spec) if err != nil { @@ -445,12 +465,12 @@ func CompactClosestProof(proof *SparseMerkleClosestProof, spec *TreeSpec) (*Spar } flippedBits := make([]byte, len(proof.FlippedBits)) for i, v := range proof.FlippedBits { - flippedBits[i] = byte(v) + flippedBits[i] = intToByte(v) } return &SparseCompactMerkleClosestProof{ Path: proof.Path, FlippedBits: flippedBits, - Depth: byte(proof.Depth), + Depth: intToByte(proof.Depth), ClosestPath: proof.ClosestPath, ClosestValueHash: proof.ClosestValueHash, ClosestProof: compactedProof, @@ -459,8 +479,8 @@ func CompactClosestProof(proof *SparseMerkleClosestProof, spec *TreeSpec) (*Spar // DecompactClosestProof decompacts a proof, so that it can be used for VerifyClosestProof. func DecompactClosestProof(proof *SparseCompactMerkleClosestProof, spec *TreeSpec) (*SparseMerkleClosestProof, error) { - if !proof.sanityCheck(spec) { - return nil, ErrBadProof + if err := proof.validateBasic(spec); err != nil { + return nil, errors.Join(ErrBadProof, err) } decompactedProof, err := DecompactProof(proof.ClosestProof, spec) if err != nil { diff --git a/smst_proofs_test.go b/smst_proofs_test.go index 0def3fb..d40bc97 100644 --- a/smst_proofs_test.go +++ b/smst_proofs_test.go @@ -28,9 +28,11 @@ func TestSMST_ProofsBasic(t *testing.T) { proof, err = smst.Prove([]byte("testKey3")) require.NoError(t, err) checkCompactEquivalence(t, proof, base) - result = VerifySumProof(proof, placeholder(base), []byte("testKey3"), defaultValue, 0, base) + result, err = VerifySumProof(proof, placeholder(base), []byte("testKey3"), defaultValue, 0, base) + require.NoError(t, err) require.True(t, result) - result = VerifySumProof(proof, root, []byte("testKey3"), []byte("badValue"), 5, base) + result, err = VerifySumProof(proof, root, []byte("testKey3"), []byte("badValue"), 5, base) + require.NoError(t, err) require.False(t, result) // Add a key, generate and verify a Merkle proof. @@ -40,13 +42,17 @@ func TestSMST_ProofsBasic(t *testing.T) { proof, err = smst.Prove([]byte("testKey")) require.NoError(t, err) checkCompactEquivalence(t, proof, base) - result = VerifySumProof(proof, root, []byte("testKey"), []byte("testValue"), 5, base) // valid + result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("testValue"), 5, base) // valid + require.NoError(t, err) require.True(t, result) - result = VerifySumProof(proof, root, []byte("testKey"), []byte("badValue"), 5, base) // wrong value + result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("badValue"), 5, base) // wrong value + require.NoError(t, err) require.False(t, result) - result = VerifySumProof(proof, root, []byte("testKey"), []byte("testValue"), 10, base) // wrong sum + result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("testValue"), 10, base) // wrong sum + require.NoError(t, err) require.False(t, result) - result = VerifySumProof(proof, root, []byte("testKey"), []byte("badValue"), 10, base) // wrong value and sum + result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("badValue"), 10, base) // wrong value and sum + require.NoError(t, err) require.False(t, result) // Add a key, generate and verify both Merkle proofs. @@ -56,29 +62,39 @@ func TestSMST_ProofsBasic(t *testing.T) { proof, err = smst.Prove([]byte("testKey")) require.NoError(t, err) checkCompactEquivalence(t, proof, base) - result = VerifySumProof(proof, root, []byte("testKey"), []byte("testValue"), 5, base) // valid + result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("testValue"), 5, base) // valid + require.NoError(t, err) require.True(t, result) - result = VerifySumProof(proof, root, []byte("testKey"), []byte("badValue"), 5, base) // wrong value + result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("badValue"), 5, base) // wrong value + require.NoError(t, err) require.False(t, result) - result = VerifySumProof(proof, root, []byte("testKey"), []byte("testValue"), 10, base) // wrong sum + result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("testValue"), 10, base) // wrong sum + require.NoError(t, err) require.False(t, result) - result = VerifySumProof(proof, root, []byte("testKey"), []byte("badValue"), 10, base) // wrong value and sum + result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("badValue"), 10, base) // wrong value and sum + require.NoError(t, err) require.False(t, result) - result = VerifySumProof(randomiseSumProof(proof), root, []byte("testKey"), []byte("testValue"), 5, base) // invalid proof + result, err = VerifySumProof(randomiseSumProof(proof), root, []byte("testKey"), []byte("testValue"), 5, base) // invalid proof + require.NoError(t, err) require.False(t, result) proof, err = smst.Prove([]byte("testKey2")) require.NoError(t, err) checkCompactEquivalence(t, proof, base) - result = VerifySumProof(proof, root, []byte("testKey2"), []byte("testValue"), 5, base) // valid + result, err = VerifySumProof(proof, root, []byte("testKey2"), []byte("testValue"), 5, base) // valid + require.NoError(t, err) require.True(t, result) - result = VerifySumProof(proof, root, []byte("testKey2"), []byte("badValue"), 5, base) // wrong value + result, err = VerifySumProof(proof, root, []byte("testKey2"), []byte("badValue"), 5, base) // wrong value + require.NoError(t, err) require.False(t, result) - result = VerifySumProof(proof, root, []byte("testKey2"), []byte("testValue"), 10, base) // wrong sum + result, err = VerifySumProof(proof, root, []byte("testKey2"), []byte("testValue"), 10, base) // wrong sum + require.NoError(t, err) require.False(t, result) - result = VerifySumProof(proof, root, []byte("testKey2"), []byte("badValue"), 10, base) // wrong value and sum + result, err = VerifySumProof(proof, root, []byte("testKey2"), []byte("badValue"), 10, base) // wrong value and sum + require.NoError(t, err) require.False(t, result) - result = VerifySumProof(randomiseSumProof(proof), root, []byte("testKey2"), []byte("testValue"), 5, base) // invalid proof + result, err = VerifySumProof(randomiseSumProof(proof), root, []byte("testKey2"), []byte("testValue"), 5, base) // invalid proof + require.NoError(t, err) require.False(t, result) // Try proving a default value for a non-default leaf. @@ -91,20 +107,25 @@ func TestSMST_ProofsBasic(t *testing.T) { SideNodes: proof.SideNodes, NonMembershipLeafData: leafData, } - result = VerifySumProof(proof, root, []byte("testKey2"), defaultValue, 0, base) + result, err = VerifySumProof(proof, root, []byte("testKey2"), defaultValue, 0, base) + require.ErrorIs(t, err, ErrBadProof) require.False(t, result) // Generate and verify a proof on an empty key. proof, err = smst.Prove([]byte("testKey3")) require.NoError(t, err) checkCompactEquivalence(t, proof, base) - result = VerifySumProof(proof, root, []byte("testKey3"), defaultValue, 0, base) // valid + result, err = VerifySumProof(proof, root, []byte("testKey3"), defaultValue, 0, base) // valid + require.NoError(t, err) require.True(t, result) - result = VerifySumProof(proof, root, []byte("testKey3"), []byte("badValue"), 0, base) // wrong value + result, err = VerifySumProof(proof, root, []byte("testKey3"), []byte("badValue"), 0, base) // wrong value + require.NoError(t, err) require.False(t, result) - result = VerifySumProof(proof, root, []byte("testKey3"), defaultValue, 5, base) // wrong sum + result, err = VerifySumProof(proof, root, []byte("testKey3"), defaultValue, 5, base) // wrong sum + require.NoError(t, err) require.False(t, result) - result = VerifySumProof(randomiseSumProof(proof), root, []byte("testKey3"), defaultValue, 0, base) // invalid proof + result, err = VerifySumProof(randomiseSumProof(proof), root, []byte("testKey3"), defaultValue, 0, base) // invalid proof + require.NoError(t, err) require.False(t, result) require.NoError(t, smn.Stop()) @@ -112,7 +133,7 @@ func TestSMST_ProofsBasic(t *testing.T) { } // Test sanity check cases for non-compact proofs. -func TestSMST_ProofsSanityCheck(t *testing.T) { +func TestSMST_ProofsValidateBasic(t *testing.T) { smn, err := NewKVStore("") require.NoError(t, err) smv, err := NewKVStore("") @@ -137,8 +158,9 @@ func TestSMST_ProofsSanityCheck(t *testing.T) { sideNodes[i] = proof.SideNodes[0] } proof.SideNodes = sideNodes - require.False(t, proof.sanityCheck(base)) - result := VerifySumProof(proof, root, []byte("testKey1"), []byte("testValue1"), 1, base) + require.EqualError(t, proof.validateBasic(base), "too many side nodes: 257") + result, err := VerifySumProof(proof, root, []byte("testKey1"), []byte("testValue1"), 1, base) + require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactProof(proof, base) require.Error(t, err) @@ -146,8 +168,9 @@ func TestSMST_ProofsSanityCheck(t *testing.T) { // Case: incorrect size for NonMembershipLeafData. proof, _ = smst.Prove([]byte("testKey1")) proof.NonMembershipLeafData = make([]byte, 1) - require.False(t, proof.sanityCheck(base)) - result = VerifySumProof(proof, root, []byte("testKey1"), []byte("testValue1"), 1, base) + require.EqualError(t, proof.validateBasic(base), "invalid non-membership leaf data size: 1") + result, err = VerifySumProof(proof, root, []byte("testKey1"), []byte("testValue1"), 1, base) + require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactProof(proof, base) require.Error(t, err) @@ -155,8 +178,9 @@ func TestSMST_ProofsSanityCheck(t *testing.T) { // Case: unexpected sidenode size. proof, _ = smst.Prove([]byte("testKey1")) proof.SideNodes[0] = make([]byte, 1) - require.False(t, proof.sanityCheck(base)) - result = VerifySumProof(proof, root, []byte("testKey1"), []byte("testValue1"), 1, base) + require.EqualError(t, proof.validateBasic(base), "invalid side node size: 1") + result, err = VerifySumProof(proof, root, []byte("testKey1"), []byte("testValue1"), 1, base) + require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactProof(proof, base) require.Error(t, err) @@ -164,9 +188,10 @@ func TestSMST_ProofsSanityCheck(t *testing.T) { // Case: incorrect non-nil sibling data proof, _ = smst.Prove([]byte("testKey1")) proof.SiblingData = base.th.digest(proof.SiblingData) - require.False(t, proof.sanityCheck(base)) + require.EqualError(t, proof.validateBasic(base), "invalid sibling data hash: 437437455c0f5ca33597b9dd2a307bdfcc6833d3c272e101f30ed6358783fc247f0b9966865746c1") - result = VerifySumProof(proof, root, []byte("testKey1"), []byte("testValue1"), 1, base) + result, err = VerifySumProof(proof, root, []byte("testKey1"), []byte("testValue1"), 1, base) + require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactProof(proof, base) require.Error(t, err) @@ -228,7 +253,8 @@ func TestSMST_ProveClosest(t *testing.T) { ClosestProof: proof.ClosestProof, // copy of proof as we are checking equality of other fields }) - result = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), true)) + result, err = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), true)) + require.NoError(t, err) require.True(t, result) // testKey4 is the neighbour of testKey2, by flipping the final bit of the @@ -252,13 +278,14 @@ func TestSMST_ProveClosest(t *testing.T) { ClosestProof: proof.ClosestProof, // copy of proof as we are checking equality of other fields }) - result = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), true)) + result, err = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), true)) + require.NoError(t, err) require.True(t, result) require.NoError(t, smn.Stop()) } -func TestSMST_ProveClosestEmptyAndOneNode(t *testing.T) { +func TestSMST_ProveClosestEmpty(t *testing.T) { var smn KVStore var smst *SMST var proof *SparseMerkleClosestProof @@ -281,12 +308,31 @@ func TestSMST_ProveClosestEmptyAndOneNode(t *testing.T) { ClosestProof: &SparseMerkleProof{}, }) - result := VerifyClosestProof(proof, smst.Root(), NoPrehashSpec(sha256.New(), true)) + result, err := VerifyClosestProof(proof, smst.Root(), NoPrehashSpec(sha256.New(), true)) + require.NoError(t, err) require.True(t, result) + require.NoError(t, smn.Stop()) +} + +func TestSMST_ProveClosestOneNode(t *testing.T) { + var smn KVStore + var smst *SMST + var proof *SparseMerkleClosestProof + var err error + + smn, err = NewKVStore("") + require.NoError(t, err) + smst = NewSparseMerkleSumTree(smn, sha256.New(), WithValueHasher(nil)) + require.NoError(t, smst.Update([]byte("foo"), []byte("bar"), 5)) + + path := sha256.Sum256([]byte("testKey2")) + flipPathBit(path[:], 3) + flipPathBit(path[:], 6) proof, err = smst.ProveClosest(path[:]) require.NoError(t, err) + closestPath := sha256.Sum256([]byte("foo")) closestValueHash := []byte("bar") var sumBz [sumSize]byte @@ -301,7 +347,8 @@ func TestSMST_ProveClosestEmptyAndOneNode(t *testing.T) { ClosestProof: &SparseMerkleProof{}, }) - result = VerifyClosestProof(proof, smst.Root(), NoPrehashSpec(sha256.New(), true)) + result, err := VerifyClosestProof(proof, smst.Root(), NoPrehashSpec(sha256.New(), true)) + require.NoError(t, err) require.True(t, result) require.NoError(t, smn.Stop()) diff --git a/smst_test.go b/smst_test.go index 9ffa6cf..2414443 100644 --- a/smst_test.go +++ b/smst_test.go @@ -472,7 +472,8 @@ func TestSMST_TotalSum(t *testing.T) { proof, err := smst.Prove([]byte("key1")) require.NoError(t, err) checkCompactEquivalence(t, proof, smst.Spec()) - valid := VerifySumProof(proof, root1, []byte("key1"), []byte("value1"), 5, smst.Spec()) + valid, err := VerifySumProof(proof, root1, []byte("key1"), []byte("value1"), 5, smst.Spec()) + require.NoError(t, err) require.True(t, valid) // Check that the sum is correct after deleting a key diff --git a/smt.go b/smt.go index c218b53..0abff03 100644 --- a/smt.go +++ b/smt.go @@ -156,7 +156,7 @@ func (smt *SMT) update( return newLeaf, nil } if leaf, ok := node.(*leafNode); ok { - prefixlen := countCommonPrefix(path, leaf.path, depth) + prefixlen := countCommonPrefixBits(path, leaf.path, depth) if prefixlen == smt.depth() { // replace leaf if paths are equal smt.addOrphan(orphans, node) return newLeaf, nil diff --git a/smt_proofs_test.go b/smt_proofs_test.go index a8ae6c7..5ae55ba 100644 --- a/smt_proofs_test.go +++ b/smt_proofs_test.go @@ -27,9 +27,11 @@ func TestSMT_ProofsBasic(t *testing.T) { proof, err = smt.Prove([]byte("testKey3")) require.NoError(t, err) checkCompactEquivalence(t, proof, base) - result = VerifyProof(proof, base.th.placeholder(), []byte("testKey3"), defaultValue, base) + result, err = VerifyProof(proof, base.th.placeholder(), []byte("testKey3"), defaultValue, base) + require.NoError(t, err) require.True(t, result) - result = VerifyProof(proof, root, []byte("testKey3"), []byte("badValue"), base) + result, err = VerifyProof(proof, root, []byte("testKey3"), []byte("badValue"), base) + require.NoError(t, err) require.False(t, result) // Add a key, generate and verify a Merkle proof. @@ -39,9 +41,11 @@ func TestSMT_ProofsBasic(t *testing.T) { proof, err = smt.Prove([]byte("testKey")) require.NoError(t, err) checkCompactEquivalence(t, proof, base) - result = VerifyProof(proof, root, []byte("testKey"), []byte("testValue"), base) + result, err = VerifyProof(proof, root, []byte("testKey"), []byte("testValue"), base) + require.NoError(t, err) require.True(t, result) - result = VerifyProof(proof, root, []byte("testKey"), []byte("badValue"), base) + result, err = VerifyProof(proof, root, []byte("testKey"), []byte("badValue"), base) + require.NoError(t, err) require.False(t, result) // Add a key, generate and verify both Merkle proofs. @@ -51,21 +55,27 @@ func TestSMT_ProofsBasic(t *testing.T) { proof, err = smt.Prove([]byte("testKey")) require.NoError(t, err) checkCompactEquivalence(t, proof, base) - result = VerifyProof(proof, root, []byte("testKey"), []byte("testValue"), base) + result, err = VerifyProof(proof, root, []byte("testKey"), []byte("testValue"), base) + require.NoError(t, err) require.True(t, result) - result = VerifyProof(proof, root, []byte("testKey"), []byte("badValue"), base) + result, err = VerifyProof(proof, root, []byte("testKey"), []byte("badValue"), base) + require.NoError(t, err) require.False(t, result) - result = VerifyProof(randomiseProof(proof), root, []byte("testKey"), []byte("testValue"), base) + result, err = VerifyProof(randomiseProof(proof), root, []byte("testKey"), []byte("testValue"), base) + require.NoError(t, err) require.False(t, result) proof, err = smt.Prove([]byte("testKey2")) require.NoError(t, err) checkCompactEquivalence(t, proof, base) - result = VerifyProof(proof, root, []byte("testKey2"), []byte("testValue"), base) + result, err = VerifyProof(proof, root, []byte("testKey2"), []byte("testValue"), base) + require.NoError(t, err) require.True(t, result) - result = VerifyProof(proof, root, []byte("testKey2"), []byte("badValue"), base) + result, err = VerifyProof(proof, root, []byte("testKey2"), []byte("badValue"), base) + require.NoError(t, err) require.False(t, result) - result = VerifyProof(randomiseProof(proof), root, []byte("testKey2"), []byte("testValue"), base) + result, err = VerifyProof(randomiseProof(proof), root, []byte("testKey2"), []byte("testValue"), base) + require.NoError(t, err) require.False(t, result) // Try proving a default value for a non-default leaf. @@ -74,18 +84,22 @@ func TestSMT_ProofsBasic(t *testing.T) { SideNodes: proof.SideNodes, NonMembershipLeafData: leafData, } - result = VerifyProof(proof, root, []byte("testKey2"), defaultValue, base) + result, err = VerifyProof(proof, root, []byte("testKey2"), defaultValue, base) + require.ErrorIs(t, err, ErrBadProof) require.False(t, result) // Generate and verify a proof on an empty key. proof, err = smt.Prove([]byte("testKey3")) require.NoError(t, err) checkCompactEquivalence(t, proof, base) - result = VerifyProof(proof, root, []byte("testKey3"), defaultValue, base) + result, err = VerifyProof(proof, root, []byte("testKey3"), defaultValue, base) + require.NoError(t, err) require.True(t, result) - result = VerifyProof(proof, root, []byte("testKey3"), []byte("badValue"), base) + result, err = VerifyProof(proof, root, []byte("testKey3"), []byte("badValue"), base) + require.NoError(t, err) require.False(t, result) - result = VerifyProof(randomiseProof(proof), root, []byte("testKey3"), defaultValue, base) + result, err = VerifyProof(randomiseProof(proof), root, []byte("testKey3"), defaultValue, base) + require.NoError(t, err) require.False(t, result) require.NoError(t, smn.Stop()) @@ -93,7 +107,7 @@ func TestSMT_ProofsBasic(t *testing.T) { } // Test sanity check cases for non-compact proofs. -func TestSMT_ProofsSanityCheck(t *testing.T) { +func TestSMT_ProofsValidateBasic(t *testing.T) { smn, err := NewKVStore("") require.NoError(t, err) smv, err := NewKVStore("") @@ -118,8 +132,9 @@ func TestSMT_ProofsSanityCheck(t *testing.T) { sideNodes[i] = proof.SideNodes[0] } proof.SideNodes = sideNodes - require.False(t, proof.sanityCheck(base)) - result := VerifyProof(proof, root, []byte("testKey1"), []byte("testValue1"), base) + require.EqualError(t, proof.validateBasic(base), "too many side nodes: 257") + result, err := VerifyProof(proof, root, []byte("testKey1"), []byte("testValue1"), base) + require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactProof(proof, base) require.Error(t, err) @@ -127,8 +142,9 @@ func TestSMT_ProofsSanityCheck(t *testing.T) { // Case: incorrect size for NonMembershipLeafData. proof, _ = smt.Prove([]byte("testKey1")) proof.NonMembershipLeafData = make([]byte, 1) - require.False(t, proof.sanityCheck(base)) - result = VerifyProof(proof, root, []byte("testKey1"), []byte("testValue1"), base) + require.EqualError(t, proof.validateBasic(base), "invalid non-membership leaf data size: 1") + result, err = VerifyProof(proof, root, []byte("testKey1"), []byte("testValue1"), base) + require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactProof(proof, base) require.Error(t, err) @@ -136,8 +152,9 @@ func TestSMT_ProofsSanityCheck(t *testing.T) { // Case: unexpected sidenode size. proof, _ = smt.Prove([]byte("testKey1")) proof.SideNodes[0] = make([]byte, 1) - require.False(t, proof.sanityCheck(base)) - result = VerifyProof(proof, root, []byte("testKey1"), []byte("testValue1"), base) + require.EqualError(t, proof.validateBasic(base), "invalid side node size: 1") + result, err = VerifyProof(proof, root, []byte("testKey1"), []byte("testValue1"), base) + require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactProof(proof, base) require.Error(t, err) @@ -145,9 +162,10 @@ func TestSMT_ProofsSanityCheck(t *testing.T) { // Case: incorrect non-nil sibling data proof, _ = smt.Prove([]byte("testKey1")) proof.SiblingData = base.th.digest(proof.SiblingData) - require.False(t, proof.sanityCheck(base)) + require.EqualError(t, proof.validateBasic(base), "invalid sibling data hash: 187864587bac133246face60f98b8214407aa314f37dfc9ce8e1f5c80284a866") - result = VerifyProof(proof, root, []byte("testKey1"), []byte("testValue1"), base) + result, err = VerifyProof(proof, root, []byte("testKey1"), []byte("testValue1"), base) + require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactProof(proof, base) require.Error(t, err) @@ -196,7 +214,8 @@ func TestSMT_ProveClosest(t *testing.T) { require.NoError(t, err) require.NotEqual(t, proof, &SparseMerkleClosestProof{}) - result = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), false)) + result, err = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), false)) + require.NoError(t, err) require.True(t, result) closestPath := sha256.Sum256([]byte("testKey2")) require.Equal(t, closestPath[:], proof.ClosestPath) @@ -211,7 +230,8 @@ func TestSMT_ProveClosest(t *testing.T) { require.NoError(t, err) require.NotEqual(t, proof, &SparseMerkleClosestProof{}) - result = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), false)) + result, err = VerifyClosestProof(proof, root, NoPrehashSpec(sha256.New(), false)) + require.NoError(t, err) require.True(t, result) closestPath = sha256.Sum256([]byte("testKey4")) require.Equal(t, closestPath[:], proof.ClosestPath) @@ -220,7 +240,7 @@ func TestSMT_ProveClosest(t *testing.T) { require.NoError(t, smn.Stop()) } -func TestSMT_ProveClosestEmptyAndOneNode(t *testing.T) { +func TestSMT_ProveClosestEmpty(t *testing.T) { var smn KVStore var smt *SMT var proof *SparseMerkleClosestProof @@ -243,9 +263,27 @@ func TestSMT_ProveClosestEmptyAndOneNode(t *testing.T) { ClosestProof: &SparseMerkleProof{}, }) - result := VerifyClosestProof(proof, smt.Root(), NoPrehashSpec(sha256.New(), false)) + result, err := VerifyClosestProof(proof, smt.Root(), NoPrehashSpec(sha256.New(), false)) + require.NoError(t, err) require.True(t, result) + require.NoError(t, smn.Stop()) +} + +func TestSMT_ProveClosestOneNode(t *testing.T) { + var smn KVStore + var smt *SMT + var proof *SparseMerkleClosestProof + var err error + + smn, err = NewKVStore("") + require.NoError(t, err) + smt = NewSparseMerkleTree(smn, sha256.New(), WithValueHasher(nil)) + + path := sha256.Sum256([]byte("testKey2")) + flipPathBit(path[:], 3) + flipPathBit(path[:], 6) + require.NoError(t, smt.Update([]byte("foo"), []byte("bar"))) proof, err = smt.ProveClosest(path[:]) require.NoError(t, err) @@ -259,7 +297,8 @@ func TestSMT_ProveClosestEmptyAndOneNode(t *testing.T) { ClosestProof: &SparseMerkleProof{}, }) - result = VerifyClosestProof(proof, smt.Root(), NoPrehashSpec(sha256.New(), false)) + result, err := VerifyClosestProof(proof, smt.Root(), NoPrehashSpec(sha256.New(), false)) + require.NoError(t, err) require.True(t, result) require.NoError(t, smn.Stop()) diff --git a/utils.go b/utils.go index d7c93ba..c9c2cad 100644 --- a/utils.go +++ b/utils.go @@ -1,5 +1,7 @@ package smt +import "fmt" + type nilPathHasher struct { hashSize int } @@ -47,6 +49,7 @@ func flipPathBit(data []byte, position int) { data[position/8] = byte(n) } +// countSetBits counts the number of bits set in the data provided (ie the number of 1s) func countSetBits(data []byte) int { count := 0 for i := 0; i < len(data)*8; i++ { @@ -57,8 +60,8 @@ func countSetBits(data []byte) int { return count } -// counts common bits in each path, starting from some position -func countCommonPrefix(data1, data2 []byte, from int) int { +// countCommonPrefixBits counts common bits in each path, starting from some position +func countCommonPrefixBits(data1, data2 []byte, from int) int { count := 0 for i := from; i < len(data1)*8; i++ { if getPathBit(data1, i) == getPathBit(data2, i) { @@ -70,6 +73,14 @@ func countCommonPrefix(data1, data2 []byte, from int) int { return count + from } +// intToByte converts an int safely to a byte panicing on error +func intToByte(i int) byte { + if i > 255 || i < 0 { + panic(fmt.Errorf("int outside of byte range [0, 255): %d", i)) + } + return byte(i) +} + // placeholder returns the default placeholder value depending on the tree type func placeholder(spec *TreeSpec) []byte { if spec.sumTree {