diff --git a/README.md b/README.md index 292a799..aed4b32 100644 --- a/README.md +++ b/README.md @@ -244,7 +244,7 @@ d.ObjBytes(func(d *Decoder, key []byte) error { ## Roadmap - [ ] Rework and export `Any` - [ ] Support `Raw` for io.Reader -- [ ] Support `Capture` for io.Reader +- [x] Support `Capture` for io.Reader - [ ] Improve Num - Better validation on decoding - Support BigFloat and BigInt diff --git a/dec_capture.go b/dec_capture.go index 6a5beab..5a9e643 100644 --- a/dec_capture.go +++ b/dec_capture.go @@ -1,19 +1,25 @@ package jx import ( - "github.com/go-faster/errors" + "bytes" + "io" ) // Capture calls f and then rolls back to state before call. -// -// Does not work with reader. func (d *Decoder) Capture(f func(d *Decoder) error) error { - if d.reader != nil { - return errors.New("capture is not supported with reader") - } if f == nil { return nil } + + if d.reader != nil { + // TODO(tdakkota): May it be more efficient? + var buf bytes.Buffer + reader := io.TeeReader(d.reader, &buf) + defer func() { + d.reader = io.MultiReader(&buf, d.reader) + }() + d.reader = reader + } head, tail, depth := d.head, d.tail, d.depth err := f(d) d.head, d.tail, d.depth = head, tail, depth diff --git a/dec_capture_test.go b/dec_capture_test.go index bf6a566..24c0b6f 100644 --- a/dec_capture_test.go +++ b/dec_capture_test.go @@ -84,25 +84,56 @@ func BenchmarkIterator_Skip(b *testing.B) { } func TestDecoder_Capture(t *testing.T) { - i := DecodeStr(`["foo", "bar", "baz"]`) - var elems int - if err := i.Capture(func(i *Decoder) error { - return i.Arr(func(i *Decoder) error { - elems++ - return i.Skip() - }) - }); err != nil { - t.Fatal(err) + strs := []string{ + "foo", + "bar", + "baz", } - require.Equal(t, Array, i.Next()) - require.Equal(t, 3, elems) - t.Run("Nil", func(t *testing.T) { - require.NoError(t, i.Capture(nil)) - require.Equal(t, Array, i.Next()) - }) -} + test := func(i *Decoder) func(t *testing.T) { + return func(t *testing.T) { + var elems int + if err := i.Capture(func(i *Decoder) error { + return i.Arr(func(i *Decoder) error { + elems++ + return i.Skip() + }) + }); err != nil { + t.Fatal(err) + } + require.Equal(t, Array, i.Next()) + require.Equal(t, 6, elems) + t.Run("Nil", func(t *testing.T) { + require.NoError(t, i.Capture(nil)) + require.Equal(t, Array, i.Next()) + }) -func TestDecoder_Capture_reader(t *testing.T) { - i := Decode(new(bytes.Buffer), 0) - require.Error(t, i.Capture(nil)) + idx := 0 + require.NoError(t, i.Arr(func(d *Decoder) error { + v, err := d.Str() + if err != nil { + return err + } + require.Equal(t, strs[idx%len(strs)], v) + + idx++ + return nil + })) + } + } + + var e Encoder + e.ArrStart() + for i := 0; i < 6; i++ { + e.Str(strs[i%len(strs)]) + } + e.ArrEnd() + testData := e.Bytes() + + t.Run("Str", test(DecodeBytes(testData))) + // Check that we get correct result even if buffer smaller than captured data. + decoder := Decoder{ + reader: bytes.NewReader(testData), + buf: make([]byte, 8), + } + t.Run("Reader", test(&decoder)) }