diff --git a/pkg/solana/chain.go b/pkg/solana/chain.go index 0fa097e85..5bbda3d26 100644 --- a/pkg/solana/chain.go +++ b/pkg/solana/chain.go @@ -313,7 +313,8 @@ func newChain(id string, cfg *config.TOMLConfig, ks core.Keystore, lggr logger.L //bc = internal.NewLoader[monitor.BalanceClient](func() (monitor.BalanceClient, error) { return ch.multiNode.SelectRPC() }) } - ch.lp = logpoller.NewLogPoller(logger.Sugared(logger.Named(lggr, "LogPoller")), logpoller.NewORM(ch.ID(), ds, lggr), lc) + // TODO: import typeProvider function from codec package and pass to constructor + ch.lp = logpoller.NewLogPoller(logger.Sugared(logger.Named(lggr, "LogPoller")), logpoller.NewORM(ch.ID(), ds, lggr), lc, nil) ch.txm = txm.NewTxm(ch.id, tc, sendTx, cfg, ks, lggr) ch.balanceMonitor = monitor.NewBalanceMonitor(ch.id, cfg, lggr, ks, bc) return &ch, nil diff --git a/pkg/solana/logpoller/filters.go b/pkg/solana/logpoller/filters.go index b550918e1..bce439631 100644 --- a/pkg/solana/logpoller/filters.go +++ b/pkg/solana/logpoller/filters.go @@ -12,10 +12,7 @@ import ( "github.com/gagliardetto/solana-go" "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/chainlink-common/pkg/types" - "github.com/smartcontractkit/chainlink-solana/pkg/solana/codec" - "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" "github.com/smartcontractkit/chainlink-solana/pkg/solana/logpoller/utils" ) @@ -29,7 +26,6 @@ type filters struct { filtersToDelete map[int64]Filter filtersMutex sync.RWMutex loadedFilters atomic.Bool - eventCodecs map[int64]types.RemoteCodec knownPrograms map[string]struct{} // fast lookup to see if a base58-encoded ProgramID matches any registered filters knownDiscriminators map[string]struct{} // fast lookup by first 10 characters (60-bits) of a base64-encoded discriminator } @@ -85,11 +81,6 @@ func (fl *filters) RegisterFilter(ctx context.Context, filter Filter) error { return fmt.Errorf("failed to load filters: %w", err) } - eventCodec, err := codec.NewIDLEventCodec(filter.EventIDL, config.BuilderForEncoding(config.EncodingTypeBorsh)) - if err != nil { - return fmt.Errorf("invalid event IDL for filter %s: %w", filter.Name, err) - } - filter.EventSig = utils.Discriminator("event", filter.EventName) fl.filtersMutex.Lock() @@ -125,7 +116,6 @@ func (fl *filters) RegisterFilter(ctx context.Context, filter Filter) error { filtersByID[filter.ID] = filter fl.filtersToBackfill[filterID] = filter - fl.eventCodecs[filter.ID] = eventCodec fl.knownPrograms[filter.Address.ToSolana().String()] = struct{}{} discriminator := base64.StdEncoding.EncodeToString(filter.EventSig[:]) fl.knownDiscriminators[discriminator[:10]] = struct{}{} @@ -210,10 +200,6 @@ func (fl *filters) MatchingFilters(addr PublicKey, eventSignature EventSignature } } -func (fl *filters) EventCodec(ID int64) types.RemoteCodec { - return fl.eventCodecs[ID] -} - // MatchchingFiltersForEncodedEvent - similar to MatchingFilters but accepts a raw encoded event. Under normal operation, // this will be called on every new event that happens on the blockchain, so it's important it returns immediately if it // doesn't match any registered filters. diff --git a/pkg/solana/logpoller/log_poller.go b/pkg/solana/logpoller/log_poller.go index 2a818f8a0..6de5db431 100644 --- a/pkg/solana/logpoller/log_poller.go +++ b/pkg/solana/logpoller/log_poller.go @@ -6,13 +6,12 @@ import ( "errors" "fmt" "math" - "reflect" "sync" "time" + bin "github.com/gagliardetto/binary" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/services" - commontypes "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" @@ -20,7 +19,8 @@ import ( ) var ( - ErrFilterNameConflict = errors.New("filter with such name already exists") + ErrFilterNameConflict = errors.New("filter with such name already exists") + ErrMissingEventTypeProvider = errors.New("cannot start LogPoller without EventTypeProvider") ) //go:generate mockery --name ORM --inpackage --structname mockORM --filename mock_orm.go @@ -42,27 +42,24 @@ type ILogPoller interface { type LogPoller struct { services.StateMachine - lggr logger.SugaredLogger - orm ORM - client internal.Loader[client.Reader] - collector *EncodedLogCollector - - filters *filters - discriminatorLookup map[string]string - events []ProgramEvent - codec commontypes.RemoteCodec - - chStop services.StopChan - wg sync.WaitGroup + lggr logger.SugaredLogger + orm ORM + client internal.Loader[client.Reader] + collector *EncodedLogCollector + filters *filters + typeProvider EventTypeProvider + chStop services.StopChan + wg sync.WaitGroup } -func NewLogPoller(lggr logger.SugaredLogger, orm ORM, cl internal.Loader[client.Reader]) ILogPoller { +func NewLogPoller(lggr logger.SugaredLogger, orm ORM, cl internal.Loader[client.Reader], typeProvider EventTypeProvider) ILogPoller { lggr = logger.Sugared(logger.Named(lggr, "LogPoller")) lp := LogPoller{ - orm: orm, - client: cl, - lggr: lggr, - filters: newFilters(lggr, orm), + orm: orm, + client: cl, + lggr: lggr, + filters: newFilters(lggr, orm), + typeProvider: typeProvider, } return &lp } @@ -100,15 +97,14 @@ func (lp *LogPoller) Process(programEvent ProgramEvent) (err error) { return err } - var event any - err = lp.filters.EventCodec(filter.ID).Decode(ctx, log.Data, &event, filter.EventName) - if err != nil { - return err - } + for _, path := range filter.SubkeyPaths { - err = lp.ExtractSubkeys(reflect.TypeOf(event), filter.SubkeyPaths) - if err != nil { - return err + var event any + event, err = lp.typeProvider.CreateType(filter.EventIdl.IdlEvent, filter.EventIdl.IdlTypeDefSlice, path) + bin.UnmarshalBorsh(&event, log.Data) + if err != nil { + return err + } } // TODO: fill in, and keep track of SequenceNumber for each filter. (Initialize from db on LoadFilters, then increment each time?) @@ -120,26 +116,10 @@ func (lp *LogPoller) Process(programEvent ProgramEvent) (err error) { return nil } -func (lp *LogPoller) ExtractSubkeys(t reflect.Type, paths SubkeyPaths) error { - s := reflect.TypeOf(event) - if s.Kind() != reflect.Struct { - return fmt.Errorf("event type must be struct, got %v. event=%v", t, event) - } - - for _, path := range paths[0] { - field, err := s.FieldByName(path) - for depth := 0; depth < len(paths); depth++ { - for _, path := range paths[depth] { - field, err = field.Type.FieldByName(path) - } - } - } - -} - -func get - func (lp *LogPoller) Start(context.Context) error { + if lp.typeProvider == nil { + return ErrMissingEventTypeProvider + } cl, err := lp.client.Get() if err != nil { return err diff --git a/pkg/solana/logpoller/models.go b/pkg/solana/logpoller/models.go index f21d2549b..4e786b782 100644 --- a/pkg/solana/logpoller/models.go +++ b/pkg/solana/logpoller/models.go @@ -4,8 +4,6 @@ import ( "time" "github.com/lib/pq" - - "github.com/smartcontractkit/chainlink-solana/pkg/solana/codec" ) type Filter struct { @@ -15,7 +13,7 @@ type Filter struct { EventName string EventSig EventSignature StartingBlock int64 - EventIDL codec.IDL + EventIdl EventIdl SubkeyPaths SubkeyPaths Retention time.Duration MaxLogsKept int64 @@ -23,7 +21,8 @@ type Filter struct { } func (f Filter) MatchSameLogs(other Filter) bool { - return f.Address == other.Address && f.EventSig == other.EventSig && f.EventIDL == other.EventIDL && f.SubkeyPaths.Equal(other.SubkeyPaths) + return f.Address == other.Address && f.EventSig == other.EventSig && + f.EventIdl.Equal(other.EventIdl) && f.SubkeyPaths.Equal(other.SubkeyPaths) } type Log struct { diff --git a/pkg/solana/logpoller/orm.go b/pkg/solana/logpoller/orm.go index 97fd6f520..85af4f1f3 100644 --- a/pkg/solana/logpoller/orm.go +++ b/pkg/solana/logpoller/orm.go @@ -26,6 +26,10 @@ func NewORM(chainID string, ds sqlutil.DataSource, lggr logger.Logger) *DSORM { } } +func (o *DSORM) ChainID() string { + return o.chainID +} + func (o *DSORM) Transact(ctx context.Context, fn func(*DSORM) error) (err error) { return sqlutil.Transact(ctx, o.new, o.ds, nil, fn) } @@ -47,7 +51,7 @@ func (o *DSORM) InsertFilter(ctx context.Context, filter Filter) (id int64, err withEventName(filter.EventName). withEventSig(filter.EventSig). withStartingBlock(filter.StartingBlock). - withEventIDL(filter.EventIDL). + withEventIDL(filter.EventIdl). withSubkeyPaths(filter.SubkeyPaths). toArgs() if err != nil { diff --git a/pkg/solana/logpoller/query.go b/pkg/solana/logpoller/query.go index 8b89cc35f..0bda6b2a0 100644 --- a/pkg/solana/logpoller/query.go +++ b/pkg/solana/logpoller/query.go @@ -80,8 +80,8 @@ func (q *queryArgs) withStartingBlock(startingBlock int64) *queryArgs { } // withEventIDL sets the EventIDL field in queryArgs. -func (q *queryArgs) withEventIDL(eventIDL string) *queryArgs { - return q.withField("event_idl", eventIDL) +func (q *queryArgs) withEventIDL(eventIdl EventIdl) *queryArgs { + return q.withField("event_idl", eventIdl) } // withSubkeyPaths sets the SubkeyPaths field in queryArgs. diff --git a/pkg/solana/logpoller/types.go b/pkg/solana/logpoller/types.go index 143c28898..ba3812a3e 100644 --- a/pkg/solana/logpoller/types.go +++ b/pkg/solana/logpoller/types.go @@ -4,9 +4,12 @@ import ( "database/sql/driver" "encoding/json" "fmt" + "reflect" "slices" "github.com/gagliardetto/solana-go" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/codec" ) type PublicKey solana.PublicKey @@ -76,26 +79,7 @@ func (p SubkeyPaths) Value() (driver.Value, error) { } func (p *SubkeyPaths) Scan(src interface{}) error { - var bSrc []byte - switch src := src.(type) { - case string: - bSrc = []byte(src) - case []byte: - bSrc = src - default: - return fmt.Errorf("can't scan %T into SubkeyPaths", src) - } - - if len(bSrc) == 0 || string(bSrc) == "null" { - return nil - } - - err := json.Unmarshal(bSrc, p) - if err != nil { - return fmt.Errorf("failed to scan %v into SubkeyPaths: %w", string(bSrc), err) - } - - return nil + return scanJson("SubkeyPaths", p, src) } func (p SubkeyPaths) Equal(o SubkeyPaths) bool { @@ -115,3 +99,50 @@ func (s *EventSignature) Scan(src interface{}) error { func (s EventSignature) Value() (driver.Value, error) { return s[:], nil } + +type EventTypeProvider interface { + CreateType(eventIdl codec.IdlEvent, typedefSlice codec.IdlTypeDefSlice, subKeyPath []string) (any, error) +} + +type EventIdl struct { + codec.IdlEvent + codec.IdlTypeDefSlice +} + +func (e *EventIdl) Scan(src interface{}) error { + return scanJson("EventIdl", e, src) +} + +func (e EventIdl) Value() (driver.Value, error) { + return json.Marshal(map[string]any{ + "IdlEvent": e.IdlEvent, + "IdlTypeDefSlice": e.IdlTypeDefSlice, + }) +} + +func (p EventIdl) Equal(o EventIdl) bool { + return reflect.DeepEqual(p, o) +} + +func scanJson(name string, dest, src interface{}) error { + var bSrc []byte + switch src := src.(type) { + case string: + bSrc = []byte(src) + case []byte: + bSrc = src + default: + return fmt.Errorf("can't scan %T into %s", src, name) + } + + if len(bSrc) == 0 || string(bSrc) == "null" { + return nil + } + + err := json.Unmarshal(bSrc, dest) + if err != nil { + return fmt.Errorf("failed to scan %v into %s: %w", string(bSrc), name, err) + } + + return nil +} diff --git a/pkg/solana/logpoller/utils/anchor.go b/pkg/solana/logpoller/utils/anchor.go index 4d11bcc83..b042fb67d 100644 --- a/pkg/solana/logpoller/utils/anchor.go +++ b/pkg/solana/logpoller/utils/anchor.go @@ -19,8 +19,6 @@ import ( "github.com/gagliardetto/solana-go/rpc" "github.com/stretchr/testify/require" - - "github.com/smartcontractkit/chainlink-solana/pkg/solana/logpoller" ) var ZeroAddress = [32]byte{} @@ -53,10 +51,12 @@ func Map[T, V any](ts []T, fn func(T) V) []V { return result } -func Discriminator(namespace, name string) logpoller.EventSignature { +const DiscriminatorLength = 8 + +func Discriminator(namespace, name string) [DiscriminatorLength]byte { h := sha256.New() h.Write([]byte(fmt.Sprintf("%s:%s", namespace, name))) - return logpoller.EventSignature(h.Sum(nil)[:8]) + return [DiscriminatorLength]byte(h.Sum(nil)[:DiscriminatorLength]) } func FundAccounts(ctx context.Context, accounts []solana.PrivateKey, solanaGoClient *rpc.Client, t *testing.T) { @@ -97,7 +97,7 @@ func IsEvent(event string, data []byte) bool { return false } d := Discriminator("event", event) - return bytes.Equal(d, data[:8]) + return bytes.Equal(d[:], data[:8]) } func ParseEvent(logs []string, event string, obj interface{}, print ...bool) error {