diff --git a/marshal.go b/marshal.go index 34cd443df..d5e511242 100644 --- a/marshal.go +++ b/marshal.go @@ -6,7 +6,6 @@ package gocql import ( "bytes" - "encoding/binary" "errors" "fmt" "math" @@ -23,6 +22,7 @@ import ( "github.com/gocql/gocql/serialization/counter" "github.com/gocql/gocql/serialization/cqlint" "github.com/gocql/gocql/serialization/cqltime" + "github.com/gocql/gocql/serialization/date" "github.com/gocql/gocql/serialization/decimal" "github.com/gocql/gocql/serialization/double" "github.com/gocql/gocql/serialization/float" @@ -192,7 +192,7 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) { case TypeUDT: return marshalUDT(info, value) case TypeDate: - return marshalDate(info, value) + return marshalDate(value) case TypeDuration: return marshalDuration(info, value) } @@ -304,7 +304,7 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error { case TypeUDT: return unmarshalUDT(info, data, value) case TypeDate: - return unmarshalDate(info, data, value) + return unmarshalDate(data, value) case TypeDuration: return unmarshalDuration(info, data, value) } @@ -700,78 +700,20 @@ func unmarshalTimestamp(data []byte, value interface{}) error { return nil } -const millisecondsInADay int64 = 24 * 60 * 60 * 1000 - -func marshalDate(info TypeInfo, value interface{}) ([]byte, error) { - var timestamp int64 - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int64: - timestamp = v - x := timestamp/millisecondsInADay + int64(1<<31) - return encInt(int32(x)), nil - case time.Time: - if v.IsZero() { - return []byte{}, nil - } - timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) - x := timestamp/millisecondsInADay + int64(1<<31) - return encInt(int32(x)), nil - case *time.Time: - if v.IsZero() { - return []byte{}, nil - } - timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) - x := timestamp/millisecondsInADay + int64(1<<31) - return encInt(int32(x)), nil - case string: - if v == "" { - return []byte{}, nil - } - t, err := time.Parse("2006-01-02", v) - if err != nil { - return nil, marshalErrorf("can not marshal %T into %s, date layout must be '2006-01-02'", value, info) - } - timestamp = int64(t.UTC().Unix()*1e3) + int64(t.UTC().Nanosecond()/1e6) - x := timestamp/millisecondsInADay + int64(1<<31) - return encInt(int32(x)), nil - } - - if value == nil { - return nil, nil +func marshalDate(value interface{}) ([]byte, error) { + data, err := date.Marshal(value) + if err != nil { + return nil, wrapMarshalError(err, "marshal error") } - return nil, marshalErrorf("can not marshal %T into %s", value, info) + return data, nil } -func unmarshalDate(info TypeInfo, data []byte, value interface{}) error { - switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) - case *time.Time: - if len(data) == 0 { - *v = time.Time{} - return nil - } - var origin uint32 = 1 << 31 - var current uint32 = binary.BigEndian.Uint32(data) - timestamp := (int64(current) - int64(origin)) * millisecondsInADay - *v = time.UnixMilli(timestamp).In(time.UTC) - return nil - case *string: - if len(data) == 0 { - *v = "" - return nil - } - var origin uint32 = 1 << 31 - var current uint32 = binary.BigEndian.Uint32(data) - timestamp := (int64(current) - int64(origin)) * millisecondsInADay - *v = time.UnixMilli(timestamp).In(time.UTC).Format("2006-01-02") - return nil +func unmarshalDate(data []byte, value interface{}) error { + err := date.Unmarshal(data, value) + if err != nil { + return wrapUnmarshalError(err, "unmarshal error") } - return unmarshalErrorf("can not unmarshal %s into %T", info, value) + return nil } func marshalDuration(info TypeInfo, value interface{}) ([]byte, error) { diff --git a/serialization/date/marshal.go b/serialization/date/marshal.go new file mode 100644 index 000000000..0115c1ac3 --- /dev/null +++ b/serialization/date/marshal.go @@ -0,0 +1,42 @@ +package date + +import ( + "reflect" + "time" +) + +func Marshal(value interface{}) ([]byte, error) { + switch v := value.(type) { + case nil: + return nil, nil + case int32: + return EncInt32(v) + case int64: + return EncInt64(v) + case uint32: + return EncUint32(v) + case string: + return EncString(v) + case time.Time: + return EncTime(v) + + case *int32: + return EncInt32R(v) + case *int64: + return EncInt64R(v) + case *uint32: + return EncUint32R(v) + case *string: + return EncStringR(v) + case *time.Time: + return EncTimeR(v) + default: + // Custom types (type MyDate uint32) can be serialized only via `reflect` package. + // Later, when generic-based serialization is introduced we can do that via generics. + rv := reflect.TypeOf(value) + if rv.Kind() != reflect.Ptr { + return EncReflect(reflect.ValueOf(v)) + } + return EncReflectR(reflect.ValueOf(v)) + } +} diff --git a/serialization/date/marshal_utils.go b/serialization/date/marshal_utils.go new file mode 100644 index 000000000..e18cc1bb8 --- /dev/null +++ b/serialization/date/marshal_utils.go @@ -0,0 +1,223 @@ +package date + +import ( + "fmt" + "reflect" + "strconv" + "strings" + "time" +) + +const ( + millisecondsInADay int64 = 24 * 60 * 60 * 1000 + centerEpoch int64 = 1 << 31 + maxYear int = 5881580 + minYear int = -5877641 + maxMilliseconds int64 = 185542587100800000 + minMilliseconds int64 = -185542587187200000 +) + +var ( + maxDate = time.Date(5881580, 07, 11, 0, 0, 0, 0, time.UTC) + minDate = time.Date(-5877641, 06, 23, 0, 0, 0, 0, time.UTC) +) + +func errWrongStringFormat(v interface{}) error { + return fmt.Errorf(`failed to marshal date: the (%T)(%[1]v) should have fromat "2006-01-02"`, v) +} + +func EncInt32(v int32) ([]byte, error) { + return encInt32(v), nil +} + +func EncInt32R(v *int32) ([]byte, error) { + if v == nil { + return nil, nil + } + return encInt32(*v), nil +} + +func EncInt64(v int64) ([]byte, error) { + if v > maxMilliseconds || v < minMilliseconds { + return nil, fmt.Errorf("failed to marshal date: the (int64)(%v) value out of range", v) + } + return encInt64(days(v)), nil +} + +func EncInt64R(v *int64) ([]byte, error) { + if v == nil { + return nil, nil + } + return EncInt64(*v) +} + +func EncUint32(v uint32) ([]byte, error) { + return encUint32(v), nil +} + +func EncUint32R(v *uint32) ([]byte, error) { + if v == nil { + return nil, nil + } + return encUint32(*v), nil +} + +func EncTime(v time.Time) ([]byte, error) { + if v.After(maxDate) || v.Before(minDate) { + return nil, fmt.Errorf("failed to marshal date: the (%T)(%s) value should be in the range from -5877641-06-23 to 5881580-07-11", v, v.Format("2006-01-02")) + } + return encTime(v), nil +} + +func EncTimeR(v *time.Time) ([]byte, error) { + if v == nil { + return nil, nil + } + return EncTime(*v) +} + +func EncString(v string) ([]byte, error) { + if v == "" { + return nil, nil + } + var err error + var y, m, d int + var t time.Time + switch ps := strings.Split(v, "-"); len(ps) { + case 3: + if y, err = strconv.Atoi(ps[0]); err != nil { + return nil, errWrongStringFormat(v) + } + if m, err = strconv.Atoi(ps[1]); err != nil { + return nil, errWrongStringFormat(v) + } + if d, err = strconv.Atoi(ps[2]); err != nil { + return nil, errWrongStringFormat(v) + } + case 4: + if y, err = strconv.Atoi(ps[1]); err != nil || ps[0] != "" { + return nil, errWrongStringFormat(v) + } + y = -y + if m, err = strconv.Atoi(ps[2]); err != nil { + return nil, errWrongStringFormat(v) + } + if d, err = strconv.Atoi(ps[3]); err != nil { + return nil, errWrongStringFormat(v) + } + default: + return nil, errWrongStringFormat(v) + } + if y > maxYear || y < minYear { + return nil, fmt.Errorf("failed to marshal date: the (%T)(%[1]v) value should be in the range from -5877641-06-23 to 5881580-07-11", v) + } + t = time.Date(y, time.Month(m), d, 0, 0, 0, 0, time.UTC) + if t.After(maxDate) || t.Before(minDate) { + return nil, fmt.Errorf("failed to marshal date: the (%T)(%[1]v) value should be in the range from -5877641-06-23 to 5881580-07-11", v) + } + return encTime(t), nil +} + +func EncStringR(v *string) ([]byte, error) { + if v == nil { + return nil, nil + } + return EncString(*v) +} + +func EncReflect(v reflect.Value) ([]byte, error) { + switch v.Kind() { + case reflect.Int32: + return encInt64(v.Int()), nil + case reflect.Int64: + val := v.Int() + if val > maxMilliseconds || val < minMilliseconds { + return nil, fmt.Errorf("failed to marshal date: the value (%T)(%[1]v) out of range", v.Interface()) + } + return encInt64(days(val)), nil + case reflect.Uint32: + val := v.Uint() + return []byte{byte(val >> 24), byte(val >> 16), byte(val >> 8), byte(val)}, nil + case reflect.String: + return encReflectString(v) + case reflect.Struct: + if v.Type().String() == "gocql.unsetColumn" { + return nil, nil + } + return nil, fmt.Errorf("failed to marshal date: unsupported value type (%T)(%[1]v)", v.Interface()) + default: + return nil, fmt.Errorf("failed to marshal date: unsupported value type (%T)(%[1]v)", v.Interface()) + } +} + +func EncReflectR(v reflect.Value) ([]byte, error) { + if v.IsNil() { + return nil, nil + } + return EncReflect(v.Elem()) +} + +func encReflectString(v reflect.Value) ([]byte, error) { + val := v.String() + if val == "" { + return nil, nil + } + var err error + var y, m, d int + var t time.Time + ps := strings.Split(val, "-") + switch len(ps) { + case 3: + if y, err = strconv.Atoi(ps[0]); err != nil { + return nil, errWrongStringFormat(v.Interface()) + } + if m, err = strconv.Atoi(ps[1]); err != nil { + return nil, errWrongStringFormat(v.Interface()) + } + if d, err = strconv.Atoi(ps[2]); err != nil { + return nil, errWrongStringFormat(v.Interface()) + } + case 4: + if y, err = strconv.Atoi(ps[1]); err != nil { + return nil, errWrongStringFormat(v.Interface()) + } + y = -y + if m, err = strconv.Atoi(ps[2]); err != nil { + return nil, errWrongStringFormat(v.Interface()) + } + if d, err = strconv.Atoi(ps[3]); err != nil { + return nil, errWrongStringFormat(v.Interface()) + } + default: + return nil, errWrongStringFormat(v.Interface()) + } + if y > maxYear || y < minYear { + return nil, fmt.Errorf("failed to marshal date: the (%T)(%[1]v) value should be in the range from -5877641-06-23 to 5881580-07-11", v.Interface()) + } + t = time.Date(y, time.Month(m), d, 0, 0, 0, 0, time.UTC) + if t.After(maxDate) || t.Before(minDate) { + return nil, fmt.Errorf("failed to marshal date: the (%T)(%[1]v) value should be in the range from -5877641-06-23 to 5881580-07-11", v.Interface()) + } + return encTime(t), nil +} + +func encInt64(v int64) []byte { + return []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} +} + +func encInt32(v int32) []byte { + return []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} +} + +func encUint32(v uint32) []byte { + return []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} +} + +func encTime(v time.Time) []byte { + d := days(v.UnixMilli()) + return []byte{byte(d >> 24), byte(d >> 16), byte(d >> 8), byte(d)} +} + +func days(v int64) int64 { + return v/millisecondsInADay + centerEpoch +} diff --git a/serialization/date/unmarshal.go b/serialization/date/unmarshal.go new file mode 100644 index 000000000..bca27cbc2 --- /dev/null +++ b/serialization/date/unmarshal.go @@ -0,0 +1,49 @@ +package date + +import ( + "fmt" + "reflect" + "time" +) + +func Unmarshal(data []byte, value interface{}) error { + switch v := value.(type) { + case nil: + return nil + + case *int32: + return DecInt32(data, v) + case *int64: + return DecInt64(data, v) + case *uint32: + return DecUint32(data, v) + case *string: + return DecString(data, v) + case *time.Time: + return DecTime(data, v) + + case **int32: + return DecInt32R(data, v) + case **int64: + return DecInt64R(data, v) + case **uint32: + return DecUint32R(data, v) + case **string: + return DecStringR(data, v) + case **time.Time: + return DecTimeR(data, v) + default: + + // Custom types (type MyDate uint32) can be deserialized only via `reflect` package. + // Later, when generic-based serialization is introduced we can do that via generics. + rv := reflect.ValueOf(value) + rt := rv.Type() + if rt.Kind() != reflect.Ptr { + return fmt.Errorf("failed to unmarshal date: unsupported value type (%T)(%[1]v)", value) + } + if rt.Elem().Kind() != reflect.Ptr { + return DecReflect(data, rv) + } + return DecReflectR(data, rv) + } +} diff --git a/serialization/date/unmarshal_utils.go b/serialization/date/unmarshal_utils.go new file mode 100644 index 000000000..75cb6d10d --- /dev/null +++ b/serialization/date/unmarshal_utils.go @@ -0,0 +1,401 @@ +package date + +import ( + "fmt" + "math" + "reflect" + "time" +) + +const ( + negInt64 = int64(-1) << 32 + zeroDate = "-5877641-06-23" + zeroMS int64 = -185542587187200000 +) + +var errWrongDataLen = fmt.Errorf("failed to unmarshal date: the length of the data should be 0 or 4") + +func errNilReference(v interface{}) error { + return fmt.Errorf("failed to unmarshal date: can not unmarshal into nil reference (%T)(%[1]v))", v) +} + +func DecInt32(p []byte, v *int32) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + *v = 0 + case 4: + *v = decInt32(p) + default: + return errWrongDataLen + } + return nil +} + +func DecInt32R(p []byte, v **int32) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + *v = new(int32) + } + case 4: + val := decInt32(p) + *v = &val + default: + return errWrongDataLen + } + return nil +} + +func DecInt64(p []byte, v *int64) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + *v = zeroMS + case 4: + *v = decMilliseconds(p) + default: + return errWrongDataLen + } + return nil +} + +func DecInt64R(p []byte, v **int64) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + val := zeroMS + *v = &val + } + case 4: + val := decMilliseconds(p) + *v = &val + default: + return errWrongDataLen + } + return nil +} + +func DecUint32(p []byte, v *uint32) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + *v = 0 + case 4: + *v = decUint32(p) + default: + return errWrongDataLen + } + return nil +} + +func DecUint32R(p []byte, v **uint32) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + *v = new(uint32) + } + case 4: + val := decUint32(p) + *v = &val + default: + return errWrongDataLen + } + return nil +} + +func DecString(p []byte, v *string) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + if p == nil { + *v = "" + } else { + *v = zeroDate + } + case 4: + *v = decString(p) + default: + return errWrongDataLen + } + return nil +} + +func DecStringR(p []byte, v **string) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + val := zeroDate + *v = &val + } + case 4: + val := decString(p) + *v = &val + default: + return errWrongDataLen + } + return nil +} + +func DecTime(p []byte, v *time.Time) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + *v = minDate + case 4: + *v = decTime(p) + default: + return errWrongDataLen + } + return nil +} + +func DecTimeR(p []byte, v **time.Time) error { + if v == nil { + return errNilReference(v) + } + switch len(p) { + case 0: + if p == nil { + *v = nil + } else { + val := minDate + *v = &val + } + case 4: + val := decTime(p) + *v = &val + default: + return errWrongDataLen + } + return nil +} + +func DecReflect(p []byte, v reflect.Value) error { + if v.IsNil() { + return fmt.Errorf("failed to unmarshal date: can not unmarshal into nil reference (%T)(%[1]v))", v.Interface()) + } + + switch v = v.Elem(); v.Kind() { + case reflect.Int32: + return decReflectInt32(p, v) + case reflect.Int64: + return decReflectInt64(p, v) + case reflect.Uint32: + return decReflectUint32(p, v) + case reflect.String: + return decReflectString(p, v) + default: + return fmt.Errorf("failed to unmarshal date: unsupported value type (%T)(%[1]v)", v.Interface()) + } +} + +func decReflectInt32(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + v.SetInt(0) + case 4: + v.SetInt(decInt64(p)) + default: + return errWrongDataLen + } + return nil +} + +func decReflectInt64(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + v.SetInt(zeroMS) + case 4: + v.SetInt(decMilliseconds(p)) + default: + return errWrongDataLen + } + return nil +} + +func decReflectUint32(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + v.SetUint(0) + case 4: + v.SetUint(decUint64(p)) + default: + return errWrongDataLen + } + return nil +} + +func decReflectString(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + if p == nil { + v.SetString("") + } else { + v.SetString(zeroDate) + } + case 4: + v.SetString(decString(p)) + default: + return errWrongDataLen + } + return nil +} + +func DecReflectR(p []byte, v reflect.Value) error { + if v.IsNil() { + return fmt.Errorf("failed to unmarshal date: can not unmarshal into nil reference (%T)(%[1]v)", v.Interface()) + } + + switch v.Type().Elem().Elem().Kind() { + case reflect.Int32: + return decReflectInt32R(p, v) + case reflect.Int64: + return decReflectInt64R(p, v) + case reflect.Uint32: + return decReflectUint32R(p, v) + case reflect.String: + return decReflectStringR(p, v) + default: + return fmt.Errorf("failed to unmarshal date: unsupported value type (%T)(%[1]v)", v.Interface()) + } +} + +func decReflectInt32R(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + v.Elem().Set(decReflectNullableR(p, v)) + case 4: + newVal := reflect.New(v.Type().Elem().Elem()) + newVal.Elem().SetInt(decInt64(p)) + v.Elem().Set(newVal) + default: + return errWrongDataLen + } + return nil +} + +func decReflectInt64R(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + var val reflect.Value + if p == nil { + val = reflect.Zero(v.Type().Elem()) + } else { + val = reflect.New(v.Type().Elem().Elem()) + val.Elem().SetInt(zeroMS) + v.Elem().Set(val) + } + v.Elem().Set(val) + case 4: + val := reflect.New(v.Type().Elem().Elem()) + val.Elem().SetInt(decMilliseconds(p)) + v.Elem().Set(val) + default: + return errWrongDataLen + } + return nil +} + +func decReflectUint32R(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + v.Elem().Set(decReflectNullableR(p, v)) + case 4: + newVal := reflect.New(v.Type().Elem().Elem()) + newVal.Elem().SetUint(decUint64(p)) + v.Elem().Set(newVal) + default: + return errWrongDataLen + } + return nil +} + +func decReflectStringR(p []byte, v reflect.Value) error { + switch len(p) { + case 0: + var val reflect.Value + if p == nil { + val = reflect.Zero(v.Type().Elem()) + } else { + val = reflect.New(v.Type().Elem().Elem()) + val.Elem().SetString(zeroDate) + } + v.Elem().Set(val) + case 4: + val := reflect.New(v.Type().Elem().Elem()) + val.Elem().SetString(decString(p)) + v.Elem().Set(val) + default: + return errWrongDataLen + } + return nil +} + +func decReflectNullableR(p []byte, v reflect.Value) reflect.Value { + if p == nil { + return reflect.Zero(v.Elem().Type()) + } + return reflect.New(v.Type().Elem().Elem()) +} + +func decInt32(p []byte) int32 { + return int32(p[0])<<24 | int32(p[1])<<16 | int32(p[2])<<8 | int32(p[3]) +} + +func decInt64(p []byte) int64 { + if p[0] > math.MaxInt8 { + return negInt64 | int64(p[0])<<24 | int64(p[1])<<16 | int64(p[2])<<8 | int64(p[3]) + } + return int64(p[0])<<24 | int64(p[1])<<16 | int64(p[2])<<8 | int64(p[3]) +} + +func decMilliseconds(p []byte) int64 { + return (int64(p[0])<<24 | int64(p[1])<<16 | int64(p[2])<<8 | int64(p[3]) - centerEpoch) * millisecondsInADay +} + +func decUint32(p []byte) uint32 { + return uint32(p[0])<<24 | uint32(p[1])<<16 | uint32(p[2])<<8 | uint32(p[3]) +} + +func decUint64(p []byte) uint64 { + return uint64(p[0])<<24 | uint64(p[1])<<16 | uint64(p[2])<<8 | uint64(p[3]) +} + +func decString(p []byte) string { + return decTime(p).Format("2006-01-02") +} + +func decTime(p []byte) time.Time { + return time.UnixMilli(decMilliseconds(p)).UTC() +}