Skip to content

Commit

Permalink
Merge pull request #27 from eduboard/leo/http-logging
Browse files Browse the repository at this point in the history
http: logging middleware and Chain middlwares
  • Loading branch information
Cosmonawt authored Jun 7, 2018
2 parents b081e46 + d9bbc83 commit 20e7c21
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 19 deletions.
49 changes: 34 additions & 15 deletions http/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,43 @@ package http

import (
"github.com/eduboard/backend"
"log"
"net/http"
)

func NewAuthMiddleware(provider eduboard.UserAuthenticationProvider, nextHandler http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("sessionID")
if err != nil {
w.WriteHeader(http.StatusForbidden)
return
}
// Chain takes a final http.Handler and a list of Middlewares and builds a call chain such that
// an incoming request passes all Middlwares in the order they were appended and finally reaches final.
func Chain(final http.Handler, m ...func(handler http.Handler) http.Handler) http.Handler {
for i := len(m) - 1; i >= 0; i-- {
final = m[i](final)
}
return final
}

func NewAuthMiddleware(provider eduboard.UserAuthenticationProvider) func(handler http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("sessionID")
if err != nil {
w.WriteHeader(http.StatusForbidden)
return
}

sessionID := cookie.Value
err, ok := provider.CheckAuthentication(sessionID)
if err != nil || !ok {
w.WriteHeader(http.StatusForbidden)
return
}
sessionID := cookie.Value
err, ok := provider.CheckAuthentication(sessionID)
if err != nil || !ok {
w.WriteHeader(http.StatusForbidden)
return
}

nextHandler.ServeHTTP(w, r)
next.ServeHTTP(w, r)
})
}
}
}

func Logger(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("%s: %s %s", r.RemoteAddr, r.Method, r.URL.Path)
next.ServeHTTP(w, r)
})
}
31 changes: 29 additions & 2 deletions http/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,38 @@ import (
"errors"
"github.com/eduboard/backend/mock"
"github.com/stretchr/testify/assert"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
)

func TestChain(t *testing.T) {
var final, oneCalled, twoCalled, threeCalled = &mock.Check{}, &mock.Check{}, &mock.Check{}, &mock.Check{}
finalHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.True(t, oneCalled.Passed, "first wrapped function not called")
assert.True(t, twoCalled.Passed, "second wrapped function not called")
assert.True(t, threeCalled.Passed, "third wrapped function not called")
final.Passed = true
n, err := w.Write([]byte("ok"))
assert.Equal(t, 2, n)
assert.Nil(t, err)
})

handlers := mock.GenerateCheckedMiddlewares(oneCalled, twoCalled, threeCalled)
c := Chain(finalHandler, handlers...)
req := httptest.NewRequest("", "/", nil)
rr := httptest.NewRecorder()
c.ServeHTTP(rr, req)

res := rr.Result()
defer res.Body.Close()
resBody, _ := ioutil.ReadAll(res.Body)

assert.True(t, final.Passed, "inner wrapped function not called")
assert.Equal(t, "ok", string(resBody), "response not correct")
}

func TestAppServer_NewAuthMiddleware(t *testing.T) {
var as = &mock.UserAuthenticationProvider{
CheckAuthenticationFn: func(sessionID string) (err error, ok bool) {
Expand Down Expand Up @@ -44,7 +71,7 @@ func TestAppServer_NewAuthMiddleware(t *testing.T) {
}

t.Run(v.name, func(t *testing.T) {
handler := NewAuthMiddleware(as, testHandler)
handler := NewAuthMiddleware(as)(testHandler)
if v.enter {
assert.True(t, handlerEntered, "handler was not entered")
}
Expand All @@ -61,4 +88,4 @@ func TestAppServer_NewAuthMiddleware(t *testing.T) {
assert.True(t, as.CheckAuthenticationFnInvoked, "authentication was not actually checked")
})
}
}
}
6 changes: 4 additions & 2 deletions http/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ func (a *AppServer) initialize() {
protected := a.authenticatedRoutes()
public := a.publicRoutes()

privateChain := Chain(protected, Logger, NewAuthMiddleware(a.UserService))

mux := http.NewServeMux()
mux.Handle("/api/v1/", NewAuthMiddleware(a.UserService, protected))
mux.Handle("/api/v1/", privateChain)
mux.Handle("/api/", public)
mux.Handle("/", http.FileServer(http.Dir(a.Static)))
mux.Handle("/", Logger(http.FileServer(http.Dir(a.Static))))

a.httpServer = &http.Server{
Addr: a.Host,
Expand Down
22 changes: 22 additions & 0 deletions mock/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package mock

import "net/http"

type Check struct {
Passed bool
}

func GenerateCheckedMiddlewares(checks ...*Check) []func(http.Handler) http.Handler {
var handlers = make([]func(http.Handler) http.Handler, len(checks))
for k, c := range checks {
func(c *Check) { // Wrapping the function call like this creates a closure, making sure `c` does not change before evaluation.
handlers[k] = func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c.Passed = true
handler.ServeHTTP(w, r)
})
}
}(c)
}
return handlers
}

0 comments on commit 20e7c21

Please sign in to comment.