diff --git a/marshal.go b/marshal.go index 34cd443df..d5e511242 100644 --- a/marshal.go +++ b/marshal.go @@ -6,7 +6,6 @@ package gocql import ( "bytes" - "encoding/binary" "errors" "fmt" "math" @@ -23,6 +22,7 @@ import ( "github.com/gocql/gocql/serialization/counter" "github.com/gocql/gocql/serialization/cqlint" "github.com/gocql/gocql/serialization/cqltime" + "github.com/gocql/gocql/serialization/date" "github.com/gocql/gocql/serialization/decimal" "github.com/gocql/gocql/serialization/double" "github.com/gocql/gocql/serialization/float" @@ -192,7 +192,7 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) { case TypeUDT: return marshalUDT(info, value) case TypeDate: - return marshalDate(info, value) + return marshalDate(value) case TypeDuration: return marshalDuration(info, value) } @@ -304,7 +304,7 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error { case TypeUDT: return unmarshalUDT(info, data, value) case TypeDate: - return unmarshalDate(info, data, value) + return unmarshalDate(data, value) case TypeDuration: return unmarshalDuration(info, data, value) } @@ -700,78 +700,20 @@ func unmarshalTimestamp(data []byte, value interface{}) error { return nil } -const millisecondsInADay int64 = 24 * 60 * 60 * 1000 - -func marshalDate(info TypeInfo, value interface{}) ([]byte, error) { - var timestamp int64 - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int64: - timestamp = v - x := timestamp/millisecondsInADay + int64(1<<31) - return encInt(int32(x)), nil - case time.Time: - if v.IsZero() { - return []byte{}, nil - } - timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) - x := timestamp/millisecondsInADay + int64(1<<31) - return encInt(int32(x)), nil - case *time.Time: - if v.IsZero() { - return []byte{}, nil - } - timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) - x := timestamp/millisecondsInADay + int64(1<<31) - return encInt(int32(x)), nil - case string: - if v == "" { - return []byte{}, nil - } - t, err := time.Parse("2006-01-02", v) - if err != nil { - return nil, marshalErrorf("can not marshal %T into %s, date layout must be '2006-01-02'", value, info) - } - timestamp = int64(t.UTC().Unix()*1e3) + int64(t.UTC().Nanosecond()/1e6) - x := timestamp/millisecondsInADay + int64(1<<31) - return encInt(int32(x)), nil - } - - if value == nil { - return nil, nil +func marshalDate(value interface{}) ([]byte, error) { + data, err := date.Marshal(value) + if err != nil { + return nil, wrapMarshalError(err, "marshal error") } - return nil, marshalErrorf("can not marshal %T into %s", value, info) + return data, nil } -func unmarshalDate(info TypeInfo, data []byte, value interface{}) error { - switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) - case *time.Time: - if len(data) == 0 { - *v = time.Time{} - return nil - } - var origin uint32 = 1 << 31 - var current uint32 = binary.BigEndian.Uint32(data) - timestamp := (int64(current) - int64(origin)) * millisecondsInADay - *v = time.UnixMilli(timestamp).In(time.UTC) - return nil - case *string: - if len(data) == 0 { - *v = "" - return nil - } - var origin uint32 = 1 << 31 - var current uint32 = binary.BigEndian.Uint32(data) - timestamp := (int64(current) - int64(origin)) * millisecondsInADay - *v = time.UnixMilli(timestamp).In(time.UTC).Format("2006-01-02") - return nil +func unmarshalDate(data []byte, value interface{}) error { + err := date.Unmarshal(data, value) + if err != nil { + return wrapUnmarshalError(err, "unmarshal error") } - return unmarshalErrorf("can not unmarshal %s into %T", info, value) + return nil } func marshalDuration(info TypeInfo, value interface{}) ([]byte, error) { diff --git a/marshal_test.go b/marshal_test.go index fbb7b9146..59a5211dc 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -1221,126 +1221,6 @@ func TestUnmarshalInetCopyBytes(t *testing.T) { } } -func TestUnmarshalDate(t *testing.T) { - data := []uint8{0x80, 0x0, 0x43, 0x31} - var date time.Time - if err := unmarshalDate(NativeType{proto: 2, typ: TypeDate}, data, &date); err != nil { - t.Fatal(err) - } - - expectedDate := "2017-02-04" - formattedDate := date.Format("2006-01-02") - if expectedDate != formattedDate { - t.Errorf("marshalTest: expected %v, got %v", expectedDate, formattedDate) - return - } - var stringDate string - if err2 := unmarshalDate(NativeType{proto: 2, typ: TypeDate}, data, &stringDate); err2 != nil { - t.Fatal(err2) - } - if expectedDate != stringDate { - t.Errorf("marshalTest: expected %v, got %v", expectedDate, formattedDate) - return - } -} - -func TestMarshalDate(t *testing.T) { - now := time.Now().UTC() - timestamp := now.UnixNano() / int64(time.Millisecond) - expectedData := encInt(int32(timestamp/86400000 + int64(1<<31))) - - var marshalDateTests = []struct { - Info TypeInfo - Data []byte - Value interface{} - }{ - { - NativeType{proto: 4, typ: TypeDate}, - expectedData, - timestamp, - }, - { - NativeType{proto: 4, typ: TypeDate}, - expectedData, - now, - }, - { - NativeType{proto: 4, typ: TypeDate}, - expectedData, - &now, - }, - { - NativeType{proto: 4, typ: TypeDate}, - expectedData, - now.Format("2006-01-02"), - }, - } - - for i, test := range marshalDateTests { - t.Log(i, test) - data, err := Marshal(test.Info, test.Value) - if err != nil { - t.Errorf("marshalTest[%d]: %v", i, err) - continue - } - if !bytes.Equal(data, test.Data) { - t.Errorf("marshalTest[%d]: expected %x (%v), got %x (%v) for time %s", i, - test.Data, decInt(test.Data), data, decInt(data), test.Value) - } - } -} - -func TestLargeDate(t *testing.T) { - farFuture := time.Date(999999, time.December, 31, 0, 0, 0, 0, time.UTC) - expectedFutureData := encInt(int32(farFuture.UnixMilli()/86400000 + int64(1<<31))) - - farPast := time.Date(-999999, time.January, 1, 0, 0, 0, 0, time.UTC) - expectedPastData := encInt(int32(farPast.UnixMilli()/86400000 + int64(1<<31))) - - var marshalDateTests = []struct { - Data []byte - Value interface{} - ExpectedDate string - }{ - { - expectedFutureData, - farFuture, - "999999-12-31", - }, - { - expectedPastData, - farPast, - "-999999-01-01", - }, - } - - nativeType := NativeType{proto: 4, typ: TypeDate} - - for i, test := range marshalDateTests { - t.Log(i, test) - - data, err := Marshal(nativeType, test.Value) - if err != nil { - t.Errorf("largeDateTest[%d]: %v", i, err) - continue - } - if !bytes.Equal(data, test.Data) { - t.Errorf("largeDateTest[%d]: expected %x (%v), got %x (%v) for time %s", i, - test.Data, decInt(test.Data), data, decInt(data), test.Value) - } - - var date time.Time - if err := Unmarshal(nativeType, data, &date); err != nil { - t.Fatal(err) - } - - formattedDate := date.Format("2006-01-02") - if test.ExpectedDate != formattedDate { - t.Fatalf("largeDateTest: expected %v, got %v", test.ExpectedDate, formattedDate) - } - } -} - func BenchmarkUnmarshalVarchar(b *testing.B) { b.ReportAllocs() src := make([]byte, 1024) diff --git a/serialization/date/marshal.go b/serialization/date/marshal.go new file mode 100644 index 000000000..0115c1ac3 --- /dev/null +++ b/serialization/date/marshal.go @@ -0,0 +1,42 @@ +package date + +import ( + "reflect" + "time" +) + +func Marshal(value interface{}) ([]byte, error) { + switch v := value.(type) { + case nil: + return nil, nil + case int32: + return EncInt32(v) + case int64: + return EncInt64(v) + case uint32: + return EncUint32(v) + case string: + return EncString(v) + case time.Time: + return EncTime(v) + + case *int32: + return EncInt32R(v) + case *int64: + return EncInt64R(v) + case *uint32: + return EncUint32R(v) + case *string: + return EncStringR(v) + case *time.Time: + return EncTimeR(v) + default: + // Custom types (type MyDate uint32) can be serialized only via `reflect` package. + // Later, when generic-based serialization is introduced we can do that via generics. + rv := reflect.TypeOf(value) + if rv.Kind() != reflect.Ptr { + return EncReflect(reflect.ValueOf(v)) + } + return EncReflectR(reflect.ValueOf(v)) + } +} diff --git a/serialization/date/marshal_utils.go b/serialization/date/marshal_utils.go new file mode 100644 index 000000000..e18cc1bb8 --- /dev/null +++ b/serialization/date/marshal_utils.go @@ -0,0 +1,223 @@ +package date + +import ( + "fmt" + "reflect" + "strconv" + "strings" + "time" +) + +const ( + millisecondsInADay int64 = 24 * 60 * 60 * 1000 + centerEpoch int64 = 1 << 31 + maxYear int = 5881580 + minYear int = -5877641 + maxMilliseconds int64 = 185542587100800000 + minMilliseconds int64 = -185542587187200000 +) + +var ( + maxDate = time.Date(5881580, 07, 11, 0, 0, 0, 0, time.UTC) + minDate = time.Date(-5877641, 06, 23, 0, 0, 0, 0, time.UTC) +) + +func errWrongStringFormat(v interface{}) error { + return fmt.Errorf(`failed to marshal date: the (%T)(%[1]v) should have fromat "2006-01-02"`, v) +} + +func EncInt32(v int32) ([]byte, error) { + return encInt32(v), nil +} + +func EncInt32R(v *int32) ([]byte, error) { + if v == nil { + return nil, nil + } + return encInt32(*v), nil +} + +func EncInt64(v int64) ([]byte, error) { + if v > maxMilliseconds || v < minMilliseconds { + return nil, fmt.Errorf("failed to marshal date: the (int64)(%v) value out of range", v) + } + return encInt64(days(v)), nil +} + +func EncInt64R(v *int64) ([]byte, error) { + if v == nil { + return nil, nil + } + return EncInt64(*v) +} + +func EncUint32(v uint32) ([]byte, error) { + return encUint32(v), nil +} + +func EncUint32R(v *uint32) ([]byte, error) { + if v == nil { + return nil, nil + } + return encUint32(*v), nil +} + +func EncTime(v time.Time) ([]byte, error) { + if v.After(maxDate) || v.Before(minDate) { + return nil, fmt.Errorf("failed to marshal date: the (%T)(%s) value should be in the range from -5877641-06-23 to 5881580-07-11", v, v.Format("2006-01-02")) + } + return encTime(v), nil +} + +func EncTimeR(v *time.Time) ([]byte, error) { + if v == nil { + return nil, nil + } + return EncTime(*v) +} + +func EncString(v string) ([]byte, error) { + if v == "" { + return nil, nil + } + var err error + var y, m, d int + var t time.Time + switch ps := strings.Split(v, "-"); len(ps) { + case 3: + if y, err = strconv.Atoi(ps[0]); err != nil { + return nil, errWrongStringFormat(v) + } + if m, err = strconv.Atoi(ps[1]); err != nil { + return nil, errWrongStringFormat(v) + } + if d, err = strconv.Atoi(ps[2]); err != nil { + return nil, errWrongStringFormat(v) + } + case 4: + if y, err = strconv.Atoi(ps[1]); err != nil || ps[0] != "" { + return nil, errWrongStringFormat(v) + } + y = -y + if m, err = strconv.Atoi(ps[2]); err != nil { + return nil, errWrongStringFormat(v) + } + if d, err = strconv.Atoi(ps[3]); err != nil { + return nil, errWrongStringFormat(v) + } + default: + return nil, errWrongStringFormat(v) + } + if y > maxYear || y < minYear { + return nil, fmt.Errorf("failed to marshal date: the (%T)(%[1]v) value should be in the range from -5877641-06-23 to 5881580-07-11", v) + } + t = time.Date(y, time.Month(m), d, 0, 0, 0, 0, time.UTC) + if t.After(maxDate) || t.Before(minDate) { + return nil, fmt.Errorf("failed to marshal date: the (%T)(%[1]v) value should be in the range from -5877641-06-23 to 5881580-07-11", v) + } + return encTime(t), nil +} + +func EncStringR(v *string) ([]byte, error) { + if v == nil { + return nil, nil + } + return EncString(*v) +} + +func EncReflect(v reflect.Value) ([]byte, error) { + switch v.Kind() { + case reflect.Int32: + return encInt64(v.Int()), nil + case reflect.Int64: + val := v.Int() + if val > maxMilliseconds || val < minMilliseconds { + return nil, fmt.Errorf("failed to marshal date: the value (%T)(%[1]v) out of range", v.Interface()) + } + return encInt64(days(val)), nil + case reflect.Uint32: + val := v.Uint() + return []byte{byte(val >> 24), byte(val >> 16), byte(val >> 8), byte(val)}, nil + case reflect.String: + return encReflectString(v) + case reflect.Struct: + if v.Type().String() == "gocql.unsetColumn" { + return nil, nil + } + return nil, fmt.Errorf("failed to marshal date: unsupported value type (%T)(%[1]v)", v.Interface()) + default: + return nil, fmt.Errorf("failed to marshal date: unsupported value type (%T)(%[1]v)", v.Interface()) + } +} + +func EncReflectR(v reflect.Value) ([]byte, error) { + if v.IsNil() { + return nil, nil + } + return EncReflect(v.Elem()) +} + +func encReflectString(v reflect.Value) ([]byte, error) { + val := v.String() + if val == "" { + return nil, nil + } + var err error + var y, m, d int + var t time.Time + ps := strings.Split(val, "-") + switch len(ps) { + case 3: + if y, err = strconv.Atoi(ps[0]); err != nil { + return nil, errWrongStringFormat(v.Interface()) + } + if m, err = strconv.Atoi(ps[1]); err != nil { + return nil, errWrongStringFormat(v.Interface()) + } + if d, err = strconv.Atoi(ps[2]); err != nil { + return nil, errWrongStringFormat(v.Interface()) + } + case 4: + if y, err = strconv.Atoi(ps[1]); err != nil { + return nil, errWrongStringFormat(v.Interface()) + } + y = -y + if m, err = strconv.Atoi(ps[2]); err != nil { + return nil, errWrongStringFormat(v.Interface()) + } + if d, err = strconv.Atoi(ps[3]); err != nil { + return nil, errWrongStringFormat(v.Interface()) + } + default: + return nil, errWrongStringFormat(v.Interface()) + } + if y > maxYear || y < minYear { + return nil, fmt.Errorf("failed to marshal date: the (%T)(%[1]v) value should be in the range from -5877641-06-23 to 5881580-07-11", v.Interface()) + } + t = time.Date(y, time.Month(m), d, 0, 0, 0, 0, time.UTC) + if t.After(maxDate) || t.Before(minDate) { + return nil, fmt.Errorf("failed to marshal date: the (%T)(%[1]v) value should be in the range from -5877641-06-23 to 5881580-07-11", v.Interface()) + } + return encTime(t), nil +} + +func encInt64(v int64) []byte { + return []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} +} + +func encInt32(v int32) []byte { + return []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} +} + +func encUint32(v uint32) []byte { + return []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} +} + +func encTime(v time.Time) []byte { + d := days(v.UnixMilli()) + return []byte{byte(d >> 24), byte(d >> 16), byte(d >> 8), byte(d)} +} + +func days(v int64) int64 { + return v/millisecondsInADay + centerEpoch +} diff --git a/serialization/date/unmarshal.go b/serialization/date/unmarshal.go new file mode 100644 index 000000000..bca27cbc2 --- /dev/null +++ b/serialization/date/unmarshal.go @@ -0,0 +1,49 @@ +package date + +import ( + "fmt" + "reflect" + "time" +) + +func Unmarshal(data []byte, value interface{}) error { + switch v := value.(type) { + case nil: + return nil + + case *int32: + return DecInt32(data, v) + case *int64: + return DecInt64(data, v) + case *uint32: + return DecUint32(data, v) + case *string: + return DecString(data, v) + case *time.Time: + return DecTime(data, v) + + case **int32: + return DecInt32R(data, v) + case **int64: + return DecInt64R(data, v) + case **uint32: + return DecUint32R(data, v) + case **string: + return DecStringR(data, v) + case **time.Time: + return DecTimeR(data, v) + default: + + // Custom types (type MyDate uint32) can be deserialized only via `reflect` package. + // Later, when generic-based serialization is introduced we can do that via generics. + rv := reflect.ValueOf(value) + rt := rv.Type() + if rt.Kind() != reflect.Ptr { + return fmt.Errorf("failed to unmarshal date: unsupported value type (%T)(%[1]v)", value) + } + if rt.Elem().Kind() != reflect.Ptr { + return DecReflect(data, rv) + } + return DecReflectR(data, rv) + } +} diff --git a/serialization/date/unmarshal_utils.go b/serialization/date/unmarshal_utils.go new file mode 100644 index 000000000..75cb6d10d --- /dev/null +++ b/serialization/date/unmarshal_utils.go @@ -0,0 +1,401 @@ +package date + +import ( + "fmt" + "math" + "reflect" + "time" +) + +const ( + negInt64 = int64(-1) << 32 + zeroDate = "-5877641-06-23" + zeroMS int64 = -185542587187200000 +) + +var errWrongDataLen = fmt.Errorf("failed to unmarshal date: the length of the data should be 0 or 4") + +func errNilReference(v interface{}) error { + return fmt.Errorf("failed to unmarshal date: can not unmarshal into nil reference (%T)(%[1]v))", v) +} + +func DecInt32(p []byte, v *int32) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + *v = 0 + case 4: + *v = decInt32(p) + default: + return errWrongDataLen + } + return nil +} + +func DecInt32R(p []byte, v **int32) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + *v = new(int32) + } + case 4: + val := decInt32(p) + *v = &val + default: + return errWrongDataLen + } + return nil +} + +func DecInt64(p []byte, v *int64) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + *v = zeroMS + case 4: + *v = decMilliseconds(p) + default: + return errWrongDataLen + } + return nil +} + +func DecInt64R(p []byte, v **int64) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + val := zeroMS + *v = &val + } + case 4: + val := decMilliseconds(p) + *v = &val + default: + return errWrongDataLen + } + return nil +} + +func DecUint32(p []byte, v *uint32) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + *v = 0 + case 4: + *v = decUint32(p) + default: + return errWrongDataLen + } + return nil +} + +func DecUint32R(p []byte, v **uint32) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + *v = new(uint32) + } + case 4: + val := decUint32(p) + *v = &val + default: + return errWrongDataLen + } + return nil +} + +func DecString(p []byte, v *string) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + if p == nil { + *v = "" + } else { + *v = zeroDate + } + case 4: + *v = decString(p) + default: + return errWrongDataLen + } + return nil +} + +func DecStringR(p []byte, v **string) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + val := zeroDate + *v = &val + } + case 4: + val := decString(p) + *v = &val + default: + return errWrongDataLen + } + return nil +} + +func DecTime(p []byte, v *time.Time) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + *v = minDate + case 4: + *v = decTime(p) + default: + return errWrongDataLen + } + return nil +} + +func DecTimeR(p []byte, v **time.Time) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + val := minDate + *v = &val + } + case 4: + val := decTime(p) + *v = &val + default: + return errWrongDataLen + } + return nil +} + +func DecReflect(p []byte, v reflect.Value) error { + if v.IsNil() { + return fmt.Errorf("failed to unmarshal date: can not unmarshal into nil reference (%T)(%[1]v))", v.Interface()) + } + + switch v = v.Elem(); v.Kind() { + case reflect.Int32: + return decReflectInt32(p, v) + case reflect.Int64: + return decReflectInt64(p, v) + case reflect.Uint32: + return decReflectUint32(p, v) + case reflect.String: + return decReflectString(p, v) + default: + return fmt.Errorf("failed to unmarshal date: unsupported value type (%T)(%[1]v)", v.Interface()) + } +} + +func decReflectInt32(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + v.SetInt(0) + case 4: + v.SetInt(decInt64(p)) + default: + return errWrongDataLen + } + return nil +} + +func decReflectInt64(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + v.SetInt(zeroMS) + case 4: + v.SetInt(decMilliseconds(p)) + default: + return errWrongDataLen + } + return nil +} + +func decReflectUint32(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + v.SetUint(0) + case 4: + v.SetUint(decUint64(p)) + default: + return errWrongDataLen + } + return nil +} + +func decReflectString(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + if p == nil { + v.SetString("") + } else { + v.SetString(zeroDate) + } + case 4: + v.SetString(decString(p)) + default: + return errWrongDataLen + } + return nil +} + +func DecReflectR(p []byte, v reflect.Value) error { + if v.IsNil() { + return fmt.Errorf("failed to unmarshal date: can not unmarshal into nil reference (%T)(%[1]v)", v.Interface()) + } + + switch v.Type().Elem().Elem().Kind() { + case reflect.Int32: + return decReflectInt32R(p, v) + case reflect.Int64: + return decReflectInt64R(p, v) + case reflect.Uint32: + return decReflectUint32R(p, v) + case reflect.String: + return decReflectStringR(p, v) + default: + return fmt.Errorf("failed to unmarshal date: unsupported value type (%T)(%[1]v)", v.Interface()) + } +} + +func decReflectInt32R(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + v.Elem().Set(decReflectNullableR(p, v)) + case 4: + newVal := reflect.New(v.Type().Elem().Elem()) + newVal.Elem().SetInt(decInt64(p)) + v.Elem().Set(newVal) + default: + return errWrongDataLen + } + return nil +} + +func decReflectInt64R(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + var val reflect.Value + if p == nil { + val = reflect.Zero(v.Type().Elem()) + } else { + val = reflect.New(v.Type().Elem().Elem()) + val.Elem().SetInt(zeroMS) + v.Elem().Set(val) + } + v.Elem().Set(val) + case 4: + val := reflect.New(v.Type().Elem().Elem()) + val.Elem().SetInt(decMilliseconds(p)) + v.Elem().Set(val) + default: + return errWrongDataLen + } + return nil +} + +func decReflectUint32R(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + v.Elem().Set(decReflectNullableR(p, v)) + case 4: + newVal := reflect.New(v.Type().Elem().Elem()) + newVal.Elem().SetUint(decUint64(p)) + v.Elem().Set(newVal) + default: + return errWrongDataLen + } + return nil +} + +func decReflectStringR(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + var val reflect.Value + if p == nil { + val = reflect.Zero(v.Type().Elem()) + } else { + val = reflect.New(v.Type().Elem().Elem()) + val.Elem().SetString(zeroDate) + } + v.Elem().Set(val) + case 4: + val := reflect.New(v.Type().Elem().Elem()) + val.Elem().SetString(decString(p)) + v.Elem().Set(val) + default: + return errWrongDataLen + } + return nil +} + +func decReflectNullableR(p []byte, v reflect.Value) reflect.Value { + if p == nil { + return reflect.Zero(v.Elem().Type()) + } + return reflect.New(v.Type().Elem().Elem()) +} + +func decInt32(p []byte) int32 { + return int32(p[0])<<24 | int32(p[1])<<16 | int32(p[2])<<8 | int32(p[3]) +} + +func decInt64(p []byte) int64 { + if p[0] > math.MaxInt8 { + return negInt64 | int64(p[0])<<24 | int64(p[1])<<16 | int64(p[2])<<8 | int64(p[3]) + } + return int64(p[0])<<24 | int64(p[1])<<16 | int64(p[2])<<8 | int64(p[3]) +} + +func decMilliseconds(p []byte) int64 { + return (int64(p[0])<<24 | int64(p[1])<<16 | int64(p[2])<<8 | int64(p[3]) - centerEpoch) * millisecondsInADay +} + +func decUint32(p []byte) uint32 { + return uint32(p[0])<<24 | uint32(p[1])<<16 | uint32(p[2])<<8 | uint32(p[3]) +} + +func decUint64(p []byte) uint64 { + return uint64(p[0])<<24 | uint64(p[1])<<16 | uint64(p[2])<<8 | uint64(p[3]) +} + +func decString(p []byte) string { + return decTime(p).Format("2006-01-02") +} + +func decTime(p []byte) time.Time { + return time.UnixMilli(decMilliseconds(p)).UTC() +} diff --git a/tests/serialization/marshal_17_date_corrupt_test.go b/tests/serialization/marshal_17_date_corrupt_test.go index 2f1f6197e..21b22764d 100644 --- a/tests/serialization/marshal_17_date_corrupt_test.go +++ b/tests/serialization/marshal_17_date_corrupt_test.go @@ -10,88 +10,101 @@ import ( "github.com/gocql/gocql" "github.com/gocql/gocql/internal/tests/serialization" "github.com/gocql/gocql/internal/tests/serialization/mod" + "github.com/gocql/gocql/serialization/date" ) func TestMarshalDateCorrupt(t *testing.T) { tType := gocql.NewNativeType(4, gocql.TypeDate, "") - marshal := func(i interface{}) ([]byte, error) { return gocql.Marshal(tType, i) } - unmarshal := func(bytes []byte, i interface{}) error { - return gocql.Unmarshal(tType, bytes, i) + type testSuite struct { + name string + marshal func(interface{}) ([]byte, error) + unmarshal func(bytes []byte, i interface{}) error } - // marshal the `int64`, `time.Time` values which out of the `cql type` range, does not return an error. - brokenMarshalTypes := serialization.GetTypes(int64(0), (*int64)(nil), time.Time{}, &time.Time{}) + testSuites := [2]testSuite{ + { + name: "serialization.date", + marshal: date.Marshal, + unmarshal: date.Unmarshal, + }, + { + name: "glob", + marshal: func(i interface{}) ([]byte, error) { + return gocql.Marshal(tType, i) + }, + unmarshal: func(bytes []byte, i interface{}) error { + return gocql.Unmarshal(tType, bytes, i) + }, + }, + } - // unmarshal of `string`, `time.Time` does not return an error on all type of data corruption. - brokenUnmarshalTypes := serialization.GetTypes(string(""), (*string)(nil), time.Time{}, &time.Time{}) + for _, tSuite := range testSuites { + marshal := tSuite.marshal + unmarshal := tSuite.unmarshal - serialization.NegativeMarshalSet{ - Values: mod.Values{ - time.Date(5881580, 7, 12, 0, 0, 0, 0, time.UTC).UTC().UnixMilli(), - time.Date(5881580, 8, 11, 0, 0, 0, 0, time.UTC).UTC().UnixMilli(), - time.Date(5881581, 7, 11, 0, 0, 0, 0, time.UTC).UTC().UnixMilli(), - time.Date(5883581, 12, 20, 0, 0, 0, 0, time.UTC).UTC().UnixMilli(), - "5881580-07-12", "5881580-08-11", "5881581-07-11", "9223372036854775807-07-12", - time.Date(5881580, 7, 12, 0, 0, 0, 0, time.UTC).UTC(), - time.Date(5881580, 8, 11, 0, 0, 0, 0, time.UTC).UTC(), - time.Date(5881581, 7, 11, 0, 0, 0, 0, time.UTC).UTC(), - time.Date(5883581, 12, 20, 0, 0, 0, 0, time.UTC).UTC(), - }.AddVariants(mod.All...), - BrokenTypes: brokenMarshalTypes, - }.Run("big_vals", t, marshal) + t.Run(tSuite.name, func(t *testing.T) { + serialization.NegativeMarshalSet{ + Values: mod.Values{ + time.Date(5881580, 7, 12, 0, 0, 0, 0, time.UTC).UnixMilli(), + time.Date(5881580, 8, 11, 0, 0, 0, 0, time.UTC).UnixMilli(), + time.Date(5881581, 7, 11, 0, 0, 0, 0, time.UTC).UnixMilli(), + time.Date(5883581, 12, 20, 0, 0, 0, 0, time.UTC).UnixMilli(), + "5881580-07-12", "5881580-08-11", "5881581-07-11", "9223372036854775807-07-12", + time.Date(5881580, 7, 12, 0, 0, 0, 0, time.UTC).UTC(), + time.Date(5881580, 8, 11, 0, 0, 0, 0, time.UTC).UTC(), + time.Date(5881581, 7, 11, 0, 0, 0, 0, time.UTC).UTC(), + time.Date(5883581, 12, 20, 0, 0, 0, 0, time.UTC).UTC(), + }.AddVariants(mod.All...), + }.Run("big_vals", t, marshal) - serialization.NegativeMarshalSet{ - Values: mod.Values{ - time.Date(-5877641, 06, 24, 0, 0, 0, 0, time.UTC).UTC().UnixMilli(), - time.Date(-5877641, 07, 23, 0, 0, 0, 0, time.UTC).UTC().UnixMilli(), - time.Date(-5877642, 06, 23, 0, 0, 0, 0, time.UTC).UTC().UnixMilli(), - time.Date(-5887641, 06, 23, 0, 0, 0, 0, time.UTC).UTC().UnixMilli(), - "5881580-07-12", "5881580-08-11", "5881581-07-11", "9223372036854775807-07-12", - "-5877641-06-24", "-5877641-07-23", "-5877642-06-23", "-9223372036854775807-07-12", - time.Date(-5877641, 06, 24, 0, 0, 0, 0, time.UTC).UTC(), - time.Date(-5877641, 07, 23, 0, 0, 0, 0, time.UTC).UTC(), - time.Date(-5877642, 06, 23, 0, 0, 0, 0, time.UTC).UTC(), - time.Date(-5887641, 06, 23, 0, 0, 0, 0, time.UTC).UTC(), - }.AddVariants(mod.All...), - BrokenTypes: brokenMarshalTypes, - }.Run("small_vals", t, marshal) + serialization.NegativeMarshalSet{ + Values: mod.Values{ + time.Date(-5877641, 06, 22, 0, 0, 0, 0, time.UTC).UnixMilli(), + time.Date(-5877641, 05, 23, 0, 0, 0, 0, time.UTC).UnixMilli(), + time.Date(-5877642, 06, 23, 0, 0, 0, 0, time.UTC).UnixMilli(), + time.Date(-5887641, 06, 23, 0, 0, 0, 0, time.UTC).UnixMilli(), + "-5877641-06-22", "-5877641-05-23", "-5877642-06-23", "-9223372036854775807-07-12", + time.Date(-5877641, 06, 22, 0, 0, 0, 0, time.UTC), + time.Date(-5877641, 05, 23, 0, 0, 0, 0, time.UTC), + time.Date(-5877642, 06, 23, 0, 0, 0, 0, time.UTC), + time.Date(-5887641, 06, 23, 0, 0, 0, 0, time.UTC), + }.AddVariants(mod.All...), + }.Run("small_vals", t, marshal) - serialization.NegativeMarshalSet{ - Values: mod.Values{ - "a1580-07-11", "1970-0d-11", "02-11", "1970-11", - }.AddVariants(mod.All...), - }.Run("corrupt_vals", t, marshal) + serialization.NegativeMarshalSet{ + Values: mod.Values{ + "a1580-07-11", "1970-0d-11", "02-11", "1970-11", + }.AddVariants(mod.All...), + }.Run("corrupt_vals", t, marshal) - serialization.NegativeUnmarshalSet{ - Data: []byte("\x00\x00\x00\x00\x00"), - Values: mod.Values{ - int64(0), time.Time{}, "", - }.AddVariants(mod.All...), - BrokenTypes: brokenUnmarshalTypes, - }.Run("big_data1", t, unmarshal) + serialization.NegativeUnmarshalSet{ + Data: []byte("\x00\x00\x00\x00\x00"), + Values: mod.Values{ + int64(0), time.Time{}, "", + }.AddVariants(mod.All...), + }.Run("big_data1", t, unmarshal) - serialization.NegativeUnmarshalSet{ - Data: []byte("\x00\x00\x4e\x94\x91\x4e\xff\xff\xff"), - Values: mod.Values{ - int64(0), time.Time{}, "", - }.AddVariants(mod.All...), - BrokenTypes: brokenUnmarshalTypes, - }.Run("big_data2", t, unmarshal) + serialization.NegativeUnmarshalSet{ + Data: []byte("\x00\x00\x4e\x94\x91\x4e\xff\xff\xff"), + Values: mod.Values{ + int64(0), time.Time{}, "", + }.AddVariants(mod.All...), + }.Run("big_data2", t, unmarshal) - serialization.NegativeUnmarshalSet{ - Data: []byte("\x00\x00\x00"), - Values: mod.Values{ - int64(0), time.Time{}, "", - }.AddVariants(mod.All...), - BrokenTypes: brokenUnmarshalTypes, - }.Run("small_data1", t, unmarshal) + serialization.NegativeUnmarshalSet{ + Data: []byte("\x00\x00\x00"), + Values: mod.Values{ + int64(0), time.Time{}, "", + }.AddVariants(mod.All...), + }.Run("small_data1", t, unmarshal) - serialization.NegativeUnmarshalSet{ - Data: []byte("\x00"), - Values: mod.Values{ - int64(0), time.Time{}, "", - }.AddVariants(mod.All...), - BrokenTypes: brokenUnmarshalTypes, - }.Run("small_data2", t, unmarshal) + serialization.NegativeUnmarshalSet{ + Data: []byte("\x00"), + Values: mod.Values{ + int64(0), time.Time{}, "", + }.AddVariants(mod.All...), + }.Run("small_data2", t, unmarshal) + }) + } } diff --git a/tests/serialization/marshal_17_date_test.go b/tests/serialization/marshal_17_date_test.go index 37fc4df55..30035eaab 100644 --- a/tests/serialization/marshal_17_date_test.go +++ b/tests/serialization/marshal_17_date_test.go @@ -4,88 +4,92 @@ package serialization_test import ( + "math" "testing" "time" "github.com/gocql/gocql" "github.com/gocql/gocql/internal/tests/serialization" "github.com/gocql/gocql/internal/tests/serialization/mod" + "github.com/gocql/gocql/serialization/date" ) func TestMarshalsDate(t *testing.T) { tType := gocql.NewNativeType(4, gocql.TypeDate, "") - marshal := func(i interface{}) ([]byte, error) { return gocql.Marshal(tType, i) } - unmarshal := func(bytes []byte, i interface{}) error { - return gocql.Unmarshal(tType, bytes, i) + type testSuite struct { + name string + marshal func(interface{}) ([]byte, error) + unmarshal func(bytes []byte, i interface{}) error + } + + testSuites := [2]testSuite{ + { + name: "serialization.date", + marshal: date.Marshal, + unmarshal: date.Unmarshal, + }, + { + name: "glob", + marshal: func(i interface{}) ([]byte, error) { + return gocql.Marshal(tType, i) + }, + unmarshal: func(bytes []byte, i interface{}) error { + return gocql.Unmarshal(tType, bytes, i) + }, + }, } zeroDate := time.Date(-5877641, 06, 23, 0, 0, 0, 0, time.UTC).UTC() middleDate := time.UnixMilli(0).UTC() maxDate := time.Date(5881580, 07, 11, 0, 0, 0, 0, time.UTC).UTC() - // marshal strings with big year like "-5877641-06-23" returns an error - brokenBigString := serialization.GetTypes(string(""), (*string)(nil)) - - // marshal `custom string` and `custom int64` unsupported - brokenMarshal := serialization.GetTypes(mod.String(""), (*mod.String)(nil), mod.Int64(0), (*mod.Int64)(nil)) - - // unmarshal `zero` data into not zero string and time.Time - brokenZero := serialization.GetTypes(time.Time{}, &time.Time{}, string(""), (*string)(nil)) - - // unmarshal `nil` data into not zero time.Time - brokenNil := serialization.GetTypes(time.Time{}) - - // unmarshal into `custom string`, `int64` and `custom int64` unsupported - brokenUnmarshal := serialization.GetTypes(mod.String(""), (*mod.String)(nil), mod.Int64(0), (*mod.Int64)(nil), int64(0), (*int64)(nil)) - - serialization.PositiveSet{ - Data: nil, - Values: mod.Values{ - (*int64)(nil), (*time.Time)(nil), (*string)(nil), - }.AddVariants(mod.CustomType), - }.Run("[nil]nullable", t, marshal, unmarshal) - - serialization.PositiveSet{ - Data: nil, - Values: mod.Values{ - int64(0), zeroDate, "", - }.AddVariants(mod.CustomType), - BrokenUnmarshalTypes: append(brokenUnmarshal, brokenNil...), - }.Run("[nil]unmarshal", t, nil, unmarshal) - - serialization.PositiveSet{ - Data: make([]byte, 0), - Values: mod.Values{ - int64(0), zeroDate, "-5877641-06-23", - }.AddVariants(mod.All...), - BrokenUnmarshalTypes: append(brokenUnmarshal, brokenZero...), - }.Run("[]unmarshal", t, nil, unmarshal) - - serialization.PositiveSet{ - Data: []byte("\x00\x00\x00\x00"), - Values: mod.Values{ - zeroDate.UnixMilli(), zeroDate, "-5877641-06-23", - }.AddVariants(mod.All...), - BrokenMarshalTypes: append(brokenMarshal, brokenBigString...), - BrokenUnmarshalTypes: brokenUnmarshal, - }.Run("zeros", t, marshal, unmarshal) - - serialization.PositiveSet{ - Data: []byte("\x80\x00\x00\x00"), - Values: mod.Values{ - middleDate.UnixMilli(), middleDate, "1970-01-01", - }.AddVariants(mod.All...), - BrokenMarshalTypes: brokenMarshal, - BrokenUnmarshalTypes: brokenUnmarshal, - }.Run("middle", t, marshal, unmarshal) - - serialization.PositiveSet{ - Data: []byte("\xff\xff\xff\xff"), - Values: mod.Values{ - maxDate.UnixMilli(), maxDate, "5881580-07-11", - }.AddVariants(mod.All...), - BrokenMarshalTypes: append(brokenMarshal, brokenBigString...), - BrokenUnmarshalTypes: brokenUnmarshal, - }.Run("max", t, marshal, unmarshal) + for _, tSuite := range testSuites { + marshal := tSuite.marshal + unmarshal := tSuite.unmarshal + + t.Run(tSuite.name, func(t *testing.T) { + serialization.PositiveSet{ + Data: nil, + Values: mod.Values{ + (*uint32)(nil), (*int32)(nil), (*int64)(nil), (*string)(nil), (*time.Time)(nil), + }.AddVariants(mod.CustomType), + }.Run("[nil]nullable", t, marshal, unmarshal) + + serialization.PositiveSet{ + Data: nil, + Values: mod.Values{ + uint32(0), int32(0), zeroDate.UnixMilli(), "", zeroDate, + }.AddVariants(mod.CustomType), + }.Run("[nil]unmarshal", t, nil, unmarshal) + + serialization.PositiveSet{ + Data: make([]byte, 0), + Values: mod.Values{ + uint32(0), int32(0), zeroDate.UnixMilli(), zeroDate, "-5877641-06-23", + }.AddVariants(mod.All...), + }.Run("[]unmarshal", t, nil, unmarshal) + + serialization.PositiveSet{ + Data: []byte("\x00\x00\x00\x00"), + Values: mod.Values{ + uint32(0), int32(0), zeroDate.UnixMilli(), zeroDate, "-5877641-06-23", + }.AddVariants(mod.All...), + }.Run("zeros", t, marshal, unmarshal) + + serialization.PositiveSet{ + Data: []byte("\x80\x00\x00\x00"), + Values: mod.Values{ + uint32(1 << 31), int32(math.MinInt32), middleDate.UnixMilli(), middleDate, "1970-01-01", + }.AddVariants(mod.All...), + }.Run("middle", t, marshal, unmarshal) + + serialization.PositiveSet{ + Data: []byte("\xff\xff\xff\xff"), + Values: mod.Values{ + uint32(math.MaxUint32), int32(-1), maxDate.UnixMilli(), maxDate, "5881580-07-11", + }.AddVariants(mod.All...), + }.Run("max", t, marshal, unmarshal) + }) + } }