Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix boolean, marshal, unmarshall functions #379

Merged
merged 2 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 12 additions & 53 deletions marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"bytes"
"errors"
"fmt"
"github.com/gocql/gocql/serialization/boolean"
"math"
"math/big"
"math/bits"
Expand Down Expand Up @@ -154,7 +155,7 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
case TypeAscii:
return marshalAscii(value)
case TypeBoolean:
return marshalBool(info, value)
return marshalBool(value)
case TypeTinyInt:
return marshalTinyInt(value)
case TypeSmallInt:
Expand Down Expand Up @@ -266,7 +267,7 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
case TypeAscii:
return unmarshalAscii(data, value)
case TypeBoolean:
return unmarshalBool(info, data, value)
return unmarshalBool(data, value)
case TypeInt:
return unmarshalInt(data, value)
case TypeBigInt:
Expand Down Expand Up @@ -525,61 +526,19 @@ func decBigInt(data []byte) int64 {
int64(data[6])<<8 | int64(data[7])
}

func marshalBool(info TypeInfo, value interface{}) ([]byte, error) {
switch v := value.(type) {
case Marshaler:
return v.MarshalCQL(info)
case unsetColumn:
return nil, nil
case bool:
return encBool(v), nil
}

if value == nil {
return nil, nil
}

rv := reflect.ValueOf(value)
switch rv.Type().Kind() {
case reflect.Bool:
return encBool(rv.Bool()), nil
}
return nil, marshalErrorf("can not marshal %T into %s", value, info)
}

func encBool(v bool) []byte {
if v {
return []byte{1}
}
return []byte{0}
}

func unmarshalBool(info TypeInfo, data []byte, value interface{}) error {
switch v := value.(type) {
case Unmarshaler:
return v.UnmarshalCQL(info, data)
case *bool:
*v = decBool(data)
return nil
}
rv := reflect.ValueOf(value)
if rv.Kind() != reflect.Ptr {
return unmarshalErrorf("can not unmarshal into non-pointer %T", value)
}
rv = rv.Elem()
switch rv.Type().Kind() {
case reflect.Bool:
rv.SetBool(decBool(data))
return nil
func marshalBool(value interface{}) ([]byte, error) {
data, err := boolean.Marshal(value)
if err != nil {
return nil, wrapMarshalError(err, "marshal error")
}
return unmarshalErrorf("can not unmarshal %s into %T", info, value)
return data, nil
}

func decBool(v []byte) bool {
if len(v) == 0 {
return false
func unmarshalBool(data []byte, value interface{}) error {
if err := boolean.Unmarshal(data, value); err != nil {
return wrapUnmarshalError(err, "unmarshal error")
}
return v[0] != 0
return nil
}

func marshalFloat(value interface{}) ([]byte, error) {
Expand Down
41 changes: 0 additions & 41 deletions marshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,6 @@ var marshalTests = []struct {
MarshalError error
UnmarshalError error
}{
{
NativeType{proto: 2, typ: TypeBoolean},
[]byte("\x00"),
false,
nil,
nil,
},
{
NativeType{proto: 2, typ: TypeBoolean},
[]byte("\x01"),
true,
nil,
nil,
},
{
NativeType{proto: 2, typ: TypeDecimal},
[]byte("\x00\x00\x00\x00\x00"),
Expand Down Expand Up @@ -303,33 +289,6 @@ var marshalTests = []struct {
nil,
nil,
},
{
NativeType{proto: 2, typ: TypeBoolean},
[]byte("\x00"),
func() *bool {
b := false
return &b
}(),
nil,
nil,
},
{
NativeType{proto: 2, typ: TypeBoolean},
[]byte("\x01"),
func() *bool {
b := true
return &b
}(),
nil,
nil,
},
{
NativeType{proto: 2, typ: TypeBoolean},
[]byte(nil),
(*bool)(nil),
nil,
nil,
},
{
NativeType{proto: 2, typ: TypeInet},
[]byte("\x7F\x00\x00\x01"),
Expand Down
24 changes: 24 additions & 0 deletions serialization/boolean/marshal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package boolean

import (
"reflect"
)

func Marshal(value interface{}) ([]byte, error) {
switch v := value.(type) {
case nil:
return nil, nil
case bool:
return EncBool(v)
case *bool:
return EncBoolR(v)
default:
// Custom types (type MyBool bool) can be serialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.TypeOf(value)
if rv.Kind() != reflect.Ptr {
return EncReflect(reflect.ValueOf(v))
}
return EncReflectR(reflect.ValueOf(v))
}
}
45 changes: 45 additions & 0 deletions serialization/boolean/marshal_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package boolean

import (
"fmt"
"reflect"
)

func EncBool(v bool) ([]byte, error) {
return encBool(v), nil
}

func EncBoolR(v *bool) ([]byte, error) {
if v == nil {
return nil, nil
}
return encBool(*v), nil
}

func EncReflect(v reflect.Value) ([]byte, error) {
switch v.Kind() {
case reflect.Bool:
return encBool(v.Bool()), nil
case reflect.Struct:
if v.Type().String() == "gocql.unsetColumn" {
return nil, nil
}
return nil, fmt.Errorf("failed to marshal boolean: unsupported value type (%T)(%[1]v)", v.Interface())
default:
return nil, fmt.Errorf("failed to marshal boolean: unsupported value type (%T)(%[1]v)", v.Interface())
}
}

func EncReflectR(v reflect.Value) ([]byte, error) {
if v.IsNil() {
return nil, nil
}
return EncReflect(v.Elem())
}

func encBool(v bool) []byte {
if v {
return []byte{1}
}
return []byte{0}
}
29 changes: 29 additions & 0 deletions serialization/boolean/unmarshal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package boolean

import (
"fmt"
"reflect"
)

func Unmarshal(data []byte, value interface{}) error {
switch v := value.(type) {
case nil:
return nil
case *bool:
return DecBool(data, v)
case **bool:
return DecBoolR(data, v)
default:
// Custom types (type MyBool bool) can be deserialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.ValueOf(value)
rt := rv.Type()
if rt.Kind() != reflect.Ptr {
return fmt.Errorf("failed to unmarshal boolean: unsupported value type (%T)(%[1]v)", v)
}
if rt.Elem().Kind() != reflect.Ptr {
return DecReflect(data, rv)
}
return DecReflectR(data, rv)
}
}
108 changes: 108 additions & 0 deletions serialization/boolean/unmarshal_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package boolean

import (
"fmt"
"reflect"
)

var errWrongDataLen = fmt.Errorf("failed to unmarshal boolean: the length of the data should be 0 or 1")

func errNilReference(v interface{}) error {
return fmt.Errorf("failed to unmarshal boolean: can not unmarshal into nil reference(%T)(%[1]v)", v)
}

func DecBool(p []byte, v *bool) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = false
case 1:
*v = decBool(p)
default:
return errWrongDataLen
}
return nil
}

func DecBoolR(p []byte, v **bool) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(bool)
}
case 1:
val := decBool(p)
*v = &val
default:
return errWrongDataLen
}
return nil
}

func DecReflect(p []byte, v reflect.Value) error {
if v.IsNil() {
return errNilReference(v)
}

switch v = v.Elem(); v.Kind() {
case reflect.Bool:
return decReflectBool(p, v)
default:
return fmt.Errorf("failed to unmarshal boolean: unsupported value type (%T)(%[1]v)", v.Interface())
}
}

func DecReflectR(p []byte, v reflect.Value) error {
if v.IsNil() {
return errNilReference(v)
}

switch v.Type().Elem().Elem().Kind() {
case reflect.Bool:
return decReflectBoolR(p, v)
default:
return fmt.Errorf("failed to unmarshal boolean: unsupported value type (%T)(%[1]v)", v.Interface())
}
}

func decReflectBool(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetBool(false)
case 1:
v.SetBool(decBool(p))
default:
return errWrongDataLen
}
return nil
}

func decReflectBoolR(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
if p == nil {
v.Elem().Set(reflect.Zero(v.Type().Elem()))
} else {
val := reflect.New(v.Type().Elem().Elem())
v.Elem().Set(val)
}
case 1:
val := reflect.New(v.Type().Elem().Elem())
val.Elem().SetBool(decBool(p))
v.Elem().Set(val)
default:
return errWrongDataLen
}
return nil
}

func decBool(p []byte) bool {
return p[0] != 0
}
Loading
Loading