diff --git a/pkg/config/jwtmiddleware/jwt_middleware_test.go b/pkg/config/jwtmiddleware/jwt_middleware_test.go index 5ff4a9f2..a0bb01d0 100644 --- a/pkg/config/jwtmiddleware/jwt_middleware_test.go +++ b/pkg/config/jwtmiddleware/jwt_middleware_test.go @@ -2,6 +2,7 @@ package jwtmiddleware import ( "encoding/json" + "errors" "fmt" "io" "net/http" @@ -332,3 +333,79 @@ func TestJWTMiddleware_Handler(t *testing.T) { }) } } + +func TestFromFirst(t *testing.T) { + // Mock TokenExtractor that returns a token or error based on input. + mockExtractor := func(token string, err error) TokenExtractor { + return func(r *http.Request) (string, error) { + return token, err + } + } + + tests := []struct { + name string + extractors []TokenExtractor + request *http.Request + wantToken string + wantErr bool + }{ + { + name: "First extractor returns valid token", + extractors: []TokenExtractor{ + mockExtractor("token1", nil), + mockExtractor("token2", nil), + }, + request: httptest.NewRequest("GET", "/", nil), + wantToken: "token1", + wantErr: false, + }, + { + name: "First extractor returns error, second returns valid token", + extractors: []TokenExtractor{ + mockExtractor("", errors.New("error")), + mockExtractor("token2", nil), + }, + request: httptest.NewRequest("GET", "/", nil), + wantToken: "", + wantErr: true, + }, + { + name: "All extractors return empty token", + extractors: []TokenExtractor{ + mockExtractor("", nil), + mockExtractor("", nil), + }, + request: httptest.NewRequest("GET", "/", nil), + wantToken: "", + wantErr: false, + }, + { + name: "First extractor returns error, second returns empty token", + extractors: []TokenExtractor{ + mockExtractor("", errors.New("error")), + mockExtractor("", nil), + }, + request: httptest.NewRequest("GET", "/", nil), + wantToken: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Call FromFirst with the mock extractors + token, err := FromFirst(tt.extractors...)(tt.request) + + // Check if error matches the expectation + if (err != nil) != tt.wantErr { + t.Errorf("FromFirst() error = %v, wantErr %v", err, tt.wantErr) + return + } + + // Check if token matches the expectation + if token != tt.wantToken { + t.Errorf("FromFirst() token = %v, want %v", token, tt.wantToken) + } + }) + } +}