Skip to content

Commit

Permalink
fix inet marshal, unmarshal functions
Browse files Browse the repository at this point in the history
  • Loading branch information
illia-li committed Nov 29, 2024
1 parent 85b37dc commit f587a93
Show file tree
Hide file tree
Showing 5 changed files with 806 additions and 61 deletions.
74 changes: 13 additions & 61 deletions marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"math"
"math/big"
"math/bits"
"net"
"reflect"
"strings"
"time"
Expand All @@ -26,6 +25,7 @@ import (
"github.com/gocql/gocql/serialization/cqlint"
"github.com/gocql/gocql/serialization/double"
"github.com/gocql/gocql/serialization/float"
"github.com/gocql/gocql/serialization/inet"
"github.com/gocql/gocql/serialization/smallint"
"github.com/gocql/gocql/serialization/text"
"github.com/gocql/gocql/serialization/timeuuid"
Expand Down Expand Up @@ -184,7 +184,7 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
case TypeVarint:
return marshalVarint(value)
case TypeInet:
return marshalInet(info, value)
return marshalInet(value)
case TypeTuple:
return marshalTuple(info, value)
case TypeUDT:
Expand Down Expand Up @@ -296,7 +296,7 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
case TypeUUID:
return unmarshalUUID(data, value)
case TypeInet:
return unmarshalInet(info, data, value)
return unmarshalInet(data, value)
case TypeTuple:
return unmarshalTuple(info, data, value)
case TypeUDT:
Expand Down Expand Up @@ -1360,68 +1360,20 @@ func unmarshalTimeUUID(data []byte, value interface{}) error {
return nil
}

func marshalInet(info TypeInfo, value interface{}) ([]byte, error) {
// we return either the 4 or 16 byte representation of an
// ip address here otherwise the db value will be prefixed
// with the remaining byte values e.g. ::ffff:127.0.0.1 and not 127.0.0.1
switch val := value.(type) {
case unsetColumn:
return nil, nil
case net.IP:
t := val.To4()
if t == nil {
return val.To16(), nil
}
return t, nil
case string:
b := net.ParseIP(val)
if b != nil {
t := b.To4()
if t == nil {
return b.To16(), nil
}
return t, nil
}
return nil, marshalErrorf("cannot marshal. invalid ip string %s", val)
}

if value == nil {
return nil, nil
func marshalInet(value interface{}) ([]byte, error) {
data, err := inet.Marshal(value)
if err != nil {
return nil, wrapMarshalError(err, "marshal error")
}

return nil, marshalErrorf("cannot marshal %T into %s", value, info)
return data, nil
}

func unmarshalInet(info TypeInfo, data []byte, value interface{}) error {
switch v := value.(type) {
case Unmarshaler:
return v.UnmarshalCQL(info, data)
case *net.IP:
if x := len(data); !(x == 4 || x == 16) {
return unmarshalErrorf("cannot unmarshal %s into %T: invalid sized IP: got %d bytes not 4 or 16", info, value, x)
}
buf := copyBytes(data)
ip := net.IP(buf)
if v4 := ip.To4(); v4 != nil {
*v = v4
return nil
}
*v = ip
return nil
case *string:
if len(data) == 0 {
*v = ""
return nil
}
ip := net.IP(data)
if v4 := ip.To4(); v4 != nil {
*v = v4.String()
return nil
}
*v = ip.String()
return nil
func unmarshalInet(data []byte, value interface{}) error {
err := inet.Unmarshal(data, value)
if err != nil {
return wrapUnmarshalError(err, "unmarshal error")
}
return unmarshalErrorf("cannot unmarshal %s into %T", info, value)
return nil
}

func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) {
Expand Down
41 changes: 41 additions & 0 deletions serialization/inet/marshal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package inet

import (
"net"
"reflect"
)

func Marshal(value interface{}) ([]byte, error) {
switch v := value.(type) {
case nil:
return nil, nil
case []byte:
return EncBytes(v)
case *[]byte:
return EncBytesR(v)
case net.IP:
return EncNetIP(v)
case *net.IP:
return EncNetIPr(v)
case [4]byte:
return EncArray4(v)
case *[4]byte:
return EncArray4R(v)
case [16]byte:
return EncArray16(v)
case *[16]byte:
return EncArray16R(v)
case string:
return EncString(v)
case *string:
return EncStringR(v)
default:
// Custom types (type MyIP []byte) can be serialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.TypeOf(value)
if rv.Kind() != reflect.Ptr {
return EncReflect(reflect.ValueOf(v))
}
return EncReflectR(reflect.ValueOf(v))
}
}
192 changes: 192 additions & 0 deletions serialization/inet/marshal_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
package inet

import (
"fmt"
"net"
"reflect"
)

func EncBytes(v []byte) ([]byte, error) {
switch len(v) {
case 0:
if v == nil {
return nil, nil
}
return make([]byte, 0), nil
case 4:
tmp := make([]byte, 4)
copy(tmp, v)
return tmp, nil
case 16:
tmp := make([]byte, 16)
copy(tmp, v)
return tmp, nil
default:
return nil, fmt.Errorf("failed to marshal inet: the ([]byte) length can be 0,4,16")
}
}

func EncBytesR(v *[]byte) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncBytes(*v)
}

func EncNetIP(v net.IP) ([]byte, error) {
switch len(v) {
case 0:
if v == nil {
return nil, nil
}
return make([]byte, 0), nil
case 4, 16:
t := v.To4()
if t == nil {
return v.To16(), nil
}
return t, nil
default:
return nil, fmt.Errorf("failed to marshal inet: the (net.IP) length can be 0,4,16")
}
}

func EncNetIPr(v *net.IP) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncNetIP(*v)
}

func EncArray16(v [16]byte) ([]byte, error) {
tmp := make([]byte, 16)
copy(tmp, v[:])
return tmp, nil
}

func EncArray16R(v *[16]byte) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncArray16(*v)
}

func EncArray4(v [4]byte) ([]byte, error) {
tmp := make([]byte, 4)
copy(tmp, v[:])
return tmp, nil
}

func EncArray4R(v *[4]byte) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncArray4(*v)
}

func EncString(v string) ([]byte, error) {
if len(v) == 0 {
return nil, nil
}
b := net.ParseIP(v)
if b != nil {
t := b.To4()
if t == nil {
return b.To16(), nil
}
return t, nil
}
return nil, fmt.Errorf("failed to marshal inet: invalid IP string %s", v)
}

func EncStringR(v *string) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncString(*v)
}

func EncReflect(v reflect.Value) ([]byte, error) {
switch v.Kind() {
case reflect.Array:
if l := v.Len(); v.Type().Elem().Kind() != reflect.Uint8 || (l != 16 && l != 4) {
return nil, fmt.Errorf("failed to marshal inet: unsupported value type (%T)(%[1]v)", v.Interface())
}
nv := reflect.New(v.Type())
nv.Elem().Set(v)
return nv.Elem().Bytes(), nil
case reflect.Slice:
if v.Type().Elem().Kind() != reflect.Uint8 {
return nil, fmt.Errorf("failed to marshal inet: unsupported value type (%T)(%[1]v)", v.Interface())
}
return encReflectBytes(v)
case reflect.String:
return encReflectString(v)
case reflect.Struct:
if v.Type().String() == "gocql.unsetColumn" {
return nil, nil
}
return nil, fmt.Errorf("failed to marshal inet: unsupported value type (%T)(%[1]v)", v.Interface())
default:
return nil, fmt.Errorf("failed to marshal inet: unsupported value type (%T)(%[1]v)", v.Interface())
}
}

func EncReflectR(v reflect.Value) ([]byte, error) {
if v.IsNil() {
return nil, nil
}
switch ev := v.Elem(); ev.Kind() {
case reflect.Array:
if l := v.Len(); ev.Type().Elem().Kind() != reflect.Uint8 || (l != 16 && l != 4) {
return nil, fmt.Errorf("failed to marshal inet: unsupported value type (%T)(%[1]v)", v.Interface())
}
return v.Elem().Bytes(), nil
case reflect.Slice:
if ev.Type().Elem().Kind() != reflect.Uint8 {
return nil, fmt.Errorf("failed to marshal inet: unsupported value type (%T)(%[1]v)", v.Interface())
}
return encReflectBytes(ev)
case reflect.String:
return encReflectString(ev)
default:
return nil, fmt.Errorf("failed to marshal inet: unsupported value type (%T)(%[1]v)", v.Interface())
}
}

func encReflectString(v reflect.Value) ([]byte, error) {
val := v.String()
if len(val) == 0 {
return nil, nil
}
b := net.ParseIP(val)
if b != nil {
t := b.To4()
if t == nil {
return b.To16(), nil
}
return t, nil
}
return nil, fmt.Errorf("failed to marshal inet: invalid IP string (%T)(%[1]v)", v.Interface())
}

func encReflectBytes(v reflect.Value) ([]byte, error) {
val := v.Bytes()
switch len(val) {
case 0:
if val == nil {
return nil, nil
}
return make([]byte, 0), nil
case 4:
tmp := make([]byte, 4)
copy(tmp, val)
return tmp, nil
case 16:
tmp := make([]byte, 16)
copy(tmp, val)
return tmp, nil
default:
return nil, fmt.Errorf("failed to marshal inet: the (%T) length can be 0,4,16", v.Interface())
}
}
46 changes: 46 additions & 0 deletions serialization/inet/unmarshal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package inet

import (
"fmt"
"net"
"reflect"
)

func Unmarshal(data []byte, value interface{}) error {
switch v := value.(type) {
case nil:
return nil
case *[]byte:
return DecBytes(data, v)
case **[]byte:
return DecBytesR(data, v)
case *net.IP:
return DecNetIP(data, v)
case **net.IP:
return DecNetIPr(data, v)
case *[4]byte:
return DecArray4(data, v)
case **[4]byte:
return DecArray4R(data, v)
case *[16]byte:
return DecArray16(data, v)
case **[16]byte:
return DecArray16R(data, v)
case *string:
return DecString(data, v)
case **string:
return DecStringR(data, v)
default:
// Custom types (type MyIP []byte) can be deserialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.ValueOf(value)
rt := rv.Type()
if rt.Kind() != reflect.Ptr {
return fmt.Errorf("failed to unmarshal inet: unsupported value type (%T)(%[1]v)", v)
}
if rt.Elem().Kind() != reflect.Ptr {
return DecReflect(data, rv)
}
return DecReflectR(data, rv)
}
}
Loading

0 comments on commit f587a93

Please sign in to comment.