From a645839425f1ab77ec143a82f7c3dfc07f742548 Mon Sep 17 00:00:00 2001 From: Sophie Turner Date: Thu, 5 Dec 2024 18:03:16 +0500 Subject: [PATCH] Unit Test ConnectionCodeContext --- auth/auth.go | 6 +++ auth/auth_test.go | 135 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 141 insertions(+) diff --git a/auth/auth.go b/auth/auth.go index 71a698b44..d46364549 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -149,6 +149,12 @@ func PubKeyContextSuperAdmin(next http.Handler) http.Handler { // ConnectionContext parses token for connection code func ConnectionCodeContext(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + if r == nil { + http.Error(w, http.StatusText(500), http.StatusInternalServerError) + return + } + token := r.Header.Get("token") if token == "" { diff --git a/auth/auth_test.go b/auth/auth_test.go index d806ea928..7e0fd4a2f 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -5,6 +5,8 @@ import ( "encoding/hex" "errors" "fmt" + "net/http" + "net/http/httptest" "strings" "testing" @@ -531,3 +533,136 @@ func TestVerifyAndExtract(t *testing.T) { }) } } + +func TestConnectionCodeContext(t *testing.T) { + + config.Connection_Auth = "valid_token" + + tests := []struct { + name string + token string + expectedStatus int + expectNextCall bool + }{ + { + name: "Valid Token in Header", + token: "valid_token", + expectedStatus: http.StatusOK, + expectNextCall: true, + }, + { + name: "Invalid Token in Header", + token: "invalid_token", + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + { + name: "Empty Token in Header", + token: "", + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + { + name: "No Token Header Present", + token: "", + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + { + name: "Malformed Header", + token: "malformed_header", + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + { + name: "Token with Special Characters", + token: "special!@#token", + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + { + name: "Token with Whitespace", + token: " " + config.Connection_Auth + " ", + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + { + name: "Case Sensitivity in Token", + token: strings.ToUpper(config.Connection_Auth), + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + nextCalled := false + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tt.token != "" { + req.Header.Set("token", tt.token) + } + + rr := httptest.NewRecorder() + + handler := ConnectionCodeContext(next) + handler.ServeHTTP(rr, req) + + assert.Equal(t, tt.expectedStatus, rr.Code) + + assert.Equal(t, tt.expectNextCall, nextCalled) + }) + } + + t.Run("Null Request Object", func(t *testing.T) { + + nextCalled := false + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + }) + + handler := ConnectionCodeContext(next) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, nil) + + assert.Equal(t, http.StatusInternalServerError, rr.Code) + + assert.False(t, nextCalled) + }) + + t.Run("Large Number of Requests", func(t *testing.T) { + + nextCalled := 0 + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextCalled++ + w.WriteHeader(http.StatusOK) + }) + + for i := 0; i < 1000; i++ { + req := httptest.NewRequest(http.MethodGet, "/", nil) + if i%2 == 0 { + req.Header.Set("token", "valid_token") + } else { + req.Header.Set("token", "invalid_token") + } + + rr := httptest.NewRecorder() + handler := ConnectionCodeContext(next) + handler.ServeHTTP(rr, req) + + if i%2 == 0 { + assert.Equal(t, http.StatusOK, rr.Code) + } else { + assert.Equal(t, http.StatusUnauthorized, rr.Code) + } + } + + assert.Equal(t, 500, nextCalled) + }) +}