diff --git a/internal/auth/auth.go b/internal/auth/auth.go index ff13218..7e60b83 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -2,17 +2,19 @@ package auth import ( "net/http" + "strings" ) func URLTokenAuth(token string) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - authToken := r.URL.Query().Get("auth_token") - if authToken == "" || authToken != token { + pathParts := strings.Split(r.URL.Path, "/") + if len(pathParts) < 2 || pathParts[len(pathParts)-1] != token { w.WriteHeader(http.StatusUnauthorized) - return } + // Remove the token part from the path to forward the request to the next handler + r.URL.Path = strings.Join(pathParts[:len(pathParts)-1], "/") next.ServeHTTP(w, r) }) } diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 1952e0c..c4ee7f5 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -6,8 +6,9 @@ import ( "testing" ) -func TestUrlTokenAuth(t *testing.T) { +func TestURLTokenAuth(t *testing.T) { validToken := "valid_token" + invalidToken := "invalid_token" middleware := URLTokenAuth(validToken) tests := []struct { @@ -17,17 +18,17 @@ func TestUrlTokenAuth(t *testing.T) { }{ { name: "Valid token", - url: "/?auth_token=valid_token", + url: "/some/path/valid_token", expectedStatus: http.StatusOK, }, { name: "Invalid token", - url: "/?auth_token=invalid_token", + url: "/some/path/invalid_token", expectedStatus: http.StatusUnauthorized, }, { name: "Missing token", - url: "/", + url: "/some/path/", expectedStatus: http.StatusUnauthorized, }, }