diff --git a/README.md b/README.md index 0cf5274d..f9ec7ce4 100644 --- a/README.md +++ b/README.md @@ -44,9 +44,9 @@ Check the `/example` folder where example code for different scenarios is locate ```bash # start oidc op server # oidc discovery http://localhost:9998/.well-known/openid-configuration -go run github.com/zitadel/oidc/v2/example/server +go run github.com/zitadel/oidc/v3/example/server # start oidc web client (in a new terminal) -CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://localhost:9998/ SCOPES="openid profile" PORT=9999 go run github.com/zitadel/oidc/v2/example/client/app +CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://localhost:9998/ SCOPES="openid profile" PORT=9999 go run github.com/zitadel/oidc/v3/example/client/app ``` - open http://localhost:9999/login in your browser @@ -56,11 +56,11 @@ CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://localhost:9998/ SCOPES="openid for the dynamic issuer, just start it with: ```bash -go run github.com/zitadel/oidc/v2/example/server/dynamic +go run github.com/zitadel/oidc/v3/example/server/dynamic ``` the oidc web client above will still work, but if you add `oidc.local` (pointing to 127.0.0.1) in your hosts file you can also start it with: ```bash -CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://oidc.local:9998/ SCOPES="openid profile" PORT=9999 go run github.com/zitadel/oidc/v2/example/client/app +CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://oidc.local:9998/ SCOPES="openid profile" PORT=9999 go run github.com/zitadel/oidc/v3/example/client/app ``` > Note: Usernames are suffixed with the hostname (`test-user@localhost` or `test-user@oidc.local`) diff --git a/example/client/api/api.go b/example/client/api/api.go index 8093b636..2f81c07b 100644 --- a/example/client/api/api.go +++ b/example/client/api/api.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/json" "fmt" "log" @@ -9,11 +10,11 @@ import ( "strings" "time" - "github.com/gorilla/mux" + "github.com/go-chi/chi" "github.com/sirupsen/logrus" - "github.com/zitadel/oidc/v2/pkg/client/rs" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/client/rs" + "github.com/zitadel/oidc/v3/pkg/oidc" ) const ( @@ -27,12 +28,12 @@ func main() { port := os.Getenv("PORT") issuer := os.Getenv("ISSUER") - provider, err := rs.NewResourceServerFromKeyFile(issuer, keyPath) + provider, err := rs.NewResourceServerFromKeyFile(context.TODO(), issuer, keyPath) if err != nil { logrus.Fatalf("error creating provider %s", err.Error()) } - router := mux.NewRouter() + router := chi.NewRouter() // public url accessible without any authorization // will print `OK` and current timestamp @@ -47,7 +48,7 @@ func main() { if !ok { return } - resp, err := rs.Introspect(r.Context(), provider, token) + resp, err := rs.Introspect[*oidc.IntrospectionResponse](r.Context(), provider, token) if err != nil { http.Error(w, err.Error(), http.StatusForbidden) return @@ -68,14 +69,14 @@ func main() { if !ok { return } - resp, err := rs.Introspect(r.Context(), provider, token) + resp, err := rs.Introspect[*oidc.IntrospectionResponse](r.Context(), provider, token) if err != nil { http.Error(w, err.Error(), http.StatusForbidden) return } - params := mux.Vars(r) - requestedClaim := params["claim"] - requestedValue := params["value"] + requestedClaim := chi.URLParam(r, "claim") + requestedValue := chi.URLParam(r, "value") + value, ok := resp.Claims[requestedClaim].(string) if !ok || value == "" || value != requestedValue { http.Error(w, "claim does not match", http.StatusForbidden) diff --git a/example/client/app/app.go b/example/client/app/app.go index 0c324d20..0e339f40 100644 --- a/example/client/app/app.go +++ b/example/client/app/app.go @@ -1,19 +1,23 @@ package main import ( + "context" "encoding/json" "fmt" "net/http" "os" "strings" + "sync/atomic" "time" "github.com/google/uuid" "github.com/sirupsen/logrus" + "golang.org/x/exp/slog" - "github.com/zitadel/oidc/v2/pkg/client/rp" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/logging" + "github.com/zitadel/oidc/v3/pkg/client/rp" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) var ( @@ -32,9 +36,25 @@ func main() { redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath) cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure()) + logger := slog.New( + slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelDebug, + }), + ) + client := &http.Client{ + Timeout: time.Minute, + } + // enable outgoing request logging + logging.EnableHTTPClient(client, + logging.WithClientGroup("client"), + ) + options := []rp.Option{ rp.WithCookieHandler(cookieHandler), rp.WithVerifierOpts(rp.WithIssuedAtOffset(5 * time.Second)), + rp.WithHTTPClient(client), + rp.WithLogger(logger), } if clientSecret == "" { options = append(options, rp.WithPKCE(cookieHandler)) @@ -43,7 +63,10 @@ func main() { options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath))) } - provider, err := rp.NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes, options...) + // One can add a logger to the context, + // pre-defining log attributes as required. + ctx := logging.ToContext(context.TODO(), logger) + provider, err := rp.NewRelyingPartyOIDC(ctx, issuer, clientID, clientSecret, redirectURI, scopes, options...) if err != nil { logrus.Fatalf("error creating provider %s", err.Error()) } @@ -118,8 +141,22 @@ func main() { // // http.Handle(callbackPath, rp.CodeExchangeHandler(marshalToken, provider)) + // simple counter for request IDs + var counter atomic.Int64 + // enable incomming request logging + mw := logging.Middleware( + logging.WithLogger(logger), + logging.WithGroup("server"), + logging.WithIDFunc(func() slog.Attr { + return slog.Int64("id", counter.Add(1)) + }), + ) + lis := fmt.Sprintf("127.0.0.1:%s", port) - logrus.Infof("listening on http://%s/", lis) - logrus.Info("press ctrl+c to stop") - logrus.Fatal(http.ListenAndServe(lis, nil)) + logger.Info("server listening, press ctrl+c to stop", "addr", lis) + err = http.ListenAndServe(lis, mw(http.DefaultServeMux)) + if err != http.ErrServerClosed { + logger.Error("server terminated", "error", err) + os.Exit(1) + } } diff --git a/example/client/device/device.go b/example/client/device/device.go index 284ba372..bea61345 100644 --- a/example/client/device/device.go +++ b/example/client/device/device.go @@ -11,8 +11,8 @@ import ( "github.com/sirupsen/logrus" - "github.com/zitadel/oidc/v2/pkg/client/rp" - httphelper "github.com/zitadel/oidc/v2/pkg/http" + "github.com/zitadel/oidc/v3/pkg/client/rp" + httphelper "github.com/zitadel/oidc/v3/pkg/http" ) var ( @@ -39,13 +39,13 @@ func main() { options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath))) } - provider, err := rp.NewRelyingPartyOIDC(issuer, clientID, clientSecret, "", scopes, options...) + provider, err := rp.NewRelyingPartyOIDC(ctx, issuer, clientID, clientSecret, "", scopes, options...) if err != nil { logrus.Fatalf("error creating provider %s", err.Error()) } logrus.Info("starting device authorization flow") - resp, err := rp.DeviceAuthorization(scopes, provider) + resp, err := rp.DeviceAuthorization(ctx, scopes, provider, nil) if err != nil { logrus.Fatal(err) } diff --git a/example/client/github/github.go b/example/client/github/github.go index 9cb813c0..7d069d49 100644 --- a/example/client/github/github.go +++ b/example/client/github/github.go @@ -10,10 +10,10 @@ import ( "golang.org/x/oauth2" githubOAuth "golang.org/x/oauth2/github" - "github.com/zitadel/oidc/v2/pkg/client/rp" - "github.com/zitadel/oidc/v2/pkg/client/rp/cli" - "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/client/rp" + "github.com/zitadel/oidc/v3/pkg/client/rp/cli" + "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) var ( diff --git a/example/client/service/service.go b/example/client/service/service.go index 7d91f31a..865a4e00 100644 --- a/example/client/service/service.go +++ b/example/client/service/service.go @@ -13,7 +13,7 @@ import ( "github.com/sirupsen/logrus" "golang.org/x/oauth2" - "github.com/zitadel/oidc/v2/pkg/client/profile" + "github.com/zitadel/oidc/v3/pkg/client/profile" ) var client = http.DefaultClient @@ -25,7 +25,7 @@ func main() { scopes := strings.Split(os.Getenv("SCOPES"), " ") if keyPath != "" { - ts, err := profile.NewJWTProfileTokenSourceFromKeyFile(issuer, keyPath, scopes) + ts, err := profile.NewJWTProfileTokenSourceFromKeyFile(context.TODO(), issuer, keyPath, scopes) if err != nil { logrus.Fatalf("error creating token source %s", err.Error()) } @@ -76,7 +76,7 @@ func main() { http.Error(w, err.Error(), http.StatusInternalServerError) return } - ts, err := profile.NewJWTProfileTokenSourceFromKeyFileData(issuer, key, scopes) + ts, err := profile.NewJWTProfileTokenSourceFromKeyFileData(context.TODO(), issuer, key, scopes) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return diff --git a/example/server/dynamic/login.go b/example/server/dynamic/login.go index e7c6e5fd..d90fb8e5 100644 --- a/example/server/dynamic/login.go +++ b/example/server/dynamic/login.go @@ -6,9 +6,9 @@ import ( "html/template" "net/http" - "github.com/gorilla/mux" + "github.com/go-chi/chi" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v3/pkg/op" ) const ( @@ -43,7 +43,7 @@ var ( type login struct { authenticate authenticate - router *mux.Router + router chi.Router callback func(context.Context, string) string } @@ -57,9 +57,9 @@ func NewLogin(authenticate authenticate, callback func(context.Context, string) } func (l *login) createRouter(issuerInterceptor *op.IssuerInterceptor) { - l.router = mux.NewRouter() - l.router.Path("/username").Methods("GET").HandlerFunc(l.loginHandler) - l.router.Path("/username").Methods("POST").HandlerFunc(issuerInterceptor.HandlerFunc(l.checkLoginHandler)) + l.router = chi.NewRouter() + l.router.Get("/username", l.loginHandler) + l.router.With(issuerInterceptor.Handler).Post("/username", l.checkLoginHandler) } type authenticate interface { diff --git a/example/server/dynamic/op.go b/example/server/dynamic/op.go index 783c75cf..16627296 100644 --- a/example/server/dynamic/op.go +++ b/example/server/dynamic/op.go @@ -7,11 +7,11 @@ import ( "log" "net/http" - "github.com/gorilla/mux" + "github.com/go-chi/chi" "golang.org/x/text/language" - "github.com/zitadel/oidc/v2/example/server/storage" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v3/example/server/storage" + "github.com/zitadel/oidc/v3/pkg/op" ) const ( @@ -47,7 +47,7 @@ func main() { //be sure to create a proper crypto random key and manage it securely! key := sha256.Sum256([]byte("test")) - router := mux.NewRouter() + router := chi.NewRouter() //for simplicity, we provide a very small default page for users who have signed out router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) { @@ -76,7 +76,7 @@ func main() { //regardless of how many pages / steps there are in the process, the UI must be registered in the router, //so we will direct all calls to /login to the login UI - router.PathPrefix("/login/").Handler(http.StripPrefix("/login", l.router)) + router.Mount("/login/", http.StripPrefix("/login", l.router)) //we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration) //is served on the correct path @@ -84,7 +84,7 @@ func main() { //if your issuer ends with a path (e.g. http://localhost:9998/custom/path/), //then you would have to set the path prefix (/custom/path/): //router.PathPrefix("/custom/path/").Handler(http.StripPrefix("/custom/path", provider.HttpHandler())) - router.PathPrefix("/").Handler(provider.HttpHandler()) + router.Mount("/", provider) server := &http.Server{ Addr: ":" + port, diff --git a/example/server/exampleop/device.go b/example/server/exampleop/device.go index ae2e8f29..7478750c 100644 --- a/example/server/exampleop/device.go +++ b/example/server/exampleop/device.go @@ -1,21 +1,34 @@ package exampleop import ( + "context" "errors" "fmt" "io" "net/http" "net/url" - "github.com/gorilla/mux" + "github.com/go-chi/chi" "github.com/gorilla/securecookie" "github.com/sirupsen/logrus" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v3/pkg/op" ) type deviceAuthenticate interface { CheckUsernamePasswordSimple(username, password string) error op.DeviceAuthorizationStorage + + // GetDeviceAuthorizationByUserCode resturns the current state of the device authorization flow, + // identified by the user code. + GetDeviceAuthorizationByUserCode(ctx context.Context, userCode string) (*op.DeviceAuthorizationState, error) + + // CompleteDeviceAuthorization marks a device authorization entry as Completed, + // identified by userCode. The Subject is added to the state, so that + // GetDeviceAuthorizatonState can use it to create a new Access Token. + CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error + + // DenyDeviceAuthorization marks a device authorization entry as Denied. + DenyDeviceAuthorization(ctx context.Context, userCode string) error } type deviceLogin struct { @@ -23,14 +36,14 @@ type deviceLogin struct { cookie *securecookie.SecureCookie } -func registerDeviceAuth(storage deviceAuthenticate, router *mux.Router) { +func registerDeviceAuth(storage deviceAuthenticate, router chi.Router) { l := &deviceLogin{ storage: storage, cookie: securecookie.New(securecookie.GenerateRandomKey(32), nil), } - router.HandleFunc("", l.userCodeHandler) - router.Path("/login").Methods(http.MethodPost).HandlerFunc(l.loginHandler) + router.HandleFunc("/", l.userCodeHandler) + router.Post("/login", l.loginHandler) router.HandleFunc("/confirm", l.confirmHandler) } diff --git a/example/server/exampleop/login.go b/example/server/exampleop/login.go index 64045b8a..053525d5 100644 --- a/example/server/exampleop/login.go +++ b/example/server/exampleop/login.go @@ -5,13 +5,13 @@ import ( "fmt" "net/http" - "github.com/gorilla/mux" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/go-chi/chi" + "github.com/zitadel/oidc/v3/pkg/op" ) type login struct { authenticate authenticate - router *mux.Router + router chi.Router callback func(context.Context, string) string } @@ -25,9 +25,9 @@ func NewLogin(authenticate authenticate, callback func(context.Context, string) } func (l *login) createRouter(issuerInterceptor *op.IssuerInterceptor) { - l.router = mux.NewRouter() - l.router.Path("/username").Methods("GET").HandlerFunc(l.loginHandler) - l.router.Path("/username").Methods("POST").HandlerFunc(issuerInterceptor.HandlerFunc(l.checkLoginHandler)) + l.router = chi.NewRouter() + l.router.Get("/username", l.loginHandler) + l.router.Post("/username", issuerInterceptor.HandlerFunc(l.checkLoginHandler)) } type authenticate interface { diff --git a/example/server/exampleop/op.go b/example/server/exampleop/op.go index 20190ca9..830f7f68 100644 --- a/example/server/exampleop/op.go +++ b/example/server/exampleop/op.go @@ -4,13 +4,16 @@ import ( "crypto/sha256" "log" "net/http" + "sync/atomic" "time" - "github.com/gorilla/mux" + "github.com/go-chi/chi" + "github.com/zitadel/logging" + "golang.org/x/exp/slog" "golang.org/x/text/language" - "github.com/zitadel/oidc/v2/example/server/storage" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v3/example/server/storage" + "github.com/zitadel/oidc/v3/pkg/op" ) const ( @@ -31,26 +34,33 @@ type Storage interface { deviceAuthenticate } +// simple counter for request IDs +var counter atomic.Int64 + // SetupServer creates an OIDC server with Issuer=http://localhost: // // Use one of the pre-made clients in storage/clients.go or register a new one. -func SetupServer(issuer string, storage Storage, extraOptions ...op.Option) *mux.Router { +func SetupServer(issuer string, storage Storage, logger *slog.Logger, wrapServer bool, extraOptions ...op.Option) chi.Router { // the OpenID Provider requires a 32-byte key for (token) encryption // be sure to create a proper crypto random key and manage it securely! key := sha256.Sum256([]byte("test")) - router := mux.NewRouter() + router := chi.NewRouter() + router.Use(logging.Middleware( + logging.WithLogger(logger), + logging.WithIDFunc(func() slog.Attr { + return slog.Int64("id", counter.Add(1)) + }), + )) // for simplicity, we provide a very small default page for users who have signed out router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) { - _, err := w.Write([]byte("signed out successfully")) - if err != nil { - log.Printf("error serving logged out page: %v", err) - } + w.Write([]byte("signed out successfully")) + // no need to check/log error, this will be handeled by the middleware. }) // creation of the OpenIDProvider with the just created in-memory Storage - provider, err := newOP(storage, issuer, key, extraOptions...) + provider, err := newOP(storage, issuer, key, logger, extraOptions...) if err != nil { log.Fatal(err) } @@ -62,17 +72,23 @@ func SetupServer(issuer string, storage Storage, extraOptions ...op.Option) *mux // regardless of how many pages / steps there are in the process, the UI must be registered in the router, // so we will direct all calls to /login to the login UI - router.PathPrefix("/login/").Handler(http.StripPrefix("/login", l.router)) + router.Mount("/login/", http.StripPrefix("/login", l.router)) + + router.Route("/device", func(r chi.Router) { + registerDeviceAuth(storage, r) + }) - router.PathPrefix("/device").Subrouter() - registerDeviceAuth(storage, router.PathPrefix("/device").Subrouter()) + handler := http.Handler(provider) + if wrapServer { + handler = op.NewLegacyServer(provider, *op.DefaultEndpoints) + } // we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration) // is served on the correct path // // if your issuer ends with a path (e.g. http://localhost:9998/custom/path/), // then you would have to set the path prefix (/custom/path/) - router.PathPrefix("/").Handler(provider.HttpHandler()) + router.Mount("/", handler) return router } @@ -80,7 +96,7 @@ func SetupServer(issuer string, storage Storage, extraOptions ...op.Option) *mux // newOP will create an OpenID Provider for localhost on a specified port with a given encryption key // and a predefined default logout uri // it will enable all options (see descriptions) -func newOP(storage op.Storage, issuer string, key [32]byte, extraOptions ...op.Option) (op.OpenIDProvider, error) { +func newOP(storage op.Storage, issuer string, key [32]byte, logger *slog.Logger, extraOptions ...op.Option) (op.OpenIDProvider, error) { config := &op.Config{ CryptoKey: key, @@ -114,10 +130,12 @@ func newOP(storage op.Storage, issuer string, key [32]byte, extraOptions ...op.O } handler, err := op.NewOpenIDProvider(issuer, config, storage, append([]op.Option{ - // we must explicitly allow the use of the http issuer + //we must explicitly allow the use of the http issuer op.WithAllowInsecure(), // as an example on how to customize an endpoint this will change the authorization_endpoint from /authorize to /auth op.WithCustomAuthEndpoint(op.NewEndpoint("auth")), + // Pass our logger to the OP + op.WithLogger(logger.WithGroup("op")), }, extraOptions...)..., ) if err != nil { diff --git a/example/server/main.go b/example/server/main.go index a2836eac..38057fb7 100644 --- a/example/server/main.go +++ b/example/server/main.go @@ -2,11 +2,12 @@ package main import ( "fmt" - "log" "net/http" + "os" - "github.com/zitadel/oidc/v2/example/server/exampleop" - "github.com/zitadel/oidc/v2/example/server/storage" + "github.com/zitadel/oidc/v3/example/server/exampleop" + "github.com/zitadel/oidc/v3/example/server/storage" + "golang.org/x/exp/slog" ) func main() { @@ -20,16 +21,22 @@ func main() { // in this example it will be handled in-memory storage := storage.NewStorage(storage.NewUserStore(issuer)) - router := exampleop.SetupServer(issuer, storage) + logger := slog.New( + slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelDebug, + }), + ) + router := exampleop.SetupServer(issuer, storage, logger, false) server := &http.Server{ Addr: ":" + port, Handler: router, } - log.Printf("server listening on http://localhost:%s/", port) - log.Println("press ctrl+c to stop") + logger.Info("server listening, press ctrl+c to stop", "addr", fmt.Sprintf("http://localhost:%s/", port)) err := server.ListenAndServe() - if err != nil { - log.Fatal(err) + if err != http.ErrServerClosed { + logger.Error("server terminated", "error", err) + os.Exit(1) } } diff --git a/example/server/storage/client.go b/example/server/storage/client.go index b28d9d4c..2e57cc5b 100644 --- a/example/server/storage/client.go +++ b/example/server/storage/client.go @@ -3,8 +3,8 @@ package storage import ( "time" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" ) var ( @@ -185,7 +185,7 @@ func WebClient(id, secret string, redirectURIs ...string) *Client { authMethod: oidc.AuthMethodBasic, loginURL: defaultLoginURL, responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode}, - grantTypes: []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken}, + grantTypes: []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken, oidc.GrantTypeTokenExchange}, accessTokenType: op.AccessTokenTypeBearer, devMode: false, idTokenUserinfoClaimsAssertion: false, diff --git a/example/server/storage/oidc.go b/example/server/storage/oidc.go index f5412cf1..63afcf93 100644 --- a/example/server/storage/oidc.go +++ b/example/server/storage/oidc.go @@ -3,10 +3,11 @@ package storage import ( "time" + "golang.org/x/exp/slog" "golang.org/x/text/language" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" ) const ( @@ -41,6 +42,19 @@ type AuthRequest struct { authTime time.Time } +// LogValue allows you to define which fields will be logged. +// Implements the [slog.LogValuer] +func (a *AuthRequest) LogValue() slog.Value { + return slog.GroupValue( + slog.String("id", a.ID), + slog.Time("creation_date", a.CreationDate), + slog.Any("scopes", a.Scopes), + slog.String("response_type", string(a.ResponseType)), + slog.String("app_id", a.ApplicationID), + slog.String("callback_uri", a.CallbackURI), + ) +} + func (a *AuthRequest) GetID() string { return a.ID } diff --git a/example/server/storage/storage.go b/example/server/storage/storage.go index 30156267..1a04f4fc 100644 --- a/example/server/storage/storage.go +++ b/example/server/storage/storage.go @@ -11,11 +11,11 @@ import ( "sync" "time" + jose "github.com/go-jose/go-jose/v3" "github.com/google/uuid" - "gopkg.in/square/go-jose.v2" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" ) // serviceKey1 is a public key which will be used for the JWT Profile Authorization Grant diff --git a/example/server/storage/storage_dynamic.go b/example/server/storage/storage_dynamic.go index cb16c020..a08f60e5 100644 --- a/example/server/storage/storage_dynamic.go +++ b/example/server/storage/storage_dynamic.go @@ -4,10 +4,10 @@ import ( "context" "time" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" ) type multiStorage struct { diff --git a/go.mod b/go.mod index 6cdd591f..fee739b6 100644 --- a/go.mod +++ b/go.mod @@ -1,13 +1,13 @@ -module github.com/zitadel/oidc/v2 +module github.com/zitadel/oidc/v3 go 1.19 require ( + github.com/go-chi/chi v1.5.4 + github.com/go-jose/go-jose/v3 v3.0.0 github.com/golang/mock v1.6.0 github.com/google/go-github/v31 v31.0.0 github.com/google/uuid v1.3.1 - github.com/gorilla/mux v1.8.0 - github.com/gorilla/schema v1.2.0 github.com/gorilla/securecookie v1.1.1 github.com/jeremija/gosubmit v0.2.7 github.com/muhlemmer/gu v0.3.1 @@ -15,11 +15,13 @@ require ( github.com/rs/cors v1.10.1 github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.8.4 + github.com/zitadel/logging v0.4.0 + github.com/zitadel/schema v1.3.0 go.opentelemetry.io/otel v1.19.0 go.opentelemetry.io/otel/trace v1.19.0 + golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 golang.org/x/oauth2 v0.13.0 golang.org/x/text v0.13.0 - gopkg.in/square/go-jose.v2 v2.6.0 ) require ( diff --git a/go.sum b/go.sum index a3b20018..e1acdaed 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,10 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-chi/chi v1.5.4 h1:QHdzF2szwjqVV4wmByUnTcsbIg7UGaQ0tPF2t5GcAIs= +github.com/go-chi/chi v1.5.4/go.mod h1:uaf8YgoFazUOkPBG7fxPftUylNumIev9awIWOENIuEg= +github.com/go-jose/go-jose/v3 v3.0.0 h1:s6rrhirfEP/CGIoc6p+PZAeogN2SxKav6Wp7+dyMWVo= +github.com/go-jose/go-jose/v3 v3.0.0/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= @@ -13,6 +17,7 @@ github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= @@ -23,10 +28,6 @@ github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= -github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= -github.com/gorilla/schema v1.2.0 h1:YufUaxZYCKGFuAq3c96BOhjgd5nmXiOY9NGzF247Tsc= -github.com/gorilla/schema v1.2.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU= github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/jeremija/gosubmit v0.2.7 h1:At0OhGCFGPXyjPYAsCchoBUhE099pcBXmsb4iZqROIc= @@ -44,10 +45,15 @@ github.com/rs/cors v1.10.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/zitadel/logging v0.4.0 h1:lRAIFgaRoJpLNbsL7jtIYHcMDoEJP9QZB4GqMfl4xaA= +github.com/zitadel/logging v0.4.0/go.mod h1:6uALRJawpkkuUPCkgzfgcPR3c2N908wqnOnIrRelUFc= +github.com/zitadel/schema v1.3.0 h1:kQ9W9tvIwZICCKWcMvCEweXET1OcOyGEuFbHs4o5kg0= +github.com/zitadel/schema v1.3.0/go.mod h1:NptN6mkBDFvERUCvZHlvWmmME+gmZ44xzwRXwhzsbtc= go.opentelemetry.io/otel v1.19.0 h1:MuS/TNf4/j4IXsZuJegVzI1cwut7Qc00344rgH7p8bs= go.opentelemetry.io/otel v1.19.0/go.mod h1:i0QyjOq3UPoTzff0PJB2N66fb4S0+rSbSB15/oyH9fY= go.opentelemetry.io/otel/metric v1.19.0 h1:aTzpGtV0ar9wlV4Sna9sdJyII5jTVJEvKETPiOKwvpE= @@ -55,9 +61,12 @@ go.opentelemetry.io/otel/metric v1.19.0/go.mod h1:L5rUsV9kM1IxCj1MmSdS+JQAcVm319 go.opentelemetry.io/otel/trace v1.19.0 h1:DFVQmlVbfVeOuBRrwdtaehRrWiL1JoVs9CPIQ1Dzxpg= go.opentelemetry.io/otel/trace v1.19.0/go.mod h1:mfaSyvGyEJEI0nyV2I4qhNQnbBOUUmYZpYojqMnX2vo= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= +golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ= +golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -102,8 +111,7 @@ google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI= -gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/testutil/gen/gen.go b/internal/testutil/gen/gen.go index a9f5925e..e4a57189 100644 --- a/internal/testutil/gen/gen.go +++ b/internal/testutil/gen/gen.go @@ -8,8 +8,8 @@ import ( "fmt" "os" - tu "github.com/zitadel/oidc/v2/internal/testutil" - "github.com/zitadel/oidc/v2/pkg/oidc" + tu "github.com/zitadel/oidc/v3/internal/testutil" + "github.com/zitadel/oidc/v3/pkg/oidc" ) var custom = map[string]any{ diff --git a/internal/testutil/token.go b/internal/testutil/token.go index 121aa0ba..2dd788f5 100644 --- a/internal/testutil/token.go +++ b/internal/testutil/token.go @@ -8,8 +8,9 @@ import ( "errors" "time" - "github.com/zitadel/oidc/v2/pkg/oidc" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" + "github.com/muhlemmer/gu" + "github.com/zitadel/oidc/v3/pkg/oidc" ) // KeySet implements oidc.Keys @@ -17,7 +18,7 @@ type KeySet struct{} // VerifySignature implments op.KeySet. func (KeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) { - if ctx.Err() != nil { + if err = ctx.Err(); err != nil { return nil, err } @@ -45,6 +46,16 @@ func init() { } } +type JWTProfileKeyStorage struct{} + +func (JWTProfileKeyStorage) GetKeyByIDAndClientID(ctx context.Context, keyID string, clientID string) (*jose.JSONWebKey, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + + return gu.Ptr(WebKey.Public()), nil +} + func signEncodeTokenClaims(claims any) string { payload, err := json.Marshal(claims) if err != nil { @@ -106,6 +117,25 @@ func NewAccessToken(issuer, subject string, audience []string, expiration time.T return NewAccessTokenCustom(issuer, subject, audience, expiration, jwtid, clientID, skew, nil) } +func NewJWTProfileAssertion(issuer, clientID string, audience []string, issuedAt, expiration time.Time) (string, *oidc.JWTTokenRequest) { + req := &oidc.JWTTokenRequest{ + Issuer: issuer, + Subject: clientID, + Audience: audience, + ExpiresAt: oidc.FromTime(expiration), + IssuedAt: oidc.FromTime(issuedAt), + } + // make sure the private claim map is set correctly + data, err := json.Marshal(req) + if err != nil { + panic(err) + } + if err = json.Unmarshal(data, req); err != nil { + panic(err) + } + return signEncodeTokenClaims(req), req +} + const InvalidSignatureToken = `eyJhbGciOiJQUzUxMiJ9.eyJpc3MiOiJsb2NhbC5jb20iLCJzdWIiOiJ0aW1AbG9jYWwuY29tIiwiYXVkIjpbInVuaXQiLCJ0ZXN0IiwiNTU1NjY2Il0sImV4cCI6MTY3Nzg0MDQzMSwiaWF0IjoxNjc3ODQwMzcwLCJhdXRoX3RpbWUiOjE2Nzc4NDAzMTAsIm5vbmNlIjoiMTIzNDUiLCJhY3IiOiJzb21ldGhpbmciLCJhbXIiOlsiZm9vIiwiYmFyIl0sImF6cCI6IjU1NTY2NiJ9.DtZmvVkuE4Hw48ijBMhRJbxEWCr_WEYuPQBMY73J9TP6MmfeNFkjVJf4nh4omjB9gVLnQ-xhEkNOe62FS5P0BB2VOxPuHZUj34dNspCgG3h98fGxyiMb5vlIYAHDF9T-w_LntlYItohv63MmdYR-hPpAqjXE7KOfErf-wUDGE9R3bfiQ4HpTdyFJB1nsToYrZ9lhP2mzjTCTs58ckZfQ28DFHn_lfHWpR4rJBgvLx7IH4rMrUayr09Ap-PxQLbv0lYMtmgG1z3JK8MXnuYR0UJdZnEIezOzUTlThhCXB-nvuAXYjYxZZTR0FtlgZUHhIpYK0V2abf_Q_Or36akNCUg` // These variables always result in a valid token @@ -137,6 +167,10 @@ func ValidAccessToken() (string, *oidc.AccessTokenClaims) { return NewAccessToken(ValidIssuer, ValidSubject, ValidAudience, ValidExpiration, ValidJWTID, ValidClientID, ValidSkew) } +func ValidJWTProfileAssertion() (string, *oidc.JWTTokenRequest) { + return NewJWTProfileAssertion(ValidClientID, ValidClientID, []string{ValidIssuer}, time.Now(), ValidExpiration) +} + // ACRVerify is a oidc.ACRVerifier func. func ACRVerify(acr string) error { if acr != ValidACR { diff --git a/pkg/client/client.go b/pkg/client/client.go index 7486ef17..b329b3d8 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -11,24 +11,25 @@ import ( "strings" "time" + jose "github.com/go-jose/go-jose/v3" "golang.org/x/oauth2" - "gopkg.in/square/go-jose.v2" - "github.com/zitadel/oidc/v2/pkg/crypto" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/logging" + "github.com/zitadel/oidc/v3/pkg/crypto" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) var Encoder = httphelper.Encoder(oidc.NewEncoder()) // Discover calls the discovery endpoint of the provided issuer and returns its configuration // It accepts an optional argument "wellknownUrl" which can be used to overide the dicovery endpoint url -func Discover(issuer string, httpClient *http.Client, wellKnownUrl ...string) (*oidc.DiscoveryConfiguration, error) { +func Discover(ctx context.Context, issuer string, httpClient *http.Client, wellKnownUrl ...string) (*oidc.DiscoveryConfiguration, error) { wellKnown := strings.TrimSuffix(issuer, "/") + oidc.DiscoveryEndpoint if len(wellKnownUrl) == 1 && wellKnownUrl[0] != "" { wellKnown = wellKnownUrl[0] } - req, err := http.NewRequest("GET", wellKnown, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnown, nil) if err != nil { return nil, err } @@ -37,6 +38,10 @@ func Discover(issuer string, httpClient *http.Client, wellKnownUrl ...string) (* if err != nil { return nil, err } + if logger, ok := logging.FromContext(ctx); ok { + logger.Debug("discover", "config", discoveryConfig) + } + if discoveryConfig.Issuer != issuer { return nil, oidc.ErrIssuerInvalid } @@ -48,12 +53,12 @@ type TokenEndpointCaller interface { HttpClient() *http.Client } -func CallTokenEndpoint(request any, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) { - return callTokenEndpoint(request, nil, caller) +func CallTokenEndpoint(ctx context.Context, request any, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) { + return callTokenEndpoint(ctx, request, nil, caller) } -func callTokenEndpoint(request any, authFn any, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) { - req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, authFn) +func callTokenEndpoint(ctx context.Context, request any, authFn any, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) { + req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, authFn) if err != nil { return nil, err } @@ -80,8 +85,8 @@ type EndSessionCaller interface { HttpClient() *http.Client } -func CallEndSessionEndpoint(request any, authFn any, caller EndSessionCaller) (*url.URL, error) { - req, err := httphelper.FormRequest(caller.GetEndSessionEndpoint(), request, Encoder, authFn) +func CallEndSessionEndpoint(ctx context.Context, request any, authFn any, caller EndSessionCaller) (*url.URL, error) { + req, err := httphelper.FormRequest(ctx, caller.GetEndSessionEndpoint(), request, Encoder, authFn) if err != nil { return nil, err } @@ -123,8 +128,8 @@ type RevokeRequest struct { ClientSecret string `schema:"client_secret"` } -func CallRevokeEndpoint(request any, authFn any, caller RevokeCaller) error { - req, err := httphelper.FormRequest(caller.GetRevokeEndpoint(), request, Encoder, authFn) +func CallRevokeEndpoint(ctx context.Context, request any, authFn any, caller RevokeCaller) error { + req, err := httphelper.FormRequest(ctx, caller.GetRevokeEndpoint(), request, Encoder, authFn) if err != nil { return err } @@ -151,8 +156,8 @@ func CallRevokeEndpoint(request any, authFn any, caller RevokeCaller) error { return nil } -func CallTokenExchangeEndpoint(request any, authFn any, caller TokenEndpointCaller) (resp *oidc.TokenExchangeResponse, err error) { - req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, authFn) +func CallTokenExchangeEndpoint(ctx context.Context, request any, authFn any, caller TokenEndpointCaller) (resp *oidc.TokenExchangeResponse, err error) { + req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, authFn) if err != nil { return nil, err } @@ -192,8 +197,8 @@ type DeviceAuthorizationCaller interface { HttpClient() *http.Client } -func CallDeviceAuthorizationEndpoint(request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller) (*oidc.DeviceAuthorizationResponse, error) { - req, err := httphelper.FormRequest(caller.GetDeviceAuthorizationEndpoint(), request, Encoder, nil) +func CallDeviceAuthorizationEndpoint(ctx context.Context, request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller, authFn any) (*oidc.DeviceAuthorizationResponse, error) { + req, err := httphelper.FormRequest(ctx, caller.GetDeviceAuthorizationEndpoint(), request, Encoder, authFn) if err != nil { return nil, err } @@ -214,7 +219,7 @@ type DeviceAccessTokenRequest struct { } func CallDeviceAccessTokenEndpoint(ctx context.Context, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) { - req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, nil) + req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, nil) if err != nil { return nil, err } diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go index 02c408b1..e06c8252 100644 --- a/pkg/client/client_test.go +++ b/pkg/client/client_test.go @@ -1,6 +1,7 @@ package client import ( + "context" "net/http" "testing" @@ -36,7 +37,7 @@ func TestDiscover(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := Discover(tt.args.issuer, http.DefaultClient, tt.args.wellKnownUrl...) + got, err := Discover(context.Background(), tt.args.issuer, http.DefaultClient, tt.args.wellKnownUrl...) if tt.wantErr { assert.Error(t, err) return diff --git a/pkg/client/integration_test.go b/pkg/client/integration_test.go index ea7225d2..ec4d57b2 100644 --- a/pkg/client/integration_test.go +++ b/pkg/client/integration_test.go @@ -2,33 +2,64 @@ package client_test import ( "bytes" + "context" + "fmt" "io" - "io/ioutil" "math/rand" "net/http" "net/http/cookiejar" "net/http/httptest" "net/url" "os" + "os/signal" "strconv" + "syscall" "testing" "time" "github.com/jeremija/gosubmit" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/exp/slog" + + "github.com/zitadel/oidc/v3/example/server/exampleop" + "github.com/zitadel/oidc/v3/example/server/storage" + "github.com/zitadel/oidc/v3/pkg/client/rp" + "github.com/zitadel/oidc/v3/pkg/client/rs" + "github.com/zitadel/oidc/v3/pkg/client/tokenexchange" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" +) - "github.com/zitadel/oidc/v2/example/server/exampleop" - "github.com/zitadel/oidc/v2/example/server/storage" - "github.com/zitadel/oidc/v2/pkg/client/rp" - "github.com/zitadel/oidc/v2/pkg/client/rs" - "github.com/zitadel/oidc/v2/pkg/client/tokenexchange" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" +var Logger = slog.New( + slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelDebug, + }), ) +var CTX context.Context + +func TestMain(m *testing.M) { + os.Exit(func() int { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGINT) + defer cancel() + CTX, cancel = context.WithTimeout(ctx, time.Minute) + defer cancel() + return m.Run() + }()) +} + func TestRelyingPartySession(t *testing.T) { + for _, wrapServer := range []bool{false, true} { + t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) { + testRelyingPartySession(t, wrapServer) + }) + } +} + +func testRelyingPartySession(t *testing.T, wrapServer bool) { t.Log("------- start example OP ------") targetURL := "http://local-site" exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL)) @@ -36,17 +67,17 @@ func TestRelyingPartySession(t *testing.T) { opServer := httptest.NewServer(&dh) defer opServer.Close() t.Logf("auth server at %s", opServer.URL) - dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage) + dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer) seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano())) clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) t.Log("------- run authorization code flow ------") - provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, "secret") + provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, "secret") t.Log("------- refresh tokens ------") - newTokens, err := rp.RefreshAccessToken(provider, refreshToken, "", "") + newTokens, err := rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "") require.NoError(t, err, "refresh token") assert.NotNil(t, newTokens, "access token") t.Logf("new access token %s", newTokens.AccessToken) @@ -54,11 +85,13 @@ func TestRelyingPartySession(t *testing.T) { t.Logf("new token type %s", newTokens.TokenType) t.Logf("new expiry %s", newTokens.Expiry.Format(time.RFC3339)) require.NotEmpty(t, newTokens.AccessToken, "new accessToken") - assert.NotEmpty(t, newTokens.Extra("id_token"), "new idToken") + assert.NotEmpty(t, newTokens.IDToken, "new idToken") + assert.NotNil(t, newTokens.IDTokenClaims) + assert.Equal(t, newTokens.IDTokenClaims.Subject, tokens.IDTokenClaims.Subject) t.Log("------ end session (logout) ------") - newLoc, err := rp.EndSession(provider, idToken, "", "") + newLoc, err := rp.EndSession(CTX, provider, tokens.IDToken, "", "") require.NoError(t, err, "logout") if newLoc != nil { t.Logf("redirect to %s", newLoc) @@ -67,17 +100,25 @@ func TestRelyingPartySession(t *testing.T) { } t.Log("------ attempt refresh again (should fail) ------") - t.Log("trying original refresh token", refreshToken) - _, err = rp.RefreshAccessToken(provider, refreshToken, "", "") + t.Log("trying original refresh token", tokens.RefreshToken) + _, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "") assert.Errorf(t, err, "refresh with original") if newTokens.RefreshToken != "" { t.Log("trying replacement refresh token", newTokens.RefreshToken) - _, err = rp.RefreshAccessToken(provider, newTokens.RefreshToken, "", "") + _, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, newTokens.RefreshToken, "", "") assert.Errorf(t, err, "refresh with replacement") } } func TestResourceServerTokenExchange(t *testing.T) { + for _, wrapServer := range []bool{false, true} { + t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) { + testResourceServerTokenExchange(t, wrapServer) + }) + } +} + +func testResourceServerTokenExchange(t *testing.T, wrapServer bool) { t.Log("------- start example OP ------") targetURL := "http://local-site" exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL)) @@ -85,23 +126,24 @@ func TestResourceServerTokenExchange(t *testing.T) { opServer := httptest.NewServer(&dh) defer opServer.Close() t.Logf("auth server at %s", opServer.URL) - dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage) + dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer) seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano())) clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) clientSecret := "secret" t.Log("------- run authorization code flow ------") - provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret) + provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret) - resourceServer, err := rs.NewResourceServerClientCredentials(opServer.URL, clientID, clientSecret) + resourceServer, err := rs.NewResourceServerClientCredentials(CTX, opServer.URL, clientID, clientSecret) require.NoError(t, err, "new resource server") t.Log("------- exchage refresh tokens (impersonation) ------") tokenExchangeResponse, err := tokenexchange.ExchangeToken( + CTX, resourceServer, - refreshToken, + tokens.RefreshToken, oidc.RefreshTokenType, "", "", @@ -119,7 +161,7 @@ func TestResourceServerTokenExchange(t *testing.T) { t.Log("------ end session (logout) ------") - newLoc, err := rp.EndSession(provider, idToken, "", "") + newLoc, err := rp.EndSession(CTX, provider, tokens.IDToken, "", "") require.NoError(t, err, "logout") if newLoc != nil { t.Logf("redirect to %s", newLoc) @@ -130,8 +172,9 @@ func TestResourceServerTokenExchange(t *testing.T) { t.Log("------- attempt exchage again (should fail) ------") tokenExchangeResponse, err = tokenexchange.ExchangeToken( + CTX, resourceServer, - refreshToken, + tokens.RefreshToken, oidc.RefreshTokenType, "", "", @@ -145,7 +188,7 @@ func TestResourceServerTokenExchange(t *testing.T) { require.Nil(t, tokenExchangeResponse, "token exchange response") } -func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, accessToken, refreshToken, idToken string) { +func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, tokens *oidc.Tokens[*oidc.IDTokenClaims]) { targetURL := "http://local-site" localURL, err := url.Parse(targetURL + "/login?requestID=1234") require.NoError(t, err, "local url") @@ -167,6 +210,7 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, key := []byte("test1234test1234") cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure()) provider, err = rp.NewRelyingPartyOIDC( + CTX, opServer.URL, clientID, clientSecret, @@ -241,7 +285,8 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, } var email string - redirect := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) { + redirect := func(w http.ResponseWriter, r *http.Request, newTokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) { + tokens = newTokens require.NotNil(t, tokens, "tokens") require.NotNil(t, info, "info") t.Log("access token", tokens.AccessToken) @@ -249,9 +294,6 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, t.Log("id token", tokens.IDToken) t.Log("email", info.Email) - accessToken = tokens.AccessToken - refreshToken = tokens.RefreshToken - idToken = tokens.IDToken email = info.Email http.Redirect(w, r, targetURL, 302) } @@ -273,12 +315,12 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, require.NoError(t, err, "get fully-authorizied redirect location") require.Equal(t, targetURL, authorizedURL.String(), "fully-authorizied redirect location") - require.NotEmpty(t, idToken, "id token") - assert.NotEmpty(t, refreshToken, "refresh token") - assert.NotEmpty(t, accessToken, "access token") + require.NotEmpty(t, tokens.IDToken, "id token") + assert.NotEmpty(t, tokens.RefreshToken, "refresh token") + assert.NotEmpty(t, tokens.AccessToken, "access token") assert.NotEmpty(t, email, "email") - return provider, accessToken, refreshToken, idToken + return provider, tokens } func TestErrorFromPromptNone(t *testing.T) { @@ -299,7 +341,7 @@ func TestErrorFromPromptNone(t *testing.T) { opServer := httptest.NewServer(&dh) defer opServer.Close() t.Logf("auth server at %s", opServer.URL) - dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, op.WithHttpInterceptors( + dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, false, op.WithHttpInterceptors( func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Logf("request to %s", r.URL) @@ -317,6 +359,7 @@ func TestErrorFromPromptNone(t *testing.T) { key := []byte("test1234test1234") cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure()) provider, err := rp.NewRelyingPartyOIDC( + CTX, opServer.URL, clientID, clientSecret, @@ -412,7 +455,7 @@ func getForm(t *testing.T, desc string, httpClient *http.Client, uri *url.URL) [ func fillForm(t *testing.T, desc string, httpClient *http.Client, body []byte, uri *url.URL, opts ...gosubmit.Option) *url.URL { // TODO: switch to io.NopCloser when go1.15 support is dropped - req := gosubmit.ParseWithURL(ioutil.NopCloser(bytes.NewReader(body)), uri.String()).FirstForm().Testing(t).NewTestRequest( + req := gosubmit.ParseWithURL(io.NopCloser(bytes.NewReader(body)), uri.String()).FirstForm().Testing(t).NewTestRequest( append([]gosubmit.Option{gosubmit.AutoFill()}, opts...)..., ) if req.URL.Scheme == "" { diff --git a/pkg/client/jwt_profile.go b/pkg/client/jwt_profile.go index 1686de62..0a5d9ec5 100644 --- a/pkg/client/jwt_profile.go +++ b/pkg/client/jwt_profile.go @@ -1,17 +1,18 @@ package client import ( + "context" "net/url" "golang.org/x/oauth2" - "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) // JWTProfileExchange handles the oauth2 jwt profile exchange -func JWTProfileExchange(jwtProfileGrantRequest *oidc.JWTProfileGrantRequest, caller TokenEndpointCaller) (*oauth2.Token, error) { - return CallTokenEndpoint(jwtProfileGrantRequest, caller) +func JWTProfileExchange(ctx context.Context, jwtProfileGrantRequest *oidc.JWTProfileGrantRequest, caller TokenEndpointCaller) (*oauth2.Token, error) { + return CallTokenEndpoint(ctx, jwtProfileGrantRequest, caller) } func ClientAssertionCodeOptions(assertion string) []oauth2.AuthCodeOption { diff --git a/pkg/client/key.go b/pkg/client/key.go index 740c6d33..0c01dd22 100644 --- a/pkg/client/key.go +++ b/pkg/client/key.go @@ -10,7 +10,7 @@ const ( applicationKey = "application" ) -type keyFile struct { +type KeyFile struct { Type string `json:"type"` // serviceaccount or application KeyID string `json:"keyId"` Key string `json:"key"` @@ -23,7 +23,7 @@ type keyFile struct { ClientID string `json:"clientId"` } -func ConfigFromKeyFile(path string) (*keyFile, error) { +func ConfigFromKeyFile(path string) (*KeyFile, error) { data, err := ioutil.ReadFile(path) if err != nil { return nil, err @@ -31,8 +31,8 @@ func ConfigFromKeyFile(path string) (*keyFile, error) { return ConfigFromKeyFileData(data) } -func ConfigFromKeyFileData(data []byte) (*keyFile, error) { - var f keyFile +func ConfigFromKeyFileData(data []byte) (*KeyFile, error) { + var f KeyFile if err := json.Unmarshal(data, &f); err != nil { return nil, err } diff --git a/pkg/client/profile/jwt_profile.go b/pkg/client/profile/jwt_profile.go index a220dc5c..a24033c9 100644 --- a/pkg/client/profile/jwt_profile.go +++ b/pkg/client/profile/jwt_profile.go @@ -1,16 +1,22 @@ package profile import ( + "context" "net/http" "time" + jose "github.com/go-jose/go-jose/v3" "golang.org/x/oauth2" - "gopkg.in/square/go-jose.v2" - "github.com/zitadel/oidc/v2/pkg/client" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/client" + "github.com/zitadel/oidc/v3/pkg/oidc" ) +type TokenSource interface { + oauth2.TokenSource + TokenCtx(context.Context) (*oauth2.Token, error) +} + // jwtProfileTokenSource implement the oauth2.TokenSource // it will request a token using the OAuth2 JWT Profile Grant // therefore sending an `assertion` by signing a JWT with the provided private key @@ -23,23 +29,38 @@ type jwtProfileTokenSource struct { tokenEndpoint string } -func NewJWTProfileTokenSourceFromKeyFile(issuer, keyPath string, scopes []string, options ...func(source *jwtProfileTokenSource)) (oauth2.TokenSource, error) { - keyData, err := client.ConfigFromKeyFile(keyPath) +// NewJWTProfileTokenSourceFromKeyFile returns an implementation of TokenSource +// It will request a token using the OAuth2 JWT Profile Grant, +// therefore sending an `assertion` by singing a JWT with the provided private key from jsonFile. +// +// The passed context is only used for the call to the Discover endpoint. +func NewJWTProfileTokenSourceFromKeyFile(ctx context.Context, issuer, jsonFile string, scopes []string, options ...func(source *jwtProfileTokenSource)) (TokenSource, error) { + keyData, err := client.ConfigFromKeyFile(jsonFile) if err != nil { return nil, err } - return NewJWTProfileTokenSource(issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...) + return NewJWTProfileTokenSource(ctx, issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...) } -func NewJWTProfileTokenSourceFromKeyFileData(issuer string, data []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (oauth2.TokenSource, error) { - keyData, err := client.ConfigFromKeyFileData(data) +// NewJWTProfileTokenSourceFromKeyFileData returns an implementation of oauth2.TokenSource +// It will request a token using the OAuth2 JWT Profile Grant, +// therefore sending an `assertion` by singing a JWT with the provided private key in jsonData. +// +// The passed context is only used for the call to the Discover endpoint. +func NewJWTProfileTokenSourceFromKeyFileData(ctx context.Context, issuer string, jsonData []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (TokenSource, error) { + keyData, err := client.ConfigFromKeyFileData(jsonData) if err != nil { return nil, err } - return NewJWTProfileTokenSource(issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...) + return NewJWTProfileTokenSource(ctx, issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...) } -func NewJWTProfileTokenSource(issuer, clientID, keyID string, key []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (oauth2.TokenSource, error) { +// NewJWTProfileSource returns an implementation of oauth2.TokenSource +// It will request a token using the OAuth2 JWT Profile Grant, +// therefore sending an `assertion` by singing a JWT with the provided private key. +// +// The passed context is only used for the call to the Discover endpoint. +func NewJWTProfileTokenSource(ctx context.Context, issuer, clientID, keyID string, key []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (TokenSource, error) { signer, err := client.NewSignerFromPrivateKeyByte(key, keyID) if err != nil { return nil, err @@ -55,7 +76,7 @@ func NewJWTProfileTokenSource(issuer, clientID, keyID string, key []byte, scopes opt(source) } if source.tokenEndpoint == "" { - config, err := client.Discover(issuer, source.httpClient) + config, err := client.Discover(ctx, issuer, source.httpClient) if err != nil { return nil, err } @@ -64,13 +85,13 @@ func NewJWTProfileTokenSource(issuer, clientID, keyID string, key []byte, scopes return source, nil } -func WithHTTPClient(client *http.Client) func(*jwtProfileTokenSource) { +func WithHTTPClient(client *http.Client) func(source *jwtProfileTokenSource) { return func(source *jwtProfileTokenSource) { source.httpClient = client } } -func WithStaticTokenEndpoint(issuer, tokenEndpoint string) func(*jwtProfileTokenSource) { +func WithStaticTokenEndpoint(issuer, tokenEndpoint string) func(source *jwtProfileTokenSource) { return func(source *jwtProfileTokenSource) { source.tokenEndpoint = tokenEndpoint } @@ -85,9 +106,13 @@ func (j *jwtProfileTokenSource) HttpClient() *http.Client { } func (j *jwtProfileTokenSource) Token() (*oauth2.Token, error) { + return j.TokenCtx(context.Background()) +} + +func (j *jwtProfileTokenSource) TokenCtx(ctx context.Context) (*oauth2.Token, error) { assertion, err := client.SignedJWTProfileAssertion(j.clientID, j.audience, time.Hour, j.signer) if err != nil { return nil, err } - return client.JWTProfileExchange(oidc.NewJWTProfileGrantRequest(assertion, j.scopes...), j) + return client.JWTProfileExchange(ctx, oidc.NewJWTProfileGrantRequest(assertion, j.scopes...), j) } diff --git a/pkg/client/rp/cli/cli.go b/pkg/client/rp/cli/cli.go index 91b200d8..eeb90112 100644 --- a/pkg/client/rp/cli/cli.go +++ b/pkg/client/rp/cli/cli.go @@ -4,9 +4,9 @@ import ( "context" "net/http" - "github.com/zitadel/oidc/v2/pkg/client/rp" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/client/rp" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) const ( diff --git a/pkg/client/rp/delegation.go b/pkg/client/rp/delegation.go index b16a39e2..23ecffd0 100644 --- a/pkg/client/rp/delegation.go +++ b/pkg/client/rp/delegation.go @@ -1,7 +1,7 @@ package rp import ( - "github.com/zitadel/oidc/v2/pkg/oidc/grants/tokenexchange" + "github.com/zitadel/oidc/v3/pkg/oidc/grants/tokenexchange" ) // DelegationTokenRequest is an implementation of TokenExchangeRequest diff --git a/pkg/client/rp/device.go b/pkg/client/rp/device.go index a397f141..02c647e3 100644 --- a/pkg/client/rp/device.go +++ b/pkg/client/rp/device.go @@ -5,8 +5,8 @@ import ( "fmt" "time" - "github.com/zitadel/oidc/v2/pkg/client" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/client" + "github.com/zitadel/oidc/v3/pkg/oidc" ) func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc.ClientCredentialsRequest, error) { @@ -32,19 +32,21 @@ func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc. // DeviceAuthorization starts a new Device Authorization flow as defined // in RFC 8628, section 3.1 and 3.2: // https://www.rfc-editor.org/rfc/rfc8628#section-3.1 -func DeviceAuthorization(scopes []string, rp RelyingParty) (*oidc.DeviceAuthorizationResponse, error) { +func DeviceAuthorization(ctx context.Context, scopes []string, rp RelyingParty, authFn any) (*oidc.DeviceAuthorizationResponse, error) { + ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAuthorization") req, err := newDeviceClientCredentialsRequest(scopes, rp) if err != nil { return nil, err } - return client.CallDeviceAuthorizationEndpoint(req, rp) + return client.CallDeviceAuthorizationEndpoint(ctx, req, rp, authFn) } // DeviceAccessToken attempts to obtain tokens from a Device Authorization, // by means of polling as defined in RFC, section 3.3 and 3.4: // https://www.rfc-editor.org/rfc/rfc8628#section-3.4 func DeviceAccessToken(ctx context.Context, deviceCode string, interval time.Duration, rp RelyingParty) (resp *oidc.AccessTokenResponse, err error) { + ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAccessToken") req := &client.DeviceAccessTokenRequest{ DeviceAccessTokenRequest: oidc.DeviceAccessTokenRequest{ GrantType: oidc.GrantTypeDeviceCode, diff --git a/pkg/client/rp/jwks.go b/pkg/client/rp/jwks.go index 3438bd6f..28aec9b9 100644 --- a/pkg/client/rp/jwks.go +++ b/pkg/client/rp/jwks.go @@ -7,10 +7,10 @@ import ( "net/http" "sync" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) func NewRemoteKeySet(client *http.Client, jwksURL string, opts ...func(*remoteKeySet)) oidc.KeySet { diff --git a/pkg/client/rp/log.go b/pkg/client/rp/log.go new file mode 100644 index 00000000..6056fa2e --- /dev/null +++ b/pkg/client/rp/log.go @@ -0,0 +1,17 @@ +package rp + +import ( + "context" + + "github.com/zitadel/logging" + "golang.org/x/exp/slog" +) + +func logCtxWithRPData(ctx context.Context, rp RelyingParty, attrs ...any) context.Context { + logger, ok := rp.Logger(ctx) + if !ok { + return ctx + } + logger = logger.With(slog.Group("rp", attrs...)) + return logging.ToContext(ctx, logger) +} diff --git a/pkg/client/rp/relying_party.go b/pkg/client/rp/relying_party.go index 051b8c83..c6ae2db0 100644 --- a/pkg/client/rp/relying_party.go +++ b/pkg/client/rp/relying_party.go @@ -7,16 +7,17 @@ import ( "fmt" "net/http" "net/url" - "strings" "time" + jose "github.com/go-jose/go-jose/v3" "github.com/google/uuid" + "github.com/zitadel/logging" + "golang.org/x/exp/slog" "golang.org/x/oauth2" - "gopkg.in/square/go-jose.v2" - "github.com/zitadel/oidc/v2/pkg/client" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/client" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) const ( @@ -63,11 +64,14 @@ type RelyingParty interface { // be used to start a DeviceAuthorization flow. GetDeviceAuthorizationEndpoint() string - // IDTokenVerifier returns the verifier interface used for oidc id_token verification - IDTokenVerifier() IDTokenVerifier + // IDTokenVerifier returns the verifier used for oidc id_token verification + IDTokenVerifier() *IDTokenVerifier // ErrorHandler returns the handler used for callback errors ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string) + + // Logger from the context, or a fallback if set. + Logger(context.Context) (logger *slog.Logger, ok bool) } type ErrorHandler func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) @@ -88,9 +92,10 @@ type relyingParty struct { cookieHandler *httphelper.CookieHandler errorHandler func(http.ResponseWriter, *http.Request, string, string, string) - idTokenVerifier IDTokenVerifier + idTokenVerifier *IDTokenVerifier verifierOpts []VerifierOption signer jose.Signer + logger *slog.Logger } func (rp *relyingParty) OAuthConfig() *oauth2.Config { @@ -137,7 +142,7 @@ func (rp *relyingParty) GetRevokeEndpoint() string { return rp.endpoints.RevokeURL } -func (rp *relyingParty) IDTokenVerifier() IDTokenVerifier { +func (rp *relyingParty) IDTokenVerifier() *IDTokenVerifier { if rp.idTokenVerifier == nil { rp.idTokenVerifier = NewIDTokenVerifier(rp.issuer, rp.oauthConfig.ClientID, NewRemoteKeySet(rp.httpClient, rp.endpoints.JKWsURL), rp.verifierOpts...) } @@ -151,6 +156,14 @@ func (rp *relyingParty) ErrorHandler() func(http.ResponseWriter, *http.Request, return rp.errorHandler } +func (rp *relyingParty) Logger(ctx context.Context) (logger *slog.Logger, ok bool) { + logger, ok = logging.FromContext(ctx) + if ok { + return logger, ok + } + return rp.logger, rp.logger != nil +} + // NewRelyingPartyOAuth creates an (OAuth2) RelyingParty with the given // OAuth2 Config and possible configOptions // it will use the AuthURL and TokenURL set in config @@ -177,7 +190,7 @@ func NewRelyingPartyOAuth(config *oauth2.Config, options ...Option) (RelyingPart // NewRelyingPartyOIDC creates an (OIDC) RelyingParty with the given // issuer, clientID, clientSecret, redirectURI, scopes and possible configOptions // it will run discovery on the provided issuer and use the found endpoints -func NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI string, scopes []string, options ...Option) (RelyingParty, error) { +func NewRelyingPartyOIDC(ctx context.Context, issuer, clientID, clientSecret, redirectURI string, scopes []string, options ...Option) (RelyingParty, error) { rp := &relyingParty{ issuer: issuer, oauthConfig: &oauth2.Config{ @@ -195,7 +208,8 @@ func NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI string, sco return nil, err } } - discoveryConfiguration, err := client.Discover(rp.issuer, rp.httpClient, rp.DiscoveryEndpoint) + ctx = logCtxWithRPData(ctx, rp, "function", "NewRelyingPartyOIDC") + discoveryConfiguration, err := client.Discover(ctx, rp.issuer, rp.httpClient, rp.DiscoveryEndpoint) if err != nil { return nil, err } @@ -282,6 +296,15 @@ func WithJWTProfile(signerFromKey SignerFromKey) Option { } } +// WithLogger sets a logger that is used +// in case the request context does not contain a logger. +func WithLogger(logger *slog.Logger) Option { + return func(rp *relyingParty) error { + rp.logger = logger + return nil + } +} + type SignerFromKey func() (jose.Signer, error) func SignerFromKeyPath(path string) SignerFromKey { @@ -310,26 +333,6 @@ func SignerFromKeyAndKeyID(key []byte, keyID string) SignerFromKey { } } -// Discover calls the discovery endpoint of the provided issuer and returns the found endpoints -// -// deprecated: use client.Discover -func Discover(issuer string, httpClient *http.Client) (Endpoints, error) { - wellKnown := strings.TrimSuffix(issuer, "/") + oidc.DiscoveryEndpoint - req, err := http.NewRequest("GET", wellKnown, nil) - if err != nil { - return Endpoints{}, err - } - discoveryConfig := new(oidc.DiscoveryConfiguration) - err = httphelper.HttpRequest(httpClient, req, &discoveryConfig) - if err != nil { - return Endpoints{}, err - } - if discoveryConfig.Issuer != issuer { - return Endpoints{}, fmt.Errorf("%w: Expected: %s, got: %s", oidc.ErrIssuerInvalid, discoveryConfig.Issuer, issuer) - } - return GetEndpoints(discoveryConfig), nil -} - // AuthURL returns the auth request url // (wrapping the oauth2 `AuthCodeURL`) func AuthURL(state string, rp RelyingParty, opts ...AuthURLOpt) string { @@ -377,9 +380,29 @@ func GenerateAndStoreCodeChallenge(w http.ResponseWriter, rp RelyingParty) (stri return oidc.NewSHACodeChallenge(codeVerifier), nil } +// ErrMissingIDToken is returned when an id_token was expected, +// but not received in the token response. +var ErrMissingIDToken = errors.New("id_token missing") + +func verifyTokenResponse[C oidc.IDClaims](ctx context.Context, token *oauth2.Token, rp RelyingParty) (*oidc.Tokens[C], error) { + if rp.IsOAuth2Only() { + return &oidc.Tokens[C]{Token: token}, nil + } + idTokenString, ok := token.Extra(idTokenKey).(string) + if !ok { + return &oidc.Tokens[C]{Token: token}, ErrMissingIDToken + } + idToken, err := VerifyTokens[C](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier()) + if err != nil { + return nil, err + } + return &oidc.Tokens[C]{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil +} + // CodeExchange handles the oauth2 code exchange, extracting and validating the id_token // returning it parsed together with the oauth2 tokens (access, refresh) func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens[C], err error) { + ctx = logCtxWithRPData(ctx, rp, "function", "CodeExchange") ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient()) codeOpts := make([]oauth2.AuthCodeOption, 0) for _, opt := range opts { @@ -390,22 +413,7 @@ func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingP if err != nil { return nil, err } - - if rp.IsOAuth2Only() { - return &oidc.Tokens[C]{Token: token}, nil - } - - idTokenString, ok := token.Extra(idTokenKey).(string) - if !ok { - return nil, errors.New("id_token missing") - } - - idToken, err := VerifyTokens[C](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier()) - if err != nil { - return nil, err - } - - return &oidc.Tokens[C]{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil + return verifyTokenResponse[C](ctx, token, rp) } type CodeExchangeCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) @@ -457,14 +465,18 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R } } -type CodeExchangeUserinfoCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, provider RelyingParty, info *oidc.UserInfo) +type SubjectGetter interface { + GetSubject() string +} + +type CodeExchangeUserinfoCallback[C oidc.IDClaims, U SubjectGetter] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, provider RelyingParty, info U) // UserinfoCallback wraps the callback function of the CodeExchangeHandler // and calls the userinfo endpoint with the access token // on success it will pass the userinfo into its callback function as well -func UserinfoCallback[C oidc.IDClaims](f CodeExchangeUserinfoCallback[C]) CodeExchangeCallback[C] { +func UserinfoCallback[C oidc.IDClaims, U SubjectGetter](f CodeExchangeUserinfoCallback[C, U]) CodeExchangeCallback[C] { return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) { - info, err := Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp) + info, err := Userinfo[U](r.Context(), tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp) if err != nil { http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized) return @@ -473,19 +485,26 @@ func UserinfoCallback[C oidc.IDClaims](f CodeExchangeUserinfoCallback[C]) CodeEx } } -// Userinfo will call the OIDC Userinfo Endpoint with the provided token -func Userinfo(token, tokenType, subject string, rp RelyingParty) (*oidc.UserInfo, error) { - req, err := http.NewRequest("GET", rp.UserinfoEndpoint(), nil) +// Userinfo will call the OIDC [UserInfo] Endpoint with the provided token and returns +// the response in an instance of type U. +// [*oidc.UserInfo] can be used as a good example, or use a custom type if type-safe +// access to custom claims is needed. +// +// [UserInfo]: https://openid.net/specs/openid-connect-core-1_0.html#UserInfo +func Userinfo[U SubjectGetter](ctx context.Context, token, tokenType, subject string, rp RelyingParty) (userinfo U, err error) { + var nilU U + ctx = logCtxWithRPData(ctx, rp, "function", "Userinfo") + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, rp.UserinfoEndpoint(), nil) if err != nil { - return nil, err + return nilU, err } req.Header.Set("authorization", tokenType+" "+token) - userinfo := new(oidc.UserInfo) if err := httphelper.HttpRequest(rp.HttpClient(), req, &userinfo); err != nil { - return nil, err + return nilU, err } - if userinfo.Subject != subject { - return nil, ErrUserInfoSubNotMatching + if userinfo.GetSubject() != subject { + return nilU, ErrUserInfoSubNotMatching } return userinfo, nil } @@ -554,7 +573,7 @@ func withURLParam(key, value string) func() []oauth2.AuthCodeOption { // This is the generalized, unexported, function used by both // URLParamOpt and AuthURLOpt. func withPrompt(prompt ...string) func() []oauth2.AuthCodeOption { - return withURLParam("prompt", oidc.SpaceDelimitedArray(prompt).Encode()) + return withURLParam("prompt", oidc.SpaceDelimitedArray(prompt).String()) } type URLParamOpt func() []oauth2.AuthCodeOption @@ -626,11 +645,15 @@ type RefreshTokenRequest struct { GrantType oidc.GrantType `schema:"grant_type"` } -// RefreshAccessToken performs a token refresh. If it doesn't error, it will always +// RefreshTokens performs a token refresh. If it doesn't error, it will always // provide a new AccessToken. It may provide a new RefreshToken, and if it does, then -// the old one should be considered invalid. It may also provide a new IDToken. The -// new IDToken can be retrieved with token.Extra("id_token"). -func RefreshAccessToken(rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oauth2.Token, error) { +// the old one should be considered invalid. +// +// In case the RP is not OAuth2 only and an IDToken was part of the response, +// the IDToken and AccessToken will be verfied +// and the IDToken and IDTokenClaims fields will be populated in the returned object. +func RefreshTokens[C oidc.IDClaims](ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oidc.Tokens[C], error) { + ctx = logCtxWithRPData(ctx, rp, "function", "RefreshTokens") request := RefreshTokenRequest{ RefreshToken: refreshToken, Scopes: rp.OAuthConfig().Scopes, @@ -640,17 +663,28 @@ func RefreshAccessToken(rp RelyingParty, refreshToken, clientAssertion, clientAs ClientAssertionType: clientAssertionType, GrantType: oidc.GrantTypeRefreshToken, } - return client.CallTokenEndpoint(request, tokenEndpointCaller{RelyingParty: rp}) + newToken, err := client.CallTokenEndpoint(ctx, request, tokenEndpointCaller{RelyingParty: rp}) + if err != nil { + return nil, err + } + tokens, err := verifyTokenResponse[C](ctx, newToken, rp) + if err == nil || errors.Is(err, ErrMissingIDToken) { + // https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse + // ...except that it might not contain an id_token. + return tokens, nil + } + return nil, err } -func EndSession(rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) { +func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) { + ctx = logCtxWithRPData(ctx, rp, "function", "EndSession") request := oidc.EndSessionRequest{ IdTokenHint: idToken, ClientID: rp.OAuthConfig().ClientID, PostLogoutRedirectURI: optionalRedirectURI, State: optionalState, } - return client.CallEndSessionEndpoint(request, nil, rp) + return client.CallEndSessionEndpoint(ctx, request, nil, rp) } // RevokeToken requires a RelyingParty that is also a client.RevokeCaller. The RelyingParty @@ -658,7 +692,8 @@ func EndSession(rp RelyingParty, idToken, optionalRedirectURI, optionalState str // NewRelyingPartyOAuth() does not. // // tokenTypeHint should be either "id_token" or "refresh_token". -func RevokeToken(rp RelyingParty, token string, tokenTypeHint string) error { +func RevokeToken(ctx context.Context, rp RelyingParty, token string, tokenTypeHint string) error { + ctx = logCtxWithRPData(ctx, rp, "function", "RevokeToken") request := client.RevokeRequest{ Token: token, TokenTypeHint: tokenTypeHint, @@ -666,7 +701,7 @@ func RevokeToken(rp RelyingParty, token string, tokenTypeHint string) error { ClientSecret: rp.OAuthConfig().ClientSecret, } if rc, ok := rp.(client.RevokeCaller); ok && rc.GetRevokeEndpoint() != "" { - return client.CallRevokeEndpoint(request, nil, rc) + return client.CallRevokeEndpoint(ctx, request, nil, rc) } return fmt.Errorf("RelyingParty does not support RevokeCaller") } diff --git a/pkg/client/rp/relying_party_test.go b/pkg/client/rp/relying_party_test.go new file mode 100644 index 00000000..4c5a1b31 --- /dev/null +++ b/pkg/client/rp/relying_party_test.go @@ -0,0 +1,107 @@ +package rp + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + tu "github.com/zitadel/oidc/v3/internal/testutil" + "github.com/zitadel/oidc/v3/pkg/oidc" + "golang.org/x/oauth2" +) + +func Test_verifyTokenResponse(t *testing.T) { + verifier := &IDTokenVerifier{ + Issuer: tu.ValidIssuer, + MaxAgeIAT: 2 * time.Minute, + ClientID: tu.ValidClientID, + Offset: time.Second, + SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, + KeySet: tu.KeySet{}, + MaxAge: 2 * time.Minute, + ACR: tu.ACRVerify, + Nonce: func(context.Context) string { return tu.ValidNonce }, + } + tests := []struct { + name string + oauth2Only bool + tokens func() (token *oauth2.Token, want *oidc.Tokens[*oidc.IDTokenClaims]) + wantErr error + }{ + { + name: "succes, oauth2 only", + oauth2Only: true, + tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) { + accesToken, _ := tu.ValidAccessToken() + token := &oauth2.Token{ + AccessToken: accesToken, + } + return token, &oidc.Tokens[*oidc.IDTokenClaims]{ + Token: token, + } + }, + }, + { + name: "id_token missing error", + oauth2Only: false, + tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) { + accesToken, _ := tu.ValidAccessToken() + token := &oauth2.Token{ + AccessToken: accesToken, + } + return token, &oidc.Tokens[*oidc.IDTokenClaims]{ + Token: token, + } + }, + wantErr: ErrMissingIDToken, + }, + { + name: "verify tokens error", + oauth2Only: false, + tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) { + accesToken, _ := tu.ValidAccessToken() + token := &oauth2.Token{ + AccessToken: accesToken, + } + token = token.WithExtra(map[string]any{ + "id_token": "foobar", + }) + return token, nil + }, + wantErr: oidc.ErrParse, + }, + { + name: "success, with id_token", + oauth2Only: false, + tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) { + accesToken, _ := tu.ValidAccessToken() + token := &oauth2.Token{ + AccessToken: accesToken, + } + idToken, claims := tu.ValidIDToken() + token = token.WithExtra(map[string]any{ + "id_token": idToken, + }) + return token, &oidc.Tokens[*oidc.IDTokenClaims]{ + Token: token, + IDTokenClaims: claims, + IDToken: idToken, + } + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rp := &relyingParty{ + oauth2Only: tt.oauth2Only, + idTokenVerifier: verifier, + } + token, want := tt.tokens() + got, err := verifyTokenResponse[*oidc.IDTokenClaims](context.Background(), token, rp) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, want, got) + }) + } +} diff --git a/pkg/client/rp/tockenexchange.go b/pkg/client/rp/tockenexchange.go index c1ac88d2..c8ca048f 100644 --- a/pkg/client/rp/tockenexchange.go +++ b/pkg/client/rp/tockenexchange.go @@ -5,7 +5,7 @@ import ( "golang.org/x/oauth2" - "github.com/zitadel/oidc/v2/pkg/oidc/grants/tokenexchange" + "github.com/zitadel/oidc/v3/pkg/oidc/grants/tokenexchange" ) // TokenExchangeRP extends the `RelyingParty` interface for the *draft* oauth2 `Token Exchange` diff --git a/pkg/client/rp/userinfo_example_test.go b/pkg/client/rp/userinfo_example_test.go new file mode 100644 index 00000000..2cc52228 --- /dev/null +++ b/pkg/client/rp/userinfo_example_test.go @@ -0,0 +1,45 @@ +package rp_test + +import ( + "context" + "fmt" + + "github.com/zitadel/oidc/v3/pkg/client/rp" + "github.com/zitadel/oidc/v3/pkg/oidc" +) + +type UserInfo struct { + Subject string `json:"sub,omitempty"` + oidc.UserInfoProfile + oidc.UserInfoEmail + oidc.UserInfoPhone + Address *oidc.UserInfoAddress `json:"address,omitempty"` + + // Foo and Bar are custom claims + Foo string `json:"foo,omitempty"` + Bar struct { + Val1 string `json:"val_1,omitempty"` + Val2 string `json:"val_2,omitempty"` + } `json:"bar,omitempty"` + + // Claims are all the combined claims, including custom. + Claims map[string]any `json:"-,omitempty"` +} + +func (u *UserInfo) GetSubject() string { + return u.Subject +} + +func ExampleUserinfo_custom() { + rpo, err := rp.NewRelyingPartyOIDC(context.TODO(), "http://localhost:8080", "clientid", "clientsecret", "http://example.com/redirect", []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopePhone}) + if err != nil { + panic(err) + } + + info, err := rp.Userinfo[*UserInfo](context.TODO(), "accesstokenstring", "Bearer", "userid", rpo) + if err != nil { + panic(err) + } + + fmt.Println(info) +} diff --git a/pkg/client/rp/verifier.go b/pkg/client/rp/verifier.go index 75d149bd..adf88725 100644 --- a/pkg/client/rp/verifier.go +++ b/pkg/client/rp/verifier.go @@ -4,24 +4,14 @@ import ( "context" "time" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/oidc" ) -type IDTokenVerifier interface { - oidc.Verifier - ClientID() string - SupportedSignAlgs() []string - KeySet() oidc.KeySet - Nonce(context.Context) string - ACR() oidc.ACRVerifier - MaxAge() time.Duration -} - // VerifyTokens implement the Token Response Validation as defined in OIDC specification // https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation -func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v IDTokenVerifier) (claims C, err error) { +func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v *IDTokenVerifier) (claims C, err error) { var nilClaims C claims, err = VerifyIDToken[C](ctx, idToken, v) @@ -36,7 +26,7 @@ func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken str // VerifyIDToken validates the id token according to // https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation -func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVerifier) (claims C, err error) { +func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v *IDTokenVerifier) (claims C, err error) { var nilClaims C decrypted, err := oidc.DecryptToken(token) @@ -52,27 +42,27 @@ func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVe return nilClaims, err } - if err = oidc.CheckIssuer(claims, v.Issuer()); err != nil { + if err = oidc.CheckIssuer(claims, v.Issuer); err != nil { return nilClaims, err } - if err = oidc.CheckAudience(claims, v.ClientID()); err != nil { + if err = oidc.CheckAudience(claims, v.ClientID); err != nil { return nilClaims, err } - if err = oidc.CheckAuthorizedParty(claims, v.ClientID()); err != nil { + if err = oidc.CheckAuthorizedParty(claims, v.ClientID); err != nil { return nilClaims, err } - if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil { + if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil { return nilClaims, err } - if err = oidc.CheckExpiration(claims, v.Offset()); err != nil { + if err = oidc.CheckExpiration(claims, v.Offset); err != nil { return nilClaims, err } - if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil { + if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT, v.Offset); err != nil { return nilClaims, err } @@ -80,16 +70,18 @@ func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVe return nilClaims, err } - if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil { + if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR); err != nil { return nilClaims, err } - if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil { + if err = oidc.CheckAuthTime(claims, v.MaxAge); err != nil { return nilClaims, err } return claims, nil } +type IDTokenVerifier oidc.Verifier + // VerifyAccessToken validates the access token according to // https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error { @@ -107,15 +99,14 @@ func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAl return nil } -// NewIDTokenVerifier returns an implementation of `IDTokenVerifier` -// for `VerifyTokens` and `VerifyIDToken` -func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...VerifierOption) IDTokenVerifier { - v := &idTokenVerifier{ - issuer: issuer, - clientID: clientID, - keySet: keySet, - offset: time.Second, - nonce: func(_ context.Context) string { +// NewIDTokenVerifier returns a oidc.Verifier suitable for ID token verification. +func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...VerifierOption) *IDTokenVerifier { + v := &IDTokenVerifier{ + Issuer: issuer, + ClientID: clientID, + KeySet: keySet, + Offset: time.Second, + Nonce: func(_ context.Context) string { return "" }, } @@ -128,95 +119,47 @@ func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ... } // VerifierOption is the type for providing dynamic options to the IDTokenVerifier -type VerifierOption func(*idTokenVerifier) +type VerifierOption func(*IDTokenVerifier) // WithIssuedAtOffset mitigates the risk of iat to be in the future // because of clock skews with the ability to add an offset to the current time -func WithIssuedAtOffset(offset time.Duration) func(*idTokenVerifier) { - return func(v *idTokenVerifier) { - v.offset = offset +func WithIssuedAtOffset(offset time.Duration) VerifierOption { + return func(v *IDTokenVerifier) { + v.Offset = offset } } // WithIssuedAtMaxAge provides the ability to define the maximum duration between iat and now -func WithIssuedAtMaxAge(maxAge time.Duration) func(*idTokenVerifier) { - return func(v *idTokenVerifier) { - v.maxAgeIAT = maxAge +func WithIssuedAtMaxAge(maxAge time.Duration) VerifierOption { + return func(v *IDTokenVerifier) { + v.MaxAgeIAT = maxAge } } // WithNonce sets the function to check the nonce func WithNonce(nonce func(context.Context) string) VerifierOption { - return func(v *idTokenVerifier) { - v.nonce = nonce + return func(v *IDTokenVerifier) { + v.Nonce = nonce } } // WithACRVerifier sets the verifier for the acr claim func WithACRVerifier(verifier oidc.ACRVerifier) VerifierOption { - return func(v *idTokenVerifier) { - v.acr = verifier + return func(v *IDTokenVerifier) { + v.ACR = verifier } } // WithAuthTimeMaxAge provides the ability to define the maximum duration between auth_time and now func WithAuthTimeMaxAge(maxAge time.Duration) VerifierOption { - return func(v *idTokenVerifier) { - v.maxAge = maxAge + return func(v *IDTokenVerifier) { + v.MaxAge = maxAge } } // WithSupportedSigningAlgorithms overwrites the default RS256 signing algorithm func WithSupportedSigningAlgorithms(algs ...string) VerifierOption { - return func(v *idTokenVerifier) { - v.supportedSignAlgs = algs + return func(v *IDTokenVerifier) { + v.SupportedSignAlgs = algs } } - -type idTokenVerifier struct { - issuer string - maxAgeIAT time.Duration - offset time.Duration - clientID string - supportedSignAlgs []string - keySet oidc.KeySet - acr oidc.ACRVerifier - maxAge time.Duration - nonce func(ctx context.Context) string -} - -func (i *idTokenVerifier) Issuer() string { - return i.issuer -} - -func (i *idTokenVerifier) MaxAgeIAT() time.Duration { - return i.maxAgeIAT -} - -func (i *idTokenVerifier) Offset() time.Duration { - return i.offset -} - -func (i *idTokenVerifier) ClientID() string { - return i.clientID -} - -func (i *idTokenVerifier) SupportedSignAlgs() []string { - return i.supportedSignAlgs -} - -func (i *idTokenVerifier) KeySet() oidc.KeySet { - return i.keySet -} - -func (i *idTokenVerifier) Nonce(ctx context.Context) string { - return i.nonce(ctx) -} - -func (i *idTokenVerifier) ACR() oidc.ACRVerifier { - return i.acr -} - -func (i *idTokenVerifier) MaxAge() time.Duration { - return i.maxAge -} diff --git a/pkg/client/rp/verifier_test.go b/pkg/client/rp/verifier_test.go index f4e0f9d0..ea15c21f 100644 --- a/pkg/client/rp/verifier_test.go +++ b/pkg/client/rp/verifier_test.go @@ -5,24 +5,24 @@ import ( "testing" "time" + jose "github.com/go-jose/go-jose/v3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - tu "github.com/zitadel/oidc/v2/internal/testutil" - "github.com/zitadel/oidc/v2/pkg/oidc" - "gopkg.in/square/go-jose.v2" + tu "github.com/zitadel/oidc/v3/internal/testutil" + "github.com/zitadel/oidc/v3/pkg/oidc" ) func TestVerifyTokens(t *testing.T) { - verifier := &idTokenVerifier{ - issuer: tu.ValidIssuer, - maxAgeIAT: 2 * time.Minute, - offset: time.Second, - supportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, - keySet: tu.KeySet{}, - maxAge: 2 * time.Minute, - acr: tu.ACRVerify, - nonce: func(context.Context) string { return tu.ValidNonce }, - clientID: tu.ValidClientID, + verifier := &IDTokenVerifier{ + Issuer: tu.ValidIssuer, + MaxAgeIAT: 2 * time.Minute, + Offset: time.Second, + SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, + KeySet: tu.KeySet{}, + MaxAge: 2 * time.Minute, + ACR: tu.ACRVerify, + Nonce: func(context.Context) string { return tu.ValidNonce }, + ClientID: tu.ValidClientID, } accessToken, _ := tu.ValidAccessToken() atHash, err := oidc.ClaimHash(accessToken, tu.SignatureAlgorithm) @@ -91,15 +91,15 @@ func TestVerifyTokens(t *testing.T) { } func TestVerifyIDToken(t *testing.T) { - verifier := &idTokenVerifier{ - issuer: tu.ValidIssuer, - maxAgeIAT: 2 * time.Minute, - offset: time.Second, - supportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, - keySet: tu.KeySet{}, - maxAge: 2 * time.Minute, - acr: tu.ACRVerify, - nonce: func(context.Context) string { return tu.ValidNonce }, + verifier := &IDTokenVerifier{ + Issuer: tu.ValidIssuer, + MaxAgeIAT: 2 * time.Minute, + Offset: time.Second, + SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, + KeySet: tu.KeySet{}, + MaxAge: 2 * time.Minute, + ACR: tu.ACRVerify, + Nonce: func(context.Context) string { return tu.ValidNonce }, } tests := []struct { @@ -231,7 +231,7 @@ func TestVerifyIDToken(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { token, want := tt.tokenClaims() - verifier.clientID = tt.clientID + verifier.ClientID = tt.clientID got, err := VerifyIDToken[*oidc.IDTokenClaims](context.Background(), token, verifier) if tt.wantErr { assert.Error(t, err) @@ -312,7 +312,7 @@ func TestNewIDTokenVerifier(t *testing.T) { tests := []struct { name string args args - want IDTokenVerifier + want *IDTokenVerifier }{ { name: "nil nonce", // otherwise assert.Equal will fail on the function @@ -329,16 +329,16 @@ func TestNewIDTokenVerifier(t *testing.T) { WithSupportedSigningAlgorithms("ABC", "DEF"), }, }, - want: &idTokenVerifier{ - issuer: tu.ValidIssuer, - offset: time.Minute, - maxAgeIAT: time.Hour, - clientID: tu.ValidClientID, - keySet: tu.KeySet{}, - nonce: nil, - acr: nil, - maxAge: 2 * time.Hour, - supportedSignAlgs: []string{"ABC", "DEF"}, + want: &IDTokenVerifier{ + Issuer: tu.ValidIssuer, + Offset: time.Minute, + MaxAgeIAT: time.Hour, + ClientID: tu.ValidClientID, + KeySet: tu.KeySet{}, + Nonce: nil, + ACR: nil, + MaxAge: 2 * time.Hour, + SupportedSignAlgs: []string{"ABC", "DEF"}, }, }, } diff --git a/pkg/client/rp/verifier_tokens_example_test.go b/pkg/client/rp/verifier_tokens_example_test.go index c297efe4..892eb235 100644 --- a/pkg/client/rp/verifier_tokens_example_test.go +++ b/pkg/client/rp/verifier_tokens_example_test.go @@ -4,9 +4,9 @@ import ( "context" "fmt" - tu "github.com/zitadel/oidc/v2/internal/testutil" - "github.com/zitadel/oidc/v2/pkg/client/rp" - "github.com/zitadel/oidc/v2/pkg/oidc" + tu "github.com/zitadel/oidc/v3/internal/testutil" + "github.com/zitadel/oidc/v3/pkg/client/rp" + "github.com/zitadel/oidc/v3/pkg/oidc" ) // MyCustomClaims extends the TokenClaims base, diff --git a/pkg/client/rs/introspect_example_test.go b/pkg/client/rs/introspect_example_test.go new file mode 100644 index 00000000..eac8be27 --- /dev/null +++ b/pkg/client/rs/introspect_example_test.go @@ -0,0 +1,52 @@ +package rs_test + +import ( + "context" + "fmt" + + "github.com/zitadel/oidc/v3/pkg/client/rs" + "github.com/zitadel/oidc/v3/pkg/oidc" +) + +type IntrospectionResponse struct { + Active bool `json:"active"` + Scope oidc.SpaceDelimitedArray `json:"scope,omitempty"` + ClientID string `json:"client_id,omitempty"` + TokenType string `json:"token_type,omitempty"` + Expiration oidc.Time `json:"exp,omitempty"` + IssuedAt oidc.Time `json:"iat,omitempty"` + NotBefore oidc.Time `json:"nbf,omitempty"` + Subject string `json:"sub,omitempty"` + Audience oidc.Audience `json:"aud,omitempty"` + Issuer string `json:"iss,omitempty"` + JWTID string `json:"jti,omitempty"` + Username string `json:"username,omitempty"` + oidc.UserInfoProfile + oidc.UserInfoEmail + oidc.UserInfoPhone + Address *oidc.UserInfoAddress `json:"address,omitempty"` + + // Foo and Bar are custom claims + Foo string `json:"foo,omitempty"` + Bar struct { + Val1 string `json:"val_1,omitempty"` + Val2 string `json:"val_2,omitempty"` + } `json:"bar,omitempty"` + + // Claims are all the combined claims, including custom. + Claims map[string]any `json:"-,omitempty"` +} + +func ExampleIntrospect_custom() { + rss, err := rs.NewResourceServerClientCredentials(context.TODO(), "http://localhost:8080", "clientid", "clientsecret") + if err != nil { + panic(err) + } + + resp, err := rs.Introspect[*IntrospectionResponse](context.TODO(), rss, "accesstokenstring") + if err != nil { + panic(err) + } + + fmt.Println(resp) +} diff --git a/pkg/client/rs/resource_server.go b/pkg/client/rs/resource_server.go index 95b6e2e4..57925d54 100644 --- a/pkg/client/rs/resource_server.go +++ b/pkg/client/rs/resource_server.go @@ -6,9 +6,9 @@ import ( "net/http" "time" - "github.com/zitadel/oidc/v2/pkg/client" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/client" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) type ResourceServer interface { @@ -42,14 +42,14 @@ func (r *resourceServer) AuthFn() (any, error) { return r.authFn() } -func NewResourceServerClientCredentials(issuer, clientID, clientSecret string, option ...Option) (ResourceServer, error) { +func NewResourceServerClientCredentials(ctx context.Context, issuer, clientID, clientSecret string, option ...Option) (ResourceServer, error) { authorizer := func() (any, error) { return httphelper.AuthorizeBasic(clientID, clientSecret), nil } - return newResourceServer(issuer, authorizer, option...) + return newResourceServer(ctx, issuer, authorizer, option...) } -func NewResourceServerJWTProfile(issuer, clientID, keyID string, key []byte, options ...Option) (ResourceServer, error) { +func NewResourceServerJWTProfile(ctx context.Context, issuer, clientID, keyID string, key []byte, options ...Option) (ResourceServer, error) { signer, err := client.NewSignerFromPrivateKeyByte(key, keyID) if err != nil { return nil, err @@ -61,10 +61,10 @@ func NewResourceServerJWTProfile(issuer, clientID, keyID string, key []byte, opt } return client.ClientAssertionFormAuthorization(assertion), nil } - return newResourceServer(issuer, authorizer, options...) + return newResourceServer(ctx, issuer, authorizer, options...) } -func newResourceServer(issuer string, authorizer func() (any, error), options ...Option) (*resourceServer, error) { +func newResourceServer(ctx context.Context, issuer string, authorizer func() (any, error), options ...Option) (*resourceServer, error) { rs := &resourceServer{ issuer: issuer, httpClient: httphelper.DefaultHTTPClient, @@ -73,7 +73,7 @@ func newResourceServer(issuer string, authorizer func() (any, error), options .. optFunc(rs) } if rs.introspectURL == "" || rs.tokenURL == "" { - config, err := client.Discover(rs.issuer, rs.httpClient) + config, err := client.Discover(ctx, rs.issuer, rs.httpClient) if err != nil { return nil, err } @@ -91,12 +91,12 @@ func newResourceServer(issuer string, authorizer func() (any, error), options .. return rs, nil } -func NewResourceServerFromKeyFile(issuer, path string, options ...Option) (ResourceServer, error) { +func NewResourceServerFromKeyFile(ctx context.Context, issuer, path string, options ...Option) (ResourceServer, error) { c, err := client.ConfigFromKeyFile(path) if err != nil { return nil, err } - return NewResourceServerJWTProfile(issuer, c.ClientID, c.KeyID, []byte(c.Key), options...) + return NewResourceServerJWTProfile(ctx, issuer, c.ClientID, c.KeyID, []byte(c.Key), options...) } type Option func(*resourceServer) @@ -116,21 +116,27 @@ func WithStaticEndpoints(tokenURL, introspectURL string) Option { } } -func Introspect(ctx context.Context, rp ResourceServer, token string) (*oidc.IntrospectionResponse, error) { +// Introspect calls the [RFC7662] Token Introspection +// endpoint and returns the response in an instance of type R. +// [*oidc.IntrospectionResponse] can be used as a good example, or use a custom type if type-safe +// access to custom claims is needed. +// +// [RFC7662]: https://www.rfc-editor.org/rfc/rfc7662 +func Introspect[R any](ctx context.Context, rp ResourceServer, token string) (resp R, err error) { if rp.IntrospectionURL() == "" { - return nil, errors.New("resource server: introspection URL is empty") + return resp, errors.New("resource server: introspection URL is empty") } authFn, err := rp.AuthFn() if err != nil { - return nil, err + return resp, err } - req, err := httphelper.FormRequest(rp.IntrospectionURL(), &oidc.IntrospectionRequest{Token: token}, client.Encoder, authFn) + req, err := httphelper.FormRequest(ctx, rp.IntrospectionURL(), &oidc.IntrospectionRequest{Token: token}, client.Encoder, authFn) if err != nil { - return nil, err + return resp, err } - resp := new(oidc.IntrospectionResponse) - if err := httphelper.HttpRequest(rp.HttpClient(), req, resp); err != nil { - return nil, err + + if err := httphelper.HttpRequest(rp.HttpClient(), req, &resp); err != nil { + return resp, err } return resp, nil } diff --git a/pkg/client/rs/resource_server_test.go b/pkg/client/rs/resource_server_test.go index 16cb6adf..bb17c641 100644 --- a/pkg/client/rs/resource_server_test.go +++ b/pkg/client/rs/resource_server_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/zitadel/oidc/v3/pkg/oidc" ) func TestNewResourceServer(t *testing.T) { @@ -164,7 +165,7 @@ func TestNewResourceServer(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := newResourceServer(tt.args.issuer, tt.args.authorizer, tt.args.options...) + got, err := newResourceServer(context.Background(), tt.args.issuer, tt.args.authorizer, tt.args.options...) if tt.wantErr { assert.Error(t, err) return @@ -187,6 +188,7 @@ func TestIntrospect(t *testing.T) { token string } rp, err := newResourceServer( + context.Background(), "https://accounts.spotify.com", nil, ) @@ -208,7 +210,7 @@ func TestIntrospect(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := Introspect(tt.args.ctx, tt.args.rp, tt.args.token) + _, err := Introspect[*oidc.IntrospectionResponse](tt.args.ctx, tt.args.rp, tt.args.token) if tt.wantErr { assert.Error(t, err) return diff --git a/pkg/client/tokenexchange/tokenexchange.go b/pkg/client/tokenexchange/tokenexchange.go index 4ae55075..fdac833b 100644 --- a/pkg/client/tokenexchange/tokenexchange.go +++ b/pkg/client/tokenexchange/tokenexchange.go @@ -1,12 +1,13 @@ package tokenexchange import ( + "context" "errors" "net/http" - "github.com/zitadel/oidc/v2/pkg/client" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/client" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) type TokenExchanger interface { @@ -21,18 +22,18 @@ type OAuthTokenExchange struct { authFn func() (any, error) } -func NewTokenExchanger(issuer string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) { - return newOAuthTokenExchange(issuer, nil, options...) +func NewTokenExchanger(ctx context.Context, issuer string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) { + return newOAuthTokenExchange(ctx, issuer, nil, options...) } -func NewTokenExchangerClientCredentials(issuer, clientID, clientSecret string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) { +func NewTokenExchangerClientCredentials(ctx context.Context, issuer, clientID, clientSecret string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) { authorizer := func() (any, error) { return httphelper.AuthorizeBasic(clientID, clientSecret), nil } - return newOAuthTokenExchange(issuer, authorizer, options...) + return newOAuthTokenExchange(ctx, issuer, authorizer, options...) } -func newOAuthTokenExchange(issuer string, authorizer func() (any, error), options ...func(source *OAuthTokenExchange)) (*OAuthTokenExchange, error) { +func newOAuthTokenExchange(ctx context.Context, issuer string, authorizer func() (any, error), options ...func(source *OAuthTokenExchange)) (*OAuthTokenExchange, error) { te := &OAuthTokenExchange{ httpClient: httphelper.DefaultHTTPClient, } @@ -41,7 +42,7 @@ func newOAuthTokenExchange(issuer string, authorizer func() (any, error), option } if te.tokenEndpoint == "" { - config, err := client.Discover(issuer, te.httpClient) + config, err := client.Discover(ctx, issuer, te.httpClient) if err != nil { return nil, err } @@ -89,6 +90,7 @@ func (te *OAuthTokenExchange) AuthFn() (any, error) { // ExchangeToken sends a token exchange request (rfc 8693) to te's token endpoint. // SubjectToken and SubjectTokenType are required parameters. func ExchangeToken( + ctx context.Context, te TokenExchanger, SubjectToken string, SubjectTokenType oidc.TokenType, @@ -123,5 +125,5 @@ func ExchangeToken( RequestedTokenType: RequestedTokenType, } - return client.CallTokenExchangeEndpoint(request, authFn, te) + return client.CallTokenExchangeEndpoint(ctx, request, authFn, te) } diff --git a/pkg/crypto/hash.go b/pkg/crypto/hash.go index 6fcc71fd..0ed2774d 100644 --- a/pkg/crypto/hash.go +++ b/pkg/crypto/hash.go @@ -8,7 +8,7 @@ import ( "fmt" "hash" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" ) var ErrUnsupportedAlgorithm = errors.New("unsupported signing algorithm") diff --git a/pkg/crypto/sign.go b/pkg/crypto/sign.go index 90e4c0e8..a197955c 100644 --- a/pkg/crypto/sign.go +++ b/pkg/crypto/sign.go @@ -4,7 +4,7 @@ import ( "encoding/json" "errors" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" ) func Sign(object any, signer jose.Signer) (string, error) { diff --git a/pkg/http/http.go b/pkg/http/http.go index 46f8250e..fd196b06 100644 --- a/pkg/http/http.go +++ b/pkg/http/http.go @@ -33,7 +33,7 @@ func AuthorizeBasic(user, password string) RequestAuthorization { } } -func FormRequest(endpoint string, request any, encoder Encoder, authFn any) (*http.Request, error) { +func FormRequest(ctx context.Context, endpoint string, request any, encoder Encoder, authFn any) (*http.Request, error) { form := url.Values{} if err := encoder.Encode(request, form); err != nil { return nil, err @@ -42,7 +42,7 @@ func FormRequest(endpoint string, request any, encoder Encoder, authFn any) (*ht fn(form) } body := strings.NewReader(form.Encode()) - req, err := http.NewRequest("POST", endpoint, body) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, body) if err != nil { return nil, err } diff --git a/pkg/oidc/authorization.go b/pkg/oidc/authorization.go index ace1de11..511e3962 100644 --- a/pkg/oidc/authorization.go +++ b/pkg/oidc/authorization.go @@ -1,5 +1,9 @@ package oidc +import ( + "golang.org/x/exp/slog" +) + const ( // ScopeOpenID defines the scope `openid` // OpenID Connect requests MUST contain the `openid` scope value @@ -77,7 +81,7 @@ type AuthRequest struct { UILocales Locales `json:"ui_locales" schema:"ui_locales"` IDTokenHint string `json:"id_token_hint" schema:"id_token_hint"` LoginHint string `json:"login_hint" schema:"login_hint"` - ACRValues []string `json:"acr_values" schema:"acr_values"` + ACRValues SpaceDelimitedArray `json:"acr_values" schema:"acr_values"` CodeChallenge string `json:"code_challenge" schema:"code_challenge"` CodeChallengeMethod CodeChallengeMethod `json:"code_challenge_method" schema:"code_challenge_method"` @@ -86,6 +90,15 @@ type AuthRequest struct { RequestParam string `schema:"request"` } +func (a *AuthRequest) LogValue() slog.Value { + return slog.GroupValue( + slog.Any("scopes", a.Scopes), + slog.String("response_type", string(a.ResponseType)), + slog.String("client_id", a.ClientID), + slog.String("redirect_uri", a.RedirectURI), + ) +} + // GetRedirectURI returns the redirect_uri value for the ErrAuthRequest interface func (a *AuthRequest) GetRedirectURI() string { return a.RedirectURI diff --git a/pkg/oidc/authorization_test.go b/pkg/oidc/authorization_test.go new file mode 100644 index 00000000..573d65c3 --- /dev/null +++ b/pkg/oidc/authorization_test.go @@ -0,0 +1,27 @@ +//go:build go1.20 + +package oidc + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/exp/slog" +) + +func TestAuthRequest_LogValue(t *testing.T) { + a := &AuthRequest{ + Scopes: SpaceDelimitedArray{"a", "b"}, + ResponseType: "respType", + ClientID: "123", + RedirectURI: "http://example.com/callback", + } + want := slog.GroupValue( + slog.Any("scopes", SpaceDelimitedArray{"a", "b"}), + slog.String("response_type", "respType"), + slog.String("client_id", "123"), + slog.String("redirect_uri", "http://example.com/callback"), + ) + got := a.LogValue() + assert.Equal(t, want, got) +} diff --git a/pkg/oidc/code_challenge.go b/pkg/oidc/code_challenge.go index 37c17830..32963629 100644 --- a/pkg/oidc/code_challenge.go +++ b/pkg/oidc/code_challenge.go @@ -3,7 +3,7 @@ package oidc import ( "crypto/sha256" - "github.com/zitadel/oidc/v2/pkg/crypto" + "github.com/zitadel/oidc/v3/pkg/crypto" ) const ( diff --git a/pkg/oidc/error.go b/pkg/oidc/error.go index 9e265b34..b690a239 100644 --- a/pkg/oidc/error.go +++ b/pkg/oidc/error.go @@ -3,6 +3,8 @@ package oidc import ( "errors" "fmt" + + "golang.org/x/exp/slog" ) type errorType string @@ -171,3 +173,34 @@ func DefaultToServerError(err error, description string) *Error { } return oauth } + +func (e *Error) LogLevel() slog.Level { + level := slog.LevelWarn + if e.ErrorType == ServerError { + level = slog.LevelError + } + if e.ErrorType == AuthorizationPending { + level = slog.LevelInfo + } + return level +} + +func (e *Error) LogValue() slog.Value { + attrs := make([]slog.Attr, 0, 5) + if e.Parent != nil { + attrs = append(attrs, slog.Any("parent", e.Parent)) + } + if e.Description != "" { + attrs = append(attrs, slog.String("description", e.Description)) + } + if e.ErrorType != "" { + attrs = append(attrs, slog.String("type", string(e.ErrorType))) + } + if e.State != "" { + attrs = append(attrs, slog.String("state", e.State)) + } + if e.redirectDisabled { + attrs = append(attrs, slog.Bool("redirect_disabled", e.redirectDisabled)) + } + return slog.GroupValue(attrs...) +} diff --git a/pkg/oidc/error_go120_test.go b/pkg/oidc/error_go120_test.go new file mode 100644 index 00000000..399d7f71 --- /dev/null +++ b/pkg/oidc/error_go120_test.go @@ -0,0 +1,83 @@ +//go:build go1.20 + +package oidc + +import ( + "io" + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/exp/slog" +) + +func TestError_LogValue(t *testing.T) { + type fields struct { + Parent error + ErrorType errorType + Description string + State string + redirectDisabled bool + } + tests := []struct { + name string + fields fields + want slog.Value + }{ + { + name: "parent", + fields: fields{ + Parent: io.EOF, + }, + want: slog.GroupValue(slog.Any("parent", io.EOF)), + }, + { + name: "description", + fields: fields{ + Description: "oops", + }, + want: slog.GroupValue(slog.String("description", "oops")), + }, + { + name: "errorType", + fields: fields{ + ErrorType: ExpiredToken, + }, + want: slog.GroupValue(slog.String("type", string(ExpiredToken))), + }, + { + name: "state", + fields: fields{ + State: "123", + }, + want: slog.GroupValue(slog.String("state", "123")), + }, + { + name: "all fields", + fields: fields{ + Parent: io.EOF, + Description: "oops", + ErrorType: ExpiredToken, + State: "123", + }, + want: slog.GroupValue( + slog.Any("parent", io.EOF), + slog.String("description", "oops"), + slog.String("type", string(ExpiredToken)), + slog.String("state", "123"), + ), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &Error{ + Parent: tt.fields.Parent, + ErrorType: tt.fields.ErrorType, + Description: tt.fields.Description, + State: tt.fields.State, + redirectDisabled: tt.fields.redirectDisabled, + } + got := e.LogValue() + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/oidc/error_test.go b/pkg/oidc/error_test.go new file mode 100644 index 00000000..0554c8fb --- /dev/null +++ b/pkg/oidc/error_test.go @@ -0,0 +1,81 @@ +package oidc + +import ( + "io" + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/exp/slog" +) + +func TestDefaultToServerError(t *testing.T) { + type args struct { + err error + description string + } + tests := []struct { + name string + args args + want *Error + }{ + { + name: "default", + args: args{ + err: io.ErrClosedPipe, + description: "oops", + }, + want: &Error{ + ErrorType: ServerError, + Description: "oops", + Parent: io.ErrClosedPipe, + }, + }, + { + name: "our Error", + args: args{ + err: ErrAccessDenied(), + description: "oops", + }, + want: &Error{ + ErrorType: AccessDenied, + Description: "The authorization request was denied.", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := DefaultToServerError(tt.args.err, tt.args.description) + assert.ErrorIs(t, got, tt.want) + }) + } +} + +func TestError_LogLevel(t *testing.T) { + tests := []struct { + name string + err *Error + want slog.Level + }{ + { + name: "server error", + err: ErrServerError(), + want: slog.LevelError, + }, + { + name: "authorization pending", + err: ErrAuthorizationPending(), + want: slog.LevelInfo, + }, + { + name: "some other error", + err: ErrAccessDenied(), + want: slog.LevelWarn, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.err.LogLevel() + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/oidc/keyset.go b/pkg/oidc/keyset.go index 7b766a55..6031c01c 100644 --- a/pkg/oidc/keyset.go +++ b/pkg/oidc/keyset.go @@ -7,7 +7,7 @@ import ( "crypto/rsa" "errors" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" ) const ( diff --git a/pkg/oidc/keyset_test.go b/pkg/oidc/keyset_test.go index 82b3ee83..f8641f2a 100644 --- a/pkg/oidc/keyset_test.go +++ b/pkg/oidc/keyset_test.go @@ -7,7 +7,7 @@ import ( "reflect" "testing" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" ) func TestFindKey(t *testing.T) { diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index 36d546c8..b4cb6b67 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -5,11 +5,11 @@ import ( "os" "time" + jose "github.com/go-jose/go-jose/v3" "golang.org/x/oauth2" - "gopkg.in/square/go-jose.v2" "github.com/muhlemmer/gu" - "github.com/zitadel/oidc/v2/pkg/crypto" + "github.com/zitadel/oidc/v3/pkg/crypto" ) const ( diff --git a/pkg/oidc/token_request.go b/pkg/oidc/token_request.go index 07c4ca0f..b43b249b 100644 --- a/pkg/oidc/token_request.go +++ b/pkg/oidc/token_request.go @@ -5,7 +5,7 @@ import ( "fmt" "time" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" ) const ( diff --git a/pkg/oidc/token_test.go b/pkg/oidc/token_test.go index f3ea8d21..9f9ee2d0 100644 --- a/pkg/oidc/token_test.go +++ b/pkg/oidc/token_test.go @@ -4,9 +4,9 @@ import ( "testing" "time" + jose "github.com/go-jose/go-jose/v3" "github.com/stretchr/testify/assert" "golang.org/x/text/language" - "gopkg.in/square/go-jose.v2" ) var ( diff --git a/pkg/oidc/types.go b/pkg/oidc/types.go index 6ab7469d..d8372b86 100644 --- a/pkg/oidc/types.go +++ b/pkg/oidc/types.go @@ -8,10 +8,10 @@ import ( "strings" "time" - "github.com/gorilla/schema" + jose "github.com/go-jose/go-jose/v3" "github.com/muhlemmer/gu" + "github.com/zitadel/schema" "golang.org/x/text/language" - "gopkg.in/square/go-jose.v2" ) type Audience []string @@ -151,7 +151,7 @@ type ResponseType string type ResponseMode string -func (s SpaceDelimitedArray) Encode() string { +func (s SpaceDelimitedArray) String() string { return strings.Join(s, " ") } @@ -161,11 +161,11 @@ func (s *SpaceDelimitedArray) UnmarshalText(text []byte) error { } func (s SpaceDelimitedArray) MarshalText() ([]byte, error) { - return []byte(s.Encode()), nil + return []byte(s.String()), nil } func (s SpaceDelimitedArray) MarshalJSON() ([]byte, error) { - return json.Marshal((s).Encode()) + return json.Marshal((s).String()) } func (s *SpaceDelimitedArray) UnmarshalJSON(data []byte) error { @@ -210,7 +210,7 @@ func (s SpaceDelimitedArray) Value() (driver.Value, error) { func NewEncoder() *schema.Encoder { e := schema.NewEncoder() e.RegisterEncoder(SpaceDelimitedArray{}, func(value reflect.Value) string { - return value.Interface().(SpaceDelimitedArray).Encode() + return value.Interface().(SpaceDelimitedArray).String() }) return e } diff --git a/pkg/oidc/types_test.go b/pkg/oidc/types_test.go index 69540c22..af4f113f 100644 --- a/pkg/oidc/types_test.go +++ b/pkg/oidc/types_test.go @@ -9,9 +9,9 @@ import ( "testing" "time" - "github.com/gorilla/schema" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/zitadel/schema" "golang.org/x/text/language" ) diff --git a/pkg/oidc/userinfo.go b/pkg/oidc/userinfo.go index caff58e9..ef8ebe46 100644 --- a/pkg/oidc/userinfo.go +++ b/pkg/oidc/userinfo.go @@ -29,6 +29,11 @@ func (u *UserInfo) GetAddress() *UserInfoAddress { return u.Address } +// GetSubject implements [rp.SubjectGetter] +func (u *UserInfo) GetSubject() string { + return u.Subject +} + type uiAlias UserInfo func (u *UserInfo) MarshalJSON() ([]byte, error) { diff --git a/pkg/oidc/verifier.go b/pkg/oidc/verifier.go index 1af1ebb8..42fbb20b 100644 --- a/pkg/oidc/verifier.go +++ b/pkg/oidc/verifier.go @@ -10,9 +10,9 @@ import ( "strings" "time" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" - str "github.com/zitadel/oidc/v2/pkg/strings" + str "github.com/zitadel/oidc/v3/pkg/strings" ) type Claims interface { @@ -61,10 +61,19 @@ var ( ErrAtHash = errors.New("at_hash does not correspond to access token") ) -type Verifier interface { - Issuer() string - MaxAgeIAT() time.Duration - Offset() time.Duration +// Verifier caries configuration for the various token verification +// functions. Use package specific constructor functions to know +// which values need to be set. +type Verifier struct { + Issuer string + MaxAgeIAT time.Duration + Offset time.Duration + ClientID string + SupportedSignAlgs []string + MaxAge time.Duration + ACR ACRVerifier + KeySet KeySet + Nonce func(ctx context.Context) string } // ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim @@ -121,6 +130,11 @@ func CheckAudience(claims Claims, clientID string) error { return nil } +// CheckAuthorizedParty checks azp (authorized party) claim requirements. +// +// If the ID Token contains multiple audiences, the Client SHOULD verify that an azp Claim is present. +// If an azp Claim is present, the Client SHOULD verify that its client_id is the Claim Value. +// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation func CheckAuthorizedParty(claims Claims, clientID string) error { if len(claims.GetAudience()) > 1 { if claims.GetAuthorizedParty() == "" { @@ -167,26 +181,26 @@ func CheckSignature(ctx context.Context, token string, payload []byte, claims Cl } func CheckExpiration(claims Claims, offset time.Duration) error { - expiration := claims.GetExpiration().Round(time.Second) - if !time.Now().UTC().Add(offset).Before(expiration) { + expiration := claims.GetExpiration() + if !time.Now().Add(offset).Before(expiration) { return ErrExpired } return nil } func CheckIssuedAt(claims Claims, maxAgeIAT, offset time.Duration) error { - issuedAt := claims.GetIssuedAt().Round(time.Second) + issuedAt := claims.GetIssuedAt() if issuedAt.IsZero() { return ErrIatMissing } - nowWithOffset := time.Now().UTC().Add(offset).Round(time.Second) + nowWithOffset := time.Now().Add(offset).Round(time.Second) if issuedAt.After(nowWithOffset) { return fmt.Errorf("%w: (iat: %v, now with offset: %v)", ErrIatInFuture, issuedAt, nowWithOffset) } if maxAgeIAT == 0 { return nil } - maxAge := time.Now().UTC().Add(-maxAgeIAT).Round(time.Second) + maxAge := time.Now().Add(-maxAgeIAT).Round(time.Second) if issuedAt.Before(maxAge) { return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrIatToOld, maxAge, issuedAt, maxAge.Sub(issuedAt)) } @@ -216,8 +230,8 @@ func CheckAuthTime(claims Claims, maxAge time.Duration) error { if claims.GetAuthTime().IsZero() { return ErrAuthTimeNotPresent } - authTime := claims.GetAuthTime().Round(time.Second) - maxAuthTime := time.Now().UTC().Add(-maxAge).Round(time.Second) + authTime := claims.GetAuthTime() + maxAuthTime := time.Now().Add(-maxAge).Round(time.Second) if authTime.Before(maxAuthTime) { return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrAuthTimeToOld, maxAge, authTime, maxAuthTime.Sub(authTime)) } diff --git a/pkg/oidc/verifier_parse_test.go b/pkg/oidc/verifier_parse_test.go new file mode 100644 index 00000000..105650f0 --- /dev/null +++ b/pkg/oidc/verifier_parse_test.go @@ -0,0 +1,128 @@ +package oidc_test + +import ( + "context" + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + tu "github.com/zitadel/oidc/v3/internal/testutil" + "github.com/zitadel/oidc/v3/pkg/oidc" +) + +func TestParseToken(t *testing.T) { + token, wantClaims := tu.ValidIDToken() + wantClaims.SignatureAlg = "" // unset, because is not part of the JSON payload + + wantPayload, err := json.Marshal(wantClaims) + require.NoError(t, err) + + tests := []struct { + name string + tokenString string + wantErr bool + }{ + { + name: "split error", + tokenString: "nope", + wantErr: true, + }, + { + name: "base64 error", + tokenString: "foo.~.bar", + wantErr: true, + }, + { + name: "success", + tokenString: token, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotClaims := new(oidc.IDTokenClaims) + gotPayload, err := oidc.ParseToken(tt.tokenString, gotClaims) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, wantClaims, gotClaims) + assert.JSONEq(t, string(wantPayload), string(gotPayload)) + }) + } +} + +func TestCheckSignature(t *testing.T) { + errCtx, cancel := context.WithCancel(context.Background()) + cancel() + + token, _ := tu.ValidIDToken() + payload, err := oidc.ParseToken(token, &oidc.IDTokenClaims{}) + require.NoError(t, err) + + type args struct { + ctx context.Context + token string + payload []byte + supportedSigAlgs []string + } + tests := []struct { + name string + args args + wantErr error + }{ + { + name: "parse error", + args: args{ + ctx: context.Background(), + token: "~", + payload: payload, + }, + wantErr: oidc.ErrParse, + }, + { + name: "default sigAlg", + args: args{ + ctx: context.Background(), + token: token, + payload: payload, + }, + }, + { + name: "unsupported sigAlg", + args: args{ + ctx: context.Background(), + token: token, + payload: payload, + supportedSigAlgs: []string{"foo", "bar"}, + }, + wantErr: oidc.ErrSignatureUnsupportedAlg, + }, + { + name: "verify error", + args: args{ + ctx: errCtx, + token: token, + payload: payload, + }, + wantErr: oidc.ErrSignatureInvalid, + }, + { + name: "inequal payloads", + args: args{ + ctx: context.Background(), + token: token, + payload: []byte{0, 1, 2}, + }, + wantErr: oidc.ErrSignatureInvalidPayload, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + claims := new(oidc.TokenClaims) + err := oidc.CheckSignature(tt.args.ctx, tt.args.token, tt.args.payload, claims, tt.args.supportedSigAlgs, tu.KeySet{}) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} diff --git a/pkg/oidc/verifier_test.go b/pkg/oidc/verifier_test.go new file mode 100644 index 00000000..93e71575 --- /dev/null +++ b/pkg/oidc/verifier_test.go @@ -0,0 +1,374 @@ +package oidc + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDecryptToken(t *testing.T) { + const tokenString = "ABC" + got, err := DecryptToken(tokenString) + require.NoError(t, err) + assert.Equal(t, tokenString, got) +} + +func TestDefaultACRVerifier(t *testing.T) { + acrVerfier := DefaultACRVerifier([]string{"foo", "bar"}) + + tests := []struct { + name string + acr string + wantErr string + }{ + { + name: "ok", + acr: "bar", + }, + { + name: "error", + acr: "hello", + wantErr: "expected one of: [foo bar], got: \"hello\"", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := acrVerfier(tt.acr) + if tt.wantErr != "" { + assert.EqualError(t, err, tt.wantErr) + return + } + require.NoError(t, err) + }) + } +} + +func TestCheckSubject(t *testing.T) { + tests := []struct { + name string + claims Claims + wantErr error + }{ + { + name: "missing", + claims: &TokenClaims{}, + wantErr: ErrSubjectMissing, + }, + { + name: "ok", + claims: &TokenClaims{ + Subject: "foo", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckSubject(tt.claims) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckIssuer(t *testing.T) { + const issuer = "foo.bar" + tests := []struct { + name string + claims Claims + wantErr error + }{ + { + name: "missing", + claims: &TokenClaims{}, + wantErr: ErrIssuerInvalid, + }, + { + name: "wrong", + claims: &TokenClaims{ + Issuer: "wrong", + }, + wantErr: ErrIssuerInvalid, + }, + { + name: "ok", + claims: &TokenClaims{ + Issuer: issuer, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckIssuer(tt.claims, issuer) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckAudience(t *testing.T) { + const clientID = "foo.bar" + tests := []struct { + name string + claims Claims + wantErr error + }{ + { + name: "missing", + claims: &TokenClaims{}, + wantErr: ErrAudience, + }, + { + name: "wrong", + claims: &TokenClaims{ + Audience: []string{"wrong"}, + }, + wantErr: ErrAudience, + }, + { + name: "ok", + claims: &TokenClaims{ + Audience: []string{clientID}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckAudience(tt.claims, clientID) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckAuthorizedParty(t *testing.T) { + const clientID = "foo.bar" + tests := []struct { + name string + claims Claims + wantErr error + }{ + { + name: "single audience, no azp", + claims: &TokenClaims{ + Audience: []string{clientID}, + }, + }, + { + name: "multiple audience, no azp", + claims: &TokenClaims{ + Audience: []string{clientID, "other"}, + }, + wantErr: ErrAzpMissing, + }, + { + name: "single audience, with azp", + claims: &TokenClaims{ + Audience: []string{clientID}, + AuthorizedParty: clientID, + }, + }, + { + name: "multiple audience, with azp", + claims: &TokenClaims{ + Audience: []string{clientID, "other"}, + AuthorizedParty: clientID, + }, + }, + { + name: "wrong azp", + claims: &TokenClaims{ + AuthorizedParty: "wrong", + }, + wantErr: ErrAzpInvalid, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckAuthorizedParty(tt.claims, clientID) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckExpiration(t *testing.T) { + const offset = time.Minute + tests := []struct { + name string + claims Claims + wantErr error + }{ + { + name: "missing", + claims: &TokenClaims{}, + wantErr: ErrExpired, + }, + { + name: "expired", + claims: &TokenClaims{ + Expiration: FromTime(time.Now().Add(-2 * offset)), + }, + wantErr: ErrExpired, + }, + { + name: "valid", + claims: &TokenClaims{ + Expiration: FromTime(time.Now().Add(2 * offset)), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckExpiration(tt.claims, offset) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckIssuedAt(t *testing.T) { + const offset = time.Minute + tests := []struct { + name string + maxAgeIAT time.Duration + claims Claims + wantErr error + }{ + { + name: "missing", + claims: &TokenClaims{}, + wantErr: ErrIatMissing, + }, + { + name: "future", + claims: &TokenClaims{ + IssuedAt: FromTime(time.Now().Add(time.Hour)), + }, + wantErr: ErrIatInFuture, + }, + { + name: "no max", + claims: &TokenClaims{ + IssuedAt: FromTime(time.Now()), + }, + }, + { + name: "past max", + maxAgeIAT: time.Minute, + claims: &TokenClaims{ + IssuedAt: FromTime(time.Now().Add(-time.Hour)), + }, + wantErr: ErrIatToOld, + }, + { + name: "within max", + maxAgeIAT: time.Hour, + claims: &TokenClaims{ + IssuedAt: FromTime(time.Now()), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckIssuedAt(tt.claims, tt.maxAgeIAT, offset) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckNonce(t *testing.T) { + const nonce = "123" + tests := []struct { + name string + claims Claims + wantErr error + }{ + { + name: "missing", + claims: &TokenClaims{}, + wantErr: ErrNonceInvalid, + }, + { + name: "wrong", + claims: &TokenClaims{ + Nonce: "wrong", + }, + wantErr: ErrNonceInvalid, + }, + { + name: "ok", + claims: &TokenClaims{ + Nonce: nonce, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckNonce(tt.claims, nonce) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckAuthorizationContextClassReference(t *testing.T) { + tests := []struct { + name string + acr ACRVerifier + wantErr error + }{ + { + name: "error", + acr: func(s string) error { return errors.New("oops") }, + wantErr: ErrAcrInvalid, + }, + { + name: "ok", + acr: func(s string) error { return nil }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckAuthorizationContextClassReference(&IDTokenClaims{}, tt.acr) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckAuthTime(t *testing.T) { + tests := []struct { + name string + claims Claims + maxAge time.Duration + wantErr error + }{ + { + name: "no max age", + claims: &TokenClaims{}, + }, + { + name: "missing", + claims: &TokenClaims{}, + maxAge: time.Minute, + wantErr: ErrAuthTimeNotPresent, + }, + { + name: "expired", + maxAge: time.Minute, + claims: &TokenClaims{ + AuthTime: FromTime(time.Now().Add(-time.Hour)), + }, + wantErr: ErrAuthTimeToOld, + }, + { + name: "ok", + maxAge: time.Minute, + claims: &TokenClaims{ + AuthTime: NowTime(), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckAuthTime(tt.claims, tt.maxAge) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} diff --git a/pkg/op/auth_request.go b/pkg/op/auth_request.go index 7d9f264a..7ef06a8b 100644 --- a/pkg/op/auth_request.go +++ b/pkg/op/auth_request.go @@ -2,6 +2,7 @@ package op import ( "context" + "errors" "fmt" "net" "net/http" @@ -10,11 +11,10 @@ import ( "strings" "time" - "github.com/gorilla/mux" - - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" - str "github.com/zitadel/oidc/v2/pkg/strings" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" + str "github.com/zitadel/oidc/v3/pkg/strings" + "golang.org/x/exp/slog" ) type AuthRequest interface { @@ -39,16 +39,17 @@ type Authorizer interface { Storage() Storage Decoder() httphelper.Decoder Encoder() httphelper.Encoder - IDTokenHintVerifier(context.Context) IDTokenHintVerifier + IDTokenHintVerifier(context.Context) *IDTokenHintVerifier Crypto() Crypto RequestObjectSupported() bool + Logger() *slog.Logger } // AuthorizeValidator is an extension of Authorizer interface // implementing its own validation mechanism for the auth request type AuthorizeValidator interface { Authorizer - ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, IDTokenHintVerifier) (string, error) + ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, *IDTokenHintVerifier) (string, error) } func authorizeHandler(authorizer Authorizer) func(http.ResponseWriter, *http.Request) { @@ -68,23 +69,23 @@ func authorizeCallbackHandler(authorizer Authorizer) func(http.ResponseWriter, * func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { authReq, err := ParseAuthorizeRequest(r, authorizer.Decoder()) if err != nil { - AuthRequestError(w, r, nil, err, authorizer.Encoder()) + AuthRequestError(w, r, nil, err, authorizer) return } ctx := r.Context() if authReq.RequestParam != "" && authorizer.RequestObjectSupported() { - authReq, err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx)) + err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx)) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } } if authReq.ClientID == "" { - AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing client_id"), authorizer.Encoder()) + AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing client_id"), authorizer) return } if authReq.RedirectURI == "" { - AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing redirect_uri"), authorizer.Encoder()) + AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing redirect_uri"), authorizer) return } validation := ValidateAuthRequest @@ -93,21 +94,21 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { } userID, err := validation(ctx, authReq, authorizer.Storage(), authorizer.IDTokenHintVerifier(ctx)) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } if authReq.RequestParam != "" { - AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer.Encoder()) + AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer) return } req, err := authorizer.Storage().CreateAuthRequest(ctx, authReq, userID) if err != nil { - AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer.Encoder()) + AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer) return } client, err := authorizer.Storage().GetClientByClientID(ctx, req.GetClientID()) if err != nil { - AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer.Encoder()) + AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer) return } RedirectToLogin(req.GetID(), client, w, r) @@ -129,31 +130,31 @@ func ParseAuthorizeRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.A // ParseRequestObject parse the `request` parameter, validates the token including the signature // and copies the token claims into the auth request -func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, issuer string) (*oidc.AuthRequest, error) { +func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, issuer string) error { requestObject := new(oidc.RequestObject) payload, err := oidc.ParseToken(authReq.RequestParam, requestObject) if err != nil { - return nil, err + return err } if requestObject.ClientID != "" && requestObject.ClientID != authReq.ClientID { - return authReq, oidc.ErrInvalidRequest() + return oidc.ErrInvalidRequest() } if requestObject.ResponseType != "" && requestObject.ResponseType != authReq.ResponseType { - return authReq, oidc.ErrInvalidRequest() + return oidc.ErrInvalidRequest() } if requestObject.Issuer != requestObject.ClientID { - return authReq, oidc.ErrInvalidRequest() + return oidc.ErrInvalidRequest() } if !str.Contains(requestObject.Audience, issuer) { - return authReq, oidc.ErrInvalidRequest() + return oidc.ErrInvalidRequest() } keySet := &jwtProfileKeySet{storage: storage, clientID: requestObject.Issuer} if err = oidc.CheckSignature(ctx, authReq.RequestParam, payload, requestObject, nil, keySet); err != nil { - return authReq, err + return err } CopyRequestObjectToAuthRequest(authReq, requestObject) - return authReq, nil + return nil } // CopyRequestObjectToAuthRequest overwrites present values from the Request Object into the auth request @@ -205,7 +206,7 @@ func CopyRequestObjectToAuthRequest(authReq *oidc.AuthRequest, requestObject *oi } // ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed -func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier IDTokenHintVerifier) (sub string, err error) { +func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier *IDTokenHintVerifier) (sub string, err error) { authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge) if err != nil { return "", err @@ -385,7 +386,7 @@ func ValidateAuthReqResponseType(client Client, responseType oidc.ResponseType) // ValidateAuthReqIDTokenHint validates the id_token_hint (if passed as parameter in the request) // and returns the `sub` claim -func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier IDTokenHintVerifier) (string, error) { +func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier *IDTokenHintVerifier) (string, error) { if idTokenHint == "" { return "", nil } @@ -405,32 +406,41 @@ func RedirectToLogin(authReqID string, client Client, w http.ResponseWriter, r * // AuthorizeCallback handles the callback after authentication in the Login UI func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { - params := mux.Vars(r) - id := params["id"] - if id == "" { - AuthRequestError(w, r, nil, fmt.Errorf("auth request callback is missing id"), authorizer.Encoder()) + id, err := ParseAuthorizeCallbackRequest(r) + if err != nil { + AuthRequestError(w, r, nil, err, authorizer) return } - authReq, err := authorizer.Storage().AuthRequestByID(r.Context(), id) if err != nil { - AuthRequestError(w, r, nil, err, authorizer.Encoder()) + AuthRequestError(w, r, nil, err, authorizer) return } if !authReq.Done() { AuthRequestError(w, r, authReq, oidc.ErrInteractionRequired().WithDescription("Unfortunately, the user may be not logged in and/or additional interaction is required."), - authorizer.Encoder()) + authorizer) return } AuthResponse(authReq, authorizer, w, r) } +func ParseAuthorizeCallbackRequest(r *http.Request) (id string, err error) { + if err = r.ParseForm(); err != nil { + return "", fmt.Errorf("cannot parse form: %w", err) + } + id = r.Form.Get("id") + if id == "" { + return "", errors.New("auth request callback is missing id") + } + return id, nil +} + // AuthResponse creates the successful authentication response (either code or tokens) func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) { client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID()) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } if authReq.GetResponseType() == oidc.ResponseTypeCode { @@ -444,7 +454,7 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) { code, err := CreateAuthRequestCode(r.Context(), authReq, authorizer.Storage(), authorizer.Crypto()) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } codeResponse := struct { @@ -456,7 +466,7 @@ func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthReques } callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder()) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } http.Redirect(w, r, callback, http.StatusFound) @@ -471,12 +481,12 @@ func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthReque createAccessToken := authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly resp, err := CreateTokenResponse(r.Context(), authReq, client, authorizer, createAccessToken, "", "") if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), resp, authorizer.Encoder()) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } http.Redirect(w, r, callback, http.StatusFound) diff --git a/pkg/op/auth_request_test.go b/pkg/op/auth_request_test.go index e8c9085e..db70fd74 100644 --- a/pkg/op/auth_request_test.go +++ b/pkg/op/auth_request_test.go @@ -11,14 +11,16 @@ import ( "testing" "github.com/golang/mock/gomock" - "github.com/gorilla/schema" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/zitadel/oidc/v2/example/server/storage" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" - "github.com/zitadel/oidc/v2/pkg/op/mock" + "github.com/zitadel/oidc/v3/example/server/storage" + tu "github.com/zitadel/oidc/v3/internal/testutil" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" + "github.com/zitadel/oidc/v3/pkg/op/mock" + "github.com/zitadel/schema" + "golang.org/x/exp/slog" ) func TestAuthorize(t *testing.T) { @@ -39,7 +41,7 @@ func TestAuthorize(t *testing.T) { expect := authorizer.EXPECT() expect.Decoder().Return(schema.NewDecoder()) - expect.Encoder().Return(schema.NewEncoder()) + expect.Logger().Return(slog.Default()) if tt.expect != nil { tt.expect(expect) @@ -123,7 +125,7 @@ func TestValidateAuthRequest(t *testing.T) { type args struct { authRequest *oidc.AuthRequest storage op.Storage - verifier op.IDTokenHintVerifier + verifier *op.IDTokenHintVerifier } tests := []struct { name string @@ -996,7 +998,7 @@ func TestAuthResponseCode(t *testing.T) { authorizer.EXPECT().Crypto().Return(&mockCrypto{ returnErr: io.ErrClosedPipe, }) - authorizer.EXPECT().Encoder().Return(schema.NewEncoder()) + authorizer.EXPECT().Logger().Return(slog.Default()) return authorizer }, }, @@ -1071,3 +1073,71 @@ func TestAuthResponseCode(t *testing.T) { }) } } + +func Test_parseAuthorizeCallbackRequest(t *testing.T) { + tests := []struct { + name string + url string + wantId string + wantErr bool + }{ + { + name: "parse error", + url: "/?id;=99", + wantErr: true, + }, + { + name: "missing id", + url: "/", + wantErr: true, + }, + { + name: "ok", + url: "/?id=99", + wantId: "99", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, tt.url, nil) + gotId, err := op.ParseAuthorizeCallbackRequest(r) + if tt.wantErr { + assert.Error(t, err) + } else { + require.NoError(t, err) + } + assert.Equal(t, tt.wantId, gotId) + }) + } +} + +func TestValidateAuthReqIDTokenHint(t *testing.T) { + token, _ := tu.ValidIDToken() + tests := []struct { + name string + idTokenHint string + want string + wantErr error + }{ + { + name: "empty", + }, + { + name: "verify err", + idTokenHint: "foo", + wantErr: oidc.ErrLoginRequired(), + }, + { + name: "ok", + idTokenHint: token, + want: tu.ValidSubject, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := op.ValidateAuthReqIDTokenHint(context.Background(), tt.idTokenHint, op.NewIDTokenHintVerifier(tu.ValidIssuer, tu.KeySet{})) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/op/client.go b/pkg/op/client.go index 32a3dd1f..04ef3c71 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -7,8 +7,8 @@ import ( "net/url" "time" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) //go:generate go get github.com/dmarkham/enumer @@ -87,7 +87,7 @@ var ( ) type ClientJWTProfile interface { - JWTProfileVerifier(context.Context) JWTProfileVerifier + JWTProfileVerifier(context.Context) *JWTProfileVerifier } func ClientJWTAuth(ctx context.Context, ca oidc.ClientAssertionParams, verifier ClientJWTProfile) (clientID string, err error) { @@ -180,3 +180,10 @@ func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, au } return data.ClientID, false, nil } + +type ClientCredentials struct { + ClientID string `schema:"client_id"` + ClientSecret string `schema:"client_secret"` // Client secret from Basic auth or request body + ClientAssertion string `schema:"client_assertion"` // JWT + ClientAssertionType string `schema:"client_assertion_type"` +} diff --git a/pkg/op/client_test.go b/pkg/op/client_test.go index 1af4157e..0321f88a 100644 --- a/pkg/op/client_test.go +++ b/pkg/op/client_test.go @@ -11,18 +11,18 @@ import ( "testing" "github.com/golang/mock/gomock" - "github.com/gorilla/schema" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" - "github.com/zitadel/oidc/v2/pkg/op/mock" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" + "github.com/zitadel/oidc/v3/pkg/op/mock" + "github.com/zitadel/schema" ) type testClientJWTProfile struct{} -func (testClientJWTProfile) JWTProfileVerifier(context.Context) op.JWTProfileVerifier { return nil } +func (testClientJWTProfile) JWTProfileVerifier(context.Context) *op.JWTProfileVerifier { return nil } func TestClientJWTAuth(t *testing.T) { type args struct { diff --git a/pkg/op/config.go b/pkg/op/config.go index e5789f71..c3834804 100644 --- a/pkg/op/config.go +++ b/pkg/op/config.go @@ -22,14 +22,14 @@ var ( type Configuration interface { IssuerFromRequest(r *http.Request) string Insecure() bool - AuthorizationEndpoint() Endpoint - TokenEndpoint() Endpoint - IntrospectionEndpoint() Endpoint - UserinfoEndpoint() Endpoint - RevocationEndpoint() Endpoint - EndSessionEndpoint() Endpoint - KeysEndpoint() Endpoint - DeviceAuthorizationEndpoint() Endpoint + AuthorizationEndpoint() *Endpoint + TokenEndpoint() *Endpoint + IntrospectionEndpoint() *Endpoint + UserinfoEndpoint() *Endpoint + RevocationEndpoint() *Endpoint + EndSessionEndpoint() *Endpoint + KeysEndpoint() *Endpoint + DeviceAuthorizationEndpoint() *Endpoint AuthMethodPostSupported() bool CodeMethodS256Supported() bool diff --git a/pkg/op/crypto.go b/pkg/op/crypto.go index 6786022e..6ab1e0a6 100644 --- a/pkg/op/crypto.go +++ b/pkg/op/crypto.go @@ -1,7 +1,7 @@ package op import ( - "github.com/zitadel/oidc/v2/pkg/crypto" + "github.com/zitadel/oidc/v3/pkg/crypto" ) type Crypto interface { diff --git a/pkg/op/device.go b/pkg/op/device.go index 2d3a8c91..813c3f57 100644 --- a/pkg/op/device.go +++ b/pkg/op/device.go @@ -12,8 +12,8 @@ import ( "strings" "time" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) type DeviceAuthorizationConfig struct { @@ -57,47 +57,57 @@ var ( func DeviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { if err := DeviceAuthorization(w, r, o); err != nil { - RequestError(w, r, err) + RequestError(w, r, err, o.Logger()) } } } func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) error { - storage, err := assertDeviceStorage(o.Storage()) + req, err := ParseDeviceCodeRequest(r, o) if err != nil { return err } - - req, err := ParseDeviceCodeRequest(r, o) + response, err := createDeviceAuthorization(r.Context(), req, req.ClientID, o) if err != nil { return err } + httphelper.MarshalJSON(w, response) + return nil +} + +func createDeviceAuthorization(ctx context.Context, req *oidc.DeviceAuthorizationRequest, clientID string, o OpenIDProvider) (*oidc.DeviceAuthorizationResponse, error) { + storage, err := assertDeviceStorage(o.Storage()) + if err != nil { + return nil, err + } config := o.DeviceAuthorization() deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes) if err != nil { - return err + return nil, NewStatusError(err, http.StatusInternalServerError) } userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.DashInterval) if err != nil { - return err + return nil, NewStatusError(err, http.StatusInternalServerError) } expires := time.Now().Add(config.Lifetime) - err = storage.StoreDeviceAuthorization(r.Context(), req.ClientID, deviceCode, userCode, expires, req.Scopes) + err = storage.StoreDeviceAuthorization(ctx, clientID, deviceCode, userCode, expires, req.Scopes) if err != nil { - return err + return nil, NewStatusError(err, http.StatusInternalServerError) } var verification *url.URL if config.UserFormURL != "" { if verification, err = url.Parse(config.UserFormURL); err != nil { - return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form") + err = oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form") + return nil, NewStatusError(err, http.StatusInternalServerError) } } else { - if verification, err = url.Parse(IssuerFromContext(r.Context())); err != nil { - return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer") + if verification, err = url.Parse(IssuerFromContext(ctx)); err != nil { + err = oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer") + return nil, NewStatusError(err, http.StatusInternalServerError) } verification.Path = config.UserFormPath } @@ -112,9 +122,7 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide verification.RawQuery = "user_code=" + userCode response.VerificationURIComplete = verification.String() - - httphelper.MarshalJSON(w, response) - return nil + return response, nil } func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuthorizationRequest, error) { @@ -201,7 +209,7 @@ func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang r = r.WithContext(ctx) if err := deviceAccessToken(w, r, exchanger); err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } } diff --git a/pkg/op/device_test.go b/pkg/op/device_test.go index 4b3c98c8..f5452f9d 100644 --- a/pkg/op/device_test.go +++ b/pkg/op/device_test.go @@ -16,8 +16,9 @@ import ( "github.com/muhlemmer/gu" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v3/example/server/storage" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" ) func Test_deviceAuthorizationHandler(t *testing.T) { @@ -319,7 +320,7 @@ func BenchmarkNewUserCode(b *testing.B) { } func TestDeviceAccessToken(t *testing.T) { - storage := testProvider.Storage().(op.DeviceAuthorizationStorage) + storage := testProvider.Storage().(*storage.Storage) storage.StoreDeviceAuthorization(context.Background(), "native", "qwerty", "yuiop", time.Now().Add(time.Minute), []string{"foo"}) storage.CompleteDeviceAuthorization(context.Background(), "yuiop", "tim") @@ -344,7 +345,7 @@ func TestDeviceAccessToken(t *testing.T) { func TestCheckDeviceAuthorizationState(t *testing.T) { now := time.Now() - storage := testProvider.Storage().(op.DeviceAuthorizationStorage) + storage := testProvider.Storage().(*storage.Storage) storage.StoreDeviceAuthorization(context.Background(), "native", "pending", "pending", now.Add(time.Minute), []string{"foo"}) storage.StoreDeviceAuthorization(context.Background(), "native", "denied", "denied", now.Add(time.Minute), []string{"foo"}) storage.StoreDeviceAuthorization(context.Background(), "native", "completed", "completed", now.Add(time.Minute), []string{"foo"}) diff --git a/pkg/op/discovery.go b/pkg/op/discovery.go index 26f89eb1..82512615 100644 --- a/pkg/op/discovery.go +++ b/pkg/op/discovery.go @@ -4,10 +4,10 @@ import ( "context" "net/http" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) type DiscoverStorage interface { @@ -25,7 +25,7 @@ var DefaultSupportedScopes = []string{ func discoveryHandler(c Configuration, s DiscoverStorage) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - Discover(w, CreateDiscoveryConfig(r, c, s)) + Discover(w, CreateDiscoveryConfig(r.Context(), c, s)) } } @@ -33,8 +33,8 @@ func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) { httphelper.MarshalJSON(w, config) } -func CreateDiscoveryConfig(r *http.Request, config Configuration, storage DiscoverStorage) *oidc.DiscoveryConfiguration { - issuer := config.IssuerFromRequest(r) +func CreateDiscoveryConfig(ctx context.Context, config Configuration, storage DiscoverStorage) *oidc.DiscoveryConfiguration { + issuer := IssuerFromContext(ctx) return &oidc.DiscoveryConfiguration{ Issuer: issuer, AuthorizationEndpoint: config.AuthorizationEndpoint().Absolute(issuer), @@ -49,7 +49,38 @@ func CreateDiscoveryConfig(r *http.Request, config Configuration, storage Discov ResponseTypesSupported: ResponseTypes(config), GrantTypesSupported: GrantTypes(config), SubjectTypesSupported: SubjectTypes(config), - IDTokenSigningAlgValuesSupported: SigAlgorithms(r.Context(), storage), + IDTokenSigningAlgValuesSupported: SigAlgorithms(ctx, storage), + RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config), + TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config), + TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config), + IntrospectionEndpointAuthSigningAlgValuesSupported: IntrospectionSigAlgorithms(config), + IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(config), + RevocationEndpointAuthSigningAlgValuesSupported: RevocationSigAlgorithms(config), + RevocationEndpointAuthMethodsSupported: AuthMethodsRevocationEndpoint(config), + ClaimsSupported: SupportedClaims(config), + CodeChallengeMethodsSupported: CodeChallengeMethods(config), + UILocalesSupported: config.SupportedUILocales(), + RequestParameterSupported: config.RequestObjectSupported(), + } +} + +func createDiscoveryConfigV2(ctx context.Context, config Configuration, storage DiscoverStorage, endpoints *Endpoints) *oidc.DiscoveryConfiguration { + issuer := IssuerFromContext(ctx) + return &oidc.DiscoveryConfiguration{ + Issuer: issuer, + AuthorizationEndpoint: endpoints.Authorization.Absolute(issuer), + TokenEndpoint: endpoints.Token.Absolute(issuer), + IntrospectionEndpoint: endpoints.Introspection.Absolute(issuer), + UserinfoEndpoint: endpoints.Userinfo.Absolute(issuer), + RevocationEndpoint: endpoints.Revocation.Absolute(issuer), + EndSessionEndpoint: endpoints.EndSession.Absolute(issuer), + JwksURI: endpoints.JwksURI.Absolute(issuer), + DeviceAuthorizationEndpoint: endpoints.DeviceAuthorization.Absolute(issuer), + ScopesSupported: Scopes(config), + ResponseTypesSupported: ResponseTypes(config), + GrantTypesSupported: GrantTypes(config), + SubjectTypesSupported: SubjectTypes(config), + IDTokenSigningAlgValuesSupported: SigAlgorithms(ctx, storage), RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config), TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config), TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config), diff --git a/pkg/op/discovery_test.go b/pkg/op/discovery_test.go index 2d0b8af5..84e12165 100644 --- a/pkg/op/discovery_test.go +++ b/pkg/op/discovery_test.go @@ -6,14 +6,14 @@ import ( "net/http/httptest" "testing" + jose "github.com/go-jose/go-jose/v3" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/square/go-jose.v2" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" - "github.com/zitadel/oidc/v2/pkg/op/mock" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" + "github.com/zitadel/oidc/v3/pkg/op/mock" ) func TestDiscover(t *testing.T) { @@ -48,9 +48,9 @@ func TestDiscover(t *testing.T) { func TestCreateDiscoveryConfig(t *testing.T) { type args struct { - request *http.Request - c op.Configuration - s op.DiscoverStorage + ctx context.Context + c op.Configuration + s op.DiscoverStorage } tests := []struct { name string @@ -61,7 +61,7 @@ func TestCreateDiscoveryConfig(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := op.CreateDiscoveryConfig(tt.args.request, tt.args.c, tt.args.s) + got := op.CreateDiscoveryConfig(tt.args.ctx, tt.args.c, tt.args.s) assert.Equal(t, tt.want, got) }) } diff --git a/pkg/op/endpoint.go b/pkg/op/endpoint.go index b1e15073..1ac1cad7 100644 --- a/pkg/op/endpoint.go +++ b/pkg/op/endpoint.go @@ -1,32 +1,46 @@ package op -import "strings" +import ( + "errors" + "strings" +) type Endpoint struct { path string url string } -func NewEndpoint(path string) Endpoint { - return Endpoint{path: path} +func NewEndpoint(path string) *Endpoint { + return &Endpoint{path: path} } -func NewEndpointWithURL(path, url string) Endpoint { - return Endpoint{path: path, url: url} +func NewEndpointWithURL(path, url string) *Endpoint { + return &Endpoint{path: path, url: url} } -func (e Endpoint) Relative() string { +func (e *Endpoint) Relative() string { + if e == nil { + return "" + } return relativeEndpoint(e.path) } -func (e Endpoint) Absolute(host string) string { +func (e *Endpoint) Absolute(host string) string { + if e == nil { + return "" + } if e.url != "" { return e.url } return absoluteEndpoint(host, e.path) } -func (e Endpoint) Validate() error { +var ErrNilEndpoint = errors.New("nil endpoint") + +func (e *Endpoint) Validate() error { + if e == nil { + return ErrNilEndpoint + } return nil // TODO: } diff --git a/pkg/op/endpoint_test.go b/pkg/op/endpoint_test.go index 50de89cf..bf112eff 100644 --- a/pkg/op/endpoint_test.go +++ b/pkg/op/endpoint_test.go @@ -3,13 +3,14 @@ package op_test import ( "testing" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/stretchr/testify/require" + "github.com/zitadel/oidc/v3/pkg/op" ) func TestEndpoint_Path(t *testing.T) { tests := []struct { name string - e op.Endpoint + e *op.Endpoint want string }{ { @@ -27,6 +28,11 @@ func TestEndpoint_Path(t *testing.T) { op.NewEndpointWithURL("/test", "http://test.com/test"), "/test", }, + { + "nil", + nil, + "", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -43,7 +49,7 @@ func TestEndpoint_Absolute(t *testing.T) { } tests := []struct { name string - e op.Endpoint + e *op.Endpoint args args want string }{ @@ -77,6 +83,12 @@ func TestEndpoint_Absolute(t *testing.T) { args{"https://host"}, "https://test.com/test", }, + { + "nil", + nil, + args{"https://host"}, + "", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -91,16 +103,19 @@ func TestEndpoint_Absolute(t *testing.T) { func TestEndpoint_Validate(t *testing.T) { tests := []struct { name string - e op.Endpoint - wantErr bool + e *op.Endpoint + wantErr error }{ - // TODO: Add test cases. + { + "nil", + nil, + op.ErrNilEndpoint, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := tt.e.Validate(); (err != nil) != tt.wantErr { - t.Errorf("Endpoint.Validate() error = %v, wantErr %v", err, tt.wantErr) - } + err := tt.e.Validate() + require.ErrorIs(t, err, tt.wantErr) }) } } diff --git a/pkg/op/error.go b/pkg/op/error.go index acca4ab9..0cac14b9 100644 --- a/pkg/op/error.go +++ b/pkg/op/error.go @@ -1,10 +1,14 @@ package op import ( + "context" + "errors" + "fmt" "net/http" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" + "golang.org/x/exp/slog" ) type ErrAuthRequest interface { @@ -13,13 +17,31 @@ type ErrAuthRequest interface { GetState() string } -func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder httphelper.Encoder) { +// LogAuthRequest is an optional interface, +// that allows logging AuthRequest fields. +// If the AuthRequest does not implement this interface, +// no details shall be printed to the logs. +type LogAuthRequest interface { + ErrAuthRequest + slog.LogValuer +} + +func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, authorizer Authorizer) { + e := oidc.DefaultToServerError(err, err.Error()) + logger := authorizer.Logger().With("oidc_error", e) + if authReq == nil { + logger.Log(r.Context(), e.LogLevel(), "auth request") http.Error(w, err.Error(), http.StatusBadRequest) return } - e := oidc.DefaultToServerError(err, err.Error()) + + if logAuthReq, ok := authReq.(LogAuthRequest); ok { + logger = logger.With("auth_request", logAuthReq) + } + if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() { + logger.Log(r.Context(), e.LogLevel(), "auth request: not redirecting") http.Error(w, e.Description, http.StatusBadRequest) return } @@ -28,19 +50,120 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok { responseMode = rm.GetResponseMode() } - url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder) + url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, authorizer.Encoder()) if err != nil { + logger.ErrorContext(r.Context(), "auth response URL", "error", err) http.Error(w, err.Error(), http.StatusBadRequest) return } + logger.Log(r.Context(), e.LogLevel(), "auth request") http.Redirect(w, r, url, http.StatusFound) } -func RequestError(w http.ResponseWriter, r *http.Request, err error) { +func RequestError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) { e := oidc.DefaultToServerError(err, err.Error()) status := http.StatusBadRequest if e.ErrorType == oidc.InvalidClient { - status = 401 + status = http.StatusUnauthorized } + logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e) httphelper.MarshalJSONWithStatus(w, e, status) } + +// TryErrorRedirect tries to handle an error by redirecting a client. +// If this attempt fails, an error is returned that must be returned +// to the client instead. +func TryErrorRedirect(ctx context.Context, authReq ErrAuthRequest, parent error, encoder httphelper.Encoder, logger *slog.Logger) (*Redirect, error) { + e := oidc.DefaultToServerError(parent, parent.Error()) + logger = logger.With("oidc_error", e) + + if authReq == nil { + logger.Log(ctx, e.LogLevel(), "auth request") + return nil, AsStatusError(e, http.StatusBadRequest) + } + + if logAuthReq, ok := authReq.(LogAuthRequest); ok { + logger = logger.With("auth_request", logAuthReq) + } + + if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() { + logger.Log(ctx, e.LogLevel(), "auth request: not redirecting") + return nil, AsStatusError(e, http.StatusBadRequest) + } + + e.State = authReq.GetState() + var responseMode oidc.ResponseMode + if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok { + responseMode = rm.GetResponseMode() + } + url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder) + if err != nil { + logger.ErrorContext(ctx, "auth response URL", "error", err) + return nil, AsStatusError(err, http.StatusBadRequest) + } + logger.Log(ctx, e.LogLevel(), "auth request redirect", "url", url) + return NewRedirect(url), nil +} + +// StatusError wraps an error with a HTTP status code. +// The status code is passed to the handler's writer. +type StatusError struct { + parent error + statusCode int +} + +// NewStatusError sets the parent and statusCode to a new StatusError. +// It is recommended for parent to be an [oidc.Error]. +// +// Typically implementations should only use this to signal something +// very specific, like an internal server error. +// If a returned error is not a StatusError, the framework +// will set a statusCode based on what the standard specifies, +// which is [http.StatusBadRequest] for most of the time. +// If the error encountered can described clearly with a [oidc.Error], +// do not use this function, as it might break standard rules! +func NewStatusError(parent error, statusCode int) StatusError { + return StatusError{ + parent: parent, + statusCode: statusCode, + } +} + +// AsStatusError unwraps a StatusError from err +// and returns it unmodified if found. +// If no StatuError was found, a new one is returned +// with statusCode set to it as a default. +func AsStatusError(err error, statusCode int) (target StatusError) { + if errors.As(err, &target) { + return target + } + return NewStatusError(err, statusCode) +} + +func (e StatusError) Error() string { + return fmt.Sprintf("%s: %s", http.StatusText(e.statusCode), e.parent.Error()) +} + +func (e StatusError) Unwrap() error { + return e.parent +} + +func (e StatusError) Is(err error) bool { + var target StatusError + if !errors.As(err, &target) { + return false + } + return errors.Is(e.parent, target.parent) && + e.statusCode == target.statusCode +} + +// WriteError asserts for a StatusError containing an [oidc.Error]. +// If no StatusError is found, the status code will default to [http.StatusBadRequest]. +// If no [oidc.Error] was found in the parent, the error type defaults to [oidc.ServerError]. +func WriteError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) { + statusError := AsStatusError(err, http.StatusBadRequest) + e := oidc.DefaultToServerError(statusError.parent, statusError.parent.Error()) + + logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e) + httphelper.MarshalJSONWithStatus(w, e, statusError.statusCode) +} diff --git a/pkg/op/error_test.go b/pkg/op/error_test.go new file mode 100644 index 00000000..689ee5ab --- /dev/null +++ b/pkg/op/error_test.go @@ -0,0 +1,677 @@ +package op + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/schema" + "golang.org/x/exp/slog" +) + +func TestAuthRequestError(t *testing.T) { + type args struct { + authReq ErrAuthRequest + err error + } + tests := []struct { + name string + args args + wantCode int + wantHeaders map[string]string + wantBody string + wantLog string + }{ + { + name: "nil auth request", + args: args{ + authReq: nil, + err: io.ErrClosedPipe, + }, + wantCode: http.StatusBadRequest, + wantBody: "io: read/write on closed pipe\n", + wantLog: `{ + "level":"ERROR", + "msg":"auth request", + "time":"not", + "oidc_error":{ + "description":"io: read/write on closed pipe", + "parent":"io: read/write on closed pipe", + "type":"server_error" + } + }`, + }, + { + name: "auth request, no redirect URI", + args: args{ + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + err: oidc.ErrInteractionRequired().WithDescription("sign in"), + }, + wantCode: http.StatusBadRequest, + wantBody: "sign in\n", + wantLog: `{ + "level":"WARN", + "msg":"auth request: not redirecting", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"", + "response_type":"responseType", + "scopes":"a b" + }, + "oidc_error":{ + "description":"sign in", + "type":"interaction_required" + } + }`, + }, + { + name: "auth request, redirect disabled", + args: args{ + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + RedirectURI: "http://example.com/callback", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + err: oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"), + }, + wantCode: http.StatusBadRequest, + wantBody: "oops\n", + wantLog: `{ + "level":"WARN", + "msg":"auth request: not redirecting", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"http://example.com/callback", + "response_type":"responseType", + "scopes":"a b" + }, + "oidc_error":{ + "description":"oops", + "type":"invalid_request", + "redirect_disabled":true + } + }`, + }, + { + name: "auth request, url parse error", + args: args{ + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + RedirectURI: "can't parse this!\n", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + err: oidc.ErrInteractionRequired().WithDescription("sign in"), + }, + wantCode: http.StatusBadRequest, + wantBody: "ErrorType=server_error Parent=parse \"can't parse this!\\n\": net/url: invalid control character in URL\n", + wantLog: `{ + "level":"ERROR", + "msg":"auth response URL", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"can't parse this!\n", + "response_type":"responseType", + "scopes":"a b" + }, + "error":{ + "type":"server_error", + "parent":"parse \"can't parse this!\\n\": net/url: invalid control character in URL" + }, + "oidc_error":{ + "description":"sign in", + "type":"interaction_required" + } + }`, + }, + { + name: "auth request redirect", + args: args{ + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + RedirectURI: "http://example.com/callback", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + err: oidc.ErrInteractionRequired().WithDescription("sign in"), + }, + wantCode: http.StatusFound, + wantHeaders: map[string]string{"Location": "http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1"}, + wantLog: `{ + "level":"WARN", + "msg":"auth request", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"http://example.com/callback", + "response_type":"responseType", + "scopes":"a b" + }, + "oidc_error":{ + "description":"sign in", + "type":"interaction_required" + } + }`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logOut := new(strings.Builder) + authorizer := &Provider{ + encoder: schema.NewEncoder(), + logger: slog.New( + slog.NewJSONHandler(logOut, &slog.HandlerOptions{ + Level: slog.LevelInfo, + }).WithAttrs([]slog.Attr{slog.String("time", "not")}), + ), + } + + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/path", nil) + AuthRequestError(w, r, tt.args.authReq, tt.args.err, authorizer) + + res := w.Result() + defer res.Body.Close() + + assert.Equal(t, tt.wantCode, res.StatusCode) + for key, wantHeader := range tt.wantHeaders { + gotHeader := res.Header.Get(key) + assert.Equalf(t, wantHeader, gotHeader, "header %q", key) + } + gotBody, err := io.ReadAll(res.Body) + require.NoError(t, err, "read result body") + assert.Equal(t, tt.wantBody, string(gotBody), "result body") + + gotLog := logOut.String() + t.Log(gotLog) + assert.JSONEq(t, tt.wantLog, gotLog, "log output") + }) + } +} + +func TestRequestError(t *testing.T) { + tests := []struct { + name string + err error + wantCode int + wantBody string + wantLog string + }{ + { + name: "server error", + err: io.ErrClosedPipe, + wantCode: http.StatusBadRequest, + wantBody: `{"error":"server_error", "error_description":"io: read/write on closed pipe"}`, + wantLog: `{ + "level":"ERROR", + "msg":"request error", + "time":"not", + "oidc_error":{ + "parent":"io: read/write on closed pipe", + "description":"io: read/write on closed pipe", + "type":"server_error"} + }`, + }, + { + name: "invalid client", + err: oidc.ErrInvalidClient().WithDescription("not good"), + wantCode: http.StatusUnauthorized, + wantBody: `{"error":"invalid_client", "error_description":"not good"}`, + wantLog: `{ + "level":"WARN", + "msg":"request error", + "time":"not", + "oidc_error":{ + "description":"not good", + "type":"invalid_client"} + }`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logOut := new(strings.Builder) + logger := slog.New( + slog.NewJSONHandler(logOut, &slog.HandlerOptions{ + Level: slog.LevelInfo, + }).WithAttrs([]slog.Attr{slog.String("time", "not")}), + ) + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/path", nil) + RequestError(w, r, tt.err, logger) + + res := w.Result() + defer res.Body.Close() + + assert.Equal(t, tt.wantCode, res.StatusCode, "status code") + + gotBody, err := io.ReadAll(res.Body) + require.NoError(t, err, "read result body") + assert.JSONEq(t, tt.wantBody, string(gotBody), "result body") + + gotLog := logOut.String() + t.Log(gotLog) + assert.JSONEq(t, tt.wantLog, gotLog, "log output") + }) + } +} + +func TestTryErrorRedirect(t *testing.T) { + type args struct { + ctx context.Context + authReq ErrAuthRequest + parent error + } + tests := []struct { + name string + args args + want *Redirect + wantErr error + wantLog string + }{ + { + name: "nil auth request", + args: args{ + ctx: context.Background(), + authReq: nil, + parent: io.ErrClosedPipe, + }, + wantErr: NewStatusError(io.ErrClosedPipe, http.StatusBadRequest), + wantLog: `{ + "level":"ERROR", + "msg":"auth request", + "time":"not", + "oidc_error":{ + "description":"io: read/write on closed pipe", + "parent":"io: read/write on closed pipe", + "type":"server_error" + } + }`, + }, + { + name: "auth request, no redirect URI", + args: args{ + ctx: context.Background(), + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + parent: oidc.ErrInteractionRequired().WithDescription("sign in"), + }, + wantErr: NewStatusError(oidc.ErrInteractionRequired().WithDescription("sign in"), http.StatusBadRequest), + wantLog: `{ + "level":"WARN", + "msg":"auth request: not redirecting", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"", + "response_type":"responseType", + "scopes":"a b" + }, + "oidc_error":{ + "description":"sign in", + "type":"interaction_required" + } + }`, + }, + { + name: "auth request, redirect disabled", + args: args{ + ctx: context.Background(), + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + RedirectURI: "http://example.com/callback", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + parent: oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"), + }, + wantErr: NewStatusError(oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"), http.StatusBadRequest), + wantLog: `{ + "level":"WARN", + "msg":"auth request: not redirecting", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"http://example.com/callback", + "response_type":"responseType", + "scopes":"a b" + }, + "oidc_error":{ + "description":"oops", + "type":"invalid_request", + "redirect_disabled":true + } + }`, + }, + { + name: "auth request, url parse error", + args: args{ + ctx: context.Background(), + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + RedirectURI: "can't parse this!\n", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + parent: oidc.ErrInteractionRequired().WithDescription("sign in"), + }, + wantErr: func() error { + //lint:ignore SA1007 just recreating the error for testing + _, err := url.Parse("can't parse this!\n") + err = oidc.ErrServerError().WithParent(err) + return NewStatusError(err, http.StatusBadRequest) + }(), + wantLog: `{ + "level":"ERROR", + "msg":"auth response URL", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"can't parse this!\n", + "response_type":"responseType", + "scopes":"a b" + }, + "error":{ + "type":"server_error", + "parent":"parse \"can't parse this!\\n\": net/url: invalid control character in URL" + }, + "oidc_error":{ + "description":"sign in", + "type":"interaction_required" + } + }`, + }, + { + name: "auth request redirect", + args: args{ + ctx: context.Background(), + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + RedirectURI: "http://example.com/callback", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + parent: oidc.ErrInteractionRequired().WithDescription("sign in"), + }, + want: &Redirect{ + URL: "http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1", + }, + wantLog: `{ + "level":"WARN", + "msg":"auth request redirect", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"http://example.com/callback", + "response_type":"responseType", + "scopes":"a b" + }, + "oidc_error":{ + "description":"sign in", + "type":"interaction_required" + }, + "url":"http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1" + }`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logOut := new(strings.Builder) + logger := slog.New( + slog.NewJSONHandler(logOut, &slog.HandlerOptions{ + Level: slog.LevelInfo, + }).WithAttrs([]slog.Attr{slog.String("time", "not")}), + ) + encoder := schema.NewEncoder() + + got, err := TryErrorRedirect(tt.args.ctx, tt.args.authReq, tt.args.parent, encoder, logger) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + + gotLog := logOut.String() + t.Log(gotLog) + assert.JSONEq(t, tt.wantLog, gotLog, "log output") + }) + } +} + +func TestNewStatusError(t *testing.T) { + err := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError) + + want := "Internal Server Error: io: read/write on closed pipe" + got := fmt.Sprint(err) + assert.Equal(t, want, got) +} + +func TestAsStatusError(t *testing.T) { + type args struct { + err error + statusCode int + } + tests := []struct { + name string + args args + want string + }{ + { + name: "already status error", + args: args{ + err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError), + statusCode: http.StatusBadRequest, + }, + want: "Internal Server Error: io: read/write on closed pipe", + }, + { + name: "oidc error", + args: args{ + err: oidc.ErrAcrInvalid, + statusCode: http.StatusBadRequest, + }, + want: "Bad Request: acr is invalid", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := AsStatusError(tt.args.err, tt.args.statusCode) + got := fmt.Sprint(err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestStatusError_Unwrap(t *testing.T) { + err := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError) + require.ErrorIs(t, err, io.ErrClosedPipe) +} + +func TestStatusError_Is(t *testing.T) { + type args struct { + err error + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "nil error", + args: args{err: nil}, + want: false, + }, + { + name: "other error", + args: args{err: io.EOF}, + want: false, + }, + { + name: "other parent", + args: args{err: NewStatusError(io.EOF, http.StatusInternalServerError)}, + want: false, + }, + { + name: "other status", + args: args{err: NewStatusError(io.ErrClosedPipe, http.StatusInsufficientStorage)}, + want: false, + }, + { + name: "same", + args: args{err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)}, + want: true, + }, + { + name: "wrapped", + args: args{err: fmt.Errorf("wrap: %w", NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError))}, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError) + if got := e.Is(tt.args.err); got != tt.want { + t.Errorf("StatusError.Is() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestWriteError(t *testing.T) { + tests := []struct { + name string + err error + wantStatus int + wantBody string + wantLog string + }{ + { + name: "not a status or oidc error", + err: io.ErrClosedPipe, + wantStatus: http.StatusBadRequest, + wantBody: `{ + "error":"server_error", + "error_description":"io: read/write on closed pipe" + }`, + wantLog: `{ + "level":"ERROR", + "msg":"request error", + "oidc_error":{ + "description":"io: read/write on closed pipe", + "parent":"io: read/write on closed pipe", + "type":"server_error" + }, + "time":"not" + }`, + }, + { + name: "status error w/o oidc", + err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError), + wantStatus: http.StatusInternalServerError, + wantBody: `{ + "error":"server_error", + "error_description":"io: read/write on closed pipe" + }`, + wantLog: `{ + "level":"ERROR", + "msg":"request error", + "oidc_error":{ + "description":"io: read/write on closed pipe", + "parent":"io: read/write on closed pipe", + "type":"server_error" + }, + "time":"not" + }`, + }, + { + name: "oidc error w/o status", + err: oidc.ErrInvalidRequest().WithDescription("oops"), + wantStatus: http.StatusBadRequest, + wantBody: `{ + "error":"invalid_request", + "error_description":"oops" + }`, + wantLog: `{ + "level":"WARN", + "msg":"request error", + "oidc_error":{ + "description":"oops", + "type":"invalid_request" + }, + "time":"not" + }`, + }, + { + name: "status with oidc error", + err: NewStatusError( + oidc.ErrUnauthorizedClient().WithDescription("oops"), + http.StatusUnauthorized, + ), + wantStatus: http.StatusUnauthorized, + wantBody: `{ + "error":"unauthorized_client", + "error_description":"oops" + }`, + wantLog: `{ + "level":"WARN", + "msg":"request error", + "oidc_error":{ + "description":"oops", + "type":"unauthorized_client" + }, + "time":"not" + }`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logOut := new(strings.Builder) + logger := slog.New( + slog.NewJSONHandler(logOut, &slog.HandlerOptions{ + Level: slog.LevelInfo, + }).WithAttrs([]slog.Attr{slog.String("time", "not")}), + ) + r := httptest.NewRequest("GET", "/target", nil) + w := httptest.NewRecorder() + + WriteError(w, r, tt.err, logger) + res := w.Result() + assert.Equal(t, tt.wantStatus, res.StatusCode, "status code") + gotBody, err := io.ReadAll(res.Body) + require.NoError(t, err) + assert.JSONEq(t, tt.wantBody, string(gotBody), "body") + assert.JSONEq(t, tt.wantLog, logOut.String()) + }) + } +} diff --git a/pkg/op/keys.go b/pkg/op/keys.go index 239ecbda..fe111f0f 100644 --- a/pkg/op/keys.go +++ b/pkg/op/keys.go @@ -4,9 +4,9 @@ import ( "context" "net/http" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" - httphelper "github.com/zitadel/oidc/v2/pkg/http" + httphelper "github.com/zitadel/oidc/v3/pkg/http" ) type KeyProvider interface { diff --git a/pkg/op/keys_test.go b/pkg/op/keys_test.go index 2e56b781..e1a38512 100644 --- a/pkg/op/keys_test.go +++ b/pkg/op/keys_test.go @@ -7,13 +7,13 @@ import ( "net/http/httptest" "testing" + jose "github.com/go-jose/go-jose/v3" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" - "gopkg.in/square/go-jose.v2" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" - "github.com/zitadel/oidc/v2/pkg/op/mock" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" + "github.com/zitadel/oidc/v3/pkg/op/mock" ) func TestKeys(t *testing.T) { diff --git a/pkg/op/mock/authorizer.mock.go b/pkg/op/mock/authorizer.mock.go index cc913eef..e4297cb8 100644 --- a/pkg/op/mock/authorizer.mock.go +++ b/pkg/op/mock/authorizer.mock.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Authorizer) +// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: Authorizer) // Package mock is a generated GoMock package. package mock @@ -9,8 +9,9 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - http "github.com/zitadel/oidc/v2/pkg/http" - op "github.com/zitadel/oidc/v2/pkg/op" + http "github.com/zitadel/oidc/v3/pkg/http" + op "github.com/zitadel/oidc/v3/pkg/op" + slog "golang.org/x/exp/slog" ) // MockAuthorizer is a mock of Authorizer interface. @@ -79,10 +80,10 @@ func (mr *MockAuthorizerMockRecorder) Encoder() *gomock.Call { } // IDTokenHintVerifier mocks base method. -func (m *MockAuthorizer) IDTokenHintVerifier(arg0 context.Context) op.IDTokenHintVerifier { +func (m *MockAuthorizer) IDTokenHintVerifier(arg0 context.Context) *op.IDTokenHintVerifier { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "IDTokenHintVerifier", arg0) - ret0, _ := ret[0].(op.IDTokenHintVerifier) + ret0, _ := ret[0].(*op.IDTokenHintVerifier) return ret0 } @@ -92,6 +93,20 @@ func (mr *MockAuthorizerMockRecorder) IDTokenHintVerifier(arg0 interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenHintVerifier", reflect.TypeOf((*MockAuthorizer)(nil).IDTokenHintVerifier), arg0) } +// Logger mocks base method. +func (m *MockAuthorizer) Logger() *slog.Logger { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Logger") + ret0, _ := ret[0].(*slog.Logger) + return ret0 +} + +// Logger indicates an expected call of Logger. +func (mr *MockAuthorizerMockRecorder) Logger() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logger", reflect.TypeOf((*MockAuthorizer)(nil).Logger)) +} + // RequestObjectSupported mocks base method. func (m *MockAuthorizer) RequestObjectSupported() bool { m.ctrl.T.Helper() diff --git a/pkg/op/mock/authorizer.mock.impl.go b/pkg/op/mock/authorizer.mock.impl.go index 3f1d525e..ba5082f0 100644 --- a/pkg/op/mock/authorizer.mock.impl.go +++ b/pkg/op/mock/authorizer.mock.impl.go @@ -4,12 +4,12 @@ import ( "context" "testing" + jose "github.com/go-jose/go-jose/v3" "github.com/golang/mock/gomock" - "github.com/gorilla/schema" - "gopkg.in/square/go-jose.v2" + "github.com/zitadel/schema" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" ) func NewAuthorizer(t *testing.T) op.Authorizer { @@ -49,7 +49,7 @@ func ExpectEncoder(a op.Authorizer) { func ExpectVerifier(a op.Authorizer, t *testing.T) { mockA := a.(*MockAuthorizer) mockA.EXPECT().IDTokenHintVerifier(gomock.Any()).DoAndReturn( - func() op.IDTokenHintVerifier { + func() *op.IDTokenHintVerifier { return op.NewIDTokenHintVerifier("", nil) }) } diff --git a/pkg/op/mock/client.go b/pkg/op/mock/client.go index 36df84a9..f01e3ec6 100644 --- a/pkg/op/mock/client.go +++ b/pkg/op/mock/client.go @@ -5,8 +5,8 @@ import ( "github.com/golang/mock/gomock" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" ) func NewClient(t *testing.T) op.Client { diff --git a/pkg/op/mock/client.mock.go b/pkg/op/mock/client.mock.go index e3d19fbc..9be08075 100644 --- a/pkg/op/mock/client.mock.go +++ b/pkg/op/mock/client.mock.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Client) +// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: Client) // Package mock is a generated GoMock package. package mock @@ -9,8 +9,8 @@ import ( time "time" gomock "github.com/golang/mock/gomock" - oidc "github.com/zitadel/oidc/v2/pkg/oidc" - op "github.com/zitadel/oidc/v2/pkg/op" + oidc "github.com/zitadel/oidc/v3/pkg/oidc" + op "github.com/zitadel/oidc/v3/pkg/op" ) // MockClient is a mock of Client interface. diff --git a/pkg/op/mock/configuration.mock.go b/pkg/op/mock/configuration.mock.go index fe7d4da6..f392a455 100644 --- a/pkg/op/mock/configuration.mock.go +++ b/pkg/op/mock/configuration.mock.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Configuration) +// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: Configuration) // Package mock is a generated GoMock package. package mock @@ -9,7 +9,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - op "github.com/zitadel/oidc/v2/pkg/op" + op "github.com/zitadel/oidc/v3/pkg/op" language "golang.org/x/text/language" ) @@ -65,10 +65,10 @@ func (mr *MockConfigurationMockRecorder) AuthMethodPrivateKeyJWTSupported() *gom } // AuthorizationEndpoint mocks base method. -func (m *MockConfiguration) AuthorizationEndpoint() op.Endpoint { +func (m *MockConfiguration) AuthorizationEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AuthorizationEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -107,10 +107,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorization() *gomock.Call { } // DeviceAuthorizationEndpoint mocks base method. -func (m *MockConfiguration) DeviceAuthorizationEndpoint() op.Endpoint { +func (m *MockConfiguration) DeviceAuthorizationEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "DeviceAuthorizationEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -121,10 +121,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorizationEndpoint() *gomock.C } // EndSessionEndpoint mocks base method. -func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint { +func (m *MockConfiguration) EndSessionEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "EndSessionEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -233,10 +233,10 @@ func (mr *MockConfigurationMockRecorder) IntrospectionAuthMethodPrivateKeyJWTSup } // IntrospectionEndpoint mocks base method. -func (m *MockConfiguration) IntrospectionEndpoint() op.Endpoint { +func (m *MockConfiguration) IntrospectionEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "IntrospectionEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -275,10 +275,10 @@ func (mr *MockConfigurationMockRecorder) IssuerFromRequest(arg0 interface{}) *go } // KeysEndpoint mocks base method. -func (m *MockConfiguration) KeysEndpoint() op.Endpoint { +func (m *MockConfiguration) KeysEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "KeysEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -331,10 +331,10 @@ func (mr *MockConfigurationMockRecorder) RevocationAuthMethodPrivateKeyJWTSuppor } // RevocationEndpoint mocks base method. -func (m *MockConfiguration) RevocationEndpoint() op.Endpoint { +func (m *MockConfiguration) RevocationEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RevocationEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -373,10 +373,10 @@ func (mr *MockConfigurationMockRecorder) SupportedUILocales() *gomock.Call { } // TokenEndpoint mocks base method. -func (m *MockConfiguration) TokenEndpoint() op.Endpoint { +func (m *MockConfiguration) TokenEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "TokenEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -401,10 +401,10 @@ func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported } // UserinfoEndpoint mocks base method. -func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint { +func (m *MockConfiguration) UserinfoEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UserinfoEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } diff --git a/pkg/op/mock/discovery.mock.go b/pkg/op/mock/discovery.mock.go index 0c78d525..c5d3d3a6 100644 --- a/pkg/op/mock/discovery.mock.go +++ b/pkg/op/mock/discovery.mock.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: DiscoverStorage) +// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: DiscoverStorage) // Package mock is a generated GoMock package. package mock @@ -8,8 +8,8 @@ import ( context "context" reflect "reflect" + jose "github.com/go-jose/go-jose/v3" gomock "github.com/golang/mock/gomock" - jose "gopkg.in/square/go-jose.v2" ) // MockDiscoverStorage is a mock of DiscoverStorage interface. diff --git a/pkg/op/mock/generate.go b/pkg/op/mock/generate.go index ca288d20..590356cf 100644 --- a/pkg/op/mock/generate.go +++ b/pkg/op/mock/generate.go @@ -1,10 +1,10 @@ package mock //go:generate go install github.com/golang/mock/mockgen@v1.6.0 -//go:generate mockgen -package mock -destination ./storage.mock.go github.com/zitadel/oidc/v2/pkg/op Storage -//go:generate mockgen -package mock -destination ./authorizer.mock.go github.com/zitadel/oidc/v2/pkg/op Authorizer -//go:generate mockgen -package mock -destination ./client.mock.go github.com/zitadel/oidc/v2/pkg/op Client -//go:generate mockgen -package mock -destination ./configuration.mock.go github.com/zitadel/oidc/v2/pkg/op Configuration -//go:generate mockgen -package mock -destination ./discovery.mock.go github.com/zitadel/oidc/v2/pkg/op DiscoverStorage -//go:generate mockgen -package mock -destination ./signer.mock.go github.com/zitadel/oidc/v2/pkg/op SigningKey,Key -//go:generate mockgen -package mock -destination ./key.mock.go github.com/zitadel/oidc/v2/pkg/op KeyProvider +//go:generate mockgen -package mock -destination ./storage.mock.go github.com/zitadel/oidc/v3/pkg/op Storage +//go:generate mockgen -package mock -destination ./authorizer.mock.go github.com/zitadel/oidc/v3/pkg/op Authorizer +//go:generate mockgen -package mock -destination ./client.mock.go github.com/zitadel/oidc/v3/pkg/op Client +//go:generate mockgen -package mock -destination ./configuration.mock.go github.com/zitadel/oidc/v3/pkg/op Configuration +//go:generate mockgen -package mock -destination ./discovery.mock.go github.com/zitadel/oidc/v3/pkg/op DiscoverStorage +//go:generate mockgen -package mock -destination ./signer.mock.go github.com/zitadel/oidc/v3/pkg/op SigningKey,Key +//go:generate mockgen -package mock -destination ./key.mock.go github.com/zitadel/oidc/v3/pkg/op KeyProvider diff --git a/pkg/op/mock/key.mock.go b/pkg/op/mock/key.mock.go index 88316517..122e852c 100644 --- a/pkg/op/mock/key.mock.go +++ b/pkg/op/mock/key.mock.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: KeyProvider) +// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: KeyProvider) // Package mock is a generated GoMock package. package mock @@ -9,7 +9,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - op "github.com/zitadel/oidc/v2/pkg/op" + op "github.com/zitadel/oidc/v3/pkg/op" ) // MockKeyProvider is a mock of KeyProvider interface. diff --git a/pkg/op/mock/signer.mock.go b/pkg/op/mock/signer.mock.go index 78c0efe3..15718e07 100644 --- a/pkg/op/mock/signer.mock.go +++ b/pkg/op/mock/signer.mock.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: SigningKey,Key) +// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: SigningKey,Key) // Package mock is a generated GoMock package. package mock @@ -7,8 +7,8 @@ package mock import ( reflect "reflect" + jose "github.com/go-jose/go-jose/v3" gomock "github.com/golang/mock/gomock" - jose "gopkg.in/square/go-jose.v2" ) // MockSigningKey is a mock of SigningKey interface. diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go index 85afb2a5..a1ce598b 100644 --- a/pkg/op/mock/storage.mock.go +++ b/pkg/op/mock/storage.mock.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Storage) +// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: Storage) // Package mock is a generated GoMock package. package mock @@ -9,10 +9,10 @@ import ( reflect "reflect" time "time" + jose "github.com/go-jose/go-jose/v3" gomock "github.com/golang/mock/gomock" - oidc "github.com/zitadel/oidc/v2/pkg/oidc" - op "github.com/zitadel/oidc/v2/pkg/op" - jose "gopkg.in/square/go-jose.v2" + oidc "github.com/zitadel/oidc/v3/pkg/oidc" + op "github.com/zitadel/oidc/v3/pkg/op" ) // MockStorage is a mock of Storage interface. diff --git a/pkg/op/mock/storage.mock.impl.go b/pkg/op/mock/storage.mock.impl.go index 9269f891..002da7ec 100644 --- a/pkg/op/mock/storage.mock.impl.go +++ b/pkg/op/mock/storage.mock.impl.go @@ -8,8 +8,8 @@ import ( "github.com/golang/mock/gomock" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" ) func NewStorage(t *testing.T) op.Storage { diff --git a/pkg/op/op.go b/pkg/op/op.go index c4be14f1..5abe08f1 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -6,16 +6,17 @@ import ( "net/http" "time" - "github.com/gorilla/mux" - "github.com/gorilla/schema" + "github.com/go-chi/chi" + jose "github.com/go-jose/go-jose/v3" "github.com/rs/cors" + "github.com/zitadel/schema" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/trace" + "golang.org/x/exp/slog" "golang.org/x/text/language" - "gopkg.in/square/go-jose.v2" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) const ( @@ -33,7 +34,7 @@ const ( ) var ( - DefaultEndpoints = &endpoints{ + DefaultEndpoints = &Endpoints{ Authorization: NewEndpoint(defaultAuthorizationEndpoint), Token: NewEndpoint(defaultTokenEndpoint), Introspection: NewEndpoint(defaultIntrospectEndpoint), @@ -76,29 +77,35 @@ func init() { } type OpenIDProvider interface { + http.Handler Configuration Storage() Storage Decoder() httphelper.Decoder Encoder() httphelper.Encoder - IDTokenHintVerifier(context.Context) IDTokenHintVerifier - AccessTokenVerifier(context.Context) AccessTokenVerifier + IDTokenHintVerifier(context.Context) *IDTokenHintVerifier + AccessTokenVerifier(context.Context) *AccessTokenVerifier Crypto() Crypto DefaultLogoutRedirectURI() string Probes() []ProbesFn + + // EXPERIMENTAL: Will change to log/slog import after we drop support for Go 1.20 + Logger() *slog.Logger + + // Deprecated: Provider now implements http.Handler directly. HttpHandler() http.Handler } type HttpInterceptor func(http.Handler) http.Handler -func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router { - router := mux.NewRouter() +func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) chi.Router { + router := chi.NewRouter() router.Use(cors.New(defaultCORSOptions).Handler) router.Use(intercept(o.IssuerFromRequest, interceptors...)) router.HandleFunc(healthEndpoint, healthHandler) router.HandleFunc(readinessEndpoint, readyHandler(o.Probes())) router.HandleFunc(oidc.DiscoveryEndpoint, discoveryHandler(o, o.Storage())) router.HandleFunc(o.AuthorizationEndpoint().Relative(), authorizeHandler(o)) - router.NewRoute().Path(authCallbackPath(o)).Queries("id", "{id}").HandlerFunc(authorizeCallbackHandler(o)) + router.HandleFunc(authCallbackPath(o), authorizeCallbackHandler(o)) router.HandleFunc(o.TokenEndpoint().Relative(), tokenHandler(o)) router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o)) router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o)) @@ -132,16 +139,17 @@ type Config struct { DeviceAuthorization DeviceAuthorizationConfig } -type endpoints struct { - Authorization Endpoint - Token Endpoint - Introspection Endpoint - Userinfo Endpoint - Revocation Endpoint - EndSession Endpoint - CheckSessionIframe Endpoint - JwksURI Endpoint - DeviceAuthorization Endpoint +// Endpoints defines endpoint routes. +type Endpoints struct { + Authorization *Endpoint + Token *Endpoint + Introspection *Endpoint + Userinfo *Endpoint + Revocation *Endpoint + EndSession *Endpoint + CheckSessionIframe *Endpoint + JwksURI *Endpoint + DeviceAuthorization *Endpoint } // NewOpenIDProvider creates a provider. The provider provides (with HttpHandler()) @@ -186,6 +194,7 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR storage: storage, endpoints: DefaultEndpoints, timer: make(<-chan time.Time), + logger: slog.Default(), } for _, optFunc := range opOpts { @@ -199,7 +208,7 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR return nil, err } - o.httpHandler = CreateRouter(o, o.interceptors...) + o.Handler = CreateRouter(o, o.interceptors...) o.decoder = schema.NewDecoder() o.decoder.IgnoreUnknownKeys(true) @@ -215,20 +224,21 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR } type Provider struct { + http.Handler config *Config issuer IssuerFromRequest insecure bool - endpoints *endpoints + endpoints *Endpoints storage Storage keySet *openIDKeySet crypto Crypto - httpHandler http.Handler decoder *schema.Decoder encoder *schema.Encoder interceptors []HttpInterceptor timer <-chan time.Time accessTokenVerifierOpts []AccessTokenVerifierOpt idTokenHintVerifierOpts []IDTokenHintVerifierOpt + logger *slog.Logger } func (o *Provider) IssuerFromRequest(r *http.Request) string { @@ -239,35 +249,35 @@ func (o *Provider) Insecure() bool { return o.insecure } -func (o *Provider) AuthorizationEndpoint() Endpoint { +func (o *Provider) AuthorizationEndpoint() *Endpoint { return o.endpoints.Authorization } -func (o *Provider) TokenEndpoint() Endpoint { +func (o *Provider) TokenEndpoint() *Endpoint { return o.endpoints.Token } -func (o *Provider) IntrospectionEndpoint() Endpoint { +func (o *Provider) IntrospectionEndpoint() *Endpoint { return o.endpoints.Introspection } -func (o *Provider) UserinfoEndpoint() Endpoint { +func (o *Provider) UserinfoEndpoint() *Endpoint { return o.endpoints.Userinfo } -func (o *Provider) RevocationEndpoint() Endpoint { +func (o *Provider) RevocationEndpoint() *Endpoint { return o.endpoints.Revocation } -func (o *Provider) EndSessionEndpoint() Endpoint { +func (o *Provider) EndSessionEndpoint() *Endpoint { return o.endpoints.EndSession } -func (o *Provider) DeviceAuthorizationEndpoint() Endpoint { +func (o *Provider) DeviceAuthorizationEndpoint() *Endpoint { return o.endpoints.DeviceAuthorization } -func (o *Provider) KeysEndpoint() Endpoint { +func (o *Provider) KeysEndpoint() *Endpoint { return o.endpoints.JwksURI } @@ -354,15 +364,15 @@ func (o *Provider) Encoder() httphelper.Encoder { return o.encoder } -func (o *Provider) IDTokenHintVerifier(ctx context.Context) IDTokenHintVerifier { +func (o *Provider) IDTokenHintVerifier(ctx context.Context) *IDTokenHintVerifier { return NewIDTokenHintVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.idTokenHintVerifierOpts...) } -func (o *Provider) JWTProfileVerifier(ctx context.Context) JWTProfileVerifier { +func (o *Provider) JWTProfileVerifier(ctx context.Context) *JWTProfileVerifier { return NewJWTProfileVerifier(o.Storage(), IssuerFromContext(ctx), 1*time.Hour, time.Second) } -func (o *Provider) AccessTokenVerifier(ctx context.Context) AccessTokenVerifier { +func (o *Provider) AccessTokenVerifier(ctx context.Context) *AccessTokenVerifier { return NewAccessTokenVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.accessTokenVerifierOpts...) } @@ -387,8 +397,13 @@ func (o *Provider) Probes() []ProbesFn { } } +func (o *Provider) Logger() *slog.Logger { + return o.logger +} + +// Deprecated: Provider now implements http.Handler directly. func (o *Provider) HttpHandler() http.Handler { - return o.httpHandler + return o } type openIDKeySet struct { @@ -421,7 +436,7 @@ func WithAllowInsecure() Option { } } -func WithCustomAuthEndpoint(endpoint Endpoint) Option { +func WithCustomAuthEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -431,7 +446,7 @@ func WithCustomAuthEndpoint(endpoint Endpoint) Option { } } -func WithCustomTokenEndpoint(endpoint Endpoint) Option { +func WithCustomTokenEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -441,7 +456,7 @@ func WithCustomTokenEndpoint(endpoint Endpoint) Option { } } -func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option { +func WithCustomIntrospectionEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -451,7 +466,7 @@ func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option { } } -func WithCustomUserinfoEndpoint(endpoint Endpoint) Option { +func WithCustomUserinfoEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -461,7 +476,7 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) Option { } } -func WithCustomRevocationEndpoint(endpoint Endpoint) Option { +func WithCustomRevocationEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -471,7 +486,7 @@ func WithCustomRevocationEndpoint(endpoint Endpoint) Option { } } -func WithCustomEndSessionEndpoint(endpoint Endpoint) Option { +func WithCustomEndSessionEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -481,7 +496,7 @@ func WithCustomEndSessionEndpoint(endpoint Endpoint) Option { } } -func WithCustomKeysEndpoint(endpoint Endpoint) Option { +func WithCustomKeysEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -491,7 +506,7 @@ func WithCustomKeysEndpoint(endpoint Endpoint) Option { } } -func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option { +func WithCustomDeviceAuthorizationEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -501,8 +516,16 @@ func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option { } } -func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys Endpoint) Option { +// WithCustomEndpoints sets multiple endpoints at once. +// Non of the endpoints may be nil, or an error will +// be returned when the Option used by the Provider. +func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys *Endpoint) Option { return func(o *Provider) error { + for _, e := range []*Endpoint{auth, token, userInfo, revocation, endSession, keys} { + if err := e.Validate(); err != nil { + return err + } + } o.endpoints.Authorization = auth o.endpoints.Token = token o.endpoints.Userinfo = userInfo @@ -534,6 +557,16 @@ func WithIDTokenHintVerifierOpts(opts ...IDTokenHintVerifierOpt) Option { } } +// WithLogger lets a logger other than slog.Default(). +// +// EXPERIMENTAL: Will change to log/slog import after we drop support for Go 1.20 +func WithLogger(logger *slog.Logger) Option { + return func(o *Provider) error { + o.logger = logger + return nil + } +} + func intercept(i IssuerFromRequest, interceptors ...HttpInterceptor) func(handler http.Handler) http.Handler { issuerInterceptor := NewIssuerInterceptor(i) return func(handler http.Handler) http.Handler { diff --git a/pkg/op/op_test.go b/pkg/op/op_test.go index b637e031..062fcfe8 100644 --- a/pkg/op/op_test.go +++ b/pkg/op/op_test.go @@ -14,9 +14,9 @@ import ( "github.com/muhlemmer/gu" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/zitadel/oidc/v2/example/server/storage" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v3/example/server/storage" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" "golang.org/x/text/language" ) @@ -157,7 +157,7 @@ func TestRoutes(t *testing.T) { values: map[string]string{ "client_id": client.GetID(), "redirect_uri": "https://example.com", - "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), "response_type": string(oidc.ResponseTypeCode), }, wantCode: http.StatusFound, @@ -194,7 +194,7 @@ func TestRoutes(t *testing.T) { path: testProvider.TokenEndpoint().Relative(), values: map[string]string{ "grant_type": string(oidc.GrantTypeBearer), - "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), "assertion": jwtToken, }, wantCode: http.StatusBadRequest, @@ -207,7 +207,7 @@ func TestRoutes(t *testing.T) { basicAuth: &basicAuth{"web", "secret"}, values: map[string]string{ "grant_type": string(oidc.GrantTypeTokenExchange), - "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), "subject_token": jwtToken, "subject_token_type": string(oidc.AccessTokenType), }, @@ -224,7 +224,7 @@ func TestRoutes(t *testing.T) { basicAuth: &basicAuth{"sid1", "verysecret"}, values: map[string]string{ "grant_type": string(oidc.GrantTypeClientCredentials), - "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), }, wantCode: http.StatusOK, contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299}`}, @@ -339,7 +339,7 @@ func TestRoutes(t *testing.T) { path: testProvider.DeviceAuthorizationEndpoint().Relative(), basicAuth: &basicAuth{"device", "secret"}, values: map[string]string{ - "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), }, wantCode: http.StatusOK, contains: []string{ @@ -371,7 +371,7 @@ func TestRoutes(t *testing.T) { } rec := httptest.NewRecorder() - testProvider.HttpHandler().ServeHTTP(rec, req) + testProvider.ServeHTTP(rec, req) resp := rec.Result() require.NoError(t, err) @@ -396,3 +396,54 @@ func TestRoutes(t *testing.T) { }) } } + +func TestWithCustomEndpoints(t *testing.T) { + type args struct { + auth *op.Endpoint + token *op.Endpoint + userInfo *op.Endpoint + revocation *op.Endpoint + endSession *op.Endpoint + keys *op.Endpoint + } + tests := []struct { + name string + args args + wantErr error + }{ + { + name: "all nil", + args: args{}, + wantErr: op.ErrNilEndpoint, + }, + { + name: "all set", + args: args{ + auth: op.NewEndpoint("/authorize"), + token: op.NewEndpoint("/oauth/token"), + userInfo: op.NewEndpoint("/userinfo"), + revocation: op.NewEndpoint("/revoke"), + endSession: op.NewEndpoint("/end_session"), + keys: op.NewEndpoint("/keys"), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider, err := op.NewOpenIDProvider(testIssuer, testConfig, + storage.NewStorage(storage.NewUserStore(testIssuer)), + op.WithCustomEndpoints(tt.args.auth, tt.args.token, tt.args.userInfo, tt.args.revocation, tt.args.endSession, tt.args.keys), + ) + require.ErrorIs(t, err, tt.wantErr) + if tt.wantErr != nil { + return + } + assert.Equal(t, tt.args.auth, provider.AuthorizationEndpoint()) + assert.Equal(t, tt.args.token, provider.TokenEndpoint()) + assert.Equal(t, tt.args.userInfo, provider.UserinfoEndpoint()) + assert.Equal(t, tt.args.revocation, provider.RevocationEndpoint()) + assert.Equal(t, tt.args.endSession, provider.EndSessionEndpoint()) + assert.Equal(t, tt.args.keys, provider.KeysEndpoint()) + }) + } +} diff --git a/pkg/op/probes.go b/pkg/op/probes.go index a56c92bc..cb3853d8 100644 --- a/pkg/op/probes.go +++ b/pkg/op/probes.go @@ -5,7 +5,7 @@ import ( "errors" "net/http" - httphelper "github.com/zitadel/oidc/v2/pkg/http" + httphelper "github.com/zitadel/oidc/v3/pkg/http" ) type ProbesFn func(context.Context) error @@ -41,9 +41,9 @@ func ReadyStorage(s Storage) ProbesFn { } func ok(w http.ResponseWriter) { - httphelper.MarshalJSON(w, status{"ok"}) + httphelper.MarshalJSON(w, Status{"ok"}) } -type status struct { +type Status struct { Status string `json:"status,omitempty"` } diff --git a/pkg/op/server.go b/pkg/op/server.go new file mode 100644 index 00000000..a9cdcf5f --- /dev/null +++ b/pkg/op/server.go @@ -0,0 +1,346 @@ +package op + +import ( + "context" + "net/http" + "net/url" + + "github.com/muhlemmer/gu" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" +) + +// Server describes the interface that needs to be implemented to serve +// OpenID Connect and Oauth2 standard requests. +// +// Methods are called after the HTTP route is resolved and +// the request body is parsed into the Request's Data field. +// When a method is called, it can be assumed that required fields, +// as described in their relevant standard, are validated already. +// The Response Data field may be of any type to allow flexibility +// to extend responses with custom fields. There are however requirements +// in the standards regarding the response models. Where applicable +// the method documentation gives a recommended type which can be used +// directly or extended upon. +// +// The addition of new methods is not considered a breaking change +// as defined by semver rules. +// Implementations MUST embed [UnimplementedServer] to maintain +// forward compatibility. +// +// EXPERIMENTAL: may change until v4 +type Server interface { + // Health returns a status of "ok" once the Server is listening. + // The recommended Response Data type is [Status]. + Health(context.Context, *Request[struct{}]) (*Response, error) + + // Ready returns a status of "ok" once all dependencies, + // such as database storage, are ready. + // An error can be returned to explain what is not ready. + // The recommended Response Data type is [Status]. + Ready(context.Context, *Request[struct{}]) (*Response, error) + + // Discovery returns the OpenID Provider Configuration Information for this server. + // https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfig + // The recommended Response Data type is [oidc.DiscoveryConfiguration]. + Discovery(context.Context, *Request[struct{}]) (*Response, error) + + // Keys serves the JWK set which the client can use verify signatures from the op. + // https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata `jwks_uri` key. + // The recommended Response Data type is [jose.JSONWebKeySet]. + Keys(context.Context, *Request[struct{}]) (*Response, error) + + // VerifyAuthRequest verifies the Auth Request and + // adds the Client to the request. + // + // When the `request` field is populated with a + // "Request Object" JWT, it needs to be Validated + // and its claims overwrite any fields in the AuthRequest. + // If the implementation does not support "Request Object", + // it MUST return an [oidc.ErrRequestNotSupported]. + // https://openid.net/specs/openid-connect-core-1_0.html#RequestObject + VerifyAuthRequest(context.Context, *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) + + // Authorize initiates the authorization flow and redirects to a login page. + // See the various https://openid.net/specs/openid-connect-core-1_0.html + // authorize endpoint sections (one for each type of flow). + Authorize(context.Context, *ClientRequest[oidc.AuthRequest]) (*Redirect, error) + + // DeviceAuthorization initiates the device authorization flow. + // https://datatracker.ietf.org/doc/html/rfc8628#section-3.1 + // The recommended Response Data type is [oidc.DeviceAuthorizationResponse]. + DeviceAuthorization(context.Context, *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) + + // VerifyClient is called on most oauth/token handlers to authenticate, + // using either a secret (POST, Basic) or assertion (JWT). + // If no secrets are provided, the client must be public. + // This method is called before each method that takes a + // [ClientRequest] argument. + VerifyClient(context.Context, *Request[ClientCredentials]) (Client, error) + + // CodeExchange returns Tokens after an authorization code + // is obtained in a successful Authorize flow. + // It is called by the Token endpoint handler when + // grant_type has the value authorization_code + // https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint + // The recommended Response Data type is [oidc.AccessTokenResponse]. + CodeExchange(context.Context, *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) + + // RefreshToken returns new Tokens after verifying a Refresh token. + // It is called by the Token endpoint handler when + // grant_type has the value refresh_token + // https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens + // The recommended Response Data type is [oidc.AccessTokenResponse]. + RefreshToken(context.Context, *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) + + // JWTProfile handles the OAuth 2.0 JWT Profile Authorization Grant + // It is called by the Token endpoint handler when + // grant_type has the value urn:ietf:params:oauth:grant-type:jwt-bearer + // https://datatracker.ietf.org/doc/html/rfc7523#section-2.1 + // The recommended Response Data type is [oidc.AccessTokenResponse]. + JWTProfile(context.Context, *Request[oidc.JWTProfileGrantRequest]) (*Response, error) + + // TokenExchange handles the OAuth 2.0 token exchange grant + // It is called by the Token endpoint handler when + // grant_type has the value urn:ietf:params:oauth:grant-type:token-exchange + // https://datatracker.ietf.org/doc/html/rfc8693 + // The recommended Response Data type is [oidc.AccessTokenResponse]. + TokenExchange(context.Context, *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) + + // ClientCredentialsExchange handles the OAuth 2.0 client credentials grant + // It is called by the Token endpoint handler when + // grant_type has the value client_credentials + // https://datatracker.ietf.org/doc/html/rfc6749#section-4.4 + // The recommended Response Data type is [oidc.AccessTokenResponse]. + ClientCredentialsExchange(context.Context, *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) + + // DeviceToken handles the OAuth 2.0 Device Authorization Grant + // It is called by the Token endpoint handler when + // grant_type has the value urn:ietf:params:oauth:grant-type:device_code. + // It is typically called in a polling fashion and appropriate errors + // should be returned to signal authorization_pending or access_denied etc. + // https://datatracker.ietf.org/doc/html/rfc8628#section-3.4, + // https://datatracker.ietf.org/doc/html/rfc8628#section-3.5. + // The recommended Response Data type is [oidc.AccessTokenResponse]. + DeviceToken(context.Context, *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) + + // Introspect handles the OAuth 2.0 Token Introspection endpoint. + // https://datatracker.ietf.org/doc/html/rfc7662 + // The recommended Response Data type is [oidc.IntrospectionResponse]. + Introspect(context.Context, *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) + + // UserInfo handles the UserInfo endpoint and returns Claims about the authenticated End-User. + // https://openid.net/specs/openid-connect-core-1_0.html#UserInfo + // The recommended Response Data type is [oidc.UserInfo]. + UserInfo(context.Context, *Request[oidc.UserInfoRequest]) (*Response, error) + + // Revocation handles token revocation using an access or refresh token. + // https://datatracker.ietf.org/doc/html/rfc7009 + // There are no response requirements. Data may remain empty. + Revocation(context.Context, *ClientRequest[oidc.RevocationRequest]) (*Response, error) + + // EndSession handles the OpenID Connect RP-Initiated Logout. + // https://openid.net/specs/openid-connect-rpinitiated-1_0.html + // There are no response requirements. Data may remain empty. + EndSession(context.Context, *Request[oidc.EndSessionRequest]) (*Redirect, error) + + // mustImpl forces implementations to embed the UnimplementedServer for forward + // compatibility with the interface. + mustImpl() +} + +// Request contains the [http.Request] informational fields +// and parsed Data from the request body (POST) or URL parameters (GET). +// Data can be assumed to be validated according to the applicable +// standard for the specific endpoints. +// +// EXPERIMENTAL: may change until v4 +type Request[T any] struct { + Method string + URL *url.URL + Header http.Header + Form url.Values + PostForm url.Values + Data *T +} + +func (r *Request[_]) path() string { + return r.URL.Path +} + +func newRequest[T any](r *http.Request, data *T) *Request[T] { + return &Request[T]{ + Method: r.Method, + URL: r.URL, + Header: r.Header, + Form: r.Form, + PostForm: r.PostForm, + Data: data, + } +} + +// ClientRequest is a Request with a verified client attached to it. +// Methods that receive this argument may assume the client was authenticated, +// or verified to be a public client. +// +// EXPERIMENTAL: may change until v4 +type ClientRequest[T any] struct { + *Request[T] + Client Client +} + +func newClientRequest[T any](r *http.Request, data *T, client Client) *ClientRequest[T] { + return &ClientRequest[T]{ + Request: newRequest[T](r, data), + Client: client, + } +} + +// Response object for most [Server] methods. +// +// EXPERIMENTAL: may change until v4 +type Response struct { + // Header map will be merged with the + // header on the [http.ResponseWriter]. + Header http.Header + + // Data will be JSON marshaled to + // the response body. + // We allow any type, so that implementations + // can extend the standard types as they wish. + // However, each method will recommend which + // (base) type to use as model, in order to + // be compliant with the standards. + Data any +} + +// NewResponse creates a new response for data, +// without custom headers. +func NewResponse(data any) *Response { + return &Response{ + Data: data, + } +} + +func (resp *Response) writeOut(w http.ResponseWriter) { + gu.MapMerge(resp.Header, w.Header()) + httphelper.MarshalJSON(w, resp.Data) +} + +// Redirect is a special response type which will +// initiate a [http.StatusFound] redirect. +// The Params field will be encoded and set to the +// URL's RawQuery field before building the URL. +// +// EXPERIMENTAL: may change until v4 +type Redirect struct { + // Header map will be merged with the + // header on the [http.ResponseWriter]. + Header http.Header + + URL string +} + +func NewRedirect(url string) *Redirect { + return &Redirect{URL: url} +} + +func (red *Redirect) writeOut(w http.ResponseWriter, r *http.Request) { + gu.MapMerge(r.Header, w.Header()) + http.Redirect(w, r, red.URL, http.StatusFound) +} + +type UnimplementedServer struct{} + +// UnimplementedStatusCode is the status code returned for methods +// that are not yet implemented. +// Note that this means methods in the sense of the Go interface, +// and not http methods covered by "501 Not Implemented". +var UnimplementedStatusCode = http.StatusNotFound + +func unimplementedError(r interface{ path() string }) StatusError { + err := oidc.ErrServerError().WithDescription("%s not implemented on this server", r.path()) + return NewStatusError(err, UnimplementedStatusCode) +} + +func unimplementedGrantError(gt oidc.GrantType) StatusError { + err := oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", gt) + return NewStatusError(err, http.StatusBadRequest) // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 +} + +func (UnimplementedServer) mustImpl() {} + +func (UnimplementedServer) Health(ctx context.Context, r *Request[struct{}]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) Ready(ctx context.Context, r *Request[struct{}]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) { + if r.Data.RequestParam != "" { + return nil, oidc.ErrRequestNotSupported() + } + return nil, unimplementedError(r) +} + +func (UnimplementedServer) Authorize(ctx context.Context, r *ClientRequest[oidc.AuthRequest]) (*Redirect, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) { + return nil, unimplementedGrantError(oidc.GrantTypeCode) +} + +func (UnimplementedServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) { + return nil, unimplementedGrantError(oidc.GrantTypeRefreshToken) +} + +func (UnimplementedServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfileGrantRequest]) (*Response, error) { + return nil, unimplementedGrantError(oidc.GrantTypeBearer) +} + +func (UnimplementedServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) { + return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange) +} + +func (UnimplementedServer) ClientCredentialsExchange(ctx context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) { + return nil, unimplementedGrantError(oidc.GrantTypeClientCredentials) +} + +func (UnimplementedServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) { + return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode) +} + +func (UnimplementedServer) Introspect(ctx context.Context, r *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) UserInfo(ctx context.Context, r *Request[oidc.UserInfoRequest]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) Revocation(ctx context.Context, r *ClientRequest[oidc.RevocationRequest]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) EndSession(ctx context.Context, r *Request[oidc.EndSessionRequest]) (*Redirect, error) { + return nil, unimplementedError(r) +} diff --git a/pkg/op/server_http.go b/pkg/op/server_http.go new file mode 100644 index 00000000..3fb481d1 --- /dev/null +++ b/pkg/op/server_http.go @@ -0,0 +1,480 @@ +package op + +import ( + "context" + "net/http" + "net/url" + + "github.com/go-chi/chi" + "github.com/rs/cors" + "github.com/zitadel/logging" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/schema" + "golang.org/x/exp/slog" +) + +// RegisterServer registers an implementation of Server. +// The resulting handler takes care of routing and request parsing, +// with some basic validation of required fields. +// The routes can be customized with [WithEndpoints]. +// +// EXPERIMENTAL: may change until v4 +func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption) http.Handler { + decoder := schema.NewDecoder() + decoder.IgnoreUnknownKeys(true) + + ws := &webServer{ + server: server, + endpoints: endpoints, + decoder: decoder, + logger: slog.Default(), + } + + for _, option := range options { + option(ws) + } + + ws.createRouter() + return ws +} + +type ServerOption func(s *webServer) + +// WithHTTPMiddleware sets the passed middleware chain to the root of +// the Server's router. +func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption { + return func(s *webServer) { + s.middleware = m + } +} + +// WithDecoder overrides the default decoder, +// which is a [schema.Decoder] with IgnoreUnknownKeys set to true. +func WithDecoder(decoder httphelper.Decoder) ServerOption { + return func(s *webServer) { + s.decoder = decoder + } +} + +// WithFallbackLogger overrides the fallback logger, which +// is used when no logger was found in the context. +// Defaults to [slog.Default]. +func WithFallbackLogger(logger *slog.Logger) ServerOption { + return func(s *webServer) { + s.logger = logger + } +} + +type webServer struct { + http.Handler + server Server + middleware []func(http.Handler) http.Handler + endpoints Endpoints + decoder httphelper.Decoder + logger *slog.Logger +} + +func (s *webServer) getLogger(ctx context.Context) *slog.Logger { + if logger, ok := logging.FromContext(ctx); ok { + return logger + } + return s.logger +} + +func (s *webServer) createRouter() { + router := chi.NewRouter() + router.Use(cors.New(defaultCORSOptions).Handler) + router.Use(s.middleware...) + router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health)) + router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready)) + router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery)) + + s.endpointRoute(router, s.endpoints.Authorization, s.authorizeHandler) + s.endpointRoute(router, s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler)) + s.endpointRoute(router, s.endpoints.Token, s.tokensHandler) + s.endpointRoute(router, s.endpoints.Introspection, s.withClient(s.introspectionHandler)) + s.endpointRoute(router, s.endpoints.Userinfo, s.userInfoHandler) + s.endpointRoute(router, s.endpoints.Revocation, s.withClient(s.revocationHandler)) + s.endpointRoute(router, s.endpoints.EndSession, s.endSessionHandler) + s.endpointRoute(router, s.endpoints.JwksURI, simpleHandler(s, s.server.Keys)) + s.Handler = router +} + +func (s *webServer) endpointRoute(router *chi.Mux, e *Endpoint, hf http.HandlerFunc) { + if e != nil { + router.HandleFunc(e.Relative(), hf) + s.logger.Info("registered route", "endpoint", e.Relative()) + } +} + +type clientHandler func(w http.ResponseWriter, r *http.Request, client Client) + +func (s *webServer) withClient(handler clientHandler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + client, err := s.verifyRequestClient(r) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType != "" { + if !ValidateGrantType(client, grantType) { + WriteError(w, r, oidc.ErrUnauthorizedClient().WithDescription("grant_type %q not allowed", grantType), s.getLogger(r.Context())) + return + } + } + handler(w, r, client) + } +} + +func (s *webServer) verifyRequestClient(r *http.Request) (_ Client, err error) { + if err = r.ParseForm(); err != nil { + return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err) + } + cc := new(ClientCredentials) + if err = s.decoder.Decode(cc, r.Form); err != nil { + return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err) + } + // Basic auth takes precedence, so if set it overwrites the form data. + if clientID, clientSecret, ok := r.BasicAuth(); ok { + cc.ClientID, err = url.QueryUnescape(clientID) + if err != nil { + return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err) + } + cc.ClientSecret, err = url.QueryUnescape(clientSecret) + if err != nil { + return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err) + } + } + if cc.ClientID == "" && cc.ClientAssertion == "" { + return nil, oidc.ErrInvalidRequest().WithDescription("client_id or client_assertion must be provided") + } + if cc.ClientAssertion != "" && cc.ClientAssertionType != oidc.ClientAssertionTypeJWTAssertion { + return nil, oidc.ErrInvalidRequest().WithDescription("invalid client_assertion_type %s", cc.ClientAssertionType) + } + return s.server.VerifyClient(r.Context(), &Request[ClientCredentials]{ + Method: r.Method, + URL: r.URL, + Header: r.Header, + Form: r.Form, + Data: cc, + }) +} + +func (s *webServer) authorizeHandler(w http.ResponseWriter, r *http.Request) { + request, err := decodeRequest[oidc.AuthRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + redirect, err := s.authorize(r.Context(), newRequest(r, request)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + redirect.writeOut(w, r) +} + +func (s *webServer) authorize(ctx context.Context, r *Request[oidc.AuthRequest]) (_ *Redirect, err error) { + cr, err := s.server.VerifyAuthRequest(ctx, r) + if err != nil { + return nil, err + } + authReq := cr.Data + if authReq.RedirectURI == "" { + return nil, ErrAuthReqMissingRedirectURI + } + authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge) + if err != nil { + return nil, err + } + authReq.Scopes, err = ValidateAuthReqScopes(cr.Client, authReq.Scopes) + if err != nil { + return nil, err + } + if err := ValidateAuthReqRedirectURI(cr.Client, authReq.RedirectURI, authReq.ResponseType); err != nil { + return nil, err + } + if err := ValidateAuthReqResponseType(cr.Client, authReq.ResponseType); err != nil { + return nil, err + } + return s.server.Authorize(ctx, cr) +} + +func (s *webServer) deviceAuthorizationHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.DeviceAuthorizationRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp, err := s.server.DeviceAuthorization(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context())) + return + } + + switch grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType { + case oidc.GrantTypeCode: + s.withClient(s.codeExchangeHandler)(w, r) + case oidc.GrantTypeRefreshToken: + s.withClient(s.refreshTokenHandler)(w, r) + case oidc.GrantTypeClientCredentials: + s.withClient(s.clientCredentialsHandler)(w, r) + case oidc.GrantTypeBearer: + s.jwtProfileHandler(w, r) + case oidc.GrantTypeTokenExchange: + s.withClient(s.tokenExchangeHandler)(w, r) + case oidc.GrantTypeDeviceCode: + s.withClient(s.deviceTokenHandler)(w, r) + case "": + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), s.getLogger(r.Context())) + default: + WriteError(w, r, unimplementedGrantError(grantType), s.getLogger(r.Context())) + } +} + +func (s *webServer) jwtProfileHandler(w http.ResponseWriter, r *http.Request) { + request, err := decodeRequest[oidc.JWTProfileGrantRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.Assertion == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("assertion missing"), s.getLogger(r.Context())) + return + } + resp, err := s.server.JWTProfile(r.Context(), newRequest(r, request)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) codeExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.AccessTokenRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.Code == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), s.getLogger(r.Context())) + return + } + if request.RedirectURI == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("redirect_uri missing"), s.getLogger(r.Context())) + return + } + resp, err := s.server.CodeExchange(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) refreshTokenHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.RefreshTokenRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.RefreshToken == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("refresh_token missing"), s.getLogger(r.Context())) + return + } + resp, err := s.server.RefreshToken(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) tokenExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.TokenExchangeRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.SubjectToken == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token missing"), s.getLogger(r.Context())) + return + } + if request.SubjectTokenType == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing"), s.getLogger(r.Context())) + return + } + if !request.SubjectTokenType.IsSupported() { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type is not supported"), s.getLogger(r.Context())) + return + } + if request.RequestedTokenType != "" && !request.RequestedTokenType.IsSupported() { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("requested_token_type is not supported"), s.getLogger(r.Context())) + return + } + if request.ActorTokenType != "" && !request.ActorTokenType.IsSupported() { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.getLogger(r.Context())) + return + } + resp, err := s.server.TokenExchange(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) clientCredentialsHandler(w http.ResponseWriter, r *http.Request, client Client) { + if client.AuthMethod() == oidc.AuthMethodNone { + WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context())) + return + } + + request, err := decodeRequest[oidc.ClientCredentialsRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp, err := s.server.ClientCredentialsExchange(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) deviceTokenHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.DeviceAccessTokenRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.DeviceCode == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("device_code missing"), s.getLogger(r.Context())) + return + } + resp, err := s.server.DeviceToken(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request, client Client) { + if client.AuthMethod() == oidc.AuthMethodNone { + WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context())) + return + } + request, err := decodeRequest[oidc.IntrospectionRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.Token == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context())) + return + } + resp, err := s.server.Introspect(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) { + request, err := decodeRequest[oidc.UserInfoRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if token, err := getAccessToken(r); err == nil { + request.AccessToken = token + } + if request.AccessToken == "" { + err = NewStatusError( + oidc.ErrInvalidRequest().WithDescription("access token missing"), + http.StatusUnauthorized, + ) + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp, err := s.server.UserInfo(r.Context(), newRequest(r, request)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) revocationHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.RevocationRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.Token == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context())) + return + } + resp, err := s.server.Revocation(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) endSessionHandler(w http.ResponseWriter, r *http.Request) { + request, err := decodeRequest[oidc.EndSessionRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp, err := s.server.EndSession(r.Context(), newRequest(r, request)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w, r) +} + +func simpleHandler(s *webServer, method func(context.Context, *Request[struct{}]) (*Response, error)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context())) + return + } + resp, err := method(r.Context(), newRequest(r, &struct{}{})) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) + } +} + +func decodeRequest[R any](decoder httphelper.Decoder, r *http.Request, postOnly bool) (*R, error) { + dst := new(R) + if err := r.ParseForm(); err != nil { + return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err) + } + form := r.Form + if postOnly { + form = r.PostForm + } + if err := decoder.Decode(dst, form); err != nil { + return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err) + } + return dst, nil +} diff --git a/pkg/op/server_http_routes_test.go b/pkg/op/server_http_routes_test.go new file mode 100644 index 00000000..c7767d25 --- /dev/null +++ b/pkg/op/server_http_routes_test.go @@ -0,0 +1,345 @@ +package op_test + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/oidc/v3/pkg/client" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" +) + +func jwtProfile() (string, error) { + keyData, err := client.ConfigFromKeyFile("../../example/server/service-key1.json") + if err != nil { + return "", err + } + signer, err := client.NewSignerFromPrivateKeyByte([]byte(keyData.Key), keyData.KeyID) + if err != nil { + return "", err + } + return client.SignedJWTProfileAssertion(keyData.UserID, []string{testIssuer}, time.Hour, signer) +} + +func TestServerRoutes(t *testing.T) { + server := op.NewLegacyServer(testProvider, *op.DefaultEndpoints) + + storage := testProvider.Storage().(routesTestStorage) + ctx := op.ContextWithIssuer(context.Background(), testIssuer) + + client, err := storage.GetClientByClientID(ctx, "web") + require.NoError(t, err) + + oidcAuthReq := &oidc.AuthRequest{ + ClientID: client.GetID(), + RedirectURI: "https://example.com", + MaxAge: gu.Ptr[uint](300), + Scopes: oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess, oidc.ScopeEmail, oidc.ScopeProfile, oidc.ScopePhone}, + ResponseType: oidc.ResponseTypeCode, + } + + authReq, err := storage.CreateAuthRequest(ctx, oidcAuthReq, "id1") + require.NoError(t, err) + storage.AuthRequestDone(authReq.GetID()) + + accessToken, refreshToken, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "") + require.NoError(t, err) + accessTokenRevoke, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "") + require.NoError(t, err) + idToken, err := op.CreateIDToken(ctx, testIssuer, authReq, time.Hour, accessToken, "123", storage, client) + require.NoError(t, err) + jwtToken, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeJWT, testProvider, client, "") + require.NoError(t, err) + jwtProfileToken, err := jwtProfile() + require.NoError(t, err) + + oidcAuthReq.IDTokenHint = idToken + + serverURL, err := url.Parse(testIssuer) + require.NoError(t, err) + + type basicAuth struct { + username, password string + } + + tests := []struct { + name string + method string + path string + basicAuth *basicAuth + header map[string]string + values map[string]string + body map[string]string + wantCode int + headerContains map[string]string + json string // test for exact json output + contains []string // when the body output is not constant, we just check for snippets to be present in the response + }{ + { + name: "health", + method: http.MethodGet, + path: "/healthz", + wantCode: http.StatusOK, + json: `{"status":"ok"}`, + }, + { + name: "ready", + method: http.MethodGet, + path: "/ready", + wantCode: http.StatusOK, + json: `{"status":"ok"}`, + }, + { + name: "discovery", + method: http.MethodGet, + path: oidc.DiscoveryEndpoint, + wantCode: http.StatusOK, + json: `{"issuer":"https://localhost:9998/","authorization_endpoint":"https://localhost:9998/authorize","token_endpoint":"https://localhost:9998/oauth/token","introspection_endpoint":"https://localhost:9998/oauth/introspect","userinfo_endpoint":"https://localhost:9998/userinfo","revocation_endpoint":"https://localhost:9998/revoke","end_session_endpoint":"https://localhost:9998/end_session","device_authorization_endpoint":"https://localhost:9998/device_authorization","jwks_uri":"https://localhost:9998/keys","scopes_supported":["openid","profile","email","phone","address","offline_access"],"response_types_supported":["code","id_token","id_token token"],"grant_types_supported":["authorization_code","implicit","refresh_token","client_credentials","urn:ietf:params:oauth:grant-type:token-exchange","urn:ietf:params:oauth:grant-type:jwt-bearer","urn:ietf:params:oauth:grant-type:device_code"],"subject_types_supported":["public"],"id_token_signing_alg_values_supported":["RS256"],"request_object_signing_alg_values_supported":["RS256"],"token_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"token_endpoint_auth_signing_alg_values_supported":["RS256"],"revocation_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"revocation_endpoint_auth_signing_alg_values_supported":["RS256"],"introspection_endpoint_auth_methods_supported":["client_secret_basic","private_key_jwt"],"introspection_endpoint_auth_signing_alg_values_supported":["RS256"],"claims_supported":["sub","aud","exp","iat","iss","auth_time","nonce","acr","amr","c_hash","at_hash","act","scopes","client_id","azp","preferred_username","name","family_name","given_name","locale","email","email_verified","phone_number","phone_number_verified"],"code_challenge_methods_supported":["S256"],"ui_locales_supported":["en"],"request_parameter_supported":true,"request_uri_parameter_supported":false}`, + }, + { + name: "authorization", + method: http.MethodGet, + path: testProvider.AuthorizationEndpoint().Relative(), + values: map[string]string{ + "client_id": client.GetID(), + "redirect_uri": "https://example.com", + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), + "response_type": string(oidc.ResponseTypeCode), + }, + wantCode: http.StatusFound, + headerContains: map[string]string{"Location": "/login/username?authRequestID="}, + }, + { + // This call will fail. A successfull test is already + // part of client/integration_test.go + name: "code exchange", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + values: map[string]string{ + "grant_type": string(oidc.GrantTypeCode), + "client_id": client.GetID(), + "client_secret": "secret", + "redirect_uri": "https://example.com", + "code": "123", + }, + wantCode: http.StatusBadRequest, + json: `{"error":"invalid_grant", "error_description":"invalid code"}`, + }, + { + name: "JWT authorization", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + values: map[string]string{ + "grant_type": string(oidc.GrantTypeBearer), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), + "assertion": jwtProfileToken, + }, + wantCode: http.StatusOK, + contains: []string{`{"access_token":`, `"token_type":"Bearer","expires_in":299}`}, + }, + { + name: "Token exchange", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + basicAuth: &basicAuth{"web", "secret"}, + values: map[string]string{ + "grant_type": string(oidc.GrantTypeTokenExchange), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), + "subject_token": jwtToken, + "subject_token_type": string(oidc.AccessTokenType), + }, + wantCode: http.StatusOK, + contains: []string{ + `{"access_token":"`, + `","issued_token_type":"urn:ietf:params:oauth:token-type:refresh_token","token_type":"Bearer","expires_in":299,"scope":"openid offline_access","refresh_token":"`, + }, + }, + { + name: "Client credentials exchange", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + basicAuth: &basicAuth{"sid1", "verysecret"}, + values: map[string]string{ + "grant_type": string(oidc.GrantTypeClientCredentials), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), + }, + wantCode: http.StatusOK, + contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299}`}, + }, + { + // This call will fail. A successfull test is already + // part of device_test.go + name: "device token", + method: http.MethodPost, + path: testProvider.TokenEndpoint().Relative(), + basicAuth: &basicAuth{"device", "secret"}, + header: map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + }, + body: map[string]string{ + "grant_type": string(oidc.GrantTypeDeviceCode), + "device_code": "123", + }, + wantCode: http.StatusBadRequest, + json: `{"error":"access_denied","error_description":"The authorization request was denied."}`, + }, + { + name: "missing grant type", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + wantCode: http.StatusBadRequest, + json: `{"error":"invalid_request","error_description":"grant_type missing"}`, + }, + { + name: "unsupported grant type", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + values: map[string]string{ + "grant_type": "foo", + }, + wantCode: http.StatusBadRequest, + json: `{"error":"unsupported_grant_type","error_description":"foo not supported"}`, + }, + { + name: "introspection", + method: http.MethodGet, + path: testProvider.IntrospectionEndpoint().Relative(), + basicAuth: &basicAuth{"web", "secret"}, + values: map[string]string{ + "token": accessToken, + }, + wantCode: http.StatusOK, + json: `{"active":true,"scope":"openid offline_access email profile phone","client_id":"web","sub":"id1","username":"test-user@localhost","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`, + }, + { + name: "user info", + method: http.MethodGet, + path: testProvider.UserinfoEndpoint().Relative(), + header: map[string]string{ + "authorization": "Bearer " + accessToken, + }, + wantCode: http.StatusOK, + json: `{"sub":"id1","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`, + }, + { + name: "refresh token", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + values: map[string]string{ + "grant_type": string(oidc.GrantTypeRefreshToken), + "refresh_token": refreshToken, + "client_id": client.GetID(), + "client_secret": "secret", + }, + wantCode: http.StatusOK, + contains: []string{ + `{"access_token":"`, + `","token_type":"Bearer","refresh_token":"`, + `","expires_in":299,"id_token":"`, + }, + }, + { + name: "revoke", + method: http.MethodGet, + path: testProvider.RevocationEndpoint().Relative(), + basicAuth: &basicAuth{"web", "secret"}, + values: map[string]string{ + "token": accessTokenRevoke, + }, + wantCode: http.StatusOK, + }, + { + name: "end session", + method: http.MethodGet, + path: testProvider.EndSessionEndpoint().Relative(), + values: map[string]string{ + "id_token_hint": idToken, + "client_id": "web", + }, + wantCode: http.StatusFound, + headerContains: map[string]string{"Location": "/logged-out"}, + contains: []string{`Found.`}, + }, + { + name: "keys", + method: http.MethodGet, + path: testProvider.KeysEndpoint().Relative(), + wantCode: http.StatusOK, + contains: []string{ + `{"keys":[{"use":"sig","kty":"RSA","kid":"`, + `","alg":"RS256","n":"`, `","e":"AQAB"}]}`, + }, + }, + { + name: "device authorization", + method: http.MethodGet, + path: testProvider.DeviceAuthorizationEndpoint().Relative(), + basicAuth: &basicAuth{"device", "secret"}, + values: map[string]string{ + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), + }, + wantCode: http.StatusOK, + contains: []string{ + `{"device_code":"`, `","user_code":"`, + `","verification_uri":"https://localhost:9998/device"`, + `"verification_uri_complete":"https://localhost:9998/device?user_code=`, + `","expires_in":300,"interval":5}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u := gu.PtrCopy(serverURL) + u.Path = tt.path + if tt.values != nil { + u.RawQuery = mapAsValues(tt.values) + } + var body io.Reader + if tt.body != nil { + body = strings.NewReader(mapAsValues(tt.body)) + } + + req := httptest.NewRequest(tt.method, u.String(), body) + for k, v := range tt.header { + req.Header.Set(k, v) + } + if tt.basicAuth != nil { + req.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password) + } + + rec := httptest.NewRecorder() + server.ServeHTTP(rec, req) + + resp := rec.Result() + require.NoError(t, err) + assert.Equal(t, tt.wantCode, resp.StatusCode) + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + respBodyString := string(respBody) + t.Log(respBodyString) + t.Log(resp.Header) + + if tt.json != "" { + assert.JSONEq(t, tt.json, respBodyString) + } + for _, c := range tt.contains { + assert.Contains(t, respBodyString, c) + } + for k, v := range tt.headerContains { + assert.Contains(t, resp.Header.Get(k), v) + } + }) + } +} diff --git a/pkg/op/server_http_test.go b/pkg/op/server_http_test.go new file mode 100644 index 00000000..86fe7ed8 --- /dev/null +++ b/pkg/op/server_http_test.go @@ -0,0 +1,1333 @@ +package op + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "strings" + "testing" + "time" + + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/schema" + "golang.org/x/exp/slog" +) + +func TestRegisterServer(t *testing.T) { + server := UnimplementedServer{} + endpoints := Endpoints{ + Authorization: &Endpoint{ + path: "/auth", + }, + } + decoder := schema.NewDecoder() + logger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) + + h := RegisterServer(server, endpoints, + WithDecoder(decoder), + WithFallbackLogger(logger), + ) + got := h.(*webServer) + assert.Equal(t, got.server, server) + assert.Equal(t, got.endpoints, endpoints) + assert.Equal(t, got.decoder, decoder) + assert.Equal(t, got.logger, logger) +} + +type testClient struct { + id string + appType ApplicationType + authMethod oidc.AuthMethod + accessTokenType AccessTokenType + responseTypes []oidc.ResponseType + grantTypes []oidc.GrantType + devMode bool +} + +type clientType string + +const ( + clientTypeWeb clientType = "web" + clientTypeNative clientType = "native" + clientTypeUserAgent clientType = "useragent" +) + +func newClient(kind clientType) *testClient { + client := &testClient{ + id: string(kind), + } + + switch kind { + case clientTypeWeb: + client.appType = ApplicationTypeWeb + client.authMethod = oidc.AuthMethodBasic + client.accessTokenType = AccessTokenTypeBearer + client.responseTypes = []oidc.ResponseType{oidc.ResponseTypeCode} + case clientTypeNative: + client.appType = ApplicationTypeNative + client.authMethod = oidc.AuthMethodNone + client.accessTokenType = AccessTokenTypeBearer + client.responseTypes = []oidc.ResponseType{oidc.ResponseTypeCode} + case clientTypeUserAgent: + client.appType = ApplicationTypeUserAgent + client.authMethod = oidc.AuthMethodBasic + client.accessTokenType = AccessTokenTypeJWT + client.responseTypes = []oidc.ResponseType{oidc.ResponseTypeIDToken} + default: + panic(fmt.Errorf("invalid client type %s", kind)) + } + return client +} + +func (c *testClient) RedirectURIs() []string { + return []string{ + "https://registered.com/callback", + "http://registered.com/callback", + "http://localhost:9999/callback", + "custom://callback", + } +} + +func (c *testClient) PostLogoutRedirectURIs() []string { + return []string{} +} + +func (c *testClient) LoginURL(id string) string { + return "login?id=" + id +} + +func (c *testClient) ApplicationType() ApplicationType { + return c.appType +} + +func (c *testClient) AuthMethod() oidc.AuthMethod { + return c.authMethod +} + +func (c *testClient) GetID() string { + return c.id +} + +func (c *testClient) AccessTokenLifetime() time.Duration { + return 5 * time.Minute +} + +func (c *testClient) IDTokenLifetime() time.Duration { + return 5 * time.Minute +} + +func (c *testClient) AccessTokenType() AccessTokenType { + return c.accessTokenType +} + +func (c *testClient) ResponseTypes() []oidc.ResponseType { + return c.responseTypes +} + +func (c *testClient) GrantTypes() []oidc.GrantType { + return c.grantTypes +} + +func (c *testClient) DevMode() bool { + return c.devMode +} + +func (c *testClient) AllowedScopes() []string { + return nil +} + +func (c *testClient) RestrictAdditionalIdTokenScopes() func(scopes []string) []string { + return func(scopes []string) []string { + return scopes + } +} + +func (c *testClient) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string { + return func(scopes []string) []string { + return scopes + } +} + +func (c *testClient) IsScopeAllowed(scope string) bool { + return false +} + +func (c *testClient) IDTokenUserinfoClaimsAssertion() bool { + return false +} + +func (c *testClient) ClockSkew() time.Duration { + return 0 +} + +type requestVerifier struct { + UnimplementedServer + client Client +} + +func (s *requestVerifier) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) { + if s.client == nil { + return nil, oidc.ErrServerError() + } + return &ClientRequest[oidc.AuthRequest]{ + Request: r, + Client: s.client, + }, nil +} + +func (s *requestVerifier) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) { + if s.client == nil { + return nil, oidc.ErrServerError() + } + return s.client, nil +} + +var testDecoder = func() *schema.Decoder { + decoder := schema.NewDecoder() + decoder.IgnoreUnknownKeys(true) + return decoder +}() + +type webServerResult struct { + wantStatus int + wantBody string +} + +func runWebServerTest(t *testing.T, handler http.HandlerFunc, r *http.Request, want webServerResult) { + t.Helper() + if r.Method == http.MethodPost { + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } + w := httptest.NewRecorder() + handler(w, r) + res := w.Result() + assert.Equal(t, want.wantStatus, res.StatusCode) + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + assert.JSONEq(t, want.wantBody, string(body)) +} + +func Test_webServer_withClient(t *testing.T) { + tests := []struct { + name string + r *http.Request + want webServerResult + }{ + { + name: "parse error", + r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error parsing form"}`, + }, + }, + { + name: "invalid grant type", + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=native&grant_type=bad&foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unauthorized_client", "error_description":"grant_type \"bad\" not allowed"}`, + }, + }, + { + name: "no grant type", + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=native&foo=bar")), + want: webServerResult{ + wantStatus: http.StatusOK, + wantBody: `{"foo":"bar"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: &requestVerifier{ + client: newClient(clientTypeNative), + }, + decoder: testDecoder, + logger: slog.Default(), + } + handler := func(w http.ResponseWriter, r *http.Request, client Client) { + fmt.Fprintf(w, `{"foo":%q}`, r.FormValue("foo")) + } + runWebServerTest(t, s.withClient(handler), tt.r, tt.want) + }) + } +} + +func Test_webServer_verifyRequestClient(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want Client + wantErr error + }{ + { + name: "parse form error", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))), + wantErr: oidc.ErrInvalidRequest().WithDescription("error parsing form"), + }, + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + wantErr: oidc.ErrInvalidRequest().WithDescription("error decoding form"), + }, + { + name: "basic auth, client_id error", + decoder: testDecoder, + r: func() *http.Request { + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")) + r.SetBasicAuth(`%%%`, "secret") + return r + }(), + wantErr: oidc.ErrInvalidClient().WithDescription("invalid basic auth header"), + }, + { + name: "basic auth, client_secret error", + decoder: testDecoder, + r: func() *http.Request { + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")) + r.SetBasicAuth("web", `%%%`) + return r + }(), + wantErr: oidc.ErrInvalidClient().WithDescription("invalid basic auth header"), + }, + { + name: "missing client id and assertion", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + wantErr: oidc.ErrInvalidRequest().WithDescription("client_id or client_assertion must be provided"), + }, + { + name: "wrong assertion type", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar&client_assertion=xxx&client_assertion_type=wrong")), + wantErr: oidc.ErrInvalidRequest().WithDescription("invalid client_assertion_type wrong"), + }, + { + name: "unimplemented verify client called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar&client_id=web")), + wantErr: StatusError{ + parent: oidc.ErrServerError().WithDescription("/ not implemented on this server"), + statusCode: UnimplementedStatusCode, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + tt.r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + got, err := s.verifyRequestClient(tt.r) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_webServer_authorizeHandler(t *testing.T) { + type fields struct { + server Server + decoder httphelper.Decoder + } + tests := []struct { + name string + fields fields + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + fields: fields{ + server: &requestVerifier{}, + decoder: schema.NewDecoder(), + }, + r: httptest.NewRequest(http.MethodPost, "/authorize", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "authorize error", + fields: fields{ + server: &requestVerifier{}, + decoder: testDecoder, + }, + r: httptest.NewRequest(http.MethodPost, "/authorize", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"server_error"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: tt.fields.server, + decoder: tt.fields.decoder, + logger: slog.Default(), + } + runWebServerTest(t, s.authorizeHandler, tt.r, tt.want) + }) + } +} + +func Test_webServer_authorize(t *testing.T) { + type args struct { + ctx context.Context + r *Request[oidc.AuthRequest] + } + tests := []struct { + name string + server Server + args args + want *Redirect + wantErr error + }{ + { + name: "verify error", + server: &requestVerifier{}, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + Data: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"openid"}, + ResponseType: oidc.ResponseTypeCode, + ClientID: "web", + RedirectURI: "https://registered.com/callback", + MaxAge: gu.Ptr[uint](300), + }, + }, + }, + wantErr: oidc.ErrServerError(), + }, + { + name: "missing redirect", + server: &requestVerifier{ + client: newClient(clientTypeWeb), + }, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + Data: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"openid"}, + ResponseType: oidc.ResponseTypeCode, + ClientID: "web", + MaxAge: gu.Ptr[uint](300), + }, + }, + }, + wantErr: ErrAuthReqMissingRedirectURI, + }, + { + name: "invalid prompt", + server: &requestVerifier{ + client: newClient(clientTypeWeb), + }, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + Data: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"openid"}, + ResponseType: oidc.ResponseTypeCode, + ClientID: "web", + RedirectURI: "https://registered.com/callback", + MaxAge: gu.Ptr[uint](300), + Prompt: []string{oidc.PromptNone, oidc.PromptLogin}, + }, + }, + }, + wantErr: oidc.ErrInvalidRequest().WithDescription("The prompt parameter `none` must only be used as a single value"), + }, + { + name: "missing scopes", + server: &requestVerifier{ + client: newClient(clientTypeWeb), + }, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + Data: &oidc.AuthRequest{ + ResponseType: oidc.ResponseTypeCode, + ClientID: "web", + RedirectURI: "https://registered.com/callback", + MaxAge: gu.Ptr[uint](300), + Prompt: []string{oidc.PromptNone}, + }, + }, + }, + wantErr: oidc.ErrInvalidRequest(). + WithDescription("The scope of your request is missing. Please ensure some scopes are requested. " + + "If you have any questions, you may contact the administrator of the application."), + }, + { + name: "invalid redirect", + server: &requestVerifier{ + client: newClient(clientTypeWeb), + }, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + Data: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"openid"}, + ResponseType: oidc.ResponseTypeCode, + ClientID: "web", + RedirectURI: "https://example.com/callback", + MaxAge: gu.Ptr[uint](300), + Prompt: []string{oidc.PromptNone}, + }, + }, + }, + wantErr: oidc.ErrInvalidRequestRedirectURI(). + WithDescription("The requested redirect_uri is missing in the client configuration. " + + "If you have any questions, you may contact the administrator of the application."), + }, + { + name: "invalid response type", + server: &requestVerifier{ + client: newClient(clientTypeWeb), + }, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + Data: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"openid"}, + ResponseType: oidc.ResponseTypeIDToken, + ClientID: "web", + RedirectURI: "https://registered.com/callback", + MaxAge: gu.Ptr[uint](300), + Prompt: []string{oidc.PromptNone}, + }, + }, + }, + wantErr: oidc.ErrUnauthorizedClient().WithDescription("The requested response type is missing in the client configuration. " + + "If you have any questions, you may contact the administrator of the application."), + }, + { + name: "unimplemented Authorize called", + server: &requestVerifier{ + client: newClient(clientTypeWeb), + }, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + URL: &url.URL{ + Path: "/authorize", + }, + Data: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"openid"}, + ResponseType: oidc.ResponseTypeCode, + ClientID: "web", + RedirectURI: "https://registered.com/callback", + MaxAge: gu.Ptr[uint](300), + Prompt: []string{oidc.PromptNone}, + }, + }, + }, + wantErr: StatusError{ + parent: oidc.ErrServerError().WithDescription("/authorize not implemented on this server"), + statusCode: UnimplementedStatusCode, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: tt.server, + decoder: testDecoder, + logger: slog.Default(), + } + got, err := s.authorize(tt.args.ctx, tt.args.r) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_webServer_deviceAuthorizationHandler(t *testing.T) { + type fields struct { + server Server + decoder httphelper.Decoder + } + tests := []struct { + name string + fields fields + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + fields: fields{ + server: &requestVerifier{}, + decoder: schema.NewDecoder(), + }, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "unimplemented DeviceAuthorization called", + fields: fields{ + server: &requestVerifier{ + client: newClient(clientTypeNative), + }, + decoder: testDecoder, + }, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=native_client")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: tt.fields.server, + decoder: tt.fields.decoder, + logger: slog.Default(), + } + client := newClient(clientTypeUserAgent) + runWebServerClientTest(t, s.deviceAuthorizationHandler, tt.r, client, tt.want) + }) + } +} + +func Test_webServer_tokensHandler(t *testing.T) { + tests := []struct { + name string + r *http.Request + want webServerResult + }{ + { + name: "parse form error", + r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error parsing form"}`, + }, + }, + { + name: "missing grant type", + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"grant_type missing"}`, + }, + }, + { + name: "invalid grant type", + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("grant_type=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"bar not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + logger: slog.Default(), + } + runWebServerTest(t, s.tokensHandler, tt.r, tt.want) + }) + } +} + +func Test_webServer_jwtProfileHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "assertion missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"assertion missing"}`, + }, + }, + { + name: "unimplemented JWTProfile called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("assertion=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"urn:ietf:params:oauth:grant-type:jwt-bearer not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerTest(t, s.jwtProfileHandler, tt.r, tt.want) + }) + } +} + +func runWebServerClientTest(t *testing.T, handler func(http.ResponseWriter, *http.Request, Client), r *http.Request, client Client, want webServerResult) { + t.Helper() + runWebServerTest(t, func(client Client) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + handler(w, r, client) + } + }(client), r, want) +} + +func Test_webServer_codeExchangeHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "code missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"code missing"}`, + }, + }, + { + name: "redirect missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("code=123")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"redirect_uri missing"}`, + }, + }, + { + name: "unimplemented CodeExchange called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("code=123&redirect_uri=https://example.com/callback")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"authorization_code not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + client := newClient(clientTypeUserAgent) + runWebServerClientTest(t, s.codeExchangeHandler, tt.r, client, tt.want) + }) + } +} + +func Test_webServer_refreshTokenHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "refresh token missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"refresh_token missing"}`, + }, + }, + { + name: "unimplemented RefreshToken called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("refresh_token=xxx")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"refresh_token not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + client := newClient(clientTypeUserAgent) + runWebServerClientTest(t, s.refreshTokenHandler, tt.r, client, tt.want) + }) + } +} + +func Test_webServer_tokenExchangeHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "subject token missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"subject_token missing"}`, + }, + }, + { + name: "subject token type missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"subject_token_type missing"}`, + }, + }, + { + name: "subject token type unsupported", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=foo")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"subject_token_type is not supported"}`, + }, + }, + { + name: "unsupported requested token type", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=urn:ietf:params:oauth:token-type:access_token&requested_token_type=foo")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"requested_token_type is not supported"}`, + }, + }, + { + name: "unsupported actor token type", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=urn:ietf:params:oauth:token-type:access_token&requested_token_type=urn:ietf:params:oauth:token-type:access_token&actor_token_type=foo")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"actor_token_type is not supported"}`, + }, + }, + { + name: "unimplemented TokenExchange called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=urn:ietf:params:oauth:token-type:access_token&requested_token_type=urn:ietf:params:oauth:token-type:access_token&actor_token_type=urn:ietf:params:oauth:token-type:access_token")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"urn:ietf:params:oauth:grant-type:token-exchange not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + client := newClient(clientTypeUserAgent) + runWebServerClientTest(t, s.tokenExchangeHandler, tt.r, client, tt.want) + }) + } +} + +func Test_webServer_clientCredentialsHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + client Client + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + client: newClient(clientTypeUserAgent), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "public client", + decoder: testDecoder, + client: newClient(clientTypeNative), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_client", "error_description":"client must be authenticated"}`, + }, + }, + { + name: "unimplemented ClientCredentialsExchange called", + decoder: testDecoder, + client: newClient(clientTypeUserAgent), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"client_credentials not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerClientTest(t, s.clientCredentialsHandler, tt.r, tt.client, tt.want) + }) + } +} + +func Test_webServer_deviceTokenHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "device code missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"device_code missing"}`, + }, + }, + { + name: "unimplemented DeviceToken called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("device_code=xxx")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"urn:ietf:params:oauth:grant-type:device_code not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + client := newClient(clientTypeUserAgent) + runWebServerClientTest(t, s.deviceTokenHandler, tt.r, client, tt.want) + }) + } +} + +func Test_webServer_introspectionHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + client Client + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + client: newClient(clientTypeUserAgent), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "public client", + decoder: testDecoder, + client: newClient(clientTypeNative), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_client", "error_description":"client must be authenticated"}`, + }, + }, + { + name: "token missing", + decoder: testDecoder, + client: newClient(clientTypeWeb), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"token missing"}`, + }, + }, + { + name: "unimplemented Introspect called", + decoder: testDecoder, + client: newClient(clientTypeWeb), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("token=xxx")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerClientTest(t, s.introspectionHandler, tt.r, tt.client, tt.want) + }) + } +} + +func Test_webServer_userInfoHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "access token missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusUnauthorized, + wantBody: `{"error":"invalid_request", "error_description":"access token missing"}`, + }, + }, + { + name: "unimplemented UserInfo called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("access_token=xxx")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + { + name: "bearer", + decoder: testDecoder, + r: func() *http.Request { + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Set("authorization", strings.Join([]string{"Bearer", "xxx"}, " ")) + return r + }(), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerTest(t, s.userInfoHandler, tt.r, tt.want) + }) + } +} + +func Test_webServer_revocationHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + client Client + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + client: newClient(clientTypeWeb), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "token missing", + decoder: testDecoder, + client: newClient(clientTypeWeb), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"token missing"}`, + }, + }, + { + name: "unimplemented Revocation called, confidential client", + decoder: testDecoder, + client: newClient(clientTypeWeb), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("token=xxx")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + { + name: "unimplemented Revocation called, public client", + decoder: testDecoder, + client: newClient(clientTypeNative), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("token=xxx")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerClientTest(t, s.revocationHandler, tt.r, tt.client, tt.want) + }) + } +} + +func Test_webServer_endSessionHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "unimplemented EndSession called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("id_token_hint=xxx")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerTest(t, s.endSessionHandler, tt.r, tt.want) + }) + } +} + +func Test_webServer_simpleHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + method func(context.Context, *Request[struct{}]) (*Response, error) + r *http.Request + want webServerResult + }{ + { + name: "parse error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error parsing form"}`, + }, + }, + { + name: "method error", + decoder: schema.NewDecoder(), + method: func(ctx context.Context, r *Request[struct{}]) (*Response, error) { + return nil, io.ErrClosedPipe + }, + r: httptest.NewRequest(http.MethodGet, "/", bytes.NewReader(make([]byte, 11<<20))), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"server_error", "error_description":"io: read/write on closed pipe"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerTest(t, simpleHandler(s, tt.method), tt.r, tt.want) + }) + } +} + +func Test_decodeRequest(t *testing.T) { + type dst struct { + A string `schema:"a"` + B string `schema:"b"` + } + type args struct { + r *http.Request + postOnly bool + } + tests := []struct { + name string + args args + want *dst + wantErr error + }{ + { + name: "parse error", + args: args{ + r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))), + }, + wantErr: oidc.ErrInvalidRequest().WithDescription("error parsing form"), + }, + { + name: "decode error", + args: args{ + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + }, + wantErr: oidc.ErrInvalidRequest().WithDescription("error decoding form"), + }, + { + name: "success, get", + args: args{ + r: httptest.NewRequest(http.MethodGet, "/?a=b&b=a", nil), + }, + want: &dst{ + A: "b", + B: "a", + }, + }, + { + name: "success, post only", + args: args{ + r: httptest.NewRequest(http.MethodPost, "/?b=a", strings.NewReader("a=b&")), + postOnly: true, + }, + want: &dst{ + A: "b", + }, + }, + { + name: "success, post mixed", + args: args{ + r: httptest.NewRequest(http.MethodPost, "/?b=a", strings.NewReader("a=b&")), + postOnly: false, + }, + want: &dst{ + A: "b", + B: "a", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.args.r.Method == http.MethodPost { + tt.args.r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } + got, err := decodeRequest[dst](schema.NewDecoder(), tt.args.r, tt.args.postOnly) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/op/server_legacy.go b/pkg/op/server_legacy.go new file mode 100644 index 00000000..0a7de855 --- /dev/null +++ b/pkg/op/server_legacy.go @@ -0,0 +1,344 @@ +package op + +import ( + "context" + "errors" + "net/http" + "time" + + "github.com/go-chi/chi" + "github.com/zitadel/oidc/v3/pkg/oidc" +) + +// LegacyServer is an implementation of [Server[] that +// simply wraps a [OpenIDProvider]. +// It can be used to transition from the former Provider/Storage +// interfaces to the new Server interface. +type LegacyServer struct { + UnimplementedServer + provider OpenIDProvider + endpoints Endpoints +} + +// NewLegacyServer wraps provider in a `Server` and returns a handler which is +// the Server's router. +// +// Only non-nil endpoints will be registered on the router. +// Nil endpoints are disabled. +// +// The passed endpoints is also set to the provider, +// to be consistent with the discovery config. +// Any `With*Endpoint()` option used on the provider is +// therefore ineffective. +func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) http.Handler { + server := RegisterServer(&LegacyServer{ + provider: provider, + endpoints: endpoints, + }, endpoints, WithHTTPMiddleware(intercept(provider.IssuerFromRequest))) + + router := chi.NewRouter() + router.Mount("/", server) + router.HandleFunc(authCallbackPath(provider), authorizeCallbackHandler(provider)) + + return router +} + +func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) { + return NewResponse(Status{Status: "ok"}), nil +} + +func (s *LegacyServer) Ready(ctx context.Context, r *Request[struct{}]) (*Response, error) { + for _, probe := range s.provider.Probes() { + // shouldn't we run probes in Go routines? + if err := probe(ctx); err != nil { + return nil, NewStatusError(err, http.StatusInternalServerError) + } + } + return NewResponse(Status{Status: "ok"}), nil +} + +func (s *LegacyServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) { + return NewResponse( + createDiscoveryConfigV2(ctx, s.provider, s.provider.Storage(), &s.endpoints), + ), nil +} + +func (s *LegacyServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) { + keys, err := s.provider.Storage().KeySet(ctx) + if err != nil { + return nil, NewStatusError(err, http.StatusInternalServerError) + } + return NewResponse(jsonWebKeySet(keys)), nil +} + +var ( + ErrAuthReqMissingClientID = errors.New("auth request is missing client_id") + ErrAuthReqMissingRedirectURI = errors.New("auth request is missing redirect_uri") +) + +func (s *LegacyServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) { + if r.Data.RequestParam != "" { + if !s.provider.RequestObjectSupported() { + return nil, oidc.ErrRequestNotSupported() + } + err := ParseRequestObject(ctx, r.Data, s.provider.Storage(), IssuerFromContext(ctx)) + if err != nil { + return nil, err + } + } + if r.Data.ClientID == "" { + return nil, ErrAuthReqMissingClientID + } + client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID) + if err != nil { + return nil, oidc.DefaultToServerError(err, "unable to retrieve client by id") + } + + return &ClientRequest[oidc.AuthRequest]{ + Request: r, + Client: client, + }, nil +} + +func (s *LegacyServer) Authorize(ctx context.Context, r *ClientRequest[oidc.AuthRequest]) (_ *Redirect, err error) { + userID, err := ValidateAuthReqIDTokenHint(ctx, r.Data.IDTokenHint, s.provider.IDTokenHintVerifier(ctx)) + if err != nil { + return nil, err + } + req, err := s.provider.Storage().CreateAuthRequest(ctx, r.Data, userID) + if err != nil { + return TryErrorRedirect(ctx, r.Data, oidc.DefaultToServerError(err, "unable to save auth request"), s.provider.Encoder(), s.provider.Logger()) + } + return NewRedirect(r.Client.LoginURL(req.GetID())), nil +} + +func (s *LegacyServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) { + response, err := createDeviceAuthorization(ctx, r.Data, r.Client.GetID(), s.provider) + if err != nil { + return nil, NewStatusError(err, http.StatusInternalServerError) + } + return NewResponse(response), nil +} + +func (s *LegacyServer) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) { + if oidc.GrantType(r.Form.Get("grant_type")) == oidc.GrantTypeClientCredentials { + storage, ok := s.provider.Storage().(ClientCredentialsStorage) + if !ok { + return nil, oidc.ErrUnsupportedGrantType().WithDescription("client_credentials grant not supported") + } + return storage.ClientCredentials(ctx, r.Data.ClientID, r.Data.ClientSecret) + } + + if r.Data.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion { + jwtExchanger, ok := s.provider.(JWTAuthorizationGrantExchanger) + if !ok || !s.provider.AuthMethodPrivateKeyJWTSupported() { + return nil, oidc.ErrInvalidClient().WithDescription("auth_method private_key_jwt not supported") + } + return AuthorizePrivateJWTKey(ctx, r.Data.ClientAssertion, jwtExchanger) + } + client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID) + if err != nil { + return nil, oidc.ErrInvalidClient().WithParent(err) + } + + switch client.AuthMethod() { + case oidc.AuthMethodNone: + return client, nil + case oidc.AuthMethodPrivateKeyJWT: + return nil, oidc.ErrInvalidClient().WithDescription("private_key_jwt not allowed for this client") + case oidc.AuthMethodPost: + if !s.provider.AuthMethodPostSupported() { + return nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported") + } + } + + err = AuthorizeClientIDSecret(ctx, r.Data.ClientID, r.Data.ClientSecret, s.provider.Storage()) + if err != nil { + return nil, err + } + + return client, nil +} + +func (s *LegacyServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) { + authReq, err := AuthRequestByCode(ctx, s.provider.Storage(), r.Data.Code) + if err != nil { + return nil, err + } + if r.Client.AuthMethod() == oidc.AuthMethodNone { + if err = AuthorizeCodeChallenge(r.Data.CodeVerifier, authReq.GetCodeChallenge()); err != nil { + return nil, err + } + } + resp, err := CreateTokenResponse(ctx, authReq, r.Client, s.provider, true, r.Data.Code, "") + if err != nil { + return nil, err + } + return NewResponse(resp), nil +} + +func (s *LegacyServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) { + if !s.provider.GrantTypeRefreshTokenSupported() { + return nil, unimplementedGrantError(oidc.GrantTypeRefreshToken) + } + request, err := RefreshTokenRequestByRefreshToken(ctx, s.provider.Storage(), r.Data.RefreshToken) + if err != nil { + return nil, err + } + if r.Client.GetID() != request.GetClientID() { + return nil, oidc.ErrInvalidGrant() + } + if err = ValidateRefreshTokenScopes(r.Data.Scopes, request); err != nil { + return nil, err + } + resp, err := CreateTokenResponse(ctx, request, r.Client, s.provider, true, "", r.Data.RefreshToken) + if err != nil { + return nil, err + } + return NewResponse(resp), nil +} + +func (s *LegacyServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfileGrantRequest]) (*Response, error) { + exchanger, ok := s.provider.(JWTAuthorizationGrantExchanger) + if !ok { + return nil, unimplementedGrantError(oidc.GrantTypeBearer) + } + tokenRequest, err := VerifyJWTAssertion(ctx, r.Data.Assertion, exchanger.JWTProfileVerifier(ctx)) + if err != nil { + return nil, err + } + + tokenRequest.Scopes, err = exchanger.Storage().ValidateJWTProfileScopes(ctx, tokenRequest.Issuer, r.Data.Scope) + if err != nil { + return nil, err + } + resp, err := CreateJWTTokenResponse(ctx, tokenRequest, exchanger) + if err != nil { + return nil, err + } + return NewResponse(resp), nil +} + +func (s *LegacyServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) { + if !s.provider.GrantTypeTokenExchangeSupported() { + return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange) + } + tokenExchangeRequest, err := CreateTokenExchangeRequest(ctx, r.Data, r.Client, s.provider) + if err != nil { + return nil, err + } + resp, err := CreateTokenExchangeResponse(ctx, tokenExchangeRequest, r.Client, s.provider) + if err != nil { + return nil, err + } + return NewResponse(resp), nil +} + +func (s *LegacyServer) ClientCredentialsExchange(ctx context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) { + storage, ok := s.provider.Storage().(ClientCredentialsStorage) + if !ok { + return nil, unimplementedGrantError(oidc.GrantTypeClientCredentials) + } + tokenRequest, err := storage.ClientCredentialsTokenRequest(ctx, r.Client.GetID(), r.Data.Scope) + if err != nil { + return nil, err + } + resp, err := CreateClientCredentialsTokenResponse(ctx, tokenRequest, s.provider, r.Client) + if err != nil { + return nil, err + } + return NewResponse(resp), nil +} + +func (s *LegacyServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) { + if !s.provider.GrantTypeClientCredentialsSupported() { + return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode) + } + // use a limited context timeout shorter as the default + // poll interval of 5 seconds. + ctx, cancel := context.WithTimeout(ctx, 4*time.Second) + defer cancel() + + state, err := CheckDeviceAuthorizationState(ctx, r.Client.GetID(), r.Data.DeviceCode, s.provider) + if err != nil { + return nil, err + } + tokenRequest := &deviceAccessTokenRequest{ + subject: state.Subject, + audience: []string{r.Client.GetID()}, + scopes: state.Scopes, + } + resp, err := CreateDeviceTokenResponse(ctx, tokenRequest, s.provider, r.Client) + if err != nil { + return nil, err + } + return NewResponse(resp), nil +} + +func (s *LegacyServer) Introspect(ctx context.Context, r *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) { + response := new(oidc.IntrospectionResponse) + tokenID, subject, ok := getTokenIDAndSubject(ctx, s.provider, r.Data.Token) + if !ok { + return NewResponse(response), nil + } + err := s.provider.Storage().SetIntrospectionFromToken(ctx, response, tokenID, subject, r.Client.GetID()) + if err != nil { + return NewResponse(response), nil + } + response.Active = true + return NewResponse(response), nil +} + +func (s *LegacyServer) UserInfo(ctx context.Context, r *Request[oidc.UserInfoRequest]) (*Response, error) { + tokenID, subject, ok := getTokenIDAndSubject(ctx, s.provider, r.Data.AccessToken) + if !ok { + return nil, NewStatusError(oidc.ErrAccessDenied().WithDescription("access token invalid"), http.StatusUnauthorized) + } + info := new(oidc.UserInfo) + err := s.provider.Storage().SetUserinfoFromToken(ctx, info, tokenID, subject, r.Header.Get("origin")) + if err != nil { + return nil, NewStatusError(err, http.StatusForbidden) + } + return NewResponse(info), nil +} + +func (s *LegacyServer) Revocation(ctx context.Context, r *ClientRequest[oidc.RevocationRequest]) (*Response, error) { + var subject string + doDecrypt := true + if r.Data.TokenTypeHint != "access_token" { + userID, tokenID, err := s.provider.Storage().GetRefreshTokenInfo(ctx, r.Client.GetID(), r.Data.Token) + if err != nil { + // An invalid refresh token means that we'll try other things (leaving doDecrypt==true) + if !errors.Is(err, ErrInvalidRefreshToken) { + return nil, RevocationError(oidc.ErrServerError().WithParent(err)) + } + } else { + r.Data.Token = tokenID + subject = userID + doDecrypt = false + } + } + if doDecrypt { + tokenID, userID, ok := getTokenIDAndSubjectForRevocation(ctx, s.provider, r.Data.Token) + if ok { + r.Data.Token = tokenID + subject = userID + } + } + if err := s.provider.Storage().RevokeToken(ctx, r.Data.Token, subject, r.Client.GetID()); err != nil { + return nil, RevocationError(err) + } + return NewResponse(nil), nil +} + +func (s *LegacyServer) EndSession(ctx context.Context, r *Request[oidc.EndSessionRequest]) (*Redirect, error) { + session, err := ValidateEndSessionRequest(ctx, r.Data, s.provider) + if err != nil { + return nil, err + } + err = s.provider.Storage().TerminateSession(ctx, session.UserID, session.ClientID) + if err != nil { + return nil, err + } + return NewRedirect(session.RedirectURI), nil +} diff --git a/pkg/op/server_test.go b/pkg/op/server_test.go new file mode 100644 index 00000000..0cad8fd5 --- /dev/null +++ b/pkg/op/server_test.go @@ -0,0 +1,5 @@ +package op + +// implementation check +var _ Server = &UnimplementedServer{} +var _ Server = &LegacyServer{} diff --git a/pkg/op/session.go b/pkg/op/session.go index 90142555..c33627f2 100644 --- a/pkg/op/session.go +++ b/pkg/op/session.go @@ -6,15 +6,17 @@ import ( "net/url" "path" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" + "golang.org/x/exp/slog" ) type SessionEnder interface { Decoder() httphelper.Decoder Storage() Storage - IDTokenHintVerifier(context.Context) IDTokenHintVerifier + IDTokenHintVerifier(context.Context) *IDTokenHintVerifier DefaultLogoutRedirectURI() string + Logger() *slog.Logger } func endSessionHandler(ender SessionEnder) func(http.ResponseWriter, *http.Request) { @@ -31,7 +33,7 @@ func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) { } session, err := ValidateEndSessionRequest(r.Context(), req, ender) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, ender.Logger()) return } redirect := session.RedirectURI @@ -41,7 +43,7 @@ func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) { err = ender.Storage().TerminateSession(r.Context(), session.UserID, session.ClientID) } if err != nil { - RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session")) + RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session"), ender.Logger()) return } http.Redirect(w, r, redirect, http.StatusFound) diff --git a/pkg/op/signer.go b/pkg/op/signer.go index 6cef2883..b2207399 100644 --- a/pkg/op/signer.go +++ b/pkg/op/signer.go @@ -3,7 +3,7 @@ package op import ( "errors" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" ) var ErrSignerCreationFailed = errors.New("signer creation failed") diff --git a/pkg/op/storage.go b/pkg/op/storage.go index 17aa0b49..d083a31c 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -5,9 +5,9 @@ import ( "errors" "time" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/oidc" ) type AuthStorage interface { @@ -191,18 +191,6 @@ type DeviceAuthorizationStorage interface { // GetDeviceAuthorizatonState returns the current state of the device authorization flow in the database. // The method is polled untill the the authorization is eighter Completed, Expired or Denied. GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (*DeviceAuthorizationState, error) - - // GetDeviceAuthorizationByUserCode resturn the current state of the device authorization flow, - // identified by the user code. - GetDeviceAuthorizationByUserCode(ctx context.Context, userCode string) (*DeviceAuthorizationState, error) - - // CompleteDeviceAuthorization marks a device authorization entry as Completed, - // identified by userCode. The Subject is added to the state, so that - // GetDeviceAuthorizatonState can use it to create a new Access Token. - CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error - - // DenyDeviceAuthorization marks a device authorization entry as Denied. - DenyDeviceAuthorization(ctx context.Context, userCode string) error } func assertDeviceStorage(s Storage) (DeviceAuthorizationStorage, error) { diff --git a/pkg/op/token.go b/pkg/op/token.go index 001023ce..bc45c298 100644 --- a/pkg/op/token.go +++ b/pkg/op/token.go @@ -4,9 +4,9 @@ import ( "context" "time" - "github.com/zitadel/oidc/v2/pkg/crypto" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/strings" + "github.com/zitadel/oidc/v3/pkg/crypto" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/strings" ) type TokenCreator interface { diff --git a/pkg/op/token_client_credentials.go b/pkg/op/token_client_credentials.go index 0b91a36e..7f1debed 100644 --- a/pkg/op/token_client_credentials.go +++ b/pkg/op/token_client_credentials.go @@ -5,8 +5,8 @@ import ( "net/http" "net/url" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) // ClientCredentialsExchange handles the OAuth 2.0 client_credentials grant, including @@ -18,18 +18,18 @@ func ClientCredentialsExchange(w http.ResponseWriter, r *http.Request, exchanger request, err := ParseClientCredentialsRequest(r, exchanger.Decoder()) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } validatedRequest, client, err := ValidateClientCredentialsRequest(r.Context(), request, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } resp, err := CreateClientCredentialsTokenResponse(r.Context(), validatedRequest, exchanger, client) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } diff --git a/pkg/op/token_code.go b/pkg/op/token_code.go index f2162ef7..3612240f 100644 --- a/pkg/op/token_code.go +++ b/pkg/op/token_code.go @@ -4,8 +4,8 @@ import ( "context" "net/http" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) // CodeExchange handles the OAuth 2.0 authorization_code grant, including @@ -17,20 +17,20 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { tokenReq, err := ParseAccessTokenRequest(r, exchanger.Decoder()) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } if tokenReq.Code == "" { - RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing")) + RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), exchanger.Logger()) return } authReq, client, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } resp, err := CreateTokenResponse(r.Context(), authReq, client, exchanger, true, tokenReq.Code, "") if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } httphelper.MarshalJSON(w, resp) @@ -98,7 +98,7 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, if err != nil { return nil, nil, err } - err = AuthorizeCodeChallenge(tokenReq, request.GetCodeChallenge()) + err = AuthorizeCodeChallenge(tokenReq.CodeVerifier, request.GetCodeChallenge()) return request, client, err } if client.AuthMethod() == oidc.AuthMethodPost && !exchanger.AuthMethodPostSupported() { diff --git a/pkg/op/token_exchange.go b/pkg/op/token_exchange.go index e64ce80e..db3e468f 100644 --- a/pkg/op/token_exchange.go +++ b/pkg/op/token_exchange.go @@ -7,8 +7,8 @@ import ( "strings" "time" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) type TokenExchangeRequest interface { @@ -140,17 +140,17 @@ func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) tokenExchangeReq, clientID, clientSecret, err := ParseTokenExchangeRequest(r, exchanger.Decoder()) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } tokenExchangeRequest, client, err := ValidateTokenExchangeRequest(r.Context(), tokenExchangeReq, clientID, clientSecret, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } resp, err := CreateTokenExchangeResponse(r.Context(), tokenExchangeRequest, client, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } httphelper.MarshalJSON(w, resp) @@ -201,12 +201,6 @@ func ValidateTokenExchangeRequest( return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing") } - storage := exchanger.Storage() - teStorage, ok := storage.(TokenExchangeStorage) - if !ok { - return nil, nil, oidc.ErrUnsupportedGrantType().WithDescription("token_exchange grant not supported") - } - client, err := AuthorizeTokenExchangeClient(ctx, clientID, clientSecret, exchanger) if err != nil { return nil, nil, err @@ -224,10 +218,28 @@ func ValidateTokenExchangeRequest( return nil, nil, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported") } + req, err := CreateTokenExchangeRequest(ctx, oidcTokenExchangeRequest, client, exchanger) + if err != nil { + return nil, nil, err + } + return req, client, nil +} + +func CreateTokenExchangeRequest( + ctx context.Context, + oidcTokenExchangeRequest *oidc.TokenExchangeRequest, + client Client, + exchanger Exchanger, +) (TokenExchangeRequest, error) { + teStorage, ok := exchanger.Storage().(TokenExchangeStorage) + if !ok { + return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange) + } + exchangeSubjectTokenIDOrToken, exchangeSubject, exchangeSubjectTokenClaims, ok := GetTokenIDAndSubjectFromToken(ctx, exchanger, oidcTokenExchangeRequest.SubjectToken, oidcTokenExchangeRequest.SubjectTokenType, false) if !ok { - return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token is invalid") + return nil, oidc.ErrInvalidRequest().WithDescription("subject_token is invalid") } var ( @@ -238,7 +250,7 @@ func ValidateTokenExchangeRequest( exchangeActorTokenIDOrToken, exchangeActor, exchangeActorTokenClaims, ok = GetTokenIDAndSubjectFromToken(ctx, exchanger, oidcTokenExchangeRequest.ActorToken, oidcTokenExchangeRequest.ActorTokenType, true) if !ok { - return nil, nil, oidc.ErrInvalidRequest().WithDescription("actor_token is invalid") + return nil, oidc.ErrInvalidRequest().WithDescription("actor_token is invalid") } } @@ -262,17 +274,17 @@ func ValidateTokenExchangeRequest( authTime: time.Now(), } - err = teStorage.ValidateTokenExchangeRequest(ctx, req) + err := teStorage.ValidateTokenExchangeRequest(ctx, req) if err != nil { - return nil, nil, err + return nil, err } err = teStorage.CreateTokenExchangeRequest(ctx, req) if err != nil { - return nil, nil, err + return nil, err } - return req, client, nil + return req, nil } func GetTokenIDAndSubjectFromToken( diff --git a/pkg/op/token_intospection.go b/pkg/op/token_intospection.go index 85823883..21b79c3b 100644 --- a/pkg/op/token_intospection.go +++ b/pkg/op/token_intospection.go @@ -5,15 +5,15 @@ import ( "errors" "net/http" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) type Introspector interface { Decoder() httphelper.Decoder Crypto() Crypto Storage() Storage - AccessTokenVerifier(context.Context) AccessTokenVerifier + AccessTokenVerifier(context.Context) *AccessTokenVerifier } type IntrospectorJWTProfile interface { diff --git a/pkg/op/token_jwt_profile.go b/pkg/op/token_jwt_profile.go index 5a94d99f..96ce1ede 100644 --- a/pkg/op/token_jwt_profile.go +++ b/pkg/op/token_jwt_profile.go @@ -5,13 +5,13 @@ import ( "net/http" "time" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) type JWTAuthorizationGrantExchanger interface { Exchanger - JWTProfileVerifier(context.Context) JWTProfileVerifier + JWTProfileVerifier(context.Context) *JWTProfileVerifier } // JWTProfile handles the OAuth 2.0 JWT Profile Authorization Grant https://tools.ietf.org/html/rfc7523#section-2.1 @@ -22,23 +22,23 @@ func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger JWTAuthorizati profileRequest, err := ParseJWTProfileGrantRequest(r, exchanger.Decoder()) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } tokenRequest, err := VerifyJWTAssertion(r.Context(), profileRequest.Assertion, exchanger.JWTProfileVerifier(r.Context())) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } tokenRequest.Scopes, err = exchanger.Storage().ValidateJWTProfileScopes(r.Context(), tokenRequest.Issuer, profileRequest.Scope) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } resp, err := CreateJWTTokenResponse(r.Context(), tokenRequest, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } httphelper.MarshalJSON(w, resp) diff --git a/pkg/op/token_refresh.go b/pkg/op/token_refresh.go index efb6fc08..afca3bf7 100644 --- a/pkg/op/token_refresh.go +++ b/pkg/op/token_refresh.go @@ -6,9 +6,9 @@ import ( "net/http" "time" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/strings" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/strings" ) type RefreshTokenRequest interface { @@ -30,16 +30,16 @@ func RefreshTokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exch tokenReq, err := ParseRefreshTokenRequest(r, exchanger.Decoder()) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } validatedRequest, client, err := ValidateRefreshTokenRequest(r.Context(), tokenReq, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } resp, err := CreateTokenResponse(r.Context(), validatedRequest, client, exchanger, true, "", tokenReq.RefreshToken) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } httphelper.MarshalJSON(w, resp) diff --git a/pkg/op/token_request.go b/pkg/op/token_request.go index 12073756..f00b294b 100644 --- a/pkg/op/token_request.go +++ b/pkg/op/token_request.go @@ -5,8 +5,9 @@ import ( "net/http" "net/url" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" + "golang.org/x/exp/slog" ) type Exchanger interface { @@ -20,8 +21,9 @@ type Exchanger interface { GrantTypeJWTAuthorizationSupported() bool GrantTypeClientCredentialsSupported() bool GrantTypeDeviceCodeSupported() bool - AccessTokenVerifier(context.Context) AccessTokenVerifier - IDTokenHintVerifier(context.Context) IDTokenHintVerifier + AccessTokenVerifier(context.Context) *AccessTokenVerifier + IDTokenHintVerifier(context.Context) *IDTokenHintVerifier + Logger() *slog.Logger } func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Request) { @@ -66,10 +68,10 @@ func Exchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { return } case "": - RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing")) + RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), exchanger.Logger()) return } - RequestError(w, r, oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", grantType)) + RequestError(w, r, oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", grantType), exchanger.Logger()) } // AuthenticatedTokenRequest is a helper interface for ParseAuthenticatedTokenRequest @@ -122,11 +124,11 @@ func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string, // AuthorizeCodeChallenge authorizes a client by validating the code_verifier against the previously sent // code_challenge of the auth request (PKCE) -func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, challenge *oidc.CodeChallenge) error { - if tokenReq.CodeVerifier == "" { +func AuthorizeCodeChallenge(codeVerifier string, challenge *oidc.CodeChallenge) error { + if codeVerifier == "" { return oidc.ErrInvalidRequest().WithDescription("code_challenge required") } - if !oidc.VerifyCodeChallenge(challenge, tokenReq.CodeVerifier) { + if !oidc.VerifyCodeChallenge(challenge, codeVerifier) { return oidc.ErrInvalidGrant().WithDescription("invalid code challenge") } return nil diff --git a/pkg/op/token_revocation.go b/pkg/op/token_revocation.go index 58332c33..d19c7f7f 100644 --- a/pkg/op/token_revocation.go +++ b/pkg/op/token_revocation.go @@ -7,22 +7,22 @@ import ( "net/url" "strings" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) type Revoker interface { Decoder() httphelper.Decoder Crypto() Crypto Storage() Storage - AccessTokenVerifier(context.Context) AccessTokenVerifier + AccessTokenVerifier(context.Context) *AccessTokenVerifier AuthMethodPrivateKeyJWTSupported() bool AuthMethodPostSupported() bool } type RevokerJWTProfile interface { Revoker - JWTProfileVerifier(context.Context) JWTProfileVerifier + JWTProfileVerifier(context.Context) *JWTProfileVerifier } func revocationHandler(revoker Revoker) func(http.ResponseWriter, *http.Request) { @@ -131,6 +131,11 @@ func ParseTokenRevocationRequest(r *http.Request, revoker Revoker) (token, token } func RevocationRequestError(w http.ResponseWriter, r *http.Request, err error) { + statusErr := RevocationError(err) + httphelper.MarshalJSONWithStatus(w, statusErr.parent, statusErr.statusCode) +} + +func RevocationError(err error) StatusError { e := oidc.DefaultToServerError(err, err.Error()) status := http.StatusBadRequest switch e.ErrorType { @@ -139,7 +144,7 @@ func RevocationRequestError(w http.ResponseWriter, r *http.Request, err error) { case oidc.ServerError: status = 500 } - httphelper.MarshalJSONWithStatus(w, e, status) + return NewStatusError(e, status) } func getTokenIDAndSubjectForRevocation(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, bool) { diff --git a/pkg/op/userinfo.go b/pkg/op/userinfo.go index 21a0af48..86205b5f 100644 --- a/pkg/op/userinfo.go +++ b/pkg/op/userinfo.go @@ -6,15 +6,15 @@ import ( "net/http" "strings" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) type UserinfoProvider interface { Decoder() httphelper.Decoder Crypto() Crypto Storage() Storage - AccessTokenVerifier(context.Context) AccessTokenVerifier + AccessTokenVerifier(context.Context) *AccessTokenVerifier } func userinfoHandler(userinfoProvider UserinfoProvider) func(http.ResponseWriter, *http.Request) { diff --git a/pkg/op/verifier_access_token.go b/pkg/op/verifier_access_token.go index 9a8b9128..120bfa71 100644 --- a/pkg/op/verifier_access_token.go +++ b/pkg/op/verifier_access_token.go @@ -2,62 +2,25 @@ package op import ( "context" - "time" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/oidc" ) -type AccessTokenVerifier interface { - oidc.Verifier - SupportedSignAlgs() []string - KeySet() oidc.KeySet -} - -type accessTokenVerifier struct { - issuer string - maxAgeIAT time.Duration - offset time.Duration - supportedSignAlgs []string - keySet oidc.KeySet -} - -// Issuer implements oidc.Verifier interface -func (i *accessTokenVerifier) Issuer() string { - return i.issuer -} - -// MaxAgeIAT implements oidc.Verifier interface -func (i *accessTokenVerifier) MaxAgeIAT() time.Duration { - return i.maxAgeIAT -} - -// Offset implements oidc.Verifier interface -func (i *accessTokenVerifier) Offset() time.Duration { - return i.offset -} - -// SupportedSignAlgs implements AccessTokenVerifier interface -func (i *accessTokenVerifier) SupportedSignAlgs() []string { - return i.supportedSignAlgs -} - -// KeySet implements AccessTokenVerifier interface -func (i *accessTokenVerifier) KeySet() oidc.KeySet { - return i.keySet -} +type AccessTokenVerifier oidc.Verifier -type AccessTokenVerifierOpt func(*accessTokenVerifier) +type AccessTokenVerifierOpt func(*AccessTokenVerifier) func WithSupportedAccessTokenSigningAlgorithms(algs ...string) AccessTokenVerifierOpt { - return func(verifier *accessTokenVerifier) { - verifier.supportedSignAlgs = algs + return func(verifier *AccessTokenVerifier) { + verifier.SupportedSignAlgs = algs } } -func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTokenVerifierOpt) AccessTokenVerifier { - verifier := &accessTokenVerifier{ - issuer: issuer, - keySet: keySet, +// NewAccessTokenVerifier returns a AccessTokenVerifier suitable for access token verification. +func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTokenVerifierOpt) *AccessTokenVerifier { + verifier := &AccessTokenVerifier{ + Issuer: issuer, + KeySet: keySet, } for _, opt := range opts { opt(verifier) @@ -66,7 +29,7 @@ func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTok } // VerifyAccessToken validates the access token (issuer, signature and expiration). -func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v AccessTokenVerifier) (claims C, err error) { +func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v *AccessTokenVerifier) (claims C, err error) { var nilClaims C decrypted, err := oidc.DecryptToken(token) @@ -78,15 +41,15 @@ func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v Acces return nilClaims, err } - if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil { + if err := oidc.CheckIssuer(claims, v.Issuer); err != nil { return nilClaims, err } - if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil { + if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil { return nilClaims, err } - if err = oidc.CheckExpiration(claims, v.Offset()); err != nil { + if err = oidc.CheckExpiration(claims, v.Offset); err != nil { return nilClaims, err } diff --git a/pkg/op/verifier_access_token_example_test.go b/pkg/op/verifier_access_token_example_test.go index effdd587..397a2d35 100644 --- a/pkg/op/verifier_access_token_example_test.go +++ b/pkg/op/verifier_access_token_example_test.go @@ -4,9 +4,9 @@ import ( "context" "fmt" - tu "github.com/zitadel/oidc/v2/internal/testutil" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + tu "github.com/zitadel/oidc/v3/internal/testutil" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" ) // MyCustomClaims extends the TokenClaims base, diff --git a/pkg/op/verifier_access_token_test.go b/pkg/op/verifier_access_token_test.go index 62c26a94..66e32ceb 100644 --- a/pkg/op/verifier_access_token_test.go +++ b/pkg/op/verifier_access_token_test.go @@ -7,8 +7,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - tu "github.com/zitadel/oidc/v2/internal/testutil" - "github.com/zitadel/oidc/v2/pkg/oidc" + tu "github.com/zitadel/oidc/v3/internal/testutil" + "github.com/zitadel/oidc/v3/pkg/oidc" ) func TestNewAccessTokenVerifier(t *testing.T) { @@ -20,7 +20,7 @@ func TestNewAccessTokenVerifier(t *testing.T) { tests := []struct { name string args args - want AccessTokenVerifier + want *AccessTokenVerifier }{ { name: "simple", @@ -28,9 +28,9 @@ func TestNewAccessTokenVerifier(t *testing.T) { issuer: tu.ValidIssuer, keySet: tu.KeySet{}, }, - want: &accessTokenVerifier{ - issuer: tu.ValidIssuer, - keySet: tu.KeySet{}, + want: &AccessTokenVerifier{ + Issuer: tu.ValidIssuer, + KeySet: tu.KeySet{}, }, }, { @@ -42,10 +42,10 @@ func TestNewAccessTokenVerifier(t *testing.T) { WithSupportedAccessTokenSigningAlgorithms("ABC", "DEF"), }, }, - want: &accessTokenVerifier{ - issuer: tu.ValidIssuer, - keySet: tu.KeySet{}, - supportedSignAlgs: []string{"ABC", "DEF"}, + want: &AccessTokenVerifier{ + Issuer: tu.ValidIssuer, + KeySet: tu.KeySet{}, + SupportedSignAlgs: []string{"ABC", "DEF"}, }, }, } @@ -58,12 +58,12 @@ func TestNewAccessTokenVerifier(t *testing.T) { } func TestVerifyAccessToken(t *testing.T) { - verifier := &accessTokenVerifier{ - issuer: tu.ValidIssuer, - maxAgeIAT: 2 * time.Minute, - offset: time.Second, - supportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, - keySet: tu.KeySet{}, + verifier := &AccessTokenVerifier{ + Issuer: tu.ValidIssuer, + MaxAgeIAT: 2 * time.Minute, + Offset: time.Second, + SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, + KeySet: tu.KeySet{}, } tests := []struct { diff --git a/pkg/op/verifier_id_token_hint.go b/pkg/op/verifier_id_token_hint.go index d906075d..61432527 100644 --- a/pkg/op/verifier_id_token_hint.go +++ b/pkg/op/verifier_id_token_hint.go @@ -2,69 +2,24 @@ package op import ( "context" - "time" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/oidc" ) -type IDTokenHintVerifier interface { - oidc.Verifier - SupportedSignAlgs() []string - KeySet() oidc.KeySet - ACR() oidc.ACRVerifier - MaxAge() time.Duration -} - -type idTokenHintVerifier struct { - issuer string - maxAgeIAT time.Duration - offset time.Duration - supportedSignAlgs []string - maxAge time.Duration - acr oidc.ACRVerifier - keySet oidc.KeySet -} - -func (i *idTokenHintVerifier) Issuer() string { - return i.issuer -} - -func (i *idTokenHintVerifier) MaxAgeIAT() time.Duration { - return i.maxAgeIAT -} - -func (i *idTokenHintVerifier) Offset() time.Duration { - return i.offset -} - -func (i *idTokenHintVerifier) SupportedSignAlgs() []string { - return i.supportedSignAlgs -} - -func (i *idTokenHintVerifier) KeySet() oidc.KeySet { - return i.keySet -} - -func (i *idTokenHintVerifier) ACR() oidc.ACRVerifier { - return i.acr -} - -func (i *idTokenHintVerifier) MaxAge() time.Duration { - return i.maxAge -} +type IDTokenHintVerifier oidc.Verifier -type IDTokenHintVerifierOpt func(*idTokenHintVerifier) +type IDTokenHintVerifierOpt func(*IDTokenHintVerifier) func WithSupportedIDTokenHintSigningAlgorithms(algs ...string) IDTokenHintVerifierOpt { - return func(verifier *idTokenHintVerifier) { - verifier.supportedSignAlgs = algs + return func(verifier *IDTokenHintVerifier) { + verifier.SupportedSignAlgs = algs } } -func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHintVerifierOpt) IDTokenHintVerifier { - verifier := &idTokenHintVerifier{ - issuer: issuer, - keySet: keySet, +func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHintVerifierOpt) *IDTokenHintVerifier { + verifier := &IDTokenHintVerifier{ + Issuer: issuer, + KeySet: keySet, } for _, opt := range opts { opt(verifier) @@ -74,7 +29,7 @@ func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHi // VerifyIDTokenHint validates the id token according to // https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation -func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v IDTokenHintVerifier) (claims C, err error) { +func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v *IDTokenHintVerifier) (claims C, err error) { var nilClaims C decrypted, err := oidc.DecryptToken(token) @@ -86,27 +41,27 @@ func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v IDTok return nilClaims, err } - if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil { + if err := oidc.CheckIssuer(claims, v.Issuer); err != nil { return nilClaims, err } - if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil { + if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil { return nilClaims, err } - if err = oidc.CheckExpiration(claims, v.Offset()); err != nil { + if err = oidc.CheckExpiration(claims, v.Offset); err != nil { return nilClaims, err } - if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil { + if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT, v.Offset); err != nil { return nilClaims, err } - if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil { + if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR); err != nil { return nilClaims, err } - if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil { + if err = oidc.CheckAuthTime(claims, v.MaxAge); err != nil { return nilClaims, err } return claims, nil diff --git a/pkg/op/verifier_id_token_hint_test.go b/pkg/op/verifier_id_token_hint_test.go index f4d0b0c6..e514a76e 100644 --- a/pkg/op/verifier_id_token_hint_test.go +++ b/pkg/op/verifier_id_token_hint_test.go @@ -7,8 +7,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - tu "github.com/zitadel/oidc/v2/internal/testutil" - "github.com/zitadel/oidc/v2/pkg/oidc" + tu "github.com/zitadel/oidc/v3/internal/testutil" + "github.com/zitadel/oidc/v3/pkg/oidc" ) func TestNewIDTokenHintVerifier(t *testing.T) { @@ -20,7 +20,7 @@ func TestNewIDTokenHintVerifier(t *testing.T) { tests := []struct { name string args args - want IDTokenHintVerifier + want *IDTokenHintVerifier }{ { name: "simple", @@ -28,9 +28,9 @@ func TestNewIDTokenHintVerifier(t *testing.T) { issuer: tu.ValidIssuer, keySet: tu.KeySet{}, }, - want: &idTokenHintVerifier{ - issuer: tu.ValidIssuer, - keySet: tu.KeySet{}, + want: &IDTokenHintVerifier{ + Issuer: tu.ValidIssuer, + KeySet: tu.KeySet{}, }, }, { @@ -42,10 +42,10 @@ func TestNewIDTokenHintVerifier(t *testing.T) { WithSupportedIDTokenHintSigningAlgorithms("ABC", "DEF"), }, }, - want: &idTokenHintVerifier{ - issuer: tu.ValidIssuer, - keySet: tu.KeySet{}, - supportedSignAlgs: []string{"ABC", "DEF"}, + want: &IDTokenHintVerifier{ + Issuer: tu.ValidIssuer, + KeySet: tu.KeySet{}, + SupportedSignAlgs: []string{"ABC", "DEF"}, }, }, } @@ -58,14 +58,14 @@ func TestNewIDTokenHintVerifier(t *testing.T) { } func TestVerifyIDTokenHint(t *testing.T) { - verifier := &idTokenHintVerifier{ - issuer: tu.ValidIssuer, - maxAgeIAT: 2 * time.Minute, - offset: time.Second, - supportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, - maxAge: 2 * time.Minute, - acr: tu.ACRVerify, - keySet: tu.KeySet{}, + verifier := &IDTokenHintVerifier{ + Issuer: tu.ValidIssuer, + MaxAgeIAT: 2 * time.Minute, + Offset: time.Second, + SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, + MaxAge: 2 * time.Minute, + ACR: tu.ACRVerify, + KeySet: tu.KeySet{}, } tests := []struct { diff --git a/pkg/op/verifier_jwt_profile.go b/pkg/op/verifier_jwt_profile.go index e7c96113..3b136651 100644 --- a/pkg/op/verifier_jwt_profile.go +++ b/pkg/op/verifier_jwt_profile.go @@ -6,33 +6,30 @@ import ( "fmt" "time" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/oidc" ) -type JWTProfileVerifier interface { +// JWTProfileVerfiier extends oidc.Verifier with +// a jwtProfileKeyStorage and a function to check +// the subject in a token. +type JWTProfileVerifier struct { oidc.Verifier - Storage() jwtProfileKeyStorage - CheckSubject(request *oidc.JWTTokenRequest) error -} - -type jwtProfileVerifier struct { - storage jwtProfileKeyStorage - subjectCheck func(request *oidc.JWTTokenRequest) error - issuer string - maxAgeIAT time.Duration - offset time.Duration + Storage JWTProfileKeyStorage + CheckSubject func(request *oidc.JWTTokenRequest) error } // NewJWTProfileVerifier creates a oidc.Verifier for JWT Profile assertions (authorization grant and client authentication) -func NewJWTProfileVerifier(storage jwtProfileKeyStorage, issuer string, maxAgeIAT, offset time.Duration, opts ...JWTProfileVerifierOption) JWTProfileVerifier { - j := &jwtProfileVerifier{ - storage: storage, - subjectCheck: SubjectIsIssuer, - issuer: issuer, - maxAgeIAT: maxAgeIAT, - offset: offset, +func NewJWTProfileVerifier(storage JWTProfileKeyStorage, issuer string, maxAgeIAT, offset time.Duration, opts ...JWTProfileVerifierOption) *JWTProfileVerifier { + j := &JWTProfileVerifier{ + Verifier: oidc.Verifier{ + Issuer: issuer, + MaxAgeIAT: maxAgeIAT, + Offset: offset, + }, + Storage: storage, + CheckSubject: SubjectIsIssuer, } for _, opt := range opts { @@ -42,38 +39,20 @@ func NewJWTProfileVerifier(storage jwtProfileKeyStorage, issuer string, maxAgeIA return j } -type JWTProfileVerifierOption func(*jwtProfileVerifier) +type JWTProfileVerifierOption func(*JWTProfileVerifier) +// SubjectCheck sets a custom function to check the subject. +// Defaults to SubjectIsIssuer() func SubjectCheck(check func(request *oidc.JWTTokenRequest) error) JWTProfileVerifierOption { - return func(verifier *jwtProfileVerifier) { - verifier.subjectCheck = check + return func(verifier *JWTProfileVerifier) { + verifier.CheckSubject = check } } -func (v *jwtProfileVerifier) Issuer() string { - return v.issuer -} - -func (v *jwtProfileVerifier) Storage() jwtProfileKeyStorage { - return v.storage -} - -func (v *jwtProfileVerifier) MaxAgeIAT() time.Duration { - return v.maxAgeIAT -} - -func (v *jwtProfileVerifier) Offset() time.Duration { - return v.offset -} - -func (v *jwtProfileVerifier) CheckSubject(request *oidc.JWTTokenRequest) error { - return v.subjectCheck(request) -} - // VerifyJWTAssertion verifies the assertion string from JWT Profile (authorization grant and client authentication) // // checks audience, exp, iat, signature and that issuer and sub are the same -func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerifier) (*oidc.JWTTokenRequest, error) { +func VerifyJWTAssertion(ctx context.Context, assertion string, v *JWTProfileVerifier) (*oidc.JWTTokenRequest, error) { ctx, span := tracer.Start(ctx, "VerifyJWTAssertion") defer span.End() @@ -83,15 +62,15 @@ func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerif return nil, err } - if err = oidc.CheckAudience(request, v.Issuer()); err != nil { + if err = oidc.CheckAudience(request, v.Issuer); err != nil { return nil, err } - if err = oidc.CheckExpiration(request, v.Offset()); err != nil { + if err = oidc.CheckExpiration(request, v.Offset); err != nil { return nil, err } - if err = oidc.CheckIssuedAt(request, v.MaxAgeIAT(), v.Offset()); err != nil { + if err = oidc.CheckIssuedAt(request, v.MaxAgeIAT, v.Offset); err != nil { return nil, err } @@ -99,17 +78,18 @@ func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerif return nil, err } - keySet := &jwtProfileKeySet{storage: v.Storage(), clientID: request.Issuer} + keySet := &jwtProfileKeySet{storage: v.Storage, clientID: request.Issuer} if err = oidc.CheckSignature(ctx, assertion, payload, request, nil, keySet); err != nil { return nil, err } return request, nil } -type jwtProfileKeyStorage interface { +type JWTProfileKeyStorage interface { GetKeyByIDAndClientID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error) } +// SubjectIsIssuer func SubjectIsIssuer(request *oidc.JWTTokenRequest) error { if request.Issuer != request.Subject { return errors.New("delegation not allowed, issuer and sub must be identical") @@ -118,7 +98,7 @@ func SubjectIsIssuer(request *oidc.JWTTokenRequest) error { } type jwtProfileKeySet struct { - storage jwtProfileKeyStorage + storage JWTProfileKeyStorage clientID string } diff --git a/pkg/op/verifier_jwt_profile_test.go b/pkg/op/verifier_jwt_profile_test.go new file mode 100644 index 00000000..d96cbb43 --- /dev/null +++ b/pkg/op/verifier_jwt_profile_test.go @@ -0,0 +1,117 @@ +package op_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + tu "github.com/zitadel/oidc/v3/internal/testutil" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" +) + +func TestNewJWTProfileVerifier(t *testing.T) { + want := &op.JWTProfileVerifier{ + Verifier: oidc.Verifier{ + Issuer: tu.ValidIssuer, + MaxAgeIAT: time.Minute, + Offset: time.Second, + }, + Storage: tu.JWTProfileKeyStorage{}, + } + got := op.NewJWTProfileVerifier(tu.JWTProfileKeyStorage{}, tu.ValidIssuer, time.Minute, time.Second, op.SubjectCheck(func(request *oidc.JWTTokenRequest) error { + return oidc.ErrSubjectMissing + })) + assert.Equal(t, want.Verifier, got.Verifier) + assert.Equal(t, want.Storage, got.Storage) + assert.ErrorIs(t, got.CheckSubject(nil), oidc.ErrSubjectMissing) +} + +func TestVerifyJWTAssertion(t *testing.T) { + errCtx, cancel := context.WithCancel(context.Background()) + cancel() + + verifier := op.NewJWTProfileVerifier(tu.JWTProfileKeyStorage{}, tu.ValidIssuer, time.Minute, 0) + tests := []struct { + name string + ctx context.Context + newToken func() (string, *oidc.JWTTokenRequest) + wantErr bool + }{ + { + name: "parse error", + ctx: context.Background(), + newToken: func() (string, *oidc.JWTTokenRequest) { return "!", nil }, + wantErr: true, + }, + { + name: "wrong audience", + ctx: context.Background(), + newToken: func() (string, *oidc.JWTTokenRequest) { + return tu.NewJWTProfileAssertion( + tu.ValidClientID, tu.ValidClientID, []string{"wrong"}, + time.Now(), tu.ValidExpiration, + ) + }, + wantErr: true, + }, + { + name: "expired", + ctx: context.Background(), + newToken: func() (string, *oidc.JWTTokenRequest) { + return tu.NewJWTProfileAssertion( + tu.ValidClientID, tu.ValidClientID, []string{tu.ValidIssuer}, + time.Now(), time.Now().Add(-time.Hour), + ) + }, + wantErr: true, + }, + { + name: "invalid iat", + ctx: context.Background(), + newToken: func() (string, *oidc.JWTTokenRequest) { + return tu.NewJWTProfileAssertion( + tu.ValidClientID, tu.ValidClientID, []string{tu.ValidIssuer}, + time.Now().Add(time.Hour), tu.ValidExpiration, + ) + }, + wantErr: true, + }, + { + name: "invalid subject", + ctx: context.Background(), + newToken: func() (string, *oidc.JWTTokenRequest) { + return tu.NewJWTProfileAssertion( + tu.ValidClientID, "wrong", []string{tu.ValidIssuer}, + time.Now(), tu.ValidExpiration, + ) + }, + wantErr: true, + }, + { + name: "check signature fail", + ctx: errCtx, + newToken: tu.ValidJWTProfileAssertion, + wantErr: true, + }, + { + name: "ok", + ctx: context.Background(), + newToken: tu.ValidJWTProfileAssertion, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assertion, want := tt.newToken() + got, err := op.VerifyJWTAssertion(tt.ctx, assertion, verifier) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, want, got) + }) + } +}