From 8f1a59780508ed3c55a08ba4faee13526e2612d6 Mon Sep 17 00:00:00 2001 From: MakMuftic Date: Thu, 23 May 2024 11:16:01 +0200 Subject: [PATCH] add unit tests --- internal/auth/auth.go | 18 +++++++++++++ internal/auth/auth_test.go | 54 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+) create mode 100644 internal/auth/auth.go create mode 100644 internal/auth/auth_test.go diff --git a/internal/auth/auth.go b/internal/auth/auth.go new file mode 100644 index 0000000..efcc294 --- /dev/null +++ b/internal/auth/auth.go @@ -0,0 +1,18 @@ +package auth + +import ( + "net/http" +) + +func UrlTokenAuth(token string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authToken := r.URL.Query().Get("auth_token") + if authToken == "" || authToken != token { + w.WriteHeader(http.StatusUnauthorized) + return + } + next.ServeHTTP(w, r) + }) + } +} diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go new file mode 100644 index 0000000..f809b64 --- /dev/null +++ b/internal/auth/auth_test.go @@ -0,0 +1,54 @@ +package auth + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestUrlTokenAuth(t *testing.T) { + validToken := "valid_token" + middleware := UrlTokenAuth(validToken) + + tests := []struct { + name string + url string + expectedStatus int + }{ + { + name: "Valid token", + url: "/?auth_token=valid_token", + expectedStatus: http.StatusOK, + }, + { + name: "Invalid token", + url: "/?auth_token=invalid_token", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "Missing token", + url: "/", + expectedStatus: http.StatusUnauthorized, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, err := http.NewRequest("GET", tt.url, nil) + if err != nil { + t.Fatalf("could not create request: %v", err) + } + + rr := httptest.NewRecorder() + handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + handler.ServeHTTP(rr, req) + + if rr.Code != tt.expectedStatus { + t.Errorf("expected status %v; got %v", tt.expectedStatus, rr.Code) + } + }) + } +}