Skip to content

Commit

Permalink
CBG-4373: implement flusher interface (#7256)
Browse files Browse the repository at this point in the history
* CBG-4373: implement flusher interface on counted and non counted response writer

* add interface check in handler

* updates to address comments

* update remove the safety around flusher calls

* address comments
  • Loading branch information
gregns1 authored Jan 7, 2025
1 parent d61c392 commit cf667d8
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 4 deletions.
29 changes: 29 additions & 0 deletions rest/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2857,3 +2857,32 @@ func TestAllDbs(t *testing.T) {
RequireStatus(t, resp, http.StatusOK)
require.Equal(t, fmt.Sprintf(`[{"db_name":"%s","bucket":"%s","state":"Online"}]`, rt.GetDatabase().Name, rt.GetDatabase().Bucket.GetName()), resp.Body.String())
}

// TestBufferFlush will test for http.ResponseWriter implements Flusher interface
func TestBufferFlush(t *testing.T) {
rt := NewRestTester(t, &RestTesterConfig{
SyncFn: channels.DocChannelsSyncFunction,
})
defer rt.Close()
ctx := base.TestCtx(t)

a := rt.ServerContext().Database(ctx, "db").Authenticator(ctx)

// Create a test user
user, err := a.NewUser("foo", "letmein", channels.BaseSetOf(t, "foo"))
require.NoError(t, err)
require.NoError(t, a.Save(user))

var wg sync.WaitGroup
var resp *TestResponse
wg.Add(1)
go func() {
defer wg.Done()
resp = rt.SendUserRequest(http.MethodGet, "/{{.keyspace}}/_changes?feed=continuous&since=0&timeout=500&include_docs=true", "", "foo")
RequireStatus(t, resp, http.StatusOK)
}()
wg.Wait()

// assert that the response is a flushed response
assert.True(t, resp.Flushed)
}
4 changes: 4 additions & 0 deletions rest/counted_response_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,7 @@ func (w *CountedResponseWriter) isHijackable() bool {
_, ok := w.writer.(http.Hijacker)
return ok
}

func (w *CountedResponseWriter) Flush() {
w.writer.(http.Flusher).Flush()
}
15 changes: 15 additions & 0 deletions rest/counted_response_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

"github.com/couchbase/sync_gateway/base"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -169,3 +170,17 @@ func TestCountableResponseWriterWithDelay(t *testing.T) {
}

}

func TestResponseWriterSupportsFLush(t *testing.T) {
for _, test := range testCases {
t.Run(test.name, func(t *testing.T) {

stat, err := base.NewIntStat(base.SubsystemDatabaseKey, "http_bytes_written", base.StatUnitBytes, base.PublicRestBytesWrittenDesc, base.StatAddedVersion3dot1dot0, base.StatDeprecatedVersionNotDeprecated, base.StatStabilityCommitted, nil, nil, prometheus.CounterValue, 0)
require.NoError(t, err)
responseWriter := getResponseWriter(t, stat, test.name, 0)

_, ok := responseWriter.(http.Flusher)
assert.True(t, ok)
})
}
}
17 changes: 13 additions & 4 deletions rest/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ const (
minCompressibleJSONSize = 1000
)

var _ http.Flusher = &CountedResponseWriter{}
var _ http.Flusher = &NonCountedResponseWriter{}
var _ http.Flusher = &EncodedResponseWriter{}

var _ http.Hijacker = &CountedResponseWriter{}
var _ http.Hijacker = &NonCountedResponseWriter{}

var ErrInvalidLogin = base.HTTPErrorf(http.StatusUnauthorized, "Invalid login")
var ErrLoginRequired = base.HTTPErrorf(http.StatusUnauthorized, "Login required")

Expand Down Expand Up @@ -671,6 +678,11 @@ func (h *handler) validateAndWriteHeaders(method handlerMethod, accessPermission
}
}
h.updateResponseWriter()
// ensure wrapped ResponseWriter implements http.Flusher
_, ok := h.response.(http.Flusher)
if !ok {
return fmt.Errorf("http.ResponseWriter %T does not implement Flusher interface", h.response)
}
return nil
}

Expand Down Expand Up @@ -1595,10 +1607,7 @@ func (h *handler) writeMultipart(subtype string, callback func(*multipart.Writer
}

func (h *handler) flush() {
switch r := h.response.(type) {
case http.Flusher:
r.Flush()
}
h.response.(http.Flusher).Flush()
}

// If the error parameter is non-nil, sets the response status code appropriately and
Expand Down
4 changes: 4 additions & 0 deletions rest/non_counted_response_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,7 @@ func (w *NonCountedResponseWriter) isHijackable() bool {
_, ok := w.ResponseWriter.(http.Hijacker)
return ok
}

func (w *NonCountedResponseWriter) Flush() {
w.ResponseWriter.(http.Flusher).Flush()
}

0 comments on commit cf667d8

Please sign in to comment.