diff --git a/http/middleware.go b/http/middleware.go index 85be9ce..32e30fe 100644 --- a/http/middleware.go +++ b/http/middleware.go @@ -6,18 +6,16 @@ import ( "net/http" ) -type Middlware func(handler http.Handler) http.Handler - // Chain takes a final http.Handler and a list of Middlewares and builds a call chain such that // an incoming request passes all Middlwares in the order they were appended and finally reaches final. -func Chain(final http.Handler, m ...Middlware) http.Handler { +func Chain(final http.Handler, m ...func(handler http.Handler) http.Handler) http.Handler { for i := len(m) - 1; i >= 0; i-- { final = m[i](final) } return final } -func NewAuthMiddleware(provider eduboard.UserAuthenticationProvider) Middlware { +func NewAuthMiddleware(provider eduboard.UserAuthenticationProvider) func(handler http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("sessionID") diff --git a/http/middleware_test.go b/http/middleware_test.go index c4b7631..6b732b9 100644 --- a/http/middleware_test.go +++ b/http/middleware_test.go @@ -4,11 +4,38 @@ import ( "errors" "github.com/eduboard/backend/mock" "github.com/stretchr/testify/assert" + "io/ioutil" "net/http" "net/http/httptest" "testing" ) +func TestChain(t *testing.T) { + var final, oneCalled, twoCalled, threeCalled = &mock.Check{}, &mock.Check{}, &mock.Check{}, &mock.Check{} + finalHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.True(t, oneCalled.Passed, "first wrapped function not called") + assert.True(t, twoCalled.Passed, "second wrapped function not called") + assert.True(t, threeCalled.Passed, "third wrapped function not called") + final.Passed = true + n, err := w.Write([]byte("ok")) + assert.Equal(t, 2, n) + assert.Nil(t, err) + }) + + handlers := mock.GenerateCheckedMiddlewares(oneCalled, twoCalled, threeCalled) + c := Chain(finalHandler, handlers...) + req := httptest.NewRequest("", "/", nil) + rr := httptest.NewRecorder() + c.ServeHTTP(rr, req) + + res := rr.Result() + defer res.Body.Close() + resBody, _ := ioutil.ReadAll(res.Body) + + assert.True(t, final.Passed, "inner wrapped function not called") + assert.Equal(t, "ok", string(resBody), "response not correct") +} + func TestAppServer_NewAuthMiddleware(t *testing.T) { var as = &mock.UserAuthenticationProvider{ CheckAuthenticationFn: func(sessionID string) (err error, ok bool) { @@ -61,4 +88,4 @@ func TestAppServer_NewAuthMiddleware(t *testing.T) { assert.True(t, as.CheckAuthenticationFnInvoked, "authentication was not actually checked") }) } -} \ No newline at end of file +} diff --git a/mock/middleware.go b/mock/middleware.go new file mode 100644 index 0000000..e97212c --- /dev/null +++ b/mock/middleware.go @@ -0,0 +1,22 @@ +package mock + +import "net/http" + +type Check struct { + Passed bool +} + +func GenerateCheckedMiddlewares(checks ...*Check) []func(http.Handler) http.Handler { + var handlers = make([]func(http.Handler) http.Handler, len(checks)) + for k, c := range checks { + func(c *Check) { // Wrapping the function call like this creates a closure, making sure `c` does not change before evaluation. + handlers[k] = func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c.Passed = true + handler.ServeHTTP(w, r) + }) + } + }(c) + } + return handlers +}