diff --git a/marshal.go b/marshal.go index 96db43ebc..34cd443df 100644 --- a/marshal.go +++ b/marshal.go @@ -29,6 +29,7 @@ import ( "github.com/gocql/gocql/serialization/inet" "github.com/gocql/gocql/serialization/smallint" "github.com/gocql/gocql/serialization/text" + "github.com/gocql/gocql/serialization/timestamp" "github.com/gocql/gocql/serialization/timeuuid" "github.com/gocql/gocql/serialization/tinyint" "github.com/gocql/gocql/serialization/uuid" @@ -173,7 +174,7 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) { case TypeTime: return marshalTime(value) case TypeTimestamp: - return marshalTimestamp(info, value) + return marshalTimestamp(value) case TypeList, TypeSet: return marshalList(info, value) case TypeMap: @@ -287,7 +288,7 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error { case TypeTime: return unmarshalTime(data, value) case TypeTimestamp: - return unmarshalTimestamp(info, data, value) + return unmarshalTimestamp(data, value) case TypeList, TypeSet: return unmarshalList(info, data, value) case TypeMap: @@ -683,64 +684,20 @@ func unmarshalTime(data []byte, value interface{}) error { return nil } -func marshalTimestamp(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int64: - return encBigInt(v), nil - case time.Time: - if v.IsZero() { - return []byte{}, nil - } - x := int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) - return encBigInt(x), nil - } - - if value == nil { - return nil, nil - } - - rv := reflect.ValueOf(value) - switch rv.Type().Kind() { - case reflect.Int64: - return encBigInt(rv.Int()), nil +func marshalTimestamp(value interface{}) ([]byte, error) { + data, err := timestamp.Marshal(value) + if err != nil { + return nil, wrapMarshalError(err, "marshal error") } - return nil, marshalErrorf("can not marshal %T into %s", value, info) + return data, nil } -func unmarshalTimestamp(info TypeInfo, data []byte, value interface{}) error { - switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) - case *int64: - *v = decBigInt(data) - return nil - case *time.Time: - if len(data) == 0 { - *v = time.Time{} - return nil - } - x := decBigInt(data) - sec := x / 1000 - nsec := (x - sec*1000) * 1000000 - *v = time.Unix(sec, nsec).In(time.UTC) - return nil - } - - rv := reflect.ValueOf(value) - if rv.Kind() != reflect.Ptr { - return unmarshalErrorf("can not unmarshal into non-pointer %T", value) - } - rv = rv.Elem() - switch rv.Type().Kind() { - case reflect.Int64: - rv.SetInt(decBigInt(data)) - return nil +func unmarshalTimestamp(data []byte, value interface{}) error { + err := timestamp.Unmarshal(data, value) + if err != nil { + return wrapUnmarshalError(err, "unmarshal error") } - return unmarshalErrorf("can not unmarshal %s into %T", info, value) + return nil } const millisecondsInADay int64 = 24 * 60 * 60 * 1000 diff --git a/marshal_test.go b/marshal_test.go index 6b793fe03..fbb7b9146 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -122,20 +122,6 @@ var marshalTests = []struct { nil, nil, }, - { - NativeType{proto: 2, typ: TypeTimestamp}, - []byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"), - time.Date(2013, time.August, 13, 9, 52, 3, 0, time.UTC), - nil, - nil, - }, - { - NativeType{proto: 2, typ: TypeTimestamp}, - []byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"), - int64(1376387523000), - nil, - nil, - }, { NativeType{proto: 5, typ: TypeDuration}, []byte("\x89\xa2\xc3\xc2\x9a\xe0F\x91\x06"), @@ -317,23 +303,6 @@ var marshalTests = []struct { nil, nil, }, - { - NativeType{proto: 2, typ: TypeTimestamp}, - []byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"), - func() *time.Time { - t := time.Date(2013, time.August, 13, 9, 52, 3, 0, time.UTC) - return &t - }(), - nil, - nil, - }, - { - NativeType{proto: 2, typ: TypeTimestamp}, - []byte(nil), - (*time.Time)(nil), - nil, - nil, - }, { NativeType{proto: 2, typ: TypeBoolean}, []byte("\x00"), @@ -863,72 +832,6 @@ func TestMarshalPointer(t *testing.T) { } } -func TestMarshalTimestamp(t *testing.T) { - var marshalTimestampTests = []struct { - Info TypeInfo - Data []byte - Value interface{} - }{ - { - NativeType{proto: 2, typ: TypeTimestamp}, - []byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"), - time.Date(2013, time.August, 13, 9, 52, 3, 0, time.UTC), - }, - { - NativeType{proto: 2, typ: TypeTimestamp}, - []byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"), - int64(1376387523000), - }, - { - // 9223372036854 is the maximum time representable in ms since the epoch - // with int64 if using UnixNano to convert - NativeType{proto: 2, typ: TypeTimestamp}, - []byte("\x00\x00\x08\x63\x7b\xd0\x5a\xf6"), - time.Date(2262, time.April, 11, 23, 47, 16, 854775807, time.UTC), - }, - { - // One nanosecond after causes overflow when using UnixNano - // Instead it should resolve to the same time in ms - NativeType{proto: 2, typ: TypeTimestamp}, - []byte("\x00\x00\x08\x63\x7b\xd0\x5a\xf6"), - time.Date(2262, time.April, 11, 23, 47, 16, 854775808, time.UTC), - }, - { - // -9223372036855 is the minimum time representable in ms since the epoch - // with int64 if using UnixNano to convert - NativeType{proto: 2, typ: TypeTimestamp}, - []byte("\xff\xff\xf7\x9c\x84\x2f\xa5\x09"), - time.Date(1677, time.September, 21, 00, 12, 43, 145224192, time.UTC), - }, - { - // One nanosecond earlier causes overflow when using UnixNano - // it should resolve to the same time in ms - NativeType{proto: 2, typ: TypeTimestamp}, - []byte("\xff\xff\xf7\x9c\x84\x2f\xa5\x09"), - time.Date(1677, time.September, 21, 00, 12, 43, 145224191, time.UTC), - }, - { - // Store the zero time as a blank slice - NativeType{proto: 2, typ: TypeTimestamp}, - []byte{}, - time.Time{}, - }, - } - - for i, test := range marshalTimestampTests { - t.Log(i, test) - data, err := Marshal(test.Info, test.Value) - if err != nil { - t.Errorf("marshalTest[%d]: %v", i, err) - continue - } - if !bytes.Equal(data, test.Data) { - t.Errorf("marshalTest[%d]: expected %x (%v), got %x (%v) for time %s", i, - test.Data, decBigInt(test.Data), data, decBigInt(data), test.Value) - } - } -} - func TestMarshalTuple(t *testing.T) { info := TupleTypeInfo{ NativeType: NativeType{proto: 3, typ: TypeTuple}, diff --git a/serialization/timestamp/marshal.go b/serialization/timestamp/marshal.go new file mode 100644 index 000000000..50288a2f4 --- /dev/null +++ b/serialization/timestamp/marshal.go @@ -0,0 +1,30 @@ +package timestamp + +import ( + "reflect" + "time" +) + +func Marshal(value interface{}) ([]byte, error) { + switch v := value.(type) { + case nil: + return nil, nil + case int64: + return EncInt64(v) + case *int64: + return EncInt64R(v) + case time.Time: + return EncTime(v) + case *time.Time: + return EncTimeR(v) + + default: + // Custom types (type MyTime int64) can be serialized only via `reflect` package. + // Later, when generic-based serialization is introduced we can do that via generics. + rv := reflect.TypeOf(value) + if rv.Kind() != reflect.Ptr { + return EncReflect(reflect.ValueOf(v)) + } + return EncReflectR(reflect.ValueOf(v)) + } +} diff --git a/serialization/timestamp/marshal_utils.go b/serialization/timestamp/marshal_utils.go new file mode 100644 index 000000000..834d539fc --- /dev/null +++ b/serialization/timestamp/marshal_utils.go @@ -0,0 +1,65 @@ +package timestamp + +import ( + "fmt" + "reflect" + "time" +) + +const ( + maxValInt64 int64 = 86399999999999 + minValInt64 int64 = 0 + maxValDur time.Duration = 86399999999999 + minValDur time.Duration = 0 +) + +func EncInt64(v int64) ([]byte, error) { + return encInt64(v), nil +} + +func EncInt64R(v *int64) ([]byte, error) { + if v == nil { + return nil, nil + } + return EncInt64(*v) +} + +func EncTime(v time.Time) ([]byte, error) { + if v.IsZero() { + return make([]byte, 0), nil + } + ms := v.Unix()*1e3 + int64(v.Nanosecond())/1e6 + return []byte{byte(ms >> 56), byte(ms >> 48), byte(ms >> 40), byte(ms >> 32), byte(ms >> 24), byte(ms >> 16), byte(ms >> 8), byte(ms)}, nil +} + +func EncTimeR(v *time.Time) ([]byte, error) { + if v == nil { + return nil, nil + } + return EncTime(*v) +} + +func EncReflect(v reflect.Value) ([]byte, error) { + switch v.Kind() { + case reflect.Int64: + return encInt64(v.Int()), nil + case reflect.Struct: + if v.Type().String() == "gocql.unsetColumn" { + return nil, nil + } + return nil, fmt.Errorf("failed to marshal timestamp: unsupported value type (%T)(%[1]v)", v.Interface()) + default: + return nil, fmt.Errorf("failed to marshal timestamp: unsupported value type (%T)(%[1]v)", v.Interface()) + } +} + +func EncReflectR(v reflect.Value) ([]byte, error) { + if v.IsNil() { + return nil, nil + } + return EncReflect(v.Elem()) +} + +func encInt64(v int64) []byte { + return []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} +} diff --git a/serialization/timestamp/unmarshal.go b/serialization/timestamp/unmarshal.go new file mode 100644 index 000000000..9d9c92a60 --- /dev/null +++ b/serialization/timestamp/unmarshal.go @@ -0,0 +1,36 @@ +package timestamp + +import ( + "fmt" + "reflect" + "time" +) + +func Unmarshal(data []byte, value interface{}) error { + switch v := value.(type) { + case nil: + return nil + + case *int64: + return DecInt64(data, v) + case **int64: + return DecInt64R(data, v) + case *time.Time: + return DecTime(data, v) + case **time.Time: + return DecTimeR(data, v) + default: + + // Custom types (type MyTime int64) can be deserialized only via `reflect` package. + // Later, when generic-based serialization is introduced we can do that via generics. + rv := reflect.ValueOf(value) + rt := rv.Type() + if rt.Kind() != reflect.Ptr { + return fmt.Errorf("failed to unmarshal timestamp: unsupported value type (%T)(%[1]v)", value) + } + if rt.Elem().Kind() != reflect.Ptr { + return DecReflect(data, rv) + } + return DecReflectR(data, rv) + } +} diff --git a/serialization/timestamp/unmarshal_utils.go b/serialization/timestamp/unmarshal_utils.go new file mode 100644 index 000000000..41d0a0203 --- /dev/null +++ b/serialization/timestamp/unmarshal_utils.go @@ -0,0 +1,150 @@ +package timestamp + +import ( + "fmt" + "reflect" + "time" +) + +var ( + errWrongDataLen = fmt.Errorf("failed to unmarshal timestamp: the length of the data should be 0 or 8") +) + +func errNilReference(v interface{}) error { + return fmt.Errorf("failed to unmarshal timestamp: can not unmarshal into nil reference (%T)(%[1]v))", v) +} + +func DecInt64(p []byte, v *int64) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + *v = 0 + case 8: + *v = decInt64(p) + default: + return errWrongDataLen + } + return nil +} + +func DecInt64R(p []byte, v **int64) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + *v = new(int64) + } + case 8: + val := decInt64(p) + *v = &val + default: + return errWrongDataLen + } + return nil +} + +func DecTime(p []byte, v *time.Time) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + *v = time.Time{} + case 8: + *v = decTime(p) + default: + return errWrongDataLen + } + return nil +} + +func DecTimeR(p []byte, v **time.Time) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + *v = new(time.Time) + } + case 8: + val := decTime(p) + *v = &val + default: + return errWrongDataLen + } + return nil +} + +func DecReflect(p []byte, v reflect.Value) error { + if v.IsNil() { + return fmt.Errorf("failed to unmarshal timestamp: can not unmarshal into nil reference (%T)(%[1]v))", v.Interface()) + } + + switch v = v.Elem(); v.Kind() { + case reflect.Int64: + return decReflectInt64(p, v) + default: + return fmt.Errorf("failed to unmarshal timestamp: unsupported value type (%T)(%[1]v)", v.Interface()) + } +} + +func decReflectInt64(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + v.SetInt(0) + case 8: + v.SetInt(decInt64(p)) + default: + return errWrongDataLen + } + return nil +} + +func DecReflectR(p []byte, v reflect.Value) error { + if v.IsNil() { + return fmt.Errorf("failed to unmarshal timestamp: can not unmarshal into nil reference (%T)(%[1]v)", v.Interface()) + } + + switch v.Type().Elem().Elem().Kind() { + case reflect.Int64: + return decReflectIntsR(p, v) + default: + return fmt.Errorf("failed to unmarshal timestamp: unsupported value type (%T)(%[1]v)", v.Interface()) + } +} + +func decReflectIntsR(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + if p == nil { + v.Elem().Set(reflect.Zero(v.Elem().Type())) + } else { + v.Elem().Set(reflect.New(v.Type().Elem().Elem())) + } + case 8: + val := reflect.New(v.Type().Elem().Elem()) + val.Elem().SetInt(decInt64(p)) + v.Elem().Set(val) + default: + return errWrongDataLen + } + return nil +} + +func decInt64(p []byte) int64 { + return int64(p[0])<<56 | int64(p[1])<<48 | int64(p[2])<<40 | int64(p[3])<<32 | int64(p[4])<<24 | int64(p[5])<<16 | int64(p[6])<<8 | int64(p[7]) +} + +func decTime(p []byte) time.Time { + msec := decInt64(p) + return time.Unix(msec/1e3, (msec%1e3)*1e6).UTC() +} diff --git a/tests/serialization/marshal_16_timestamp_corrupt_test.go b/tests/serialization/marshal_16_timestamp_corrupt_test.go index 19f1416fb..436c1c1df 100644 --- a/tests/serialization/marshal_16_timestamp_corrupt_test.go +++ b/tests/serialization/marshal_16_timestamp_corrupt_test.go @@ -10,39 +10,60 @@ import ( "github.com/gocql/gocql" "github.com/gocql/gocql/internal/tests/serialization" "github.com/gocql/gocql/internal/tests/serialization/mod" + "github.com/gocql/gocql/serialization/timestamp" ) func TestMarshalTimestampCorrupt(t *testing.T) { tType := gocql.NewNativeType(4, gocql.TypeTimestamp, "") - unmarshal := func(bytes []byte, i interface{}) error { - return gocql.Unmarshal(tType, bytes, i) + type testSuite struct { + name string + marshal func(interface{}) ([]byte, error) + unmarshal func(bytes []byte, i interface{}) error } - // unmarshal of all supported `go types` does not return an error on all type of corruption. - brokenTypes := serialization.GetTypes(int64(0), (*int64)(nil), mod.Int64(0), (*mod.Int64)(nil), time.Time{}, (*time.Time)(nil)) - - serialization.NegativeUnmarshalSet{ - Data: []byte("\x7f\xff\xff\xff\xff\xff\xff\xff\xff"), - Values: mod.Values{ - int64(0), time.Time{}, - }.AddVariants(mod.All...), - BrokenTypes: brokenTypes, - }.Run("big_data", t, unmarshal) - - serialization.NegativeUnmarshalSet{ - Data: []byte("\xff\xff\xff\xff\xff\xff\xff"), - Values: mod.Values{ - int64(0), time.Time{}, - }.AddVariants(mod.All...), - BrokenTypes: brokenTypes, - }.Run("small_data1", t, unmarshal) - - serialization.NegativeUnmarshalSet{ - Data: []byte("\x00"), - Values: mod.Values{ - int64(0), time.Time{}, - }.AddVariants(mod.All...), - BrokenTypes: brokenTypes, - }.Run("small_data2", t, unmarshal) + testSuites := [2]testSuite{ + { + name: "serialization.timestamp", + marshal: timestamp.Marshal, + unmarshal: timestamp.Unmarshal, + }, + { + name: "glob", + marshal: func(i interface{}) ([]byte, error) { + return gocql.Marshal(tType, i) + }, + unmarshal: func(bytes []byte, i interface{}) error { + return gocql.Unmarshal(tType, bytes, i) + }, + }, + } + + for _, tSuite := range testSuites { + unmarshal := tSuite.unmarshal + + t.Run(tSuite.name, func(t *testing.T) { + + serialization.NegativeUnmarshalSet{ + Data: []byte("\x7f\xff\xff\xff\xff\xff\xff\xff\xff"), + Values: mod.Values{ + int64(0), time.Time{}, + }.AddVariants(mod.All...), + }.Run("big_data", t, unmarshal) + + serialization.NegativeUnmarshalSet{ + Data: []byte("\xff\xff\xff\xff\xff\xff\xff"), + Values: mod.Values{ + int64(0), time.Time{}, + }.AddVariants(mod.All...), + }.Run("small_data1", t, unmarshal) + + serialization.NegativeUnmarshalSet{ + Data: []byte("\x00"), + Values: mod.Values{ + int64(0), time.Time{}, + }.AddVariants(mod.All...), + }.Run("small_data2", t, unmarshal) + }) + } } diff --git a/tests/serialization/marshal_16_timestamp_test.go b/tests/serialization/marshal_16_timestamp_test.go index 9503bf1e4..5bc2ba86d 100644 --- a/tests/serialization/marshal_16_timestamp_test.go +++ b/tests/serialization/marshal_16_timestamp_test.go @@ -11,62 +11,90 @@ import ( "github.com/gocql/gocql" "github.com/gocql/gocql/internal/tests/serialization" "github.com/gocql/gocql/internal/tests/serialization/mod" + "github.com/gocql/gocql/serialization/timestamp" ) func TestMarshalsTimestamp(t *testing.T) { tType := gocql.NewNativeType(4, gocql.TypeTimestamp, "") - marshal := func(i interface{}) ([]byte, error) { return gocql.Marshal(tType, i) } - unmarshal := func(bytes []byte, i interface{}) error { - return gocql.Unmarshal(tType, bytes, i) + type testSuite struct { + name string + marshal func(interface{}) ([]byte, error) + unmarshal func(bytes []byte, i interface{}) error } - zeroTime := time.UnixMilli(0).UTC() + testSuites := [2]testSuite{ + { + name: "serialization.timestamp", + marshal: timestamp.Marshal, + unmarshal: timestamp.Unmarshal, + }, + { + name: "glob", + marshal: func(i interface{}) ([]byte, error) { + return gocql.Marshal(tType, i) + }, + unmarshal: func(bytes []byte, i interface{}) error { + return gocql.Unmarshal(tType, bytes, i) + }, + }, + } + + zeroTime := time.Unix(0, 0).UTC() - // unmarshall `nil` and `zero` data returns a negative value of the `time.Time{}` + // The `time` package have a speciality - values `time.Time{}` and `time.Unix(0,0).UTC()` are different + // The old unmarshal function unmarshalls `nil` and `zero` data into `time.Time{}`, but data with zeros into `time.Unix(0,0).UTC()` brokenTime := serialization.GetTypes(time.Time{}, &time.Time{}) + _ = brokenTime - serialization.PositiveSet{ - Data: nil, - Values: mod.Values{ - (*int64)(nil), (*time.Time)(nil), - }.AddVariants(mod.CustomType), - }.Run("[nil]nullable", t, marshal, unmarshal) + for _, tSuite := range testSuites { + marshal := tSuite.marshal + unmarshal := tSuite.unmarshal - serialization.PositiveSet{ - Data: nil, - Values: mod.Values{ - int64(0), zeroTime, - }.AddVariants(mod.CustomType), - BrokenUnmarshalTypes: brokenTime, - }.Run("[nil]unmarshal", t, nil, unmarshal) + t.Run(tSuite.name, func(t *testing.T) { + serialization.PositiveSet{ + Data: nil, + Values: mod.Values{ + (*int64)(nil), (*time.Time)(nil), + }.AddVariants(mod.CustomType), + }.Run("[nil]nullable", t, marshal, unmarshal) - serialization.PositiveSet{ - Data: make([]byte, 0), - Values: mod.Values{ - int64(0), zeroTime, - }.AddVariants(mod.All...), - BrokenUnmarshalTypes: brokenTime, - }.Run("[]unmarshal", t, nil, unmarshal) + serialization.PositiveSet{ + Data: nil, + Values: mod.Values{ + int64(0), zeroTime, + }.AddVariants(mod.CustomType), + BrokenUnmarshalTypes: brokenTime, + }.Run("[nil]unmarshal", t, nil, unmarshal) - serialization.PositiveSet{ - Data: []byte("\x00\x00\x00\x00\x00\x00\x00\x00"), - Values: mod.Values{ - int64(0), zeroTime, - }.AddVariants(mod.All...), - }.Run("zeros", t, marshal, unmarshal) + serialization.PositiveSet{ + Data: make([]byte, 0), + Values: mod.Values{ + int64(0), zeroTime, + }.AddVariants(mod.All...), + BrokenUnmarshalTypes: brokenTime, + }.Run("[]unmarshal", t, nil, unmarshal) - serialization.PositiveSet{ - Data: []byte("\x7f\xff\xff\xff\xff\xff\xff\xff"), - Values: mod.Values{ - int64(math.MaxInt64), time.UnixMilli(math.MaxInt64).UTC(), - }.AddVariants(mod.All...), - }.Run("max", t, marshal, unmarshal) + serialization.PositiveSet{ + Data: []byte("\x00\x00\x00\x00\x00\x00\x00\x00"), + Values: mod.Values{ + int64(0), zeroTime, + }.AddVariants(mod.All...), + }.Run("zeros", t, marshal, unmarshal) - serialization.PositiveSet{ - Data: []byte("\x80\x00\x00\x00\x00\x00\x00\x00"), - Values: mod.Values{ - int64(math.MinInt64), time.UnixMilli(math.MinInt64).UTC(), - }.AddVariants(mod.All...), - }.Run("min", t, marshal, unmarshal) + serialization.PositiveSet{ + Data: []byte("\x7f\xff\xff\xff\xff\xff\xff\xff"), + Values: mod.Values{ + int64(math.MaxInt64), time.UnixMilli(math.MaxInt64).UTC(), + }.AddVariants(mod.All...), + }.Run("max", t, marshal, unmarshal) + + serialization.PositiveSet{ + Data: []byte("\x80\x00\x00\x00\x00\x00\x00\x00"), + Values: mod.Values{ + int64(math.MinInt64), time.UnixMilli(math.MinInt64).UTC(), + }.AddVariants(mod.All...), + }.Run("min", t, marshal, unmarshal) + }) + } }