Skip to content

Commit

Permalink
Add UnmarshalHook (#47)
Browse files Browse the repository at this point in the history
* Return bytesRead in "readVInt"

* Add UnmarshalHook and a hook "WithElementMap"

* Make hooks functional options

* Fix design of options

* WithElementHooks -> WithElementReadHooks

* Fix comment
  • Loading branch information
kamatama41 authored and at-wat committed Nov 28, 2019
1 parent eb13e34 commit 0e6763b
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 17 deletions.
2 changes: 1 addition & 1 deletion block.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ type Lace struct {
func UnmarshalBlock(r io.Reader) (*Block, error) {
var b Block
var err error
if b.TrackNumber, err = readVInt(r); err != nil {
if b.TrackNumber, _, err = readVInt(r); err != nil {
return nil, err
}
if v, err := readInt(r, 2); err == nil {
Expand Down
2 changes: 1 addition & 1 deletion elementtable.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func init() {
revTable = make(elementRevTable)

for k, v := range table {
e, err := readVInt(bytes.NewBuffer(v.b))
e, _, err := readVInt(bytes.NewBuffer(v.b))
if err != nil {
panic(err)
}
Expand Down
56 changes: 49 additions & 7 deletions unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,19 @@ var (
)

// Unmarshal EBML stream
func Unmarshal(r io.Reader, val interface{}) error {
func Unmarshal(r io.Reader, val interface{}, opts ...UnmarshalOption) error {
options := &UnmarshalOptions{}
for _, o := range opts {
o(options)
}

vo := reflect.ValueOf(val)
if !vo.IsValid() {
return errIndefiniteType
}
voe := vo.Elem()

for {
if _, err := readElement(r, sizeInf, voe); err != nil {
if _, err := readElement(r, sizeInf, voe, 0, nil, options); err != nil {
if err == io.EOF {
return nil
}
Expand All @@ -50,7 +54,7 @@ func Unmarshal(r io.Reader, val interface{}) error {
}
}

func readElement(r0 io.Reader, n int64, vo reflect.Value) (io.Reader, error) {
func readElement(r0 io.Reader, n int64, vo reflect.Value, pos uint64, parent *Element, options *UnmarshalOptions) (io.Reader, error) {
var r io.Reader
if n != sizeInf {
r = io.LimitReader(r0, n)
Expand Down Expand Up @@ -80,7 +84,9 @@ func readElement(r0 io.Reader, n int64, vo reflect.Value) (io.Reader, error) {
}

for {
e, err := readVInt(r)
var headerSize uint64
e, nb, err := readVInt(r)
headerSize += uint64(nb)
if err != nil {
return nil, err
}
Expand All @@ -89,7 +95,8 @@ func readElement(r0 io.Reader, n int64, vo reflect.Value) (io.Reader, error) {
return nil, errUnknownElement
}

size, err := readVInt(r)
size, nb, err := readVInt(r)
headerSize += uint64(nb)
if err != nil {
return nil, err
}
Expand All @@ -116,7 +123,19 @@ func readElement(r0 io.Reader, n int64, vo reflect.Value) (io.Reader, error) {
vn = vnext
}
}
r0, err := readElement(r, int64(size), vn)

elem := &Element{
Value: vn.Interface(),
Name: v.e.String(),
Position: pos,
Size: size,
Parent: parent,
}
r0, err := readElement(r, int64(size), vn, pos+headerSize, elem, options)
for _, hook := range options.hooks {
hook(elem)
}

if err != nil && err != io.EOF {
return r0, err
}
Expand All @@ -137,5 +156,28 @@ func readElement(r0 io.Reader, n int64, vo reflect.Value) (io.Reader, error) {
}
}
}
pos += headerSize + size
}
}

type UnmarshalOption func(*UnmarshalOptions)

type UnmarshalOptions struct {
hooks []func(elem *Element)
}

// Element represents an EBML element
type Element struct {
Value interface{}
Name string
Position uint64
Size uint64
Parent *Element
}

// WithElementReadHooks returns an UnmarshalOption which registers element hooks
func WithElementReadHooks(hooks ...func(*Element)) UnmarshalOption {
return func(opts *UnmarshalOptions) {
opts.hooks = hooks
}
}
73 changes: 73 additions & 0 deletions unmarshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,59 @@ func ExampleUnmarshal() {
// Output: {{webm 2 2}}
}

func TestUnmarshal_WithElementReadHooks(t *testing.T) {
TestBinary := []byte{
0x18, 0x53, 0x80, 0x67, 0xa1, // Segment
0x16, 0x54, 0xae, 0x6b, 0x9c, // Tracks
0xae, 0x8c, // TrackEntry[0]
0x53, 0x6e, 0x86, 0x56, 0x69, 0x64, 0x65, 0x6f, 0x00, // Name=Video
0xd7, 0x81, 0x01, // TrackNumber=1
0xae, 0x8c, // TrackEntry[1]
0x53, 0x6e, 0x86, 0x41, 0x75, 0x64, 0x69, 0x6f, 0x00, // Name=Audio
0xd7, 0x81, 0x02, // TrackNumber=2
}

type TestEBML struct {
Segment struct {
Tracks struct {
TrackEntry []struct {
Name string `ebml:"Name,omitempty"`
TrackNumber uint64 `ebml:"TrackNumber"`
} `ebml:"TrackEntry"`
} `ebml:"Tracks"`
} `ebml:"Segment"`
}

r := bytes.NewReader(TestBinary)

var ret TestEBML
m := make(map[string][]*Element)
hook := withElementMap(m)
if err := Unmarshal(r, &ret, WithElementReadHooks(hook)); err != nil {
t.Errorf("error: %+v\n", err)
}

// Verify positions of elements
expected := map[string][]uint64{
"Segment.Tracks": {5},
"Segment.Tracks.TrackEntry": {10, 24},
}
for key, positions := range expected {
elem, ok := m[key]
if !ok {
t.Errorf("Key '%s' doesn't exist\n", key)
}
if len(elem) != len(positions) {
t.Errorf("Unexpected element size of '%s', expected: %d, got: %d\n", key, len(positions), len(elem))
}
for i, pos := range positions {
if elem[i].Position != pos {
t.Errorf("Unexpected element positon of '%s[%d]', expected: %d, got: %d\n", key, i, pos, elem[i].Position)
}
}
}
}

func TestUnmarshal_Tag(t *testing.T) {
var tagged struct {
DocCustomNamedType string `ebml:"EBMLDocType"`
Expand Down Expand Up @@ -92,3 +145,23 @@ func BenchmarkUnmarshal(b *testing.B) {
}
}
}

func withElementMap(m map[string][]*Element) func(*Element) {
return func(elem *Element) {
key := elem.Name
e := elem
for {
if e.Parent == nil {
break
}
e = e.Parent
key = fmt.Sprintf("%s.%s", e.Name, key)
}
elements, ok := m[key]
if !ok {
elements = make([]*Element, 0)
}
elements = append(elements, elem)
m[key] = elements
}
}
13 changes: 7 additions & 6 deletions value.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ var perTypeReader = map[Type]func(io.Reader, uint64) (interface{}, error){
TypeBlock: readBlock,
}

func readVInt(r io.Reader) (uint64, error) {
func readVInt(r io.Reader) (uint64, int, error) {
var bs [1]byte
_, err := r.Read(bs[:])
bytesRead, err := r.Read(bs[:])
if err != nil {
return 0, err
return 0, bytesRead, err
}

var vc int
Expand Down Expand Up @@ -82,14 +82,15 @@ func readVInt(r io.Reader) (uint64, error) {

for {
if vc == 0 {
return value, nil
return value, bytesRead, nil
}

var bs [1]byte
_, err := r.Read(bs[:])
n, err := r.Read(bs[:])
if err != nil {
return 0, err
return 0, bytesRead, err
}
bytesRead += n
value = value<<8 | uint64(bs[0])
vc--
}
Expand Down
4 changes: 2 additions & 2 deletions value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestDataSize(t *testing.T) {

for n, c := range testCases {
t.Run("Decode "+n, func(t *testing.T) {
r, err := readVInt(bytes.NewBuffer(c.b))
r, _, err := readVInt(bytes.NewBuffer(c.b))
if err != nil {
t.Fatalf("Failed to readVInt: %v", err)
}
Expand Down Expand Up @@ -73,7 +73,7 @@ func TestElementID(t *testing.T) {

for n, c := range testCases {
t.Run("Decode "+n, func(t *testing.T) {
r, err := readVInt(bytes.NewBuffer(c.b))
r, _, err := readVInt(bytes.NewBuffer(c.b))
if err != nil {
t.Fatalf("Failed to readVInt: %v", err)
}
Expand Down

0 comments on commit 0e6763b

Please sign in to comment.