From 3741a01ab08dc05e8b87d620799af53b8244b3ed Mon Sep 17 00:00:00 2001 From: shakezula Date: Tue, 11 Jul 2023 14:52:57 -0600 Subject: [PATCH] [chore] submodule pattern updates to TreeStore --- persistence/module.go | 11 ++- persistence/sql/sql.go | 11 --- persistence/trees/module.go | 58 ++++++++++----- persistence/trees/module_test.go | 101 +++++++++++++++++++++++++++ persistence/trees/trees.go | 82 +++++++++++++--------- runtime/bus.go | 4 ++ shared/modules/persistence_module.go | 2 +- 7 files changed, 203 insertions(+), 66 deletions(-) create mode 100644 persistence/trees/module_test.go diff --git a/persistence/module.go b/persistence/module.go index dec15112b4..b80ada2429 100644 --- a/persistence/module.go +++ b/persistence/module.go @@ -106,9 +106,14 @@ func (*persistenceModule) Create(bus modules.Bus, options ...modules.ModuleOptio treeModule, err := trees.Create( bus, trees.WithTreeStoreDirectory(persistenceCfg.TreesStoreDir), - trees.WithLogger(m.logger)) + trees.WithLogger(m.logger), + trees.WithTxIndexer(txIndexer)) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create TreeStoreModule: %w", err) + } + treeStoreModule, ok := treeModule.(modules.TreeStoreModule) + if !ok { + return nil, fmt.Errorf("failed to cast %T as TreeStoreModule", treeModule) } m.config = persistenceCfg @@ -117,7 +122,7 @@ func (*persistenceModule) Create(bus modules.Bus, options ...modules.ModuleOptio m.blockStore = blockStore m.txIndexer = txIndexer - m.stateTrees = treeModule + m.stateTrees = treeStoreModule // TECHDEBT: reconsider if this is the best place to call `populateGenesisState`. Note that // this forces the genesis state to be reloaded on every node startup until state diff --git a/persistence/sql/sql.go b/persistence/sql/sql.go index 48bbec2d38..7dea7cee23 100644 --- a/persistence/sql/sql.go +++ b/persistence/sql/sql.go @@ -6,7 +6,6 @@ import ( "fmt" "github.com/jackc/pgx/v5" - "github.com/pokt-network/pocket/persistence/indexer" ptypes "github.com/pokt-network/pocket/persistence/types" coreTypes "github.com/pokt-network/pocket/shared/core/types" ) @@ -91,16 +90,6 @@ func GetAccountsUpdated( return accounts, nil } -// GetTransactions takes a transaction indexer and returns the transactions for the current height -func GetTransactions(txi indexer.TxIndexer, height uint64) ([]*coreTypes.IndexedTransaction, error) { - // TECHDEBT(#813): Avoid this cast to int64 - indexedTxs, err := txi.GetByHeight(int64(height), false) - if err != nil { - return nil, fmt.Errorf("failed to get transactions by height: %w", err) - } - return indexedTxs, nil -} - // GetPools returns the pools updated at the given height func GetPools(pgtx pgx.Tx, height uint64) ([]*coreTypes.Account, error) { pools, err := GetAccountsUpdated(pgtx, ptypes.Pool, height) diff --git a/persistence/trees/module.go b/persistence/trees/module.go index 9c0d5eff01..452a44e458 100644 --- a/persistence/trees/module.go +++ b/persistence/trees/module.go @@ -3,19 +3,26 @@ package trees import ( "fmt" + "github.com/pokt-network/pocket/persistence/indexer" "github.com/pokt-network/pocket/persistence/kvstore" "github.com/pokt-network/pocket/shared/modules" "github.com/pokt-network/smt" ) -func (*treeStore) Create(bus modules.Bus, options ...modules.TreeStoreOption) (modules.TreeStoreModule, error) { - m := &treeStore{} +var _ modules.Module = &TreeStore{} + +func (*TreeStore) Create(bus modules.Bus, options ...modules.ModuleOption) (modules.Module, error) { + m := &TreeStore{} + + bus.RegisterModule(m) for _, option := range options { option(m) } - m.SetBus(bus) + if m.TXI == nil { + m.TXI = bus.GetPersistenceModule().GetTxIndexer() + } if err := m.setupTrees(); err != nil { return nil, err @@ -24,14 +31,14 @@ func (*treeStore) Create(bus modules.Bus, options ...modules.TreeStoreOption) (m return m, nil } -func Create(bus modules.Bus, options ...modules.TreeStoreOption) (modules.TreeStoreModule, error) { - return new(treeStore).Create(bus, options...) +func Create(bus modules.Bus, options ...modules.ModuleOption) (modules.Module, error) { + return new(TreeStore).Create(bus, options...) } // WithLogger assigns a logger for the tree store -func WithLogger(logger *modules.Logger) modules.TreeStoreOption { - return func(m modules.TreeStoreModule) { - if mod, ok := m.(*treeStore); ok { +func WithLogger(logger *modules.Logger) modules.ModuleOption { + return func(m modules.InjectableModule) { + if mod, ok := m.(*TreeStore); ok { mod.logger = logger } } @@ -39,20 +46,37 @@ func WithLogger(logger *modules.Logger) modules.TreeStoreOption { // WithTreeStoreDirectory assigns the path where the tree store // saves its data. -func WithTreeStoreDirectory(path string) modules.TreeStoreOption { - return func(m modules.TreeStoreModule) { - if mod, ok := m.(*treeStore); ok { - mod.treeStoreDir = path +func WithTreeStoreDirectory(path string) modules.ModuleOption { + return func(m modules.InjectableModule) { + mod, ok := m.(*TreeStore) + if ok { + mod.TreeStoreDir = path + } + } +} + +// WithTxIndexer assigns a TxIndexer for use during operation. +func WithTxIndexer(txi indexer.TxIndexer) modules.ModuleOption { + return func(m modules.InjectableModule) { + mod, ok := m.(*TreeStore) + if ok { + mod.TXI = txi } } } -func (t *treeStore) setupTrees() error { - if t.treeStoreDir == ":memory:" { +func (t *TreeStore) GetModuleName() string { return modules.TreeStoreModuleName } +func (t *TreeStore) Start() error { return nil } +func (t *TreeStore) Stop() error { return nil } +func (t *TreeStore) GetBus() modules.Bus { return t.Bus } +func (t *TreeStore) SetBus(bus modules.Bus) { t.Bus = bus } + +func (t *TreeStore) setupTrees() error { + if t.TreeStoreDir == ":memory:" { return t.setupInMemory() } - nodeStore, err := kvstore.NewKVStore(fmt.Sprintf("%s/%s_nodes", t.treeStoreDir, RootTreeName)) + nodeStore, err := kvstore.NewKVStore(fmt.Sprintf("%s/%s_nodes", t.TreeStoreDir, RootTreeName)) if err != nil { return err } @@ -64,7 +88,7 @@ func (t *treeStore) setupTrees() error { t.merkleTrees = make(map[string]*stateTree, len(stateTreeNames)) for i := 0; i < len(stateTreeNames); i++ { - nodeStore, err := kvstore.NewKVStore(fmt.Sprintf("%s/%s_nodes", t.treeStoreDir, stateTreeNames[i])) + nodeStore, err := kvstore.NewKVStore(fmt.Sprintf("%s/%s_nodes", t.TreeStoreDir, stateTreeNames[i])) if err != nil { return err } @@ -78,7 +102,7 @@ func (t *treeStore) setupTrees() error { return nil } -func (t *treeStore) setupInMemory() error { +func (t *TreeStore) setupInMemory() error { nodeStore := kvstore.NewMemKVStore() t.rootTree = &stateTree{ name: RootTreeName, diff --git a/persistence/trees/module_test.go b/persistence/trees/module_test.go new file mode 100644 index 0000000000..a238d84917 --- /dev/null +++ b/persistence/trees/module_test.go @@ -0,0 +1,101 @@ +package trees_test + +import ( + "fmt" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + "github.com/pokt-network/pocket/internal/testutil" + "github.com/pokt-network/pocket/persistence/trees" + "github.com/pokt-network/pocket/runtime/genesis" + "github.com/pokt-network/pocket/runtime/test_artifacts" + coreTypes "github.com/pokt-network/pocket/shared/core/types" + cryptoPocket "github.com/pokt-network/pocket/shared/crypto" + "github.com/pokt-network/pocket/shared/modules" + mockModules "github.com/pokt-network/pocket/shared/modules/mocks" +) + +const ( + serviceURLFormat = "node%d.consensus:42069" +) + +func TestTreeStore_Create(t *testing.T) { + ctrl := gomock.NewController(t) + mockRuntimeMgr := mockModules.NewMockRuntimeMgr(ctrl) + mockBus := createMockBus(t, mockRuntimeMgr) + + genesisStateMock := createMockGenesisState(nil) + persistenceMock := preparePersistenceMock(t, mockBus, genesisStateMock) + + mockBus.EXPECT().GetPersistenceModule().Return(persistenceMock).AnyTimes() + persistenceMock.EXPECT().GetBus().AnyTimes().Return(mockBus) + persistenceMock.EXPECT().NewRWContext(int64(0)).AnyTimes() + persistenceMock.EXPECT().GetTxIndexer().AnyTimes() + + treemod, err := trees.Create(mockBus, + trees.WithTreeStoreDirectory(":memory:")) + assert.NoError(t, err) + got := treemod.GetBus() + assert.Equal(t, got, mockBus) +} + +func TestTreeStore_DebugClearAll(t *testing.T) { + // TODO: Write test case for the DebugClearAll method + t.Skip("TODO: Write test case for DebugClearAll method") +} + +// createMockGenesisState configures and returns a mocked GenesisState +func createMockGenesisState(valKeys []cryptoPocket.PrivateKey) *genesis.GenesisState { + genesisState := new(genesis.GenesisState) + validators := make([]*coreTypes.Actor, len(valKeys)) + for i, valKey := range valKeys { + addr := valKey.Address().String() + mockActor := &coreTypes.Actor{ + ActorType: coreTypes.ActorType_ACTOR_TYPE_VAL, + Address: addr, + PublicKey: valKey.PublicKey().String(), + ServiceUrl: validatorId(i + 1), + StakedAmount: test_artifacts.DefaultStakeAmountString, + PausedHeight: int64(0), + UnstakingHeight: int64(0), + Output: addr, + } + validators[i] = mockActor + } + genesisState.Validators = validators + + return genesisState +} + +// Persistence mock - only needed for validatorMap access +func preparePersistenceMock(t *testing.T, busMock *mockModules.MockBus, genesisState *genesis.GenesisState) *mockModules.MockPersistenceModule { + ctrl := gomock.NewController(t) + + persistenceModuleMock := mockModules.NewMockPersistenceModule(ctrl) + readCtxMock := mockModules.NewMockPersistenceReadContext(ctrl) + + readCtxMock.EXPECT().GetAllValidators(gomock.Any()).Return(genesisState.GetValidators(), nil).AnyTimes() + readCtxMock.EXPECT().GetAllStakedActors(gomock.Any()).DoAndReturn(func(height int64) ([]*coreTypes.Actor, error) { + return testutil.Concatenate[*coreTypes.Actor]( + genesisState.GetValidators(), + genesisState.GetServicers(), + genesisState.GetFishermen(), + genesisState.GetApplications(), + ), nil + }).AnyTimes() + persistenceModuleMock.EXPECT().NewReadContext(gomock.Any()).Return(readCtxMock, nil).AnyTimes() + readCtxMock.EXPECT().Release().AnyTimes() + + persistenceModuleMock.EXPECT().GetBus().Return(busMock).AnyTimes() + persistenceModuleMock.EXPECT().SetBus(busMock).AnyTimes() + persistenceModuleMock.EXPECT().GetModuleName().Return(modules.PersistenceModuleName).AnyTimes() + busMock.RegisterModule(persistenceModuleMock) + + return persistenceModuleMock +} + +func validatorId(i int) string { + return fmt.Sprintf(serviceURLFormat, i) +} diff --git a/persistence/trees/trees.go b/persistence/trees/trees.go index 2f397c7816..9d47675a5f 100644 --- a/persistence/trees/trees.go +++ b/persistence/trees/trees.go @@ -71,26 +71,29 @@ type stateTree struct { nodeStore kvstore.KVStore } -var _ modules.TreeStoreModule = &treeStore{} +var _ modules.TreeStoreModule = &TreeStore{} -// treeStore stores a set of merkle trees that +// TreeStore stores a set of merkle trees that // it manages. It fulfills the modules.TreeStore interface. // * It is responsible for atomic commit or rollback behavior // of the underlying trees by utilizing the lazy loading // functionality provided by the underlying smt library. -type treeStore struct { +type TreeStore struct { base_modules.IntegrableModule - logger *modules.Logger - treeStoreDir string + logger *modules.Logger + Bus modules.Bus + TXI indexer.TxIndexer + + TreeStoreDir string rootTree *stateTree merkleTrees map[string]*stateTree } -// GetTree returns the name, root hash, and nodeStore for the matching tree tree -// stored in the TreeStore. This enables the caller to import the smt and not -// change the one stored -func (t *treeStore) GetTree(name string) ([]byte, kvstore.KVStore) { +// GetTree returns the root hash and nodeStore for the matching tree +// stored in the TreeStore. This enables the caller to import the smt +// and not change the one stored. +func (t *TreeStore) GetTree(name string) ([]byte, kvstore.KVStore) { if name == RootTreeName { return t.rootTree.tree.Root(), t.rootTree.nodeStore } @@ -100,18 +103,24 @@ func (t *treeStore) GetTree(name string) ([]byte, kvstore.KVStore) { return nil, nil } -// Update takes a transaction and a height and updates -// all of the trees in the treeStore for that height. -func (t *treeStore) Update(pgtx pgx.Tx, height uint64) (string, error) { - txi := t.GetBus().GetPersistenceModule().GetTxIndexer() +// Update takes a pgx transaction and a height and updates all of the trees in the TreeStore for that height. +// It is atomic and handles its own savepoint and rollback creation. +func (t *TreeStore) Update(pgtx pgx.Tx, height uint64) (string, error) { t.logger.Info().Msgf("🌴 updating state trees at height %d", height) - return t.updateMerkleTrees(pgtx, txi, height) + + stateHash, err := t.updateMerkleTrees(pgtx, t.TXI, height) + if err != nil { + t.Rollback() + return "", fmt.Errorf("failed to update merkle trees: %w", err) + } + + return stateHash, nil } // DebugClearAll is used by the debug cli to completely reset all merkle trees. // This should only be called by the debug CLI. // TECHDEBT: Move this into a separate file with a debug build flag to avoid accidental usage in prod -func (t *treeStore) DebugClearAll() error { +func (t *TreeStore) DebugClearAll() error { if err := t.rootTree.nodeStore.ClearAll(); err != nil { return fmt.Errorf("failed to clear root node store: %w", err) } @@ -132,8 +141,10 @@ func (t *treeStore) GetModuleName() string { } // updateMerkleTrees updates all of the merkle trees in order defined by `numMerkleTrees` -// * it returns the new state hash capturing the state of all the trees or an error if one occurred -func (t *treeStore) updateMerkleTrees(pgtx pgx.Tx, txi indexer.TxIndexer, height uint64) (string, error) { +// * it returns the new state hash capturing the state of all the trees or an error if one occurred. +// * This function does not commit state to disk. The caller must manually invoke `Commit` to persist +// changes to disk. +func (t *TreeStore) updateMerkleTrees(pgtx pgx.Tx, txi indexer.TxIndexer, height uint64) (string, error) { for treeName := range t.merkleTrees { switch treeName { // Actor Merkle Trees @@ -207,18 +218,11 @@ func (t *treeStore) updateMerkleTrees(pgtx pgx.Tx, txi indexer.TxIndexer, height return t.getStateHash(), nil } -func (t *treeStore) commit() error { - for treeName, stateTree := range t.merkleTrees { - if err := stateTree.tree.Commit(); err != nil { - return fmt.Errorf("failed to commit %s: %w", treeName, err) - } - } - return nil -} - -func (t *treeStore) getStateHash() string { +func (t *TreeStore) getStateHash() string { for _, stateTree := range t.merkleTrees { - if err := t.rootTree.tree.Update([]byte(stateTree.name), stateTree.tree.Root()); err != nil { + key := []byte(stateTree.name) + val := stateTree.tree.Root() + if err := t.rootTree.tree.Update(key, val); err != nil { log.Fatalf("failed to update root tree with %s tree's hash: %v", stateTree.name, err) } } @@ -234,7 +238,7 @@ func (t *treeStore) getStateHash() string { //////////////////////// // NB: I think this needs to be done manually for all 4 types. -func (t *treeStore) updateActorsTree(actorType coreTypes.ActorType, actors []*coreTypes.Actor) error { +func (t *TreeStore) updateActorsTree(actorType coreTypes.ActorType, actors []*coreTypes.Actor) error { for _, actor := range actors { bzAddr, err := hex.DecodeString(actor.GetAddress()) if err != nil { @@ -262,7 +266,7 @@ func (t *treeStore) updateActorsTree(actorType coreTypes.ActorType, actors []*co // Account Tree Helpers // ////////////////////////// -func (t *treeStore) updateAccountTrees(accounts []*coreTypes.Account) error { +func (t *TreeStore) updateAccountTrees(accounts []*coreTypes.Account) error { for _, account := range accounts { bzAddr, err := hex.DecodeString(account.GetAddress()) if err != nil { @@ -282,7 +286,7 @@ func (t *treeStore) updateAccountTrees(accounts []*coreTypes.Account) error { return nil } -func (t *treeStore) updatePoolTrees(pools []*coreTypes.Account) error { +func (t *TreeStore) updatePoolTrees(pools []*coreTypes.Account) error { for _, pool := range pools { bzAddr, err := hex.DecodeString(pool.GetAddress()) if err != nil { @@ -306,7 +310,7 @@ func (t *treeStore) updatePoolTrees(pools []*coreTypes.Account) error { // Data Tree Helpers // /////////////////////// -func (t *treeStore) updateTransactionsTree(indexedTxs []*coreTypes.IndexedTransaction) error { +func (t *TreeStore) updateTransactionsTree(indexedTxs []*coreTypes.IndexedTransaction) error { for _, idxTx := range indexedTxs { txBz := idxTx.GetTx() txHash := crypto.SHA3Hash(txBz) @@ -317,7 +321,7 @@ func (t *treeStore) updateTransactionsTree(indexedTxs []*coreTypes.IndexedTransa return nil } -func (t *treeStore) updateParamsTree(params []*coreTypes.Param) error { +func (t *TreeStore) updateParamsTree(params []*coreTypes.Param) error { for _, param := range params { paramBz, err := codec.GetCodec().Marshal(param) paramKey := crypto.SHA3Hash([]byte(param.Name)) @@ -332,7 +336,7 @@ func (t *treeStore) updateParamsTree(params []*coreTypes.Param) error { return nil } -func (t *treeStore) updateFlagsTree(flags []*coreTypes.Flag) error { +func (t *TreeStore) updateFlagsTree(flags []*coreTypes.Flag) error { for _, flag := range flags { flagBz, err := codec.GetCodec().Marshal(flag) flagKey := crypto.SHA3Hash([]byte(flag.Name)) @@ -346,3 +350,13 @@ func (t *treeStore) updateFlagsTree(flags []*coreTypes.Flag) error { return nil } + +// getTransactions takes a transaction indexer and returns the transactions for the current height +func getTransactions(txi indexer.TxIndexer, height uint64) ([]*coreTypes.IndexedTransaction, error) { + // TECHDEBT(#813): Avoid this cast to int64 + indexedTxs, err := txi.GetByHeight(int64(height), false) + if err != nil { + return nil, fmt.Errorf("failed to get transactions by height: %w", err) + } + return indexedTxs, nil +} diff --git a/runtime/bus.go b/runtime/bus.go index 53df657528..1a72aac563 100644 --- a/runtime/bus.go +++ b/runtime/bus.go @@ -71,6 +71,10 @@ func (m *bus) GetPersistenceModule() modules.PersistenceModule { return getModuleFromRegistry[modules.PersistenceModule](m, modules.PersistenceModuleName) } +func (m *bus) GetTreeStoreModule() modules.TreeStoreModule { + return getModuleFromRegistry[modules.TreeStoreModule](m, modules.TreeStoreModuleName) +} + func (m *bus) GetP2PModule() modules.P2PModule { return getModuleFromRegistry[modules.P2PModule](m, modules.P2PModuleName) } diff --git a/shared/modules/persistence_module.go b/shared/modules/persistence_module.go index 0dbccfc6e2..393c54165a 100644 --- a/shared/modules/persistence_module.go +++ b/shared/modules/persistence_module.go @@ -1,6 +1,6 @@ package modules -//go:generate mockgen -destination=./mocks/persistence_module_mock.go github.com/pokt-network/pocket/shared/modules PersistenceModule,PersistenceRWContext,PersistenceReadContext,PersistenceWriteContext,PersistenceLocalContext +//go:generate mockgen -destination=./mocks/persistence_module_mock.go github.com/pokt-network/pocket/shared/modules PersistenceModule,PersistenceRWContext,PersistenceReadContext,PersistenceWriteContext,PersistenceLocalContext,TreeStoreModule import ( "math/big"