Skip to content

Commit

Permalink
[TODO] chore: update smt.MerkleRoot#Sum() error handling (#672)
Browse files Browse the repository at this point in the history
Co-authored-by: Redouane Lakrache <r3d0ne@gmail.com>
bryanchriswhite and red-0ne committed Jul 19, 2024
1 parent 69d97d3 commit 5a3cfd8
Showing 12 changed files with 95 additions and 92 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -57,7 +57,7 @@ require (
// repo is the first obvious idea, but has to be carefully considered, automated, and is not
// a hard blocker.
github.com/pokt-network/shannon-sdk v0.0.0-20240628223057-7d2928722749
github.com/pokt-network/smt v0.11.1
github.com/pokt-network/smt v0.12.0
github.com/pokt-network/smt/kvstore/badger v0.0.0-20240109205447-868237978c0b
github.com/prometheus/client_golang v1.19.0
github.com/regen-network/gocuke v1.1.0
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -996,8 +996,8 @@ github.com/pokt-network/ring-go v0.1.0 h1:hF7mDR4VVCIqqDAsrloP8azM9y1mprc99YgnTj
github.com/pokt-network/ring-go v0.1.0/go.mod h1:8NHPH7H3EwrPX3XHfpyRI6bz4gApkE3+fd0XZRbMWP0=
github.com/pokt-network/shannon-sdk v0.0.0-20240628223057-7d2928722749 h1:V/3xzmykSABhAxRZLawWUoIPVlnp7EGCnCxFpLXD7R0=
github.com/pokt-network/shannon-sdk v0.0.0-20240628223057-7d2928722749/go.mod h1:MfoRhzPRlxiaY3xQyZo28B7ibDuhricA//TGGy48TwM=
github.com/pokt-network/smt v0.11.1 h1:ySN8PjrPDKyvzLcX0qTHR2s5ReaZnjq25z0B7p6AWl0=
github.com/pokt-network/smt v0.11.1/go.mod h1:S4Ho4OPkK2v2vUCHNtA49XDjqUC/OFYpBbynRVYmxvA=
github.com/pokt-network/smt v0.12.0 h1:uqru/0ykC4LnBoMacakobNOd1iRK69PlohqjMtLmYNA=
github.com/pokt-network/smt v0.12.0/go.mod h1:S4Ho4OPkK2v2vUCHNtA49XDjqUC/OFYpBbynRVYmxvA=
github.com/pokt-network/smt/kvstore/badger v0.0.0-20240109205447-868237978c0b h1:TjfgV3vgW0zW47Br/OgUXD4M8iyR74EYanbFfN4ed8o=
github.com/pokt-network/smt/kvstore/badger v0.0.0-20240109205447-868237978c0b/go.mod h1:GbzcG5ebj8twKmBL1VzdPM4NS44okwYXBfQaVXT+6yU=
github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI=
4 changes: 2 additions & 2 deletions pkg/client/supplier/client_test.go
Original file line number Diff line number Diff line change
@@ -2,7 +2,6 @@ package supplier_test

import (
"context"
"crypto/sha256"
"testing"
"time"

@@ -14,6 +13,7 @@ import (

"github.com/pokt-network/poktroll/pkg/client/keyring"
"github.com/pokt-network/poktroll/pkg/client/supplier"
"github.com/pokt-network/poktroll/pkg/crypto/protocol"
"github.com/pokt-network/poktroll/testutil/mockclient"
"github.com/pokt-network/poktroll/testutil/testclient/testkeyring"
"github.com/pokt-network/poktroll/testutil/testclient/testtx"
@@ -181,7 +181,7 @@ func TestSupplierClient_SubmitProof(t *testing.T) {
// Generating an ephemeral tree & spec just so we can submit
// a proof of the right size.
// TODO_TECHDEBT(#446): Centralize the configuration for the SMT spec.
tree := smt.NewSparseMerkleSumTrie(kvStore, sha256.New())
tree := smt.NewSparseMerkleSumTrie(kvStore, protocol.NewTrieHasher())
emptyPath := make([]byte, tree.PathHasherSize())
proof, err := tree.ProveClosest(emptyPath)
require.NoError(t, err)
5 changes: 5 additions & 0 deletions pkg/crypto/protocol/hasher.go
Original file line number Diff line number Diff line change
@@ -4,8 +4,13 @@ import "crypto/sha256"

const (
RelayHasherSize = sha256.Size
TrieHasherSize = sha256.Size
TrieRootSize = TrieHasherSize + trieRootMetadataSize
// TODO_CONSIDERATION: Export this from the SMT package.
trieRootMetadataSize = 16
)

var (
NewRelayHasher = sha256.New
NewTrieHasher = sha256.New
)
3 changes: 2 additions & 1 deletion pkg/relayer/session/sessiontree.go
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@ import (
"github.com/pokt-network/smt"
"github.com/pokt-network/smt/kvstore/badger"

"github.com/pokt-network/poktroll/pkg/crypto/protocol"
"github.com/pokt-network/poktroll/pkg/relayer"
sessiontypes "github.com/pokt-network/poktroll/x/session/types"
)
@@ -85,7 +86,7 @@ func NewSessionTree(

// Create the SMST from the KVStore and a nil value hasher so the proof would
// contain a non-hashed Relay that could be used to validate the proof on-chain.
trie := smt.NewSparseMerkleSumTrie(treeStore, sha256.New(), smt.WithValueHasher(nil))
trie := smt.NewSparseMerkleSumTrie(treeStore, protocol.NewTrieHasher(), smt.WithValueHasher(nil))

sessionTree := &sessionTree{
sessionHeader: sessionHeader,
3 changes: 1 addition & 2 deletions tests/integration/tokenomics/relay_mining_difficulty_test.go
Original file line number Diff line number Diff line change
@@ -2,7 +2,6 @@ package integration_test

import (
"context"
"crypto/sha256"
"testing"

"github.com/pokt-network/smt"
@@ -193,7 +192,7 @@ func prepareSMST(
integrationApp.GetRingClient(),
)

trie := smt.NewSparseMerkleSumTrie(kvStore, sha256.New(), smt.WithValueHasher(nil))
trie := smt.NewSparseMerkleSumTrie(kvStore, protocol.NewTrieHasher(), smt.WithValueHasher(nil))
err = trie.Update(minedRelay.Hash, minedRelay.Bytes, 1)
require.NoError(t, err)

39 changes: 23 additions & 16 deletions testutil/proof/fixture_generators.go
Original file line number Diff line number Diff line change
@@ -5,9 +5,11 @@ import (
"math/rand"
"testing"

"github.com/pokt-network/smt"
"github.com/stretchr/testify/require"

"github.com/pokt-network/smt"

"github.com/pokt-network/poktroll/pkg/crypto/protocol"
testsession "github.com/pokt-network/poktroll/testutil/session"
prooftypes "github.com/pokt-network/poktroll/x/proof/types"
sessiontypes "github.com/pokt-network/poktroll/x/session/types"
@@ -51,32 +53,37 @@ func ClaimWithRandomHash(t *testing.T, appAddr, supplierAddr string, sum uint64)
// TODO_MAINNET: Revisit if the SMT should be big or little Endian. Refs:
// https://github.com/pokt-network/smt/pull/46#discussion_r1636975124
// https://github.com/pokt-network/smt/blob/ea585c6c3bc31c804b6bafa83e985e473b275580/smst.go#L23C10-L23C76
func SmstRootWithSum(sum uint64) smt.MerkleRoot {
root := [smt.SmstRootSizeBytes]byte{}
// Insert the sum into the root hash
binary.BigEndian.PutUint64(root[smt.SmtRootSizeBytes:], sum)
// Insert the count into the root hash
// TODO_TECHDEBT: This is a hard-coded count of 1, but could be a parameter.
// TODO_TECHDEBT: We are assuming the sum takes up 8 bytes.
binary.BigEndian.PutUint64(root[smt.SmtRootSizeBytes+8:], 1)
return smt.MerkleRoot(root[:])
func SmstRootWithSum(sum uint64) smt.MerkleSumRoot {
root := [protocol.TrieRootSize]byte{}
return encodeSum(root, sum)
}

// RandSmstRootWithSum returns a randomized SMST root with the given sum that
// can be used for testing. Randomizing the root is a simple way to randomize
// test claim hashes for testing proof requirement cases.
func RandSmstRootWithSum(t *testing.T, sum uint64) smt.MerkleRoot {
func RandSmstRootWithSum(t *testing.T, sum uint64) smt.MerkleSumRoot {
t.Helper()

root := [smt.SmstRootSizeBytes]byte{}
root := [protocol.TrieRootSize]byte{}
// Only populate the first 32 bytes with random data, leave the last 8 bytes for the sum.
_, err := rand.Read(root[:smt.SmtRootSizeBytes]) //nolint:staticcheck // We need a deterministic pseudo-random source.
_, err := rand.Read(root[:protocol.TrieHasherSize]) //nolint:staticcheck // We need a deterministic pseudo-random source.
require.NoError(t, err)

binary.BigEndian.PutUint64(root[smt.SmtRootSizeBytes:], sum)
return encodeSum(root, sum)
}

// encodeSum returns a copy of the given root, binary encodes the given sum,
// and stores the encoded sum in the root copy.
func encodeSum(r [protocol.TrieRootSize]byte, sum uint64) smt.MerkleSumRoot {
root := make([]byte, protocol.TrieRootSize)
copy(root, r[:])

// Insert the sum into the root hash
binary.BigEndian.PutUint64(root[protocol.TrieHasherSize:], sum)
// Insert the count into the root hash
// TODO_TECHDEBT: This is a hard-coded count of 1, but could be a parameter.
// TODO_TECHDEBT: We are assuming the sum takes up 8 bytes.
binary.BigEndian.PutUint64(root[smt.SmtRootSizeBytes+8:], 1)
return smt.MerkleRoot(root[:])
binary.BigEndian.PutUint64(root[protocol.TrieHasherSize+8:], 1)

return root
}
1 change: 1 addition & 0 deletions x/application/keeper/auto_undelegate.go
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@ import (

sdk "github.com/cosmos/cosmos-sdk/types"
proto "github.com/cosmos/gogoproto/proto"

gatewaytypes "github.com/pokt-network/poktroll/x/gateway/types"
)

5 changes: 3 additions & 2 deletions x/proof/keeper/msg_server_create_claim_test.go
Original file line number Diff line number Diff line change
@@ -5,11 +5,12 @@ import (

abci "github.com/cometbft/cometbft/abci/types"
cosmostypes "github.com/cosmos/cosmos-sdk/types"
"github.com/pokt-network/smt"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/pokt-network/smt"

keepertest "github.com/pokt-network/poktroll/testutil/keeper"
testproof "github.com/pokt-network/poktroll/testutil/proof"
"github.com/pokt-network/poktroll/testutil/sample"
@@ -490,7 +491,7 @@ func newTestClaimMsg(
supplierAddr string,
appAddr string,
service *sharedtypes.Service,
merkleRoot smt.MerkleRoot,
merkleRoot smt.MerkleSumRoot,
) *types.MsgCreateClaim {
t.Helper()

29 changes: 2 additions & 27 deletions x/proof/types/claim.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package types

import (
"fmt"

"github.com/cometbft/cometbft/crypto"

"github.com/pokt-network/smt"
@@ -11,36 +9,13 @@ import (
// GetNumComputeUnits returns the number of compute units for a given claim
// as determined by the sum of the root hash.
func (claim *Claim) GetNumComputeUnits() (numComputeUnits uint64, err error) {
// NB: smt.MerkleRoot#Sum() will panic if the root hash is not valid.
// Convert this panic into an error return.
defer func() {
if r := recover(); r != nil {
numComputeUnits = 0
err = fmt.Errorf(
"unable to get sum of invalid merkle root: %x; error: %v",
claim.GetRootHash(), r,
)
}
}()

return smt.MerkleRoot(claim.GetRootHash()).Sum(), nil
return smt.MerkleSumRoot(claim.GetRootHash()).Sum()
}

// GetNumRelays returns the number of relays for a given claim
// as determined by the count of the root hash.
func (claim *Claim) GetNumRelays() (numRelays uint64, err error) {
// Convert this panic into an error return.
defer func() {
if r := recover(); r != nil {
numRelays = 0
err = fmt.Errorf(
"unable to get count of invalid merkle root: %x; error: %v",
claim.GetRootHash(), r,
)
}
}()

return smt.MerkleRoot(claim.GetRootHash()).Count(), nil
return smt.MerkleSumRoot(claim.GetRootHash()).Count()
}

// GetHash returns the SHA-256 hash of the serialized claim.
55 changes: 41 additions & 14 deletions x/tokenomics/keeper/settle_session_accounting.go
Original file line number Diff line number Diff line change
@@ -6,6 +6,10 @@ import (

"cosmossdk.io/math"
cosmostypes "github.com/cosmos/cosmos-sdk/types"

"github.com/pokt-network/poktroll/app/volatile"
"github.com/pokt-network/poktroll/pkg/crypto/protocol"

"github.com/pokt-network/smt"

"github.com/pokt-network/poktroll/telemetry"
@@ -26,15 +30,20 @@ import (
func (k Keeper) SettleSessionAccounting(
ctx context.Context,
claim *prooftypes.Claim,
) error {
) (err error) {
logger := k.Logger().With("method", "SettleSessionAccounting")

settlementCoin := cosmostypes.NewCoin("upokt", math.NewInt(0))
isSuccessful := false
// This is emitted only when the function returns.
defer telemetry.EventSuccessCounter(
"settle_session_accounting",
func() float32 { return float32(settlementCoin.Amount.Int64()) },
func() float32 {
if settlementCoin.Amount.BigInt() == nil {
return 0
}
return float32(settlementCoin.Amount.Int64())
},
func() bool { return isSuccessful },
)

@@ -50,7 +59,7 @@ func (k Keeper) SettleSessionAccounting(
logger.Error("received a nil session header")
return tokenomicstypes.ErrTokenomicsSessionHeaderNil
}
if err := sessionHeader.ValidateBasic(); err != nil {
if err = sessionHeader.ValidateBasic(); err != nil {
logger.Error("received an invalid session header", "error", err)
return tokenomicstypes.ErrTokenomicsSessionHeaderInvalid
}
@@ -66,15 +75,19 @@ func (k Keeper) SettleSessionAccounting(
}

// Retrieve the sum of the root as a proxy into the amount of work done
root := (smt.MerkleRoot)(claim.GetRootHash())
root := (smt.MerkleSumRoot)(claim.GetRootHash())

if !root.HasDigestSize(protocol.TrieHasherSize) {
return tokenomicstypes.ErrTokenomicsRootHashInvalid.Wrapf(
"root hash has invalid digest size (%d), expected (%d)",
root.DigestSize(), protocol.TrieHasherSize,
)
}

// TODO_BLOCKER(@Olshansk): This check should be the responsibility of the SMST package
// since it's used to get compute units from the root hash.
if root == nil || len(root) != smt.SmstRootSizeBytes {
logger.Error(fmt.Sprintf("received an invalid root hash of size: %d", len(root)))
return tokenomicstypes.ErrTokenomicsRootHashInvalid
claimComputeUnits, err := root.Sum()
if err != nil {
return tokenomicstypes.ErrTokenomicsRootHashInvalid.Wrapf("%v", err)
}
claimComputeUnits := root.Sum()

// Helpers for logging the same metadata throughout this function calls
logger = logger.With(
@@ -96,7 +109,11 @@ func (k Keeper) SettleSessionAccounting(
logger.Info(fmt.Sprintf("About to start settling claim for %d compute units", claimComputeUnits))

// Calculate the amount of tokens to mint & burn
settlementCoin = k.getCoinFromComputeUnits(ctx, root)
settlementCoin, err = k.getCoinFromComputeUnits(ctx, root)
if err != nil {
return err
}

settlementCoins := cosmostypes.NewCoins(settlementCoin)

logger.Info(fmt.Sprintf(
@@ -194,10 +211,20 @@ func (k Keeper) SettleSessionAccounting(
return nil
}

func (k Keeper) getCoinFromComputeUnits(ctx context.Context, root smt.MerkleRoot) cosmostypes.Coin {
func (k Keeper) getCoinFromComputeUnits(ctx context.Context, root smt.MerkleSumRoot) (cosmostypes.Coin, error) {
// Retrieve the existing tokenomics params
params := k.GetParams(ctx)

upokt := math.NewInt(int64(root.Sum() * params.ComputeUnitsToTokensMultiplier))
return cosmostypes.NewCoin("upokt", upokt)
sum, err := root.Sum()
if err != nil {
return cosmostypes.Coin{}, err
}

upokt := math.NewInt(int64(sum * params.ComputeUnitsToTokensMultiplier))

if upokt.IsNegative() {
return cosmostypes.Coin{}, tokenomicstypes.ErrTokenomicsRootHashInvalid.Wrap("sum * compute_units_to_tokens_multiplier is negative")
}

return cosmostypes.NewCoin(volatile.DenomuPOKT, upokt), nil
}
37 changes: 12 additions & 25 deletions x/tokenomics/keeper/settle_session_accounting_test.go
Original file line number Diff line number Diff line change
@@ -11,9 +11,11 @@ import (
cosmostypes "github.com/cosmos/cosmos-sdk/types"
authtypes "github.com/cosmos/cosmos-sdk/x/auth/types"
banktypes "github.com/cosmos/cosmos-sdk/x/bank/types"
"github.com/pokt-network/smt"
"github.com/stretchr/testify/require"

"github.com/pokt-network/smt"

"github.com/pokt-network/poktroll/pkg/crypto/protocol"
testkeeper "github.com/pokt-network/poktroll/testutil/keeper"
testproof "github.com/pokt-network/poktroll/testutil/proof"
"github.com/pokt-network/poktroll/testutil/sample"
@@ -300,11 +302,10 @@ func TestSettleSessionAccounting_AppNotFound(t *testing.T) {
func TestSettleSessionAccounting_InvalidRoot(t *testing.T) {
keeper, ctx, appAddr, supplierAddr := testkeeper.TokenomicsKeeperWithActorAddrs(t)

rootHashSizeBytes := smt.SmstRootSizeBytes
// Define test cases
tests := []struct {
desc string
root []byte // smst.MerkleRoot
root []byte // smst.MerkleSumRoot
errExpected bool
}{
{
@@ -313,27 +314,27 @@ func TestSettleSessionAccounting_InvalidRoot(t *testing.T) {
errExpected: true,
},
{
desc: fmt.Sprintf("Less than %d bytes", rootHashSizeBytes),
root: make([]byte, rootHashSizeBytes-1), // Less than expected number of bytes
desc: fmt.Sprintf("Less than %d bytes", protocol.TrieRootSize),
root: make([]byte, protocol.TrieRootSize-1), // Less than expected number of bytes
errExpected: true,
},
{
desc: fmt.Sprintf("More than %d bytes", rootHashSizeBytes),
root: make([]byte, rootHashSizeBytes+1), // More than expected number of bytes
desc: fmt.Sprintf("More than %d bytes", protocol.TrieRootSize),
root: make([]byte, protocol.TrieRootSize+1), // More than expected number of bytes
errExpected: true,
},
{
desc: "correct size but empty",
root: func() []byte {
root := make([]byte, rootHashSizeBytes) // All 0s
root := make([]byte, protocol.TrieRootSize) // All 0s
return root[:]
}(),
errExpected: false,
},
{
desc: "correct size but invalid value",
root: func() []byte {
return bytes.Repeat([]byte("a"), rootHashSizeBytes)
return bytes.Repeat([]byte("a"), protocol.TrieRootSize)
}(),
errExpected: true,
},
@@ -350,26 +351,12 @@ func TestSettleSessionAccounting_InvalidRoot(t *testing.T) {
// Iterate over each test case
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
// Use defer-recover to catch any panic
defer func() {
if r := recover(); r != nil {
t.Errorf("Test panicked: %s", r)
}
}()

// Setup claim by copying the testproof.BaseClaim and updating the root
claim := testproof.BaseClaim(appAddr, supplierAddr, 0)
claim.RootHash = smt.MerkleRoot(test.root[:])
claim.RootHash = smt.MerkleSumRoot(test.root[:])

// Execute test function
err := func() (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic occurred: %v", r)
}
}()
return keeper.SettleSessionAccounting(ctx, &claim)
}()
err := keeper.SettleSessionAccounting(ctx, &claim)

// Assert the error
if test.errExpected {

0 comments on commit 5a3cfd8

Please sign in to comment.