From fc89ea6ff042714b9c07777d6ec3974b39b47ac2 Mon Sep 17 00:00:00 2001 From: 9seconds Date: Tue, 15 Jun 2021 16:17:31 +0300 Subject: [PATCH] Correctly restore HTTP protocol on doing a request --- headers/headers_test.go | 4 ++-- headers/wrappers.go | 12 +++++++++++- headers/wrappers_test.go | 7 ++++--- layers/ctx_test.go | 19 ------------------- server_test.go | 7 +++++++ 5 files changed, 24 insertions(+), 25 deletions(-) diff --git a/headers/headers_test.go b/headers/headers_test.go index d50dbae1..66b2c73e 100644 --- a/headers/headers_test.go +++ b/headers/headers_test.go @@ -38,7 +38,7 @@ func (suite *HeadersTestSuite) TearDownTest() { } func (suite *HeadersTestSuite) TestCheckHeaders() { - suite.Len(suite.hdrs.Headers, 2) + suite.Len(suite.hdrs.Headers, 3) headerNames := map[string]bool{ "Accept-Encoding": true, @@ -147,7 +147,7 @@ func (suite *HeadersTestSuite) TestSetNoCleanup() { func (suite *HeadersTestSuite) TestSetUnknown() { suite.hdrs.Set("hello", "NewValue", false) - suite.Len(suite.hdrs.Headers, 3) + suite.Len(suite.hdrs.Headers, 4) header := suite.hdrs.GetFirst("hello") diff --git a/headers/wrappers.go b/headers/wrappers.go index 9d672dd7..21ed5694 100644 --- a/headers/wrappers.go +++ b/headers/wrappers.go @@ -15,6 +15,7 @@ func (r requestHeaderWrapper) Read(rd io.Reader) error { method := append([]byte(nil), r.ref.Method()...) requestURI := append([]byte(nil), r.ref.RequestURI()...) host := append([]byte(nil), r.ref.Host()...) + protocol := append([]byte(nil), r.ref.Protocol()...) r.ref.Reset() r.ref.DisableNormalizing() @@ -26,6 +27,7 @@ func (r requestHeaderWrapper) Read(rd io.Reader) error { return errors.Annotate(err, "cannot read request headers", "headers_sync", 0) } + r.ref.SetProtocolBytes(protocol) r.ref.SetHostBytes(host) r.ref.SetMethodBytes(method) r.ref.SetRequestURIBytes(requestURI) @@ -42,7 +44,15 @@ func (r requestHeaderWrapper) ResetConnectionClose() { } func (r requestHeaderWrapper) Headers() []byte { - return r.ref.RawHeaders() + buf := append([]byte(nil), r.ref.Method()...) + buf = append(buf, ' ') + buf = append(buf, r.ref.RequestURI()...) + buf = append(buf, ' ') + buf = append(buf, r.ref.Protocol()...) + buf = append(buf, '\r', '\n') + buf = append(buf, r.ref.RawHeaders()...) + + return buf } type responseHeaderWrapper struct { diff --git a/headers/wrappers_test.go b/headers/wrappers_test.go index c4548e1b..425ab7ff 100644 --- a/headers/wrappers_test.go +++ b/headers/wrappers_test.go @@ -84,11 +84,12 @@ func (suite *RequestHeaderWrapperTestSuite) TestRawHeaders() { suite.hdr.SetConnectionClose() request := []string{ + "GET http://example.com HTTP/1.1", "Host: example.com", "accept: deflate", "connection: close", } - fullRequest := strings.Join(append([]string{"GET / HTTP/1.1"}, request...), "\r\n") + "\r\n\r\n" + fullRequest := strings.Join(request, "\r\n") + "\r\n\r\n" suite.NoError(suite.wrp.Read(strings.NewReader(fullRequest))) suite.Equal([]byte(strings.Join(request, "\r\n")+"\r\n\r\n"), suite.wrp.Headers()) @@ -116,7 +117,7 @@ func (suite *ResponseWrapperTestSuite) TestCorrectRestore() { }, "\r\n") + "\r\n\r\n" suite.NoError(suite.wrp.Read(strings.NewReader(request))) - suite.Equal(fasthttp.StatusCreated, suite.hdr.StatusCode()) + suite.Equal(fasthttp.StatusCreated, suite.hdr.StatusCode()) } func (suite *ResponseWrapperTestSuite) TestDisableNormalizing() { @@ -137,5 +138,5 @@ func TestRequestHeaderWrapper(t *testing.T) { } func TestResponseHeaderWrapper(t *testing.T) { - suite.Run(t, &ResponseWrapperTestSuite{}) + suite.Run(t, &ResponseWrapperTestSuite{}) } diff --git a/layers/ctx_test.go b/layers/ctx_test.go index f3e6b499..87bb1843 100644 --- a/layers/ctx_test.go +++ b/layers/ctx_test.go @@ -116,25 +116,6 @@ func (suite *ContextTestSuite) TestRespond() { suite.Equal("text/plain", string(resp.Header.ContentType())) } -func (suite *ContextTestSuite) TestErrorRequest() { - ctx := layers.AcquireContext() - defer layers.ReleaseContext(ctx) - - fhttpCtx := &fasthttp.RequestCtx{} - remoteAddr := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 65342, - } - - fhttpCtx.Init(&fasthttp.Request{}, remoteAddr, nil) - - suite.Error(ctx.Init(fhttpCtx, - "127.0.0.1:8000", - suite.eventsChannel, - "user", - events.RequestTypeTLS)) -} - func (suite *ContextTestSuite) TestErrorGeneril() { suite.ctx.Error(io.EOF) diff --git a/server_test.go b/server_test.go index e2feae92..242f802a 100644 --- a/server_test.go +++ b/server_test.go @@ -220,6 +220,13 @@ func (suite *ServerTestSuite) TestHTTPSAuthRequired() { suite.Error(err) } +func (suite *ServerTestSuite) TestGolangOrg() { + resp, err := suite.http.Get("https://golang.org") + + suite.NoError(err) + suite.Equal(http.StatusOK, resp.StatusCode) +} + func TestServer(t *testing.T) { suite.Run(t, &ServerTestSuite{}) }