Skip to content

Commit

Permalink
fix ascii marshal functions
Browse files Browse the repository at this point in the history
  • Loading branch information
illia-li committed Nov 14, 2024
1 parent aa6c8dc commit 0299959
Show file tree
Hide file tree
Showing 5 changed files with 312 additions and 71 deletions.
88 changes: 17 additions & 71 deletions marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

"gopkg.in/inf.v0"

"github.com/gocql/gocql/serialization/ascii"
"github.com/gocql/gocql/serialization/bigint"
"github.com/gocql/gocql/serialization/blob"
"github.com/gocql/gocql/serialization/counter"
Expand Down Expand Up @@ -147,7 +148,7 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
case TypeBlob:
return marshalBlob(value)
case TypeAscii:
return marshalVarcharOld(info, value)
return marshalAscii(value)
case TypeBoolean:
return marshalBool(info, value)
case TypeTinyInt:
Expand Down Expand Up @@ -257,7 +258,7 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
case TypeBlob:
return unmarshalBlob(data, value)
case TypeAscii:
return unmarshalVarcharOld(info, data, value)
return unmarshalAscii(data, value)
case TypeBoolean:
return unmarshalBool(info, data, value)
case TypeInt:
Expand Down Expand Up @@ -341,6 +342,7 @@ func marshalVarchar(value interface{}) ([]byte, error) {
}
return data, nil
}

func marshalText(value interface{}) ([]byte, error) {
data, err := text.Marshal(value)
if err != nil {
Expand All @@ -357,6 +359,14 @@ func marshalBlob(value interface{}) ([]byte, error) {
return data, nil
}

func marshalAscii(value interface{}) ([]byte, error) {
data, err := ascii.Marshal(value)
if err != nil {
return nil, wrapMarshalError(err, "marshal error")
}
return data, nil
}

func unmarshalVarchar(data []byte, value interface{}) error {
err := varchar.Unmarshal(data, value)
if err != nil {
Expand All @@ -381,76 +391,12 @@ func unmarshalBlob(data []byte, value interface{}) error {
return nil
}

func marshalVarcharOld(info TypeInfo, value interface{}) ([]byte, error) {
switch v := value.(type) {
case Marshaler:
return v.MarshalCQL(info)
case unsetColumn:
return nil, nil
case string:
return []byte(v), nil
case []byte:
return v, nil
}

if value == nil {
return nil, nil
}

rv := reflect.ValueOf(value)
t := rv.Type()
k := t.Kind()
switch {
case k == reflect.String:
return []byte(rv.String()), nil
case k == reflect.Slice && t.Elem().Kind() == reflect.Uint8:
return rv.Bytes(), nil
}
return nil, marshalErrorf("can not marshal %T into %s", value, info)
}

func unmarshalVarcharOld(info TypeInfo, data []byte, value interface{}) error {
switch v := value.(type) {
case Unmarshaler:
return v.UnmarshalCQL(info, data)
case *string:
*v = string(data)
return nil
case *[]byte:
if data != nil {
*v = append((*v)[:0], data...)
} else {
*v = nil
}
return nil
}

rv := reflect.ValueOf(value)
if rv.Kind() != reflect.Ptr {
return unmarshalErrorf("can not unmarshal into non-pointer %T", value)
}
rv = rv.Elem()
t := rv.Type()
k := t.Kind()
switch {
case k == reflect.String:
rv.SetString(string(data))
return nil
case k == reflect.Slice && t.Elem().Kind() == reflect.Uint8, k == reflect.Interface:
var dataCopy []byte
if data != nil {
dataCopy = make([]byte, len(data))
copy(dataCopy, data)
}
if k == reflect.Slice {
rv.SetBytes(dataCopy)
} else {
rv.Set(reflect.ValueOf(dataCopy))
}
return nil
func unmarshalAscii(data []byte, value interface{}) error {
err := ascii.Unmarshal(data, value)
if err != nil {
return wrapUnmarshalError(err, "unmarshal error")
}

return unmarshalErrorf("can not unmarshal %s into %T", info, value)
return nil
}

func marshalSmallInt(value interface{}) ([]byte, error) {
Expand Down
28 changes: 28 additions & 0 deletions serialization/ascii/marshal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package ascii

import (
"reflect"
)

func Marshal(value interface{}) ([]byte, error) {
switch v := value.(type) {
case nil:
return nil, nil
case string:
return EncString(v)
case *string:
return EncStringR(v)
case []byte:
return EncBytes(v)
case *[]byte:
return EncBytesR(v)
default:
// Custom types (type MyString string) can be serialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.ValueOf(value)
if rv.Kind() != reflect.Ptr {
return EncReflect(rv)
}
return EncReflectR(rv)
}
}
56 changes: 56 additions & 0 deletions serialization/ascii/marshal_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package ascii

import (
"fmt"
"reflect"
)

func EncString(v string) ([]byte, error) {
return encString(v), nil
}

func EncStringR(v *string) ([]byte, error) {
if v == nil {
return nil, nil
}
return encString(*v), nil
}

func EncBytes(v []byte) ([]byte, error) {
return v, nil
}

func EncBytesR(v *[]byte) ([]byte, error) {
if v == nil {
return nil, nil
}
return *v, nil
}

func EncReflect(v reflect.Value) ([]byte, error) {
switch v.Kind() {
case reflect.String:
return encString(v.String()), nil
case reflect.Slice:
if v.Type().Elem().Kind() != reflect.Uint8 {
return nil, fmt.Errorf("failed to marshal ascii: unsupported value type (%T)(%[1]v)", v.Interface())
}
return EncBytes(v.Bytes())
default:
return nil, fmt.Errorf("failed to marshal ascii: unsupported value type (%T)(%[1]v)", v.Interface())
}
}

func EncReflectR(v reflect.Value) ([]byte, error) {
if v.IsNil() {
return nil, nil
}
return EncReflect(v.Elem())
}

func encString(v string) []byte {
if v == "" {
return make([]byte, 0)
}
return []byte(v)
}
35 changes: 35 additions & 0 deletions serialization/ascii/unmarshal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package ascii

import (
"fmt"
"reflect"
)

func Unmarshal(data []byte, value interface{}) error {
switch v := value.(type) {
case nil:
return nil
case *string:
return DecString(data, v)
case **string:
return DecStringR(data, v)
case *[]byte:
return DecBytes(data, v)
case **[]byte:
return DecBytesR(data, v)
case *interface{}:
return DecInterface(data, v)
default:
// Custom types (type MyString string) can be deserialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.ValueOf(value)
rt := rv.Type()
if rt.Kind() != reflect.Ptr {
return fmt.Errorf("failed to unmarshal ascii: unsupported value type (%T)(%[1]v)", v)
}
if rt.Elem().Kind() != reflect.Ptr {
return DecReflect(data, rv)
}
return DecReflectR(data, rv)
}
}
Loading

0 comments on commit 0299959

Please sign in to comment.