Skip to content

Commit

Permalink
Code improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
ilija42 committed Dec 16, 2024
1 parent 17d3474 commit d795ae4
Show file tree
Hide file tree
Showing 30 changed files with 188 additions and 206 deletions.
12 changes: 6 additions & 6 deletions pkg/solana/chainreader/chain_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ import (
"github.com/smartcontractkit/chainlink-common/pkg/types/query"
"github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives"

"github.com/smartcontractkit/chainlink-solana/pkg/solana/codec"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/config"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/solanacodec"
)

const ServiceName = "SolanaChainReader"
Expand Down Expand Up @@ -223,12 +223,12 @@ func (s *SolanaChainReaderService) CreateContractType(readIdentifier string, for
func (s *SolanaChainReaderService) init(namespaces map[string]config.ChainReaderMethods) error {
for namespace, methods := range namespaces {
for methodName, method := range methods.Methods {
var idl solanacodec.IDL
var idl codec.IDL
if err := json.Unmarshal([]byte(method.AnchorIDL), &idl); err != nil {
return err
}

idlCodec, err := solanacodec.NewIDLAccountCodec(idl, config.BuilderForEncoding(method.Encoding))
idlCodec, err := codec.NewIDLAccountCodec(idl, config.BuilderForEncoding(method.Encoding))
if err != nil {
return err
}
Expand All @@ -239,12 +239,12 @@ func (s *SolanaChainReaderService) init(namespaces map[string]config.ChainReader

injectAddressModifier(procedure.OutputModifications)

mod, err := procedure.OutputModifications.ToModifier(solanacodec.DecoderHooks...)
mod, err := procedure.OutputModifications.ToModifier(codec.DecoderHooks...)
if err != nil {
return err
}

codecWithModifiers, err := solanacodec.NewNamedModifierCodec(idlCodec, procedure.IDLAccount, mod)
codecWithModifiers, err := codec.NewNamedModifierCodec(idlCodec, procedure.IDLAccount, mod)
if err != nil {
return err
}
Expand All @@ -265,7 +265,7 @@ func (s *SolanaChainReaderService) init(namespaces map[string]config.ChainReader
func injectAddressModifier(outputModifications codeccommon.ModifiersConfig) {
for i, modConfig := range outputModifications {
if addrModifierConfig, ok := modConfig.(*codeccommon.AddressBytesToStringModifierConfig); ok {
addrModifierConfig.Modifier = solanacodec.SolanaAddressModifier{}
addrModifierConfig.Modifier = codec.SolanaAddressModifier{}
outputModifications[i] = addrModifierConfig
}
}
Expand Down
14 changes: 7 additions & 7 deletions pkg/solana/chainreader/chain_reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ import (
"github.com/smartcontractkit/chainlink-common/pkg/utils/tests"

"github.com/smartcontractkit/chainlink-solana/pkg/solana/chainreader"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/codec"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/codec/testutils"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/config"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/solanacodec"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/solanacodec/testutils"
)

const (
Expand Down Expand Up @@ -269,16 +269,16 @@ func TestSolanaChainReaderService_GetLatestValue(t *testing.T) {
})
}

func newTestIDLAndCodec(t *testing.T) (string, solanacodec.IDL, types.RemoteCodec) {
func newTestIDLAndCodec(t *testing.T) (string, codec.IDL, types.RemoteCodec) {
t.Helper()

var idl solanacodec.IDL
var idl codec.IDL
if err := json.Unmarshal([]byte(testutils.JSONIDLWithAllTypes), &idl); err != nil {
t.Logf("failed to unmarshal test IDL: %s", err.Error())
t.FailNow()
}

entry, err := solanacodec.NewIDLAccountCodec(idl, binary.LittleEndian())
entry, err := codec.NewIDLAccountCodec(idl, binary.LittleEndian())
if err != nil {
t.Logf("failed to create new codec from test IDL: %s", err.Error())
t.FailNow()
Expand Down Expand Up @@ -763,13 +763,13 @@ func (r *chainReaderInterfaceTester) MaxWaitTimeForEvents() time.Duration {
func makeTestCodec(t *testing.T, rawIDL string, encoding config.EncodingType) types.RemoteCodec {
t.Helper()

var idl solanacodec.IDL
var idl codec.IDL
if err := json.Unmarshal([]byte(rawIDL), &idl); err != nil {
t.Logf("failed to unmarshal test IDL: %s", err.Error())
t.FailNow()
}

testCodec, err := solanacodec.NewIDLAccountCodec(idl, config.BuilderForEncoding(encoding))
testCodec, err := codec.NewIDLAccountCodec(idl, config.BuilderForEncoding(encoding))
if err != nil {
t.Logf("failed to create new codec from test IDL: %s", err.Error())
t.FailNow()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package solanacodec
package codec

/*
copied from https://github.com/gagliardetto/anchor-go where the IDL definition is not importable due to being defined
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package solanacodec
package codec

import (
"fmt"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package solanacodec_test
package codec_test

import (
"testing"
Expand All @@ -9,11 +9,11 @@ import (

commontypes "github.com/smartcontractkit/chainlink-common/pkg/types"

"github.com/smartcontractkit/chainlink-solana/pkg/solana/solanacodec"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/codec"
)

func TestSolanaAddressModifier(t *testing.T) {
modifier := solanacodec.SolanaAddressModifier{}
modifier := codec.SolanaAddressModifier{}

// Valid Solana address (32 bytes, Base58 encoded)
validAddressStr := "9nQhQ7iCyY5SgAX2Zm4DtxNh9Ubc4vbiLkiYbX43SDXY"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package solanacodec
package codec

import (
"fmt"
Expand All @@ -15,6 +15,8 @@ type Entry interface {
GetCodecType() commonencodings.TypeCodec
GetType() reflect.Type
Modifier() codec.Modifier
Size(numItems int) (int, error)
FixedSize() (int, error)
}

func NewAccountEntry(offchainName string, idlAccount IdlTypeDef, idlTypes IdlTypeDefSlice, includeDiscriminator bool, mod codec.Modifier, builder commonencodings.Builder) (Entry, error) {
Expand All @@ -30,15 +32,14 @@ func NewAccountEntry(offchainName string, idlAccount IdlTypeDef, idlTypes IdlTyp
return nil, err
}

entry := &CodecEntry{
return &entry{
offchainName: offchainName,
onchainName: idlAccount.Name,
includeDiscriminator: includeDiscriminator,
typeCodec: accCodec,
reflectType: accCodec.GetType(),
mod: ensureModifier(mod),
}
return entry, nil
}, nil
}

func NewInstructionArgsEntry(offChainName string, instructions IdlInstruction, idlTypes IdlTypeDefSlice, mod codec.Modifier, builder commonencodings.Builder) (Entry, error) {
Expand All @@ -54,7 +55,7 @@ func NewInstructionArgsEntry(offChainName string, instructions IdlInstruction, i
return nil, err
}

return &CodecEntry{
return &entry{
offchainName: offChainName,
onchainName: instructions.Name,
typeCodec: instructionCodecArgs,
Expand All @@ -63,7 +64,7 @@ func NewInstructionArgsEntry(offChainName string, instructions IdlInstruction, i
}, nil
}

type CodecEntry struct {
type entry struct {
// TODO this might not be needed in the end, it was handy to make tests simpler
offchainName string
onchainName string
Expand All @@ -73,35 +74,35 @@ type CodecEntry struct {
includeDiscriminator bool
}

func (entry *CodecEntry) GetType() reflect.Type {
return entry.reflectType
func (e *entry) GetType() reflect.Type {
return e.reflectType
}

func (entry *CodecEntry) GetCodecType() commonencodings.TypeCodec {
return entry.typeCodec
func (e *entry) GetCodecType() commonencodings.TypeCodec {
return e.typeCodec
}

func (entry *CodecEntry) Encode(value any, into []byte) ([]byte, error) {
func (e *entry) Encode(value any, into []byte) ([]byte, error) {
// Special handling for encoding a nil pointer to an empty struct.
t := entry.reflectType
t := e.reflectType
if value == nil {
if t.Kind() == reflect.Pointer {
elem := t.Elem()
if elem.Kind() == reflect.Struct && elem.NumField() == 0 {
return []byte{}, nil
}
}
return nil, fmt.Errorf("%w: cannot encode nil value for offchainName: %q, onchainName: %q", commontypes.ErrInvalidType, entry.offchainName, entry.onchainName)
return nil, fmt.Errorf("%w: cannot encode nil value for offchainName: %q, onchainName: %q", commontypes.ErrInvalidType, e.offchainName, e.onchainName)
}

encodedVal, err := entry.typeCodec.Encode(value, into)
encodedVal, err := e.typeCodec.Encode(value, into)
if err != nil {
return nil, err
}

if entry.includeDiscriminator {
if e.includeDiscriminator {
var byt []byte
disc := NewDiscriminator(entry.onchainName)
disc := NewDiscriminator(e.onchainName)
encodedDisc, err := disc.Encode(&disc.hashPrefix, byt)
if err != nil {
return nil, err
Expand All @@ -112,18 +113,18 @@ func (entry *CodecEntry) Encode(value any, into []byte) ([]byte, error) {
return encodedVal, nil
}

func (entry *CodecEntry) Decode(encoded []byte) (any, []byte, error) {
if entry.includeDiscriminator {
func (e *entry) Decode(encoded []byte) (any, []byte, error) {
if e.includeDiscriminator {
if len(encoded) < discriminatorLength {
return nil, nil, fmt.Errorf("%w: encoded data too short to contain discriminator for offchainName: %q, onchainName: %q", commontypes.ErrInvalidType, entry.offchainName, entry.onchainName)
return nil, nil, fmt.Errorf("%w: encoded data too short to contain discriminator for offchainName: %q, onchainName: %q", commontypes.ErrInvalidType, e.offchainName, e.onchainName)
}
encoded = encoded[discriminatorLength:]
}
return entry.typeCodec.Decode(encoded)
return e.typeCodec.Decode(encoded)
}

func (entry *CodecEntry) Modifier() codec.Modifier {
return entry.mod
func (e *entry) Modifier() codec.Modifier {
return e.mod
}

func ensureModifier(mod codec.Modifier) codec.Modifier {
Expand All @@ -132,3 +133,11 @@ func ensureModifier(mod codec.Modifier) codec.Modifier {
}
return mod
}

func (e *entry) Size(numItems int) (int, error) {
return e.typeCodec.Size(numItems)
}

func (e *entry) FixedSize() (int, error) {
return e.typeCodec.FixedSize()
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package solanacodec_test
package codec_test

import (
"bytes"
Expand All @@ -15,8 +15,8 @@ import (
looptestutils "github.com/smartcontractkit/chainlink-common/pkg/loop/testutils"
clcommontypes "github.com/smartcontractkit/chainlink-common/pkg/types"
. "github.com/smartcontractkit/chainlink-common/pkg/types/interfacetests" //nolint common practice to import test mods with .
"github.com/smartcontractkit/chainlink-solana/pkg/solana/solanacodec"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/solanacodec/testutils"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/codec"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/codec/testutils"
)

const anyExtraValue = 3
Expand Down Expand Up @@ -83,7 +83,7 @@ func encodeFieldsOnSliceOrArray(t *testing.T, request *EncodeRequest) []byte {
}

func (it *codecInterfaceTester) GetCodec(t *testing.T) clcommontypes.Codec {
codecConfig := solanacodec.Config{Configs: map[string]solanacodec.ChainConfig{}}
codecConfig := codec.Config{Configs: map[string]codec.ChainConfig{}}
TestItem := CreateTestStruct[*testing.T](0, it)
for offChainName, v := range testutils.CodecDefs {
codecEntryCfg := codecConfig.Configs[offChainName]
Expand All @@ -101,7 +101,7 @@ func (it *codecInterfaceTester) GetCodec(t *testing.T) clcommontypes.Codec {
if slices.Contains([]string{testutils.TestItemType, testutils.TestItemSliceType, testutils.TestItemArray1Type, testutils.TestItemArray2Type, testutils.TestItemWithConfigExtraType}, offChainName) {
addressByteModifier := &commoncodec.AddressBytesToStringModifierConfig{
Fields: []string{"AccountStruct.AccountStr"},
Modifier: solanacodec.SolanaAddressModifier{},
Modifier: codec.SolanaAddressModifier{},
}
codecEntryCfg.ModifierConfigs = append(codecEntryCfg.ModifierConfigs, addressByteModifier)
}
Expand All @@ -119,7 +119,7 @@ func (it *codecInterfaceTester) GetCodec(t *testing.T) clcommontypes.Codec {
codecConfig.Configs[offChainName] = codecEntryCfg
}

c, err := solanacodec.NewCodec(codecConfig)
c, err := codec.NewCodec(codecConfig)
require.NoError(t, err)

return c
Expand Down
35 changes: 35 additions & 0 deletions pkg/solana/codec/decoder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package codec

import (
"context"
"fmt"

"github.com/smartcontractkit/chainlink-common/pkg/codec/encodings"
commontypes "github.com/smartcontractkit/chainlink-common/pkg/types"
)

type Decoder struct {
definitions map[string]Entry
codecFromTypeCodec encodings.CodecFromTypeCodec
}

var _ commontypes.Decoder = &Decoder{}

func (d *Decoder) Decode(ctx context.Context, raw []byte, into any, itemType string) (err error) {
if d.codecFromTypeCodec == nil {
d.codecFromTypeCodec = make(encodings.CodecFromTypeCodec)
for k, v := range d.definitions {
d.codecFromTypeCodec[k] = v
}
}

return d.codecFromTypeCodec.Decode(ctx, raw, into, itemType)
}

func (d *Decoder) GetMaxDecodingSize(_ context.Context, n int, itemType string) (int, error) {
codecEntry, ok := d.definitions[itemType]
if !ok {
return 0, fmt.Errorf("%w: nil entry", commontypes.ErrInvalidType)
}
return codecEntry.GetCodecType().Size(n)
}
Loading

0 comments on commit d795ae4

Please sign in to comment.