From 1d01877b486dde69ecd51acb80931897ec1af535 Mon Sep 17 00:00:00 2001 From: cosmonawt Date: Wed, 6 Jun 2018 09:11:48 +0200 Subject: [PATCH 1/2] http: logging middleware and function to conveniently chain multiple middlewares --- http/middleware.go | 51 +++++++++++++++++++++++++++++------------ http/middleware_test.go | 2 +- http/server.go | 6 +++-- 3 files changed, 41 insertions(+), 18 deletions(-) diff --git a/http/middleware.go b/http/middleware.go index 4565e7a..85be9ce 100644 --- a/http/middleware.go +++ b/http/middleware.go @@ -2,24 +2,45 @@ package http import ( "github.com/eduboard/backend" + "log" "net/http" ) -func NewAuthMiddleware(provider eduboard.UserAuthenticationProvider, nextHandler http.Handler) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - cookie, err := r.Cookie("sessionID") - if err != nil { - w.WriteHeader(http.StatusForbidden) - return - } +type Middlware func(handler http.Handler) http.Handler - sessionID := cookie.Value - err, ok := provider.CheckAuthentication(sessionID) - if err != nil || !ok { - w.WriteHeader(http.StatusForbidden) - return - } +// Chain takes a final http.Handler and a list of Middlewares and builds a call chain such that +// an incoming request passes all Middlwares in the order they were appended and finally reaches final. +func Chain(final http.Handler, m ...Middlware) http.Handler { + for i := len(m) - 1; i >= 0; i-- { + final = m[i](final) + } + return final +} + +func NewAuthMiddleware(provider eduboard.UserAuthenticationProvider) Middlware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie("sessionID") + if err != nil { + w.WriteHeader(http.StatusForbidden) + return + } - nextHandler.ServeHTTP(w, r) + sessionID := cookie.Value + err, ok := provider.CheckAuthentication(sessionID) + if err != nil || !ok { + w.WriteHeader(http.StatusForbidden) + return + } + + next.ServeHTTP(w, r) + }) } -} \ No newline at end of file +} + +func Logger(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Printf("%s: %s %s", r.RemoteAddr, r.Method, r.URL.Path) + next.ServeHTTP(w, r) + }) +} diff --git a/http/middleware_test.go b/http/middleware_test.go index a9405de..c4b7631 100644 --- a/http/middleware_test.go +++ b/http/middleware_test.go @@ -44,7 +44,7 @@ func TestAppServer_NewAuthMiddleware(t *testing.T) { } t.Run(v.name, func(t *testing.T) { - handler := NewAuthMiddleware(as, testHandler) + handler := NewAuthMiddleware(as)(testHandler) if v.enter { assert.True(t, handlerEntered, "handler was not entered") } diff --git a/http/server.go b/http/server.go index f739b8f..4fa768b 100644 --- a/http/server.go +++ b/http/server.go @@ -19,10 +19,12 @@ func (a *AppServer) initialize() { protected := a.authenticatedRoutes() public := a.publicRoutes() + privateChain := Chain(protected, Logger, NewAuthMiddleware(a.UserService)) + mux := http.NewServeMux() - mux.Handle("/api/v1/", NewAuthMiddleware(a.UserService, protected)) + mux.Handle("/api/v1/", privateChain) mux.Handle("/api/", public) - mux.Handle("/", http.FileServer(http.Dir(a.Static))) + mux.Handle("/", Logger(http.FileServer(http.Dir(a.Static)))) a.httpServer = &http.Server{ Addr: a.Host, From d9bbc830c04d2949fe4c36068203094aab3134bd Mon Sep 17 00:00:00 2001 From: cosmonawt Date: Thu, 7 Jun 2018 07:13:24 +0200 Subject: [PATCH 2/2] http: test Chain in middleware_test.go --- http/middleware.go | 6 ++---- http/middleware_test.go | 29 ++++++++++++++++++++++++++++- mock/middleware.go | 22 ++++++++++++++++++++++ 3 files changed, 52 insertions(+), 5 deletions(-) create mode 100644 mock/middleware.go diff --git a/http/middleware.go b/http/middleware.go index 85be9ce..32e30fe 100644 --- a/http/middleware.go +++ b/http/middleware.go @@ -6,18 +6,16 @@ import ( "net/http" ) -type Middlware func(handler http.Handler) http.Handler - // Chain takes a final http.Handler and a list of Middlewares and builds a call chain such that // an incoming request passes all Middlwares in the order they were appended and finally reaches final. -func Chain(final http.Handler, m ...Middlware) http.Handler { +func Chain(final http.Handler, m ...func(handler http.Handler) http.Handler) http.Handler { for i := len(m) - 1; i >= 0; i-- { final = m[i](final) } return final } -func NewAuthMiddleware(provider eduboard.UserAuthenticationProvider) Middlware { +func NewAuthMiddleware(provider eduboard.UserAuthenticationProvider) func(handler http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("sessionID") diff --git a/http/middleware_test.go b/http/middleware_test.go index c4b7631..6b732b9 100644 --- a/http/middleware_test.go +++ b/http/middleware_test.go @@ -4,11 +4,38 @@ import ( "errors" "github.com/eduboard/backend/mock" "github.com/stretchr/testify/assert" + "io/ioutil" "net/http" "net/http/httptest" "testing" ) +func TestChain(t *testing.T) { + var final, oneCalled, twoCalled, threeCalled = &mock.Check{}, &mock.Check{}, &mock.Check{}, &mock.Check{} + finalHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.True(t, oneCalled.Passed, "first wrapped function not called") + assert.True(t, twoCalled.Passed, "second wrapped function not called") + assert.True(t, threeCalled.Passed, "third wrapped function not called") + final.Passed = true + n, err := w.Write([]byte("ok")) + assert.Equal(t, 2, n) + assert.Nil(t, err) + }) + + handlers := mock.GenerateCheckedMiddlewares(oneCalled, twoCalled, threeCalled) + c := Chain(finalHandler, handlers...) + req := httptest.NewRequest("", "/", nil) + rr := httptest.NewRecorder() + c.ServeHTTP(rr, req) + + res := rr.Result() + defer res.Body.Close() + resBody, _ := ioutil.ReadAll(res.Body) + + assert.True(t, final.Passed, "inner wrapped function not called") + assert.Equal(t, "ok", string(resBody), "response not correct") +} + func TestAppServer_NewAuthMiddleware(t *testing.T) { var as = &mock.UserAuthenticationProvider{ CheckAuthenticationFn: func(sessionID string) (err error, ok bool) { @@ -61,4 +88,4 @@ func TestAppServer_NewAuthMiddleware(t *testing.T) { assert.True(t, as.CheckAuthenticationFnInvoked, "authentication was not actually checked") }) } -} \ No newline at end of file +} diff --git a/mock/middleware.go b/mock/middleware.go new file mode 100644 index 0000000..e97212c --- /dev/null +++ b/mock/middleware.go @@ -0,0 +1,22 @@ +package mock + +import "net/http" + +type Check struct { + Passed bool +} + +func GenerateCheckedMiddlewares(checks ...*Check) []func(http.Handler) http.Handler { + var handlers = make([]func(http.Handler) http.Handler, len(checks)) + for k, c := range checks { + func(c *Check) { // Wrapping the function call like this creates a closure, making sure `c` does not change before evaluation. + handlers[k] = func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c.Passed = true + handler.ServeHTTP(w, r) + }) + } + }(c) + } + return handlers +}