diff --git a/pkg/solana/chain.go b/pkg/solana/chain.go index 5ed5eb8cb..5bbda3d26 100644 --- a/pkg/solana/chain.go +++ b/pkg/solana/chain.go @@ -28,6 +28,7 @@ import ( mn "github.com/smartcontractkit/chainlink-solana/pkg/solana/client/multinode" "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" "github.com/smartcontractkit/chainlink-solana/pkg/solana/internal" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/logpoller" "github.com/smartcontractkit/chainlink-solana/pkg/solana/monitor" "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm" ) @@ -37,6 +38,7 @@ type Chain interface { ID() string Config() config.Config + LogPoller() logpoller.ILogPoller TxManager() TxManager // Reader returns a new Reader from the available list of nodes (if there are multiple, it will randomly select one) Reader() (client.Reader, error) @@ -89,6 +91,7 @@ type chain struct { services.StateMachine id string cfg *config.TOMLConfig + lp logpoller.ILogPoller txm *txm.Txm balanceMonitor services.Service lggr logger.Logger @@ -235,6 +238,7 @@ func newChain(id string, cfg *config.TOMLConfig, ks core.Keystore, lggr logger.L clientCache: map[string]*verifiedCachedClient{}, } + var lc internal.Loader[client.Reader] = utils.NewLazyLoad(func() (client.Reader, error) { return ch.getClient() }) var tc internal.Loader[client.ReaderWriter] = utils.NewLazyLoad(func() (client.ReaderWriter, error) { return ch.getClient() }) var bc internal.Loader[monitor.BalanceClient] = utils.NewLazyLoad(func() (monitor.BalanceClient, error) { return ch.getClient() }) @@ -303,10 +307,14 @@ func newChain(id string, cfg *config.TOMLConfig, ks core.Keystore, lggr logger.L return result.Signature(), result.Error() } - tc = internal.NewLoader[client.ReaderWriter](func() (client.ReaderWriter, error) { return ch.multiNode.SelectRPC() }) - bc = internal.NewLoader[monitor.BalanceClient](func() (monitor.BalanceClient, error) { return ch.multiNode.SelectRPC() }) + // TODO: Can we just remove these? They nullify the lazy loaders initialized earlier, don't they? + //lc = internal.NewLoader[client.Reader](func() (client.Reader, error) { return ch.multiNode.SelectRPC() }) + //tc = internal.NewLoader[client.ReaderWriter](func() (client.ReaderWriter, error) { return ch.multiNode.SelectRPC() }) + //bc = internal.NewLoader[monitor.BalanceClient](func() (monitor.BalanceClient, error) { return ch.multiNode.SelectRPC() }) } + // TODO: import typeProvider function from codec package and pass to constructor + ch.lp = logpoller.NewLogPoller(logger.Sugared(logger.Named(lggr, "LogPoller")), logpoller.NewORM(ch.ID(), ds, lggr), lc, nil) ch.txm = txm.NewTxm(ch.id, tc, sendTx, cfg, ks, lggr) ch.balanceMonitor = monitor.NewBalanceMonitor(ch.id, cfg, lggr, ks, bc) return &ch, nil @@ -396,6 +404,10 @@ func (c *chain) Config() config.Config { return c.cfg } +func (c *chain) LogPoller() logpoller.ILogPoller { + return c.lp +} + func (c *chain) TxManager() TxManager { return c.txm } diff --git a/pkg/solana/client/client.go b/pkg/solana/client/client.go index a015fdc1f..6558f96a1 100644 --- a/pkg/solana/client/client.go +++ b/pkg/solana/client/client.go @@ -39,6 +39,7 @@ type Reader interface { GetTransaction(ctx context.Context, txHash solana.Signature, opts *rpc.GetTransactionOpts) (*rpc.GetTransactionResult, error) GetBlocks(ctx context.Context, startSlot uint64, endSlot *uint64) (rpc.BlocksResult, error) GetBlocksWithLimit(ctx context.Context, startSlot uint64, limit uint64) (*rpc.BlocksResult, error) + GetBlockWithOpts(context.Context, uint64, *rpc.GetBlockOpts) (*rpc.GetBlockResult, error) GetBlock(ctx context.Context, slot uint64) (*rpc.GetBlockResult, error) GetSignaturesForAddressWithOpts(ctx context.Context, addr solana.PublicKey, opts *rpc.GetSignaturesForAddressOpts) ([]*rpc.TransactionSignature, error) } @@ -331,6 +332,15 @@ func (c *Client) GetLatestBlock(ctx context.Context) (*rpc.GetBlockResult, error return v.(*rpc.GetBlockResult), err } +func (c *Client) GetBlockWithOpts(ctx context.Context, slot uint64, opts *rpc.GetBlockOpts) (*rpc.GetBlockResult, error) { + // get block based on slot with custom options set + done := c.latency("get_block_with_opts") + defer done() + ctx, cancel := context.WithTimeout(ctx, c.txTimeout) + defer cancel() + return c.rpc.GetBlockWithOpts(ctx, slot, opts) +} + func (c *Client) GetBlock(ctx context.Context, slot uint64) (*rpc.GetBlockResult, error) { // get block based on slot done := c.latency("get_block") diff --git a/pkg/solana/client/mocks/reader_writer.go b/pkg/solana/client/mocks/reader_writer.go index c64a4a9ad..48cae23ad 100644 --- a/pkg/solana/client/mocks/reader_writer.go +++ b/pkg/solana/client/mocks/reader_writer.go @@ -257,6 +257,66 @@ func (_c *ReaderWriter_GetBlock_Call) RunAndReturn(run func(context.Context, uin return _c } +// GetBlockWithOpts provides a mock function with given fields: _a0, _a1, _a2 +func (_m *ReaderWriter) GetBlockWithOpts(_a0 context.Context, _a1 uint64, _a2 *rpc.GetBlockOpts) (*rpc.GetBlockResult, error) { + ret := _m.Called(_a0, _a1, _a2) + + if len(ret) == 0 { + panic("no return value specified for GetBlockWithOpts") + } + + var r0 *rpc.GetBlockResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, uint64, *rpc.GetBlockOpts) (*rpc.GetBlockResult, error)); ok { + return rf(_a0, _a1, _a2) + } + if rf, ok := ret.Get(0).(func(context.Context, uint64, *rpc.GetBlockOpts) *rpc.GetBlockResult); ok { + r0 = rf(_a0, _a1, _a2) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*rpc.GetBlockResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, uint64, *rpc.GetBlockOpts) error); ok { + r1 = rf(_a0, _a1, _a2) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ReaderWriter_GetBlockWithOpts_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetBlockWithOpts' +type ReaderWriter_GetBlockWithOpts_Call struct { + *mock.Call +} + +// GetBlockWithOpts is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 uint64 +// - _a2 *rpc.GetBlockOpts +func (_e *ReaderWriter_Expecter) GetBlockWithOpts(_a0 interface{}, _a1 interface{}, _a2 interface{}) *ReaderWriter_GetBlockWithOpts_Call { + return &ReaderWriter_GetBlockWithOpts_Call{Call: _e.mock.On("GetBlockWithOpts", _a0, _a1, _a2)} +} + +func (_c *ReaderWriter_GetBlockWithOpts_Call) Run(run func(_a0 context.Context, _a1 uint64, _a2 *rpc.GetBlockOpts)) *ReaderWriter_GetBlockWithOpts_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(uint64), args[2].(*rpc.GetBlockOpts)) + }) + return _c +} + +func (_c *ReaderWriter_GetBlockWithOpts_Call) Return(_a0 *rpc.GetBlockResult, _a1 error) *ReaderWriter_GetBlockWithOpts_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *ReaderWriter_GetBlockWithOpts_Call) RunAndReturn(run func(context.Context, uint64, *rpc.GetBlockOpts) (*rpc.GetBlockResult, error)) *ReaderWriter_GetBlockWithOpts_Call { + _c.Call.Return(run) + return _c +} + // GetBlocks provides a mock function with given fields: ctx, startSlot, endSlot func (_m *ReaderWriter) GetBlocks(ctx context.Context, startSlot uint64, endSlot *uint64) (rpc.BlocksResult, error) { ret := _m.Called(ctx, startSlot, endSlot) diff --git a/pkg/solana/codec/solana_test.go b/pkg/solana/codec/solana_test.go index 4dd116691..57e921fb7 100644 --- a/pkg/solana/codec/solana_test.go +++ b/pkg/solana/codec/solana_test.go @@ -143,6 +143,13 @@ func TestNewIDLCodec_CircularDependency(t *testing.T) { assert.ErrorIs(t, err, types.ErrInvalidConfig) } +func TestNewIDLInstructionCodec(t *testing.T) { + t.Parallel() + + var idl codec.IDL + +} + func newTestIDLAndCodec(t *testing.T, account bool) (string, codec.IDL, types.RemoteCodec) { t.Helper() diff --git a/pkg/solana/logpoller/filters.go b/pkg/solana/logpoller/filters.go new file mode 100644 index 000000000..bce439631 --- /dev/null +++ b/pkg/solana/logpoller/filters.go @@ -0,0 +1,307 @@ +package logpoller + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "iter" + "maps" + "sync" + "sync/atomic" + + "github.com/gagliardetto/solana-go" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/logpoller/utils" +) + +type filters struct { + orm ORM + lggr logger.SugaredLogger + + filtersByName map[string]Filter + filtersByAddress map[PublicKey]map[EventSignature]map[int64]Filter + filtersToBackfill map[int64]Filter + filtersToDelete map[int64]Filter + filtersMutex sync.RWMutex + loadedFilters atomic.Bool + knownPrograms map[string]struct{} // fast lookup to see if a base58-encoded ProgramID matches any registered filters + knownDiscriminators map[string]struct{} // fast lookup by first 10 characters (60-bits) of a base64-encoded discriminator +} + +func newFilters(lggr logger.SugaredLogger, orm ORM) *filters { + return &filters{ + orm: orm, + lggr: lggr, + } +} + +// PruneFilters - prunes all filters marked to be deleted from the database and all corresponding logs. +func (fl *filters) PruneFilters(ctx context.Context) error { + err := fl.LoadFilters(ctx) + if err != nil { + return fmt.Errorf("failed to load filters: %w", err) + } + + fl.filtersMutex.Lock() + filtersToDelete := fl.filtersToDelete + fl.filtersToDelete = make(map[int64]Filter) + fl.filtersMutex.Unlock() + + if len(filtersToDelete) == 0 { + return nil + } + + err = fl.orm.DeleteFilters(ctx, filtersToDelete) + if err != nil { + fl.filtersMutex.Lock() + defer fl.filtersMutex.Unlock() + maps.Copy(fl.filtersToDelete, filtersToDelete) + return fmt.Errorf("failed to delete filters: %w", err) + } + + return nil +} + +// RegisterFilter persists provided filter and ensures that any log emitted by a contract with filter.Address +// that matches filter.EventSig signature will be captured starting from filter.StartingBlock. +// The filter may be unregistered later by filter.Name. +// In case of Filter.Name collision (within the chain scope) returns ErrFilterNameConflict if +// one of the fields defining resulting logs (Address, EventSig, EventIDL, SubkeyPaths) does not match original filter. +// Otherwise, updates remaining fields and schedules backfill. +// Warnings/debug information is keyed by filter name. +func (fl *filters) RegisterFilter(ctx context.Context, filter Filter) error { + if len(filter.Name) == 0 { + return errors.New("name is required") + } + + err := fl.LoadFilters(ctx) + if err != nil { + return fmt.Errorf("failed to load filters: %w", err) + } + + filter.EventSig = utils.Discriminator("event", filter.EventName) + + fl.filtersMutex.Lock() + defer fl.filtersMutex.Unlock() + + if existingFilter, ok := fl.filtersByName[filter.Name]; ok { + if !existingFilter.MatchSameLogs(filter) { + return ErrFilterNameConflict + } + + fl.removeFilterFromIndexes(existingFilter) + } + + filterID, err := fl.orm.InsertFilter(ctx, filter) + if err != nil { + return fmt.Errorf("failed to insert filter: %w", err) + } + + filter.ID = filterID + fl.filtersByName[filter.Name] = filter + filtersByEventSig, ok := fl.filtersByAddress[filter.Address] + if !ok { + filtersByEventSig = make(map[EventSignature]map[int64]Filter) + fl.filtersByAddress[filter.Address] = filtersByEventSig + } + + filtersByID, ok := filtersByEventSig[filter.EventSig] + if !ok { + filtersByID = make(map[int64]Filter) + filtersByEventSig[filter.EventSig] = filtersByID + } + + filtersByID[filter.ID] = filter + fl.filtersToBackfill[filterID] = filter + + fl.knownPrograms[filter.Address.ToSolana().String()] = struct{}{} + discriminator := base64.StdEncoding.EncodeToString(filter.EventSig[:]) + fl.knownDiscriminators[discriminator[:10]] = struct{}{} + + return nil +} + +// UnregisterFilter will mark the filter with the given name for pruning and async prune all corresponding logs. +// If the name does not exist, it will log an error but not return an error. +// Warnings/debug information is keyed by filter name. +func (fl *filters) UnregisterFilter(ctx context.Context, name string) error { + err := fl.LoadFilters(ctx) + if err != nil { + return fmt.Errorf("failed to load filters: %w", err) + } + + fl.filtersMutex.Lock() + defer fl.filtersMutex.Unlock() + + filter, ok := fl.filtersByName[name] + if !ok { + fl.lggr.Warnw("Filter not found in filtersByName", "name", name) + return nil + } + + if err := fl.orm.MarkFilterDeleted(ctx, filter.ID); err != nil { + return fmt.Errorf("failed to mark filter deleted: %w", err) + } + + fl.removeFilterFromIndexes(filter) + + fl.filtersToDelete[filter.ID] = filter + return nil +} + +func (fl *filters) removeFilterFromIndexes(filter Filter) { + delete(fl.filtersByName, filter.Name) + delete(fl.filtersToBackfill, filter.ID) + + filtersByEventSig, ok := fl.filtersByAddress[filter.Address] + if !ok { + fl.lggr.Warnw("Filter not found in filtersByAddress", "name", filter.Name, "address", filter.Address) + return + } + + filtersByID, ok := filtersByEventSig[filter.EventSig] + if !ok { + fl.lggr.Warnw("Filter not found in filtersByEventSig", "name", filter.Name, "address", filter.Address) + return + } + + delete(filtersByID, filter.ID) + if len(filtersByID) == 0 { + delete(filtersByEventSig, filter.EventSig) + } + + if len(filtersByEventSig) == 0 { + delete(fl.filtersByAddress, filter.Address) + } +} + +// MatchingFilters - returns iterator to go through all matching filters. +// Requires LoadFilters to be called at least once. +func (fl *filters) MatchingFilters(addr PublicKey, eventSignature EventSignature) iter.Seq[Filter] { + if !fl.loadedFilters.Load() { + fl.lggr.Critical("Invariant violation: expected filters to be loaded before call to MatchingFilters") + return nil + } + return func(yield func(Filter) bool) { + fl.filtersMutex.RLock() + defer fl.filtersMutex.RUnlock() + filters, ok := fl.filtersByAddress[addr] + if !ok { + return + } + + for _, filter := range filters[eventSignature] { + if !yield(filter) { + return + } + } + } +} + +// MatchchingFiltersForEncodedEvent - similar to MatchingFilters but accepts a raw encoded event. Under normal operation, +// this will be called on every new event that happens on the blockchain, so it's important it returns immediately if it +// doesn't match any registered filters. +func (fl *filters) MatchingFiltersForEncodedEvent(event ProgramEvent) iter.Seq[Filter] { + if _, ok := fl.knownPrograms[event.Program]; !ok { + return nil + } + + // The first 64-bits of the event data is the event sig. Because it's base64 encoded, this corresponds to + // the first 10 characters plus 4 bits of the 11th character. We can quickly rule it out as not matching any known + // discriminators if the first 10 characters don't match. If it passes that initial test, we base64-decode the + // first 11 characters, and use the first 8 bytes of that as the event sig to call MatchingFilters. The address + // also needs to be base58-decoded to pass to MatchingFilters + if _, ok := fl.knownDiscriminators[event.Data[:10]]; !ok { + return nil + } + + addr, err := solana.PublicKeyFromBase58(event.Program) + if err != nil { + fl.lggr.Errorw("failed to parse Program ID for event", "EventProgram", event) + return nil + } + decoded, err := base64.StdEncoding.DecodeString(event.Data[:11]) + if err != nil { + fl.lggr.Errorw("failed to decode event data", "EventProgram", event) + return nil + } + eventSig := EventSignature(decoded[:8]) + + return fl.MatchingFilters(PublicKey(addr), eventSig) +} + +// ConsumeFiltersToBackfill - removes all filters from the backfill queue and returns them to caller. +// Requires LoadFilters to be called at least once. +func (fl *filters) ConsumeFiltersToBackfill() map[int64]Filter { + if !fl.loadedFilters.Load() { + fl.lggr.Critical("Invariant violation: expected filters to be loaded before call to MatchingFilters") + return nil + } + fl.filtersMutex.Lock() + defer fl.filtersMutex.Unlock() + filtersToBackfill := fl.filtersToBackfill + fl.filtersToBackfill = make(map[int64]Filter) + return filtersToBackfill +} + +// LoadFilters - loads filters from database. Can be called multiple times without side effects. +func (fl *filters) LoadFilters(ctx context.Context) error { + if fl.loadedFilters.Load() { + return nil + } + + fl.lggr.Debugw("Loading filters from db") + fl.filtersMutex.Lock() + defer fl.filtersMutex.Unlock() + // reset filters' indexes to ensure we do not have partial data from the previous run + fl.filtersByName = make(map[string]Filter) + fl.filtersByAddress = make(map[PublicKey]map[EventSignature]map[int64]Filter) + fl.filtersToBackfill = make(map[int64]Filter) + fl.filtersToDelete = make(map[int64]Filter) + + filters, err := fl.orm.SelectFilters(ctx) + if err != nil { + return fmt.Errorf("failed to select filters from db: %w", err) + } + + for _, filter := range filters { + if filter.IsDeleted { + fl.filtersToDelete[filter.ID] = filter + continue + } + + if _, ok := fl.filtersByName[filter.Name]; ok { + errMsg := fmt.Sprintf("invariant violation while loading from db: expected filters to have unique name: %s ", filter.Name) + fl.lggr.Critical(errMsg) + return errors.New(errMsg) + } + + fl.filtersByName[filter.Name] = filter + filtersByEventSig, ok := fl.filtersByAddress[filter.Address] + if !ok { + filtersByEventSig = make(map[EventSignature]map[int64]Filter) + fl.filtersByAddress[filter.Address] = filtersByEventSig + } + + filtersByID, ok := filtersByEventSig[filter.EventSig] + if !ok { + filtersByID = make(map[int64]Filter) + filtersByEventSig[filter.EventSig] = filtersByID + } + + if _, ok := filtersByID[filter.ID]; ok { + errMsg := fmt.Sprintf("invariant violation while loading from db: expected filters to have unique ID: %d ", filter.ID) + fl.lggr.Critical(errMsg) + return errors.New(errMsg) + } + + filtersByID[filter.ID] = filter + fl.filtersToBackfill[filter.ID] = filter + } + + fl.loadedFilters.Store(true) + + return nil +} diff --git a/pkg/solana/logpoller/filters_test.go b/pkg/solana/logpoller/filters_test.go new file mode 100644 index 000000000..4b7f57f9b --- /dev/null +++ b/pkg/solana/logpoller/filters_test.go @@ -0,0 +1,307 @@ +package logpoller + +import ( + "errors" + "fmt" + "slices" + "testing" + + "github.com/gagliardetto/solana-go" + "github.com/google/uuid" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestFilters_LoadFilters(t *testing.T) { + orm := newMockORM(t) + fs := newFilters(logger.Sugared(logger.Test(t)), orm) + ctx := tests.Context(t) + orm.On("SelectFilters", mock.Anything).Return(nil, errors.New("db failed")).Once() + deleted := Filter{ + ID: 3, + Name: "Deleted", + IsDeleted: true, + } + happyPath := Filter{ + ID: 1, + Name: "Happy path", + } + happyPath2 := Filter{ + ID: 2, + Name: "Happy path 2", + } + orm.On("SelectFilters", mock.Anything).Return([]Filter{ + deleted, + happyPath, + happyPath2, + }, nil).Once() + + err := fs.LoadFilters(ctx) + require.EqualError(t, err, "failed to select filters from db: db failed") + err = fs.LoadFilters(ctx) + require.NoError(t, err) + // only one filter to delete + require.Len(t, fs.filtersToDelete, 1) + require.Equal(t, deleted, fs.filtersToDelete[deleted.ID]) + // both happy path are indexed + require.Len(t, fs.filtersByAddress, 1) + require.Len(t, fs.filtersByAddress[happyPath.Address], 1) + require.Len(t, fs.filtersByAddress[happyPath.Address][happyPath.EventSig], 2) + require.Equal(t, happyPath, fs.filtersByAddress[happyPath.Address][happyPath.EventSig][happyPath.ID]) + require.Equal(t, happyPath2, fs.filtersByAddress[happyPath.Address][happyPath.EventSig][happyPath2.ID]) + require.Len(t, fs.filtersByName, 2) + require.Equal(t, fs.filtersByName[happyPath.Name], happyPath) + require.Equal(t, fs.filtersByName[happyPath2.Name], happyPath2) + // any call following successful should be noop + err = fs.LoadFilters(ctx) + require.NoError(t, err) +} + +func TestFilters_RegisterFilter(t *testing.T) { + lggr := logger.Sugared(logger.Test(t)) + t.Run("Returns an error if name is empty", func(t *testing.T) { + orm := newMockORM(t) + fs := newFilters(lggr, orm) + err := fs.RegisterFilter(tests.Context(t), Filter{}) + require.EqualError(t, err, "name is required") + }) + t.Run("Returns an error if fails to load filters from db", func(t *testing.T) { + orm := newMockORM(t) + fs := newFilters(lggr, orm) + orm.On("SelectFilters", mock.Anything).Return(nil, errors.New("db failed")).Once() + err := fs.RegisterFilter(tests.Context(t), Filter{Name: "Filter"}) + require.EqualError(t, err, "failed to load filters: failed to select filters from db: db failed") + }) + t.Run("Returns an error if trying to update primary fields", func(t *testing.T) { + testCases := []struct { + Name string + ModifyField func(*Filter) + }{ + { + Name: "Address", + ModifyField: func(f *Filter) { + privateKey, err := solana.NewRandomPrivateKey() + require.NoError(t, err) + f.Address = PublicKey(privateKey.PublicKey()) + }, + }, + { + Name: "EventSig", + ModifyField: func(f *Filter) { + f.EventSig = EventSignature{3, 2, 1} + }, + }, + { + Name: "EventIDL", + ModifyField: func(f *Filter) { + f.EventIDL = uuid.NewString() + }, + }, + { + Name: "SubkeyPaths", + ModifyField: func(f *Filter) { + f.SubkeyPaths = [][]string{{uuid.NewString()}} + }, + }, + } + for _, tc := range testCases { + t.Run(fmt.Sprintf("Updating %s", tc.Name), func(t *testing.T) { + orm := newMockORM(t) + fs := newFilters(lggr, orm) + const filterName = "Filter" + dbFilter := Filter{Name: filterName} + orm.On("SelectFilters", mock.Anything).Return([]Filter{dbFilter}, nil).Once() + newFilter := dbFilter + tc.ModifyField(&newFilter) + err := fs.RegisterFilter(tests.Context(t), newFilter) + require.EqualError(t, err, ErrFilterNameConflict.Error()) + }) + } + }) + t.Run("Happy path", func(t *testing.T) { + orm := newMockORM(t) + fs := newFilters(lggr, orm) + const filterName = "Filter" + orm.On("SelectFilters", mock.Anything).Return(nil, nil).Once() + orm.On("InsertFilter", mock.Anything, mock.Anything).Return(int64(0), errors.New("failed to insert")).Once() + filter := Filter{Name: filterName} + err := fs.RegisterFilter(tests.Context(t), filter) + require.Error(t, err) + // can read after db issue is resolved + orm.On("InsertFilter", mock.Anything, mock.Anything).Return(int64(1), nil).Once() + err = fs.RegisterFilter(tests.Context(t), filter) + require.NoError(t, err) + // can update non-primary fields + filter.EventName = uuid.NewString() + filter.StartingBlock++ + filter.Retention++ + filter.MaxLogsKept++ + orm.On("InsertFilter", mock.Anything, mock.Anything).Return(int64(1), nil).Once() + err = fs.RegisterFilter(tests.Context(t), filter) + require.NoError(t, err) + storedFilters := slices.Collect(fs.MatchingFilters(filter.Address, filter.EventSig)) + require.Len(t, storedFilters, 1) + filter.ID = 1 + require.Equal(t, filter, storedFilters[0]) + }) + t.Run("Can reregister after unregister", func(t *testing.T) { + orm := newMockORM(t) + fs := newFilters(lggr, orm) + const filterName = "Filter" + orm.On("SelectFilters", mock.Anything).Return(nil, nil).Once() + const filterID = int64(10) + orm.On("InsertFilter", mock.Anything, mock.Anything).Return(filterID, nil).Once() + err := fs.RegisterFilter(tests.Context(t), Filter{Name: filterName}) + require.NoError(t, err) + orm.On("MarkFilterDeleted", mock.Anything, filterID).Return(nil).Once() + err = fs.UnregisterFilter(tests.Context(t), filterName) + require.NoError(t, err) + orm.On("InsertFilter", mock.Anything, mock.Anything).Return(filterID+1, nil).Once() + err = fs.RegisterFilter(tests.Context(t), Filter{Name: filterName}) + require.NoError(t, err) + require.Len(t, fs.filtersToDelete, 1) + require.Equal(t, Filter{Name: filterName, ID: filterID}, fs.filtersToDelete[filterID]) + require.Len(t, fs.filtersToBackfill, 1) + require.Equal(t, Filter{Name: filterName, ID: filterID + 1}, fs.filtersToBackfill[filterID+1]) + }) +} + +func TestFilters_UnregisterFilter(t *testing.T) { + lggr := logger.Sugared(logger.Test(t)) + t.Run("Returns an error if fails to load filters from db", func(t *testing.T) { + orm := newMockORM(t) + fs := newFilters(lggr, orm) + orm.On("SelectFilters", mock.Anything).Return(nil, errors.New("db failed")).Once() + err := fs.UnregisterFilter(tests.Context(t), "Filter") + require.EqualError(t, err, "failed to load filters: failed to select filters from db: db failed") + }) + t.Run("Noop if filter is not present", func(t *testing.T) { + orm := newMockORM(t) + fs := newFilters(lggr, orm) + const filterName = "Filter" + orm.On("SelectFilters", mock.Anything).Return(nil, nil).Once() + err := fs.UnregisterFilter(tests.Context(t), filterName) + require.NoError(t, err) + }) + t.Run("Returns error if fails to mark filter as deleted", func(t *testing.T) { + orm := newMockORM(t) + fs := newFilters(lggr, orm) + const filterName = "Filter" + const id int64 = 10 + orm.On("SelectFilters", mock.Anything).Return([]Filter{{ID: id, Name: filterName}}, nil).Once() + orm.On("MarkFilterDeleted", mock.Anything, id).Return(errors.New("db query failed")).Once() + err := fs.UnregisterFilter(tests.Context(t), filterName) + require.EqualError(t, err, "failed to mark filter deleted: db query failed") + }) + t.Run("Happy path", func(t *testing.T) { + orm := newMockORM(t) + fs := newFilters(lggr, orm) + const filterName = "Filter" + const id int64 = 10 + orm.On("SelectFilters", mock.Anything).Return([]Filter{{ID: id, Name: filterName}}, nil).Once() + orm.On("MarkFilterDeleted", mock.Anything, id).Return(nil).Once() + err := fs.UnregisterFilter(tests.Context(t), filterName) + require.NoError(t, err) + require.Len(t, fs.filtersToDelete, 1) + require.Len(t, fs.filtersToBackfill, 0) + require.Len(t, fs.filtersByName, 0) + require.Len(t, fs.filtersByAddress, 0) + }) +} + +func TestFilters_PruneFilters(t *testing.T) { + lggr := logger.Sugared(logger.Test(t)) + t.Run("Happy path", func(t *testing.T) { + orm := newMockORM(t) + fs := newFilters(lggr, orm) + toDelete := Filter{ + ID: 1, + Name: "To delete", + IsDeleted: true, + } + orm.On("SelectFilters", mock.Anything).Return([]Filter{ + toDelete, + { + ID: 2, + Name: "To keep", + }, + }, nil).Once() + orm.On("DeleteFilters", mock.Anything, map[int64]Filter{toDelete.ID: toDelete}).Return(nil).Once() + err := fs.PruneFilters(tests.Context(t)) + require.NoError(t, err) + require.Len(t, fs.filtersToDelete, 0) + }) + t.Run("If DB removal fails will add filters back into removal slice ", func(t *testing.T) { + orm := newMockORM(t) + fs := newFilters(lggr, orm) + toDelete := Filter{ + ID: 1, + Name: "To delete", + IsDeleted: true, + } + orm.On("SelectFilters", mock.Anything).Return([]Filter{ + toDelete, + { + ID: 2, + Name: "To keep", + }, + }, nil).Once() + newToDelete := Filter{ + ID: 3, + Name: "To delete 2", + } + orm.On("DeleteFilters", mock.Anything, map[int64]Filter{toDelete.ID: toDelete}).Return(errors.New("db failed")).Run(func(_ mock.Arguments) { + orm.On("MarkFilterDeleted", mock.Anything, newToDelete.ID).Return(nil).Once() + orm.On("InsertFilter", mock.Anything, mock.Anything).Return(newToDelete.ID, nil).Once() + require.NoError(t, fs.RegisterFilter(tests.Context(t), newToDelete)) + require.NoError(t, fs.UnregisterFilter(tests.Context(t), newToDelete.Name)) + }).Once() + err := fs.PruneFilters(tests.Context(t)) + require.EqualError(t, err, "failed to delete filters: db failed") + require.Equal(t, fs.filtersToDelete, map[int64]Filter{newToDelete.ID: newToDelete, toDelete.ID: toDelete}) + }) +} + +func TestFilters_MatchingFilters(t *testing.T) { + orm := newMockORM(t) + lggr := logger.Sugared(logger.Test(t)) + expectedFilter1 := Filter{ + ID: 1, + Name: "expectedFilter1", + Address: newRandomPublicKey(t), + EventSig: newRandomEventSignature(t), + } + expectedFilter2 := Filter{ + ID: 2, + Name: "expectedFilter2", + Address: expectedFilter1.Address, + EventSig: expectedFilter1.EventSig, + } + sameAddress := Filter{ + ID: 3, + Name: "sameAddressWrongEventSig", + Address: expectedFilter1.Address, + EventSig: newRandomEventSignature(t), + } + + sameEventSig := Filter{ + ID: 4, + Name: "wrongAddressSameEventSig", + Address: newRandomPublicKey(t), + EventSig: expectedFilter1.EventSig, + } + orm.On("SelectFilters", mock.Anything).Return([]Filter{expectedFilter1, expectedFilter2, sameAddress, sameEventSig}, nil).Once() + filters := newFilters(lggr, orm) + err := filters.LoadFilters(tests.Context(t)) + require.NoError(t, err) + matchingFilters := slices.Collect(filters.MatchingFilters(expectedFilter1.Address, expectedFilter1.EventSig)) + require.Len(t, matchingFilters, 2) + require.Contains(t, matchingFilters, expectedFilter1) + require.Contains(t, matchingFilters, expectedFilter2) + // if at least one key does not match - returns empty iterator + require.Empty(t, slices.Collect(filters.MatchingFilters(newRandomPublicKey(t), expectedFilter1.EventSig))) + require.Empty(t, slices.Collect(filters.MatchingFilters(expectedFilter1.Address, newRandomEventSignature(t)))) + require.Empty(t, slices.Collect(filters.MatchingFilters(newRandomPublicKey(t), newRandomEventSignature(t)))) +} diff --git a/pkg/solana/logpoller/job.go b/pkg/solana/logpoller/job.go index 165c0b5fe..8e24d4165 100644 --- a/pkg/solana/logpoller/job.go +++ b/pkg/solana/logpoller/job.go @@ -36,6 +36,7 @@ type eventDetail struct { slotNumber uint64 blockHeight uint64 blockHash solana.Hash + blockTime solana.UnixTimeSeconds trxIdx int trxSig solana.Signature } @@ -114,12 +115,18 @@ func (j *getTransactionsFromBlockJob) Run(ctx context.Context) error { blockHash: block.Blockhash, } - if block.BlockHeight != nil { - detail.blockHeight = *block.BlockHeight + if block.BlockHeight == nil { + return fmt.Errorf("block at slot %d returned from rpc is missing block number", j.slotNumber) + } + detail.blockHeight = *block.BlockHeight + + if block.BlockTime == nil { + return fmt.Errorf("received block %d from rpc with missing block time", block.BlockHeight) + detail.blockTime = *block.BlockTime } if len(block.Transactions) != len(blockSigsOnly.Signatures) { - return fmt.Errorf("block %d has %d transactions but %d signatures", j.slotNumber, len(block.Transactions), len(blockSigsOnly.Signatures)) + return fmt.Errorf("block %d has %d transactions but %d signatures", block.BlockHeight, len(block.Transactions), len(blockSigsOnly.Signatures)) } j.parser.ExpectTxs(j.slotNumber, len(block.Transactions)) @@ -143,6 +150,7 @@ func messagesToEvents(messages []string, parser ProgramEventProcessor, detail ev event.SlotNumber = detail.slotNumber event.BlockHeight = detail.blockHeight event.BlockHash = detail.blockHash + event.BlockTime = detail.blockTime event.TransactionHash = detail.trxSig event.TransactionIndex = detail.trxIdx event.TransactionLogIndex = logIdx diff --git a/pkg/solana/logpoller/loader.go b/pkg/solana/logpoller/loader.go index d714f08ad..39a985a98 100644 --- a/pkg/solana/logpoller/loader.go +++ b/pkg/solana/logpoller/loader.go @@ -27,8 +27,8 @@ type ProgramEventProcessor interface { } type RPCClient interface { - GetLatestBlockhash(ctx context.Context, commitment rpc.CommitmentType) (out *rpc.GetLatestBlockhashResult, err error) - GetBlocks(ctx context.Context, startSlot uint64, endSlot *uint64, commitment rpc.CommitmentType) (out rpc.BlocksResult, err error) + LatestBlockhash(ctx context.Context) (out *rpc.GetLatestBlockhashResult, err error) + GetBlocks(ctx context.Context, startSlot uint64, endSlot *uint64) (out rpc.BlocksResult, err error) GetBlockWithOpts(context.Context, uint64, *rpc.GetBlockOpts) (*rpc.GetBlockResult, error) GetSignaturesForAddressWithOpts(context.Context, solana.PublicKey, *rpc.GetSignaturesForAddressOpts) ([]*rpc.TransactionSignature, error) } @@ -170,7 +170,7 @@ func (c *EncodedLogCollector) runSlotPolling(ctx context.Context) { ctxB, cancel := context.WithTimeout(ctx, c.rpcTimeLimit) // not to be run as a job, but as a blocking call - result, err := c.client.GetLatestBlockhash(ctxB, rpc.CommitmentFinalized) + result, err := c.client.LatestBlockhash(ctxB) if err != nil { c.lggr.Error("failed to get latest blockhash", "err", err) cancel() @@ -276,7 +276,7 @@ func (c *EncodedLogCollector) loadSlotBlocksRange(ctx context.Context, start, en rpcCtx, cancel := context.WithTimeout(ctx, c.rpcTimeLimit) defer cancel() - if result, err = c.client.GetBlocks(rpcCtx, start, &end, rpc.CommitmentFinalized); err != nil { + if result, err = c.client.GetBlocks(rpcCtx, start, &end); err != nil { return err } diff --git a/pkg/solana/logpoller/loader_test.go b/pkg/solana/logpoller/loader_test.go index e3cbb7700..4d3dcd8cc 100644 --- a/pkg/solana/logpoller/loader_test.go +++ b/pkg/solana/logpoller/loader_test.go @@ -19,7 +19,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/chainlink-solana/pkg/solana/logpoller" - mocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/logpoller/mocks" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/logpoller/mocks" ) var ( @@ -63,7 +63,7 @@ func TestEncodedLogCollector_ParseSingleEvent(t *testing.T) { latest.Store(uint64(40)) client.EXPECT(). - GetLatestBlockhash(mock.Anything, rpc.CommitmentFinalized). + LatestBlockhash(mock.Anything). RunAndReturn(latestBlockhashReturnFunc(&latest)) client.EXPECT(). @@ -71,7 +71,6 @@ func TestEncodedLogCollector_ParseSingleEvent(t *testing.T) { mock.Anything, mock.MatchedBy(getBlocksStartValMatcher), mock.MatchedBy(getBlocksEndValMatcher(&latest)), - rpc.CommitmentFinalized, ). RunAndReturn(getBlocksReturnFunc(false)) @@ -139,7 +138,7 @@ func TestEncodedLogCollector_MultipleEventOrdered(t *testing.T) { } client.EXPECT(). - GetLatestBlockhash(mock.Anything, rpc.CommitmentFinalized). + LatestBlockhash(mock.Anything). RunAndReturn(latestBlockhashReturnFunc(&latest)) client.EXPECT(). @@ -147,7 +146,6 @@ func TestEncodedLogCollector_MultipleEventOrdered(t *testing.T) { mock.Anything, mock.MatchedBy(getBlocksStartValMatcher), mock.MatchedBy(getBlocksEndValMatcher(&latest)), - rpc.CommitmentFinalized, ). RunAndReturn(getBlocksReturnFunc(false)) @@ -298,7 +296,7 @@ func TestEncodedLogCollector_BackfillForAddress(t *testing.T) { // GetLatestBlockhash might be called at start-up; make it take some time because the result isn't needed for this test client.EXPECT(). - GetLatestBlockhash(mock.Anything, rpc.CommitmentFinalized). + LatestBlockhash(mock.Anything). RunAndReturn(latestBlockhashReturnFunc(&latest)). After(2 * time.Second). Maybe() @@ -308,7 +306,6 @@ func TestEncodedLogCollector_BackfillForAddress(t *testing.T) { mock.Anything, mock.MatchedBy(getBlocksStartValMatcher), mock.MatchedBy(getBlocksEndValMatcher(&latest)), - rpc.CommitmentFinalized, ). RunAndReturn(getBlocksReturnFunc(true)) @@ -455,7 +452,7 @@ func (p *testBlockProducer) Count() uint64 { return p.count } -func (p *testBlockProducer) GetLatestBlockhash(_ context.Context, _ rpc.CommitmentType) (out *rpc.GetLatestBlockhashResult, err error) { +func (p *testBlockProducer) LatestBlockhash(_ context.Context) (out *rpc.GetLatestBlockhashResult, err error) { p.b.Helper() p.mu.Lock() @@ -474,7 +471,7 @@ func (p *testBlockProducer) GetLatestBlockhash(_ context.Context, _ rpc.Commitme }, nil } -func (p *testBlockProducer) GetBlocks(_ context.Context, startSlot uint64, endSlot *uint64, _ rpc.CommitmentType) (out rpc.BlocksResult, err error) { +func (p *testBlockProducer) GetBlocks(_ context.Context, startSlot uint64, endSlot *uint64) (out rpc.BlocksResult, err error) { p.b.Helper() p.mu.Lock() @@ -486,7 +483,7 @@ func (p *testBlockProducer) GetBlocks(_ context.Context, startSlot uint64, endSl blocks[idx] = startSlot + uint64(idx) } - return rpc.BlocksResult(blocks), nil + return blocks, nil } func (p *testBlockProducer) GetBlockWithOpts(_ context.Context, block uint64, opts *rpc.GetBlockOpts) (*rpc.GetBlockResult, error) { @@ -589,8 +586,8 @@ func (p *testParser) Events() []logpoller.ProgramEvent { return p.events } -func latestBlockhashReturnFunc(latest *atomic.Uint64) func(context.Context, rpc.CommitmentType) (*rpc.GetLatestBlockhashResult, error) { - return func(ctx context.Context, ct rpc.CommitmentType) (*rpc.GetLatestBlockhashResult, error) { +func latestBlockhashReturnFunc(latest *atomic.Uint64) func(context.Context) (*rpc.GetLatestBlockhashResult, error) { + return func(ctx context.Context) (*rpc.GetLatestBlockhashResult, error) { defer func() { latest.Store(latest.Load() + 2) }() @@ -608,8 +605,8 @@ func latestBlockhashReturnFunc(latest *atomic.Uint64) func(context.Context, rpc. } } -func getBlocksReturnFunc(empty bool) func(context.Context, uint64, *uint64, rpc.CommitmentType) (rpc.BlocksResult, error) { - return func(_ context.Context, u1 uint64, u2 *uint64, _ rpc.CommitmentType) (rpc.BlocksResult, error) { +func getBlocksReturnFunc(empty bool) func(context.Context, uint64, *uint64) (rpc.BlocksResult, error) { + return func(_ context.Context, u1 uint64, u2 *uint64) (rpc.BlocksResult, error) { blocks := []uint64{} if !empty { @@ -619,7 +616,7 @@ func getBlocksReturnFunc(empty bool) func(context.Context, uint64, *uint64, rpc. } } - return rpc.BlocksResult(blocks), nil + return blocks, nil } } diff --git a/pkg/solana/logpoller/log_data_parser.go b/pkg/solana/logpoller/log_data_parser.go index 4080a09e2..2549c40bc 100644 --- a/pkg/solana/logpoller/log_data_parser.go +++ b/pkg/solana/logpoller/log_data_parser.go @@ -19,6 +19,7 @@ type BlockData struct { SlotNumber uint64 BlockHeight uint64 BlockHash solana.Hash + BlockTime solana.UnixTimeSeconds TransactionHash solana.Signature TransactionIndex int TransactionLogIndex uint @@ -31,6 +32,7 @@ type ProgramLog struct { } type ProgramEvent struct { + Program string BlockData Prefix string Data string @@ -78,8 +80,9 @@ func parseProgramLogs(logs []string) []ProgramOutput { if len(dataMatches) > 1 { instLogs[lastLogIdx].Events = append(instLogs[lastLogIdx].Events, ProgramEvent{ - Prefix: prefixBuilder(depth), - Data: dataMatches[1], + Program: instLogs[lastLogIdx].Program, + Prefix: prefixBuilder(depth), + Data: dataMatches[1], }) } } else if strings.HasPrefix(log, "Log truncated") { diff --git a/pkg/solana/logpoller/log_poller.go b/pkg/solana/logpoller/log_poller.go new file mode 100644 index 000000000..6de5db431 --- /dev/null +++ b/pkg/solana/logpoller/log_poller.go @@ -0,0 +1,232 @@ +package logpoller + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "math" + "sync" + "time" + + bin "github.com/gagliardetto/binary" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/utils" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/internal" +) + +var ( + ErrFilterNameConflict = errors.New("filter with such name already exists") + ErrMissingEventTypeProvider = errors.New("cannot start LogPoller without EventTypeProvider") +) + +//go:generate mockery --name ORM --inpackage --structname mockORM --filename mock_orm.go +type ORM interface { + ChainID() string + InsertFilter(ctx context.Context, filter Filter) (id int64, err error) + SelectFilters(ctx context.Context) ([]Filter, error) + DeleteFilters(ctx context.Context, filters map[int64]Filter) error + MarkFilterDeleted(ctx context.Context, id int64) (err error) + InsertLogs(context.Context, []Log) (err error) +} + +type ILogPoller interface { + Start(context.Context) error + Close() error + RegisterFilter(ctx context.Context, filter Filter) error + UnregisterFilter(ctx context.Context, name string) error +} + +type LogPoller struct { + services.StateMachine + lggr logger.SugaredLogger + orm ORM + client internal.Loader[client.Reader] + collector *EncodedLogCollector + filters *filters + typeProvider EventTypeProvider + chStop services.StopChan + wg sync.WaitGroup +} + +func NewLogPoller(lggr logger.SugaredLogger, orm ORM, cl internal.Loader[client.Reader], typeProvider EventTypeProvider) ILogPoller { + lggr = logger.Sugared(logger.Named(lggr, "LogPoller")) + lp := LogPoller{ + orm: orm, + client: cl, + lggr: lggr, + filters: newFilters(lggr, orm), + typeProvider: typeProvider, + } + return &lp +} + +func makeLogIndex(txIndex int, txLogIndex uint) int64 { + if txIndex < 0 || txIndex > math.MaxUint32 || txLogIndex > math.MaxUint32 { + panic(fmt.Sprintf("txIndex or txLogIndex out of range: txIndex=%d, txLogIndex=%d", txIndex, txLogIndex)) + } + return int64(math.MaxUint32*uint32(txIndex) + uint32(txLogIndex)) +} + +// Process - process stream of events coming from log ingester +func (lp *LogPoller) Process(programEvent ProgramEvent) (err error) { + ctx, cancel := utils.ContextFromChan(lp.chStop) + defer cancel() + + blockData := programEvent.BlockData + + var logs []Log + for filter := range lp.filters.MatchingFiltersForEncodedEvent(programEvent) { + log := Log{ + FilterID: filter.ID, + ChainID: lp.orm.ChainID(), + LogIndex: makeLogIndex(blockData.TransactionIndex, blockData.TransactionLogIndex), + BlockHash: Hash(blockData.BlockHash), + BlockNumber: int64(blockData.BlockHeight), + BlockTimestamp: blockData.BlockTime.Time(), // TODO: is this a timezone safe conversion? + Address: filter.Address, + EventSig: filter.EventSig, + TxHash: Signature(blockData.TransactionHash), + } + + log.Data, err = base64.StdEncoding.DecodeString(programEvent.Data) + if err != nil { + return err + } + + for _, path := range filter.SubkeyPaths { + + var event any + event, err = lp.typeProvider.CreateType(filter.EventIdl.IdlEvent, filter.EventIdl.IdlTypeDefSlice, path) + bin.UnmarshalBorsh(&event, log.Data) + if err != nil { + return err + } + } + + // TODO: fill in, and keep track of SequenceNumber for each filter. (Initialize from db on LoadFilters, then increment each time?) + + logs = append(logs, log) + } + + lp.orm.InsertLogs(ctx, logs) + return nil +} + +func (lp *LogPoller) Start(context.Context) error { + if lp.typeProvider == nil { + return ErrMissingEventTypeProvider + } + cl, err := lp.client.Get() + if err != nil { + return err + } + lp.collector = NewEncodedLogCollector(cl, lp, lp.lggr) + return lp.StartOnce("LogPoller", func() error { + lp.wg.Add(2) + go lp.run() + go lp.backgroundWorkerRun() + + return nil + }) +} + +func (lp *LogPoller) Close() error { + return lp.StopOnce("LogPoller", func() error { + close(lp.chStop) + lp.wg.Wait() + return nil + }) +} + +// RegisterFilter - refer to filters.RegisterFilter for details. +func (lp *LogPoller) RegisterFilter(ctx context.Context, filter Filter) error { + ctx, cancel := lp.chStop.Ctx(ctx) + defer cancel() + return lp.filters.RegisterFilter(ctx, filter) +} + +// UnregisterFilter refer to filters.UnregisterFilter for details +func (lp *LogPoller) UnregisterFilter(ctx context.Context, name string) error { + ctx, cancel := lp.chStop.Ctx(ctx) + defer cancel() + return lp.filters.UnregisterFilter(ctx, name) +} + +func (lp *LogPoller) loadFilters(ctx context.Context) error { + retryTicker := services.TickerConfig{Initial: 0, JitterPct: services.DefaultJitter}.NewTicker(time.Second) + defer retryTicker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-retryTicker.C: + } + err := lp.filters.LoadFilters(ctx) + if err != nil { + lp.lggr.Errorw("Failed loading filters in init logpoller loop, retrying later", "err", err) + } + } + // unreachable +} + +func (lp *LogPoller) run() { + defer lp.wg.Done() + ctx, cancel := lp.chStop.NewCtx() + defer cancel() + err := lp.loadFilters(ctx) + if err != nil { + lp.lggr.Warnw("Failed loading filters", "err", err) + return + } + + var blocks chan struct { + BlockNumber int64 + Logs any // to be defined + } + + for { + select { + case <-ctx.Done(): + return + case block := <-blocks: + filtersToBackfill := lp.filters.ConsumeFiltersToBackfill() + + // TODO: NONEVM-916 parse, filters and persist logs + // NOTE: removal of filters occurs in the separate goroutine, so there is a chance that upon insert + // of log corresponding filter won't be present in the db. Ensure to refilter and retry on insert error + for _, filter := range filtersToBackfill { + go lp.startFilterBackfill(ctx, filter, block.BlockNumber) + } + } + } +} + +func (lp *LogPoller) backgroundWorkerRun() { + defer lp.wg.Done() + ctx, cancel := lp.chStop.NewCtx() + defer cancel() + + pruneFilters := services.NewTicker(time.Minute) + defer pruneFilters.Stop() + for { + select { + case <-ctx.Done(): + return + case <-pruneFilters.C: + err := lp.filters.PruneFilters(ctx) + if err != nil { + lp.lggr.Errorw("Failed to prune filters", "err", err) + } + } + } +} + +func (lp *LogPoller) startFilterBackfill(ctx context.Context, filter Filter, toBlock int64) { + // TODO: NONEVM-916 start backfill + lp.lggr.Debugw("Starting filter backfill", "filter", filter) +} diff --git a/pkg/solana/logpoller/mock_orm.go b/pkg/solana/logpoller/mock_orm.go new file mode 100644 index 000000000..c22850864 --- /dev/null +++ b/pkg/solana/logpoller/mock_orm.go @@ -0,0 +1,140 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package logpoller + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" +) + +// mockORM is an autogenerated mock type for the ORM type +type mockORM struct { + mock.Mock +} + +// DeleteFilters provides a mock function with given fields: ctx, filters +func (_m *mockORM) DeleteFilters(ctx context.Context, filters map[int64]Filter) error { + ret := _m.Called(ctx, filters) + + if len(ret) == 0 { + panic("no return value specified for DeleteFilters") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, map[int64]Filter) error); ok { + r0 = rf(ctx, filters) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// InsertFilter provides a mock function with given fields: ctx, filter +func (_m *mockORM) InsertFilter(ctx context.Context, filter Filter) (int64, error) { + ret := _m.Called(ctx, filter) + + if len(ret) == 0 { + panic("no return value specified for InsertFilter") + } + + var r0 int64 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, Filter) (int64, error)); ok { + return rf(ctx, filter) + } + if rf, ok := ret.Get(0).(func(context.Context, Filter) int64); ok { + r0 = rf(ctx, filter) + } else { + r0 = ret.Get(0).(int64) + } + + if rf, ok := ret.Get(1).(func(context.Context, Filter) error); ok { + r1 = rf(ctx, filter) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SelectFilters provides a mock function with given fields: ctx +func (_m *mockORM) SelectFilters(ctx context.Context) ([]Filter, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for SelectFilters") + } + + var r0 []Filter + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) ([]Filter, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []Filter); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]Filter) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MarkFilterBackfilled provides a mock function with given fields: ctx, id, earliestBlock +func (_m *mockORM) MarkFilterBackfilled(ctx context.Context, id int64, earliestBlock int64) error { + ret := _m.Called(ctx, id, earliestBlock) + + if len(ret) == 0 { + panic("no return value specified for MarkFilterBackfilled") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, int64, int64) error); ok { + r0 = rf(ctx, id, earliestBlock) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MarkFilterDeleted provides a mock function with given fields: ctx, id +func (_m *mockORM) MarkFilterDeleted(ctx context.Context, id int64) error { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for MarkFilterDeleted") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, int64) error); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// newMockORM creates a new instance of mockORM. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func newMockORM(t interface { + mock.TestingT + Cleanup(func()) +}) *mockORM { + mock := &mockORM{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/solana/logpoller/mocks/rpc_client.go b/pkg/solana/logpoller/mocks/rpc_client.go index 851eba9ec..1d112f399 100644 --- a/pkg/solana/logpoller/mocks/rpc_client.go +++ b/pkg/solana/logpoller/mocks/rpc_client.go @@ -85,9 +85,9 @@ func (_c *RPCClient_GetBlockWithOpts_Call) RunAndReturn(run func(context.Context return _c } -// GetBlocks provides a mock function with given fields: ctx, startSlot, endSlot, commitment -func (_m *RPCClient) GetBlocks(ctx context.Context, startSlot uint64, endSlot *uint64, commitment rpc.CommitmentType) (rpc.BlocksResult, error) { - ret := _m.Called(ctx, startSlot, endSlot, commitment) +// GetBlocks provides a mock function with given fields: ctx, startSlot, endSlot +func (_m *RPCClient) GetBlocks(ctx context.Context, startSlot uint64, endSlot *uint64) (rpc.BlocksResult, error) { + ret := _m.Called(ctx, startSlot, endSlot) if len(ret) == 0 { panic("no return value specified for GetBlocks") @@ -95,19 +95,19 @@ func (_m *RPCClient) GetBlocks(ctx context.Context, startSlot uint64, endSlot *u var r0 rpc.BlocksResult var r1 error - if rf, ok := ret.Get(0).(func(context.Context, uint64, *uint64, rpc.CommitmentType) (rpc.BlocksResult, error)); ok { - return rf(ctx, startSlot, endSlot, commitment) + if rf, ok := ret.Get(0).(func(context.Context, uint64, *uint64) (rpc.BlocksResult, error)); ok { + return rf(ctx, startSlot, endSlot) } - if rf, ok := ret.Get(0).(func(context.Context, uint64, *uint64, rpc.CommitmentType) rpc.BlocksResult); ok { - r0 = rf(ctx, startSlot, endSlot, commitment) + if rf, ok := ret.Get(0).(func(context.Context, uint64, *uint64) rpc.BlocksResult); ok { + r0 = rf(ctx, startSlot, endSlot) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(rpc.BlocksResult) } } - if rf, ok := ret.Get(1).(func(context.Context, uint64, *uint64, rpc.CommitmentType) error); ok { - r1 = rf(ctx, startSlot, endSlot, commitment) + if rf, ok := ret.Get(1).(func(context.Context, uint64, *uint64) error); ok { + r1 = rf(ctx, startSlot, endSlot) } else { r1 = ret.Error(1) } @@ -124,14 +124,13 @@ type RPCClient_GetBlocks_Call struct { // - ctx context.Context // - startSlot uint64 // - endSlot *uint64 -// - commitment rpc.CommitmentType -func (_e *RPCClient_Expecter) GetBlocks(ctx interface{}, startSlot interface{}, endSlot interface{}, commitment interface{}) *RPCClient_GetBlocks_Call { - return &RPCClient_GetBlocks_Call{Call: _e.mock.On("GetBlocks", ctx, startSlot, endSlot, commitment)} +func (_e *RPCClient_Expecter) GetBlocks(ctx interface{}, startSlot interface{}, endSlot interface{}) *RPCClient_GetBlocks_Call { + return &RPCClient_GetBlocks_Call{Call: _e.mock.On("GetBlocks", ctx, startSlot, endSlot)} } -func (_c *RPCClient_GetBlocks_Call) Run(run func(ctx context.Context, startSlot uint64, endSlot *uint64, commitment rpc.CommitmentType)) *RPCClient_GetBlocks_Call { +func (_c *RPCClient_GetBlocks_Call) Run(run func(ctx context.Context, startSlot uint64, endSlot *uint64)) *RPCClient_GetBlocks_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(uint64), args[2].(*uint64), args[3].(rpc.CommitmentType)) + run(args[0].(context.Context), args[1].(uint64), args[2].(*uint64)) }) return _c } @@ -141,34 +140,34 @@ func (_c *RPCClient_GetBlocks_Call) Return(out rpc.BlocksResult, err error) *RPC return _c } -func (_c *RPCClient_GetBlocks_Call) RunAndReturn(run func(context.Context, uint64, *uint64, rpc.CommitmentType) (rpc.BlocksResult, error)) *RPCClient_GetBlocks_Call { +func (_c *RPCClient_GetBlocks_Call) RunAndReturn(run func(context.Context, uint64, *uint64) (rpc.BlocksResult, error)) *RPCClient_GetBlocks_Call { _c.Call.Return(run) return _c } -// GetLatestBlockhash provides a mock function with given fields: ctx, commitment -func (_m *RPCClient) GetLatestBlockhash(ctx context.Context, commitment rpc.CommitmentType) (*rpc.GetLatestBlockhashResult, error) { - ret := _m.Called(ctx, commitment) +// GetSignaturesForAddressWithOpts provides a mock function with given fields: _a0, _a1, _a2 +func (_m *RPCClient) GetSignaturesForAddressWithOpts(_a0 context.Context, _a1 solana.PublicKey, _a2 *rpc.GetSignaturesForAddressOpts) ([]*rpc.TransactionSignature, error) { + ret := _m.Called(_a0, _a1, _a2) if len(ret) == 0 { - panic("no return value specified for GetLatestBlockhash") + panic("no return value specified for GetSignaturesForAddressWithOpts") } - var r0 *rpc.GetLatestBlockhashResult + var r0 []*rpc.TransactionSignature var r1 error - if rf, ok := ret.Get(0).(func(context.Context, rpc.CommitmentType) (*rpc.GetLatestBlockhashResult, error)); ok { - return rf(ctx, commitment) + if rf, ok := ret.Get(0).(func(context.Context, solana.PublicKey, *rpc.GetSignaturesForAddressOpts) ([]*rpc.TransactionSignature, error)); ok { + return rf(_a0, _a1, _a2) } - if rf, ok := ret.Get(0).(func(context.Context, rpc.CommitmentType) *rpc.GetLatestBlockhashResult); ok { - r0 = rf(ctx, commitment) + if rf, ok := ret.Get(0).(func(context.Context, solana.PublicKey, *rpc.GetSignaturesForAddressOpts) []*rpc.TransactionSignature); ok { + r0 = rf(_a0, _a1, _a2) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*rpc.GetLatestBlockhashResult) + r0 = ret.Get(0).([]*rpc.TransactionSignature) } } - if rf, ok := ret.Get(1).(func(context.Context, rpc.CommitmentType) error); ok { - r1 = rf(ctx, commitment) + if rf, ok := ret.Get(1).(func(context.Context, solana.PublicKey, *rpc.GetSignaturesForAddressOpts) error); ok { + r1 = rf(_a0, _a1, _a2) } else { r1 = ret.Error(1) } @@ -176,58 +175,59 @@ func (_m *RPCClient) GetLatestBlockhash(ctx context.Context, commitment rpc.Comm return r0, r1 } -// RPCClient_GetLatestBlockhash_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetLatestBlockhash' -type RPCClient_GetLatestBlockhash_Call struct { +// RPCClient_GetSignaturesForAddressWithOpts_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSignaturesForAddressWithOpts' +type RPCClient_GetSignaturesForAddressWithOpts_Call struct { *mock.Call } -// GetLatestBlockhash is a helper method to define mock.On call -// - ctx context.Context -// - commitment rpc.CommitmentType -func (_e *RPCClient_Expecter) GetLatestBlockhash(ctx interface{}, commitment interface{}) *RPCClient_GetLatestBlockhash_Call { - return &RPCClient_GetLatestBlockhash_Call{Call: _e.mock.On("GetLatestBlockhash", ctx, commitment)} +// GetSignaturesForAddressWithOpts is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 solana.PublicKey +// - _a2 *rpc.GetSignaturesForAddressOpts +func (_e *RPCClient_Expecter) GetSignaturesForAddressWithOpts(_a0 interface{}, _a1 interface{}, _a2 interface{}) *RPCClient_GetSignaturesForAddressWithOpts_Call { + return &RPCClient_GetSignaturesForAddressWithOpts_Call{Call: _e.mock.On("GetSignaturesForAddressWithOpts", _a0, _a1, _a2)} } -func (_c *RPCClient_GetLatestBlockhash_Call) Run(run func(ctx context.Context, commitment rpc.CommitmentType)) *RPCClient_GetLatestBlockhash_Call { +func (_c *RPCClient_GetSignaturesForAddressWithOpts_Call) Run(run func(_a0 context.Context, _a1 solana.PublicKey, _a2 *rpc.GetSignaturesForAddressOpts)) *RPCClient_GetSignaturesForAddressWithOpts_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(rpc.CommitmentType)) + run(args[0].(context.Context), args[1].(solana.PublicKey), args[2].(*rpc.GetSignaturesForAddressOpts)) }) return _c } -func (_c *RPCClient_GetLatestBlockhash_Call) Return(out *rpc.GetLatestBlockhashResult, err error) *RPCClient_GetLatestBlockhash_Call { - _c.Call.Return(out, err) +func (_c *RPCClient_GetSignaturesForAddressWithOpts_Call) Return(_a0 []*rpc.TransactionSignature, _a1 error) *RPCClient_GetSignaturesForAddressWithOpts_Call { + _c.Call.Return(_a0, _a1) return _c } -func (_c *RPCClient_GetLatestBlockhash_Call) RunAndReturn(run func(context.Context, rpc.CommitmentType) (*rpc.GetLatestBlockhashResult, error)) *RPCClient_GetLatestBlockhash_Call { +func (_c *RPCClient_GetSignaturesForAddressWithOpts_Call) RunAndReturn(run func(context.Context, solana.PublicKey, *rpc.GetSignaturesForAddressOpts) ([]*rpc.TransactionSignature, error)) *RPCClient_GetSignaturesForAddressWithOpts_Call { _c.Call.Return(run) return _c } -// GetSignaturesForAddressWithOpts provides a mock function with given fields: _a0, _a1, _a2 -func (_m *RPCClient) GetSignaturesForAddressWithOpts(_a0 context.Context, _a1 solana.PublicKey, _a2 *rpc.GetSignaturesForAddressOpts) ([]*rpc.TransactionSignature, error) { - ret := _m.Called(_a0, _a1, _a2) +// LatestBlockhash provides a mock function with given fields: ctx +func (_m *RPCClient) LatestBlockhash(ctx context.Context) (*rpc.GetLatestBlockhashResult, error) { + ret := _m.Called(ctx) if len(ret) == 0 { - panic("no return value specified for GetSignaturesForAddressWithOpts") + panic("no return value specified for LatestBlockhash") } - var r0 []*rpc.TransactionSignature + var r0 *rpc.GetLatestBlockhashResult var r1 error - if rf, ok := ret.Get(0).(func(context.Context, solana.PublicKey, *rpc.GetSignaturesForAddressOpts) ([]*rpc.TransactionSignature, error)); ok { - return rf(_a0, _a1, _a2) + if rf, ok := ret.Get(0).(func(context.Context) (*rpc.GetLatestBlockhashResult, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func(context.Context, solana.PublicKey, *rpc.GetSignaturesForAddressOpts) []*rpc.TransactionSignature); ok { - r0 = rf(_a0, _a1, _a2) + if rf, ok := ret.Get(0).(func(context.Context) *rpc.GetLatestBlockhashResult); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]*rpc.TransactionSignature) + r0 = ret.Get(0).(*rpc.GetLatestBlockhashResult) } } - if rf, ok := ret.Get(1).(func(context.Context, solana.PublicKey, *rpc.GetSignaturesForAddressOpts) error); ok { - r1 = rf(_a0, _a1, _a2) + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -235,32 +235,30 @@ func (_m *RPCClient) GetSignaturesForAddressWithOpts(_a0 context.Context, _a1 so return r0, r1 } -// RPCClient_GetSignaturesForAddressWithOpts_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSignaturesForAddressWithOpts' -type RPCClient_GetSignaturesForAddressWithOpts_Call struct { +// RPCClient_LatestBlockhash_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LatestBlockhash' +type RPCClient_LatestBlockhash_Call struct { *mock.Call } -// GetSignaturesForAddressWithOpts is a helper method to define mock.On call -// - _a0 context.Context -// - _a1 solana.PublicKey -// - _a2 *rpc.GetSignaturesForAddressOpts -func (_e *RPCClient_Expecter) GetSignaturesForAddressWithOpts(_a0 interface{}, _a1 interface{}, _a2 interface{}) *RPCClient_GetSignaturesForAddressWithOpts_Call { - return &RPCClient_GetSignaturesForAddressWithOpts_Call{Call: _e.mock.On("GetSignaturesForAddressWithOpts", _a0, _a1, _a2)} +// LatestBlockhash is a helper method to define mock.On call +// - ctx context.Context +func (_e *RPCClient_Expecter) LatestBlockhash(ctx interface{}) *RPCClient_LatestBlockhash_Call { + return &RPCClient_LatestBlockhash_Call{Call: _e.mock.On("LatestBlockhash", ctx)} } -func (_c *RPCClient_GetSignaturesForAddressWithOpts_Call) Run(run func(_a0 context.Context, _a1 solana.PublicKey, _a2 *rpc.GetSignaturesForAddressOpts)) *RPCClient_GetSignaturesForAddressWithOpts_Call { +func (_c *RPCClient_LatestBlockhash_Call) Run(run func(ctx context.Context)) *RPCClient_LatestBlockhash_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(solana.PublicKey), args[2].(*rpc.GetSignaturesForAddressOpts)) + run(args[0].(context.Context)) }) return _c } -func (_c *RPCClient_GetSignaturesForAddressWithOpts_Call) Return(_a0 []*rpc.TransactionSignature, _a1 error) *RPCClient_GetSignaturesForAddressWithOpts_Call { - _c.Call.Return(_a0, _a1) +func (_c *RPCClient_LatestBlockhash_Call) Return(out *rpc.GetLatestBlockhashResult, err error) *RPCClient_LatestBlockhash_Call { + _c.Call.Return(out, err) return _c } -func (_c *RPCClient_GetSignaturesForAddressWithOpts_Call) RunAndReturn(run func(context.Context, solana.PublicKey, *rpc.GetSignaturesForAddressOpts) ([]*rpc.TransactionSignature, error)) *RPCClient_GetSignaturesForAddressWithOpts_Call { +func (_c *RPCClient_LatestBlockhash_Call) RunAndReturn(run func(context.Context) (*rpc.GetLatestBlockhashResult, error)) *RPCClient_LatestBlockhash_Call { _c.Call.Return(run) return _c } diff --git a/pkg/solana/logpoller/models.go b/pkg/solana/logpoller/models.go new file mode 100644 index 000000000..4e786b782 --- /dev/null +++ b/pkg/solana/logpoller/models.go @@ -0,0 +1,44 @@ +package logpoller + +import ( + "time" + + "github.com/lib/pq" +) + +type Filter struct { + ID int64 // only for internal usage. Values set externally are ignored. + Name string + Address PublicKey + EventName string + EventSig EventSignature + StartingBlock int64 + EventIdl EventIdl + SubkeyPaths SubkeyPaths + Retention time.Duration + MaxLogsKept int64 + IsDeleted bool // only for internal usage. Values set externally are ignored. +} + +func (f Filter) MatchSameLogs(other Filter) bool { + return f.Address == other.Address && f.EventSig == other.EventSig && + f.EventIdl.Equal(other.EventIdl) && f.SubkeyPaths.Equal(other.SubkeyPaths) +} + +type Log struct { + ID int64 + FilterID int64 + ChainID string + LogIndex int64 + BlockHash Hash + BlockNumber int64 + BlockTimestamp time.Time + Address PublicKey + EventSig EventSignature + SubkeyValues pq.ByteaArray + TxHash Signature + Data []byte + CreatedAt time.Time + ExpiresAt *time.Time + SequenceNum int64 +} diff --git a/pkg/solana/logpoller/models_test.go b/pkg/solana/logpoller/models_test.go new file mode 100644 index 000000000..7aa6c651c --- /dev/null +++ b/pkg/solana/logpoller/models_test.go @@ -0,0 +1,58 @@ +package logpoller + +import ( + "testing" + "time" + + "github.com/gagliardetto/solana-go" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +func newRandomFilter(t *testing.T) Filter { + return Filter{ + Name: uuid.NewString(), + Address: newRandomPublicKey(t), + EventName: "event", + EventSig: newRandomEventSignature(t), + StartingBlock: 1, + EventIDL: "{}", + SubkeyPaths: [][]string{{"a", "b"}, {"c"}}, + Retention: 1000, + MaxLogsKept: 3, + } +} + +func newRandomPublicKey(t *testing.T) PublicKey { + privateKey, err := solana.NewRandomPrivateKey() + require.NoError(t, err) + pubKey := privateKey.PublicKey() + return PublicKey(pubKey) +} + +func newRandomEventSignature(t *testing.T) EventSignature { + pubKey := newRandomPublicKey(t) + return EventSignature(pubKey[:8]) +} + +func newRandomLog(t *testing.T, filterID int64, chainID string) Log { + privateKey, err := solana.NewRandomPrivateKey() + require.NoError(t, err) + pubKey := privateKey.PublicKey() + data := []byte("solana is fun") + signature, err := privateKey.Sign(data) + require.NoError(t, err) + return Log{ + FilterID: filterID, + ChainID: chainID, + LogIndex: 1, + BlockHash: Hash(pubKey), + BlockNumber: 10, + BlockTimestamp: time.Unix(1731590113, 0), + Address: PublicKey(pubKey), + EventSig: EventSignature{3, 2, 1}, + SubkeyValues: [][]byte{{3, 2, 1}, {1}, {1, 2}, pubKey.Bytes()}, + TxHash: Signature(signature), + Data: data, + } +} diff --git a/pkg/solana/logpoller/orm.go b/pkg/solana/logpoller/orm.go new file mode 100644 index 000000000..85af4f1f3 --- /dev/null +++ b/pkg/solana/logpoller/orm.go @@ -0,0 +1,198 @@ +package logpoller + +import ( + "context" + "errors" + "fmt" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" +) + +var _ ORM = (*DSORM)(nil) + +type DSORM struct { + chainID string + ds sqlutil.DataSource + lggr logger.Logger +} + +// NewORM creates an DSORM scoped to chainID. +func NewORM(chainID string, ds sqlutil.DataSource, lggr logger.Logger) *DSORM { + return &DSORM{ + chainID: chainID, + ds: ds, + lggr: lggr, + } +} + +func (o *DSORM) ChainID() string { + return o.chainID +} + +func (o *DSORM) Transact(ctx context.Context, fn func(*DSORM) error) (err error) { + return sqlutil.Transact(ctx, o.new, o.ds, nil, fn) +} + +// new returns a NewORM like o, but backed by ds. +func (o *DSORM) new(ds sqlutil.DataSource) *DSORM { return NewORM(o.chainID, ds, o.lggr) } + +// InsertFilter is idempotent. +// +// Each address/event pair must have a unique job id, so it may be removed when the job is deleted. +// Returns ID for updated or newly inserted filter. +func (o *DSORM) InsertFilter(ctx context.Context, filter Filter) (id int64, err error) { + args, err := newQueryArgs(o.chainID). + withField("name", filter.Name). + withRetention(filter.Retention). + withMaxLogsKept(filter.MaxLogsKept). + withName(filter.Name). + withAddress(filter.Address). + withEventName(filter.EventName). + withEventSig(filter.EventSig). + withStartingBlock(filter.StartingBlock). + withEventIDL(filter.EventIdl). + withSubkeyPaths(filter.SubkeyPaths). + toArgs() + if err != nil { + return 0, err + } + + // '::' has to be escaped in the query string + // https://github.com/jmoiron/sqlx/issues/91, https://github.com/jmoiron/sqlx/issues/428 + query := ` + INSERT INTO solana.log_poller_filters + (chain_id, name, address, event_name, event_sig, starting_block, event_idl, subkey_paths, retention, max_logs_kept) + VALUES (:chain_id, :name, :address, :event_name, :event_sig, :starting_block, :event_idl, :subkey_paths, :retention, :max_logs_kept) + ON CONFLICT (chain_id, name) WHERE NOT is_deleted DO UPDATE SET + event_name = EXCLUDED.event_name, + starting_block = EXCLUDED.starting_block, + retention = EXCLUDED.retention, + max_logs_kept = EXCLUDED.max_logs_kept + RETURNING id;` + + query, sqlArgs, err := o.ds.BindNamed(query, args) + if err != nil { + return 0, err + } + if err = o.ds.GetContext(ctx, &id, query, sqlArgs...); err != nil { + return 0, err + } + return id, nil +} + +// GetFilterByID returns filter by ID +func (o *DSORM) GetFilterByID(ctx context.Context, id int64) (Filter, error) { + query := filtersQuery("WHERE id = $1") + var result Filter + err := o.ds.GetContext(ctx, &result, query, id) + return result, err +} + +func (o *DSORM) MarkFilterDeleted(ctx context.Context, id int64) (err error) { + query := `UPDATE solana.log_poller_filters SET is_deleted = true WHERE id = $1` + _, err = o.ds.ExecContext(ctx, query, id) + return err +} + +func (o *DSORM) DeleteFilter(ctx context.Context, id int64) (err error) { + query := `DELETE FROM solana.log_poller_filters WHERE id = $1` + _, err = o.ds.ExecContext(ctx, query, id) + return err +} + +func (o *DSORM) DeleteFilters(ctx context.Context, filters map[int64]Filter) error { + for _, filter := range filters { + err := o.DeleteFilter(ctx, filter.ID) + if err != nil { + return fmt.Errorf("error deleting filter %s (%d): %w", filter.Name, filter.ID, err) + } + } + + return nil +} + +func (o *DSORM) SelectFilters(ctx context.Context) ([]Filter, error) { + query := filtersQuery("WHERE chain_id = $1") + var filters []Filter + err := o.ds.SelectContext(ctx, &filters, query, o.chainID) + return filters, err +} + +// InsertLogs is idempotent to support replays. +func (o *DSORM) InsertLogs(ctx context.Context, logs []Log) error { + if err := o.validateLogs(logs); err != nil { + return err + } + return o.Transact(ctx, func(orm *DSORM) error { + return orm.insertLogsWithinTx(ctx, logs, orm.ds) + }) +} + +func (o *DSORM) insertLogsWithinTx(ctx context.Context, logs []Log, tx sqlutil.DataSource) error { + batchInsertSize := 4000 + for i := 0; i < len(logs); i += batchInsertSize { + start, end := i, i+batchInsertSize + if end > len(logs) { + end = len(logs) + } + + query := `INSERT INTO solana.logs + (filter_id, chain_id, log_index, block_hash, block_number, block_timestamp, address, event_sig, subkey_values, tx_hash, data, created_at, expires_at, sequence_num) + VALUES + (:filter_id, :chain_id, :log_index, :block_hash, :block_number, :block_timestamp, :address, :event_sig, :subkey_values, :tx_hash, :data, NOW(), :expires_at, :sequence_num) + ON CONFLICT DO NOTHING` + + _, err := tx.NamedExecContext(ctx, query, logs[start:end]) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) && batchInsertSize > 500 { + // In case of DB timeouts, try to insert again with a smaller batch upto a limit + batchInsertSize /= 2 + i -= batchInsertSize // counteract +=batchInsertSize on next loop iteration + continue + } + return err + } + } + return nil +} + +func (o *DSORM) validateLogs(logs []Log) error { + for _, log := range logs { + if o.chainID != log.ChainID { + return fmt.Errorf("invalid chainID in log got %v want %v", log.ChainID, o.chainID) + } + } + return nil +} + +// SelectLogs finds the logs in a given block range. +func (o *DSORM) SelectLogs(ctx context.Context, start, end int64, address PublicKey, eventSig EventSignature) ([]Log, error) { + args, err := newQueryArgsForEvent(o.chainID, address, eventSig). + withStartBlock(start). + withEndBlock(end). + toArgs() + if err != nil { + return nil, err + } + + query := logsQuery(` + WHERE chain_id = :chain_id + AND address = :address + AND event_sig = :event_sig + AND block_number >= :start_block + AND block_number <= :end_block + ORDER BY block_number, log_index`) + + var logs []Log + query, sqlArgs, err := o.ds.BindNamed(query, args) + if err != nil { + return nil, err + } + + err = o.ds.SelectContext(ctx, &logs, query, sqlArgs...) + if err != nil { + return nil, err + } + return logs, nil +} diff --git a/pkg/solana/logpoller/orm_test.go b/pkg/solana/logpoller/orm_test.go new file mode 100644 index 000000000..b296fb410 --- /dev/null +++ b/pkg/solana/logpoller/orm_test.go @@ -0,0 +1,182 @@ +//go:build db_tests + +package logpoller + +import ( + "os" + "testing" + + "github.com/gagliardetto/solana-go" + "github.com/google/uuid" + _ "github.com/jackc/pgx/v4/stdlib" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil/pg" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" + "github.com/stretchr/testify/require" +) + +// NOTE: at the moment it's not possible to run all db tests at once. This issue will be addressed separately + +func TestLogPollerFilters(t *testing.T) { + lggr := logger.Test(t) + dbURL, ok := os.LookupEnv("CL_DATABASE_URL") + require.True(t, ok, "CL_DATABASE_URL must be set") + t.Run("Ensure all fields are readable/writable", func(t *testing.T) { + privateKey, err := solana.NewRandomPrivateKey() + require.NoError(t, err) + pubKey := privateKey.PublicKey() + chainID := uuid.NewString() + dbx := pg.NewSqlxDB(t, dbURL) + orm := NewORM(chainID, dbx, lggr) + filters := []Filter{ + { + Name: "happy path", + Address: PublicKey(pubKey), + EventName: "event", + EventSig: EventSignature{1, 2, 3}, + StartingBlock: 1, + EventIDL: "{}", + SubkeyPaths: SubkeyPaths([][]string{{"a", "b"}, {"c"}}), + Retention: 1000, + MaxLogsKept: 3, + }, + { + Name: "empty sub key paths", + Address: PublicKey(pubKey), + EventName: "event", + EventSig: EventSignature{1, 2, 3}, + StartingBlock: 1, + EventIDL: "{}", + SubkeyPaths: SubkeyPaths([][]string{}), + Retention: 1000, + MaxLogsKept: 3, + }, + { + Name: "nil sub key paths", + Address: PublicKey(pubKey), + EventName: "event", + EventSig: EventSignature{1, 2, 3}, + StartingBlock: 1, + EventIDL: "{}", + SubkeyPaths: nil, + Retention: 1000, + MaxLogsKept: 3, + }, + } + + for _, filter := range filters { + t.Run("Read/write filter: "+filter.Name, func(t *testing.T) { + ctx := tests.Context(t) + id, err := orm.InsertFilter(ctx, filter) + require.NoError(t, err) + filter.ID = id + dbFilter, err := orm.GetFilterByID(ctx, id) + require.NoError(t, err) + require.Equal(t, filter, dbFilter) + }) + } + }) + t.Run("Updates non primary fields if name and chainID is not unique", func(t *testing.T) { + chainID := uuid.NewString() + dbx := pg.NewSqlxDB(t, dbURL) + orm := NewORM(chainID, dbx, lggr) + filter := newRandomFilter(t) + ctx := tests.Context(t) + id, err := orm.InsertFilter(ctx, filter) + require.NoError(t, err) + filter.EventName = uuid.NewString() + filter.StartingBlock++ + filter.Retention++ + filter.MaxLogsKept++ + id2, err := orm.InsertFilter(ctx, filter) + require.NoError(t, err) + require.Equal(t, id, id2) + dbFilter, err := orm.GetFilterByID(ctx, id) + require.NoError(t, err) + filter.ID = id + require.Equal(t, filter, dbFilter) + }) + t.Run("Allows reuse name of a filter marked as deleted", func(t *testing.T) { + chainID := uuid.NewString() + dbx := pg.NewSqlxDB(t, dbURL) + orm := NewORM(chainID, dbx, lggr) + filter := newRandomFilter(t) + ctx := tests.Context(t) + filterID, err := orm.InsertFilter(ctx, filter) + require.NoError(t, err) + // mark deleted + err = orm.MarkFilterDeleted(ctx, filterID) + require.NoError(t, err) + // ensure marked as deleted + dbFilter, err := orm.GetFilterByID(ctx, filterID) + require.NoError(t, err) + require.True(t, dbFilter.IsDeleted, "expected to be deleted") + newFilterID, err := orm.InsertFilter(ctx, filter) + require.NoError(t, err) + require.NotEqual(t, newFilterID, filterID, "expected db to generate new filter as we can not be sure that new one matches the same logs") + }) + t.Run("Allows reuse name for a filter with different chainID", func(t *testing.T) { + dbx := pg.NewSqlxDB(t, dbURL) + orm1 := NewORM(uuid.NewString(), dbx, lggr) + orm2 := NewORM(uuid.NewString(), dbx, lggr) + filter := newRandomFilter(t) + ctx := tests.Context(t) + filterID1, err := orm1.InsertFilter(ctx, filter) + require.NoError(t, err) + filterID2, err := orm2.InsertFilter(ctx, filter) + require.NoError(t, err) + require.NotEqual(t, filterID1, filterID2) + }) + t.Run("Deletes log on parent filter deletion", func(t *testing.T) { + dbx := pg.NewSqlxDB(t, dbURL) + chainID := uuid.NewString() + orm := NewORM(chainID, dbx, lggr) + filter := newRandomFilter(t) + ctx := tests.Context(t) + filterID, err := orm.InsertFilter(ctx, filter) + require.NoError(t, err) + log := newRandomLog(t, filterID, chainID) + err = orm.InsertLogs(ctx, []Log{log}) + require.NoError(t, err) + logs, err := orm.SelectLogs(ctx, 0, log.BlockNumber, log.Address, log.EventSig) + require.NoError(t, err) + require.Len(t, logs, 1) + err = orm.MarkFilterDeleted(ctx, filterID) + require.NoError(t, err) + // logs are expected to be present in db even if filter was marked as deleted + logs, err = orm.SelectLogs(ctx, 0, log.BlockNumber, log.Address, log.EventSig) + require.NoError(t, err) + require.Len(t, logs, 1) + err = orm.DeleteFilter(ctx, filterID) + require.NoError(t, err) + logs, err = orm.SelectLogs(ctx, 0, log.BlockNumber, log.Address, log.EventSig) + require.NoError(t, err) + require.Len(t, logs, 0) + }) +} + +func TestLogPollerLogs(t *testing.T) { + lggr := logger.Test(t) + dbURL, ok := os.LookupEnv("CL_DATABASE_URL") + require.True(t, ok, "CL_DATABASE_URL must be set") + chainID := uuid.NewString() + dbx := pg.NewSqlxDB(t, dbURL) + orm := NewORM(chainID, dbx, lggr) + + ctx := tests.Context(t) + // create filter as it's required for a log + filterID, err := orm.InsertFilter(ctx, newRandomFilter(t)) + require.NoError(t, err) + log := newRandomLog(t, filterID, chainID) + err = orm.InsertLogs(ctx, []Log{log}) + require.NoError(t, err) + // insert of the same Log should not produce two instances + err = orm.InsertLogs(ctx, []Log{log}) + require.NoError(t, err) + dbLogs, err := orm.SelectLogs(ctx, 0, 100, log.Address, log.EventSig) + require.NoError(t, err) + require.Len(t, dbLogs, 1) + log.ID = dbLogs[0].ID + log.CreatedAt = dbLogs[0].CreatedAt + require.Equal(t, log, dbLogs[0]) +} diff --git a/pkg/solana/logpoller/parser.go b/pkg/solana/logpoller/parser.go new file mode 100644 index 000000000..a3a054594 --- /dev/null +++ b/pkg/solana/logpoller/parser.go @@ -0,0 +1,6 @@ +package logpoller + +var ( + logsFields = [...]string{"id", "filter_id", "chain_id", "log_index", "block_hash", "block_number", "block_timestamp", "address", "event_sig", "subkey_values", "tx_hash", "data", "created_at", "expires_at", "sequence_num"} + filterFields = [...]string{"id", "name", "address", "event_name", "event_sig", "starting_block", "event_idl", "subkey_paths", "retention", "max_logs_kept", "is_deleted"} +) diff --git a/pkg/solana/logpoller/query.go b/pkg/solana/logpoller/query.go new file mode 100644 index 000000000..0bda6b2a0 --- /dev/null +++ b/pkg/solana/logpoller/query.go @@ -0,0 +1,130 @@ +package logpoller + +import ( + "errors" + "fmt" + "strings" + "time" +) + +// queryArgs is a helper for building the arguments to a postgres query created by DSORM +// Besides the convenience methods, it also keeps track of arguments validation and sanitization. +type queryArgs struct { + args map[string]any + idxLookup map[string]uint8 + err []error +} + +func newQueryArgs(chainID string) *queryArgs { + return &queryArgs{ + args: map[string]any{ + "chain_id": chainID, + }, + idxLookup: make(map[string]uint8), + err: []error{}, + } +} + +func (q *queryArgs) withField(fieldName string, value any) *queryArgs { + _, args := q.withIndexableField(fieldName, value, false) + + return args +} + +func (q *queryArgs) withIndexableField(fieldName string, value any, addIndex bool) (string, *queryArgs) { + if addIndex { + idx := q.nextIdx(fieldName) + idxName := fmt.Sprintf("%s_%d", fieldName, idx) + + q.idxLookup[fieldName] = idx + fieldName = idxName + } + + q.args[fieldName] = value + + return fieldName, q +} + +func (q *queryArgs) nextIdx(baseFieldName string) uint8 { + idx, ok := q.idxLookup[baseFieldName] + if !ok { + return 0 + } + + return idx + 1 +} + +// withName sets the Name field in queryArgs. +func (q *queryArgs) withName(name string) *queryArgs { + return q.withField("name", name) +} + +// withAddress sets the Address field in queryArgs. +func (q *queryArgs) withAddress(address PublicKey) *queryArgs { + return q.withField("address", address) +} + +// withEventName sets the EventName field in queryArgs. +func (q *queryArgs) withEventName(eventName string) *queryArgs { + return q.withField("event_name", eventName) +} + +// withEventSig sets the EventSig field in queryArgs. +func (q *queryArgs) withEventSig(eventSig EventSignature) *queryArgs { + return q.withField("event_sig", eventSig) +} + +// withStartingBlock sets the StartingBlock field in queryArgs. +func (q *queryArgs) withStartingBlock(startingBlock int64) *queryArgs { + return q.withField("starting_block", startingBlock) +} + +// withEventIDL sets the EventIDL field in queryArgs. +func (q *queryArgs) withEventIDL(eventIdl EventIdl) *queryArgs { + return q.withField("event_idl", eventIdl) +} + +// withSubkeyPaths sets the SubkeyPaths field in queryArgs. +func (q *queryArgs) withSubkeyPaths(subkeyPaths [][]string) *queryArgs { + return q.withField("subkey_paths", subkeyPaths) +} + +// withRetention sets the Retention field in queryArgs. +func (q *queryArgs) withRetention(retention time.Duration) *queryArgs { + return q.withField("retention", retention) +} + +// withMaxLogsKept sets the MaxLogsKept field in queryArgs. +func (q *queryArgs) withMaxLogsKept(maxLogsKept int64) *queryArgs { + return q.withField("max_logs_kept", maxLogsKept) +} + +func newQueryArgsForEvent(chainID string, address PublicKey, eventSig EventSignature) *queryArgs { + return newQueryArgs(chainID). + withAddress(address). + withEventSig(eventSig) +} + +func (q *queryArgs) withStartBlock(startBlock int64) *queryArgs { + return q.withField("start_block", startBlock) +} + +func (q *queryArgs) withEndBlock(endBlock int64) *queryArgs { + return q.withField("end_block", endBlock) +} + +func logsQuery(clause string) string { + return fmt.Sprintf(`SELECT %s FROM solana.logs %s`, strings.Join(logsFields[:], ", "), clause) +} + +func filtersQuery(clause string) string { + return fmt.Sprintf(`SELECT %s FROM solana.log_poller_filters %s`, strings.Join(filterFields[:], ", "), clause) +} + +func (q *queryArgs) toArgs() (map[string]any, error) { + if len(q.err) > 0 { + return nil, errors.Join(q.err...) + } + + return q.args, nil +} diff --git a/pkg/solana/logpoller/types.go b/pkg/solana/logpoller/types.go new file mode 100644 index 000000000..ba3812a3e --- /dev/null +++ b/pkg/solana/logpoller/types.go @@ -0,0 +1,148 @@ +package logpoller + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + "reflect" + "slices" + + "github.com/gagliardetto/solana-go" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/codec" +) + +type PublicKey solana.PublicKey + +// Scan implements Scanner for database/sql. +func (k *PublicKey) Scan(src interface{}) error { + return scanFixedLengthArray("PublicKey", solana.PublicKeyLength, src, k[:]) +} + +// Value implements valuer for database/sql. +func (k PublicKey) Value() (driver.Value, error) { + return k[:], nil +} + +func (k PublicKey) ToSolana() solana.PublicKey { + return solana.PublicKey(k) +} + +type Hash solana.Hash + +// Scan implements Scanner for database/sql. +func (h *Hash) Scan(src interface{}) error { + return scanFixedLengthArray("Hash", solana.PublicKeyLength, src, h[:]) +} + +// Value implements valuer for database/sql. +func (h Hash) Value() (driver.Value, error) { + return h[:], nil +} + +func (h Hash) ToSolana() solana.Hash { + return solana.Hash(h) +} + +type Signature solana.Signature + +// Scan implements Scanner for database/sql. +func (s *Signature) Scan(src interface{}) error { + return scanFixedLengthArray("Signature", solana.SignatureLength, src, s[:]) +} + +// Value implements valuer for database/sql. +func (s Signature) Value() (driver.Value, error) { + return s[:], nil +} + +func (s Signature) ToSolana() solana.Signature { + return solana.Signature(s) +} + +func scanFixedLengthArray(name string, maxLength int, src interface{}, dest []byte) error { + srcB, ok := src.([]byte) + if !ok { + return fmt.Errorf("can't scan %T into %s", src, name) + } + if len(srcB) != maxLength { + return fmt.Errorf("can't scan []byte of len %d into %s, want %d", len(srcB), name, maxLength) + } + copy(dest, srcB) + return nil +} + +type SubkeyPaths [][]string + +func (p SubkeyPaths) Value() (driver.Value, error) { + return json.Marshal([][]string(p)) +} + +func (p *SubkeyPaths) Scan(src interface{}) error { + return scanJson("SubkeyPaths", p, src) +} + +func (p SubkeyPaths) Equal(o SubkeyPaths) bool { + return slices.EqualFunc(p, o, slices.Equal) +} + +const EventSignatureLength = 8 + +type EventSignature [EventSignatureLength]byte + +// Scan implements Scanner for database/sql. +func (s *EventSignature) Scan(src interface{}) error { + return scanFixedLengthArray("EventSignature", EventSignatureLength, src, s[:]) +} + +// Value implements valuer for database/sql. +func (s EventSignature) Value() (driver.Value, error) { + return s[:], nil +} + +type EventTypeProvider interface { + CreateType(eventIdl codec.IdlEvent, typedefSlice codec.IdlTypeDefSlice, subKeyPath []string) (any, error) +} + +type EventIdl struct { + codec.IdlEvent + codec.IdlTypeDefSlice +} + +func (e *EventIdl) Scan(src interface{}) error { + return scanJson("EventIdl", e, src) +} + +func (e EventIdl) Value() (driver.Value, error) { + return json.Marshal(map[string]any{ + "IdlEvent": e.IdlEvent, + "IdlTypeDefSlice": e.IdlTypeDefSlice, + }) +} + +func (p EventIdl) Equal(o EventIdl) bool { + return reflect.DeepEqual(p, o) +} + +func scanJson(name string, dest, src interface{}) error { + var bSrc []byte + switch src := src.(type) { + case string: + bSrc = []byte(src) + case []byte: + bSrc = src + default: + return fmt.Errorf("can't scan %T into %s", src, name) + } + + if len(bSrc) == 0 || string(bSrc) == "null" { + return nil + } + + err := json.Unmarshal(bSrc, dest) + if err != nil { + return fmt.Errorf("failed to scan %v into %s: %w", string(bSrc), name, err) + } + + return nil +} diff --git a/pkg/solana/logpoller/utils/anchor.go b/pkg/solana/logpoller/utils/anchor.go new file mode 100644 index 000000000..b042fb67d --- /dev/null +++ b/pkg/solana/logpoller/utils/anchor.go @@ -0,0 +1,290 @@ +package utils + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "fmt" + "regexp" + "strconv" + "strings" + "testing" + "time" + + bin "github.com/gagliardetto/binary" + "github.com/gagliardetto/solana-go" + "github.com/gagliardetto/solana-go/rpc" + + "github.com/stretchr/testify/require" +) + +var ZeroAddress = [32]byte{} + +func MakeRandom32ByteArray() [32]byte { + a := make([]byte, 32) + if _, err := rand.Read(a); err != nil { + panic(err) // should never panic but check in case + } + return [32]byte(a) +} + +func Uint64ToLE(chain uint64) []byte { + chainLE := make([]byte, 8) + binary.LittleEndian.PutUint64(chainLE, chain) + return chainLE +} + +func To28BytesLE(value uint64) [28]byte { + le := make([]byte, 28) + binary.LittleEndian.PutUint64(le, value) + return [28]byte(le) +} + +func Map[T, V any](ts []T, fn func(T) V) []V { + result := make([]V, len(ts)) + for i, t := range ts { + result[i] = fn(t) + } + return result +} + +const DiscriminatorLength = 8 + +func Discriminator(namespace, name string) [DiscriminatorLength]byte { + h := sha256.New() + h.Write([]byte(fmt.Sprintf("%s:%s", namespace, name))) + return [DiscriminatorLength]byte(h.Sum(nil)[:DiscriminatorLength]) +} + +func FundAccounts(ctx context.Context, accounts []solana.PrivateKey, solanaGoClient *rpc.Client, t *testing.T) { + sigs := []solana.Signature{} + for _, v := range accounts { + sig, err := solanaGoClient.RequestAirdrop(ctx, v.PublicKey(), 1000*solana.LAMPORTS_PER_SOL, rpc.CommitmentFinalized) + require.NoError(t, err) + sigs = append(sigs, sig) + } + + // wait for confirmation so later transactions don't fail + remaining := len(sigs) + count := 0 + for remaining > 0 { + count++ + statusRes, sigErr := solanaGoClient.GetSignatureStatuses(ctx, true, sigs...) + require.NoError(t, sigErr) + require.NotNil(t, statusRes) + require.NotNil(t, statusRes.Value) + + unconfirmedTxCount := 0 + for _, res := range statusRes.Value { + if res == nil || res.ConfirmationStatus == rpc.ConfirmationStatusProcessed || res.ConfirmationStatus == rpc.ConfirmationStatusConfirmed { + unconfirmedTxCount++ + } + } + remaining = unconfirmedTxCount + + time.Sleep(500 * time.Millisecond) + if count > 60 { + require.NoError(t, fmt.Errorf("unable to find transaction within timeout")) + } + } +} + +func IsEvent(event string, data []byte) bool { + if len(data) < 8 { + return false + } + d := Discriminator("event", event) + return bytes.Equal(d[:], data[:8]) +} + +func ParseEvent(logs []string, event string, obj interface{}, print ...bool) error { + for _, v := range logs { + if strings.Contains(v, "Program data:") { + encodedData := strings.TrimSpace(strings.TrimPrefix(v, "Program data:")) + data, err := base64.StdEncoding.DecodeString(encodedData) + if err != nil { + return err + } + if IsEvent(event, data) { + if err := bin.UnmarshalBorsh(obj, data); err != nil { + return err + } + + if len(print) > 0 && print[0] { + fmt.Printf("%s: %+v\n", event, obj) + } + return nil + } + } + } + return fmt.Errorf("%s: event not found", event) +} + +func ParseMultipleEvents[T any](logs []string, event string, print bool) ([]T, error) { + var results []T + for _, v := range logs { + if strings.Contains(v, "Program data:") { + encodedData := strings.TrimSpace(strings.TrimPrefix(v, "Program data:")) + data, err := base64.StdEncoding.DecodeString(encodedData) + if err != nil { + return nil, err + } + if IsEvent(event, data) { + var obj T + if err := bin.UnmarshalBorsh(&obj, data); err != nil { + return nil, err + } + + if print { + fmt.Printf("%s: %+v\n", event, obj) + } + + results = append(results, obj) + } + } + } + if len(results) == 0 { + return nil, fmt.Errorf("%s: event not found", event) + } + + return results, nil +} + +type AnchorInstruction struct { + Name string + ProgramID string + Logs []string + ComputeUnits int + InnerCalls []*AnchorInstruction +} + +// Parses the log messages from an Anchor program and returns a list of AnchorInstructions. +func ParseLogMessages(logMessages []string) []*AnchorInstruction { + var instructions []*AnchorInstruction + var stack []*AnchorInstruction + var currentInstruction *AnchorInstruction + + programInvokeRegex := regexp.MustCompile(`Program (\w+) invoke`) + programSuccessRegex := regexp.MustCompile(`Program (\w+) success`) + computeUnitsRegex := regexp.MustCompile(`Program (\w+) consumed (\d+) of \d+ compute units`) + + for _, line := range logMessages { + line = strings.TrimSpace(line) + + // Program invocation - push to stack + if match := programInvokeRegex.FindStringSubmatch(line); len(match) > 1 { + newInstruction := &AnchorInstruction{ + ProgramID: match[1], + Name: "", + Logs: []string{}, + ComputeUnits: 0, + InnerCalls: []*AnchorInstruction{}, + } + + if len(stack) == 0 { + instructions = append(instructions, newInstruction) + } else { + stack[len(stack)-1].InnerCalls = append(stack[len(stack)-1].InnerCalls, newInstruction) + } + + stack = append(stack, newInstruction) + currentInstruction = newInstruction + continue + } + + // Program success - pop from stack + if match := programSuccessRegex.FindStringSubmatch(line); len(match) > 1 { + if len(stack) > 0 { + stack = stack[:len(stack)-1] // pop + if len(stack) > 0 { + currentInstruction = stack[len(stack)-1] + } else { + currentInstruction = nil + } + } + continue + } + + // Instruction name + if strings.Contains(line, "Instruction:") { + if currentInstruction != nil { + currentInstruction.Name = strings.TrimSpace(strings.Split(line, "Instruction:")[1]) + } + continue + } + + // Program logs + if strings.HasPrefix(line, "Program log:") { + if currentInstruction != nil { + logMessage := strings.TrimSpace(strings.TrimPrefix(line, "Program log:")) + currentInstruction.Logs = append(currentInstruction.Logs, logMessage) + } + continue + } + + // Compute units + if match := computeUnitsRegex.FindStringSubmatch(line); len(match) > 1 { + programID := match[1] + computeUnits, _ := strconv.Atoi(match[2]) + + // Find the instruction in the stack that matches this program ID + for i := len(stack) - 1; i >= 0; i-- { + if stack[i].ProgramID == programID { + stack[i].ComputeUnits = computeUnits + break + } + } + } + } + + return instructions +} + +// Pretty prints the given Anchor instructions. +// Example usage: +// parsed := utils.ParseLogMessages(result.Meta.LogMessages) +// output := utils.PrintInstructions(parsed) +// t.Logf("Parsed Instructions: %s", output) +func PrintInstructions(instructions []*AnchorInstruction) string { + var output strings.Builder + + var printInstruction func(*AnchorInstruction, int, string) + printInstruction = func(instruction *AnchorInstruction, index int, indent string) { + output.WriteString(fmt.Sprintf("%sInstruction %d: %s\n", indent, index, instruction.Name)) + output.WriteString(fmt.Sprintf("%s Program ID: %s\n", indent, instruction.ProgramID)) + output.WriteString(fmt.Sprintf("%s Compute Units: %d\n", indent, instruction.ComputeUnits)) + output.WriteString(fmt.Sprintf("%s Logs:\n", indent)) + for _, log := range instruction.Logs { + output.WriteString(fmt.Sprintf("%s %s\n", indent, log)) + } + if len(instruction.InnerCalls) > 0 { + output.WriteString(fmt.Sprintf("%s Inner Calls:\n", indent)) + for i, innerCall := range instruction.InnerCalls { + printInstruction(innerCall, i+1, indent+" ") + } + } + } + + for i, instruction := range instructions { + printInstruction(instruction, i+1, "") + } + + return output.String() +} + +func GetBlockTime(ctx context.Context, client *rpc.Client, commitment rpc.CommitmentType) (*solana.UnixTimeSeconds, error) { + block, err := client.GetBlockHeight(ctx, commitment) + if err != nil { + return nil, fmt.Errorf("failed to get block height: %w", err) + } + + blockTime, err := client.GetBlockTime(ctx, block) + if err != nil { + return nil, fmt.Errorf("failed to get block time: %w", err) + } + + return blockTime, nil +}