From 0b7c063fa1ac8f64f260c8b3feef2bd8d6fbb519 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torbjo=CC=88rn=20Einarsson?= Date: Fri, 8 Nov 2024 21:13:14 +0100 Subject: [PATCH] test: more unit tests for package bits --- bits/ebspreader.go | 7 +-- bits/ebspreader_test.go | 51 +++++++++++++++++++- bits/ebspwriter_test.go | 90 +++++++++++++++++++++++++++++++++++ bits/fixedslicereader_test.go | 17 +++++++ 4 files changed, 158 insertions(+), 7 deletions(-) diff --git a/bits/ebspreader.go b/bits/ebspreader.go index bb91da68..1f788966 100644 --- a/bits/ebspreader.go +++ b/bits/ebspreader.go @@ -216,12 +216,7 @@ func (r *EBSPReader) MoreRbspData() (bool, error) { // reset resets EBSPReader based on copy of previous state. func (r *EBSPReader) reset(prevState EBSPReader) error { - rdSeek, ok := r.rd.(io.ReadSeeker) - - if !ok { - return ErrNotReadSeeker - } - + rdSeek, _ := r.rd.(io.ReadSeeker) _, err := rdSeek.Seek(int64(prevState.pos+1), 0) if err != nil { return err diff --git a/bits/ebspreader_test.go b/bits/ebspreader_test.go index cdebcaec..10bdf1c9 100644 --- a/bits/ebspreader_test.go +++ b/bits/ebspreader_test.go @@ -3,6 +3,7 @@ package bits_test import ( "bytes" "encoding/hex" + "errors" "io" "testing" @@ -327,7 +328,7 @@ func TestEBSPReader(t *testing.T) { r := bits.NewEBSPReader(brd) r.SetError(io.ErrUnexpectedEOF) // Read shold never result in panic - // Error should be preseverd + // Error should be preservedd _ = r.Read(100) if r.AccError() != io.ErrUnexpectedEOF { t.Errorf("Expected error not found") @@ -349,4 +350,52 @@ func TestEBSPReader(t *testing.T) { t.Errorf("Expected error not found") } }) + + t.Run("try seek in non-seekable reader", func(t *testing.T) { + emptyBuf := bytes.Buffer{} + r := bits.NewEBSPReader(&emptyBuf) + _, err := r.MoreRbspData() + if !errors.Is(err, bits.ErrNotReadSeeker) { + t.Error("Expected error checking for more data in empty buffer") + } + }) + + t.Run("no rbsp bits left", func(t *testing.T) { + input := []byte{0b1} + brd := bytes.NewReader(input) + r := bits.NewEBSPReader(brd) + for i := 0; i < 8; i++ { + _ = r.ReadFlag() + } + more, err := r.MoreRbspData() + if more { + t.Error("Expected no more rbsp data") + } + if err != nil { + t.Error("Expected error to be nil when no more data") + } + }) + + t.Run("not last rbsp bit bit", func(t *testing.T) { + input := []byte{0b01000001} + brd := bytes.NewReader(input) + r := bits.NewEBSPReader(brd) + more, err := r.MoreRbspData() + if !more { + t.Error("Expected more rbsp data") + } + if err != nil { + t.Error("Expected error to be nil when no more data") + } + err = r.ReadRbspTrailingBits() + if err == nil { + t.Error("Expected error when reading trailing bits") + } + err = r.ReadRbspTrailingBits() + if err == nil { + t.Error("Expected error when reading trailing bits") + } + + }) + } diff --git a/bits/ebspwriter_test.go b/bits/ebspwriter_test.go index ca70ce80..0a81bafa 100644 --- a/bits/ebspwriter_test.go +++ b/bits/ebspwriter_test.go @@ -66,6 +66,73 @@ func TestEBSPWriter(t *testing.T) { if gotBits != tc.bits { t.Errorf("wanted %s but got %s for %d", tc.bits, gotBits, tc.n) } + nrBitsInBuffer := w.NrBitsInBuffer() + if int(nrBitsInBuffer) != len(tc.bits)%8 { + t.Errorf("wanted %d bits in buffer but got %d", len(tc.bits)%8, nrBitsInBuffer) + } + if w.AccError() != nil { + t.Errorf("unexpected error: %v", w.AccError()) + } + } + }) + + t.Run("write to limited writer", func(t *testing.T) { + lw := newLimitedWriter(3) + w := bits.NewEBSPWriter(lw) + w.Write(0, 16) + if lw.nrWritten != 2 { + t.Errorf("wanted 2 bytes written but got %d", lw.nrWritten) + } + if w.AccError() != nil { + t.Errorf("unexpected error: %v", w.AccError()) + } + w.Write(1, 8) + // Now we should have written 4 due to start code emulation prevention byte + if lw.nrWritten != 4 { + t.Errorf("wanted 4 bytes written but got %d", lw.nrWritten) + } + if w.AccError() == nil { + t.Errorf("wanted error but got nil") + } + w.Write(1, 8) + if w.AccError() == nil { + t.Errorf("error should stay") + } + if lw.nrWritten != 4 { + t.Errorf("wanted 4 bytes written but got %d", lw.nrWritten) + } + }) + + t.Run("start code emulation prevention error", func(t *testing.T) { + lw := newLimitedWriter(2) + w := bits.NewEBSPWriter(lw) + w.Write(0, 16) + if lw.nrWritten != 2 { + t.Errorf("wanted 2 bytes written but got %d", lw.nrWritten) + } + if w.AccError() != nil { + t.Errorf("unexpected error: %v", w.AccError()) + } + w.Write(1, 8) + // Now we should have written 3 since start-code emulation triggered error + if lw.nrWritten != 3 { + t.Errorf("wanted 3 bytes written but got %d", lw.nrWritten) + } + if w.AccError() == nil { + t.Errorf("wanted error but got nil") + } + }) + + t.Run("write SEI and RBSP", func(t *testing.T) { + b := bytes.Buffer{} + w := bits.NewEBSPWriter(&b) + w.WriteSEIValue(300) + w.WriteRbspTrailingBits() + gotBits := getBitsWritten(w, &b) + expectedBits := "111111110010110110000000" + if gotBits != expectedBits { + t.Errorf("wanted %s but got %s", expectedBits, gotBits) + } }) } @@ -81,4 +148,27 @@ func getBitsWritten(w *bits.EBSPWriter, b *bytes.Buffer) string { bits += fullByte[8-nrBitsInWriter:] } return bits + +} + +type limitedWriter struct { + nrWritten uint + maxNrBytes uint +} + +func newLimitedWriter(maxNrBytes uint) *limitedWriter { + return &limitedWriter{nrWritten: 0, maxNrBytes: maxNrBytes} +} + +func (w *limitedWriter) Write(p []byte) (n int, err error) { + prevNrWritten := w.nrWritten + w.nrWritten += uint(len(p)) + if w.nrWritten > w.maxNrBytes { + n = int(w.maxNrBytes - prevNrWritten) + if n < 0 { + n = 0 + } + return n, fmt.Errorf("write limit reached") + } + return len(p), nil } diff --git a/bits/fixedslicereader_test.go b/bits/fixedslicereader_test.go index 5fd977b5..c1904eaf 100644 --- a/bits/fixedslicereader_test.go +++ b/bits/fixedslicereader_test.go @@ -234,6 +234,23 @@ func TestFixedSliceReader(t *testing.T) { t.Errorf("got error msg %q instead of %q", sr.AccError().Error(), wantedErrMsg) } }) + + t.Run("read possibly zero terminated string", func(t *testing.T) { + data := []byte("hej\x00") + sr := bits.NewFixedSliceReader(data) + _, ok := sr.ReadPossiblyZeroTerminatedString(-1) + if ok { + t.Errorf("got ok but impossible") + } + val, ok := sr.ReadPossiblyZeroTerminatedString(0) + if !ok || val != "" { + t.Errorf("got %q instead of empty string", val) + } + val, ok = sr.ReadPossiblyZeroTerminatedString(4) + if !ok || val != "hej" { + t.Errorf("got %q instead of 'hej'", val) + } + }) } func verifyAccErrorInt(t *testing.T, sr *bits.FixedSliceReader, val int) {