diff --git a/block.go b/block.go index b750620..de73d5f 100644 --- a/block.go +++ b/block.go @@ -79,7 +79,7 @@ type Lace struct { func UnmarshalBlock(r io.Reader) (*Block, error) { var b Block var err error - if b.TrackNumber, err = readVInt(r); err != nil { + if b.TrackNumber, _, err = readVInt(r); err != nil { return nil, err } if v, err := readInt(r, 2); err == nil { diff --git a/elementtable.go b/elementtable.go index 862a4cf..30baea5 100644 --- a/elementtable.go +++ b/elementtable.go @@ -91,7 +91,7 @@ func init() { revTable = make(elementRevTable) for k, v := range table { - e, err := readVInt(bytes.NewBuffer(v.b)) + e, _, err := readVInt(bytes.NewBuffer(v.b)) if err != nil { panic(err) } diff --git a/unmarshal.go b/unmarshal.go index 430c084..02a4e67 100644 --- a/unmarshal.go +++ b/unmarshal.go @@ -33,15 +33,19 @@ var ( ) // Unmarshal EBML stream -func Unmarshal(r io.Reader, val interface{}) error { +func Unmarshal(r io.Reader, val interface{}, opts ...UnmarshalOption) error { + options := &UnmarshalOptions{} + for _, o := range opts { + o(options) + } + vo := reflect.ValueOf(val) if !vo.IsValid() { return errIndefiniteType } voe := vo.Elem() - for { - if _, err := readElement(r, sizeInf, voe); err != nil { + if _, err := readElement(r, sizeInf, voe, 0, nil, options); err != nil { if err == io.EOF { return nil } @@ -50,7 +54,7 @@ func Unmarshal(r io.Reader, val interface{}) error { } } -func readElement(r0 io.Reader, n int64, vo reflect.Value) (io.Reader, error) { +func readElement(r0 io.Reader, n int64, vo reflect.Value, pos uint64, parent *Element, options *UnmarshalOptions) (io.Reader, error) { var r io.Reader if n != sizeInf { r = io.LimitReader(r0, n) @@ -80,7 +84,9 @@ func readElement(r0 io.Reader, n int64, vo reflect.Value) (io.Reader, error) { } for { - e, err := readVInt(r) + var headerSize uint64 + e, nb, err := readVInt(r) + headerSize += uint64(nb) if err != nil { return nil, err } @@ -89,7 +95,8 @@ func readElement(r0 io.Reader, n int64, vo reflect.Value) (io.Reader, error) { return nil, errUnknownElement } - size, err := readVInt(r) + size, nb, err := readVInt(r) + headerSize += uint64(nb) if err != nil { return nil, err } @@ -116,7 +123,19 @@ func readElement(r0 io.Reader, n int64, vo reflect.Value) (io.Reader, error) { vn = vnext } } - r0, err := readElement(r, int64(size), vn) + + elem := &Element{ + Value: vn.Interface(), + Name: v.e.String(), + Position: pos, + Size: size, + Parent: parent, + } + r0, err := readElement(r, int64(size), vn, pos+headerSize, elem, options) + for _, hook := range options.hooks { + hook(elem) + } + if err != nil && err != io.EOF { return r0, err } @@ -137,5 +156,28 @@ func readElement(r0 io.Reader, n int64, vo reflect.Value) (io.Reader, error) { } } } + pos += headerSize + size + } +} + +type UnmarshalOption func(*UnmarshalOptions) + +type UnmarshalOptions struct { + hooks []func(elem *Element) +} + +// Element represents an EBML element +type Element struct { + Value interface{} + Name string + Position uint64 + Size uint64 + Parent *Element +} + +// WithElementReadHooks returns an UnmarshalOption which registers element hooks +func WithElementReadHooks(hooks ...func(*Element)) UnmarshalOption { + return func(opts *UnmarshalOptions) { + opts.hooks = hooks } } diff --git a/unmarshal_test.go b/unmarshal_test.go index 37c83d5..cf9ee16 100644 --- a/unmarshal_test.go +++ b/unmarshal_test.go @@ -46,6 +46,59 @@ func ExampleUnmarshal() { // Output: {{webm 2 2}} } +func TestUnmarshal_WithElementReadHooks(t *testing.T) { + TestBinary := []byte{ + 0x18, 0x53, 0x80, 0x67, 0xa1, // Segment + 0x16, 0x54, 0xae, 0x6b, 0x9c, // Tracks + 0xae, 0x8c, // TrackEntry[0] + 0x53, 0x6e, 0x86, 0x56, 0x69, 0x64, 0x65, 0x6f, 0x00, // Name=Video + 0xd7, 0x81, 0x01, // TrackNumber=1 + 0xae, 0x8c, // TrackEntry[1] + 0x53, 0x6e, 0x86, 0x41, 0x75, 0x64, 0x69, 0x6f, 0x00, // Name=Audio + 0xd7, 0x81, 0x02, // TrackNumber=2 + } + + type TestEBML struct { + Segment struct { + Tracks struct { + TrackEntry []struct { + Name string `ebml:"Name,omitempty"` + TrackNumber uint64 `ebml:"TrackNumber"` + } `ebml:"TrackEntry"` + } `ebml:"Tracks"` + } `ebml:"Segment"` + } + + r := bytes.NewReader(TestBinary) + + var ret TestEBML + m := make(map[string][]*Element) + hook := withElementMap(m) + if err := Unmarshal(r, &ret, WithElementReadHooks(hook)); err != nil { + t.Errorf("error: %+v\n", err) + } + + // Verify positions of elements + expected := map[string][]uint64{ + "Segment.Tracks": {5}, + "Segment.Tracks.TrackEntry": {10, 24}, + } + for key, positions := range expected { + elem, ok := m[key] + if !ok { + t.Errorf("Key '%s' doesn't exist\n", key) + } + if len(elem) != len(positions) { + t.Errorf("Unexpected element size of '%s', expected: %d, got: %d\n", key, len(positions), len(elem)) + } + for i, pos := range positions { + if elem[i].Position != pos { + t.Errorf("Unexpected element positon of '%s[%d]', expected: %d, got: %d\n", key, i, pos, elem[i].Position) + } + } + } +} + func TestUnmarshal_Tag(t *testing.T) { var tagged struct { DocCustomNamedType string `ebml:"EBMLDocType"` @@ -92,3 +145,23 @@ func BenchmarkUnmarshal(b *testing.B) { } } } + +func withElementMap(m map[string][]*Element) func(*Element) { + return func(elem *Element) { + key := elem.Name + e := elem + for { + if e.Parent == nil { + break + } + e = e.Parent + key = fmt.Sprintf("%s.%s", e.Name, key) + } + elements, ok := m[key] + if !ok { + elements = make([]*Element, 0) + } + elements = append(elements, elem) + m[key] = elements + } +} diff --git a/value.go b/value.go index 757bec1..603c001 100644 --- a/value.go +++ b/value.go @@ -43,11 +43,11 @@ var perTypeReader = map[Type]func(io.Reader, uint64) (interface{}, error){ TypeBlock: readBlock, } -func readVInt(r io.Reader) (uint64, error) { +func readVInt(r io.Reader) (uint64, int, error) { var bs [1]byte - _, err := r.Read(bs[:]) + bytesRead, err := r.Read(bs[:]) if err != nil { - return 0, err + return 0, bytesRead, err } var vc int @@ -82,14 +82,15 @@ func readVInt(r io.Reader) (uint64, error) { for { if vc == 0 { - return value, nil + return value, bytesRead, nil } var bs [1]byte - _, err := r.Read(bs[:]) + n, err := r.Read(bs[:]) if err != nil { - return 0, err + return 0, bytesRead, err } + bytesRead += n value = value<<8 | uint64(bs[0]) vc-- } diff --git a/value_test.go b/value_test.go index 83cb96a..a7538cc 100644 --- a/value_test.go +++ b/value_test.go @@ -32,7 +32,7 @@ func TestDataSize(t *testing.T) { for n, c := range testCases { t.Run("Decode "+n, func(t *testing.T) { - r, err := readVInt(bytes.NewBuffer(c.b)) + r, _, err := readVInt(bytes.NewBuffer(c.b)) if err != nil { t.Fatalf("Failed to readVInt: %v", err) } @@ -73,7 +73,7 @@ func TestElementID(t *testing.T) { for n, c := range testCases { t.Run("Decode "+n, func(t *testing.T) { - r, err := readVInt(bytes.NewBuffer(c.b)) + r, _, err := readVInt(bytes.NewBuffer(c.b)) if err != nil { t.Fatalf("Failed to readVInt: %v", err) }