diff --git a/body_wrapper_test.go b/body_wrapper_test.go index 5ac362b8b..c80dd71ff 100644 --- a/body_wrapper_test.go +++ b/body_wrapper_test.go @@ -279,7 +279,7 @@ func TestBodyWrapper_Rewind(t *testing.T) { } func TestBodyWrapper_GetBody(t *testing.T) { - t.Run("independent readers", func(t *testing.T) { + t.Run("parallel readers", func(t *testing.T) { body := newMockBody("test_body") wrp := newBodyWrapper(body, nil) @@ -302,113 +302,187 @@ func TestBodyWrapper_GetBody(t *testing.T) { assert.Equal(t, 1, body.closeCount) }) - t.Run("start reading body, then call GetBody", func(t *testing.T) { + t.Run("start read - get body - finish read", func(t *testing.T) { body := newMockBody("test_body") wrp := newBodyWrapper(body, nil) // Start reading body - var buf bytes.Buffer - tee := io.TeeReader(wrp, &buf) - _, err := io.Copy(ioutil.Discard, tee) + b := make([]byte, len("test")) + n, err := wrp.Read(b) + assert.Equal(t, len(b), n) assert.NoError(t, err) + assert.Equal(t, "test", string(b)) // Call GetBody and read from it rd, err := wrp.GetBody() assert.NoError(t, err) - - b, err := ioutil.ReadAll(rd) + b, err = io.ReadAll(rd) assert.NoError(t, err) assert.Equal(t, "test_body", string(b)) + // Finish reading body + b, err = io.ReadAll(wrp) + assert.NoError(t, err) + assert.Equal(t, "_body", string(b)) + // Check body read count and close count - assert.Equal(t, 2, body.readCount) + assert.NotEqual(t, 0, body.readCount) assert.Equal(t, 1, body.closeCount) }) - t.Run("rewind - start reading body, then call GetBody", func(t *testing.T) { + t.Run("start read - rewind - get body - read again", func(t *testing.T) { body := newMockBody("test_body") - cancelCount := 0 - cancelFn := func() { - cancelCount++ - } + wrp := newBodyWrapper(body, nil) - wrp := newBodyWrapper(body, cancelFn) + // Start reading body + b := make([]byte, len("test")) + n, err := wrp.Read(b) + assert.Equal(t, len(b), n) + assert.NoError(t, err) + assert.Equal(t, "test", string(b)) + // Rewind wrp.Rewind() + + // Call GetBody and read from it rd, err := wrp.GetBody() assert.NoError(t, err) + b, err = io.ReadAll(rd) + assert.NoError(t, err) + assert.Equal(t, "test_body", string(b)) - // Rewind again to ensure that the body can be read multiple times - wrp.Rewind() + // Re-read body until EOF + b, err = io.ReadAll(wrp) + assert.NoError(t, err) + assert.Equal(t, "test_body", string(b)) + + // Check body read count and close count + assert.NotEqual(t, 0, body.readCount) + assert.Equal(t, 1, body.closeCount) + }) + + t.Run("read all - get body - read again", func(t *testing.T) { + body := newMockBody("test_body") - // Read the entire body to ensure it is fully consumed - b, err := ioutil.ReadAll(rd) + wrp := newBodyWrapper(body, nil) + + // Read body until EOF + b, err := io.ReadAll(wrp) assert.NoError(t, err) assert.Equal(t, "test_body", string(b)) - assert.Equal(t, 2, body.readCount) + // Call GetBody and read from it + rd, err := wrp.GetBody() + assert.NoError(t, err) + b, err = io.ReadAll(rd) + assert.NoError(t, err) + assert.Equal(t, "test_body", string(b)) + + // Try to read more + b, err = io.ReadAll(wrp) + assert.NoError(t, err) + assert.Equal(t, "", string(b)) + + // Check body read count and close count + assert.NotEqual(t, 0, body.readCount) assert.Equal(t, 1, body.closeCount) - assert.Equal(t, 1, cancelCount) }) - t.Run("rewind - read body until EOF (using io.ReadAll), then call GetBody", func(t *testing.T) { + t.Run("read all - rewind - get body - read again", func(t *testing.T) { body := newMockBody("test_body") - cancelCount := 0 - cancelFn := func() { - cancelCount++ - } + wrp := newBodyWrapper(body, nil) - wrp := newBodyWrapper(body, cancelFn) + // Read body until EOF + b, err := io.ReadAll(wrp) + assert.NoError(t, err) + assert.Equal(t, "test_body", string(b)) + // Rewind wrp.Rewind() + + // Call GetBody and read from it rd, err := wrp.GetBody() assert.NoError(t, err) - // rewind the body again to make the entire body available to be read - wrp.Rewind() + b, err = io.ReadAll(rd) + assert.NoError(t, err) + assert.Equal(t, "test_body", string(b)) - b, err := ioutil.ReadAll(rd) + // Re-read body until EOF + b, err = io.ReadAll(wrp) assert.NoError(t, err) assert.Equal(t, "test_body", string(b)) - assert.Equal(t, 2, body.readCount) + // Check body read count and close count + assert.NotEqual(t, 0, body.readCount) assert.Equal(t, 1, body.closeCount) - assert.Equal(t, 1, cancelCount) }) - t.Run("rewind - read body until EOF (using io.ReadAll), then call Close, then call GetBody", func(t *testing.T) { + t.Run("read all - close - get body - read again", func(t *testing.T) { body := newMockBody("test_body") - cancelCount := 0 - cancelFn := func() { - cancelCount++ - } + wrp := newBodyWrapper(body, nil) - wrp := newBodyWrapper(body, cancelFn) + // Read body until EOF + b, err := io.ReadAll(wrp) + assert.NoError(t, err) + assert.Equal(t, "test_body", string(b)) - wrp.Rewind() + // Close + err = wrp.Close() + assert.NoError(t, err) + + // Call GetBody and read from it rd, err := wrp.GetBody() assert.NoError(t, err) - // rewind the body again to make the entire body available to be read - wrp.Rewind() + b, err = io.ReadAll(rd) + assert.NoError(t, err) + assert.Equal(t, "test_body", string(b)) - // close the wrapper + // Try to read more + b, err = io.ReadAll(wrp) + assert.NoError(t, err) + assert.Equal(t, "", string(b)) + + // Check body read count and close count + assert.NotEqual(t, 0, body.readCount) + assert.Equal(t, 1, body.closeCount) + }) + + t.Run("read all - close - rewind - get body - read again", func(t *testing.T) { + body := newMockBody("test_body") + + wrp := newBodyWrapper(body, nil) + + // Read body until EOF + b, err := io.ReadAll(wrp) + assert.NoError(t, err) + assert.Equal(t, "test_body", string(b)) + + // Close err = wrp.Close() assert.NoError(t, err) - // call GetBody after Close should not return an error - _, err = wrp.GetBody() + // Rewind + wrp.Rewind() + + // Call GetBody and read from it + rd, err := wrp.GetBody() + assert.NoError(t, err) + b, err = io.ReadAll(rd) assert.NoError(t, err) + assert.Equal(t, "test_body", string(b)) - b, err := ioutil.ReadAll(rd) + // Try to read more + b, err = io.ReadAll(wrp) assert.NoError(t, err) assert.Equal(t, "test_body", string(b)) - assert.Equal(t, 2, body.readCount) + // Check body read count and close count + assert.NotEqual(t, 0, body.readCount) assert.Equal(t, 1, body.closeCount) - assert.Equal(t, 1, cancelCount) }) }