From b5865f5149648819c767e2582a3b6c1a9ecefd30 Mon Sep 17 00:00:00 2001 From: tdakkota Date: Sun, 29 Jan 2023 16:11:14 +0300 Subject: [PATCH] test: add streaming encoding tests --- enc_b64_test.go | 39 ++++++++--------- enc_str_test.go | 60 +++++++++++++++----------- enc_stream_test.go | 103 +++++++++++++++++++++++++++++++++++++++++++++ enc_test.go | 73 +++++++++++++++++++++----------- float_test.go | 17 +++----- int_test.go | 28 ++++++------ jx_test.go | 10 ----- null_test.go | 7 ++- w_b64_test.go | 52 +++++++++++++++++++++++ w_str.go | 37 ++++++---------- w_str_escape.go | 64 ++++++++++------------------ 11 files changed, 316 insertions(+), 174 deletions(-) create mode 100644 enc_stream_test.go diff --git a/enc_b64_test.go b/enc_b64_test.go index cc496d8..579ef3d 100644 --- a/enc_b64_test.go +++ b/enc_b64_test.go @@ -1,42 +1,43 @@ package jx import ( - "encoding/base64" + "bytes" "fmt" "testing" - - "github.com/stretchr/testify/require" ) func TestEncoder_Base64(t *testing.T) { t.Run("Values", func(t *testing.T) { - for _, s := range [][]byte{ + for i, s := range [][]byte{ []byte(`1`), []byte(`12`), []byte(`2345`), {1, 2, 3, 4, 5, 6}, - } { - var e Encoder - e.Base64(s) - - expected := fmt.Sprintf("%q", base64.StdEncoding.EncodeToString(s)) - require.Equal(t, expected, e.String()) - requireCompat(t, e.Bytes(), s) + bytes.Repeat([]byte{1}, encoderBufSize-1), + bytes.Repeat([]byte{1}, encoderBufSize), + bytes.Repeat([]byte{1}, encoderBufSize+1), + } { + s := s + t.Run(fmt.Sprintf("Test%d", i+1), func(t *testing.T) { + requireCompat(t, func(e *Encoder) { + e.Base64(s) + }, s) + }) } }) t.Run("Zeroes", func(t *testing.T) { t.Run("Nil", func(t *testing.T) { - v := []byte(nil) - var e Encoder - e.Base64(v) - requireCompat(t, e.Bytes(), v) + s := []byte(nil) + requireCompat(t, func(e *Encoder) { + e.Base64(s) + }, s) }) t.Run("ZeroLen", func(t *testing.T) { - v := make([]byte, 0) - var e Encoder - e.Base64(v) - requireCompat(t, e.Bytes(), v) + s := make([]byte, 0) + requireCompat(t, func(e *Encoder) { + e.Base64(s) + }, s) }) }) } diff --git a/enc_str_test.go b/enc_str_test.go index ef702d3..2063420 100644 --- a/enc_str_test.go +++ b/enc_str_test.go @@ -3,6 +3,7 @@ package jx import ( "encoding/json" "fmt" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -21,25 +22,31 @@ func TestEncoder_Str(t *testing.T) { {"\x00"}, {"\x00 "}, {`"hello, world!"`}, + + {strings.Repeat("a", encoderBufSize)}, } for i, tt := range testCases { tt := tt t.Run(fmt.Sprintf("Test%d", i+1), func(t *testing.T) { for _, enc := range []struct { name string - enc func(e *Encoder, input string) + enc func(e *Encoder, input string) bool }{ {"Str", (*Encoder).Str}, - {"Bytes", func(e *Encoder, input string) { - e.ByteStr([]byte(tt.input)) + {"Bytes", func(e *Encoder, input string) bool { + return e.ByteStr([]byte(tt.input)) }}, } { enc := enc t.Run(enc.name, func(t *testing.T) { - e := GetEncoder() - enc.enc(e, tt.input) - requireCompat(t, e.Bytes(), tt.input) + requireCompat(t, func(e *Encoder) { + enc.enc(e, tt.input) + }, tt.input) + t.Run("Decode", func(t *testing.T) { + e := GetEncoder() + enc.enc(e, tt.input) + i := GetDecoder() i.ResetBytes(e.Bytes()) s, err := i.Str() @@ -54,9 +61,9 @@ func TestEncoder_Str(t *testing.T) { const ( v = "\"/\"" ) - var e Encoder - e.Str(v) - requireCompat(t, e.Bytes(), v) + requireCompat(t, func(e *Encoder) { + e.StrEscape(v) + }, v) }) t.Run("QuotesObj", func(t *testing.T) { const ( @@ -64,17 +71,21 @@ func TestEncoder_Str(t *testing.T) { v = "\"/\"" ) + cb := func(e *Encoder) { + e.ObjStart() + e.FieldStart(k) + e.Str(v) + e.ObjEnd() + t.Log(e) + } + var e Encoder - e.ObjStart() - e.FieldStart(k) - e.Str(v) - e.ObjEnd() - t.Log(e) + cb(&e) var target map[string]string require.NoError(t, json.Unmarshal(e.Bytes(), &target)) assert.Equal(t, v, target[k]) - requireCompat(t, e.Bytes(), map[string]string{k: v}) + requireCompat(t, cb, map[string]string{k: v}) }) } @@ -96,19 +107,18 @@ func TestEncoder_StrEscape(t *testing.T) { t.Run(fmt.Sprintf("Test%d", i+1), func(t *testing.T) { for _, enc := range []struct { name string - enc func(e *Encoder, input string) + enc func(e *Encoder, input string) bool }{ {"Str", (*Encoder).StrEscape}, - {"Bytes", func(e *Encoder, input string) { - e.ByteStrEscape([]byte(tt.input)) + {"Bytes", func(e *Encoder, input string) bool { + return e.ByteStrEscape([]byte(tt.input)) }}, } { enc := enc t.Run(enc.name, func(t *testing.T) { - e := GetEncoder() - enc.enc(e, tt.input) - require.Equal(t, tt.expect, string(e.Bytes())) - requireCompat(t, e.Bytes(), tt.input) + requireCompat(t, func(e *Encoder) { + enc.enc(e, tt.input) + }, tt.input) }) } }) @@ -117,8 +127,8 @@ func TestEncoder_StrEscape(t *testing.T) { const ( v = "\"/\"" ) - var e Encoder - e.StrEscape(v) - requireCompat(t, e.Bytes(), v) + requireCompat(t, func(e *Encoder) { + e.StrEscape(v) + }, v) }) } diff --git a/enc_stream_test.go b/enc_stream_test.go new file mode 100644 index 0000000..3eab255 --- /dev/null +++ b/enc_stream_test.go @@ -0,0 +1,103 @@ +package jx + +import ( + "io" + "strings" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/go-faster/errors" +) + +func TestEncoderStreamingCheck(t *testing.T) { + a := require.New(t) + + e := NewStreamingEncoder(io.Discard, 512) + + _, err := e.Write([]byte("hello")) + a.ErrorIs(err, errStreaming) + + _, err = e.WriteTo(io.Discard) + a.ErrorIs(err, errStreaming) + + a.PanicsWithError(errStreaming.Error(), func() { + _ = e.String() + }) +} + +type errWriter struct { + err error + n int +} + +func (e *errWriter) Write(p []byte) (int, error) { + n := e.n + if n <= 0 { + n = len(p) + } + return n, e.err +} + +func TestEncoder_Close(t *testing.T) { + errTest := errors.New("test") + + t.Run("FlushErr", func(t *testing.T) { + ew := &errWriter{err: errTest} + e := NewStreamingEncoder(ew, -1) + e.Null() + + require.ErrorIs(t, e.Close(), errTest) + }) + t.Run("WriteErr", func(t *testing.T) { + ew := &errWriter{err: errTest} + e := NewStreamingEncoder(ew, 32) + e.Obj(func(e *Encoder) { + e.FieldStart(strings.Repeat("a", 32)) + e.Null() + }) + + require.ErrorIs(t, e.Close(), errTest) + }) + t.Run("ShortWrite", func(t *testing.T) { + ew := &errWriter{n: 1} + e := NewStreamingEncoder(ew, -1) + e.Null() + + require.ErrorIs(t, e.Close(), io.ErrShortWrite) + }) + t.Run("OK", func(t *testing.T) { + e := NewStreamingEncoder(io.Discard, -1) + e.Null() + + require.NoError(t, e.Close()) + }) + t.Run("NoStreaming", func(t *testing.T) { + var e Encoder + e.Null() + + require.NoError(t, e.Close()) + }) +} + +func TestEncoder_ResetWriter(t *testing.T) { + do := func(e *Encoder) { + e.ObjStart() + e.FieldStart(strings.Repeat("a", 32)) + e.Null() + e.ObjEnd() + + require.NoError(t, e.Close()) + } + + var e Encoder + do(&e) + expected := e.String() + + for range [3]struct{}{} { + var got strings.Builder + e.ResetWriter(&got) + do(&e) + require.Equal(t, expected, got.String()) + } +} diff --git a/enc_test.go b/enc_test.go index 8c07611..6e8febd 100644 --- a/enc_test.go +++ b/enc_test.go @@ -1,12 +1,37 @@ package jx import ( + "encoding/json" "fmt" + "strings" "testing" "github.com/stretchr/testify/require" ) +func testEncoderModes(t *testing.T, cb func(*Encoder), expected string) { + t.Run("Buffer", func(t *testing.T) { + e := GetEncoder() + cb(e) + require.Equal(t, expected, e.String()) + }) + t.Run("Writer", func(t *testing.T) { + var sb strings.Builder + e := NewStreamingEncoder(&sb, -1) + cb(e) + require.NoError(t, e.Close()) + require.Equal(t, expected, sb.String()) + }) +} + +// requireCompat fails if `encoding/json` will encode v differently than exp. +func requireCompat(t *testing.T, cb func(*Encoder), v any) { + t.Helper() + buf, err := json.Marshal(v) + require.NoError(t, err, "json.Marshal(%#v)", v) + testEncoderModes(t, cb, string(buf)) +} + func TestEncoderByteShouldGrowBuffer(t *testing.T) { should := require.New(t) e := GetEncoder() @@ -56,46 +81,46 @@ func TestEncoderStrShouldGrowBuffer(t *testing.T) { } func TestEncoder_ArrEmpty(t *testing.T) { - e := GetEncoder() - e.ArrEmpty() - require.Equal(t, "[]", string(e.Bytes())) + testEncoderModes(t, func(e *Encoder) { + e.ArrEmpty() + }, "[]") } func TestEncoder_ObjEmpty(t *testing.T) { - e := GetEncoder() - e.ObjEmpty() - require.Equal(t, "{}", string(e.Bytes())) + testEncoderModes(t, func(e *Encoder) { + e.ObjEmpty() + }, "{}") } func TestEncoder_Obj(t *testing.T) { - t.Run("FieldStart", func(t *testing.T) { - var e Encoder - e.Obj(func(e *Encoder) { - e.Field("hello", func(e *Encoder) { - e.Str("world") + t.Run("Field", func(t *testing.T) { + testEncoderModes(t, func(e *Encoder) { + e.Obj(func(e *Encoder) { + e.Field("hello", func(e *Encoder) { + e.Str("world") + }) }) - }) - require.Equal(t, `{"hello":"world"}`, e.String()) + }, `{"hello":"world"}`) }) t.Run("Nil", func(t *testing.T) { - var e Encoder - e.Obj(nil) - require.Equal(t, `{}`, e.String()) + testEncoderModes(t, func(e *Encoder) { + e.Obj(nil) + }, `{}`) }) } func TestEncoder_Arr(t *testing.T) { t.Run("Elem", func(t *testing.T) { - var e Encoder - e.Arr(func(e *Encoder) { - e.Str("world") - }) - require.Equal(t, `["world"]`, e.String()) + testEncoderModes(t, func(e *Encoder) { + e.Arr(func(e *Encoder) { + e.Str("world") + }) + }, `["world"]`) }) t.Run("Nil", func(t *testing.T) { - var e Encoder - e.Arr(nil) - require.Equal(t, `[]`, e.String()) + testEncoderModes(t, func(e *Encoder) { + e.Arr(nil) + }, `[]`) }) } diff --git a/float_test.go b/float_test.go index ef14a83..d3cf615 100644 --- a/float_test.go +++ b/float_test.go @@ -123,12 +123,9 @@ func TestWriteFloat32(t *testing.T) { } for _, val := range vals { t.Run(fmt.Sprintf("%v", val), func(t *testing.T) { - should := require.New(t) - w := GetEncoder() - w.Float32(val) - output, err := json.Marshal(val) - should.Nil(err) - should.Equal(output, w.Bytes()) + requireCompat(t, func(e *Encoder) { + e.Float32(val) + }, val) }) } should := require.New(t) @@ -144,11 +141,9 @@ func TestWriteFloat64(t *testing.T) { } for _, val := range vals { t.Run(fmt.Sprintf("%v", val), func(t *testing.T) { - should := require.New(t) - e := GetEncoder() - e.Float64(val) - s := strconv.FormatFloat(val, 'f', -1, 64) - should.Equal(s, string(e.Bytes())) + requireCompat(t, func(e *Encoder) { + e.Float64(val) + }, val) }) } should := require.New(t) diff --git a/int_test.go b/int_test.go index 2405e2b..2d732d3 100644 --- a/int_test.go +++ b/int_test.go @@ -475,10 +475,9 @@ func TestWriteUint32(t *testing.T) { vals := []uint32{0, 1, 11, 111, 255, 999999, 0xfff, 0xffff, 0xfffff, 0xffffff, 0xfffffff, 0xffffffff} for _, val := range vals { t.Run(fmt.Sprintf("%v", val), func(t *testing.T) { - should := require.New(t) - e := GetEncoder() - e.UInt32(val) - should.Equal(strconv.FormatUint(uint64(val), 10), e.String()) + requireCompat(t, func(e *Encoder) { + e.UInt32(val) + }, val) }) } should := require.New(t) @@ -492,10 +491,9 @@ func TestWriteInt32(t *testing.T) { vals := []int32{0, 1, 11, 111, 255, 999999, 0xfff, 0xffff, 0xfffff, 0xffffff, 0xfffffff, 0x7fffffff, -0x80000000} for _, val := range vals { t.Run(fmt.Sprintf("%v", val), func(t *testing.T) { - should := require.New(t) - e := GetEncoder() - e.Int32(val) - should.Equal(strconv.FormatInt(int64(val), 10), e.String()) + requireCompat(t, func(e *Encoder) { + e.Int32(val) + }, val) }) } should := require.New(t) @@ -513,10 +511,9 @@ func TestWriteUint64(t *testing.T) { } for _, val := range vals { t.Run(fmt.Sprintf("%v", val), func(t *testing.T) { - should := require.New(t) - e := GetEncoder() - e.UInt64(val) - should.Equal(strconv.FormatUint(val, 10), e.String()) + requireCompat(t, func(e *Encoder) { + e.UInt64(val) + }, val) }) } should := require.New(t) @@ -534,10 +531,9 @@ func TestWriteInt64(t *testing.T) { } for _, val := range vals { t.Run(fmt.Sprintf("%v", val), func(t *testing.T) { - should := require.New(t) - e := GetEncoder() - e.Int64(val) - should.Equal(strconv.FormatInt(val, 10), e.String()) + requireCompat(t, func(e *Encoder) { + e.Int64(val) + }, val) }) } should := require.New(t) diff --git a/jx_test.go b/jx_test.go index fdf9d6c..12a1a9b 100644 --- a/jx_test.go +++ b/jx_test.go @@ -2,22 +2,12 @@ package jx import ( "bytes" - "encoding/json" "sync" "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) -// requireCompat fails if `encoding/json` will encode v differently than exp. -func requireCompat(t testing.TB, got []byte, v interface{}) { - t.Helper() - buf, err := json.Marshal(v) - require.NoError(t, err) - require.Equal(t, string(buf), string(got)) -} - func TestPutEncoder(t *testing.T) { var wg sync.WaitGroup for j := 0; j < 4; j++ { diff --git a/null_test.go b/null_test.go index fff6394..9778ca6 100644 --- a/null_test.go +++ b/null_test.go @@ -7,10 +7,9 @@ import ( ) func TestWriteNull(t *testing.T) { - should := require.New(t) - e := GetEncoder() - e.Null() - should.Equal("null", e.String()) + testEncoderModes(t, func(e *Encoder) { + e.Null() + }, "null") } func TestDecodeNullArrayElement(t *testing.T) { diff --git a/w_b64_test.go b/w_b64_test.go index f5e0983..d7b7ecd 100644 --- a/w_b64_test.go +++ b/w_b64_test.go @@ -1,10 +1,62 @@ package jx import ( + "bytes" "fmt" + "io" + "strings" "testing" + + "github.com/stretchr/testify/require" + + "github.com/go-faster/errors" ) +type limitWriter struct { + w io.Writer + n int64 +} + +func (t *limitWriter) Write(p []byte) (n int, err error) { + if t.n-int64(len(p)) < 0 { + return 0, errors.New("limit reached") + } + // real write + n = len(p) + if int64(n) > t.n { + n = int(t.n) + } + n, err = t.w.Write(p[0:n]) + t.n -= int64(n) + if err == nil { + n = len(p) + } + return +} + +func TestWriter_Base64(t *testing.T) { + const bufSize = minEncoderBufSize + const fieldLength = bufSize - len(`{"":`) + + limits := []int64{ + 31, // write '"' + 32, // flush + 33, // Write base64 + 73, // Write tail of base64 + } + + data := bytes.Repeat([]byte{0}, bufSize) + for _, n := range limits { + // Write '"' error. + e := NewStreamingEncoder(&limitWriter{w: io.Discard, n: n}, bufSize) + e.Obj(func(e *Encoder) { + e.FieldStart(strings.Repeat("a", fieldLength)) + e.Base64(data) + }) + require.Error(t, e.Close()) + } +} + func BenchmarkWriter_Base64(b *testing.B) { for _, n := range []int{ 128, diff --git a/w_str.go b/w_str.go index eb11b69..1b6aa4f 100644 --- a/w_str.go +++ b/w_str.go @@ -38,10 +38,8 @@ func (w *Writer) ByteStr(v []byte) bool { return writeStr(w, v) } -func writeStr[S byteseq.Byteseq](w *Writer, v S) bool { - if w.byte('"') { - return true - } +func writeStr[S byteseq.Byteseq](w *Writer, v S) (fail bool) { + fail = w.byte('"') // Fast path, without utf8 and escape support. var ( @@ -58,54 +56,45 @@ func writeStr[S byteseq.Byteseq](w *Writer, v S) bool { return writeStreamByteseq(w, v) || w.byte('"') } slow: - w.Buf = append(w.Buf, v[:i]...) - return strSlow[S](w, v[i:]) + return writeStreamByteseq(w, v[:i]) || strSlow[S](w, v[i:]) } -func strSlow[S byteseq.Byteseq](w *Writer, v S) bool { +func strSlow[S byteseq.Byteseq](w *Writer, v S) (fail bool) { var i, start int // for the remaining parts, we process them char by char - for i < len(v) { + for i < len(v) && !fail { b := v[i] if safeSet[b] == 0 { i++ continue } if start < i { - if writeStreamByteseq(w, v[start:i]) { - return true - } + fail = fail || writeStreamByteseq(w, v[start:i]) } - var fail bool switch b { case '\\', '"': - fail = w.twoBytes('\\', b) + fail = fail || w.twoBytes('\\', b) case '\n': - fail = w.twoBytes('\\', 'n') + fail = fail || w.twoBytes('\\', 'n') case '\r': - fail = w.twoBytes('\\', 'r') + fail = fail || w.twoBytes('\\', 'r') case '\t': - fail = w.twoBytes('\\', 't') + fail = fail || w.twoBytes('\\', 't') default: // This encodes bytes < 0x20 except for \t, \n and \r. // If escapeHTML is set, it also escapes <, >, and & // because they can lead to security holes when // user-controlled strings are rendered into JSON // and served to some browsers. - fail = w.rawStr(`\u00`) || w.twoBytes(hexChars[b>>4], hexChars[b&0xF]) - } - if fail { - return true + fail = fail || w.rawStr(`\u00`) || w.twoBytes(hexChars[b>>4], hexChars[b&0xF]) } i++ start = i continue } if start < len(v) { - if writeStreamByteseq(w, v[start:]) { - return true - } + fail = fail || writeStreamByteseq(w, v[start:]) } - return w.byte('"') + return fail || w.byte('"') } diff --git a/w_str_escape.go b/w_str_escape.go index cd25945..14bf3b2 100644 --- a/w_str_escape.go +++ b/w_str_escape.go @@ -122,63 +122,55 @@ func (w *Writer) ByteStrEscape(v []byte) bool { return strEscape(w, v) } -func strEscape[S byteseq.Byteseq](w *Writer, v S) bool { - length := len(v) - if w.byte('"') { - return true - } +func strEscape[S byteseq.Byteseq](w *Writer, v S) (fail bool) { + fail = w.byte('"') // Fast path, probably does not require escaping. - i := 0 + var ( + i = 0 + length = len(v) + ) for ; i < length; i++ { c := v[i] if c >= utf8.RuneSelf || !(htmlSafeSet[c]) { break } } - if writeStreamByteseq(w, v[:i]) { - return true - } + fail = fail || writeStreamByteseq(w, v[:i]) if i == length { - return w.byte('"') + return fail || w.byte('"') } - return strEscapeSlow[S](w, i, v, length) + return fail || strEscapeSlow[S](w, i, v, length) } -func strEscapeSlow[S byteseq.Byteseq](w *Writer, i int, v S, valLen int) bool { +func strEscapeSlow[S byteseq.Byteseq](w *Writer, i int, v S, valLen int) (fail bool) { start := i // for the remaining parts, we process them char by char - for i < valLen { + for i < valLen && !fail { if b := v[i]; b < utf8.RuneSelf { if htmlSafeSet[b] { i++ continue } if start < i { - if writeStreamByteseq(w, v[start:i]) { - return true - } + fail = fail || writeStreamByteseq(w, v[start:i]) } - var fail bool switch b { case '\\', '"': - fail = w.twoBytes('\\', b) + fail = fail || w.twoBytes('\\', b) case '\n': - fail = w.twoBytes('\\', 'n') + fail = fail || w.twoBytes('\\', 'n') case '\r': - fail = w.twoBytes('\\', 'r') + fail = fail || w.twoBytes('\\', 'r') case '\t': - fail = w.twoBytes('\\', 't') + fail = fail || w.twoBytes('\\', 't') default: // This encodes bytes < 0x20 except for \t, \n and \r. // If escapeHTML is set, it also escapes <, >, and & // because they can lead to security holes when // user-controlled strings are rendered into JSON // and served to some browsers. - fail = w.rawStr(`\u00`) || w.twoBytes(hexChars[b>>4], hexChars[b&0xF]) - } - if fail { - return true + fail = fail || w.rawStr(`\u00`) || w.twoBytes(hexChars[b>>4], hexChars[b&0xF]) } i++ start = i @@ -187,13 +179,9 @@ func strEscapeSlow[S byteseq.Byteseq](w *Writer, i int, v S, valLen int) bool { c, size := byteseq.DecodeRuneInByteseq(v[i:]) if c == utf8.RuneError && size == 1 { if start < i { - if writeStreamByteseq(w, v[start:i]) { - return true - } - } - if w.rawStr(`\ufffd`) { - return true + fail = fail || writeStreamByteseq(w, v[start:i]) } + fail = fail || w.rawStr(`\ufffd`) i++ start = i continue @@ -207,13 +195,9 @@ func strEscapeSlow[S byteseq.Byteseq](w *Writer, i int, v S, valLen int) bool { // See http://timelessrepo.com/json-isnt-a-javascript-subset for discussion. if c == '\u2028' || c == '\u2029' { if start < i { - if writeStreamByteseq(w, v[start:i]) { - return true - } - } - if w.rawStr(`\u202`) || w.byte(hexChars[c&0xF]) { - return true + fail = fail || writeStreamByteseq(w, v[start:i]) } + fail = fail || w.rawStr(`\u202`) || w.byte(hexChars[c&0xF]) i += size start = i continue @@ -221,9 +205,7 @@ func strEscapeSlow[S byteseq.Byteseq](w *Writer, i int, v S, valLen int) bool { i += size } if start < len(v) { - if writeStreamByteseq(w, v[start:]) { - return true - } + fail = fail || writeStreamByteseq(w, v[start:]) } - return w.byte('"') + return fail || w.byte('"') }