From d8502e48a6aad508beee2aa6ca55508fd3e089c4 Mon Sep 17 00:00:00 2001 From: illia-li Date: Fri, 20 Dec 2024 12:52:54 -0400 Subject: [PATCH] fix `duration` marshal, unmarshal functions Changes: 1. Unmarshalling into `int64`,`custom int64`,`time.Duration`,`string` and `custom string`, was unsupported before, now is supported. 2. Marshalling are `custom string`, `custom int64` was unsupported before, now is supported. 3. Unmarshalling `broken data` does not return an error before, now returns an error. --- marshal.go | 200 +------ serialization/duration/duration.go | 17 + serialization/duration/marshal.go | 38 ++ serialization/duration/marshal_utils.go | 189 +++++++ serialization/duration/marshal_vint_test.go | 61 ++ serialization/duration/unmarshal.go | 45 ++ serialization/duration/unmarshal_utils.go | 531 ++++++++++++++++++ serialization/duration/unmarshal_vint_test.go | 51 ++ 8 files changed, 955 insertions(+), 177 deletions(-) create mode 100644 serialization/duration/duration.go create mode 100644 serialization/duration/marshal.go create mode 100644 serialization/duration/marshal_utils.go create mode 100644 serialization/duration/marshal_vint_test.go create mode 100644 serialization/duration/unmarshal.go create mode 100644 serialization/duration/unmarshal_utils.go create mode 100644 serialization/duration/unmarshal_vint_test.go diff --git a/marshal.go b/marshal.go index d5e511242..5f6c16ae7 100644 --- a/marshal.go +++ b/marshal.go @@ -9,11 +9,8 @@ import ( "errors" "fmt" "math" - "math/big" - "math/bits" "reflect" "strings" - "time" "unsafe" "github.com/gocql/gocql/serialization/ascii" @@ -25,6 +22,7 @@ import ( "github.com/gocql/gocql/serialization/date" "github.com/gocql/gocql/serialization/decimal" "github.com/gocql/gocql/serialization/double" + "github.com/gocql/gocql/serialization/duration" "github.com/gocql/gocql/serialization/float" "github.com/gocql/gocql/serialization/inet" "github.com/gocql/gocql/serialization/smallint" @@ -38,7 +36,6 @@ import ( ) var ( - bigOne = big.NewInt(1) emptyValue reflect.Value ) @@ -194,7 +191,7 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) { case TypeDate: return marshalDate(value) case TypeDuration: - return marshalDuration(info, value) + return marshalDuration(value) } // detect protocol 2 UDT @@ -306,7 +303,7 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error { case TypeDate: return unmarshalDate(data, value) case TypeDuration: - return unmarshalDuration(info, data, value) + return unmarshalDuration(data, value) } // detect protocol 2 UDT @@ -429,17 +426,6 @@ func marshalInt(value interface{}) ([]byte, error) { return data, nil } -func encInt(x int32) []byte { - return []byte{byte(x >> 24), byte(x >> 16), byte(x >> 8), byte(x)} -} - -func decInt(x []byte) int32 { - if len(x) != 4 { - return 0 - } - return int32(x[0])<<24 | int32(x[1])<<16 | int32(x[2])<<8 | int32(x[3]) -} - func marshalBigInt(value interface{}) ([]byte, error) { data, err := bigint.Marshal(value) if err != nil { @@ -456,11 +442,6 @@ func marshalCounter(value interface{}) ([]byte, error) { return data, nil } -func encBigInt(x int64) []byte { - return []byte{byte(x >> 56), byte(x >> 48), byte(x >> 40), byte(x >> 32), - byte(x >> 24), byte(x >> 16), byte(x >> 8), byte(x)} -} - func unmarshalCounter(data []byte, value interface{}) error { err := counter.Unmarshal(data, value) if err != nil { @@ -628,46 +609,6 @@ func unmarshalDecimal(data []byte, value interface{}) error { return nil } -// decBigInt2C sets the value of n to the big-endian two's complement -// value stored in the given data. If data[0]&80 != 0, the number -// is negative. If data is empty, the result will be 0. -func decBigInt2C(data []byte, n *big.Int) *big.Int { - if n == nil { - n = new(big.Int) - } - n.SetBytes(data) - if len(data) > 0 && data[0]&0x80 > 0 { - n.Sub(n, new(big.Int).Lsh(bigOne, uint(len(data))*8)) - } - return n -} - -// encBigInt2C returns the big-endian two's complement -// form of n. -func encBigInt2C(n *big.Int) []byte { - switch n.Sign() { - case 0: - return []byte{0} - case 1: - b := n.Bytes() - if b[0]&0x80 > 0 { - b = append([]byte{0}, b...) - } - return b - case -1: - length := uint(n.BitLen()/8+1) * 8 - b := new(big.Int).Add(n, new(big.Int).Lsh(bigOne, length)).Bytes() - // When the most significant bit is on a byte - // boundary, we can get some extra significant - // bits, so strip them off when that happens. - if len(b) >= 2 && b[0] == 0xff && b[1]&0x80 != 0 { - b = b[1:] - } - return b - } - return nil -} - func marshalTime(value interface{}) ([]byte, error) { data, err := cqltime.Marshal(value) if err != nil { @@ -716,131 +657,36 @@ func unmarshalDate(data []byte, value interface{}) error { return nil } -func marshalDuration(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int64: - return encVints(0, 0, v), nil - case time.Duration: - return encVints(0, 0, v.Nanoseconds()), nil - case string: - d, err := time.ParseDuration(v) - if err != nil { - return nil, err - } - return encVints(0, 0, d.Nanoseconds()), nil +func marshalDuration(value interface{}) ([]byte, error) { + switch uv := value.(type) { case Duration: - return encVints(v.Months, v.Days, v.Nanoseconds), nil - } - - if value == nil { - return nil, nil + value = duration.Duration(uv) + case *Duration: + value = (*duration.Duration)(uv) } - - rv := reflect.ValueOf(value) - switch rv.Type().Kind() { - case reflect.Int64: - return encVints(0, 0, rv.Int()), nil + data, err := duration.Marshal(value) + if err != nil { + return nil, wrapMarshalError(err, "marshal error") } - return nil, marshalErrorf("can not marshal %T into %s", value, info) + return data, nil } -func unmarshalDuration(info TypeInfo, data []byte, value interface{}) error { - switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) +func unmarshalDuration(data []byte, value interface{}) error { + switch uv := value.(type) { case *Duration: - if len(data) == 0 { - *v = Duration{ - Months: 0, - Days: 0, - Nanoseconds: 0, - } - return nil - } - months, days, nanos, err := decVints(data) - if err != nil { - return unmarshalErrorf("failed to unmarshal %s into %T: %s", info, value, err.Error()) - } - *v = Duration{ - Months: months, - Days: days, - Nanoseconds: nanos, + value = (*duration.Duration)(uv) + case **Duration: + if uv == nil { + value = (**duration.Duration)(nil) + } else { + value = (**duration.Duration)(unsafe.Pointer(uv)) } - return nil - } - return unmarshalErrorf("can not unmarshal %s into %T", info, value) -} - -func decVints(data []byte) (int32, int32, int64, error) { - month, i, err := decVint(data, 0) - if err != nil { - return 0, 0, 0, fmt.Errorf("failed to extract month: %s", err.Error()) - } - days, i, err := decVint(data, i) - if err != nil { - return 0, 0, 0, fmt.Errorf("failed to extract days: %s", err.Error()) } - nanos, _, err := decVint(data, i) + err := duration.Unmarshal(data, value) if err != nil { - return 0, 0, 0, fmt.Errorf("failed to extract nanoseconds: %s", err.Error()) - } - return int32(month), int32(days), nanos, err -} - -func decVint(data []byte, start int) (int64, int, error) { - if len(data) <= start { - return 0, 0, errors.New("unexpected eof") - } - firstByte := data[start] - if firstByte&0x80 == 0 { - return decIntZigZag(uint64(firstByte)), start + 1, nil - } - numBytes := bits.LeadingZeros32(uint32(^firstByte)) - 24 - ret := uint64(firstByte & (0xff >> uint(numBytes))) - if len(data) < start+numBytes+1 { - return 0, 0, fmt.Errorf("data expect to have %d bytes, but it has only %d", start+numBytes+1, len(data)) - } - for i := start; i < start+numBytes; i++ { - ret <<= 8 - ret |= uint64(data[i+1] & 0xff) - } - return decIntZigZag(ret), start + numBytes + 1, nil -} - -func decIntZigZag(n uint64) int64 { - return int64((n >> 1) ^ -(n & 1)) -} - -func encIntZigZag(n int64) uint64 { - return uint64((n >> 63) ^ (n << 1)) -} - -func encVints(months int32, days int32, nanos int64) []byte { - buf := append(encVint(int64(months)), encVint(int64(days))...) - return append(buf, encVint(nanos)...) -} - -func encVint(v int64) []byte { - vEnc := encIntZigZag(v) - lead0 := bits.LeadingZeros64(vEnc) - numBytes := (639 - lead0*9) >> 6 - - // It can be 1 or 0 is v ==0 - if numBytes <= 1 { - return []byte{byte(vEnc)} - } - extraBytes := numBytes - 1 - var buf = make([]byte, numBytes) - for i := extraBytes; i >= 0; i-- { - buf[i] = byte(vEnc) - vEnc >>= 8 + return wrapUnmarshalError(err, "unmarshal error") } - buf[0] |= byte(^(0xff >> uint(extraBytes))) - return buf + return nil } func writeCollectionSize(info CollectionType, n int, buf *bytes.Buffer) error { diff --git a/serialization/duration/duration.go b/serialization/duration/duration.go new file mode 100644 index 000000000..823fd87f4 --- /dev/null +++ b/serialization/duration/duration.go @@ -0,0 +1,17 @@ +package duration + +type Duration struct { + Months int32 + Days int32 + Nanoseconds int64 +} + +func (d Duration) Valid() bool { + if d.Months >= 0 && d.Days >= 0 && d.Nanoseconds >= 0 { + return true + } + if d.Months <= 0 && d.Days <= 0 && d.Nanoseconds <= 0 { + return true + } + return false +} diff --git a/serialization/duration/marshal.go b/serialization/duration/marshal.go new file mode 100644 index 000000000..470fe496a --- /dev/null +++ b/serialization/duration/marshal.go @@ -0,0 +1,38 @@ +package duration + +import ( + "reflect" + "time" +) + +func Marshal(value interface{}) ([]byte, error) { + switch v := value.(type) { + case nil: + return nil, nil + case int64: + return EncInt64(v) + case time.Duration: + return EncDur(v) + case string: + return EncString(v) + case Duration: + return EncDuration(v) + + case *int64: + return EncInt64R(v) + case *time.Duration: + return EncDurR(v) + case *string: + return EncStringR(v) + case *Duration: + return EncDurationR(v) + default: + // Custom types (type MyDate uint32) can be serialized only via `reflect` package. + // Later, when generic-based serialization is introduced we can do that via generics. + rv := reflect.TypeOf(value) + if rv.Kind() != reflect.Ptr { + return EncReflect(reflect.ValueOf(v)) + } + return EncReflectR(reflect.ValueOf(v)) + } +} diff --git a/serialization/duration/marshal_utils.go b/serialization/duration/marshal_utils.go new file mode 100644 index 000000000..4c2020e24 --- /dev/null +++ b/serialization/duration/marshal_utils.go @@ -0,0 +1,189 @@ +package duration + +import ( + "fmt" + "reflect" + "time" +) + +const ( + vintPrefix1 byte = 128 + vintPrefix2 byte = 192 + vintPrefix3 byte = 224 + vintPrefix4 byte = 240 + vintPrefix5 byte = 248 + vintPrefix6 byte = 252 + vintPrefix7 byte = 254 + vintPrefix8 byte = 255 + + nanoDay = 24 * 60 * 60 * 1000 * 1000 * 1000 +) + +func EncInt64(v int64) ([]byte, error) { + return encInt64(v), nil +} + +func EncInt64R(v *int64) ([]byte, error) { + if v == nil { + return nil, nil + } + return encInt64(*v), nil +} + +func EncDur(v time.Duration) ([]byte, error) { + return encDur(v), nil +} + +func EncDurR(v *time.Duration) ([]byte, error) { + if v == nil { + return nil, nil + } + return encDur(*v), nil +} + +func EncString(v string) ([]byte, error) { + if v == "" { + return nil, nil + } + d, err := time.ParseDuration(v) + if err != nil { + return nil, fmt.Errorf("failed to marshal duration: the (string)(%s) have invalid format, %v", v, err) + } + return encDur(d), nil +} + +func EncStringR(v *string) ([]byte, error) { + if v == nil { + return nil, nil + } + return EncString(*v) +} + +func EncDuration(v Duration) ([]byte, error) { + if !v.Valid() { + return nil, fmt.Errorf("failed to marshal duration: the (Duration) values of months (%d), days (%d) and nanoseconds (%d) should have the same sign", v.Months, v.Days, v.Nanoseconds) + } + return append(append(encVint32(encIntZigZag32(v.Months)), encVint32(encIntZigZag32(v.Days))...), encVint64(encIntZigZag64(v.Nanoseconds))...), nil +} + +func EncDurationR(v *Duration) ([]byte, error) { + if v == nil { + return nil, nil + } + if !v.Valid() { + return nil, fmt.Errorf("failed to marshal duration: the (*Duration) values of the months (%d), days (%d) and nanoseconds (%d) should have same sign", v.Months, v.Days, v.Nanoseconds) + } + return append(append(encVint32(encIntZigZag32(v.Months)), encVint32(encIntZigZag32(v.Days))...), encVint64(encIntZigZag64(v.Nanoseconds))...), nil +} + +func EncReflect(v reflect.Value) ([]byte, error) { + switch v.Kind() { + case reflect.Int64: + return encInt64(v.Int()), nil + case reflect.String: + val := v.String() + if val == "" { + return nil, nil + } + d, err := time.ParseDuration(val) + if err != nil { + return nil, fmt.Errorf("failed to marshal duration: the (%T)(%[1]v) have invalid format, %v", v, err) + } + return encDur(d), nil + case reflect.Struct: + if v.Type().String() == "gocql.unsetColumn" { + return nil, nil + } + return nil, fmt.Errorf("failed to marshal duration: unsupported value type (%T)(%[1]v)", v.Interface()) + default: + return nil, fmt.Errorf("failed to marshal duration: unsupported value type (%T)(%[1]v)", v.Interface()) + } +} + +func EncReflectR(v reflect.Value) ([]byte, error) { + if v.IsNil() { + return nil, nil + } + return EncReflect(v.Elem()) +} + +func encDur(v time.Duration) []byte { + return encNanos(encIntZigZagDur(v)) +} + +func encInt64(v int64) []byte { + return encNanos(encIntZigZag64(v)) +} + +func encIntZigZag32(v int32) uint32 { + return uint32((v >> 31) ^ (v << 1)) +} + +func encIntZigZag64(v int64) uint64 { + return uint64((v >> 63) ^ (v << 1)) +} + +func encIntZigZagDur(v time.Duration) uint64 { + return uint64((v >> 63) ^ (v << 1)) +} + +func encVint32(v uint32) []byte { + switch { + case byte(v>>28) != 0: + return []byte{vintPrefix4, byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} + case byte(v>>21) != 0: + return []byte{vintPrefix3 | byte(v>>24), byte(v >> 16), byte(v >> 8), byte(v)} + case byte(v>>14) != 0: + return []byte{vintPrefix2 | byte(v>>16), byte(v >> 8), byte(v)} + case byte(v>>7) != 0: + return []byte{vintPrefix1 | byte(v>>8), byte(v)} + default: + return []byte{byte(v)} + } +} + +func encVint64(v uint64) []byte { + switch { + case byte(v>>56) != 0: + return []byte{vintPrefix8, byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} + case byte(v>>49) != 0: + return []byte{vintPrefix7 | byte(v>>56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} + case byte(v>>42) != 0: + return []byte{vintPrefix6 | byte(v>>48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} + case byte(v>>35) != 0: + return []byte{vintPrefix5 | byte(v>>40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} + case byte(v>>28) != 0: + return []byte{vintPrefix4 | byte(v>>32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} + case byte(v>>21) != 0: + return []byte{vintPrefix3 | byte(v>>24), byte(v >> 16), byte(v >> 8), byte(v)} + case byte(v>>14) != 0: + return []byte{vintPrefix2 | byte(v>>16), byte(v >> 8), byte(v)} + case byte(v>>7) != 0: + return []byte{vintPrefix1 | byte(v>>8), byte(v)} + default: + return []byte{byte(v)} + } +} + +func encNanos(v uint64) []byte { + switch { + case byte(v>>56) != 0: + return []byte{0, 0, vintPrefix8, byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} + case byte(v>>49) != 0: + return []byte{0, 0, vintPrefix7 | byte(v>>56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} + case byte(v>>42) != 0: + return []byte{0, 0, vintPrefix6 | byte(v>>48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} + case byte(v>>35) != 0: + return []byte{0, 0, vintPrefix5 | byte(v>>40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} + case byte(v>>28) != 0: + return []byte{0, 0, vintPrefix4 | byte(v>>32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} + case byte(v>>21) != 0: + return []byte{0, 0, vintPrefix3 | byte(v>>24), byte(v >> 16), byte(v >> 8), byte(v)} + case byte(v>>14) != 0: + return []byte{0, 0, vintPrefix2 | byte(v>>16), byte(v >> 8), byte(v)} + case byte(v>>7) != 0: + return []byte{0, 0, vintPrefix1 | byte(v>>8), byte(v)} + default: + return []byte{0, 0, byte(v)} + } +} diff --git a/serialization/duration/marshal_vint_test.go b/serialization/duration/marshal_vint_test.go new file mode 100644 index 000000000..aa1ae9567 --- /dev/null +++ b/serialization/duration/marshal_vint_test.go @@ -0,0 +1,61 @@ +package duration + +import ( + "bytes" + "math" + "math/bits" + "testing" +) + +func TestEncVint32(t *testing.T) { + for i := int32(math.MaxInt32); i != 1; i = i / 2 { + testEnc32(t, i) + testEnc32(t, -i-1) + } +} + +func TestEncVint64(t *testing.T) { + for i := int64(math.MaxInt64); i != 1; i = i / 2 { + testEnc64(t, i) + testEnc64(t, -i-1) + } +} + +func testEnc32(t *testing.T, v int32) { + t.Helper() + expected := genVintData(int64(v)) + received := encVint32(encIntZigZag32(v)) + + if !bytes.Equal(expected, received) { + t.Fatalf("expected and recieved data not equal\nvalue:%d\ndata expected:%b\ndata received:%b", v, expected, received) + } +} + +func testEnc64(t *testing.T, v int64) { + t.Helper() + expected := genVintData(v) + received := encVint64(encIntZigZag64(v)) + + if !bytes.Equal(expected, received) { + t.Fatalf("expected and recieved data not equal\nvalue:%d\ndata expected:%b\ndata received:%b", v, expected, received) + } +} + +func genVintData(v int64) []byte { + vEnc := encIntZigZag64(v) + lead0 := bits.LeadingZeros64(vEnc) + numBytes := (639 - lead0*9) >> 6 + + // It can be 1 or 0 is v ==0 + if numBytes <= 1 { + return []byte{byte(vEnc)} + } + extraBytes := numBytes - 1 + var buf = make([]byte, numBytes) + for i := extraBytes; i >= 0; i-- { + buf[i] = byte(vEnc) + vEnc >>= 8 + } + buf[0] |= byte(^(0xff >> uint(extraBytes))) + return buf +} diff --git a/serialization/duration/unmarshal.go b/serialization/duration/unmarshal.go new file mode 100644 index 000000000..010b0ca73 --- /dev/null +++ b/serialization/duration/unmarshal.go @@ -0,0 +1,45 @@ +package duration + +import ( + "fmt" + "reflect" + "time" +) + +func Unmarshal(data []byte, value interface{}) error { + switch v := value.(type) { + case nil: + return nil + + case *int64: + return DecInt64(data, v) + case *string: + return DecString(data, v) + case *time.Duration: + return DecDur(data, v) + case *Duration: + return DecDuration(data, v) + + case **int64: + return DecInt64R(data, v) + case **string: + return DecStringR(data, v) + case **time.Duration: + return DecDurR(data, v) + case **Duration: + return DecDurationR(data, v) + default: + + // Custom types (type MyDate uint32) can be deserialized only via `reflect` package. + // Later, when generic-based serialization is introduced we can do that via generics. + rv := reflect.ValueOf(value) + rt := rv.Type() + if rt.Kind() != reflect.Ptr { + return fmt.Errorf("failed to unmarshal duration: unsupported value type (%T)(%[1]v)", value) + } + if rt.Elem().Kind() != reflect.Ptr { + return DecReflect(data, rv) + } + return DecReflectR(data, rv) + } +} diff --git a/serialization/duration/unmarshal_utils.go b/serialization/duration/unmarshal_utils.go new file mode 100644 index 000000000..43bf83293 --- /dev/null +++ b/serialization/duration/unmarshal_utils.go @@ -0,0 +1,531 @@ +package duration + +import ( + "fmt" + "math" + "reflect" + "time" +) + +const ( + maxDays = (math.MaxInt64 - math.MaxInt64%nanoDay) / nanoDay + minDays = -maxDays + maxDaysNanos = maxDays * nanoDay + minDaysNanos = minDays * nanoDay + zeroDuration = "0s" +) + +var ( + errWrongDataLen = fmt.Errorf("failed to unmarshal duration: the length of the data should be 0 or 3-19") + errBrokenData = fmt.Errorf("failed to unmarshal duration: the data is broken") + errInvalidSign = fmt.Errorf("failed to unmarshal duration: the data values of months, days and nanoseconds should have the same sign") +) + +func errNilReference(v interface{}) error { + return fmt.Errorf("failed to unmarshal duration: can not unmarshal into nil reference (%T)(%[1]v))", v) +} + +func DecInt64(p []byte, v *int64) error { + if v == nil { + return errNilReference(v) + } + switch l := len(p); { + case l == 0: + *v = 0 + case l < 3: + return errWrongDataLen + default: + if p[0] != 0 || p[1] != 0 { + return fmt.Errorf("failed to unmarshal duration: to unmarshal into (int64) the data values of the months and days should be 0") + } + var ok bool + if *v, ok = decNanos64(p); !ok { + return errBrokenData + } + } + return nil +} + +func DecInt64R(p []byte, v **int64) error { + if v == nil { + return errNilReference(v) + } + switch l := len(p); { + case l == 0: + if p == nil { + *v = nil + } else { + *v = new(int64) + } + case l < 3: + return errWrongDataLen + default: + if p[0] != 0 || p[1] != 0 { + return fmt.Errorf("failed to unmarshal duration: to unmarshal into (*int64) the data values of the months and days should be 0") + } + n, ok := decNanos64(p) + if !ok { + return errBrokenData + } + *v = &n + } + return nil +} + +func DecString(p []byte, v *string) error { + if v == nil { + return errNilReference(v) + } + switch l := len(p); { + case l == 0: + if p == nil { + *v = "" + } else { + *v = zeroDuration + } + case l < 3: + return errWrongDataLen + default: + if p[0] != 0 || p[1] != 0 { + return fmt.Errorf("failed to unmarshal duration: to unmarshal into (string) the data values of the months and days should be 0") + } + n, ok := decNanosDur(p) + if !ok { + return errBrokenData + } + *v = n.String() + } + return nil +} + +func DecStringR(p []byte, v **string) error { + if v == nil { + return errNilReference(v) + } + switch l := len(p); { + case l == 0: + if p == nil { + *v = nil + } else { + val := zeroDuration + *v = &val + } + case l < 3: + return errWrongDataLen + default: + if p[0] != 0 || p[1] != 0 { + return fmt.Errorf("failed to unmarshal duration: to unmarshal into (*string) the data values of the months and days should be 0") + } + n, ok := decNanosDur(p) + if !ok { + return errBrokenData + } + val := n.String() + *v = &val + } + return nil +} + +func DecDur(p []byte, v *time.Duration) error { + if v == nil { + return errNilReference(v) + } + switch l := len(p); { + case l == 0: + *v = 0 + case l < 3: + return errWrongDataLen + default: + if p[0] != 0 || p[1] != 0 { + return fmt.Errorf("failed to unmarshal duration: to unmarshal into (time.Duration) the data values of the months and days should be 0") + } + var ok bool + if *v, ok = decNanosDur(p); !ok { + return errBrokenData + } + } + return nil +} + +func DecDurR(p []byte, v **time.Duration) error { + if v == nil { + return errNilReference(v) + } + switch l := len(p); { + case l == 0: + if p == nil { + *v = nil + } else { + *v = new(time.Duration) + } + case l < 3: + return errWrongDataLen + default: + if p[0] != 0 { + return fmt.Errorf("failed to unmarshal duration: to unmarshal into (*time.Duration) the data values of the months and days should be 0") + } + n, ok := decNanosDur(p) + if !ok { + return errBrokenData + } + *v = &n + } + return nil +} + +func DecDuration(p []byte, v *Duration) error { + if v == nil { + return errNilReference(v) + } + switch l := len(p); { + case l == 0: + *v = Duration{} + case l < 3: + return errWrongDataLen + default: + var ok bool + v.Months, v.Days, v.Nanoseconds, ok = decVints(p) + if !ok { + return errBrokenData + } + if !v.Valid() { + return errInvalidSign + } + } + return nil +} + +func DecDurationR(p []byte, v **Duration) error { + if v == nil { + return errNilReference(v) + } + switch l := len(p); { + case l == 0: + if p == nil { + *v = nil + } else { + *v = new(Duration) + } + case l < 3: + return errWrongDataLen + default: + var ok bool + var val Duration + val.Months, val.Days, val.Nanoseconds, ok = decVints(p) + if !ok { + return errBrokenData + } + if !val.Valid() { + return errInvalidSign + } + *v = &val + } + return nil +} + +func DecReflect(p []byte, v reflect.Value) error { + if v.IsNil() { + return fmt.Errorf("failed to unmarshal duration: can not unmarshal into nil reference (%T)(%[1]v))", v.Interface()) + } + + switch v = v.Elem(); v.Kind() { + case reflect.Int64: + return decReflectInt64(p, v) + case reflect.String: + return decReflectString(p, v) + default: + return fmt.Errorf("failed to unmarshal duration: unsupported value type (%T)(%[1]v)", v.Interface()) + } +} + +func decReflectInt64(p []byte, v reflect.Value) error { + switch l := len(p); { + case l == 0: + v.SetInt(0) + case l < 3: + return errWrongDataLen + default: + if p[0] != 0 || p[1] != 0 { + return fmt.Errorf("failed to unmarshal duration: to unmarshal into (%T) the data values of the months and days value should be 0", v.Interface()) + } + n, ok := decNanos64(p) + if !ok { + return errBrokenData + } + v.SetInt(n) + } + return nil +} + +func decReflectString(p []byte, v reflect.Value) error { + switch l := len(p); { + case l == 0: + if p == nil { + v.SetString("") + } else { + v.SetString(zeroDuration) + } + case l < 3: + return errWrongDataLen + default: + if p[0] != 0 || p[1] != 0 { + return fmt.Errorf("failed to unmarshal duration: to unmarshal into (%T) the data values of the months and days value should be 0", v.Interface()) + } + n, ok := decNanosDur(p) + if !ok { + return errBrokenData + } + v.SetString(n.String()) + } + return nil +} + +func DecReflectR(p []byte, v reflect.Value) error { + if v.IsNil() { + return fmt.Errorf("failed to unmarshal duration: can not unmarshal into nil reference (%T)(%[1]v)", v.Interface()) + } + + switch v.Type().Elem().Elem().Kind() { + case reflect.Int64: + return decReflectInt64R(p, v) + case reflect.String: + return decReflectStringR(p, v) + default: + return fmt.Errorf("failed to unmarshal duration: unsupported value type (%T)(%[1]v)", v.Interface()) + } +} + +func decReflectInt64R(p []byte, v reflect.Value) error { + switch l := len(p); { + case l == 0: + var val reflect.Value + if p == nil { + val = reflect.Zero(v.Type().Elem()) + } else { + val = reflect.New(v.Type().Elem().Elem()) + } + v.Elem().Set(val) + case l < 3: + return errWrongDataLen + default: + if p[0] != 0 || p[1] != 0 { + return fmt.Errorf("failed to unmarshal duration: to unmarshal into (%T) the data values of the months and days value should be 0", v.Interface()) + } + n, ok := decNanos64(p) + if !ok { + return errBrokenData + } + val := reflect.New(v.Type().Elem().Elem()) + val.Elem().SetInt(n) + v.Elem().Set(val) + } + return nil +} + +func decReflectStringR(p []byte, v reflect.Value) error { + switch l := len(p); { + case l == 0: + var val reflect.Value + if p == nil { + val = reflect.Zero(v.Type().Elem()) + } else { + val = reflect.New(v.Type().Elem().Elem()) + val.Elem().SetString(zeroDuration) + } + v.Elem().Set(val) + case l < 3: + return errWrongDataLen + default: + if p[0] != 0 || p[1] != 0 { + return fmt.Errorf("failed to unmarshal duration: to unmarshal into (%T) the data values of the months and days value should be 0", v.Interface()) + } + n, ok := decNanosDur(p) + if !ok { + return errBrokenData + } + val := reflect.New(v.Type().Elem().Elem()) + val.Elem().SetString(n.String()) + v.Elem().Set(val) + } + return nil +} + +func decVints(p []byte) (int32, int32, int64, bool) { + m, read := decVint32(p, 0) + if read == 0 { + return 0, 0, 0, false + } + d, read := decVint32(p, read) + if read == 0 { + return 0, 0, 0, false + } + n, read := decVint64(p, read) + if read == 0 { + return 0, 0, 0, false + } + return decZigZag32(m), decZigZag32(d), decZigZag64(n), true +} + +func decNanos64(p []byte) (int64, bool) { + n, read := decVint64(p, 2) + if read == 0 { + return 0, false + } + return decZigZag64(n), true +} + +func decNanosDur(p []byte) (time.Duration, bool) { + n, read := decVint64(p, 2) + if read == 0 { + return 0, false + } + return decZigZagDur(n), true +} + +func decVint64(p []byte, s int) (uint64, int) { + vintLen := decVintLen(p[s:]) + if vintLen+s != len(p) { + return 0, 0 + } + switch vintLen { + case 9: + return dec9Vint64(p[s:]), s + 9 + case 8: + return dec8Vint64(p[s:]), s + 8 + case 7: + return dec7Vint64(p[s:]), s + 7 + case 6: + return dec6Vint64(p[s:]), s + 6 + case 5: + return dec5Vint64(p[s:]), s + 5 + case 4: + return dec4Vint64(p[s:]), s + 4 + case 3: + return dec3Vint64(p[s:]), s + 3 + case 2: + return dec2Vint64(p[s:]), s + 2 + case 1: + return dec1Vint64(p[s:]), s + 1 + case 0: + return 0, s + 1 + default: + return 0, 0 + } +} + +func decVint32(p []byte, s int) (uint32, int) { + vintLen := decVintLen(p[s:]) + if vintLen+s >= len(p) { + return 0, 0 + } + switch vintLen { + case 5: + if p[s] != vintPrefix4 { + return 0, 0 + } + return dec5Vint32(p[s:]), s + 5 + case 4: + return dec4Vint32(p[s:]), s + 4 + case 3: + return dec3Vint32(p[s:]), s + 3 + case 2: + return dec2Vint32(p[s:]), s + 2 + case 1: + return dec1Vint32(p[s:]), s + 1 + case 0: + return 0, s + 1 + default: + return 0, 0 + } +} + +func decVintLen(p []byte) int { + switch { + case p[0] == 255: + return 9 + case p[0]>>1 == 127: + return 8 + case p[0]>>2 == 63: + return 7 + case p[0]>>3 == 31: + return 6 + case p[0]>>4 == 15: + return 5 + case p[0]>>5 == 7: + return 4 + case p[0]>>6 == 3: + return 3 + case p[0]>>7 == 1: + return 2 + default: + return 1 + } +} + +func decZigZag32(n uint32) int32 { + return int32((n >> 1) ^ -(n & 1)) +} + +func decZigZag64(n uint64) int64 { + return int64((n >> 1) ^ -(n & 1)) +} + +func decZigZagDur(n uint64) time.Duration { + return time.Duration((n >> 1) ^ -(n & 1)) +} + +func dec5Vint32(p []byte) uint32 { + return uint32(p[1])<<24 | uint32(p[2])<<16 | uint32(p[3])<<8 | uint32(p[4]) +} + +func dec4Vint32(p []byte) uint32 { + return uint32(p[0]&^vintPrefix3)<<24 | uint32(p[1])<<16 | uint32(p[2])<<8 | uint32(p[3]) +} + +func dec3Vint32(p []byte) uint32 { + return uint32(p[0]&^vintPrefix2)<<16 | uint32(p[1])<<8 | uint32(p[2]) +} + +func dec2Vint32(p []byte) uint32 { + return uint32(p[0]&^vintPrefix1)<<8 | uint32(p[1]) +} + +func dec1Vint32(p []byte) uint32 { + return uint32(p[0]) +} + +func dec9Vint64(p []byte) uint64 { + return uint64(p[1])<<56 | uint64(p[2])<<48 | uint64(p[3])<<40 | uint64(p[4])<<32 | uint64(p[5])<<24 | uint64(p[6])<<16 | uint64(p[7])<<8 | uint64(p[8]) +} + +func dec8Vint64(p []byte) uint64 { + return uint64(p[0]&^vintPrefix7)<<56 | uint64(p[1])<<48 | uint64(p[2])<<40 | uint64(p[3])<<32 | uint64(p[4])<<24 | uint64(p[5])<<16 | uint64(p[6])<<8 | uint64(p[7]) +} + +func dec7Vint64(p []byte) uint64 { + return uint64(p[0]&^vintPrefix6)<<48 | uint64(p[1])<<40 | uint64(p[2])<<32 | uint64(p[3])<<24 | uint64(p[4])<<16 | uint64(p[5])<<8 | uint64(p[6]) +} + +func dec6Vint64(p []byte) uint64 { + return uint64(p[0]&^vintPrefix5)<<40 | uint64(p[1])<<32 | uint64(p[2])<<24 | uint64(p[3])<<16 | uint64(p[4])<<8 | uint64(p[5]) +} + +func dec5Vint64(p []byte) uint64 { + return uint64(p[0]&^vintPrefix4)<<32 | uint64(p[1])<<24 | uint64(p[2])<<16 | uint64(p[3])<<8 | uint64(p[4]) +} + +func dec4Vint64(p []byte) uint64 { + return uint64(p[0]&^vintPrefix3)<<24 | uint64(p[1])<<16 | uint64(p[2])<<8 | uint64(p[3]) +} + +func dec3Vint64(p []byte) uint64 { + return uint64(p[0]&^vintPrefix2)<<16 | uint64(p[1])<<8 | uint64(p[2]) +} + +func dec2Vint64(p []byte) uint64 { + return uint64(p[0]&^vintPrefix1)<<8 | uint64(p[1]) +} + +func dec1Vint64(p []byte) uint64 { + return uint64(p[0]) +} diff --git a/serialization/duration/unmarshal_vint_test.go b/serialization/duration/unmarshal_vint_test.go new file mode 100644 index 000000000..46b47e10d --- /dev/null +++ b/serialization/duration/unmarshal_vint_test.go @@ -0,0 +1,51 @@ +package duration + +import ( + "math" + "testing" +) + +func TestDecVint32(t *testing.T) { + for i := int32(math.MaxInt32); i != 1; i = i / 2 { + testDec32(t, i) + testDec32(t, -i-1) + } +} + +func TestDecVint64(t *testing.T) { + for i := int64(math.MaxInt64); i != 1; i = i / 2 { + testDec64(t, i) + testDec64(t, -i-1) + } +} + +func testDec32(t *testing.T, expected int32) { + t.Helper() + // appending one byte is necessary because the `decVint32` function looks at the length of the data for the next vint len read. + data := append(genVintData(int64(expected)), 0) + + vint, read := decVint32(data, 0) + if read == 0 { + t.Fatalf("decVint32 function can`t read vint data: value %d, data %b", expected, data) + } + + received := decZigZag32(vint) + if expected != received { + t.Fatalf("\nexpected:%d\nreceived:%d\ndata:%b", expected, received, data) + } +} + +func testDec64(t *testing.T, expected int64) { + t.Helper() + data := genVintData(int64(expected)) + + vint, read := decVint64(data, 0) + if read == 0 { + t.Fatalf("decVint64 function can`t read vint data: value %d, data %b", expected, data) + } + + received := decZigZag64(vint) + if expected != received { + t.Fatalf("\nexpected:%d\nreceived:%d\ndata:%b", expected, received, data) + } +}