diff --git a/block.go b/block.go index 85cefab..b750620 100644 --- a/block.go +++ b/block.go @@ -117,7 +117,11 @@ func UnmarshalBlock(r io.Reader) (*Block, error) { // MarshalBlock marshals EBML Block structure func MarshalBlock(b *Block, w io.Writer) error { - if _, err := w.Write(encodeVInt(b.TrackNumber)); err != nil { + n, err := encodeElementID(b.TrackNumber) + if err != nil { + return err + } + if _, err := w.Write(n); err != nil { return err } if _, err := w.Write([]byte{byte(b.Timecode >> 8), byte(b.Timecode)}); err != nil { diff --git a/marshal.go b/marshal.go index ed10fed..3965e80 100644 --- a/marshal.go +++ b/marshal.go @@ -90,7 +90,7 @@ func marshalImpl(vo reflect.Value, w io.Writer) error { var bw io.Writer if inf { // Directly write length unspecified element - bsz := encodeVInt(uint64(sizeInf)) + bsz := encodeDataSize(uint64(sizeInf)) if _, err := w.Write(bsz); err != nil { return err } @@ -115,7 +115,7 @@ func marshalImpl(vo reflect.Value, w io.Writer) error { // Write element with length if !inf { - bsz := encodeVInt(uint64(bw.(*bytes.Buffer).Len())) + bsz := encodeDataSize(uint64(bw.(*bytes.Buffer).Len())) if _, err := w.Write(bsz); err != nil { return err } diff --git a/unmarshal.go b/unmarshal.go index f7d4f40..430c084 100644 --- a/unmarshal.go +++ b/unmarshal.go @@ -101,7 +101,7 @@ func readElement(r0 io.Reader, n int64, vo reflect.Value) (io.Reader, error) { switch v.t { case TypeMaster: if v.top && !vnext.IsValid() { - b := bytes.Join([][]byte{table[v.e].b, encodeVInt(size)}, []byte{}) + b := bytes.Join([][]byte{table[v.e].b, encodeDataSize(size)}, []byte{}) return bytes.NewBuffer(b), io.EOF } var vn reflect.Value diff --git a/value.go b/value.go index 19ca53c..757bec1 100644 --- a/value.go +++ b/value.go @@ -28,8 +28,9 @@ const ( ) var ( - errInvalidFloatSize = errors.New("Invalid float size") - errInvalidType = errors.New("Invalid type") + errInvalidFloatSize = errors.New("Invalid float size") + errInvalidType = errors.New("Invalid type") + errUnsupportedElementID = errors.New("Unsupported Element ID") ) var perTypeReader = map[Type]func(io.Reader, uint64) (interface{}, error){ @@ -180,20 +181,20 @@ var perTypeEncoder = map[Type]func(interface{}) ([]byte, error){ TypeBlock: encodeBlock, } -func encodeVInt(v uint64) []byte { - if v < 0x80 { +func encodeDataSize(v uint64) []byte { + if v < 0x80-1 { return []byte{byte(v) | 0x80} - } else if v < 0x4000 { + } else if v < 0x4000-1 { return []byte{byte(v>>8) | 0x40, byte(v)} - } else if v < 0x200000 { + } else if v < 0x200000-1 { return []byte{byte(v>>16) | 0x20, byte(v >> 8), byte(v)} - } else if v < 0x10000000 { + } else if v < 0x10000000-1 { return []byte{byte(v>>24) | 0x10, byte(v >> 16), byte(v >> 8), byte(v)} - } else if v < 0x8000000000 { + } else if v < 0x800000000-1 { return []byte{byte(v>>32) | 0x8, byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} - } else if v < 0x400000000000 { + } else if v < 0x40000000000-1 { return []byte{byte(v>>40) | 0x4, byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} - } else if v < 0x20000000000000 { + } else if v < 0x2000000000000-1 { return []byte{byte(v>>48) | 0x2, byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} } else if v < sizeInf { return []byte{0x1, byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} @@ -201,6 +202,24 @@ func encodeVInt(v uint64) []byte { return []byte{0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff} } } +func encodeElementID(v uint64) ([]byte, error) { + if v < 0x80 { + return []byte{byte(v) | 0x80}, nil + } else if v < 0x4000 { + return []byte{byte(v>>8) | 0x40, byte(v)}, nil + } else if v < 0x200000 { + return []byte{byte(v>>16) | 0x20, byte(v >> 8), byte(v)}, nil + } else if v < 0x10000000 { + return []byte{byte(v>>24) | 0x10, byte(v >> 16), byte(v >> 8), byte(v)}, nil + } else if v < 0x800000000 { + return []byte{byte(v>>32) | 0x8, byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil + } else if v < 0x40000000000 { + return []byte{byte(v>>40) | 0x4, byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil + } else if v < 0x2000000000000 { + return []byte{byte(v>>48) | 0x2, byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil + } + return nil, errUnsupportedElementID +} func encodeBinary(i interface{}) ([]byte, error) { v, ok := i.([]byte) if !ok { diff --git a/value_test.go b/value_test.go index 3849617..83cb96a 100644 --- a/value_test.go +++ b/value_test.go @@ -7,20 +7,27 @@ import ( "time" ) -func TestVInt(t *testing.T) { +func TestDataSize(t *testing.T) { testCases := map[string]struct { b []byte i uint64 }{ - "1 byte": {[]byte{0x81}, 0x01}, - "2 bytes": {[]byte{0x41, 0x23}, 0x0123}, - "3 bytes": {[]byte{0x21, 0x23, 0x45}, 0x012345}, - "4 bytes": {[]byte{0x11, 0x23, 0x45, 0x67}, 0x01234567}, - "5 bytes": {[]byte{0x09, 0x23, 0x45, 0x67, 0x89}, 0x0123456789}, - "6 bytes": {[]byte{0x05, 0x23, 0x45, 0x67, 0x89, 0xab}, 0x0123456789ab}, - "7 bytes": {[]byte{0x03, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd}, 0x0123456789abcd}, - "8 bytes": {[]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}, 0x23456789abcdef}, - "Indefinite": {[]byte{0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, sizeInf}, + "1 byte (upper bound)": {[]byte{0xFE}, 0x80 - 2}, + "2 bytes (lower bound)": {[]byte{0x40, 0x7F}, 0x80 - 1}, + "2 bytes (upper bound)": {[]byte{0x7F, 0xFE}, 0x4000 - 2}, + "3 bytes (lower bound)": {[]byte{0x20, 0x3F, 0xFF}, 0x4000 - 1}, + "3 bytes (upper bound)": {[]byte{0x3F, 0xFF, 0xFE}, 0x200000 - 2}, + "4 bytes (lower bound)": {[]byte{0x10, 0x1F, 0xFF, 0xFF}, 0x200000 - 1}, + "4 bytes (upper bound)": {[]byte{0x1F, 0xFF, 0xFF, 0xFE}, 0x10000000 - 2}, + "5 bytes (lower bound)": {[]byte{0x08, 0x0F, 0xFF, 0xFF, 0xFF}, 0x10000000 - 1}, + "5 bytes (upper bound)": {[]byte{0x0F, 0xFF, 0xFF, 0xFF, 0xFE}, 0x800000000 - 2}, + "6 bytes (lower bound)": {[]byte{0x04, 0x07, 0xFF, 0xFF, 0xFF, 0xFF}, 0x800000000 - 1}, + "6 bytes (upper bound)": {[]byte{0x07, 0xFF, 0xFF, 0xFF, 0xFF, 0xFE}, 0x40000000000 - 2}, + "7 bytes (lower bound)": {[]byte{0x02, 0x03, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, 0x40000000000 - 1}, + "7 bytes (upper bound)": {[]byte{0x03, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFE}, 0x2000000000000 - 2}, + "8 bytes (lower bound)": {[]byte{0x01, 0x01, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, 0x2000000000000 - 1}, + "8 bytes (upper bound)": {[]byte{0x01, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFE}, 0xffffffffffffff - 1}, + "Indefinite": {[]byte{0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, sizeInf}, } for n, c := range testCases { @@ -36,14 +43,64 @@ func TestVInt(t *testing.T) { } for n, c := range testCases { t.Run("Encode "+n, func(t *testing.T) { - b := encodeVInt(c.i) + b := encodeDataSize(c.i) if bytes.Compare(b, c.b) != 0 { - t.Errorf("Unexpected encodeVInt result, expected: %d, got: %d", c.b, b) + t.Errorf("Unexpected encodeDataSize result, expected: %d, got: %d", c.b, b) } }) } } +func TestElementID(t *testing.T) { + testCases := map[string]struct { + b []byte + i uint64 + }{ + "1 byte (upper bound)": {[]byte{0xFF}, 0x80 - 1}, + "2 bytes (lower bound)": {[]byte{0x40, 0x80}, 0x80}, + "2 bytes (upper bound)": {[]byte{0x7F, 0xFF}, 0x4000 - 1}, + "3 bytes (lower bound)": {[]byte{0x20, 0x40, 0x00}, 0x4000}, + "3 bytes (upper bound)": {[]byte{0x3F, 0xFF, 0xFF}, 0x200000 - 1}, + "4 bytes (lower bound)": {[]byte{0x10, 0x20, 0x00, 0x00}, 0x200000}, + "4 bytes (upper bound)": {[]byte{0x1F, 0xFF, 0xFF, 0xFF}, 0x10000000 - 1}, + "5 bytes (lower bound)": {[]byte{0x08, 0x10, 0x00, 0x00, 0x00}, 0x10000000}, + "5 bytes (upper bound)": {[]byte{0x0F, 0xFF, 0xFF, 0xFF, 0xFF}, 0x800000000 - 1}, + "6 bytes (lower bound)": {[]byte{0x04, 0x08, 0x00, 0x00, 0x00, 0x00}, 0x800000000}, + "6 bytes (upper bound)": {[]byte{0x07, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, 0x40000000000 - 1}, + "7 bytes (lower bound)": {[]byte{0x02, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00}, 0x40000000000}, + "7 bytes (upper bound)": {[]byte{0x03, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, 0x2000000000000 - 1}, + } + + for n, c := range testCases { + t.Run("Decode "+n, func(t *testing.T) { + r, err := readVInt(bytes.NewBuffer(c.b)) + if err != nil { + t.Fatalf("Failed to readVInt: %v", err) + } + if r != c.i { + t.Errorf("Unexpected readVInt result, expected: %d, got: %d", c.i, r) + } + }) + } + for n, c := range testCases { + t.Run("Encode "+n, func(t *testing.T) { + b, err := encodeElementID(c.i) + if err != nil { + t.Fatalf("Failed to encodeElementID: %v", err) + } + if bytes.Compare(b, c.b) != 0 { + t.Errorf("Unexpected encodeDataSize result, expected: %d, got: %d", c.b, b) + } + }) + } + + _, err := encodeElementID(0x2000000000000) + if err != errUnsupportedElementID { + t.Errorf("Unexpected error type result, expected: %s, got: %s", errUnsupportedElementID, err) + } + +} + func TestValue(t *testing.T) { testCases := map[string]struct { b []byte