Skip to content

Commit

Permalink
Add discriminator value check in codec entry Decode
Browse files Browse the repository at this point in the history
  • Loading branch information
ilija42 committed Dec 18, 2024
1 parent aa89442 commit 6bc1d31
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions pkg/solana/codec/codec_entry.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package codec

import (
"bytes"
"fmt"
"reflect"

Expand Down Expand Up @@ -29,6 +30,7 @@ type entry struct {
// includeDiscriminator during Encode adds a discriminator to the encoded bytes under an assumption that the provided value didn't have a discriminator.
// During Decode includeDiscriminator removes discriminator from bytes under an assumption that the provided struct doesn't need a discriminator.
includeDiscriminator bool
discriminator Discriminator
}

func NewAccountEntry(offchainName string, idlAccount IdlTypeDef, idlTypes IdlTypeDefSlice, includeDiscriminator bool, mod codec.Modifier, builder commonencodings.Builder) (Entry, error) {
Expand All @@ -47,10 +49,11 @@ func NewAccountEntry(offchainName string, idlAccount IdlTypeDef, idlTypes IdlTyp
return &entry{
offchainName: offchainName,
onchainName: idlAccount.Name,
includeDiscriminator: includeDiscriminator,
typeCodec: accCodec,
reflectType: accCodec.GetType(),
typeCodec: accCodec,
mod: ensureModifier(mod),
includeDiscriminator: includeDiscriminator,
discriminator: *NewDiscriminator(idlAccount.Name),
}, nil
}

Expand Down Expand Up @@ -96,8 +99,7 @@ func (e *entry) Encode(value any, into []byte) ([]byte, error) {

if e.includeDiscriminator {
var byt []byte
disc := NewDiscriminator(e.onchainName)
encodedDisc, err := disc.Encode(&disc.hashPrefix, byt)
encodedDisc, err := e.discriminator.Encode(&e.discriminator.hashPrefix, byt)
if err != nil {
return nil, err
}
Expand All @@ -112,6 +114,11 @@ func (e *entry) Decode(encoded []byte) (any, []byte, error) {
if len(encoded) < discriminatorLength {
return nil, nil, fmt.Errorf("%w: encoded data too short to contain discriminator for offchainName: %q, onchainName: %q", commontypes.ErrInvalidType, e.offchainName, e.onchainName)
}

if !bytes.Equal(e.discriminator.hashPrefix, encoded[:8]) {
return nil, nil, fmt.Errorf("%w: encoded data has a bad discriminator %v for offchainName: %q, onchainName: %q", commontypes.ErrInvalidType, encoded[:8], e.offchainName, e.onchainName)
}

encoded = encoded[discriminatorLength:]
}
return e.typeCodec.Decode(encoded)
Expand Down

0 comments on commit 6bc1d31

Please sign in to comment.