diff --git a/pkg/solana/codec/codec_entry.go b/pkg/solana/codec/codec_entry.go index 4b11e9caf..bc42ae968 100644 --- a/pkg/solana/codec/codec_entry.go +++ b/pkg/solana/codec/codec_entry.go @@ -34,49 +34,75 @@ type entry struct { } func NewAccountEntry(offchainName string, idlAccount IdlTypeDef, idlTypes IdlTypeDefSlice, includeDiscriminator bool, mod codec.Modifier, builder commonencodings.Builder) (Entry, error) { - refs := &codecRefs{ - builder: builder, - codecs: make(map[string]commonencodings.TypeCodec), - typeDefs: idlTypes, - dependencies: make(map[string][]string), + _, accCodec, err := createCodecType(idlAccount, createRefs(idlTypes, builder), false) + if err != nil { + return nil, err + } + + return newEntry( + offchainName, + idlAccount.Name, + accCodec, + includeDiscriminator, + mod, + ), nil +} + +func NewInstructionArgsEntry(offChainName string, instructions IdlInstruction, idlTypes IdlTypeDefSlice, mod codec.Modifier, builder commonencodings.Builder) (Entry, error) { + _, instructionCodecArgs, err := asStruct(instructions.Args, createRefs(idlTypes, builder), instructions.Name, false, true) + if err != nil { + return nil, err } - _, accCodec, err := createCodecType(idlAccount, refs, false) + return newEntry( + offChainName, + instructions.Name, + instructionCodecArgs, + // Instruction arguments don't need a discriminator by default + false, + mod, + ), nil +} + +func NewEventArgsEntry(offChainName string, event IdlEvent, idlTypes IdlTypeDefSlice, includeDiscriminator bool, mod codec.Modifier, builder commonencodings.Builder) (Entry, error) { + _, eventCodec, err := asStruct(eventFieldsToFields(event.Fields), createRefs(idlTypes, builder), event.Name, false, false) if err != nil { return nil, err } + return newEntry( + offChainName, + event.Name, + eventCodec, + includeDiscriminator, + mod, + ), nil +} + +func newEntry( + offchainName, onchainName string, + typeCodec commonencodings.TypeCodec, + includeDiscriminator bool, + mod codec.Modifier, +) Entry { return &entry{ offchainName: offchainName, - onchainName: idlAccount.Name, - reflectType: accCodec.GetType(), - typeCodec: accCodec, + onchainName: onchainName, + reflectType: typeCodec.GetType(), + typeCodec: typeCodec, mod: ensureModifier(mod), includeDiscriminator: includeDiscriminator, - discriminator: *NewDiscriminator(idlAccount.Name), - }, nil + discriminator: *NewDiscriminator(onchainName), + } } -func NewInstructionArgsEntry(offChainName string, instructions IdlInstruction, idlTypes IdlTypeDefSlice, mod codec.Modifier, builder commonencodings.Builder) (Entry, error) { - refs := &codecRefs{ +func createRefs(idlTypes IdlTypeDefSlice, builder commonencodings.Builder) *codecRefs { + return &codecRefs{ builder: builder, codecs: make(map[string]commonencodings.TypeCodec), typeDefs: idlTypes, dependencies: make(map[string][]string), } - - _, instructionCodecArgs, err := asStruct(instructions.Args, refs, instructions.Name, false, true) - if err != nil { - return nil, err - } - - return &entry{ - offchainName: offChainName, - onchainName: instructions.Name, - typeCodec: instructionCodecArgs, - reflectType: instructionCodecArgs.GetType(), - mod: ensureModifier(mod), - }, nil } func (e *entry) Encode(value any, into []byte) ([]byte, error) { @@ -89,7 +115,8 @@ func (e *entry) Encode(value any, into []byte) ([]byte, error) { return []byte{}, nil } } - return nil, fmt.Errorf("%w: cannot encode nil value for offchainName: %q, onchainName: %q", commontypes.ErrInvalidType, e.offchainName, e.onchainName) + return nil, fmt.Errorf("%w: cannot encode nil value for offchainName: %q, onchainName: %q", + commontypes.ErrInvalidType, e.offchainName, e.onchainName) } encodedVal, err := e.typeCodec.Encode(value, into) @@ -112,11 +139,13 @@ func (e *entry) Encode(value any, into []byte) ([]byte, error) { func (e *entry) Decode(encoded []byte) (any, []byte, error) { if e.includeDiscriminator { if len(encoded) < discriminatorLength { - return nil, nil, fmt.Errorf("%w: encoded data too short to contain discriminator for offchainName: %q, onchainName: %q", commontypes.ErrInvalidType, e.offchainName, e.onchainName) + return nil, nil, fmt.Errorf("%w: encoded data too short to contain discriminator for offchainName: %q, onchainName: %q", + commontypes.ErrInvalidType, e.offchainName, e.onchainName) } - if !bytes.Equal(e.discriminator.hashPrefix, encoded[:8]) { - return nil, nil, fmt.Errorf("%w: encoded data has a bad discriminator %v for offchainName: %q, onchainName: %q", commontypes.ErrInvalidType, encoded[:8], e.offchainName, e.onchainName) + if !bytes.Equal(e.discriminator.hashPrefix, encoded[:discriminatorLength]) { + return nil, nil, fmt.Errorf("%w: encoded data has a bad discriminator %v for offchainName: %q, onchainName: %q", + commontypes.ErrInvalidType, encoded[:discriminatorLength], e.offchainName, e.onchainName) } encoded = encoded[discriminatorLength:] @@ -150,3 +179,14 @@ func ensureModifier(mod codec.Modifier) codec.Modifier { } return mod } + +func eventFieldsToFields(evFields []IdlEventField) []IdlField { + var idlFields []IdlField + for _, evField := range evFields { + idlFields = append(idlFields, IdlField{ + Name: evField.Name, + Type: evField.Type, + }) + } + return idlFields +} diff --git a/pkg/solana/codec/codec_test.go b/pkg/solana/codec/codec_test.go index 3f9023eea..8a9dedb45 100644 --- a/pkg/solana/codec/codec_test.go +++ b/pkg/solana/codec/codec_test.go @@ -9,12 +9,15 @@ import ( bin "github.com/gagliardetto/binary" "github.com/gagliardetto/solana-go" ocr2types "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec" looptestutils "github.com/smartcontractkit/chainlink-common/pkg/loop/testutils" clcommontypes "github.com/smartcontractkit/chainlink-common/pkg/types" . "github.com/smartcontractkit/chainlink-common/pkg/types/interfacetests" //nolint common practice to import test mods with . + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/codec" "github.com/smartcontractkit/chainlink-solana/pkg/solana/codec/testutils" ) @@ -25,6 +28,27 @@ func TestCodec(t *testing.T) { tester := &codecInterfaceTester{} RunCodecInterfaceTests(t, tester) RunCodecInterfaceTests(t, looptestutils.WrapCodecTesterForLoop(tester)) + + t.Run("Events are encode-able and decode-able for a single item", func(t *testing.T) { + ctx := tests.Context(t) + item := CreateTestStruct[*testing.T](0, tester) + req := &EncodeRequest{TestStructs: []TestStruct{item}, TestOn: testutils.TestEventItem} + resp := tester.EncodeFields(t, req) + + codec := tester.GetCodec(t) + actualEncoding, err := codec.Encode(ctx, item, testutils.TestEventItem) + require.NoError(t, err) + assert.Equal(t, resp, actualEncoding) + + into := TestStruct{} + require.NoError(t, codec.Decode(ctx, actualEncoding, &into, testutils.TestEventItem)) + assert.Equal(t, item, into) + }) +} + +func FuzzCodec(f *testing.F) { + tester := &codecInterfaceTester{} + RunCodecInterfaceFuzzTests(f, tester) } type codecInterfaceTester struct { @@ -44,7 +68,7 @@ func (it *codecInterfaceTester) GetAccountString(i int) string { } func (it *codecInterfaceTester) EncodeFields(t *testing.T, request *EncodeRequest) []byte { - if request.TestOn == TestItemType { + if request.TestOn == TestItemType || request.TestOn == testutils.TestEventItem { return encodeFieldsOnItem(t, request) } @@ -53,6 +77,7 @@ func (it *codecInterfaceTester) EncodeFields(t *testing.T, request *EncodeReques func encodeFieldsOnItem(t *testing.T, request *EncodeRequest) ocr2types.Report { buf := new(bytes.Buffer) + // The underlying TestItemAsAccount adds a discriminator by default while being Borsh encoded. if err := testutils.EncodeRequestToTestItemAsAccount(request.TestStructs[0]).MarshalWithEncoder(bin.NewBorshEncoder(buf)); err != nil { require.NoError(t, err) } @@ -98,7 +123,7 @@ func (it *codecInterfaceTester) GetCodec(t *testing.T) clcommontypes.Codec { } } - if slices.Contains([]string{TestItemType, TestItemSliceType, TestItemArray1Type, TestItemArray2Type, testutils.TestItemWithConfigExtraType}, offChainName) { + if slices.Contains([]string{TestItemType, TestItemSliceType, TestItemArray1Type, TestItemArray2Type, testutils.TestItemWithConfigExtraType, testutils.TestEventItem}, offChainName) { addressByteModifier := &commoncodec.AddressBytesToStringModifierConfig{ Fields: []string{"AccountStruct.AccountStr"}, Modifier: codec.SolanaAddressModifier{}, diff --git a/pkg/solana/codec/solana.go b/pkg/solana/codec/solana.go index dcf48af68..19fe40d3e 100644 --- a/pkg/solana/codec/solana.go +++ b/pkg/solana/codec/solana.go @@ -64,8 +64,6 @@ func NewCodec(conf Config) (commontypes.RemoteCodec, error) { for offChainName, cfg := range conf.Configs { var idl IDL - onChainName := cfg.OnChainName - if err := json.Unmarshal([]byte(cfg.IDL), &idl); err != nil { return nil, err } @@ -75,53 +73,59 @@ func NewCodec(conf Config) (commontypes.RemoteCodec, error) { return nil, err } + definition, err := findDefinitionFromIDL(cfg.Type, cfg.OnChainName, idl) + if err != nil { + return nil, err + } + var cEntry Entry - switch cfg.Type { - case ChainConfigTypeAccountDef: - var account *IdlTypeDef - for i := range idl.Accounts { - if idl.Accounts[i].Name == cfg.OnChainName { - account = &idl.Accounts[i] - break - } - } + switch v := definition.(type) { + case IdlTypeDef: + cEntry, err = NewAccountEntry(offChainName, v, idl.Types, true, mod, binary.LittleEndian()) + case IdlInstruction: + cEntry, err = NewInstructionArgsEntry(offChainName, v, idl.Types, mod, binary.LittleEndian()) + case IdlEvent: + cEntry, err = NewEventArgsEntry(offChainName, v, idl.Types, true, mod, binary.LittleEndian()) + } + if err != nil { + return nil, fmt.Errorf("failed to create %q codec entry: %w", offChainName, err) + } - if account == nil { - return nil, fmt.Errorf("failed to find account %q in IDL for offchainName %q", cfg.OnChainName, offChainName) - } + parsed.EncoderDefs[offChainName] = cEntry + parsed.DecoderDefs[offChainName] = cEntry + } - cEntry, err = NewAccountEntry(offChainName, *account, idl.Types, true, mod, binary.LittleEndian()) - if err != nil { - return nil, fmt.Errorf("failed to create %q codec entry: %w", offChainName, err) - } - case ChainConfigTypeInstructionDef: - var instruction *IdlInstruction - for i := range idl.Instructions { - if idl.Instructions[i].Name == onChainName { - instruction = &idl.Instructions[i] - break - } - } + return parsed.ToCodec() +} - if instruction == nil { - return nil, fmt.Errorf("failed to find instruction %q in IDL for offChainName %q", cfg.OnChainName, offChainName) +func findDefinitionFromIDL(cfgType ChainConfigType, onChainName string, idl IDL) (interface{}, error) { + // not the most efficient way to do this, but these slices should always be very, very small + switch cfgType { + case ChainConfigTypeAccountDef: + for i := range idl.Accounts { + if idl.Accounts[i].Name == onChainName { + return idl.Accounts[i], nil } + } + return nil, fmt.Errorf("failed to find account %q in IDL", onChainName) - cEntry, err = NewInstructionArgsEntry(offChainName, *instruction, idl.Types, mod, binary.LittleEndian()) - if err != nil { - return nil, fmt.Errorf("failed to create %q codec entry: %w", offChainName, err) + case ChainConfigTypeInstructionDef: + for i := range idl.Instructions { + if idl.Instructions[i].Name == onChainName { + return idl.Instructions[i], nil } - case ChainConfigTypeEventDef: - return nil, fmt.Errorf("TODO, unimplemented type: %q", cfg.Type) - default: - return nil, fmt.Errorf("unknown type: %q", cfg.Type) } + return nil, fmt.Errorf("failed to find instruction %q in IDL", onChainName) - parsed.EncoderDefs[offChainName] = cEntry - parsed.DecoderDefs[offChainName] = cEntry + case ChainConfigTypeEventDef: + for i := range idl.Events { + if idl.Events[i].Name == onChainName { + return idl.Events[i], nil + } + } + return nil, fmt.Errorf("failed to find event %q in IDL", onChainName) } - - return parsed.ToCodec() + return nil, fmt.Errorf("unknown type: %q", cfgType) } // NewIDLAccountCodec is for Anchor custom types diff --git a/pkg/solana/codec/testutils/eventItemTypeIDL.json b/pkg/solana/codec/testutils/eventItemTypeIDL.json new file mode 100644 index 000000000..f98f27671 --- /dev/null +++ b/pkg/solana/codec/testutils/eventItemTypeIDL.json @@ -0,0 +1,73 @@ +{ + "version": "0.1.0", + "name": "test_item_event_type", + "instructions": [], + "events": [ + { + "name": "TestItem", + "fields": [ + { "name": "Field", "type": "i32" }, + { "name": "OracleId", "type": "u8" }, + { "name": "OracleIds", "type": { "array": ["u8", 32] } }, + { "name": "AccountStruct", "type": { "defined": "AccountStruct" } }, + { "name": "Accounts", "type": { "vec": "publicKey" } }, + { "name": "DifferentField", "type": "string" }, + { "name": "BigField", "type": "i128" }, + { "name": "NestedDynamicStruct", "type": { "defined": "NestedDynamic" } }, + { "name": "NestedStaticStruct", "type": { "defined": "NestedStatic" } } + ] + } + ], + "types": [ + { + "name": "AccountStruct", + "type": { + "kind": "struct", + "fields": [ + { "name": "Account", "type": "publicKey" }, + { "name": "AccountStr", "type": "publicKey" } + ] + } + }, + { + "name": "InnerDynamic", + "type": { + "kind": "struct", + "fields": [ + { "name": "IntVal", "type": "i64" }, + { "name": "S", "type": "string" } + ] + } + }, + { + "name": "NestedDynamic", + "type": { + "kind": "struct", + "fields": [ + { "name": "FixedBytes", "type": { "array": ["u8", 2] } }, + { "name": "Inner", "type": { "defined": "InnerDynamic" } } + ] + } + }, + { + "name": "InnerStatic", + "type": { + "kind": "struct", + "fields": [ + { "name": "IntVal", "type": "i64" }, + { "name": "A", "type": "publicKey" } + ] + } + }, + { + "name": "NestedStatic", + "type": { + "kind": "struct", + "fields": [ + { "name": "FixedBytes", "type": { "array": ["u8", 2] } }, + { "name": "Inner", "type": { "defined": "InnerStatic" } } + ] + } + } + ] +} diff --git a/pkg/solana/codec/testutils/types.go b/pkg/solana/codec/testutils/types.go index 7143af49d..3c52adb0f 100644 --- a/pkg/solana/codec/testutils/types.go +++ b/pkg/solana/codec/testutils/types.go @@ -55,6 +55,7 @@ var ( EnumVal: 0, } TestItemWithConfigExtraType = "TestItemWithConfigExtra" + TestEventItem = "TestEventItem" ) type StructWithNestedStruct struct { @@ -97,6 +98,9 @@ var CircularDepIDL string //go:embed itemIDL.json var itemTypeJSONIDL string +//go:embed eventItemTypeIDL.json +var eventItemTypeJSONIDL string + //go:embed itemSliceTypeIDL.json var itemSliceTypeJSONIDL string @@ -147,6 +151,11 @@ var CodecDefs = map[string]CodecDef{ IDLTypeName: interfacetests.NilType, ItemType: codec.ChainConfigTypeAccountDef, }, + TestEventItem: { + IDL: eventItemTypeJSONIDL, + IDLTypeName: interfacetests.TestItemType, + ItemType: codec.ChainConfigTypeEventDef, + }, } type TestItemAsAccount struct {