Skip to content

Commit

Permalink
Merge pull request #367 from illia-li/il/fix/marshal/decimal
Browse files Browse the repository at this point in the history
Fix `decimal` marshal, unmarshall functions
  • Loading branch information
dkropachev authored Dec 9, 2024
2 parents 7572e54 + f010340 commit 8f09151
Show file tree
Hide file tree
Showing 10 changed files with 1,044 additions and 82 deletions.
10 changes: 8 additions & 2 deletions internal/tests/serialization/utils_equal.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,16 @@ func equalVals(in1, in2 interface{}) bool {
return vin1.Cmp(vin2) == 0
case inf.Dec:
vin2 := in2.(inf.Dec)
return vin1.Cmp(&vin2) == 0
if vin1.Scale() != vin2.Scale() {
return false
}
return vin1.UnscaledBig().Cmp(vin2.UnscaledBig()) == 0
case *inf.Dec:
vin2 := in2.(*inf.Dec)
return vin1.Cmp(vin2) == 0
if vin1.Scale() != vin2.Scale() {
return false
}
return vin1.UnscaledBig().Cmp(vin2.UnscaledBig()) == 0
case fmt.Stringer:
vin2 := in2.(fmt.Stringer)
return vin1.String() == vin2.String()
Expand Down
49 changes: 12 additions & 37 deletions marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"encoding/binary"
"errors"
"fmt"
"gopkg.in/inf.v0"
"math"
"math/big"
"math/bits"
Expand All @@ -23,6 +22,7 @@ import (
"github.com/gocql/gocql/serialization/blob"
"github.com/gocql/gocql/serialization/counter"
"github.com/gocql/gocql/serialization/cqlint"
"github.com/gocql/gocql/serialization/decimal"
"github.com/gocql/gocql/serialization/double"
"github.com/gocql/gocql/serialization/float"
"github.com/gocql/gocql/serialization/inet"
Expand Down Expand Up @@ -168,7 +168,7 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
case TypeDouble:
return marshalDouble(value)
case TypeDecimal:
return marshalDecimal(info, value)
return marshalDecimal(value)
case TypeTime:
return marshalTime(info, value)
case TypeTimestamp:
Expand Down Expand Up @@ -282,7 +282,7 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
case TypeDouble:
return unmarshalDouble(data, value)
case TypeDecimal:
return unmarshalDecimal(info, data, value)
return unmarshalDecimal(data, value)
case TypeTime:
return unmarshalTime(info, data, value)
case TypeTimestamp:
Expand Down Expand Up @@ -611,44 +611,19 @@ func unmarshalDouble(data []byte, value interface{}) error {
return nil
}

func marshalDecimal(info TypeInfo, value interface{}) ([]byte, error) {
if value == nil {
return nil, nil
}

switch v := value.(type) {
case Marshaler:
return v.MarshalCQL(info)
case unsetColumn:
return nil, nil
case inf.Dec:
unscaled := encBigInt2C(v.UnscaledBig())
if unscaled == nil {
return nil, marshalErrorf("can not marshal %T into %s", value, info)
}

buf := make([]byte, 4+len(unscaled))
copy(buf[0:4], encInt(int32(v.Scale())))
copy(buf[4:], unscaled)
return buf, nil
func marshalDecimal(value interface{}) ([]byte, error) {
data, err := decimal.Marshal(value)
if err != nil {
return nil, wrapMarshalError(err, "marshal error")
}
return nil, marshalErrorf("can not marshal %T into %s", value, info)
return data, nil
}

func unmarshalDecimal(info TypeInfo, data []byte, value interface{}) error {
switch v := value.(type) {
case Unmarshaler:
return v.UnmarshalCQL(info, data)
case *inf.Dec:
if len(data) < 4 {
return unmarshalErrorf("inf.Dec needs at least 4 bytes, while value has only %d", len(data))
}
scale := decInt(data[0:4])
unscaled := decBigInt2C(data[4:], nil)
*v = *inf.NewDecBig(unscaled, inf.Scale(scale))
return nil
func unmarshalDecimal(data []byte, value interface{}) error {
if err := decimal.Unmarshal(data, value); err != nil {
return wrapUnmarshalError(err, "unmarshal error")
}
return unmarshalErrorf("can not unmarshal %s into %T", info, value)
return nil
}

// decBigInt2C sets the value of n to the big-endian two's complement
Expand Down
6 changes: 0 additions & 6 deletions marshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -598,12 +598,6 @@ var unmarshalTests = []struct {
map[string]int{"foo": 1},
unmarshalErrorf("unmarshal map: unexpected eof"),
},
{
NativeType{proto: 2, typ: TypeDecimal},
[]byte("\xff\xff\xff"),
inf.NewDec(0, 0), // From the datastax/python-driver test suite
unmarshalErrorf("inf.Dec needs at least 4 bytes, while value has only 3"),
},
{
NativeType{proto: 5, typ: TypeDuration},
[]byte("\x89\xa2\xc3\xc2\x9a\xe0F\x91"),
Expand Down
29 changes: 29 additions & 0 deletions serialization/decimal/marshal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package decimal

import (
"gopkg.in/inf.v0"
"reflect"
)

func Marshal(value interface{}) ([]byte, error) {
switch v := value.(type) {
case nil:
return nil, nil
case inf.Dec:
return EncInfDec(v)
case *inf.Dec:
return EncInfDecR(v)
case string:
return EncString(v)
case *string:
return EncStringR(v)
default:
// Custom types (type MyString string) can be serialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.TypeOf(value)
if rv.Kind() != reflect.Ptr {
return EncReflect(reflect.ValueOf(v))
}
return EncReflectR(reflect.ValueOf(v))
}
}
141 changes: 141 additions & 0 deletions serialization/decimal/marshal_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package decimal

import (
"fmt"
"gopkg.in/inf.v0"
"math/big"
"reflect"
"strconv"
"strings"

"github.com/gocql/gocql/serialization/varint"
)

func EncInfDec(v inf.Dec) ([]byte, error) {
sign := v.Sign()
if sign == 0 {
return []byte{0, 0, 0, 0, 0}, nil
}
return append(encScale(v.Scale()), varint.EncBigIntRS(v.UnscaledBig())...), nil
}

func EncInfDecR(v *inf.Dec) ([]byte, error) {
if v == nil {
return nil, nil
}
return encInfDecR(v), nil
}

// EncString encodes decimal string which should contains `scale` and `unscaled` strings separated by `;`.
func EncString(v string) ([]byte, error) {
if v == "" {
return nil, nil
}
vs := strings.Split(v, ";")
if len(vs) != 2 {
return nil, fmt.Errorf("failed to marshal decimal: invalid decimal string %s", v)
}
scale, err := strconv.ParseInt(vs[0], 10, 32)
if err != nil {
return nil, fmt.Errorf("failed to marshal decimal: invalid decimal scale string %s", vs[0])
}
unscaleData, err := encUnscaledString(vs[1])
if err != nil {
return nil, err
}
return append(encScale64(scale), unscaleData...), nil
}

func EncStringR(v *string) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncString(*v)
}

func EncReflect(v reflect.Value) ([]byte, error) {
switch v.Type().Kind() {
case reflect.String:
return encReflectString(v)
case reflect.Struct:
if v.Type().String() == "gocql.unsetColumn" {
return nil, nil
}
return nil, fmt.Errorf("failed to marshal decimal: unsupported value type (%T)(%[1]v)", v.Interface())
default:
return nil, fmt.Errorf("failed to marshal decimal: unsupported value type (%T)(%[1]v)", v.Interface())
}
}

func EncReflectR(v reflect.Value) ([]byte, error) {
if v.IsNil() {
return nil, nil
}
return EncReflect(v.Elem())
}

func encReflectString(v reflect.Value) ([]byte, error) {
val := v.String()
if val == "" {
return nil, nil
}
vs := strings.Split(val, ";")
if len(vs) != 2 {
return nil, fmt.Errorf("failed to marshal decimal: invalid decimal string (%T)(%[1]v)", v.Interface())
}
scale, err := strconv.ParseInt(vs[0], 10, 32)
if err != nil {
return nil, fmt.Errorf("failed to marshal decimal: invalid decimal scale string (%T)(%s)", v.Interface(), vs[0])
}
unscaledData, err := encUnscaledString(vs[1])
if err != nil {
return nil, err
}
return append(encScale64(scale), unscaledData...), nil
}

func encInfDecR(v *inf.Dec) []byte {
sign := v.Sign()
if sign == 0 {
return []byte{0, 0, 0, 0, 0}
}
return append(encScale(v.Scale()), varint.EncBigIntRS(v.UnscaledBig())...)
}

func encScale(v inf.Scale) []byte {
return []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}
}

func encScale64(v int64) []byte {
return []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}
}

func encUnscaledString(v string) ([]byte, error) {
switch {
case len(v) == 0:
return nil, nil
case len(v) <= 18:
n, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to marshal decimal: invalid unscaled string %s, %s", v, err)
}
return varint.EncInt64Ext(n), nil
case len(v) <= 20:
n, err := strconv.ParseInt(v, 10, 64)
if err == nil {
return varint.EncInt64Ext(n), nil
}

t, ok := new(big.Int).SetString(v, 10)
if !ok {
return nil, fmt.Errorf("failed to marshal decimal: invalid unscaled string %s", v)
}
return varint.EncBigIntRS(t), nil
default:
t, ok := new(big.Int).SetString(v, 10)
if !ok {
return nil, fmt.Errorf("failed to marshal decimal: invalid unscaled string %s", v)
}
return varint.EncBigIntRS(t), nil
}
}
34 changes: 34 additions & 0 deletions serialization/decimal/unmarshal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package decimal

import (
"fmt"
"gopkg.in/inf.v0"
"reflect"
)

func Unmarshal(data []byte, value interface{}) error {
switch v := value.(type) {
case nil:
return nil
case *inf.Dec:
return DecInfDec(data, v)
case **inf.Dec:
return DecInfDecR(data, v)
case *string:
return DecString(data, v)
case **string:
return DecStringR(data, v)
default:
// Custom types (type MyString string) can be deserialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.ValueOf(value)
rt := rv.Type()
if rt.Kind() != reflect.Ptr {
return fmt.Errorf("failed to unmarshal decimal: unsupported value type (%T)(%#[1]v)", value)
}
if rt.Elem().Kind() != reflect.Ptr {
return DecReflect(data, rv)
}
return DecReflectR(data, rv)
}
}
Loading

0 comments on commit 8f09151

Please sign in to comment.