diff --git a/cmd/argocd/commands/login.go b/cmd/argocd/commands/login.go index 2843972e7e476..7a8fbac81b698 100644 --- a/cmd/argocd/commands/login.go +++ b/cmd/argocd/commands/login.go @@ -20,6 +20,7 @@ import ( "golang.org/x/oauth2" "github.com/argoproj/argo-cd/v2/cmd/argocd/commands/headless" + "github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils" argocdclient "github.com/argoproj/argo-cd/v2/pkg/apiclient" sessionpkg "github.com/argoproj/argo-cd/v2/pkg/apiclient/session" settingspkg "github.com/argoproj/argo-cd/v2/pkg/apiclient/settings" @@ -196,9 +197,21 @@ func userDisplayName(claims jwt.MapClaims) string { if name := jwtutil.StringField(claims, "name"); name != "" { return name } - return jwtutil.StringField(claims, "sub") + argoClaims := &utils.ArgoClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: claims["sub"].(string), + }, + } + if fedClaims, ok := claims["federated_claims"].(map[string]any); ok { + argoClaims.FederatedClaims = &utils.FederatedClaims{ + ConnectorID: fedClaims["connector_id"].(string), + UserID: fedClaims["user_id"].(string), + } + } + return utils.GetUserIdentifier(argoClaims) } +// oauth2Login opens a browser, runs a temporary HTTP server to delegate OAuth2 login flow and // oauth2Login opens a browser, runs a temporary HTTP server to delegate OAuth2 login flow and // returns the JWT token and a refresh token (if supported) func oauth2Login( diff --git a/cmd/argocd/commands/login_test.go b/cmd/argocd/commands/login_test.go index 91cf3e11388b3..f18fd4e3f0bcc 100644 --- a/cmd/argocd/commands/login_test.go +++ b/cmd/argocd/commands/login_test.go @@ -62,3 +62,18 @@ func Test_ssoAuthFlow_ssoLaunchBrowser_false(t *testing.T) { assert.Contains(t, out, "To authenticate, copy-and-paste the following URL into your preferred browser: http://test-sso-browser-flow.com") } + +func Test_userDisplayName_federatedClaims(t *testing.T) { + claims := jwt.MapClaims{ + "iss": "qux", + "sub": "foo", + "groups": []string{"baz"}, + "federated_claims": map[string]any{ + "connector_id": "dex", + "user_id": "ldap-123", + }, + } + actualName := userDisplayName(claims) + expectedName := "ldap-123" + assert.Equal(t, expectedName, actualName) +} diff --git a/cmd/argocd/commands/project_role.go b/cmd/argocd/commands/project_role.go index 759608575908b..0775acabf4c1e 100644 --- a/cmd/argocd/commands/project_role.go +++ b/cmd/argocd/commands/project_role.go @@ -279,6 +279,22 @@ func tokenTimeToString(t int64) string { return tokenTimeToString } +func mapClaimsToArgoClaims(claims jwtgo.MapClaims) *utils.ArgoClaims { + sub := jwt.StringField(claims, "sub") + argoClaims := &utils.ArgoClaims{ + RegisteredClaims: jwtgo.RegisteredClaims{ + Subject: sub, + }, + } + if fedClaims, ok := claims["federated_claims"].(map[string]any); ok { + argoClaims.FederatedClaims = &utils.FederatedClaims{ + ConnectorID: fmt.Sprint(fedClaims["connector_id"]), + UserID: fmt.Sprint(fedClaims["user_id"]), + } + } + return argoClaims +} + // NewProjectRoleCreateTokenCommand returns a new instance of an `argocd proj role create-token` command func NewProjectRoleCreateTokenCommand(clientOpts *argocdclient.ClientOptions) *cobra.Command { var ( @@ -332,7 +348,7 @@ Create token succeeded for proj:test-project:test-role. issuedAt, _ := jwt.IssuedAt(claims) expiresAt := int64(jwt.Float64Field(claims, "exp")) id := jwt.StringField(claims, "jti") - subject := jwt.StringField(claims, "sub") + subject := utils.GetUserIdentifier(mapClaimsToArgoClaims(claims)) if !outputTokenOnly { fmt.Printf("Create token succeeded for %s.\n", subject) diff --git a/cmd/argocd/commands/utils/claims.go b/cmd/argocd/commands/utils/claims.go new file mode 100644 index 0000000000000..969fc95180eaf --- /dev/null +++ b/cmd/argocd/commands/utils/claims.go @@ -0,0 +1,48 @@ +package utils + +import ( + "fmt" + + "github.com/golang-jwt/jwt/v5" +) + +// ArgoClaims defines the claims structure based on Dex's documented claims +type ArgoClaims struct { + jwt.RegisteredClaims + Email string `json:"email,omitempty"` + EmailVerified bool `json:"email_verified,omitempty"` + Name string `json:"name,omitempty"` + Groups []string `json:"groups,omitempty"` + // As per Dex docs, federated_claims has a specific structure + FederatedClaims *FederatedClaims `json:"federated_claims,omitempty"` +} + +// FederatedClaims represents the structure documented by Dex +type FederatedClaims struct { + ConnectorID string `json:"connector_id"` + UserID string `json:"user_id"` +} + +// GetFederatedClaims extracts federated claims from jwt.MapClaims +func GetFederatedClaims(claims jwt.MapClaims) *FederatedClaims { + if federated, ok := claims["federated_claims"].(map[string]any); ok { + return &FederatedClaims{ + ConnectorID: fmt.Sprint(federated["connector_id"]), + UserID: fmt.Sprint(federated["user_id"]), + } + } + return nil +} + +// GetUserIdentifier returns a consistent user identifier, checking federated_claims.user_id when Dex is in use +func GetUserIdentifier(claims *ArgoClaims) string { + // Check federated claims first + if claims.FederatedClaims != nil && claims.FederatedClaims.UserID != "" { + return claims.FederatedClaims.UserID + } + // Fallback to sub + if claims.Subject != "" { + return claims.Subject + } + return "" +} diff --git a/cmd/argocd/commands/utils/claims_test.go b/cmd/argocd/commands/utils/claims_test.go new file mode 100644 index 0000000000000..c3a9ed2f1c3c6 --- /dev/null +++ b/cmd/argocd/commands/utils/claims_test.go @@ -0,0 +1,62 @@ +package utils + +import ( + "testing" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" +) + +func TestGetUserIdentifier(t *testing.T) { + tests := []struct { + name string + claims *ArgoClaims + want string + }{ + { + name: "when both dex and sub defined - prefer dex user_id", + claims: &ArgoClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: "ignored:login", + }, + FederatedClaims: &FederatedClaims{ + UserID: "dex-user", + }, + }, + want: "dex-user", + }, + { + name: "when both dex and sub defined but dex user_id empty - fallback to sub", + claims: &ArgoClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: "test:apiKey", + }, + FederatedClaims: &FederatedClaims{ + UserID: "", + }, + }, + want: "test:apiKey", + }, + { + name: "when only sub is defined (no dex) - use sub", + claims: &ArgoClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: "admin:login", + }, + }, + want: "admin:login", + }, + { + name: "when neither dex nor sub defined - return empty", + claims: &ArgoClaims{}, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GetUserIdentifier(tt.claims) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/server/account/account.go b/server/account/account.go index 94394e323ed63..f6b1f69ac130f 100644 --- a/server/account/account.go +++ b/server/account/account.go @@ -38,7 +38,7 @@ func NewServer(sessionMgr *session.SessionManager, settingsMgr *settings.Setting // UpdatePassword updates the password of the currently authenticated account or the account specified in the request. func (s *Server) UpdatePassword(ctx context.Context, q *account.UpdatePasswordRequest) (*account.UpdatePasswordResponse, error) { issuer := session.Iss(ctx) - username := session.Sub(ctx) + username := session.GetUserIdentifier(ctx) updatedUsername := username if q.Name != "" { @@ -169,7 +169,7 @@ func toApiAccount(name string, a settings.Account) *account.Account { func (s *Server) ensureHasAccountPermission(ctx context.Context, action string, account string) error { // account has always has access to itself - if session.Sub(ctx) == account && session.Iss(ctx) == session.SessionManagerClaimsIssuer { + if session.GetUserIdentifier(ctx) == account && session.Iss(ctx) == session.SessionManagerClaimsIssuer { return nil } if err := s.enf.EnforceErr(ctx.Value("claims"), rbacpolicy.ResourceAccounts, action, account); err != nil { diff --git a/server/account/account_test.go b/server/account/account_test.go index 272a0297c9c4f..f4933ce4c4e25 100644 --- a/server/account/account_test.go +++ b/server/account/account_test.go @@ -82,30 +82,54 @@ func getAdminAccount(mgr *settings.SettingsManager) (*settings.Account, error) { } func adminContext(ctx context.Context) context.Context { - // nolint:staticcheck - return context.WithValue(ctx, "claims", &jwt.RegisteredClaims{Subject: "admin", Issuer: sessionutil.SessionManagerClaimsIssuer}) + claims := jwt.MapClaims{ + "sub": "admin", + "iss": sessionutil.SessionManagerClaimsIssuer, + "groups": []string{"role:admin"}, + "federated_claims": map[string]any{ + "user_id": "admin", + }, + } + ctx = context.WithValue(ctx, sessionutil.ClaimsKey(), claims) + //nolint:staticcheck + ctx = context.WithValue(ctx, "claims", claims) + return ctx } func ssoAdminContext(ctx context.Context, iat time.Time) context.Context { - // nolint:staticcheck - return context.WithValue(ctx, "claims", &jwt.RegisteredClaims{ - Subject: "admin", - Issuer: "https://myargocdhost.com/api/dex", - IssuedAt: jwt.NewNumericDate(iat), - }) + claims := jwt.MapClaims{ + "sub": "admin", + "iss": "https://myargocdhost.com/api/dex", + "iat": jwt.NewNumericDate(iat), + "groups": []string{"role:admin"}, // Add admin group + "federated_claims": map[string]any{ + "user_id": "admin", + }, + } + // Set both context values + ctx = context.WithValue(ctx, sessionutil.ClaimsKey(), claims) + //nolint:staticcheck + ctx = context.WithValue(ctx, "claims", claims) + + return ctx } func projTokenContext(ctx context.Context) context.Context { + claims := jwt.MapClaims{ + "sub": "proj:demo:deployer", + "iss": sessionutil.SessionManagerClaimsIssuer, + "groups": []string{"proj:demo:deployer"}, + } + ctx = context.WithValue(ctx, sessionutil.ClaimsKey(), claims) // nolint:staticcheck - return context.WithValue(ctx, "claims", &jwt.RegisteredClaims{ - Subject: "proj:demo:deployer", - Issuer: sessionutil.SessionManagerClaimsIssuer, - }) + ctx = context.WithValue(ctx, "claims", claims) + return ctx } func TestUpdatePassword(t *testing.T) { accountServer, sessionServer := newTestAccountServer(context.Background()) ctx := adminContext(context.Background()) + var err error // ensure password is not allowed to be updated if given bad password diff --git a/server/rbacpolicy/rbacpolicy.go b/server/rbacpolicy/rbacpolicy.go index 250be3fdea3f3..0e1420ed0a09b 100644 --- a/server/rbacpolicy/rbacpolicy.go +++ b/server/rbacpolicy/rbacpolicy.go @@ -6,6 +6,7 @@ import ( "github.com/golang-jwt/jwt/v5" log "github.com/sirupsen/logrus" + "github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils" "github.com/argoproj/argo-cd/v2/pkg/apis/application/v1alpha1" applister "github.com/argoproj/argo-cd/v2/pkg/client/listers/application/v1alpha1" jwtutil "github.com/argoproj/argo-cd/v2/util/jwt" @@ -115,8 +116,14 @@ func (p *RBACPolicyEnforcer) EnforceClaims(claims jwt.Claims, rvals ...any) bool if err != nil { return false } + argoClaims := &utils.ArgoClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: jwtutil.StringField(mapClaims, "sub"), + }, + FederatedClaims: utils.GetFederatedClaims(mapClaims), + } - subject := jwtutil.StringField(mapClaims, "sub") + subject := utils.GetUserIdentifier(argoClaims) // Check if the request is for an application resource. We have special enforcement which takes // into consideration the project's token and group bindings var runtimePolicy string diff --git a/server/server.go b/server/server.go index 63d935fe5c893..45d20cfe091fc 100644 --- a/server/server.go +++ b/server/server.go @@ -63,6 +63,7 @@ import ( "k8s.io/client-go/tools/cache" "sigs.k8s.io/controller-runtime/pkg/client" + "github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils" "github.com/argoproj/argo-cd/v2/common" "github.com/argoproj/argo-cd/v2/pkg/apiclient" accountpkg "github.com/argoproj/argo-cd/v2/pkg/apiclient/account" @@ -1532,6 +1533,15 @@ func (server *ArgoCDServer) getClaims(ctx context.Context) (jwt.Claims, string, groupClaims = *tmpClaims } } + + // Convert to ArgoClaims for user identifier comparison + argoClaims := &utils.ArgoClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: jwtutil.StringField(groupClaims, "sub"), + }, + FederatedClaims: utils.GetFederatedClaims(groupClaims), + } + iss := jwtutil.StringField(groupClaims, "iss") if iss != util_session.SessionManagerClaimsIssuer && server.settings.UserInfoGroupsEnabled() && server.settings.UserInfoPath() != "" { userInfo, unauthorized, err := server.ssoClientApp.GetUserInfo(groupClaims, server.settings.IssuerURL(), server.settings.UserInfoPath()) @@ -1543,7 +1553,13 @@ func (server *ArgoCDServer) getClaims(ctx context.Context) (jwt.Claims, string, log.Errorf("error fetching user info endpoint: %v", err) return claims, "", status.Errorf(codes.Internal, "invalid userinfo response") } - if groupClaims["sub"] != userInfo["sub"] { + userInfoClaims := &utils.ArgoClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: jwtutil.StringField(userInfo, "sub"), + }, + FederatedClaims: utils.GetFederatedClaims(userInfo), + } + if utils.GetUserIdentifier(argoClaims) != utils.GetUserIdentifier(userInfoClaims) { return claims, "", status.Error(codes.Unknown, "subject of claims from user info endpoint didn't match subject of idToken, see https://openid.net/specs/openid-connect-core-1_0.html#UserInfo") } groupClaims["groups"] = userInfo["groups"] diff --git a/server/server_test.go b/server/server_test.go index 7906c1f899072..7b7624300645e 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -815,7 +815,7 @@ func TestAuthenticate_3rd_party_JWTs(t *testing.T) { anonymousEnabled: false, claims: jwt.RegisteredClaims{Audience: jwt.ClaimStrings{common.ArgoCDClientAppID}, Subject: "admin", ExpiresAt: jwt.NewNumericDate(time.Now())}, expectedErrorContains: common.TokenVerificationError, - expectedClaims: jwt.RegisteredClaims{Issuer: "sso"}, + expectedClaims: jwt.MapClaims{"iss": "sso"}, }, { test: "anonymous enabled, expired token, admin claim", @@ -870,7 +870,7 @@ func TestAuthenticate_3rd_party_JWTs(t *testing.T) { claims: jwt.RegisteredClaims{Audience: jwt.ClaimStrings{common.ArgoCDClientAppID}, Subject: "admin", ExpiresAt: jwt.NewNumericDate(time.Now())}, useDex: true, expectedErrorContains: common.TokenVerificationError, - expectedClaims: jwt.RegisteredClaims{Issuer: "sso"}, + expectedClaims: jwt.MapClaims{"iss": "sso"}, }, { test: "external OIDC: anonymous enabled, expired token, admin claim", @@ -1643,3 +1643,115 @@ func Test_enforceContentTypes(t *testing.T) { assert.Equal(t, http.StatusUnsupportedMediaType, resp.StatusCode, "should not have passed, since a disallowed content type was provided") }) } + +func TestGetClaimsWithFederatedIdentity(t *testing.T) { + defaultExpiry := jwt.NewNumericDate(time.Now().Add(time.Hour * 24)) + + tests := []struct { + name string + claims jwt.MapClaims + expectedIdentifier string + expectedErrorContains string + }{ + { + name: "federated claims present - should use federated user_id", + claims: jwt.MapClaims{ + "aud": "argo-cd", + "exp": defaultExpiry, + "sub": "different-id", + "federated_claims": map[string]any{ + "connector_id": "github", + "user_id": "federated-user-12345", + }, + }, + expectedIdentifier: "federated-user-12345", + expectedErrorContains: "", + }, + { + name: "no federated claims - should fallback to sub", + claims: jwt.MapClaims{ + "aud": "argo-cd", + "exp": defaultExpiry, + "sub": "fallback-sub-id", + }, + expectedIdentifier: "fallback-sub-id", + expectedErrorContains: "", + }, + { + name: "empty federated claims - should fallback to sub", + claims: jwt.MapClaims{ + "aud": "argo-cd", + "exp": defaultExpiry, + "sub": "fallback-sub-id", + "federated_claims": map[string]any{}, + }, + expectedIdentifier: "fallback-sub-id", + expectedErrorContains: "", + }, + { + name: "federated claims without user_id - should fallback to sub", + claims: jwt.MapClaims{ + "aud": "argo-cd", + "exp": defaultExpiry, + "sub": "fallback-sub-id", + "federated_claims": map[string]any{ + "connector_id": "github", + }, + }, + expectedIdentifier: "fallback-sub-id", + expectedErrorContains: "", + }, + { + name: "no sub and no federated claims", + claims: jwt.MapClaims{ + "aud": "argo-cd", + "exp": defaultExpiry, + }, + expectedIdentifier: "", + expectedErrorContains: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create test server + argocd, oidcURL := getTestServer(t, false, true, false, settings_util.OIDCConfig{}) + + // Add issuer to claims + tt.claims["iss"] = oidcURL + + // Create token and context + token := jwt.NewWithClaims(jwt.SigningMethodRS512, tt.claims) + key, err := jwt.ParseRSAPrivateKeyFromPEM(testutil.PrivateKey) + require.NoError(t, err) + tokenString, err := token.SignedString(key) + require.NoError(t, err) + ctx := metadata.NewIncomingContext(context.Background(), + metadata.Pairs(apiclient.MetaDataTokenKey, tokenString)) + + // Test claims retrieval + gotClaims, _, err := argocd.getClaims(ctx) + + if tt.expectedErrorContains != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedErrorContains) + return + } + require.NoError(t, err) + mapClaims, ok := gotClaims.(jwt.MapClaims) + require.True(t, ok) + + var actualIdentifier string + if fedClaims, ok := mapClaims["federated_claims"].(map[string]any); ok { + if userID, exists := fedClaims["user_id"]; exists && userID != "" { + actualIdentifier = userID.(string) + } + } + if actualIdentifier == "" && mapClaims["sub"] != nil { + actualIdentifier = mapClaims["sub"].(string) + } + + assert.Equal(t, tt.expectedIdentifier, actualIdentifier) + }) + } +} diff --git a/util/oidc/oidc.go b/util/oidc/oidc.go index 45c47ad8689d6..1ec5ed2a6185a 100644 --- a/util/oidc/oidc.go +++ b/util/oidc/oidc.go @@ -21,6 +21,7 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/oauth2" + "github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils" "github.com/argoproj/argo-cd/v2/common" "github.com/argoproj/argo-cd/v2/server/settings/oidc" "github.com/argoproj/argo-cd/v2/util/cache" @@ -402,9 +403,8 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) { log.Errorf("cannot encrypt accessToken: %v (claims=%s)", err, claimsJSON) return } - sub := jwtutil.StringField(claims, "sub") err = a.clientCache.Set(&cache.Item{ - Key: formatAccessTokenCacheKey(sub), + Key: formatAccessTokenCacheKey(claims), Object: encToken, CacheActionOpts: cache.CacheActionOpts{ Expiration: getTokenExpiration(claims), @@ -552,12 +552,18 @@ func createClaimsAuthenticationRequestParameter(requestedClaims map[string]*oidc // GetUserInfo queries the IDP userinfo endpoint for claims func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoPath string) (jwt.MapClaims, bool, error) { - sub := jwtutil.StringField(actualClaims, "sub") + argoClaims := &utils.ArgoClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: jwtutil.StringField(actualClaims, "sub"), + }, + FederatedClaims: utils.GetFederatedClaims(actualClaims), + } + sub := utils.GetUserIdentifier(argoClaims) var claims jwt.MapClaims var encClaims []byte // in case we got it in the cache, we just return the item - clientCacheKey := formatUserInfoResponseCacheKey(sub) + clientCacheKey := formatUserInfoResponseCacheKey(actualClaims) if err := a.clientCache.Get(clientCacheKey, &encClaims); err == nil { claimsRaw, err := crypto.Decrypt(encClaims, a.encryptionKey) if err != nil { @@ -575,7 +581,7 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP // check if the accessToken for the user is still present var encAccessToken []byte - err := a.clientCache.Get(formatAccessTokenCacheKey(sub), &encAccessToken) + err := a.clientCache.Get(formatAccessTokenCacheKey(actualClaims), &encAccessToken) // without an accessToken we can't query the user info endpoint // thus the user needs to reauthenticate for argocd to get a new accessToken if errors.Is(err, cache.ErrCacheMiss) { @@ -607,6 +613,9 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP if response.StatusCode == http.StatusUnauthorized { return claims, true, err } + if response.StatusCode == http.StatusNotFound { + return jwt.MapClaims{}, true, fmt.Errorf("user info path not found: %s", userInfoPath) + } // according to https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponseValidation // the response should be validated @@ -684,11 +693,25 @@ func getTokenExpiration(claims jwt.MapClaims) time.Duration { } // formatUserInfoResponseCacheKey returns the key which is used to store userinfo of user in cache -func formatUserInfoResponseCacheKey(sub string) string { - return fmt.Sprintf("%s_%s", UserInfoResponseCachePrefix, sub) +func formatUserInfoResponseCacheKey(claims jwt.MapClaims) string { + argoClaims := &utils.ArgoClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: jwtutil.StringField(claims, "sub"), + }, + FederatedClaims: utils.GetFederatedClaims(claims), + } + userID := utils.GetUserIdentifier(argoClaims) + return fmt.Sprintf("%s_%s", UserInfoResponseCachePrefix, userID) } // formatAccessTokenCacheKey returns the key which is used to store the accessToken of a user in cache -func formatAccessTokenCacheKey(sub string) string { - return fmt.Sprintf("%s_%s", AccessTokenCachePrefix, sub) +func formatAccessTokenCacheKey(claims jwt.MapClaims) string { + argoClaims := &utils.ArgoClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: jwtutil.StringField(claims, "sub"), + }, + FederatedClaims: utils.GetFederatedClaims(claims), + } + userID := utils.GetUserIdentifier(argoClaims) + return fmt.Sprintf("%s_%s", AccessTokenCachePrefix, userID) } diff --git a/util/oidc/oidc_test.go b/util/oidc/oidc_test.go index b31bb407af5ad..63f6ead1ff5ce 100644 --- a/util/oidc/oidc_test.go +++ b/util/oidc/oidc_test.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "encoding/json" "fmt" + "log" "net/http" "net/http/httptest" "net/url" @@ -629,9 +630,9 @@ func TestGetUserInfo(t *testing.T) { { name: "call UserInfo with wrong userInfoPath", userInfoPath: "/user", - expectedOutput: jwt.MapClaims(nil), + expectedOutput: jwt.MapClaims{}, expectError: true, - expectUnauthenticated: false, + expectUnauthenticated: true, expectedCacheItems: []struct { key string value string @@ -639,11 +640,11 @@ func TestGetUserInfo(t *testing.T) { expectError bool }{ { - key: formatUserInfoResponseCacheKey("randomUser"), + key: formatUserInfoResponseCacheKey(jwt.MapClaims{"sub": "randomUser", "federated_claims": map[string]any{"user_id": "randomUser"}}), expectError: true, }, }, - idpClaims: jwt.MapClaims{"sub": "randomUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, + idpClaims: jwt.MapClaims{"sub": "randomUser", "federated_claims": map[string]any{"user_id": "randomUser"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, idpHandler: func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusNotFound) }, @@ -654,7 +655,7 @@ func TestGetUserInfo(t *testing.T) { encrypt bool }{ { - key: formatAccessTokenCacheKey("randomUser"), + key: formatAccessTokenCacheKey(jwt.MapClaims{"sub": "randomUser", "federated_claims": map[string]any{"user_id": "randomUser"}}), value: "FakeAccessToken", encrypt: true, }, @@ -673,11 +674,11 @@ func TestGetUserInfo(t *testing.T) { expectError bool }{ { - key: formatUserInfoResponseCacheKey("randomUser"), + key: formatUserInfoResponseCacheKey(jwt.MapClaims{"sub": "fallbackUser"}), expectError: true, }, }, - idpClaims: jwt.MapClaims{"sub": "randomUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, + idpClaims: jwt.MapClaims{"sub": "fallbackUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, idpHandler: func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusUnauthorized) }, @@ -688,7 +689,7 @@ func TestGetUserInfo(t *testing.T) { encrypt bool }{ { - key: formatAccessTokenCacheKey("randomUser"), + key: formatAccessTokenCacheKey(jwt.MapClaims{"sub": "fallbackUser"}), value: "FakeAccessToken", encrypt: true, }, @@ -707,11 +708,11 @@ func TestGetUserInfo(t *testing.T) { expectError bool }{ { - key: formatUserInfoResponseCacheKey("randomUser"), + key: formatUserInfoResponseCacheKey(jwt.MapClaims{"sub": "randomUser", "federated_claims": map[string]any{"user_id": "randomUser"}}), expectError: true, }, }, - idpClaims: jwt.MapClaims{"sub": "randomUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, + idpClaims: jwt.MapClaims{"sub": "randomUser", "federated_claims": map[string]any{"user_id": "randomUser"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, idpHandler: func(w http.ResponseWriter, _ *http.Request) { userInfoBytes := ` notevenJsongarbage @@ -730,7 +731,7 @@ func TestGetUserInfo(t *testing.T) { encrypt bool }{ { - key: formatAccessTokenCacheKey("randomUser"), + key: formatAccessTokenCacheKey(jwt.MapClaims{"sub": "randomUser", "federated_claims": map[string]any{"user_id": "randomUser"}}), value: "FakeAccessToken", encrypt: true, }, @@ -749,11 +750,11 @@ func TestGetUserInfo(t *testing.T) { expectError bool }{ { - key: formatUserInfoResponseCacheKey("randomUser"), + key: formatUserInfoResponseCacheKey(jwt.MapClaims{"sub": "randomUser", "federated_claims": map[string]any{"user_id": "randomUser"}}), expectError: true, }, }, - idpClaims: jwt.MapClaims{"sub": "randomUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, + idpClaims: jwt.MapClaims{"sub": "randomUser", "federated_claims": map[string]any{"user_id": "randomUser"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, idpHandler: func(w http.ResponseWriter, _ *http.Request) { userInfoBytes := ` { @@ -782,13 +783,13 @@ func TestGetUserInfo(t *testing.T) { expectError bool }{ { - key: formatUserInfoResponseCacheKey("randomUser"), + key: formatUserInfoResponseCacheKey(jwt.MapClaims{"sub": "randomUser", "federated_claims": map[string]any{"user_id": "randomUser"}}), value: "{\"groups\":[\"githubOrg:engineers\"]}", expectEncrypted: true, expectError: false, }, }, - idpClaims: jwt.MapClaims{"sub": "randomUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, + idpClaims: jwt.MapClaims{"sub": "randomUser", "federated_claims": map[string]any{"user_id": "randomUser"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, idpHandler: func(w http.ResponseWriter, _ *http.Request) { userInfoBytes := ` { @@ -809,7 +810,177 @@ func TestGetUserInfo(t *testing.T) { encrypt bool }{ { - key: formatAccessTokenCacheKey("randomUser"), + key: formatAccessTokenCacheKey(jwt.MapClaims{"sub": "randomUser", "federated_claims": map[string]any{"user_id": "randomUser"}}), + value: "FakeAccessToken", + encrypt: true, + }, + }, + }, + { + name: "call UserInfo with different sub and federated_claims", + userInfoPath: "/user-info", + expectedOutput: jwt.MapClaims{ + "sub": "different-sub", + "federated_claims": map[string]any{ + "connector_id": "github", + "user_id": "preferred-id", + }, + "groups": []any{"githubOrg:engineers"}, + }, + expectError: false, + expectUnauthenticated: false, + expectedCacheItems: []struct { + key string + value string + expectEncrypted bool + expectError bool + }{ + { + // Key should use federated_claims.user_id (preferred-id) instead of sub + key: formatUserInfoResponseCacheKey(jwt.MapClaims{"sub": "different-sub", "federated_claims": map[string]any{"user_id": "preferred-id"}}), + value: `{"sub":"different-sub","federated_claims":{"connector_id":"github","user_id":"preferred-id"},"groups":["githubOrg:engineers"]}`, + expectEncrypted: true, + expectError: false, + }, + }, + idpClaims: jwt.MapClaims{ + "sub": "different-sub", + "federated_claims": map[string]any{ + "connector_id": "github", + "user_id": "preferred-id", + }, + "exp": float64(time.Now().Add(5 * time.Minute).Unix()), + }, + idpHandler: func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("content-type", "application/json") + w.WriteHeader(http.StatusOK) + response := jwt.MapClaims{ + "sub": "different-sub", + "federated_claims": map[string]any{ + "connector_id": "github", + "user_id": "preferred-id", + }, + "groups": []any{"githubOrg:engineers"}, + } + if err := json.NewEncoder(w).Encode(response); err != nil { + log.Printf("Failed to encode response: %v", err) + } + }, + cache: cache.NewInMemoryCache(24 * time.Hour), + cacheItems: []struct { + key string + value string + encrypt bool + }{ + { + // Access token cache key should also use federated_claims.user_id + key: formatAccessTokenCacheKey(jwt.MapClaims{"sub": "different-sub", "federated_claims": map[string]any{"user_id": "preferred-id"}}), + value: "FakeAccessToken", + encrypt: true, + }, + }, + }, + { + name: "call UserInfo with only sub claim", + userInfoPath: "/user-info", + expectedOutput: jwt.MapClaims{"sub": "sub-only-user", "groups": []any{"githubOrg:engineers"}}, + expectError: false, + expectUnauthenticated: false, + expectedCacheItems: []struct { + key string + value string + expectEncrypted bool + expectError bool + }{ + { + key: formatUserInfoResponseCacheKey(jwt.MapClaims{"sub": "sub-only-user"}), + value: `{"sub":"sub-only-user","groups":["githubOrg:engineers"]}`, + expectEncrypted: true, + expectError: false, + }, + }, + idpClaims: jwt.MapClaims{ + "sub": "sub-only-user", + "exp": float64(time.Now().Add(5 * time.Minute).Unix()), + }, + idpHandler: func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("content-type", "application/json") + w.WriteHeader(http.StatusOK) + response := jwt.MapClaims{ + "sub": "sub-only-user", + "groups": []any{"githubOrg:engineers"}, + } + if err := json.NewEncoder(w).Encode(response); err != nil { + log.Printf("Failed to encode response: %v", err) + } + }, + cache: cache.NewInMemoryCache(24 * time.Hour), + cacheItems: []struct { + key string + value string + encrypt bool + }{ + { + key: formatAccessTokenCacheKey(jwt.MapClaims{"sub": "sub-only-user"}), + value: "FakeAccessToken", + encrypt: true, + }, + }, + }, + { + name: "call UserInfo with only federated claims", + userInfoPath: "/user-info", + expectedOutput: jwt.MapClaims{ + "federated_claims": map[string]any{ + "connector_id": "github", + "user_id": "federated-only-user", + }, + "groups": []any{"githubOrg:engineers"}, + }, + expectError: false, + expectUnauthenticated: false, + expectedCacheItems: []struct { + key string + value string + expectEncrypted bool + expectError bool + }{ + { + key: formatUserInfoResponseCacheKey(jwt.MapClaims{"federated_claims": map[string]any{"user_id": "federated-only-user"}}), + value: `{"federated_claims":{"connector_id":"github","user_id":"federated-only-user"},"groups":["githubOrg:engineers"]}`, + expectEncrypted: true, + expectError: false, + }, + }, + idpClaims: jwt.MapClaims{ + "federated_claims": map[string]any{ + "connector_id": "github", + "user_id": "federated-only-user", + }, + "exp": float64(time.Now().Add(5 * time.Minute).Unix()), + }, + idpHandler: func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("content-type", "application/json") + w.WriteHeader(http.StatusOK) + response := jwt.MapClaims{ + "federated_claims": map[string]any{ + "connector_id": "github", + "user_id": "federated-only-user", + }, + "groups": []any{"githubOrg:engineers"}, + } + if err := json.NewEncoder(w).Encode(response); err != nil { + log.Printf("Failed to encode response: %v", err) + } + }, + cache: cache.NewInMemoryCache(24 * time.Hour), + cacheItems: []struct { + key string + value string + encrypt bool + }{ + { + key: formatAccessTokenCacheKey(jwt.MapClaims{"federated_claims": map[string]any{"user_id": "federated-only-user"}}), value: "FakeAccessToken", encrypt: true, }, @@ -848,6 +1019,9 @@ func TestGetUserInfo(t *testing.T) { assert.Equal(t, tt.expectUnauthenticated, unauthenticated) if tt.expectError { require.Error(t, err) + if tt.userInfoPath != "/user-info" { + assert.Contains(t, err.Error(), "user info path not found") + } } else { require.NoError(t, err) } @@ -862,7 +1036,13 @@ func TestGetUserInfo(t *testing.T) { tmpValue, err = crypto.Decrypt(tmpValue, encryptionKey) require.NoError(t, err) } - assert.Equal(t, item.value, string(tmpValue)) + // Compare as objects instead of strings + var expected, actual map[string]any + err = json.Unmarshal([]byte(item.value), &expected) + require.NoError(t, err) + err = json.Unmarshal(tmpValue, &actual) + require.NoError(t, err) + assert.Equal(t, expected, actual) } } }) diff --git a/util/rbac/rbac.go b/util/rbac/rbac.go index 300de8c92fb0c..13b45e9d3d149 100644 --- a/util/rbac/rbac.go +++ b/util/rbac/rbac.go @@ -10,6 +10,7 @@ import ( "sync" "time" + "github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils" "github.com/argoproj/argo-cd/v2/util/assets" "github.com/argoproj/argo-cd/v2/util/glob" jwtutil "github.com/argoproj/argo-cd/v2/util/jwt" @@ -244,16 +245,21 @@ func (e *Enforcer) Enforce(rvals ...any) bool { func (e *Enforcer) EnforceErr(rvals ...any) error { if !e.Enforce(rvals...) { errMsg := "permission denied" + if len(rvals) > 0 { - rvalsStrs := make([]string, len(rvals)-1) - for i, rval := range rvals[1:] { - rvalsStrs[i] = fmt.Sprintf("%s", rval) + rvalsStrs := []string{} + for _, rval := range rvals[1:] { + rvalsStrs = append(rvalsStrs, fmt.Sprintf("%v", rval)) } if s, ok := rvals[0].(jwt.Claims); ok { claims, err := jwtutil.MapClaims(s) if err == nil { - if sub := jwtutil.StringField(claims, "sub"); sub != "" { - rvalsStrs = append(rvalsStrs, "sub: "+sub) + if argoClaims := (&utils.ArgoClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: jwtutil.StringField(claims, "sub"), + }, + }); utils.GetUserIdentifier(argoClaims) != "" { + rvalsStrs = append(rvalsStrs, "sub: "+utils.GetUserIdentifier(argoClaims)) } if issuedAtTime, err := jwtutil.IssuedAtTime(claims); err == nil { rvalsStrs = append(rvalsStrs, "iat: "+issuedAtTime.Format(time.RFC3339)) @@ -262,6 +268,7 @@ func (e *Enforcer) EnforceErr(rvals ...any) error { } errMsg = fmt.Sprintf("%s: %s", errMsg, strings.Join(rvalsStrs, ", ")) } + return status.Error(codes.PermissionDenied, errMsg) } return nil diff --git a/util/session/sessionmanager.go b/util/session/sessionmanager.go index dfb0053bfea74..6eb79d4f996be 100644 --- a/util/session/sessionmanager.go +++ b/util/session/sessionmanager.go @@ -20,6 +20,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils" "github.com/argoproj/argo-cd/v2/common" "github.com/argoproj/argo-cd/v2/pkg/client/listers/application/v1alpha1" "github.com/argoproj/argo-cd/v2/server/rbacpolicy" @@ -156,15 +157,18 @@ func NewSessionManager(settingsMgr *settings.SettingsManager, projectsLister v1a // Passing a value of `0` for secondsBeforeExpiry creates a token that never expires. // The id parameter holds an optional unique JWT token identifier and stored as a standard claim "jti" in the JWT token. func (mgr *SessionManager) Create(subject string, secondsBeforeExpiry int64, id string) (string, error) { - // Create a new token object, specifying signing method and the claims - // you would like it to contain. now := time.Now().UTC() - claims := jwt.RegisteredClaims{ - IssuedAt: jwt.NewNumericDate(now), - Issuer: SessionManagerClaimsIssuer, - NotBefore: jwt.NewNumericDate(now), - Subject: subject, - ID: id, + claims := &utils.ArgoClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(now), + Issuer: SessionManagerClaimsIssuer, + NotBefore: jwt.NewNumericDate(now), + Subject: subject, + ID: id, + }, + FederatedClaims: &utils.FederatedClaims{ + UserID: "", // Empty for local auth + }, } if secondsBeforeExpiry > 0 { expires := now.Add(time.Duration(secondsBeforeExpiry) * time.Second) @@ -226,7 +230,19 @@ func (mgr *SessionManager) Parse(tokenString string) (jwt.Claims, string, error) return nil, "", err } - subject := jwtutil.StringField(claims, "sub") + // Convert MapClaims to ArgoClaims + argoClaims := &utils.ArgoClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: jwtutil.StringField(claims, "sub"), + }, + } + if fedClaims, ok := claims["federated_claims"].(map[string]any); ok { + argoClaims.FederatedClaims = &utils.FederatedClaims{ + UserID: jwtutil.StringField(fedClaims, "user_id"), + } + } + + subject := utils.GetUserIdentifier(argoClaims) id := jwtutil.StringField(claims, "jti") if projName, role, ok := rbacpolicy.GetProjectRoleFromSubject(subject); ok { @@ -502,9 +518,24 @@ func WithAuthMiddleware(disabled bool, authn TokenVerifier, next http.Handler) h return } ctx := r.Context() + + // Convert claims to MapClaims + mapClaims, err := jwtutil.MapClaims(claims) + if err != nil { + http.Error(w, "Invalid claims format", http.StatusUnauthorized) + return + } + + argoClaims := &utils.ArgoClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: jwtutil.StringField(mapClaims, "sub"), + }, + FederatedClaims: utils.GetFederatedClaims(mapClaims), + } + // Add claims to the context to inspect for RBAC // nolint:staticcheck - ctx = context.WithValue(ctx, "claims", claims) + ctx = context.WithValue(ctx, "user_id", utils.GetUserIdentifier(argoClaims)) r = r.WithContext(ctx) } next.ServeHTTP(w, r) @@ -515,12 +546,14 @@ func WithAuthMiddleware(disabled bool, authn TokenVerifier, next http.Handler) h // We choose how to verify based on the issuer. func (mgr *SessionManager) VerifyToken(tokenString string) (jwt.Claims, string, error) { parser := jwt.NewParser(jwt.WithoutClaimsValidation()) - var claims jwt.RegisteredClaims + claims := jwt.MapClaims{} _, _, err := parser.ParseUnverified(tokenString, &claims) if err != nil { return nil, "", err } - switch claims.Issuer { + // Get issuer from MapClaims + issuer, _ := claims["iss"].(string) + switch issuer { case SessionManagerClaimsIssuer: // Argo CD signed token return mgr.Parse(tokenString) @@ -547,8 +580,8 @@ func (mgr *SessionManager) VerifyToken(tokenString string) (jwt.Claims, string, log.Warnf("Failed to verify token: %s", err) tokenExpiredError := &oidc.TokenExpiredError{} if errors.As(err, &tokenExpiredError) { - claims = jwt.RegisteredClaims{ - Issuer: "sso", + claims = jwt.MapClaims{ + "iss": "sso", } return claims, "", common.TokenVerificationErr } @@ -584,7 +617,7 @@ func (mgr *SessionManager) RevokeToken(ctx context.Context, id string, expiringA } func LoggedIn(ctx context.Context) bool { - return Sub(ctx) != "" && ctx.Value(AuthErrorCtxKey) == nil + return GetUserIdentifier(ctx) != "" && ctx.Value(AuthErrorCtxKey) == nil } // Username is a helper to extract a human readable username from a context @@ -593,12 +626,12 @@ func Username(ctx context.Context) string { if !ok { return "" } - switch jwtutil.StringField(mapClaims, "iss") { - case SessionManagerClaimsIssuer: - return jwtutil.StringField(mapClaims, "sub") - default: - return jwtutil.StringField(mapClaims, "email") + subject := jwtutil.StringField(mapClaims, "sub") + if strings.Contains(subject, ":") { + parts := strings.Split(subject, ":") + return parts[0] // Return just the username part } + return subject } func Iss(ctx context.Context) string { @@ -617,12 +650,19 @@ func Iat(ctx context.Context) (time.Time, error) { return jwtutil.IssuedAtTime(mapClaims) } -func Sub(ctx context.Context) string { +// GetUserIdentifier returns the user identifier from context, prioritizing federated claims over subject +func GetUserIdentifier(ctx context.Context) string { mapClaims, ok := mapClaims(ctx) if !ok { return "" } - return jwtutil.StringField(mapClaims, "sub") + argoClaims := &utils.ArgoClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: jwtutil.StringField(mapClaims, "sub"), + }, + FederatedClaims: utils.GetFederatedClaims(mapClaims), + } + return utils.GetUserIdentifier(argoClaims) } func Groups(ctx context.Context, scopes []string) []string { @@ -633,11 +673,32 @@ func Groups(ctx context.Context, scopes []string) []string { return jwtutil.GetGroups(mapClaims, scopes) } +type contextKey struct{} + +var claimsKey = contextKey{} + +// ClaimsKey returns the context key used for claims +func ClaimsKey() any { + return claimsKey +} + func mapClaims(ctx context.Context) (jwt.MapClaims, bool) { - claims, ok := ctx.Value("claims").(jwt.Claims) + claims, ok := ctx.Value(claimsKey).(jwt.Claims) if !ok { - return nil, false + claims, ok = ctx.Value("claims").(jwt.Claims) + if !ok { + // Try direct MapClaims from both keys + mapClaims, ok := ctx.Value(claimsKey).(jwt.MapClaims) + if !ok { + mapClaims, ok = ctx.Value("claims").(jwt.MapClaims) + } + if ok { + return mapClaims, true + } + return nil, false + } } + mapClaims, err := jwtutil.MapClaims(claims) if err != nil { return nil, false diff --git a/util/session/sessionmanager_test.go b/util/session/sessionmanager_test.go index 1fe50fb74a843..940b3e5d89ef9 100644 --- a/util/session/sessionmanager_test.go +++ b/util/session/sessionmanager_test.go @@ -24,12 +24,14 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/client-go/kubernetes/fake" + "github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils" "github.com/argoproj/argo-cd/v2/common" appv1 "github.com/argoproj/argo-cd/v2/pkg/apis/application/v1alpha1" apps "github.com/argoproj/argo-cd/v2/pkg/client/clientset/versioned/fake" "github.com/argoproj/argo-cd/v2/pkg/client/listers/application/v1alpha1" "github.com/argoproj/argo-cd/v2/test" "github.com/argoproj/argo-cd/v2/util/errors" + jwtutil "github.com/argoproj/argo-cd/v2/util/jwt" "github.com/argoproj/argo-cd/v2/util/password" "github.com/argoproj/argo-cd/v2/util/settings" utiltest "github.com/argoproj/argo-cd/v2/util/test" @@ -98,13 +100,81 @@ func TestSessionManager_AdminToken(t *testing.T) { require.NoError(t, err) assert.Empty(t, newToken) - mapClaims := *(claims.(*jwt.MapClaims)) - subject := mapClaims["sub"].(string) + // Convert claims to ArgoClaims + mapClaims, err := jwtutil.MapClaims(claims) + require.NoError(t, err) + + argoClaims := &utils.ArgoClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: jwtutil.StringField(mapClaims, "sub"), + }, + FederatedClaims: utils.GetFederatedClaims(mapClaims), + } + subject := utils.GetUserIdentifier(argoClaims) if subject != "admin" { t.Errorf("Token claim subject %q does not match expected subject %q.", subject, "admin") } } +// This test verifies both cases: +// When no federated claims exist, it falls back to the subject +// When federated claims exist with a user_id, it uses that instead +func TestSessionManager_TokenIdentifier(t *testing.T) { + redisClient, closer := test.NewInMemoryRedis() + defer closer() + settingsMgr := settings.NewSettingsManager(context.Background(), getKubeClient("pass", true), "argocd") + mgr := newSessionManager(settingsMgr, getProjLister(), NewUserStateStorage(redisClient)) + + tests := []struct { + name string + subject string + fedID string + expected string + }{ + { + name: "Falls back to subject when no federated claims", + subject: "admin:login", + fedID: "", + expected: "admin", + }, + { + name: "Uses federated user_id when present", + subject: "admin:login", + fedID: "fed-admin", + expected: "fed-admin", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token, err := mgr.Create(tt.subject, 0, "123") + require.NoError(t, err) + + claims, _, err := mgr.Parse(token) + require.NoError(t, err) + + mapClaims, err := jwtutil.MapClaims(claims) + require.NoError(t, err) + + if tt.fedID != "" { + mapClaims["federated_claims"] = map[string]any{ + "user_id": tt.fedID, + } + } + + argoClaims := &utils.ArgoClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: jwtutil.StringField(mapClaims, "sub"), + }, + FederatedClaims: utils.GetFederatedClaims(mapClaims), + } + + subject := utils.GetUserIdentifier(argoClaims) + assert.Equal(t, tt.expected, subject) + }) + } +} + func TestSessionManager_AdminToken_ExpiringSoon(t *testing.T) { redisClient, closer := test.NewInMemoryRedis() defer closer() @@ -125,8 +195,17 @@ func TestSessionManager_AdminToken_ExpiringSoon(t *testing.T) { // verify that new token is valid and for the same user claims, _, err := mgr.Parse(newToken) require.NoError(t, err) - mapClaims := *(claims.(*jwt.MapClaims)) - subject := mapClaims["sub"].(string) + + mapClaims, err := jwtutil.MapClaims(claims) + require.NoError(t, err) + + argoClaims := &utils.ArgoClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: jwtutil.StringField(mapClaims, "sub"), + }, + FederatedClaims: utils.GetFederatedClaims(mapClaims), + } + subject := utils.GetUserIdentifier(argoClaims) assert.Equal(t, "admin", subject) } @@ -226,10 +305,17 @@ type tokenVerifierMock struct { } func (tm *tokenVerifierMock) VerifyToken(_ string) (jwt.Claims, string, error) { - if tm.claims == nil { + if tm.err != nil { return nil, "", tm.err } - return tm.claims, "", tm.err + mapClaims := jwt.MapClaims{ + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + } + if tm.claims == nil { + return jwt.MapClaims{}, "", nil + } + return mapClaims, "", nil } func strPointer(str string) *string { @@ -338,29 +424,36 @@ func TestSessionManager_WithAuthMiddleware(t *testing.T) { } } -var loggedOutContext = context.Background() - -// nolint:staticcheck -var loggedInContext = context.WithValue(context.Background(), "claims", &jwt.MapClaims{"iss": "qux", "sub": "foo", "email": "bar", "groups": []string{"baz"}}) +var ( + loggedOutContext = context.Background() + // nolint:staticcheck + loggedInContext = context.WithValue(context.Background(), claimsKey, &jwt.MapClaims{"iss": "qux", "sub": "foo", "email": "bar", "groups": []string{"baz"}, "federated_claims": map[string]any{"user_id": "foo"}}) + // for testing without federated claims + loggedInContextNoFederated = context.WithValue(context.Background(), claimsKey, &jwt.MapClaims{"iss": "qux", "sub": "foo", "email": "bar", "groups": []string{"baz"}}) +) func TestIss(t *testing.T) { assert.Empty(t, Iss(loggedOutContext)) assert.Equal(t, "qux", Iss(loggedInContext)) + assert.Equal(t, "foo", GetUserIdentifier(loggedInContextNoFederated)) // Without federated claims, falls back to sub } func TestLoggedIn(t *testing.T) { assert.False(t, LoggedIn(loggedOutContext)) assert.True(t, LoggedIn(loggedInContext)) + assert.Equal(t, "foo", Username(loggedInContextNoFederated)) } func TestUsername(t *testing.T) { assert.Empty(t, Username(loggedOutContext)) - assert.Equal(t, "bar", Username(loggedInContext)) + assert.Equal(t, "foo", Username(loggedInContext)) + assert.Equal(t, "foo", Username(loggedInContextNoFederated)) } func TestSub(t *testing.T) { - assert.Empty(t, Sub(loggedOutContext)) - assert.Equal(t, "foo", Sub(loggedInContext)) + assert.Empty(t, GetUserIdentifier(loggedOutContext)) + assert.Equal(t, "foo", GetUserIdentifier(loggedInContext)) + assert.Equal(t, "foo", Username(loggedInContextNoFederated)) } func TestGroups(t *testing.T) {