From f83eeb56e5c353295ba5a31594b054bcf4d5f0b5 Mon Sep 17 00:00:00 2001 From: Bryan White Date: Fri, 19 Jul 2024 12:16:23 +0200 Subject: [PATCH] [TODO] chore: update `smt.MerkleRoot#Sum()` error handling (#672) Co-authored-by: Redouane Lakrache --- go.mod | 2 +- go.sum | 4 +- pkg/client/supplier/client_test.go | 4 +- pkg/crypto/protocol/hasher.go | 12 ++++ pkg/relayer/session/sessiontree.go | 3 +- .../relay_mining_difficulty_test.go | 3 +- testutil/proof/fixture_generators.go | 39 +++++++------ x/application/keeper/auto_undelegate.go | 1 + .../keeper/msg_server_create_claim_test.go | 5 +- x/proof/types/claim.go | 29 +--------- .../keeper/settle_session_accounting.go | 55 ++++++++++++++----- .../keeper/settle_session_accounting_test.go | 37 ++++--------- 12 files changed, 102 insertions(+), 92 deletions(-) create mode 100644 pkg/crypto/protocol/hasher.go diff --git a/go.mod b/go.mod index 3fbc306fb..33c469839 100644 --- a/go.mod +++ b/go.mod @@ -57,7 +57,7 @@ require ( // repo is the first obvious idea, but has to be carefully considered, automated, and is not // a hard blocker. github.com/pokt-network/shannon-sdk v0.0.0-20240628223057-7d2928722749 - github.com/pokt-network/smt v0.11.1 + github.com/pokt-network/smt v0.12.0 github.com/pokt-network/smt/kvstore/badger v0.0.0-20240109205447-868237978c0b github.com/prometheus/client_golang v1.19.0 github.com/regen-network/gocuke v1.1.0 diff --git a/go.sum b/go.sum index fcb07cd82..5c9f760df 100644 --- a/go.sum +++ b/go.sum @@ -996,8 +996,8 @@ github.com/pokt-network/ring-go v0.1.0 h1:hF7mDR4VVCIqqDAsrloP8azM9y1mprc99YgnTj github.com/pokt-network/ring-go v0.1.0/go.mod h1:8NHPH7H3EwrPX3XHfpyRI6bz4gApkE3+fd0XZRbMWP0= github.com/pokt-network/shannon-sdk v0.0.0-20240628223057-7d2928722749 h1:V/3xzmykSABhAxRZLawWUoIPVlnp7EGCnCxFpLXD7R0= github.com/pokt-network/shannon-sdk v0.0.0-20240628223057-7d2928722749/go.mod h1:MfoRhzPRlxiaY3xQyZo28B7ibDuhricA//TGGy48TwM= -github.com/pokt-network/smt v0.11.1 h1:ySN8PjrPDKyvzLcX0qTHR2s5ReaZnjq25z0B7p6AWl0= -github.com/pokt-network/smt v0.11.1/go.mod h1:S4Ho4OPkK2v2vUCHNtA49XDjqUC/OFYpBbynRVYmxvA= +github.com/pokt-network/smt v0.12.0 h1:uqru/0ykC4LnBoMacakobNOd1iRK69PlohqjMtLmYNA= +github.com/pokt-network/smt v0.12.0/go.mod h1:S4Ho4OPkK2v2vUCHNtA49XDjqUC/OFYpBbynRVYmxvA= github.com/pokt-network/smt/kvstore/badger v0.0.0-20240109205447-868237978c0b h1:TjfgV3vgW0zW47Br/OgUXD4M8iyR74EYanbFfN4ed8o= github.com/pokt-network/smt/kvstore/badger v0.0.0-20240109205447-868237978c0b/go.mod h1:GbzcG5ebj8twKmBL1VzdPM4NS44okwYXBfQaVXT+6yU= github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= diff --git a/pkg/client/supplier/client_test.go b/pkg/client/supplier/client_test.go index 122318bb4..09aa0dffb 100644 --- a/pkg/client/supplier/client_test.go +++ b/pkg/client/supplier/client_test.go @@ -2,7 +2,6 @@ package supplier_test import ( "context" - "crypto/sha256" "testing" "time" @@ -14,6 +13,7 @@ import ( "github.com/pokt-network/poktroll/pkg/client/keyring" "github.com/pokt-network/poktroll/pkg/client/supplier" + "github.com/pokt-network/poktroll/pkg/crypto/protocol" "github.com/pokt-network/poktroll/testutil/mockclient" "github.com/pokt-network/poktroll/testutil/testclient/testkeyring" "github.com/pokt-network/poktroll/testutil/testclient/testtx" @@ -181,7 +181,7 @@ func TestSupplierClient_SubmitProof(t *testing.T) { // Generating an ephemeral tree & spec just so we can submit // a proof of the right size. // TODO_TECHDEBT(#446): Centralize the configuration for the SMT spec. - tree := smt.NewSparseMerkleSumTrie(kvStore, sha256.New()) + tree := smt.NewSparseMerkleSumTrie(kvStore, protocol.NewTrieHasher()) emptyPath := make([]byte, tree.PathHasherSize()) proof, err := tree.ProveClosest(emptyPath) require.NoError(t, err) diff --git a/pkg/crypto/protocol/hasher.go b/pkg/crypto/protocol/hasher.go new file mode 100644 index 000000000..447c202ec --- /dev/null +++ b/pkg/crypto/protocol/hasher.go @@ -0,0 +1,12 @@ +package protocol + +import "crypto/sha256" + +const ( + TrieHasherSize = sha256.Size + TrieRootSize = TrieHasherSize + trieRootMetadataSize + // TODO_CONSIDERATION: Export this from the SMT package. + trieRootMetadataSize = 16 +) + +var NewTrieHasher = sha256.New diff --git a/pkg/relayer/session/sessiontree.go b/pkg/relayer/session/sessiontree.go index 5236225cd..b9ec7151e 100644 --- a/pkg/relayer/session/sessiontree.go +++ b/pkg/relayer/session/sessiontree.go @@ -11,6 +11,7 @@ import ( "github.com/pokt-network/smt" "github.com/pokt-network/smt/kvstore/badger" + "github.com/pokt-network/poktroll/pkg/crypto/protocol" "github.com/pokt-network/poktroll/pkg/relayer" sessiontypes "github.com/pokt-network/poktroll/x/session/types" ) @@ -85,7 +86,7 @@ func NewSessionTree( // Create the SMST from the KVStore and a nil value hasher so the proof would // contain a non-hashed Relay that could be used to validate the proof on-chain. - trie := smt.NewSparseMerkleSumTrie(treeStore, sha256.New(), smt.WithValueHasher(nil)) + trie := smt.NewSparseMerkleSumTrie(treeStore, protocol.NewTrieHasher(), smt.WithValueHasher(nil)) sessionTree := &sessionTree{ sessionHeader: sessionHeader, diff --git a/tests/integration/tokenomics/relay_mining_difficulty_test.go b/tests/integration/tokenomics/relay_mining_difficulty_test.go index 1b543c00a..09430a7b1 100644 --- a/tests/integration/tokenomics/relay_mining_difficulty_test.go +++ b/tests/integration/tokenomics/relay_mining_difficulty_test.go @@ -2,7 +2,6 @@ package integration_test import ( "context" - "crypto/sha256" "testing" "github.com/pokt-network/smt" @@ -193,7 +192,7 @@ func prepareSMST( integrationApp.GetRingClient(), ) - trie := smt.NewSparseMerkleSumTrie(kvStore, sha256.New(), smt.WithValueHasher(nil)) + trie := smt.NewSparseMerkleSumTrie(kvStore, protocol.NewTrieHasher(), smt.WithValueHasher(nil)) err = trie.Update(minedRelay.Hash, minedRelay.Bytes, 1) require.NoError(t, err) diff --git a/testutil/proof/fixture_generators.go b/testutil/proof/fixture_generators.go index 96d263ce0..19855b8af 100644 --- a/testutil/proof/fixture_generators.go +++ b/testutil/proof/fixture_generators.go @@ -5,9 +5,11 @@ import ( "math/rand" "testing" - "github.com/pokt-network/smt" "github.com/stretchr/testify/require" + "github.com/pokt-network/smt" + + "github.com/pokt-network/poktroll/pkg/crypto/protocol" testsession "github.com/pokt-network/poktroll/testutil/session" prooftypes "github.com/pokt-network/poktroll/x/proof/types" sessiontypes "github.com/pokt-network/poktroll/x/session/types" @@ -51,32 +53,37 @@ func ClaimWithRandomHash(t *testing.T, appAddr, supplierAddr string, sum uint64) // TODO_MAINNET: Revisit if the SMT should be big or little Endian. Refs: // https://github.com/pokt-network/smt/pull/46#discussion_r1636975124 // https://github.com/pokt-network/smt/blob/ea585c6c3bc31c804b6bafa83e985e473b275580/smst.go#L23C10-L23C76 -func SmstRootWithSum(sum uint64) smt.MerkleRoot { - root := [smt.SmstRootSizeBytes]byte{} - // Insert the sum into the root hash - binary.BigEndian.PutUint64(root[smt.SmtRootSizeBytes:], sum) - // Insert the count into the root hash - // TODO_TECHDEBT: This is a hard-coded count of 1, but could be a parameter. - // TODO_TECHDEBT: We are assuming the sum takes up 8 bytes. - binary.BigEndian.PutUint64(root[smt.SmtRootSizeBytes+8:], 1) - return smt.MerkleRoot(root[:]) +func SmstRootWithSum(sum uint64) smt.MerkleSumRoot { + root := [protocol.TrieRootSize]byte{} + return encodeSum(root, sum) } // RandSmstRootWithSum returns a randomized SMST root with the given sum that // can be used for testing. Randomizing the root is a simple way to randomize // test claim hashes for testing proof requirement cases. -func RandSmstRootWithSum(t *testing.T, sum uint64) smt.MerkleRoot { +func RandSmstRootWithSum(t *testing.T, sum uint64) smt.MerkleSumRoot { t.Helper() - root := [smt.SmstRootSizeBytes]byte{} + root := [protocol.TrieRootSize]byte{} // Only populate the first 32 bytes with random data, leave the last 8 bytes for the sum. - _, err := rand.Read(root[:smt.SmtRootSizeBytes]) //nolint:staticcheck // We need a deterministic pseudo-random source. + _, err := rand.Read(root[:protocol.TrieHasherSize]) //nolint:staticcheck // We need a deterministic pseudo-random source. require.NoError(t, err) - binary.BigEndian.PutUint64(root[smt.SmtRootSizeBytes:], sum) + return encodeSum(root, sum) +} + +// encodeSum returns a copy of the given root, binary encodes the given sum, +// and stores the encoded sum in the root copy. +func encodeSum(r [protocol.TrieRootSize]byte, sum uint64) smt.MerkleSumRoot { + root := make([]byte, protocol.TrieRootSize) + copy(root, r[:]) + + // Insert the sum into the root hash + binary.BigEndian.PutUint64(root[protocol.TrieHasherSize:], sum) // Insert the count into the root hash // TODO_TECHDEBT: This is a hard-coded count of 1, but could be a parameter. // TODO_TECHDEBT: We are assuming the sum takes up 8 bytes. - binary.BigEndian.PutUint64(root[smt.SmtRootSizeBytes+8:], 1) - return smt.MerkleRoot(root[:]) + binary.BigEndian.PutUint64(root[protocol.TrieHasherSize+8:], 1) + + return root } diff --git a/x/application/keeper/auto_undelegate.go b/x/application/keeper/auto_undelegate.go index dfe2fb67c..9e88cbb29 100644 --- a/x/application/keeper/auto_undelegate.go +++ b/x/application/keeper/auto_undelegate.go @@ -7,6 +7,7 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" proto "github.com/cosmos/gogoproto/proto" + gatewaytypes "github.com/pokt-network/poktroll/x/gateway/types" ) diff --git a/x/proof/keeper/msg_server_create_claim_test.go b/x/proof/keeper/msg_server_create_claim_test.go index 763d5ff6d..4ac0c4c72 100644 --- a/x/proof/keeper/msg_server_create_claim_test.go +++ b/x/proof/keeper/msg_server_create_claim_test.go @@ -5,11 +5,12 @@ import ( abci "github.com/cometbft/cometbft/abci/types" cosmostypes "github.com/cosmos/cosmos-sdk/types" - "github.com/pokt-network/smt" "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "github.com/pokt-network/smt" + keepertest "github.com/pokt-network/poktroll/testutil/keeper" testproof "github.com/pokt-network/poktroll/testutil/proof" "github.com/pokt-network/poktroll/testutil/sample" @@ -490,7 +491,7 @@ func newTestClaimMsg( supplierAddr string, appAddr string, service *sharedtypes.Service, - merkleRoot smt.MerkleRoot, + merkleRoot smt.MerkleSumRoot, ) *types.MsgCreateClaim { t.Helper() diff --git a/x/proof/types/claim.go b/x/proof/types/claim.go index 724dbc2d3..a9c9bad3a 100644 --- a/x/proof/types/claim.go +++ b/x/proof/types/claim.go @@ -1,8 +1,6 @@ package types import ( - "fmt" - "github.com/cometbft/cometbft/crypto" "github.com/pokt-network/smt" @@ -11,36 +9,13 @@ import ( // GetNumComputeUnits returns the number of compute units for a given claim // as determined by the sum of the root hash. func (claim *Claim) GetNumComputeUnits() (numComputeUnits uint64, err error) { - // NB: smt.MerkleRoot#Sum() will panic if the root hash is not valid. - // Convert this panic into an error return. - defer func() { - if r := recover(); r != nil { - numComputeUnits = 0 - err = fmt.Errorf( - "unable to get sum of invalid merkle root: %x; error: %v", - claim.GetRootHash(), r, - ) - } - }() - - return smt.MerkleRoot(claim.GetRootHash()).Sum(), nil + return smt.MerkleSumRoot(claim.GetRootHash()).Sum() } // GetNumRelays returns the number of relays for a given claim // as determined by the count of the root hash. func (claim *Claim) GetNumRelays() (numRelays uint64, err error) { - // Convert this panic into an error return. - defer func() { - if r := recover(); r != nil { - numRelays = 0 - err = fmt.Errorf( - "unable to get count of invalid merkle root: %x; error: %v", - claim.GetRootHash(), r, - ) - } - }() - - return smt.MerkleRoot(claim.GetRootHash()).Count(), nil + return smt.MerkleSumRoot(claim.GetRootHash()).Count() } // GetHash returns the SHA-256 hash of the serialized claim. diff --git a/x/tokenomics/keeper/settle_session_accounting.go b/x/tokenomics/keeper/settle_session_accounting.go index abefd8eed..61dfb6aa6 100644 --- a/x/tokenomics/keeper/settle_session_accounting.go +++ b/x/tokenomics/keeper/settle_session_accounting.go @@ -6,6 +6,10 @@ import ( "cosmossdk.io/math" cosmostypes "github.com/cosmos/cosmos-sdk/types" + + "github.com/pokt-network/poktroll/app/volatile" + "github.com/pokt-network/poktroll/pkg/crypto/protocol" + "github.com/pokt-network/smt" "github.com/pokt-network/poktroll/telemetry" @@ -26,7 +30,7 @@ import ( func (k Keeper) SettleSessionAccounting( ctx context.Context, claim *prooftypes.Claim, -) error { +) (err error) { logger := k.Logger().With("method", "SettleSessionAccounting") settlementCoin := cosmostypes.NewCoin("upokt", math.NewInt(0)) @@ -34,7 +38,12 @@ func (k Keeper) SettleSessionAccounting( // This is emitted only when the function returns. defer telemetry.EventSuccessCounter( "settle_session_accounting", - func() float32 { return float32(settlementCoin.Amount.Int64()) }, + func() float32 { + if settlementCoin.Amount.BigInt() == nil { + return 0 + } + return float32(settlementCoin.Amount.Int64()) + }, func() bool { return isSuccessful }, ) @@ -50,7 +59,7 @@ func (k Keeper) SettleSessionAccounting( logger.Error("received a nil session header") return tokenomicstypes.ErrTokenomicsSessionHeaderNil } - if err := sessionHeader.ValidateBasic(); err != nil { + if err = sessionHeader.ValidateBasic(); err != nil { logger.Error("received an invalid session header", "error", err) return tokenomicstypes.ErrTokenomicsSessionHeaderInvalid } @@ -66,15 +75,19 @@ func (k Keeper) SettleSessionAccounting( } // Retrieve the sum of the root as a proxy into the amount of work done - root := (smt.MerkleRoot)(claim.GetRootHash()) + root := (smt.MerkleSumRoot)(claim.GetRootHash()) + + if !root.HasDigestSize(protocol.TrieHasherSize) { + return tokenomicstypes.ErrTokenomicsRootHashInvalid.Wrapf( + "root hash has invalid digest size (%d), expected (%d)", + root.DigestSize(), protocol.TrieHasherSize, + ) + } - // TODO_BLOCKER(@Olshansk): This check should be the responsibility of the SMST package - // since it's used to get compute units from the root hash. - if root == nil || len(root) != smt.SmstRootSizeBytes { - logger.Error(fmt.Sprintf("received an invalid root hash of size: %d", len(root))) - return tokenomicstypes.ErrTokenomicsRootHashInvalid + claimComputeUnits, err := root.Sum() + if err != nil { + return tokenomicstypes.ErrTokenomicsRootHashInvalid.Wrapf("%v", err) } - claimComputeUnits := root.Sum() // Helpers for logging the same metadata throughout this function calls logger = logger.With( @@ -96,7 +109,11 @@ func (k Keeper) SettleSessionAccounting( logger.Info(fmt.Sprintf("About to start settling claim for %d compute units", claimComputeUnits)) // Calculate the amount of tokens to mint & burn - settlementCoin = k.getCoinFromComputeUnits(ctx, root) + settlementCoin, err = k.getCoinFromComputeUnits(ctx, root) + if err != nil { + return err + } + settlementCoins := cosmostypes.NewCoins(settlementCoin) logger.Info(fmt.Sprintf( @@ -194,10 +211,20 @@ func (k Keeper) SettleSessionAccounting( return nil } -func (k Keeper) getCoinFromComputeUnits(ctx context.Context, root smt.MerkleRoot) cosmostypes.Coin { +func (k Keeper) getCoinFromComputeUnits(ctx context.Context, root smt.MerkleSumRoot) (cosmostypes.Coin, error) { // Retrieve the existing tokenomics params params := k.GetParams(ctx) - upokt := math.NewInt(int64(root.Sum() * params.ComputeUnitsToTokensMultiplier)) - return cosmostypes.NewCoin("upokt", upokt) + sum, err := root.Sum() + if err != nil { + return cosmostypes.Coin{}, err + } + + upokt := math.NewInt(int64(sum * params.ComputeUnitsToTokensMultiplier)) + + if upokt.IsNegative() { + return cosmostypes.Coin{}, tokenomicstypes.ErrTokenomicsRootHashInvalid.Wrap("sum * compute_units_to_tokens_multiplier is negative") + } + + return cosmostypes.NewCoin(volatile.DenomuPOKT, upokt), nil } diff --git a/x/tokenomics/keeper/settle_session_accounting_test.go b/x/tokenomics/keeper/settle_session_accounting_test.go index 480761b66..86a74b04f 100644 --- a/x/tokenomics/keeper/settle_session_accounting_test.go +++ b/x/tokenomics/keeper/settle_session_accounting_test.go @@ -11,9 +11,11 @@ import ( cosmostypes "github.com/cosmos/cosmos-sdk/types" authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" banktypes "github.com/cosmos/cosmos-sdk/x/bank/types" - "github.com/pokt-network/smt" "github.com/stretchr/testify/require" + "github.com/pokt-network/smt" + + "github.com/pokt-network/poktroll/pkg/crypto/protocol" testkeeper "github.com/pokt-network/poktroll/testutil/keeper" testproof "github.com/pokt-network/poktroll/testutil/proof" "github.com/pokt-network/poktroll/testutil/sample" @@ -300,11 +302,10 @@ func TestSettleSessionAccounting_AppNotFound(t *testing.T) { func TestSettleSessionAccounting_InvalidRoot(t *testing.T) { keeper, ctx, appAddr, supplierAddr := testkeeper.TokenomicsKeeperWithActorAddrs(t) - rootHashSizeBytes := smt.SmstRootSizeBytes // Define test cases tests := []struct { desc string - root []byte // smst.MerkleRoot + root []byte // smst.MerkleSumRoot errExpected bool }{ { @@ -313,19 +314,19 @@ func TestSettleSessionAccounting_InvalidRoot(t *testing.T) { errExpected: true, }, { - desc: fmt.Sprintf("Less than %d bytes", rootHashSizeBytes), - root: make([]byte, rootHashSizeBytes-1), // Less than expected number of bytes + desc: fmt.Sprintf("Less than %d bytes", protocol.TrieRootSize), + root: make([]byte, protocol.TrieRootSize-1), // Less than expected number of bytes errExpected: true, }, { - desc: fmt.Sprintf("More than %d bytes", rootHashSizeBytes), - root: make([]byte, rootHashSizeBytes+1), // More than expected number of bytes + desc: fmt.Sprintf("More than %d bytes", protocol.TrieRootSize), + root: make([]byte, protocol.TrieRootSize+1), // More than expected number of bytes errExpected: true, }, { desc: "correct size but empty", root: func() []byte { - root := make([]byte, rootHashSizeBytes) // All 0s + root := make([]byte, protocol.TrieRootSize) // All 0s return root[:] }(), errExpected: false, @@ -333,7 +334,7 @@ func TestSettleSessionAccounting_InvalidRoot(t *testing.T) { { desc: "correct size but invalid value", root: func() []byte { - return bytes.Repeat([]byte("a"), rootHashSizeBytes) + return bytes.Repeat([]byte("a"), protocol.TrieRootSize) }(), errExpected: true, }, @@ -350,26 +351,12 @@ func TestSettleSessionAccounting_InvalidRoot(t *testing.T) { // Iterate over each test case for _, test := range tests { t.Run(test.desc, func(t *testing.T) { - // Use defer-recover to catch any panic - defer func() { - if r := recover(); r != nil { - t.Errorf("Test panicked: %s", r) - } - }() - // Setup claim by copying the testproof.BaseClaim and updating the root claim := testproof.BaseClaim(appAddr, supplierAddr, 0) - claim.RootHash = smt.MerkleRoot(test.root[:]) + claim.RootHash = smt.MerkleSumRoot(test.root[:]) // Execute test function - err := func() (err error) { - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("panic occurred: %v", r) - } - }() - return keeper.SettleSessionAccounting(ctx, &claim) - }() + err := keeper.SettleSessionAccounting(ctx, &claim) // Assert the error if test.errExpected {