diff --git a/http/middleware.go b/http/middleware.go index 4565e7a..32e30fe 100644 --- a/http/middleware.go +++ b/http/middleware.go @@ -2,24 +2,43 @@ package http import ( "github.com/eduboard/backend" + "log" "net/http" ) -func NewAuthMiddleware(provider eduboard.UserAuthenticationProvider, nextHandler http.Handler) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - cookie, err := r.Cookie("sessionID") - if err != nil { - w.WriteHeader(http.StatusForbidden) - return - } +// Chain takes a final http.Handler and a list of Middlewares and builds a call chain such that +// an incoming request passes all Middlwares in the order they were appended and finally reaches final. +func Chain(final http.Handler, m ...func(handler http.Handler) http.Handler) http.Handler { + for i := len(m) - 1; i >= 0; i-- { + final = m[i](final) + } + return final +} + +func NewAuthMiddleware(provider eduboard.UserAuthenticationProvider) func(handler http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie("sessionID") + if err != nil { + w.WriteHeader(http.StatusForbidden) + return + } - sessionID := cookie.Value - err, ok := provider.CheckAuthentication(sessionID) - if err != nil || !ok { - w.WriteHeader(http.StatusForbidden) - return - } + sessionID := cookie.Value + err, ok := provider.CheckAuthentication(sessionID) + if err != nil || !ok { + w.WriteHeader(http.StatusForbidden) + return + } - nextHandler.ServeHTTP(w, r) + next.ServeHTTP(w, r) + }) } -} \ No newline at end of file +} + +func Logger(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Printf("%s: %s %s", r.RemoteAddr, r.Method, r.URL.Path) + next.ServeHTTP(w, r) + }) +} diff --git a/http/middleware_test.go b/http/middleware_test.go index a9405de..6b732b9 100644 --- a/http/middleware_test.go +++ b/http/middleware_test.go @@ -4,11 +4,38 @@ import ( "errors" "github.com/eduboard/backend/mock" "github.com/stretchr/testify/assert" + "io/ioutil" "net/http" "net/http/httptest" "testing" ) +func TestChain(t *testing.T) { + var final, oneCalled, twoCalled, threeCalled = &mock.Check{}, &mock.Check{}, &mock.Check{}, &mock.Check{} + finalHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.True(t, oneCalled.Passed, "first wrapped function not called") + assert.True(t, twoCalled.Passed, "second wrapped function not called") + assert.True(t, threeCalled.Passed, "third wrapped function not called") + final.Passed = true + n, err := w.Write([]byte("ok")) + assert.Equal(t, 2, n) + assert.Nil(t, err) + }) + + handlers := mock.GenerateCheckedMiddlewares(oneCalled, twoCalled, threeCalled) + c := Chain(finalHandler, handlers...) + req := httptest.NewRequest("", "/", nil) + rr := httptest.NewRecorder() + c.ServeHTTP(rr, req) + + res := rr.Result() + defer res.Body.Close() + resBody, _ := ioutil.ReadAll(res.Body) + + assert.True(t, final.Passed, "inner wrapped function not called") + assert.Equal(t, "ok", string(resBody), "response not correct") +} + func TestAppServer_NewAuthMiddleware(t *testing.T) { var as = &mock.UserAuthenticationProvider{ CheckAuthenticationFn: func(sessionID string) (err error, ok bool) { @@ -44,7 +71,7 @@ func TestAppServer_NewAuthMiddleware(t *testing.T) { } t.Run(v.name, func(t *testing.T) { - handler := NewAuthMiddleware(as, testHandler) + handler := NewAuthMiddleware(as)(testHandler) if v.enter { assert.True(t, handlerEntered, "handler was not entered") } @@ -61,4 +88,4 @@ func TestAppServer_NewAuthMiddleware(t *testing.T) { assert.True(t, as.CheckAuthenticationFnInvoked, "authentication was not actually checked") }) } -} \ No newline at end of file +} diff --git a/http/server.go b/http/server.go index f739b8f..4fa768b 100644 --- a/http/server.go +++ b/http/server.go @@ -19,10 +19,12 @@ func (a *AppServer) initialize() { protected := a.authenticatedRoutes() public := a.publicRoutes() + privateChain := Chain(protected, Logger, NewAuthMiddleware(a.UserService)) + mux := http.NewServeMux() - mux.Handle("/api/v1/", NewAuthMiddleware(a.UserService, protected)) + mux.Handle("/api/v1/", privateChain) mux.Handle("/api/", public) - mux.Handle("/", http.FileServer(http.Dir(a.Static))) + mux.Handle("/", Logger(http.FileServer(http.Dir(a.Static)))) a.httpServer = &http.Server{ Addr: a.Host, diff --git a/mock/middleware.go b/mock/middleware.go new file mode 100644 index 0000000..e97212c --- /dev/null +++ b/mock/middleware.go @@ -0,0 +1,22 @@ +package mock + +import "net/http" + +type Check struct { + Passed bool +} + +func GenerateCheckedMiddlewares(checks ...*Check) []func(http.Handler) http.Handler { + var handlers = make([]func(http.Handler) http.Handler, len(checks)) + for k, c := range checks { + func(c *Check) { // Wrapping the function call like this creates a closure, making sure `c` does not change before evaluation. + handlers[k] = func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c.Passed = true + handler.ServeHTTP(w, r) + }) + } + }(c) + } + return handlers +}