diff --git a/pkg/solana/codec/solana_test.go b/pkg/solana/codec/solana_test.go index 4dd116691..57e921fb7 100644 --- a/pkg/solana/codec/solana_test.go +++ b/pkg/solana/codec/solana_test.go @@ -143,6 +143,13 @@ func TestNewIDLCodec_CircularDependency(t *testing.T) { assert.ErrorIs(t, err, types.ErrInvalidConfig) } +func TestNewIDLInstructionCodec(t *testing.T) { + t.Parallel() + + var idl codec.IDL + +} + func newTestIDLAndCodec(t *testing.T, account bool) (string, codec.IDL, types.RemoteCodec) { t.Helper() diff --git a/pkg/solana/logpoller/filters.go b/pkg/solana/logpoller/filters.go index 1f8bd4933..b550918e1 100644 --- a/pkg/solana/logpoller/filters.go +++ b/pkg/solana/logpoller/filters.go @@ -2,6 +2,7 @@ package logpoller import ( "context" + "encoding/base64" "errors" "fmt" "iter" @@ -9,19 +10,28 @@ import ( "sync" "sync/atomic" + "github.com/gagliardetto/solana-go" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/types" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/codec" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/logpoller/utils" ) type filters struct { orm ORM lggr logger.SugaredLogger - filtersByName map[string]Filter - filtersByAddress map[PublicKey]map[EventSignature]map[int64]Filter - filtersToBackfill map[int64]Filter - filtersToDelete map[int64]Filter - filtersMutex sync.RWMutex - loadedFilters atomic.Bool + filtersByName map[string]Filter + filtersByAddress map[PublicKey]map[EventSignature]map[int64]Filter + filtersToBackfill map[int64]Filter + filtersToDelete map[int64]Filter + filtersMutex sync.RWMutex + loadedFilters atomic.Bool + eventCodecs map[int64]types.RemoteCodec + knownPrograms map[string]struct{} // fast lookup to see if a base58-encoded ProgramID matches any registered filters + knownDiscriminators map[string]struct{} // fast lookup by first 10 characters (60-bits) of a base64-encoded discriminator } func newFilters(lggr logger.SugaredLogger, orm ORM) *filters { @@ -75,6 +85,13 @@ func (fl *filters) RegisterFilter(ctx context.Context, filter Filter) error { return fmt.Errorf("failed to load filters: %w", err) } + eventCodec, err := codec.NewIDLEventCodec(filter.EventIDL, config.BuilderForEncoding(config.EncodingTypeBorsh)) + if err != nil { + return fmt.Errorf("invalid event IDL for filter %s: %w", filter.Name, err) + } + + filter.EventSig = utils.Discriminator("event", filter.EventName) + fl.filtersMutex.Lock() defer fl.filtersMutex.Unlock() @@ -107,6 +124,12 @@ func (fl *filters) RegisterFilter(ctx context.Context, filter Filter) error { filtersByID[filter.ID] = filter fl.filtersToBackfill[filterID] = filter + + fl.eventCodecs[filter.ID] = eventCodec + fl.knownPrograms[filter.Address.ToSolana().String()] = struct{}{} + discriminator := base64.StdEncoding.EncodeToString(filter.EventSig[:]) + fl.knownDiscriminators[discriminator[:10]] = struct{}{} + return nil } @@ -187,6 +210,42 @@ func (fl *filters) MatchingFilters(addr PublicKey, eventSignature EventSignature } } +func (fl *filters) EventCodec(ID int64) types.RemoteCodec { + return fl.eventCodecs[ID] +} + +// MatchchingFiltersForEncodedEvent - similar to MatchingFilters but accepts a raw encoded event. Under normal operation, +// this will be called on every new event that happens on the blockchain, so it's important it returns immediately if it +// doesn't match any registered filters. +func (fl *filters) MatchingFiltersForEncodedEvent(event ProgramEvent) iter.Seq[Filter] { + if _, ok := fl.knownPrograms[event.Program]; !ok { + return nil + } + + // The first 64-bits of the event data is the event sig. Because it's base64 encoded, this corresponds to + // the first 10 characters plus 4 bits of the 11th character. We can quickly rule it out as not matching any known + // discriminators if the first 10 characters don't match. If it passes that initial test, we base64-decode the + // first 11 characters, and use the first 8 bytes of that as the event sig to call MatchingFilters. The address + // also needs to be base58-decoded to pass to MatchingFilters + if _, ok := fl.knownDiscriminators[event.Data[:10]]; !ok { + return nil + } + + addr, err := solana.PublicKeyFromBase58(event.Program) + if err != nil { + fl.lggr.Errorw("failed to parse Program ID for event", "EventProgram", event) + return nil + } + decoded, err := base64.StdEncoding.DecodeString(event.Data[:11]) + if err != nil { + fl.lggr.Errorw("failed to decode event data", "EventProgram", event) + return nil + } + eventSig := EventSignature(decoded[:8]) + + return fl.MatchingFilters(PublicKey(addr), eventSig) +} + // ConsumeFiltersToBackfill - removes all filters from the backfill queue and returns them to caller. // Requires LoadFilters to be called at least once. func (fl *filters) ConsumeFiltersToBackfill() map[int64]Filter { diff --git a/pkg/solana/logpoller/log_poller.go b/pkg/solana/logpoller/log_poller.go index 37f897aa2..2a818f8a0 100644 --- a/pkg/solana/logpoller/log_poller.go +++ b/pkg/solana/logpoller/log_poller.go @@ -2,12 +2,18 @@ package logpoller import ( "context" + "encoding/base64" "errors" + "fmt" + "math" + "reflect" "sync" "time" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/services" + commontypes "github.com/smartcontractkit/chainlink-common/pkg/types" + "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" "github.com/smartcontractkit/chainlink-solana/pkg/solana/internal" @@ -19,10 +25,12 @@ var ( //go:generate mockery --name ORM --inpackage --structname mockORM --filename mock_orm.go type ORM interface { + ChainID() string InsertFilter(ctx context.Context, filter Filter) (id int64, err error) SelectFilters(ctx context.Context) ([]Filter, error) DeleteFilters(ctx context.Context, filters map[int64]Filter) error MarkFilterDeleted(ctx context.Context, id int64) (err error) + InsertLogs(context.Context, []Log) (err error) } type ILogPoller interface { @@ -39,8 +47,10 @@ type LogPoller struct { client internal.Loader[client.Reader] collector *EncodedLogCollector - filters *filters - events []ProgramEvent + filters *filters + discriminatorLookup map[string]string + events []ProgramEvent + codec commontypes.RemoteCodec chStop services.StopChan wg sync.WaitGroup @@ -57,12 +67,78 @@ func NewLogPoller(lggr logger.SugaredLogger, orm ORM, cl internal.Loader[client. return &lp } -func (lp *LogPoller) Process(event ProgramEvent) error { - // process stream of events coming from event loader - lp.events = append(lp.events, event) +func makeLogIndex(txIndex int, txLogIndex uint) int64 { + if txIndex < 0 || txIndex > math.MaxUint32 || txLogIndex > math.MaxUint32 { + panic(fmt.Sprintf("txIndex or txLogIndex out of range: txIndex=%d, txLogIndex=%d", txIndex, txLogIndex)) + } + return int64(math.MaxUint32*uint32(txIndex) + uint32(txLogIndex)) +} + +// Process - process stream of events coming from log ingester +func (lp *LogPoller) Process(programEvent ProgramEvent) (err error) { + ctx, cancel := utils.ContextFromChan(lp.chStop) + defer cancel() + + blockData := programEvent.BlockData + + var logs []Log + for filter := range lp.filters.MatchingFiltersForEncodedEvent(programEvent) { + log := Log{ + FilterID: filter.ID, + ChainID: lp.orm.ChainID(), + LogIndex: makeLogIndex(blockData.TransactionIndex, blockData.TransactionLogIndex), + BlockHash: Hash(blockData.BlockHash), + BlockNumber: int64(blockData.BlockHeight), + BlockTimestamp: blockData.BlockTime.Time(), // TODO: is this a timezone safe conversion? + Address: filter.Address, + EventSig: filter.EventSig, + TxHash: Signature(blockData.TransactionHash), + } + + log.Data, err = base64.StdEncoding.DecodeString(programEvent.Data) + if err != nil { + return err + } + + var event any + err = lp.filters.EventCodec(filter.ID).Decode(ctx, log.Data, &event, filter.EventName) + if err != nil { + return err + } + + err = lp.ExtractSubkeys(reflect.TypeOf(event), filter.SubkeyPaths) + if err != nil { + return err + } + + // TODO: fill in, and keep track of SequenceNumber for each filter. (Initialize from db on LoadFilters, then increment each time?) + + logs = append(logs, log) + } + + lp.orm.InsertLogs(ctx, logs) return nil } +func (lp *LogPoller) ExtractSubkeys(t reflect.Type, paths SubkeyPaths) error { + s := reflect.TypeOf(event) + if s.Kind() != reflect.Struct { + return fmt.Errorf("event type must be struct, got %v. event=%v", t, event) + } + + for _, path := range paths[0] { + field, err := s.FieldByName(path) + for depth := 0; depth < len(paths); depth++ { + for _, path := range paths[depth] { + field, err = field.Type.FieldByName(path) + } + } + } + +} + +func get + func (lp *LogPoller) Start(context.Context) error { cl, err := lp.client.Get() if err != nil { diff --git a/pkg/solana/logpoller/models.go b/pkg/solana/logpoller/models.go index 10f853682..f21d2549b 100644 --- a/pkg/solana/logpoller/models.go +++ b/pkg/solana/logpoller/models.go @@ -4,6 +4,8 @@ import ( "time" "github.com/lib/pq" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/codec" ) type Filter struct { @@ -13,7 +15,7 @@ type Filter struct { EventName string EventSig EventSignature StartingBlock int64 - EventIDL string + EventIDL codec.IDL SubkeyPaths SubkeyPaths Retention time.Duration MaxLogsKept int64 diff --git a/pkg/solana/logpoller/utils/anchor.go b/pkg/solana/logpoller/utils/anchor.go new file mode 100644 index 000000000..4d11bcc83 --- /dev/null +++ b/pkg/solana/logpoller/utils/anchor.go @@ -0,0 +1,290 @@ +package utils + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "fmt" + "regexp" + "strconv" + "strings" + "testing" + "time" + + bin "github.com/gagliardetto/binary" + "github.com/gagliardetto/solana-go" + "github.com/gagliardetto/solana-go/rpc" + + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/logpoller" +) + +var ZeroAddress = [32]byte{} + +func MakeRandom32ByteArray() [32]byte { + a := make([]byte, 32) + if _, err := rand.Read(a); err != nil { + panic(err) // should never panic but check in case + } + return [32]byte(a) +} + +func Uint64ToLE(chain uint64) []byte { + chainLE := make([]byte, 8) + binary.LittleEndian.PutUint64(chainLE, chain) + return chainLE +} + +func To28BytesLE(value uint64) [28]byte { + le := make([]byte, 28) + binary.LittleEndian.PutUint64(le, value) + return [28]byte(le) +} + +func Map[T, V any](ts []T, fn func(T) V) []V { + result := make([]V, len(ts)) + for i, t := range ts { + result[i] = fn(t) + } + return result +} + +func Discriminator(namespace, name string) logpoller.EventSignature { + h := sha256.New() + h.Write([]byte(fmt.Sprintf("%s:%s", namespace, name))) + return logpoller.EventSignature(h.Sum(nil)[:8]) +} + +func FundAccounts(ctx context.Context, accounts []solana.PrivateKey, solanaGoClient *rpc.Client, t *testing.T) { + sigs := []solana.Signature{} + for _, v := range accounts { + sig, err := solanaGoClient.RequestAirdrop(ctx, v.PublicKey(), 1000*solana.LAMPORTS_PER_SOL, rpc.CommitmentFinalized) + require.NoError(t, err) + sigs = append(sigs, sig) + } + + // wait for confirmation so later transactions don't fail + remaining := len(sigs) + count := 0 + for remaining > 0 { + count++ + statusRes, sigErr := solanaGoClient.GetSignatureStatuses(ctx, true, sigs...) + require.NoError(t, sigErr) + require.NotNil(t, statusRes) + require.NotNil(t, statusRes.Value) + + unconfirmedTxCount := 0 + for _, res := range statusRes.Value { + if res == nil || res.ConfirmationStatus == rpc.ConfirmationStatusProcessed || res.ConfirmationStatus == rpc.ConfirmationStatusConfirmed { + unconfirmedTxCount++ + } + } + remaining = unconfirmedTxCount + + time.Sleep(500 * time.Millisecond) + if count > 60 { + require.NoError(t, fmt.Errorf("unable to find transaction within timeout")) + } + } +} + +func IsEvent(event string, data []byte) bool { + if len(data) < 8 { + return false + } + d := Discriminator("event", event) + return bytes.Equal(d, data[:8]) +} + +func ParseEvent(logs []string, event string, obj interface{}, print ...bool) error { + for _, v := range logs { + if strings.Contains(v, "Program data:") { + encodedData := strings.TrimSpace(strings.TrimPrefix(v, "Program data:")) + data, err := base64.StdEncoding.DecodeString(encodedData) + if err != nil { + return err + } + if IsEvent(event, data) { + if err := bin.UnmarshalBorsh(obj, data); err != nil { + return err + } + + if len(print) > 0 && print[0] { + fmt.Printf("%s: %+v\n", event, obj) + } + return nil + } + } + } + return fmt.Errorf("%s: event not found", event) +} + +func ParseMultipleEvents[T any](logs []string, event string, print bool) ([]T, error) { + var results []T + for _, v := range logs { + if strings.Contains(v, "Program data:") { + encodedData := strings.TrimSpace(strings.TrimPrefix(v, "Program data:")) + data, err := base64.StdEncoding.DecodeString(encodedData) + if err != nil { + return nil, err + } + if IsEvent(event, data) { + var obj T + if err := bin.UnmarshalBorsh(&obj, data); err != nil { + return nil, err + } + + if print { + fmt.Printf("%s: %+v\n", event, obj) + } + + results = append(results, obj) + } + } + } + if len(results) == 0 { + return nil, fmt.Errorf("%s: event not found", event) + } + + return results, nil +} + +type AnchorInstruction struct { + Name string + ProgramID string + Logs []string + ComputeUnits int + InnerCalls []*AnchorInstruction +} + +// Parses the log messages from an Anchor program and returns a list of AnchorInstructions. +func ParseLogMessages(logMessages []string) []*AnchorInstruction { + var instructions []*AnchorInstruction + var stack []*AnchorInstruction + var currentInstruction *AnchorInstruction + + programInvokeRegex := regexp.MustCompile(`Program (\w+) invoke`) + programSuccessRegex := regexp.MustCompile(`Program (\w+) success`) + computeUnitsRegex := regexp.MustCompile(`Program (\w+) consumed (\d+) of \d+ compute units`) + + for _, line := range logMessages { + line = strings.TrimSpace(line) + + // Program invocation - push to stack + if match := programInvokeRegex.FindStringSubmatch(line); len(match) > 1 { + newInstruction := &AnchorInstruction{ + ProgramID: match[1], + Name: "", + Logs: []string{}, + ComputeUnits: 0, + InnerCalls: []*AnchorInstruction{}, + } + + if len(stack) == 0 { + instructions = append(instructions, newInstruction) + } else { + stack[len(stack)-1].InnerCalls = append(stack[len(stack)-1].InnerCalls, newInstruction) + } + + stack = append(stack, newInstruction) + currentInstruction = newInstruction + continue + } + + // Program success - pop from stack + if match := programSuccessRegex.FindStringSubmatch(line); len(match) > 1 { + if len(stack) > 0 { + stack = stack[:len(stack)-1] // pop + if len(stack) > 0 { + currentInstruction = stack[len(stack)-1] + } else { + currentInstruction = nil + } + } + continue + } + + // Instruction name + if strings.Contains(line, "Instruction:") { + if currentInstruction != nil { + currentInstruction.Name = strings.TrimSpace(strings.Split(line, "Instruction:")[1]) + } + continue + } + + // Program logs + if strings.HasPrefix(line, "Program log:") { + if currentInstruction != nil { + logMessage := strings.TrimSpace(strings.TrimPrefix(line, "Program log:")) + currentInstruction.Logs = append(currentInstruction.Logs, logMessage) + } + continue + } + + // Compute units + if match := computeUnitsRegex.FindStringSubmatch(line); len(match) > 1 { + programID := match[1] + computeUnits, _ := strconv.Atoi(match[2]) + + // Find the instruction in the stack that matches this program ID + for i := len(stack) - 1; i >= 0; i-- { + if stack[i].ProgramID == programID { + stack[i].ComputeUnits = computeUnits + break + } + } + } + } + + return instructions +} + +// Pretty prints the given Anchor instructions. +// Example usage: +// parsed := utils.ParseLogMessages(result.Meta.LogMessages) +// output := utils.PrintInstructions(parsed) +// t.Logf("Parsed Instructions: %s", output) +func PrintInstructions(instructions []*AnchorInstruction) string { + var output strings.Builder + + var printInstruction func(*AnchorInstruction, int, string) + printInstruction = func(instruction *AnchorInstruction, index int, indent string) { + output.WriteString(fmt.Sprintf("%sInstruction %d: %s\n", indent, index, instruction.Name)) + output.WriteString(fmt.Sprintf("%s Program ID: %s\n", indent, instruction.ProgramID)) + output.WriteString(fmt.Sprintf("%s Compute Units: %d\n", indent, instruction.ComputeUnits)) + output.WriteString(fmt.Sprintf("%s Logs:\n", indent)) + for _, log := range instruction.Logs { + output.WriteString(fmt.Sprintf("%s %s\n", indent, log)) + } + if len(instruction.InnerCalls) > 0 { + output.WriteString(fmt.Sprintf("%s Inner Calls:\n", indent)) + for i, innerCall := range instruction.InnerCalls { + printInstruction(innerCall, i+1, indent+" ") + } + } + } + + for i, instruction := range instructions { + printInstruction(instruction, i+1, "") + } + + return output.String() +} + +func GetBlockTime(ctx context.Context, client *rpc.Client, commitment rpc.CommitmentType) (*solana.UnixTimeSeconds, error) { + block, err := client.GetBlockHeight(ctx, commitment) + if err != nil { + return nil, fmt.Errorf("failed to get block height: %w", err) + } + + blockTime, err := client.GetBlockTime(ctx, block) + if err != nil { + return nil, fmt.Errorf("failed to get block time: %w", err) + } + + return blockTime, nil +}