Skip to content

Commit

Permalink
feat: migrate CBOR constructor for generic use
Browse files Browse the repository at this point in the history
Fixes #550
  • Loading branch information
agaffney committed Mar 26, 2024
1 parent e04a54d commit e3ed823
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 105 deletions.
128 changes: 128 additions & 0 deletions cbor/tags.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package cbor

import (
"fmt"
"math/big"
"reflect"

Expand Down Expand Up @@ -75,6 +76,16 @@ func init() {
); err != nil {
panic(err)
}
// Plutus constructors (CBOR alternatives)
//for i := CborTagAlternative1Min; i <= CborTagAlternative1Max; i++ {
if err := customTagSet.Add(
tagOpts,
reflect.TypeOf(Constructor{}),
CborTagAlternative1Min+1,
); err != nil {
panic(err)
}
//}
}

// WrappedCbor corresponds to CBOR tag 24 and is used to encode nested CBOR data
Expand Down Expand Up @@ -118,3 +129,120 @@ type Set []any

// Map corresponds to CBOR tag 259 and is used to represent a map with key/value operations
type Map map[any]any

// TODO: docs
type Constructor struct {
DecodeStoreCbor
fieldsCbor []byte
constructor uint
value any
}

func NewConstructor(constructor uint, value any) Constructor {
c := Constructor{
constructor: constructor,
}
if value != nil {
c.value = value
}
fmt.Printf("NewConstructor(): value = %#v\n", value)
return c
}

func (v Constructor) Constructor() uint {
return v.constructor
}

func (v Constructor) Fields() []any {
return v.value.([]any)
}

func (c *Constructor) SetFieldsCbor(data []byte) {
c.fieldsCbor = data[:]
}

func (c Constructor) FieldsCbor() []byte {
return c.fieldsCbor[:]
}

func (c *Constructor) UnmarshalCBOR(data []byte) error {
// Save original CBOR
c.SetCbor(data)
// Parse as a raw tag to get number and nested CBOR data
tmpTag := RawTag{}
if _, err := Decode(data, &tmpTag); err != nil {
return err
}
fmt.Printf("tmpTag = %#v\n", tmpTag)
c.SetFieldsCbor([]byte(tmpTag.Content))
// Parse the tag value via our custom Value object to handle problem types
//tmpValue := Value{}
var tmpValue any
if _, err := Decode(tmpTag.Content, &tmpValue); err != nil {
return err
}
if tmpTag.Number >= CborTagAlternative1Min && tmpTag.Number <= CborTagAlternative1Max {
// Alternatives 0-6
c.constructor = uint(tmpTag.Number - CborTagAlternative1Min)
c.value = tmpValue
} else if tmpTag.Number >= CborTagAlternative2Min && tmpTag.Number <= CborTagAlternative2Max {
// Alternatives 7-127
c.constructor = uint(tmpTag.Number - CborTagAlternative2Min + 7)
c.value = tmpValue
} else if tmpTag.Number == CborTagAlternative3 {
// Alternatives 128+
tmpValues := tmpValue.([]any)
c.constructor = uint(tmpValues[0].(uint64))
/*
newValue := Value{
value: tmpValues[1],
}
c.value = newValue
*/
c.value = tmpValues[1]
} else {
return fmt.Errorf("unsupported tag: %d", tmpTag.Number)
}
fmt.Printf("UnmarshalCBOR(): c.value = %#v\n", c.value)
return nil
}

func (c Constructor) MarshalCBOR() ([]byte, error) {
var tmpTag Tag
if c.constructor <= 6 {
// Alternatives 0-6
tmpTag.Number = uint64(c.constructor + CborTagAlternative1Min)
tmpTag.Content = c.value
} else if c.constructor >= 7 && c.constructor <= 127 {
// Alternatives 7-127
tmpTag.Number = uint64(c.constructor + CborTagAlternative2Min - 7)
tmpTag.Content = c.value
} else if c.constructor >= 128 {
tmpTag.Number = CborTagAlternative3
tmpTag.Content = []any{
c.constructor,
c.value,
}
}
return Encode(&tmpTag)
}

func (v Constructor) MarshalJSON() ([]byte, error) {
tmpJson := fmt.Sprintf(`{"constructor":%d,"fields":[`, v.constructor)
tmpList := [][]byte{}
for _, val := range v.value.([]any) {
tmpVal, err := generateAstJson(val)
if err != nil {
return nil, err
}
tmpList = append(tmpList, tmpVal)
}
for idx, val := range tmpList {
tmpJson += string(val)
if idx != (len(tmpList) - 1) {
tmpJson += `,`
}
}
tmpJson += `]}`
return []byte(tmpJson), nil
}
29 changes: 29 additions & 0 deletions cbor/tags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,24 @@ var tagsTestDefs = []struct {
},
),
},
/*
// 121([1, 2, 3])
{
cborHex: "d87983010203",
object: cbor.NewConstructor(
0,
[]any{0x1, 0x2, 0x3},
),
},
*/
// 122(1, 2, 3])
{
cborHex: "d87a83010203",
object: cbor.NewConstructor(
1,
[]any{uint64(1), uint64(2), uint64(3)},
),
},
}

func TestTagsDecode(t *testing.T) {
Expand All @@ -66,6 +84,17 @@ func TestTagsDecode(t *testing.T) {
if _, err := cbor.Decode(cborData, &dest); err != nil {
t.Fatalf("failed to decode CBOR: %s", err)
}
// Set stored CBOR for supported types to make comparison easier
switch v := dest.(type) {
case cbor.Constructor:
v.SetFieldsCbor(nil)
dest = v
}
switch v := testDef.object.(type) {
case cbor.DecodeStoreCborInterface:
v.SetCbor(cborData)
testDef.object = v
}
if !reflect.DeepEqual(dest, testDef.object) {
t.Fatalf(
"CBOR did not decode to expected object\n got: %#v\n wanted: %#v",
Expand Down
105 changes: 0 additions & 105 deletions cbor/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,111 +247,6 @@ func generateAstJsonMap[T map[any]any | Map](v T) ([]byte, error) {

}

type Constructor struct {
DecodeStoreCbor
constructor uint
value *Value
}

func NewConstructor(constructor uint, value any) Constructor {
c := Constructor{
constructor: constructor,
}
if value != nil {
c.value = &Value{
value: value,
}
}
return c
}

func (v Constructor) Constructor() uint {
return v.constructor
}

func (v Constructor) Fields() []any {
return v.value.Value().([]any)
}

func (c Constructor) FieldsCbor() []byte {
return c.value.Cbor()
}

func (c *Constructor) UnmarshalCBOR(data []byte) error {
// Save original CBOR
c.SetCbor(data)
// Parse as a raw tag to get number and nested CBOR data
tmpTag := RawTag{}
if _, err := Decode(data, &tmpTag); err != nil {
return err
}
// Parse the tag value via our custom Value object to handle problem types
tmpValue := Value{}
if _, err := Decode(tmpTag.Content, &tmpValue); err != nil {
return err
}
if tmpTag.Number >= CborTagAlternative1Min && tmpTag.Number <= CborTagAlternative1Max {
// Alternatives 0-6
c.constructor = uint(tmpTag.Number - CborTagAlternative1Min)
c.value = &tmpValue
} else if tmpTag.Number >= CborTagAlternative2Min && tmpTag.Number <= CborTagAlternative2Max {
// Alternatives 7-127
c.constructor = uint(tmpTag.Number - CborTagAlternative2Min + 7)
c.value = &tmpValue
} else if tmpTag.Number == CborTagAlternative3 {
// Alternatives 128+
tmpValues := tmpValue.Value().([]any)
c.constructor = uint(tmpValues[0].(uint64))
newValue := Value{
value: tmpValues[1],
}
c.value = &newValue
} else {
return fmt.Errorf("unsupported tag: %d", tmpTag.Number)
}
return nil
}

func (c Constructor) MarshalCBOR() ([]byte, error) {
var tmpTag Tag
if c.constructor <= 6 {
// Alternatives 0-6
tmpTag.Number = uint64(c.constructor + CborTagAlternative1Min)
tmpTag.Content = c.value.Value()
} else if c.constructor >= 7 && c.constructor <= 127 {
// Alternatives 7-127
tmpTag.Number = uint64(c.constructor + CborTagAlternative2Min - 7)
tmpTag.Content = c.value.Value()
} else if c.constructor >= 128 {
tmpTag.Number = CborTagAlternative3
tmpTag.Content = []any{
c.constructor,
c.value.Value(),
}
}
return Encode(&tmpTag)
}

func (v Constructor) MarshalJSON() ([]byte, error) {
tmpJson := fmt.Sprintf(`{"constructor":%d,"fields":[`, v.constructor)
tmpList := [][]byte{}
for _, val := range v.value.Value().([]any) {
tmpVal, err := generateAstJson(val)
if err != nil {
return nil, err
}
tmpList = append(tmpList, tmpVal)
}
for idx, val := range tmpList {
tmpJson += string(val)
if idx != (len(tmpList) - 1) {
tmpJson += `,`
}
}
tmpJson += `]}`
return []byte(tmpJson), nil
}

type LazyValue struct {
value *Value
}
Expand Down

0 comments on commit e3ed823

Please sign in to comment.