Skip to content

Commit

Permalink
Merge pull request #362 from illia-li/il/fix/marshal/inet
Browse files Browse the repository at this point in the history
Fix `inet` marshal, unmarshall functions
  • Loading branch information
dkropachev authored Dec 5, 2024
2 parents 615c6d9 + 62d8df1 commit a5b05d0
Show file tree
Hide file tree
Showing 8 changed files with 1,066 additions and 261 deletions.
74 changes: 13 additions & 61 deletions marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"math"
"math/big"
"math/bits"
"net"
"reflect"
"strings"
"time"
Expand All @@ -26,6 +25,7 @@ import (
"github.com/gocql/gocql/serialization/cqlint"
"github.com/gocql/gocql/serialization/double"
"github.com/gocql/gocql/serialization/float"
"github.com/gocql/gocql/serialization/inet"
"github.com/gocql/gocql/serialization/smallint"
"github.com/gocql/gocql/serialization/text"
"github.com/gocql/gocql/serialization/timeuuid"
Expand Down Expand Up @@ -184,7 +184,7 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
case TypeVarint:
return marshalVarint(value)
case TypeInet:
return marshalInet(info, value)
return marshalInet(value)
case TypeTuple:
return marshalTuple(info, value)
case TypeUDT:
Expand Down Expand Up @@ -296,7 +296,7 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
case TypeUUID:
return unmarshalUUID(data, value)
case TypeInet:
return unmarshalInet(info, data, value)
return unmarshalInet(data, value)
case TypeTuple:
return unmarshalTuple(info, data, value)
case TypeUDT:
Expand Down Expand Up @@ -1360,68 +1360,20 @@ func unmarshalTimeUUID(data []byte, value interface{}) error {
return nil
}

func marshalInet(info TypeInfo, value interface{}) ([]byte, error) {
// we return either the 4 or 16 byte representation of an
// ip address here otherwise the db value will be prefixed
// with the remaining byte values e.g. ::ffff:127.0.0.1 and not 127.0.0.1
switch val := value.(type) {
case unsetColumn:
return nil, nil
case net.IP:
t := val.To4()
if t == nil {
return val.To16(), nil
}
return t, nil
case string:
b := net.ParseIP(val)
if b != nil {
t := b.To4()
if t == nil {
return b.To16(), nil
}
return t, nil
}
return nil, marshalErrorf("cannot marshal. invalid ip string %s", val)
}

if value == nil {
return nil, nil
func marshalInet(value interface{}) ([]byte, error) {
data, err := inet.Marshal(value)
if err != nil {
return nil, wrapMarshalError(err, "marshal error")
}

return nil, marshalErrorf("cannot marshal %T into %s", value, info)
return data, nil
}

func unmarshalInet(info TypeInfo, data []byte, value interface{}) error {
switch v := value.(type) {
case Unmarshaler:
return v.UnmarshalCQL(info, data)
case *net.IP:
if x := len(data); !(x == 4 || x == 16) {
return unmarshalErrorf("cannot unmarshal %s into %T: invalid sized IP: got %d bytes not 4 or 16", info, value, x)
}
buf := copyBytes(data)
ip := net.IP(buf)
if v4 := ip.To4(); v4 != nil {
*v = v4
return nil
}
*v = ip
return nil
case *string:
if len(data) == 0 {
*v = ""
return nil
}
ip := net.IP(data)
if v4 := ip.To4(); v4 != nil {
*v = v4.String()
return nil
}
*v = ip.String()
return nil
func unmarshalInet(data []byte, value interface{}) error {
err := inet.Unmarshal(data, value)
if err != nil {
return wrapUnmarshalError(err, "unmarshal error")
}
return unmarshalErrorf("cannot unmarshal %s into %T", info, value)
return nil
}

func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) {
Expand Down
2 changes: 1 addition & 1 deletion marshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1374,7 +1374,7 @@ func TestMarshalNil(t *testing.T) {
func TestUnmarshalInetCopyBytes(t *testing.T) {
data := []byte{127, 0, 0, 1}
var ip net.IP
if err := unmarshalInet(NativeType{proto: 2, typ: TypeInet}, data, &ip); err != nil {
if err := unmarshalInet(data, &ip); err != nil {
t.Fatal(err)
}

Expand Down
41 changes: 41 additions & 0 deletions serialization/inet/marshal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package inet

import (
"net"
"reflect"
)

func Marshal(value interface{}) ([]byte, error) {
switch v := value.(type) {
case nil:
return nil, nil
case []byte:
return EncBytes(v)
case *[]byte:
return EncBytesR(v)
case net.IP:
return EncNetIP(v)
case *net.IP:
return EncNetIPr(v)
case [4]byte:
return EncArray4(v)
case *[4]byte:
return EncArray4R(v)
case [16]byte:
return EncArray16(v)
case *[16]byte:
return EncArray16R(v)
case string:
return EncString(v)
case *string:
return EncStringR(v)
default:
// Custom types (type MyIP []byte) can be serialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.TypeOf(value)
if rv.Kind() != reflect.Ptr {
return EncReflect(reflect.ValueOf(v))
}
return EncReflectR(reflect.ValueOf(v))
}
}
192 changes: 192 additions & 0 deletions serialization/inet/marshal_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
package inet

import (
"fmt"
"net"
"reflect"
)

func EncBytes(v []byte) ([]byte, error) {
switch len(v) {
case 0:
if v == nil {
return nil, nil
}
return make([]byte, 0), nil
case 4:
tmp := make([]byte, 4)
copy(tmp, v)
return tmp, nil
case 16:
tmp := make([]byte, 16)
copy(tmp, v)
return tmp, nil
default:
return nil, fmt.Errorf("failed to marshal inet: the ([]byte) length can be 0,4,16")
}
}

func EncBytesR(v *[]byte) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncBytes(*v)
}

func EncNetIP(v net.IP) ([]byte, error) {
switch len(v) {
case 0:
if v == nil {
return nil, nil
}
return make([]byte, 0), nil
case 4, 16:
t := v.To4()
if t == nil {
return v.To16(), nil
}
return t, nil
default:
return nil, fmt.Errorf("failed to marshal inet: the (net.IP) length can be 0,4,16")
}
}

func EncNetIPr(v *net.IP) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncNetIP(*v)
}

func EncArray16(v [16]byte) ([]byte, error) {
tmp := make([]byte, 16)
copy(tmp, v[:])
return tmp, nil
}

func EncArray16R(v *[16]byte) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncArray16(*v)
}

func EncArray4(v [4]byte) ([]byte, error) {
tmp := make([]byte, 4)
copy(tmp, v[:])
return tmp, nil
}

func EncArray4R(v *[4]byte) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncArray4(*v)
}

func EncString(v string) ([]byte, error) {
if len(v) == 0 {
return nil, nil
}
b := net.ParseIP(v)
if b != nil {
t := b.To4()
if t == nil {
return b.To16(), nil
}
return t, nil
}
return nil, fmt.Errorf("failed to marshal inet: invalid IP string %s", v)
}

func EncStringR(v *string) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncString(*v)
}

func EncReflect(v reflect.Value) ([]byte, error) {
switch v.Kind() {
case reflect.Array:
if l := v.Len(); v.Type().Elem().Kind() != reflect.Uint8 || (l != 16 && l != 4) {
return nil, fmt.Errorf("failed to marshal inet: unsupported value type (%T)(%[1]v)", v.Interface())
}
nv := reflect.New(v.Type())
nv.Elem().Set(v)
return nv.Elem().Bytes(), nil
case reflect.Slice:
if v.Type().Elem().Kind() != reflect.Uint8 {
return nil, fmt.Errorf("failed to marshal inet: unsupported value type (%T)(%[1]v)", v.Interface())
}
return encReflectBytes(v)
case reflect.String:
return encReflectString(v)
case reflect.Struct:
if v.Type().String() == "gocql.unsetColumn" {
return nil, nil
}
return nil, fmt.Errorf("failed to marshal inet: unsupported value type (%T)(%[1]v)", v.Interface())
default:
return nil, fmt.Errorf("failed to marshal inet: unsupported value type (%T)(%[1]v)", v.Interface())
}
}

func EncReflectR(v reflect.Value) ([]byte, error) {
if v.IsNil() {
return nil, nil
}
switch ev := v.Elem(); ev.Kind() {
case reflect.Array:
if l := v.Len(); ev.Type().Elem().Kind() != reflect.Uint8 || (l != 16 && l != 4) {
return nil, fmt.Errorf("failed to marshal inet: unsupported value type (%T)(%[1]v)", v.Interface())
}
return v.Elem().Bytes(), nil
case reflect.Slice:
if ev.Type().Elem().Kind() != reflect.Uint8 {
return nil, fmt.Errorf("failed to marshal inet: unsupported value type (%T)(%[1]v)", v.Interface())
}
return encReflectBytes(ev)
case reflect.String:
return encReflectString(ev)
default:
return nil, fmt.Errorf("failed to marshal inet: unsupported value type (%T)(%[1]v)", v.Interface())
}
}

func encReflectString(v reflect.Value) ([]byte, error) {
val := v.String()
if len(val) == 0 {
return nil, nil
}
b := net.ParseIP(val)
if b != nil {
t := b.To4()
if t == nil {
return b.To16(), nil
}
return t, nil
}
return nil, fmt.Errorf("failed to marshal inet: invalid IP string (%T)(%[1]v)", v.Interface())
}

func encReflectBytes(v reflect.Value) ([]byte, error) {
val := v.Bytes()
switch len(val) {
case 0:
if val == nil {
return nil, nil
}
return make([]byte, 0), nil
case 4:
tmp := make([]byte, 4)
copy(tmp, val)
return tmp, nil
case 16:
tmp := make([]byte, 16)
copy(tmp, val)
return tmp, nil
default:
return nil, fmt.Errorf("failed to marshal inet: the (%T) length can be 0,4,16", v.Interface())
}
}
Loading

0 comments on commit a5b05d0

Please sign in to comment.