From 57fb9f77aa5db4486af67a67132063a3b841abae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Fri, 17 Mar 2023 17:36:02 +0200 Subject: [PATCH 01/14] chore: replace gorilla/mux with go-chi/chi (#332) BREAKING CHANGE: The returned router from `op.CreateRouter()` is now a `chi.Router` Closes #301 --- example/client/api/api.go | 10 ++++---- example/server/dynamic/login.go | 10 ++++---- example/server/dynamic/op.go | 8 +++---- example/server/exampleop/device.go | 8 +++---- example/server/exampleop/login.go | 10 ++++---- example/server/exampleop/op.go | 15 ++++++------ go.mod | 2 +- go.sum | 15 ++---------- pkg/op/auth_request.go | 22 +++++++++++------ pkg/op/auth_request_test.go | 38 +++++++++++++++++++++++++++++- pkg/op/op.go | 18 ++++++++------ pkg/op/op_test.go | 2 +- 12 files changed, 98 insertions(+), 60 deletions(-) diff --git a/example/client/api/api.go b/example/client/api/api.go index 8093b636..9f654a9e 100644 --- a/example/client/api/api.go +++ b/example/client/api/api.go @@ -9,7 +9,7 @@ import ( "strings" "time" - "github.com/gorilla/mux" + "github.com/go-chi/chi" "github.com/sirupsen/logrus" "github.com/zitadel/oidc/v2/pkg/client/rs" @@ -32,7 +32,7 @@ func main() { logrus.Fatalf("error creating provider %s", err.Error()) } - router := mux.NewRouter() + router := chi.NewRouter() // public url accessible without any authorization // will print `OK` and current timestamp @@ -73,9 +73,9 @@ func main() { http.Error(w, err.Error(), http.StatusForbidden) return } - params := mux.Vars(r) - requestedClaim := params["claim"] - requestedValue := params["value"] + requestedClaim := chi.URLParam(r, "claim") + requestedValue := chi.URLParam(r, "value") + value, ok := resp.Claims[requestedClaim].(string) if !ok || value == "" || value != requestedValue { http.Error(w, "claim does not match", http.StatusForbidden) diff --git a/example/server/dynamic/login.go b/example/server/dynamic/login.go index e7c6e5fd..eb5340e3 100644 --- a/example/server/dynamic/login.go +++ b/example/server/dynamic/login.go @@ -6,7 +6,7 @@ import ( "html/template" "net/http" - "github.com/gorilla/mux" + "github.com/go-chi/chi" "github.com/zitadel/oidc/v2/pkg/op" ) @@ -43,7 +43,7 @@ var ( type login struct { authenticate authenticate - router *mux.Router + router chi.Router callback func(context.Context, string) string } @@ -57,9 +57,9 @@ func NewLogin(authenticate authenticate, callback func(context.Context, string) } func (l *login) createRouter(issuerInterceptor *op.IssuerInterceptor) { - l.router = mux.NewRouter() - l.router.Path("/username").Methods("GET").HandlerFunc(l.loginHandler) - l.router.Path("/username").Methods("POST").HandlerFunc(issuerInterceptor.HandlerFunc(l.checkLoginHandler)) + l.router = chi.NewRouter() + l.router.Get("/username", l.loginHandler) + l.router.With(issuerInterceptor.Handler).Post("/username", l.checkLoginHandler) } type authenticate interface { diff --git a/example/server/dynamic/op.go b/example/server/dynamic/op.go index 783c75cf..2bb68329 100644 --- a/example/server/dynamic/op.go +++ b/example/server/dynamic/op.go @@ -7,7 +7,7 @@ import ( "log" "net/http" - "github.com/gorilla/mux" + "github.com/go-chi/chi" "golang.org/x/text/language" "github.com/zitadel/oidc/v2/example/server/storage" @@ -47,7 +47,7 @@ func main() { //be sure to create a proper crypto random key and manage it securely! key := sha256.Sum256([]byte("test")) - router := mux.NewRouter() + router := chi.NewRouter() //for simplicity, we provide a very small default page for users who have signed out router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) { @@ -76,7 +76,7 @@ func main() { //regardless of how many pages / steps there are in the process, the UI must be registered in the router, //so we will direct all calls to /login to the login UI - router.PathPrefix("/login/").Handler(http.StripPrefix("/login", l.router)) + router.Mount("/login/", http.StripPrefix("/login", l.router)) //we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration) //is served on the correct path @@ -84,7 +84,7 @@ func main() { //if your issuer ends with a path (e.g. http://localhost:9998/custom/path/), //then you would have to set the path prefix (/custom/path/): //router.PathPrefix("/custom/path/").Handler(http.StripPrefix("/custom/path", provider.HttpHandler())) - router.PathPrefix("/").Handler(provider.HttpHandler()) + router.Mount("/", provider) server := &http.Server{ Addr: ":" + port, diff --git a/example/server/exampleop/device.go b/example/server/exampleop/device.go index ae2e8f29..59c2196f 100644 --- a/example/server/exampleop/device.go +++ b/example/server/exampleop/device.go @@ -7,7 +7,7 @@ import ( "net/http" "net/url" - "github.com/gorilla/mux" + "github.com/go-chi/chi" "github.com/gorilla/securecookie" "github.com/sirupsen/logrus" "github.com/zitadel/oidc/v2/pkg/op" @@ -23,14 +23,14 @@ type deviceLogin struct { cookie *securecookie.SecureCookie } -func registerDeviceAuth(storage deviceAuthenticate, router *mux.Router) { +func registerDeviceAuth(storage deviceAuthenticate, router chi.Router) { l := &deviceLogin{ storage: storage, cookie: securecookie.New(securecookie.GenerateRandomKey(32), nil), } - router.HandleFunc("", l.userCodeHandler) - router.Path("/login").Methods(http.MethodPost).HandlerFunc(l.loginHandler) + router.HandleFunc("/", l.userCodeHandler) + router.Post("/login", l.loginHandler) router.HandleFunc("/confirm", l.confirmHandler) } diff --git a/example/server/exampleop/login.go b/example/server/exampleop/login.go index c014c9ad..9facb902 100644 --- a/example/server/exampleop/login.go +++ b/example/server/exampleop/login.go @@ -5,12 +5,12 @@ import ( "fmt" "net/http" - "github.com/gorilla/mux" + "github.com/go-chi/chi" ) type login struct { authenticate authenticate - router *mux.Router + router chi.Router callback func(context.Context, string) string } @@ -24,9 +24,9 @@ func NewLogin(authenticate authenticate, callback func(context.Context, string) } func (l *login) createRouter() { - l.router = mux.NewRouter() - l.router.Path("/username").Methods("GET").HandlerFunc(l.loginHandler) - l.router.Path("/username").Methods("POST").HandlerFunc(l.checkLoginHandler) + l.router = chi.NewRouter() + l.router.Get("/username", l.loginHandler) + l.router.Post("/username", l.checkLoginHandler) } type authenticate interface { diff --git a/example/server/exampleop/op.go b/example/server/exampleop/op.go index 56044839..077244c4 100644 --- a/example/server/exampleop/op.go +++ b/example/server/exampleop/op.go @@ -6,7 +6,7 @@ import ( "net/http" "time" - "github.com/gorilla/mux" + "github.com/go-chi/chi" "golang.org/x/text/language" "github.com/zitadel/oidc/v2/example/server/storage" @@ -34,12 +34,12 @@ type Storage interface { // SetupServer creates an OIDC server with Issuer=http://localhost: // // Use one of the pre-made clients in storage/clients.go or register a new one. -func SetupServer(issuer string, storage Storage) *mux.Router { +func SetupServer(issuer string, storage Storage) chi.Router { // the OpenID Provider requires a 32-byte key for (token) encryption // be sure to create a proper crypto random key and manage it securely! key := sha256.Sum256([]byte("test")) - router := mux.NewRouter() + router := chi.NewRouter() // for simplicity, we provide a very small default page for users who have signed out router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) { @@ -61,17 +61,18 @@ func SetupServer(issuer string, storage Storage) *mux.Router { // regardless of how many pages / steps there are in the process, the UI must be registered in the router, // so we will direct all calls to /login to the login UI - router.PathPrefix("/login/").Handler(http.StripPrefix("/login", l.router)) + router.Mount("/login/", http.StripPrefix("/login", l.router)) - router.PathPrefix("/device").Subrouter() - registerDeviceAuth(storage, router.PathPrefix("/device").Subrouter()) + router.Route("/device", func(r chi.Router) { + registerDeviceAuth(storage, r) + }) // we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration) // is served on the correct path // // if your issuer ends with a path (e.g. http://localhost:9998/custom/path/), // then you would have to set the path prefix (/custom/path/) - router.PathPrefix("/").Handler(provider.HttpHandler()) + router.Mount("/", provider) return router } diff --git a/go.mod b/go.mod index 75942643..a6362504 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,10 @@ module github.com/zitadel/oidc/v2 go 1.18 require ( + github.com/go-chi/chi v1.5.4 github.com/golang/mock v1.6.0 github.com/google/go-github/v31 v31.0.0 github.com/google/uuid v1.3.0 - github.com/gorilla/mux v1.8.0 github.com/gorilla/schema v1.2.0 github.com/gorilla/securecookie v1.1.1 github.com/jeremija/gosubmit v0.2.7 diff --git a/go.sum b/go.sum index e4e5c6c8..a5ba6426 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-chi/chi v1.5.4 h1:QHdzF2szwjqVV4wmByUnTcsbIg7UGaQ0tPF2t5GcAIs= +github.com/go-chi/chi v1.5.4/go.mod h1:uaf8YgoFazUOkPBG7fxPftUylNumIev9awIWOENIuEg= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -19,8 +21,6 @@ github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= -github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/schema v1.2.0 h1:YufUaxZYCKGFuAq3c96BOhjgd5nmXiOY9NGzF247Tsc= github.com/gorilla/schema v1.2.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU= github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= @@ -34,8 +34,6 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/muhlemmer/gu v0.3.1 h1:7EAqmFrW7n3hETvuAdmFmn4hS8W+z3LgKtrnow+YzNM= github.com/muhlemmer/gu v0.3.1/go.mod h1:YHtHR+gxM+bKEIIs7Hmi9sPT3ZDUvTN/i88wQpZkrdM= -github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= -github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rs/cors v1.8.3 h1:O+qNyWn7Z+F9M0ILBHgMVPuB1xTOucVd5gtaYyXBpRo= @@ -50,9 +48,6 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -84,12 +79,6 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9sn golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= -golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/pkg/op/auth_request.go b/pkg/op/auth_request.go index b3120988..4c483637 100644 --- a/pkg/op/auth_request.go +++ b/pkg/op/auth_request.go @@ -2,6 +2,7 @@ package op import ( "context" + "errors" "fmt" "net" "net/http" @@ -10,8 +11,6 @@ import ( "strings" "time" - "github.com/gorilla/mux" - httphelper "github.com/zitadel/oidc/v2/pkg/http" "github.com/zitadel/oidc/v2/pkg/oidc" str "github.com/zitadel/oidc/v2/pkg/strings" @@ -405,13 +404,11 @@ func RedirectToLogin(authReqID string, client Client, w http.ResponseWriter, r * // AuthorizeCallback handles the callback after authentication in the Login UI func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { - params := mux.Vars(r) - id := params["id"] - if id == "" { - AuthRequestError(w, r, nil, fmt.Errorf("auth request callback is missing id"), authorizer.Encoder()) + id, err := ParseAuthorizeCallbackRequest(r) + if err != nil { + AuthRequestError(w, r, nil, err, authorizer.Encoder()) return } - authReq, err := authorizer.Storage().AuthRequestByID(r.Context(), id) if err != nil { AuthRequestError(w, r, nil, err, authorizer.Encoder()) @@ -426,6 +423,17 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author AuthResponse(authReq, authorizer, w, r) } +func ParseAuthorizeCallbackRequest(r *http.Request) (id string, err error) { + if err = r.ParseForm(); err != nil { + return "", fmt.Errorf("cannot parse form: %w", err) + } + id = r.Form.Get("id") + if id == "" { + return "", errors.New("auth request callback is missing id") + } + return id, nil +} + // AuthResponse creates the successful authentication response (either code or tokens) func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) { client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID()) diff --git a/pkg/op/auth_request_test.go b/pkg/op/auth_request_test.go index 7a9701bd..542f2e2c 100644 --- a/pkg/op/auth_request_test.go +++ b/pkg/op/auth_request_test.go @@ -12,7 +12,6 @@ import ( "github.com/gorilla/schema" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - httphelper "github.com/zitadel/oidc/v2/pkg/http" "github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v2/pkg/op" @@ -967,3 +966,40 @@ func (m *mockEncoder) Encode(src interface{}, dst map[string][]string) error { } return nil } + +func Test_parseAuthorizeCallbackRequest(t *testing.T) { + tests := []struct { + name string + url string + wantId string + wantErr bool + }{ + { + name: "parse error", + url: "/?id;=99", + wantErr: true, + }, + { + name: "missing id", + url: "/", + wantErr: true, + }, + { + name: "ok", + url: "/?id=99", + wantId: "99", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, tt.url, nil) + gotId, err := op.ParseAuthorizeCallbackRequest(r) + if tt.wantErr { + assert.Error(t, err) + } else { + require.NoError(t, err) + } + assert.Equal(t, tt.wantId, gotId) + }) + } +} diff --git a/pkg/op/op.go b/pkg/op/op.go index ecb753ec..0536bbc4 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -6,7 +6,7 @@ import ( "net/http" "time" - "github.com/gorilla/mux" + "github.com/go-chi/chi" "github.com/gorilla/schema" "github.com/rs/cors" "golang.org/x/text/language" @@ -68,6 +68,7 @@ var ( ) type OpenIDProvider interface { + http.Handler Configuration Storage() Storage Decoder() httphelper.Decoder @@ -77,20 +78,22 @@ type OpenIDProvider interface { Crypto() Crypto DefaultLogoutRedirectURI() string Probes() []ProbesFn + + // Deprecated: Provider now implements http.Handler directly. HttpHandler() http.Handler } type HttpInterceptor func(http.Handler) http.Handler -func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router { - router := mux.NewRouter() +func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) chi.Router { + router := chi.NewRouter() router.Use(cors.New(defaultCORSOptions).Handler) router.Use(intercept(o.IssuerFromRequest, interceptors...)) router.HandleFunc(healthEndpoint, healthHandler) router.HandleFunc(readinessEndpoint, readyHandler(o.Probes())) router.HandleFunc(oidc.DiscoveryEndpoint, discoveryHandler(o, o.Storage())) router.HandleFunc(o.AuthorizationEndpoint().Relative(), authorizeHandler(o)) - router.NewRoute().Path(authCallbackPath(o)).Queries("id", "{id}").HandlerFunc(authorizeCallbackHandler(o)) + router.HandleFunc(authCallbackPath(o), authorizeCallbackHandler(o)) router.HandleFunc(o.TokenEndpoint().Relative(), tokenHandler(o)) router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o)) router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o)) @@ -184,7 +187,7 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR return nil, err } - o.httpHandler = CreateRouter(o, o.interceptors...) + o.Handler = CreateRouter(o, o.interceptors...) o.decoder = schema.NewDecoder() o.decoder.IgnoreUnknownKeys(true) @@ -200,6 +203,7 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR } type Provider struct { + http.Handler config *Config issuer IssuerFromRequest insecure bool @@ -207,7 +211,6 @@ type Provider struct { storage Storage keySet *openIDKeySet crypto Crypto - httpHandler http.Handler decoder *schema.Decoder encoder *schema.Encoder interceptors []HttpInterceptor @@ -372,8 +375,9 @@ func (o *Provider) Probes() []ProbesFn { } } +// Deprecated: Provider now implements http.Handler directly. func (o *Provider) HttpHandler() http.Handler { - return o.httpHandler + return o } type openIDKeySet struct { diff --git a/pkg/op/op_test.go b/pkg/op/op_test.go index ba3570bd..8429212a 100644 --- a/pkg/op/op_test.go +++ b/pkg/op/op_test.go @@ -365,7 +365,7 @@ func TestRoutes(t *testing.T) { } rec := httptest.NewRecorder() - testProvider.HttpHandler().ServeHTTP(rec, req) + testProvider.ServeHTTP(rec, req) resp := rec.Result() require.NoError(t, err) From c8cf15e26609b6fe63c9464ac9e6095dd34eabc2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Fri, 17 Mar 2023 18:41:41 +0200 Subject: [PATCH 02/14] upgrade this module to v3 --- README.md | 8 ++++---- example/client/api/api.go | 4 ++-- example/client/app/app.go | 6 +++--- example/client/device/device.go | 4 ++-- example/client/github/github.go | 8 ++++---- example/client/service/service.go | 2 +- example/server/dynamic/login.go | 2 +- example/server/dynamic/op.go | 4 ++-- example/server/exampleop/device.go | 2 +- example/server/exampleop/op.go | 4 ++-- example/server/main.go | 4 ++-- example/server/storage/client.go | 4 ++-- example/server/storage/oidc.go | 4 ++-- example/server/storage/storage.go | 4 ++-- example/server/storage/storage_dynamic.go | 4 ++-- go.mod | 2 +- internal/testutil/gen/gen.go | 4 ++-- internal/testutil/token.go | 2 +- pkg/client/client.go | 6 +++--- pkg/client/integration_test.go | 14 +++++++------- pkg/client/jwt_profile.go | 4 ++-- pkg/client/profile/jwt_profile.go | 4 ++-- pkg/client/rp/cli/cli.go | 6 +++--- pkg/client/rp/delegation.go | 2 +- pkg/client/rp/device.go | 4 ++-- pkg/client/rp/jwks.go | 4 ++-- pkg/client/rp/relying_party.go | 6 +++--- pkg/client/rp/tockenexchange.go | 2 +- pkg/client/rp/verifier.go | 2 +- pkg/client/rp/verifier_test.go | 4 ++-- pkg/client/rp/verifier_tokens_example_test.go | 6 +++--- pkg/client/rs/resource_server.go | 6 +++--- pkg/client/tokenexchange/tokenexchange.go | 6 +++--- pkg/oidc/code_challenge.go | 2 +- pkg/oidc/token.go | 2 +- pkg/oidc/verifier.go | 2 +- pkg/op/auth_request.go | 6 +++--- pkg/op/auth_request_test.go | 8 ++++---- pkg/op/client.go | 4 ++-- pkg/op/client_test.go | 8 ++++---- pkg/op/crypto.go | 2 +- pkg/op/device.go | 4 ++-- pkg/op/device_test.go | 4 ++-- pkg/op/discovery.go | 4 ++-- pkg/op/discovery_test.go | 6 +++--- pkg/op/endpoint_test.go | 2 +- pkg/op/error.go | 4 ++-- pkg/op/keys.go | 2 +- pkg/op/keys_test.go | 6 +++--- pkg/op/mock/authorizer.mock.go | 6 +++--- pkg/op/mock/authorizer.mock.impl.go | 4 ++-- pkg/op/mock/client.go | 4 ++-- pkg/op/mock/client.mock.go | 6 +++--- pkg/op/mock/configuration.mock.go | 4 ++-- pkg/op/mock/discovery.mock.go | 2 +- pkg/op/mock/generate.go | 14 +++++++------- pkg/op/mock/key.mock.go | 4 ++-- pkg/op/mock/signer.mock.go | 2 +- pkg/op/mock/storage.mock.go | 6 +++--- pkg/op/mock/storage.mock.impl.go | 4 ++-- pkg/op/op.go | 4 ++-- pkg/op/op_test.go | 6 +++--- pkg/op/probes.go | 2 +- pkg/op/session.go | 4 ++-- pkg/op/storage.go | 2 +- pkg/op/token.go | 6 +++--- pkg/op/token_client_credentials.go | 4 ++-- pkg/op/token_code.go | 4 ++-- pkg/op/token_exchange.go | 4 ++-- pkg/op/token_intospection.go | 4 ++-- pkg/op/token_jwt_profile.go | 4 ++-- pkg/op/token_refresh.go | 6 +++--- pkg/op/token_request.go | 4 ++-- pkg/op/token_revocation.go | 4 ++-- pkg/op/userinfo.go | 4 ++-- pkg/op/verifier_access_token.go | 2 +- pkg/op/verifier_access_token_example_test.go | 6 +++--- pkg/op/verifier_access_token_test.go | 4 ++-- pkg/op/verifier_id_token_hint.go | 2 +- pkg/op/verifier_id_token_hint_test.go | 4 ++-- pkg/op/verifier_jwt_profile.go | 2 +- 81 files changed, 176 insertions(+), 176 deletions(-) diff --git a/README.md b/README.md index f369a5c7..b7993e69 100644 --- a/README.md +++ b/README.md @@ -44,9 +44,9 @@ Check the `/example` folder where example code for different scenarios is locate ```bash # start oidc op server # oidc discovery http://localhost:9998/.well-known/openid-configuration -go run github.com/zitadel/oidc/v2/example/server +go run github.com/zitadel/oidc/v3/example/server # start oidc web client (in a new terminal) -CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://localhost:9998/ SCOPES="openid profile" PORT=9999 go run github.com/zitadel/oidc/v2/example/client/app +CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://localhost:9998/ SCOPES="openid profile" PORT=9999 go run github.com/zitadel/oidc/v3/example/client/app ``` - open http://localhost:9999/login in your browser @@ -56,11 +56,11 @@ CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://localhost:9998/ SCOPES="openid for the dynamic issuer, just start it with: ```bash -go run github.com/zitadel/oidc/v2/example/server/dynamic +go run github.com/zitadel/oidc/v3/example/server/dynamic ``` the oidc web client above will still work, but if you add `oidc.local` (pointing to 127.0.0.1) in your hosts file you can also start it with: ```bash -CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://oidc.local:9998/ SCOPES="openid profile" PORT=9999 go run github.com/zitadel/oidc/v2/example/client/app +CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://oidc.local:9998/ SCOPES="openid profile" PORT=9999 go run github.com/zitadel/oidc/v3/example/client/app ``` > Note: Usernames are suffixed with the hostname (`test-user@localhost` or `test-user@oidc.local`) diff --git a/example/client/api/api.go b/example/client/api/api.go index 9f654a9e..95e84e7e 100644 --- a/example/client/api/api.go +++ b/example/client/api/api.go @@ -12,8 +12,8 @@ import ( "github.com/go-chi/chi" "github.com/sirupsen/logrus" - "github.com/zitadel/oidc/v2/pkg/client/rs" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/client/rs" + "github.com/zitadel/oidc/v3/pkg/oidc" ) const ( diff --git a/example/client/app/app.go b/example/client/app/app.go index 0c324d20..446c17be 100644 --- a/example/client/app/app.go +++ b/example/client/app/app.go @@ -11,9 +11,9 @@ import ( "github.com/google/uuid" "github.com/sirupsen/logrus" - "github.com/zitadel/oidc/v2/pkg/client/rp" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/client/rp" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) var ( diff --git a/example/client/device/device.go b/example/client/device/device.go index 284ba372..88ecfe99 100644 --- a/example/client/device/device.go +++ b/example/client/device/device.go @@ -11,8 +11,8 @@ import ( "github.com/sirupsen/logrus" - "github.com/zitadel/oidc/v2/pkg/client/rp" - httphelper "github.com/zitadel/oidc/v2/pkg/http" + "github.com/zitadel/oidc/v3/pkg/client/rp" + httphelper "github.com/zitadel/oidc/v3/pkg/http" ) var ( diff --git a/example/client/github/github.go b/example/client/github/github.go index 9cb813c0..7d069d49 100644 --- a/example/client/github/github.go +++ b/example/client/github/github.go @@ -10,10 +10,10 @@ import ( "golang.org/x/oauth2" githubOAuth "golang.org/x/oauth2/github" - "github.com/zitadel/oidc/v2/pkg/client/rp" - "github.com/zitadel/oidc/v2/pkg/client/rp/cli" - "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/client/rp" + "github.com/zitadel/oidc/v3/pkg/client/rp/cli" + "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) var ( diff --git a/example/client/service/service.go b/example/client/service/service.go index 95261743..4908b095 100644 --- a/example/client/service/service.go +++ b/example/client/service/service.go @@ -13,7 +13,7 @@ import ( "github.com/sirupsen/logrus" "golang.org/x/oauth2" - "github.com/zitadel/oidc/v2/pkg/client/profile" + "github.com/zitadel/oidc/v3/pkg/client/profile" ) var client = http.DefaultClient diff --git a/example/server/dynamic/login.go b/example/server/dynamic/login.go index eb5340e3..d90fb8e5 100644 --- a/example/server/dynamic/login.go +++ b/example/server/dynamic/login.go @@ -8,7 +8,7 @@ import ( "github.com/go-chi/chi" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v3/pkg/op" ) const ( diff --git a/example/server/dynamic/op.go b/example/server/dynamic/op.go index 2bb68329..16627296 100644 --- a/example/server/dynamic/op.go +++ b/example/server/dynamic/op.go @@ -10,8 +10,8 @@ import ( "github.com/go-chi/chi" "golang.org/x/text/language" - "github.com/zitadel/oidc/v2/example/server/storage" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v3/example/server/storage" + "github.com/zitadel/oidc/v3/pkg/op" ) const ( diff --git a/example/server/exampleop/device.go b/example/server/exampleop/device.go index 59c2196f..0dda3d53 100644 --- a/example/server/exampleop/device.go +++ b/example/server/exampleop/device.go @@ -10,7 +10,7 @@ import ( "github.com/go-chi/chi" "github.com/gorilla/securecookie" "github.com/sirupsen/logrus" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v3/pkg/op" ) type deviceAuthenticate interface { diff --git a/example/server/exampleop/op.go b/example/server/exampleop/op.go index 077244c4..120590fa 100644 --- a/example/server/exampleop/op.go +++ b/example/server/exampleop/op.go @@ -9,8 +9,8 @@ import ( "github.com/go-chi/chi" "golang.org/x/text/language" - "github.com/zitadel/oidc/v2/example/server/storage" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v3/example/server/storage" + "github.com/zitadel/oidc/v3/pkg/op" ) const ( diff --git a/example/server/main.go b/example/server/main.go index a2836eac..ee27bbab 100644 --- a/example/server/main.go +++ b/example/server/main.go @@ -5,8 +5,8 @@ import ( "log" "net/http" - "github.com/zitadel/oidc/v2/example/server/exampleop" - "github.com/zitadel/oidc/v2/example/server/storage" + "github.com/zitadel/oidc/v3/example/server/exampleop" + "github.com/zitadel/oidc/v3/example/server/storage" ) func main() { diff --git a/example/server/storage/client.go b/example/server/storage/client.go index b850053d..300ce0a2 100644 --- a/example/server/storage/client.go +++ b/example/server/storage/client.go @@ -3,8 +3,8 @@ package storage import ( "time" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" ) var ( diff --git a/example/server/storage/oidc.go b/example/server/storage/oidc.go index f5412cf1..b56ad090 100644 --- a/example/server/storage/oidc.go +++ b/example/server/storage/oidc.go @@ -5,8 +5,8 @@ import ( "golang.org/x/text/language" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" ) const ( diff --git a/example/server/storage/storage.go b/example/server/storage/storage.go index 7e1afbd5..2aeefe71 100644 --- a/example/server/storage/storage.go +++ b/example/server/storage/storage.go @@ -14,8 +14,8 @@ import ( "github.com/google/uuid" "gopkg.in/square/go-jose.v2" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" ) // serviceKey1 is a public key which will be used for the JWT Profile Authorization Grant diff --git a/example/server/storage/storage_dynamic.go b/example/server/storage/storage_dynamic.go index 6e5ee321..3aec9d72 100644 --- a/example/server/storage/storage_dynamic.go +++ b/example/server/storage/storage_dynamic.go @@ -6,8 +6,8 @@ import ( "gopkg.in/square/go-jose.v2" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" ) type multiStorage struct { diff --git a/go.mod b/go.mod index a6362504..7cee26ef 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/zitadel/oidc/v2 +module github.com/zitadel/oidc/v3 go 1.18 diff --git a/internal/testutil/gen/gen.go b/internal/testutil/gen/gen.go index a9f5925e..e4a57189 100644 --- a/internal/testutil/gen/gen.go +++ b/internal/testutil/gen/gen.go @@ -8,8 +8,8 @@ import ( "fmt" "os" - tu "github.com/zitadel/oidc/v2/internal/testutil" - "github.com/zitadel/oidc/v2/pkg/oidc" + tu "github.com/zitadel/oidc/v3/internal/testutil" + "github.com/zitadel/oidc/v3/pkg/oidc" ) var custom = map[string]any{ diff --git a/internal/testutil/token.go b/internal/testutil/token.go index 121aa0ba..27cab5d1 100644 --- a/internal/testutil/token.go +++ b/internal/testutil/token.go @@ -8,7 +8,7 @@ import ( "errors" "time" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/oidc" "gopkg.in/square/go-jose.v2" ) diff --git a/pkg/client/client.go b/pkg/client/client.go index 9eda973e..e9af8ce7 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -14,9 +14,9 @@ import ( "golang.org/x/oauth2" "gopkg.in/square/go-jose.v2" - "github.com/zitadel/oidc/v2/pkg/crypto" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/crypto" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) var Encoder = httphelper.Encoder(oidc.NewEncoder()) diff --git a/pkg/client/integration_test.go b/pkg/client/integration_test.go index e19a7202..709d5a13 100644 --- a/pkg/client/integration_test.go +++ b/pkg/client/integration_test.go @@ -18,13 +18,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/zitadel/oidc/v2/example/server/exampleop" - "github.com/zitadel/oidc/v2/example/server/storage" - "github.com/zitadel/oidc/v2/pkg/client/rp" - "github.com/zitadel/oidc/v2/pkg/client/rs" - "github.com/zitadel/oidc/v2/pkg/client/tokenexchange" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/example/server/exampleop" + "github.com/zitadel/oidc/v3/example/server/storage" + "github.com/zitadel/oidc/v3/pkg/client/rp" + "github.com/zitadel/oidc/v3/pkg/client/rs" + "github.com/zitadel/oidc/v3/pkg/client/tokenexchange" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) func TestRelyingPartySession(t *testing.T) { diff --git a/pkg/client/jwt_profile.go b/pkg/client/jwt_profile.go index 1686de62..486d998e 100644 --- a/pkg/client/jwt_profile.go +++ b/pkg/client/jwt_profile.go @@ -5,8 +5,8 @@ import ( "golang.org/x/oauth2" - "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) // JWTProfileExchange handles the oauth2 jwt profile exchange diff --git a/pkg/client/profile/jwt_profile.go b/pkg/client/profile/jwt_profile.go index a934f7d5..bb185707 100644 --- a/pkg/client/profile/jwt_profile.go +++ b/pkg/client/profile/jwt_profile.go @@ -7,8 +7,8 @@ import ( "golang.org/x/oauth2" "gopkg.in/square/go-jose.v2" - "github.com/zitadel/oidc/v2/pkg/client" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/client" + "github.com/zitadel/oidc/v3/pkg/oidc" ) // jwtProfileTokenSource implement the oauth2.TokenSource diff --git a/pkg/client/rp/cli/cli.go b/pkg/client/rp/cli/cli.go index 91b200d8..eeb90112 100644 --- a/pkg/client/rp/cli/cli.go +++ b/pkg/client/rp/cli/cli.go @@ -4,9 +4,9 @@ import ( "context" "net/http" - "github.com/zitadel/oidc/v2/pkg/client/rp" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/client/rp" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) const ( diff --git a/pkg/client/rp/delegation.go b/pkg/client/rp/delegation.go index b16a39e2..23ecffd0 100644 --- a/pkg/client/rp/delegation.go +++ b/pkg/client/rp/delegation.go @@ -1,7 +1,7 @@ package rp import ( - "github.com/zitadel/oidc/v2/pkg/oidc/grants/tokenexchange" + "github.com/zitadel/oidc/v3/pkg/oidc/grants/tokenexchange" ) // DelegationTokenRequest is an implementation of TokenExchangeRequest diff --git a/pkg/client/rp/device.go b/pkg/client/rp/device.go index 73b67cae..9cfc41e5 100644 --- a/pkg/client/rp/device.go +++ b/pkg/client/rp/device.go @@ -5,8 +5,8 @@ import ( "fmt" "time" - "github.com/zitadel/oidc/v2/pkg/client" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/client" + "github.com/zitadel/oidc/v3/pkg/oidc" ) func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc.ClientCredentialsRequest, error) { diff --git a/pkg/client/rp/jwks.go b/pkg/client/rp/jwks.go index 3438bd6f..79cf2322 100644 --- a/pkg/client/rp/jwks.go +++ b/pkg/client/rp/jwks.go @@ -9,8 +9,8 @@ import ( "gopkg.in/square/go-jose.v2" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) func NewRemoteKeySet(client *http.Client, jwksURL string, opts ...func(*remoteKeySet)) oidc.KeySet { diff --git a/pkg/client/rp/relying_party.go b/pkg/client/rp/relying_party.go index ede74538..725715e4 100644 --- a/pkg/client/rp/relying_party.go +++ b/pkg/client/rp/relying_party.go @@ -14,9 +14,9 @@ import ( "golang.org/x/oauth2" "gopkg.in/square/go-jose.v2" - "github.com/zitadel/oidc/v2/pkg/client" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/client" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) const ( diff --git a/pkg/client/rp/tockenexchange.go b/pkg/client/rp/tockenexchange.go index c1ac88d2..c8ca048f 100644 --- a/pkg/client/rp/tockenexchange.go +++ b/pkg/client/rp/tockenexchange.go @@ -5,7 +5,7 @@ import ( "golang.org/x/oauth2" - "github.com/zitadel/oidc/v2/pkg/oidc/grants/tokenexchange" + "github.com/zitadel/oidc/v3/pkg/oidc/grants/tokenexchange" ) // TokenExchangeRP extends the `RelyingParty` interface for the *draft* oauth2 `Token Exchange` diff --git a/pkg/client/rp/verifier.go b/pkg/client/rp/verifier.go index 75d149bd..0cf427a7 100644 --- a/pkg/client/rp/verifier.go +++ b/pkg/client/rp/verifier.go @@ -6,7 +6,7 @@ import ( "gopkg.in/square/go-jose.v2" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/oidc" ) type IDTokenVerifier interface { diff --git a/pkg/client/rp/verifier_test.go b/pkg/client/rp/verifier_test.go index 7588c1ff..002d65d3 100644 --- a/pkg/client/rp/verifier_test.go +++ b/pkg/client/rp/verifier_test.go @@ -7,8 +7,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - tu "github.com/zitadel/oidc/v2/internal/testutil" - "github.com/zitadel/oidc/v2/pkg/oidc" + tu "github.com/zitadel/oidc/v3/internal/testutil" + "github.com/zitadel/oidc/v3/pkg/oidc" "gopkg.in/square/go-jose.v2" ) diff --git a/pkg/client/rp/verifier_tokens_example_test.go b/pkg/client/rp/verifier_tokens_example_test.go index c297efe4..892eb235 100644 --- a/pkg/client/rp/verifier_tokens_example_test.go +++ b/pkg/client/rp/verifier_tokens_example_test.go @@ -4,9 +4,9 @@ import ( "context" "fmt" - tu "github.com/zitadel/oidc/v2/internal/testutil" - "github.com/zitadel/oidc/v2/pkg/client/rp" - "github.com/zitadel/oidc/v2/pkg/oidc" + tu "github.com/zitadel/oidc/v3/internal/testutil" + "github.com/zitadel/oidc/v3/pkg/client/rp" + "github.com/zitadel/oidc/v3/pkg/oidc" ) // MyCustomClaims extends the TokenClaims base, diff --git a/pkg/client/rs/resource_server.go b/pkg/client/rs/resource_server.go index 4e0353c2..f0e0e0ab 100644 --- a/pkg/client/rs/resource_server.go +++ b/pkg/client/rs/resource_server.go @@ -6,9 +6,9 @@ import ( "net/http" "time" - "github.com/zitadel/oidc/v2/pkg/client" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/client" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) type ResourceServer interface { diff --git a/pkg/client/tokenexchange/tokenexchange.go b/pkg/client/tokenexchange/tokenexchange.go index 1375f687..ce665cdc 100644 --- a/pkg/client/tokenexchange/tokenexchange.go +++ b/pkg/client/tokenexchange/tokenexchange.go @@ -4,9 +4,9 @@ import ( "errors" "net/http" - "github.com/zitadel/oidc/v2/pkg/client" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/client" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) type TokenExchanger interface { diff --git a/pkg/oidc/code_challenge.go b/pkg/oidc/code_challenge.go index 37c17830..32963629 100644 --- a/pkg/oidc/code_challenge.go +++ b/pkg/oidc/code_challenge.go @@ -3,7 +3,7 @@ package oidc import ( "crypto/sha256" - "github.com/zitadel/oidc/v2/pkg/crypto" + "github.com/zitadel/oidc/v3/pkg/crypto" ) const ( diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index b017023a..83f3805d 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -8,7 +8,7 @@ import ( "golang.org/x/oauth2" "gopkg.in/square/go-jose.v2" - "github.com/zitadel/oidc/v2/pkg/crypto" + "github.com/zitadel/oidc/v3/pkg/crypto" ) const ( diff --git a/pkg/oidc/verifier.go b/pkg/oidc/verifier.go index c4ee95eb..ad82617d 100644 --- a/pkg/oidc/verifier.go +++ b/pkg/oidc/verifier.go @@ -12,7 +12,7 @@ import ( "gopkg.in/square/go-jose.v2" - str "github.com/zitadel/oidc/v2/pkg/strings" + str "github.com/zitadel/oidc/v3/pkg/strings" ) type Claims interface { diff --git a/pkg/op/auth_request.go b/pkg/op/auth_request.go index 4c483637..1d1add53 100644 --- a/pkg/op/auth_request.go +++ b/pkg/op/auth_request.go @@ -11,9 +11,9 @@ import ( "strings" "time" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" - str "github.com/zitadel/oidc/v2/pkg/strings" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" + str "github.com/zitadel/oidc/v3/pkg/strings" ) type AuthRequest interface { diff --git a/pkg/op/auth_request_test.go b/pkg/op/auth_request_test.go index 542f2e2c..65c65dab 100644 --- a/pkg/op/auth_request_test.go +++ b/pkg/op/auth_request_test.go @@ -12,10 +12,10 @@ import ( "github.com/gorilla/schema" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" - "github.com/zitadel/oidc/v2/pkg/op/mock" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" + "github.com/zitadel/oidc/v3/pkg/op/mock" ) // diff --git a/pkg/op/client.go b/pkg/op/client.go index af4724a8..175caec2 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -7,8 +7,8 @@ import ( "net/url" "time" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) //go:generate go get github.com/dmarkham/enumer diff --git a/pkg/op/client_test.go b/pkg/op/client_test.go index 1af4157e..2e40d9af 100644 --- a/pkg/op/client_test.go +++ b/pkg/op/client_test.go @@ -14,10 +14,10 @@ import ( "github.com/gorilla/schema" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" - "github.com/zitadel/oidc/v2/pkg/op/mock" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" + "github.com/zitadel/oidc/v3/pkg/op/mock" ) type testClientJWTProfile struct{} diff --git a/pkg/op/crypto.go b/pkg/op/crypto.go index 6786022e..6ab1e0a6 100644 --- a/pkg/op/crypto.go +++ b/pkg/op/crypto.go @@ -1,7 +1,7 @@ package op import ( - "github.com/zitadel/oidc/v2/pkg/crypto" + "github.com/zitadel/oidc/v3/pkg/crypto" ) type Crypto interface { diff --git a/pkg/op/device.go b/pkg/op/device.go index 04c06f27..e54da706 100644 --- a/pkg/op/device.go +++ b/pkg/op/device.go @@ -11,8 +11,8 @@ import ( "strings" "time" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) type DeviceAuthorizationConfig struct { diff --git a/pkg/op/device_test.go b/pkg/op/device_test.go index 69ba1024..ab117002 100644 --- a/pkg/op/device_test.go +++ b/pkg/op/device_test.go @@ -15,8 +15,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" ) func Test_deviceAuthorizationHandler(t *testing.T) { diff --git a/pkg/op/discovery.go b/pkg/op/discovery.go index 26f89eb1..38afeab7 100644 --- a/pkg/op/discovery.go +++ b/pkg/op/discovery.go @@ -6,8 +6,8 @@ import ( "gopkg.in/square/go-jose.v2" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) type DiscoverStorage interface { diff --git a/pkg/op/discovery_test.go b/pkg/op/discovery_test.go index 2d0b8af5..e55e9051 100644 --- a/pkg/op/discovery_test.go +++ b/pkg/op/discovery_test.go @@ -11,9 +11,9 @@ import ( "github.com/stretchr/testify/require" "gopkg.in/square/go-jose.v2" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" - "github.com/zitadel/oidc/v2/pkg/op/mock" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" + "github.com/zitadel/oidc/v3/pkg/op/mock" ) func TestDiscover(t *testing.T) { diff --git a/pkg/op/endpoint_test.go b/pkg/op/endpoint_test.go index 50de89cf..46e5d478 100644 --- a/pkg/op/endpoint_test.go +++ b/pkg/op/endpoint_test.go @@ -3,7 +3,7 @@ package op_test import ( "testing" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v3/pkg/op" ) func TestEndpoint_Path(t *testing.T) { diff --git a/pkg/op/error.go b/pkg/op/error.go index acca4ab9..b2d84ae1 100644 --- a/pkg/op/error.go +++ b/pkg/op/error.go @@ -3,8 +3,8 @@ package op import ( "net/http" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) type ErrAuthRequest interface { diff --git a/pkg/op/keys.go b/pkg/op/keys.go index 239ecbda..418dcb59 100644 --- a/pkg/op/keys.go +++ b/pkg/op/keys.go @@ -6,7 +6,7 @@ import ( "gopkg.in/square/go-jose.v2" - httphelper "github.com/zitadel/oidc/v2/pkg/http" + httphelper "github.com/zitadel/oidc/v3/pkg/http" ) type KeyProvider interface { diff --git a/pkg/op/keys_test.go b/pkg/op/keys_test.go index 2e56b781..259b87c6 100644 --- a/pkg/op/keys_test.go +++ b/pkg/op/keys_test.go @@ -11,9 +11,9 @@ import ( "github.com/stretchr/testify/assert" "gopkg.in/square/go-jose.v2" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" - "github.com/zitadel/oidc/v2/pkg/op/mock" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" + "github.com/zitadel/oidc/v3/pkg/op/mock" ) func TestKeys(t *testing.T) { diff --git a/pkg/op/mock/authorizer.mock.go b/pkg/op/mock/authorizer.mock.go index cc913eef..931b8969 100644 --- a/pkg/op/mock/authorizer.mock.go +++ b/pkg/op/mock/authorizer.mock.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Authorizer) +// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: Authorizer) // Package mock is a generated GoMock package. package mock @@ -9,8 +9,8 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - http "github.com/zitadel/oidc/v2/pkg/http" - op "github.com/zitadel/oidc/v2/pkg/op" + http "github.com/zitadel/oidc/v3/pkg/http" + op "github.com/zitadel/oidc/v3/pkg/op" ) // MockAuthorizer is a mock of Authorizer interface. diff --git a/pkg/op/mock/authorizer.mock.impl.go b/pkg/op/mock/authorizer.mock.impl.go index 3f1d525e..6a5bdfd3 100644 --- a/pkg/op/mock/authorizer.mock.impl.go +++ b/pkg/op/mock/authorizer.mock.impl.go @@ -8,8 +8,8 @@ import ( "github.com/gorilla/schema" "gopkg.in/square/go-jose.v2" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" ) func NewAuthorizer(t *testing.T) op.Authorizer { diff --git a/pkg/op/mock/client.go b/pkg/op/mock/client.go index 36df84a9..f01e3ec6 100644 --- a/pkg/op/mock/client.go +++ b/pkg/op/mock/client.go @@ -5,8 +5,8 @@ import ( "github.com/golang/mock/gomock" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" ) func NewClient(t *testing.T) op.Client { diff --git a/pkg/op/mock/client.mock.go b/pkg/op/mock/client.mock.go index e3d19fbc..9be08075 100644 --- a/pkg/op/mock/client.mock.go +++ b/pkg/op/mock/client.mock.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Client) +// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: Client) // Package mock is a generated GoMock package. package mock @@ -9,8 +9,8 @@ import ( time "time" gomock "github.com/golang/mock/gomock" - oidc "github.com/zitadel/oidc/v2/pkg/oidc" - op "github.com/zitadel/oidc/v2/pkg/op" + oidc "github.com/zitadel/oidc/v3/pkg/oidc" + op "github.com/zitadel/oidc/v3/pkg/op" ) // MockClient is a mock of Client interface. diff --git a/pkg/op/mock/configuration.mock.go b/pkg/op/mock/configuration.mock.go index fe7d4da6..96429ddb 100644 --- a/pkg/op/mock/configuration.mock.go +++ b/pkg/op/mock/configuration.mock.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Configuration) +// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: Configuration) // Package mock is a generated GoMock package. package mock @@ -9,7 +9,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - op "github.com/zitadel/oidc/v2/pkg/op" + op "github.com/zitadel/oidc/v3/pkg/op" language "golang.org/x/text/language" ) diff --git a/pkg/op/mock/discovery.mock.go b/pkg/op/mock/discovery.mock.go index 0c78d525..4c33953d 100644 --- a/pkg/op/mock/discovery.mock.go +++ b/pkg/op/mock/discovery.mock.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: DiscoverStorage) +// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: DiscoverStorage) // Package mock is a generated GoMock package. package mock diff --git a/pkg/op/mock/generate.go b/pkg/op/mock/generate.go index ca288d20..590356cf 100644 --- a/pkg/op/mock/generate.go +++ b/pkg/op/mock/generate.go @@ -1,10 +1,10 @@ package mock //go:generate go install github.com/golang/mock/mockgen@v1.6.0 -//go:generate mockgen -package mock -destination ./storage.mock.go github.com/zitadel/oidc/v2/pkg/op Storage -//go:generate mockgen -package mock -destination ./authorizer.mock.go github.com/zitadel/oidc/v2/pkg/op Authorizer -//go:generate mockgen -package mock -destination ./client.mock.go github.com/zitadel/oidc/v2/pkg/op Client -//go:generate mockgen -package mock -destination ./configuration.mock.go github.com/zitadel/oidc/v2/pkg/op Configuration -//go:generate mockgen -package mock -destination ./discovery.mock.go github.com/zitadel/oidc/v2/pkg/op DiscoverStorage -//go:generate mockgen -package mock -destination ./signer.mock.go github.com/zitadel/oidc/v2/pkg/op SigningKey,Key -//go:generate mockgen -package mock -destination ./key.mock.go github.com/zitadel/oidc/v2/pkg/op KeyProvider +//go:generate mockgen -package mock -destination ./storage.mock.go github.com/zitadel/oidc/v3/pkg/op Storage +//go:generate mockgen -package mock -destination ./authorizer.mock.go github.com/zitadel/oidc/v3/pkg/op Authorizer +//go:generate mockgen -package mock -destination ./client.mock.go github.com/zitadel/oidc/v3/pkg/op Client +//go:generate mockgen -package mock -destination ./configuration.mock.go github.com/zitadel/oidc/v3/pkg/op Configuration +//go:generate mockgen -package mock -destination ./discovery.mock.go github.com/zitadel/oidc/v3/pkg/op DiscoverStorage +//go:generate mockgen -package mock -destination ./signer.mock.go github.com/zitadel/oidc/v3/pkg/op SigningKey,Key +//go:generate mockgen -package mock -destination ./key.mock.go github.com/zitadel/oidc/v3/pkg/op KeyProvider diff --git a/pkg/op/mock/key.mock.go b/pkg/op/mock/key.mock.go index 88316517..122e852c 100644 --- a/pkg/op/mock/key.mock.go +++ b/pkg/op/mock/key.mock.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: KeyProvider) +// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: KeyProvider) // Package mock is a generated GoMock package. package mock @@ -9,7 +9,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - op "github.com/zitadel/oidc/v2/pkg/op" + op "github.com/zitadel/oidc/v3/pkg/op" ) // MockKeyProvider is a mock of KeyProvider interface. diff --git a/pkg/op/mock/signer.mock.go b/pkg/op/mock/signer.mock.go index 78c0efe3..7075241d 100644 --- a/pkg/op/mock/signer.mock.go +++ b/pkg/op/mock/signer.mock.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: SigningKey,Key) +// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: SigningKey,Key) // Package mock is a generated GoMock package. package mock diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go index 85afb2a5..6bfb1c98 100644 --- a/pkg/op/mock/storage.mock.go +++ b/pkg/op/mock/storage.mock.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Storage) +// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: Storage) // Package mock is a generated GoMock package. package mock @@ -10,8 +10,8 @@ import ( time "time" gomock "github.com/golang/mock/gomock" - oidc "github.com/zitadel/oidc/v2/pkg/oidc" - op "github.com/zitadel/oidc/v2/pkg/op" + oidc "github.com/zitadel/oidc/v3/pkg/oidc" + op "github.com/zitadel/oidc/v3/pkg/op" jose "gopkg.in/square/go-jose.v2" ) diff --git a/pkg/op/mock/storage.mock.impl.go b/pkg/op/mock/storage.mock.impl.go index 9269f891..002da7ec 100644 --- a/pkg/op/mock/storage.mock.impl.go +++ b/pkg/op/mock/storage.mock.impl.go @@ -8,8 +8,8 @@ import ( "github.com/golang/mock/gomock" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" ) func NewStorage(t *testing.T) op.Storage { diff --git a/pkg/op/op.go b/pkg/op/op.go index 0536bbc4..27c14103 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -12,8 +12,8 @@ import ( "golang.org/x/text/language" "gopkg.in/square/go-jose.v2" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) const ( diff --git a/pkg/op/op_test.go b/pkg/op/op_test.go index 8429212a..3958b89b 100644 --- a/pkg/op/op_test.go +++ b/pkg/op/op_test.go @@ -14,9 +14,9 @@ import ( "github.com/muhlemmer/gu" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/zitadel/oidc/v2/example/server/storage" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v3/example/server/storage" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" "golang.org/x/text/language" ) diff --git a/pkg/op/probes.go b/pkg/op/probes.go index a56c92bc..9ef5bb56 100644 --- a/pkg/op/probes.go +++ b/pkg/op/probes.go @@ -5,7 +5,7 @@ import ( "errors" "net/http" - httphelper "github.com/zitadel/oidc/v2/pkg/http" + httphelper "github.com/zitadel/oidc/v3/pkg/http" ) type ProbesFn func(context.Context) error diff --git a/pkg/op/session.go b/pkg/op/session.go index c4f76f32..fbce125f 100644 --- a/pkg/op/session.go +++ b/pkg/op/session.go @@ -6,8 +6,8 @@ import ( "net/url" "path" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) type SessionEnder interface { diff --git a/pkg/op/storage.go b/pkg/op/storage.go index e36eac7a..25444ddb 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -7,7 +7,7 @@ import ( "gopkg.in/square/go-jose.v2" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/oidc" ) type AuthStorage interface { diff --git a/pkg/op/token.go b/pkg/op/token.go index 58568a78..44648aac 100644 --- a/pkg/op/token.go +++ b/pkg/op/token.go @@ -4,9 +4,9 @@ import ( "context" "time" - "github.com/zitadel/oidc/v2/pkg/crypto" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/strings" + "github.com/zitadel/oidc/v3/pkg/crypto" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/strings" ) type TokenCreator interface { diff --git a/pkg/op/token_client_credentials.go b/pkg/op/token_client_credentials.go index fc31d579..0cf77961 100644 --- a/pkg/op/token_client_credentials.go +++ b/pkg/op/token_client_credentials.go @@ -5,8 +5,8 @@ import ( "net/http" "net/url" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) // ClientCredentialsExchange handles the OAuth 2.0 client_credentials grant, including diff --git a/pkg/op/token_code.go b/pkg/op/token_code.go index 565a4776..b5e892af 100644 --- a/pkg/op/token_code.go +++ b/pkg/op/token_code.go @@ -4,8 +4,8 @@ import ( "context" "net/http" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) // CodeExchange handles the OAuth 2.0 authorization_code grant, including diff --git a/pkg/op/token_exchange.go b/pkg/op/token_exchange.go index 055ff139..93aa9b24 100644 --- a/pkg/op/token_exchange.go +++ b/pkg/op/token_exchange.go @@ -7,8 +7,8 @@ import ( "strings" "time" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) type TokenExchangeRequest interface { diff --git a/pkg/op/token_intospection.go b/pkg/op/token_intospection.go index 85823883..28df2175 100644 --- a/pkg/op/token_intospection.go +++ b/pkg/op/token_intospection.go @@ -5,8 +5,8 @@ import ( "errors" "net/http" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) type Introspector interface { diff --git a/pkg/op/token_jwt_profile.go b/pkg/op/token_jwt_profile.go index 23bac9ac..4563e16f 100644 --- a/pkg/op/token_jwt_profile.go +++ b/pkg/op/token_jwt_profile.go @@ -5,8 +5,8 @@ import ( "net/http" "time" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) type JWTAuthorizationGrantExchanger interface { diff --git a/pkg/op/token_refresh.go b/pkg/op/token_refresh.go index 148d2a4f..aeaa5b4b 100644 --- a/pkg/op/token_refresh.go +++ b/pkg/op/token_refresh.go @@ -6,9 +6,9 @@ import ( "net/http" "time" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/strings" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/strings" ) type RefreshTokenRequest interface { diff --git a/pkg/op/token_request.go b/pkg/op/token_request.go index b9e9805f..058a2029 100644 --- a/pkg/op/token_request.go +++ b/pkg/op/token_request.go @@ -5,8 +5,8 @@ import ( "net/http" "net/url" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) type Exchanger interface { diff --git a/pkg/op/token_revocation.go b/pkg/op/token_revocation.go index 58332c33..34f8746f 100644 --- a/pkg/op/token_revocation.go +++ b/pkg/op/token_revocation.go @@ -7,8 +7,8 @@ import ( "net/url" "strings" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) type Revoker interface { diff --git a/pkg/op/userinfo.go b/pkg/op/userinfo.go index 21a0af48..52a2aa20 100644 --- a/pkg/op/userinfo.go +++ b/pkg/op/userinfo.go @@ -6,8 +6,8 @@ import ( "net/http" "strings" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) type UserinfoProvider interface { diff --git a/pkg/op/verifier_access_token.go b/pkg/op/verifier_access_token.go index 9a8b9128..7527ea69 100644 --- a/pkg/op/verifier_access_token.go +++ b/pkg/op/verifier_access_token.go @@ -4,7 +4,7 @@ import ( "context" "time" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/oidc" ) type AccessTokenVerifier interface { diff --git a/pkg/op/verifier_access_token_example_test.go b/pkg/op/verifier_access_token_example_test.go index effdd587..397a2d35 100644 --- a/pkg/op/verifier_access_token_example_test.go +++ b/pkg/op/verifier_access_token_example_test.go @@ -4,9 +4,9 @@ import ( "context" "fmt" - tu "github.com/zitadel/oidc/v2/internal/testutil" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + tu "github.com/zitadel/oidc/v3/internal/testutil" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" ) // MyCustomClaims extends the TokenClaims base, diff --git a/pkg/op/verifier_access_token_test.go b/pkg/op/verifier_access_token_test.go index 62c26a94..a1972f1c 100644 --- a/pkg/op/verifier_access_token_test.go +++ b/pkg/op/verifier_access_token_test.go @@ -7,8 +7,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - tu "github.com/zitadel/oidc/v2/internal/testutil" - "github.com/zitadel/oidc/v2/pkg/oidc" + tu "github.com/zitadel/oidc/v3/internal/testutil" + "github.com/zitadel/oidc/v3/pkg/oidc" ) func TestNewAccessTokenVerifier(t *testing.T) { diff --git a/pkg/op/verifier_id_token_hint.go b/pkg/op/verifier_id_token_hint.go index d906075d..50c3ff6a 100644 --- a/pkg/op/verifier_id_token_hint.go +++ b/pkg/op/verifier_id_token_hint.go @@ -4,7 +4,7 @@ import ( "context" "time" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/oidc" ) type IDTokenHintVerifier interface { diff --git a/pkg/op/verifier_id_token_hint_test.go b/pkg/op/verifier_id_token_hint_test.go index f4d0b0c6..9f4c6c18 100644 --- a/pkg/op/verifier_id_token_hint_test.go +++ b/pkg/op/verifier_id_token_hint_test.go @@ -7,8 +7,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - tu "github.com/zitadel/oidc/v2/internal/testutil" - "github.com/zitadel/oidc/v2/pkg/oidc" + tu "github.com/zitadel/oidc/v3/internal/testutil" + "github.com/zitadel/oidc/v3/pkg/oidc" ) func TestNewIDTokenHintVerifier(t *testing.T) { diff --git a/pkg/op/verifier_jwt_profile.go b/pkg/op/verifier_jwt_profile.go index 4d83c590..b7dfec71 100644 --- a/pkg/op/verifier_jwt_profile.go +++ b/pkg/op/verifier_jwt_profile.go @@ -8,7 +8,7 @@ import ( "gopkg.in/square/go-jose.v2" - "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/oidc" ) type JWTProfileVerifier interface { From 33c716ddcfef94c88945bc268d89840cb7caff6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Wed, 22 Mar 2023 19:18:41 +0200 Subject: [PATCH 03/14] feat: merge the verifier types (#336) BREAKING CHANGE: - The various verifier types are merged into a oidc.Verifir. - oidc.Verfier became a struct with exported fields * use type aliases for oidc.Verifier this binds the correct contstructor to each verifier usecase. * fix: handle the zero cases for oidc.Time * add unit tests to oidc verifier * fix: correct returned field for JWTTokenRequest JWTTokenRequest.GetIssuedAt() was returning the ExpiresAt field. This change corrects that by returning IssuedAt instead. --- internal/testutil/token.go | 36 ++- pkg/client/rp/relying_party.go | 8 +- pkg/client/rp/verifier.go | 127 +++------ pkg/client/rp/verifier_test.go | 62 ++--- pkg/oidc/token_request.go | 2 +- pkg/oidc/types.go | 6 + pkg/oidc/types_test.go | 51 ++++ pkg/oidc/verifier.go | 36 ++- pkg/oidc/verifier_parse_test.go | 128 +++++++++ pkg/oidc/verifier_test.go | 374 ++++++++++++++++++++++++++ pkg/op/auth_request.go | 8 +- pkg/op/auth_request_test.go | 34 ++- pkg/op/client.go | 2 +- pkg/op/client_test.go | 2 +- pkg/op/mock/authorizer.mock.go | 4 +- pkg/op/mock/authorizer.mock.impl.go | 2 +- pkg/op/op.go | 10 +- pkg/op/session.go | 2 +- pkg/op/token_intospection.go | 2 +- pkg/op/token_jwt_profile.go | 2 +- pkg/op/token_request.go | 4 +- pkg/op/token_revocation.go | 4 +- pkg/op/userinfo.go | 2 +- pkg/op/verifier_access_token.go | 63 +---- pkg/op/verifier_access_token_test.go | 28 +- pkg/op/verifier_id_token_hint.go | 75 ++---- pkg/op/verifier_id_token_hint_test.go | 32 +-- pkg/op/verifier_jwt_profile.go | 76 ++---- pkg/op/verifier_jwt_profile_test.go | 117 ++++++++ 29 files changed, 948 insertions(+), 351 deletions(-) create mode 100644 pkg/oidc/verifier_parse_test.go create mode 100644 pkg/oidc/verifier_test.go create mode 100644 pkg/op/verifier_jwt_profile_test.go diff --git a/internal/testutil/token.go b/internal/testutil/token.go index 27cab5d1..41778de7 100644 --- a/internal/testutil/token.go +++ b/internal/testutil/token.go @@ -8,6 +8,7 @@ import ( "errors" "time" + "github.com/muhlemmer/gu" "github.com/zitadel/oidc/v3/pkg/oidc" "gopkg.in/square/go-jose.v2" ) @@ -17,7 +18,7 @@ type KeySet struct{} // VerifySignature implments op.KeySet. func (KeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) { - if ctx.Err() != nil { + if err = ctx.Err(); err != nil { return nil, err } @@ -45,6 +46,16 @@ func init() { } } +type JWTProfileKeyStorage struct{} + +func (JWTProfileKeyStorage) GetKeyByIDAndClientID(ctx context.Context, keyID string, clientID string) (*jose.JSONWebKey, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + + return gu.Ptr(WebKey.Public()), nil +} + func signEncodeTokenClaims(claims any) string { payload, err := json.Marshal(claims) if err != nil { @@ -106,6 +117,25 @@ func NewAccessToken(issuer, subject string, audience []string, expiration time.T return NewAccessTokenCustom(issuer, subject, audience, expiration, jwtid, clientID, skew, nil) } +func NewJWTProfileAssertion(issuer, clientID string, audience []string, issuedAt, expiration time.Time) (string, *oidc.JWTTokenRequest) { + req := &oidc.JWTTokenRequest{ + Issuer: issuer, + Subject: clientID, + Audience: audience, + ExpiresAt: oidc.FromTime(expiration), + IssuedAt: oidc.FromTime(issuedAt), + } + // make sure the private claim map is set correctly + data, err := json.Marshal(req) + if err != nil { + panic(err) + } + if err = json.Unmarshal(data, req); err != nil { + panic(err) + } + return signEncodeTokenClaims(req), req +} + const InvalidSignatureToken = `eyJhbGciOiJQUzUxMiJ9.eyJpc3MiOiJsb2NhbC5jb20iLCJzdWIiOiJ0aW1AbG9jYWwuY29tIiwiYXVkIjpbInVuaXQiLCJ0ZXN0IiwiNTU1NjY2Il0sImV4cCI6MTY3Nzg0MDQzMSwiaWF0IjoxNjc3ODQwMzcwLCJhdXRoX3RpbWUiOjE2Nzc4NDAzMTAsIm5vbmNlIjoiMTIzNDUiLCJhY3IiOiJzb21ldGhpbmciLCJhbXIiOlsiZm9vIiwiYmFyIl0sImF6cCI6IjU1NTY2NiJ9.DtZmvVkuE4Hw48ijBMhRJbxEWCr_WEYuPQBMY73J9TP6MmfeNFkjVJf4nh4omjB9gVLnQ-xhEkNOe62FS5P0BB2VOxPuHZUj34dNspCgG3h98fGxyiMb5vlIYAHDF9T-w_LntlYItohv63MmdYR-hPpAqjXE7KOfErf-wUDGE9R3bfiQ4HpTdyFJB1nsToYrZ9lhP2mzjTCTs58ckZfQ28DFHn_lfHWpR4rJBgvLx7IH4rMrUayr09Ap-PxQLbv0lYMtmgG1z3JK8MXnuYR0UJdZnEIezOzUTlThhCXB-nvuAXYjYxZZTR0FtlgZUHhIpYK0V2abf_Q_Or36akNCUg` // These variables always result in a valid token @@ -137,6 +167,10 @@ func ValidAccessToken() (string, *oidc.AccessTokenClaims) { return NewAccessToken(ValidIssuer, ValidSubject, ValidAudience, ValidExpiration, ValidJWTID, ValidClientID, ValidSkew) } +func ValidJWTProfileAssertion() (string, *oidc.JWTTokenRequest) { + return NewJWTProfileAssertion(ValidClientID, ValidClientID, []string{ValidIssuer}, time.Now(), ValidExpiration) +} + // ACRVerify is a oidc.ACRVerifier func. func ACRVerify(acr string) error { if acr != ValidACR { diff --git a/pkg/client/rp/relying_party.go b/pkg/client/rp/relying_party.go index 725715e4..bd96e160 100644 --- a/pkg/client/rp/relying_party.go +++ b/pkg/client/rp/relying_party.go @@ -63,8 +63,8 @@ type RelyingParty interface { // be used to start a DeviceAuthorization flow. GetDeviceAuthorizationEndpoint() string - // IDTokenVerifier returns the verifier interface used for oidc id_token verification - IDTokenVerifier() IDTokenVerifier + // IDTokenVerifier returns the verifier used for oidc id_token verification + IDTokenVerifier() *IDTokenVerifier // ErrorHandler returns the handler used for callback errors ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string) @@ -88,7 +88,7 @@ type relyingParty struct { cookieHandler *httphelper.CookieHandler errorHandler func(http.ResponseWriter, *http.Request, string, string, string) - idTokenVerifier IDTokenVerifier + idTokenVerifier *IDTokenVerifier verifierOpts []VerifierOption signer jose.Signer } @@ -137,7 +137,7 @@ func (rp *relyingParty) GetRevokeEndpoint() string { return rp.endpoints.RevokeURL } -func (rp *relyingParty) IDTokenVerifier() IDTokenVerifier { +func (rp *relyingParty) IDTokenVerifier() *IDTokenVerifier { if rp.idTokenVerifier == nil { rp.idTokenVerifier = NewIDTokenVerifier(rp.issuer, rp.oauthConfig.ClientID, NewRemoteKeySet(rp.httpClient, rp.endpoints.JKWsURL), rp.verifierOpts...) } diff --git a/pkg/client/rp/verifier.go b/pkg/client/rp/verifier.go index 0cf427a7..3294f407 100644 --- a/pkg/client/rp/verifier.go +++ b/pkg/client/rp/verifier.go @@ -9,19 +9,9 @@ import ( "github.com/zitadel/oidc/v3/pkg/oidc" ) -type IDTokenVerifier interface { - oidc.Verifier - ClientID() string - SupportedSignAlgs() []string - KeySet() oidc.KeySet - Nonce(context.Context) string - ACR() oidc.ACRVerifier - MaxAge() time.Duration -} - // VerifyTokens implement the Token Response Validation as defined in OIDC specification // https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation -func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v IDTokenVerifier) (claims C, err error) { +func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v *IDTokenVerifier) (claims C, err error) { var nilClaims C claims, err = VerifyIDToken[C](ctx, idToken, v) @@ -36,7 +26,7 @@ func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken str // VerifyIDToken validates the id token according to // https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation -func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVerifier) (claims C, err error) { +func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v *IDTokenVerifier) (claims C, err error) { var nilClaims C decrypted, err := oidc.DecryptToken(token) @@ -52,27 +42,27 @@ func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVe return nilClaims, err } - if err = oidc.CheckIssuer(claims, v.Issuer()); err != nil { + if err = oidc.CheckIssuer(claims, v.Issuer); err != nil { return nilClaims, err } - if err = oidc.CheckAudience(claims, v.ClientID()); err != nil { + if err = oidc.CheckAudience(claims, v.ClientID); err != nil { return nilClaims, err } - if err = oidc.CheckAuthorizedParty(claims, v.ClientID()); err != nil { + if err = oidc.CheckAuthorizedParty(claims, v.ClientID); err != nil { return nilClaims, err } - if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil { + if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil { return nilClaims, err } - if err = oidc.CheckExpiration(claims, v.Offset()); err != nil { + if err = oidc.CheckExpiration(claims, v.Offset); err != nil { return nilClaims, err } - if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil { + if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT, v.Offset); err != nil { return nilClaims, err } @@ -80,16 +70,18 @@ func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVe return nilClaims, err } - if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil { + if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR); err != nil { return nilClaims, err } - if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil { + if err = oidc.CheckAuthTime(claims, v.MaxAge); err != nil { return nilClaims, err } return claims, nil } +type IDTokenVerifier oidc.Verifier + // VerifyAccessToken validates the access token according to // https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error { @@ -107,15 +99,14 @@ func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAl return nil } -// NewIDTokenVerifier returns an implementation of `IDTokenVerifier` -// for `VerifyTokens` and `VerifyIDToken` -func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...VerifierOption) IDTokenVerifier { - v := &idTokenVerifier{ - issuer: issuer, - clientID: clientID, - keySet: keySet, - offset: time.Second, - nonce: func(_ context.Context) string { +// NewIDTokenVerifier returns a oidc.Verifier suitable for ID token verification. +func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...VerifierOption) *IDTokenVerifier { + v := &IDTokenVerifier{ + Issuer: issuer, + ClientID: clientID, + KeySet: keySet, + Offset: time.Second, + Nonce: func(_ context.Context) string { return "" }, } @@ -128,95 +119,47 @@ func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ... } // VerifierOption is the type for providing dynamic options to the IDTokenVerifier -type VerifierOption func(*idTokenVerifier) +type VerifierOption func(*IDTokenVerifier) // WithIssuedAtOffset mitigates the risk of iat to be in the future // because of clock skews with the ability to add an offset to the current time -func WithIssuedAtOffset(offset time.Duration) func(*idTokenVerifier) { - return func(v *idTokenVerifier) { - v.offset = offset +func WithIssuedAtOffset(offset time.Duration) VerifierOption { + return func(v *IDTokenVerifier) { + v.Offset = offset } } // WithIssuedAtMaxAge provides the ability to define the maximum duration between iat and now -func WithIssuedAtMaxAge(maxAge time.Duration) func(*idTokenVerifier) { - return func(v *idTokenVerifier) { - v.maxAgeIAT = maxAge +func WithIssuedAtMaxAge(maxAge time.Duration) VerifierOption { + return func(v *IDTokenVerifier) { + v.MaxAgeIAT = maxAge } } // WithNonce sets the function to check the nonce func WithNonce(nonce func(context.Context) string) VerifierOption { - return func(v *idTokenVerifier) { - v.nonce = nonce + return func(v *IDTokenVerifier) { + v.Nonce = nonce } } // WithACRVerifier sets the verifier for the acr claim func WithACRVerifier(verifier oidc.ACRVerifier) VerifierOption { - return func(v *idTokenVerifier) { - v.acr = verifier + return func(v *IDTokenVerifier) { + v.ACR = verifier } } // WithAuthTimeMaxAge provides the ability to define the maximum duration between auth_time and now func WithAuthTimeMaxAge(maxAge time.Duration) VerifierOption { - return func(v *idTokenVerifier) { - v.maxAge = maxAge + return func(v *IDTokenVerifier) { + v.MaxAge = maxAge } } // WithSupportedSigningAlgorithms overwrites the default RS256 signing algorithm func WithSupportedSigningAlgorithms(algs ...string) VerifierOption { - return func(v *idTokenVerifier) { - v.supportedSignAlgs = algs + return func(v *IDTokenVerifier) { + v.SupportedSignAlgs = algs } } - -type idTokenVerifier struct { - issuer string - maxAgeIAT time.Duration - offset time.Duration - clientID string - supportedSignAlgs []string - keySet oidc.KeySet - acr oidc.ACRVerifier - maxAge time.Duration - nonce func(ctx context.Context) string -} - -func (i *idTokenVerifier) Issuer() string { - return i.issuer -} - -func (i *idTokenVerifier) MaxAgeIAT() time.Duration { - return i.maxAgeIAT -} - -func (i *idTokenVerifier) Offset() time.Duration { - return i.offset -} - -func (i *idTokenVerifier) ClientID() string { - return i.clientID -} - -func (i *idTokenVerifier) SupportedSignAlgs() []string { - return i.supportedSignAlgs -} - -func (i *idTokenVerifier) KeySet() oidc.KeySet { - return i.keySet -} - -func (i *idTokenVerifier) Nonce(ctx context.Context) string { - return i.nonce(ctx) -} - -func (i *idTokenVerifier) ACR() oidc.ACRVerifier { - return i.acr -} - -func (i *idTokenVerifier) MaxAge() time.Duration { - return i.maxAge -} diff --git a/pkg/client/rp/verifier_test.go b/pkg/client/rp/verifier_test.go index 002d65d3..11bf2f9f 100644 --- a/pkg/client/rp/verifier_test.go +++ b/pkg/client/rp/verifier_test.go @@ -13,16 +13,16 @@ import ( ) func TestVerifyTokens(t *testing.T) { - verifier := &idTokenVerifier{ - issuer: tu.ValidIssuer, - maxAgeIAT: 2 * time.Minute, - offset: time.Second, - supportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, - keySet: tu.KeySet{}, - maxAge: 2 * time.Minute, - acr: tu.ACRVerify, - nonce: func(context.Context) string { return tu.ValidNonce }, - clientID: tu.ValidClientID, + verifier := &IDTokenVerifier{ + Issuer: tu.ValidIssuer, + MaxAgeIAT: 2 * time.Minute, + Offset: time.Second, + SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, + KeySet: tu.KeySet{}, + MaxAge: 2 * time.Minute, + ACR: tu.ACRVerify, + Nonce: func(context.Context) string { return tu.ValidNonce }, + ClientID: tu.ValidClientID, } accessToken, _ := tu.ValidAccessToken() atHash, err := oidc.ClaimHash(accessToken, tu.SignatureAlgorithm) @@ -91,15 +91,15 @@ func TestVerifyTokens(t *testing.T) { } func TestVerifyIDToken(t *testing.T) { - verifier := &idTokenVerifier{ - issuer: tu.ValidIssuer, - maxAgeIAT: 2 * time.Minute, - offset: time.Second, - supportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, - keySet: tu.KeySet{}, - maxAge: 2 * time.Minute, - acr: tu.ACRVerify, - nonce: func(context.Context) string { return tu.ValidNonce }, + verifier := &IDTokenVerifier{ + Issuer: tu.ValidIssuer, + MaxAgeIAT: 2 * time.Minute, + Offset: time.Second, + SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, + KeySet: tu.KeySet{}, + MaxAge: 2 * time.Minute, + ACR: tu.ACRVerify, + Nonce: func(context.Context) string { return tu.ValidNonce }, } tests := []struct { @@ -219,7 +219,7 @@ func TestVerifyIDToken(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { token, want := tt.tokenClaims() - verifier.clientID = tt.clientID + verifier.ClientID = tt.clientID got, err := VerifyIDToken[*oidc.IDTokenClaims](context.Background(), token, verifier) if tt.wantErr { assert.Error(t, err) @@ -300,7 +300,7 @@ func TestNewIDTokenVerifier(t *testing.T) { tests := []struct { name string args args - want IDTokenVerifier + want *IDTokenVerifier }{ { name: "nil nonce", // otherwise assert.Equal will fail on the function @@ -317,16 +317,16 @@ func TestNewIDTokenVerifier(t *testing.T) { WithSupportedSigningAlgorithms("ABC", "DEF"), }, }, - want: &idTokenVerifier{ - issuer: tu.ValidIssuer, - offset: time.Minute, - maxAgeIAT: time.Hour, - clientID: tu.ValidClientID, - keySet: tu.KeySet{}, - nonce: nil, - acr: nil, - maxAge: 2 * time.Hour, - supportedSignAlgs: []string{"ABC", "DEF"}, + want: &IDTokenVerifier{ + Issuer: tu.ValidIssuer, + Offset: time.Minute, + MaxAgeIAT: time.Hour, + ClientID: tu.ValidClientID, + KeySet: tu.KeySet{}, + Nonce: nil, + ACR: nil, + MaxAge: 2 * time.Hour, + SupportedSignAlgs: []string{"ABC", "DEF"}, }, }, } diff --git a/pkg/oidc/token_request.go b/pkg/oidc/token_request.go index e63e0e51..6b6945a1 100644 --- a/pkg/oidc/token_request.go +++ b/pkg/oidc/token_request.go @@ -192,7 +192,7 @@ func (j *JWTTokenRequest) GetExpiration() time.Time { // GetIssuedAt implements the Claims interface func (j *JWTTokenRequest) GetIssuedAt() time.Time { - return j.ExpiresAt.AsTime() + return j.IssuedAt.AsTime() } // GetNonce implements the Claims interface diff --git a/pkg/oidc/types.go b/pkg/oidc/types.go index cb513a09..167f8b78 100644 --- a/pkg/oidc/types.go +++ b/pkg/oidc/types.go @@ -173,10 +173,16 @@ func NewEncoder() *schema.Encoder { type Time int64 func (ts Time) AsTime() time.Time { + if ts == 0 { + return time.Time{} + } return time.Unix(int64(ts), 0) } func FromTime(tt time.Time) Time { + if tt.IsZero() { + return 0 + } return Time(tt.Unix()) } diff --git a/pkg/oidc/types_test.go b/pkg/oidc/types_test.go index 2721e0b7..64f07f16 100644 --- a/pkg/oidc/types_test.go +++ b/pkg/oidc/types_test.go @@ -7,6 +7,7 @@ import ( "strconv" "strings" "testing" + "time" "github.com/gorilla/schema" "github.com/stretchr/testify/assert" @@ -467,6 +468,56 @@ func TestNewEncoder(t *testing.T) { assert.Equal(t, a, b) } +func TestTime_AsTime(t *testing.T) { + tests := []struct { + name string + ts Time + want time.Time + }{ + { + name: "unset", + ts: 0, + want: time.Time{}, + }, + { + name: "set", + ts: 1, + want: time.Unix(1, 0), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.ts.AsTime() + assert.Equal(t, tt.want, got) + }) + } +} + +func TestTime_FromTime(t *testing.T) { + tests := []struct { + name string + tt time.Time + want Time + }{ + { + name: "zero", + tt: time.Time{}, + want: 0, + }, + { + name: "set", + tt: time.Unix(1, 0), + want: 1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := FromTime(tt.tt) + assert.Equal(t, tt.want, got) + }) + } +} + func TestTime_UnmarshalJSON(t *testing.T) { type dst struct { UpdatedAt Time `json:"updated_at"` diff --git a/pkg/oidc/verifier.go b/pkg/oidc/verifier.go index ad82617d..2d4e7a67 100644 --- a/pkg/oidc/verifier.go +++ b/pkg/oidc/verifier.go @@ -61,10 +61,19 @@ var ( ErrAtHash = errors.New("at_hash does not correspond to access token") ) -type Verifier interface { - Issuer() string - MaxAgeIAT() time.Duration - Offset() time.Duration +// Verifier caries configuration for the various token verification +// functions. Use package specific constructor functions to know +// which values need to be set. +type Verifier struct { + Issuer string + MaxAgeIAT time.Duration + Offset time.Duration + ClientID string + SupportedSignAlgs []string + MaxAge time.Duration + ACR ACRVerifier + KeySet KeySet + Nonce func(ctx context.Context) string } // ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim @@ -121,6 +130,11 @@ func CheckAudience(claims Claims, clientID string) error { return nil } +// CheckAuthorizedParty checks azp (authorized party) claim requirements. +// +// If the ID Token contains multiple audiences, the Client SHOULD verify that an azp Claim is present. +// If an azp Claim is present, the Client SHOULD verify that its client_id is the Claim Value. +// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation func CheckAuthorizedParty(claims Claims, clientID string) error { if len(claims.GetAudience()) > 1 { if claims.GetAuthorizedParty() == "" { @@ -167,26 +181,26 @@ func CheckSignature(ctx context.Context, token string, payload []byte, claims Cl } func CheckExpiration(claims Claims, offset time.Duration) error { - expiration := claims.GetExpiration().Round(time.Second) - if !time.Now().UTC().Add(offset).Before(expiration) { + expiration := claims.GetExpiration() + if !time.Now().Add(offset).Before(expiration) { return ErrExpired } return nil } func CheckIssuedAt(claims Claims, maxAgeIAT, offset time.Duration) error { - issuedAt := claims.GetIssuedAt().Round(time.Second) + issuedAt := claims.GetIssuedAt() if issuedAt.IsZero() { return ErrIatMissing } - nowWithOffset := time.Now().UTC().Add(offset).Round(time.Second) + nowWithOffset := time.Now().Add(offset).Round(time.Second) if issuedAt.After(nowWithOffset) { return fmt.Errorf("%w: (iat: %v, now with offset: %v)", ErrIatInFuture, issuedAt, nowWithOffset) } if maxAgeIAT == 0 { return nil } - maxAge := time.Now().UTC().Add(-maxAgeIAT).Round(time.Second) + maxAge := time.Now().Add(-maxAgeIAT).Round(time.Second) if issuedAt.Before(maxAge) { return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrIatToOld, maxAge, issuedAt, maxAge.Sub(issuedAt)) } @@ -216,8 +230,8 @@ func CheckAuthTime(claims Claims, maxAge time.Duration) error { if claims.GetAuthTime().IsZero() { return ErrAuthTimeNotPresent } - authTime := claims.GetAuthTime().Round(time.Second) - maxAuthTime := time.Now().UTC().Add(-maxAge).Round(time.Second) + authTime := claims.GetAuthTime() + maxAuthTime := time.Now().Add(-maxAge).Round(time.Second) if authTime.Before(maxAuthTime) { return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrAuthTimeToOld, maxAge, authTime, maxAuthTime.Sub(authTime)) } diff --git a/pkg/oidc/verifier_parse_test.go b/pkg/oidc/verifier_parse_test.go new file mode 100644 index 00000000..105650f0 --- /dev/null +++ b/pkg/oidc/verifier_parse_test.go @@ -0,0 +1,128 @@ +package oidc_test + +import ( + "context" + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + tu "github.com/zitadel/oidc/v3/internal/testutil" + "github.com/zitadel/oidc/v3/pkg/oidc" +) + +func TestParseToken(t *testing.T) { + token, wantClaims := tu.ValidIDToken() + wantClaims.SignatureAlg = "" // unset, because is not part of the JSON payload + + wantPayload, err := json.Marshal(wantClaims) + require.NoError(t, err) + + tests := []struct { + name string + tokenString string + wantErr bool + }{ + { + name: "split error", + tokenString: "nope", + wantErr: true, + }, + { + name: "base64 error", + tokenString: "foo.~.bar", + wantErr: true, + }, + { + name: "success", + tokenString: token, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotClaims := new(oidc.IDTokenClaims) + gotPayload, err := oidc.ParseToken(tt.tokenString, gotClaims) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, wantClaims, gotClaims) + assert.JSONEq(t, string(wantPayload), string(gotPayload)) + }) + } +} + +func TestCheckSignature(t *testing.T) { + errCtx, cancel := context.WithCancel(context.Background()) + cancel() + + token, _ := tu.ValidIDToken() + payload, err := oidc.ParseToken(token, &oidc.IDTokenClaims{}) + require.NoError(t, err) + + type args struct { + ctx context.Context + token string + payload []byte + supportedSigAlgs []string + } + tests := []struct { + name string + args args + wantErr error + }{ + { + name: "parse error", + args: args{ + ctx: context.Background(), + token: "~", + payload: payload, + }, + wantErr: oidc.ErrParse, + }, + { + name: "default sigAlg", + args: args{ + ctx: context.Background(), + token: token, + payload: payload, + }, + }, + { + name: "unsupported sigAlg", + args: args{ + ctx: context.Background(), + token: token, + payload: payload, + supportedSigAlgs: []string{"foo", "bar"}, + }, + wantErr: oidc.ErrSignatureUnsupportedAlg, + }, + { + name: "verify error", + args: args{ + ctx: errCtx, + token: token, + payload: payload, + }, + wantErr: oidc.ErrSignatureInvalid, + }, + { + name: "inequal payloads", + args: args{ + ctx: context.Background(), + token: token, + payload: []byte{0, 1, 2}, + }, + wantErr: oidc.ErrSignatureInvalidPayload, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + claims := new(oidc.TokenClaims) + err := oidc.CheckSignature(tt.args.ctx, tt.args.token, tt.args.payload, claims, tt.args.supportedSigAlgs, tu.KeySet{}) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} diff --git a/pkg/oidc/verifier_test.go b/pkg/oidc/verifier_test.go new file mode 100644 index 00000000..93e71575 --- /dev/null +++ b/pkg/oidc/verifier_test.go @@ -0,0 +1,374 @@ +package oidc + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDecryptToken(t *testing.T) { + const tokenString = "ABC" + got, err := DecryptToken(tokenString) + require.NoError(t, err) + assert.Equal(t, tokenString, got) +} + +func TestDefaultACRVerifier(t *testing.T) { + acrVerfier := DefaultACRVerifier([]string{"foo", "bar"}) + + tests := []struct { + name string + acr string + wantErr string + }{ + { + name: "ok", + acr: "bar", + }, + { + name: "error", + acr: "hello", + wantErr: "expected one of: [foo bar], got: \"hello\"", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := acrVerfier(tt.acr) + if tt.wantErr != "" { + assert.EqualError(t, err, tt.wantErr) + return + } + require.NoError(t, err) + }) + } +} + +func TestCheckSubject(t *testing.T) { + tests := []struct { + name string + claims Claims + wantErr error + }{ + { + name: "missing", + claims: &TokenClaims{}, + wantErr: ErrSubjectMissing, + }, + { + name: "ok", + claims: &TokenClaims{ + Subject: "foo", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckSubject(tt.claims) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckIssuer(t *testing.T) { + const issuer = "foo.bar" + tests := []struct { + name string + claims Claims + wantErr error + }{ + { + name: "missing", + claims: &TokenClaims{}, + wantErr: ErrIssuerInvalid, + }, + { + name: "wrong", + claims: &TokenClaims{ + Issuer: "wrong", + }, + wantErr: ErrIssuerInvalid, + }, + { + name: "ok", + claims: &TokenClaims{ + Issuer: issuer, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckIssuer(tt.claims, issuer) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckAudience(t *testing.T) { + const clientID = "foo.bar" + tests := []struct { + name string + claims Claims + wantErr error + }{ + { + name: "missing", + claims: &TokenClaims{}, + wantErr: ErrAudience, + }, + { + name: "wrong", + claims: &TokenClaims{ + Audience: []string{"wrong"}, + }, + wantErr: ErrAudience, + }, + { + name: "ok", + claims: &TokenClaims{ + Audience: []string{clientID}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckAudience(tt.claims, clientID) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckAuthorizedParty(t *testing.T) { + const clientID = "foo.bar" + tests := []struct { + name string + claims Claims + wantErr error + }{ + { + name: "single audience, no azp", + claims: &TokenClaims{ + Audience: []string{clientID}, + }, + }, + { + name: "multiple audience, no azp", + claims: &TokenClaims{ + Audience: []string{clientID, "other"}, + }, + wantErr: ErrAzpMissing, + }, + { + name: "single audience, with azp", + claims: &TokenClaims{ + Audience: []string{clientID}, + AuthorizedParty: clientID, + }, + }, + { + name: "multiple audience, with azp", + claims: &TokenClaims{ + Audience: []string{clientID, "other"}, + AuthorizedParty: clientID, + }, + }, + { + name: "wrong azp", + claims: &TokenClaims{ + AuthorizedParty: "wrong", + }, + wantErr: ErrAzpInvalid, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckAuthorizedParty(tt.claims, clientID) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckExpiration(t *testing.T) { + const offset = time.Minute + tests := []struct { + name string + claims Claims + wantErr error + }{ + { + name: "missing", + claims: &TokenClaims{}, + wantErr: ErrExpired, + }, + { + name: "expired", + claims: &TokenClaims{ + Expiration: FromTime(time.Now().Add(-2 * offset)), + }, + wantErr: ErrExpired, + }, + { + name: "valid", + claims: &TokenClaims{ + Expiration: FromTime(time.Now().Add(2 * offset)), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckExpiration(tt.claims, offset) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckIssuedAt(t *testing.T) { + const offset = time.Minute + tests := []struct { + name string + maxAgeIAT time.Duration + claims Claims + wantErr error + }{ + { + name: "missing", + claims: &TokenClaims{}, + wantErr: ErrIatMissing, + }, + { + name: "future", + claims: &TokenClaims{ + IssuedAt: FromTime(time.Now().Add(time.Hour)), + }, + wantErr: ErrIatInFuture, + }, + { + name: "no max", + claims: &TokenClaims{ + IssuedAt: FromTime(time.Now()), + }, + }, + { + name: "past max", + maxAgeIAT: time.Minute, + claims: &TokenClaims{ + IssuedAt: FromTime(time.Now().Add(-time.Hour)), + }, + wantErr: ErrIatToOld, + }, + { + name: "within max", + maxAgeIAT: time.Hour, + claims: &TokenClaims{ + IssuedAt: FromTime(time.Now()), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckIssuedAt(tt.claims, tt.maxAgeIAT, offset) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckNonce(t *testing.T) { + const nonce = "123" + tests := []struct { + name string + claims Claims + wantErr error + }{ + { + name: "missing", + claims: &TokenClaims{}, + wantErr: ErrNonceInvalid, + }, + { + name: "wrong", + claims: &TokenClaims{ + Nonce: "wrong", + }, + wantErr: ErrNonceInvalid, + }, + { + name: "ok", + claims: &TokenClaims{ + Nonce: nonce, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckNonce(tt.claims, nonce) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckAuthorizationContextClassReference(t *testing.T) { + tests := []struct { + name string + acr ACRVerifier + wantErr error + }{ + { + name: "error", + acr: func(s string) error { return errors.New("oops") }, + wantErr: ErrAcrInvalid, + }, + { + name: "ok", + acr: func(s string) error { return nil }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckAuthorizationContextClassReference(&IDTokenClaims{}, tt.acr) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckAuthTime(t *testing.T) { + tests := []struct { + name string + claims Claims + maxAge time.Duration + wantErr error + }{ + { + name: "no max age", + claims: &TokenClaims{}, + }, + { + name: "missing", + claims: &TokenClaims{}, + maxAge: time.Minute, + wantErr: ErrAuthTimeNotPresent, + }, + { + name: "expired", + maxAge: time.Minute, + claims: &TokenClaims{ + AuthTime: FromTime(time.Now().Add(-time.Hour)), + }, + wantErr: ErrAuthTimeToOld, + }, + { + name: "ok", + maxAge: time.Minute, + claims: &TokenClaims{ + AuthTime: NowTime(), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckAuthTime(tt.claims, tt.maxAge) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} diff --git a/pkg/op/auth_request.go b/pkg/op/auth_request.go index 1d1add53..b516909b 100644 --- a/pkg/op/auth_request.go +++ b/pkg/op/auth_request.go @@ -38,7 +38,7 @@ type Authorizer interface { Storage() Storage Decoder() httphelper.Decoder Encoder() httphelper.Encoder - IDTokenHintVerifier(context.Context) IDTokenHintVerifier + IDTokenHintVerifier(context.Context) *IDTokenHintVerifier Crypto() Crypto RequestObjectSupported() bool } @@ -47,7 +47,7 @@ type Authorizer interface { // implementing its own validation mechanism for the auth request type AuthorizeValidator interface { Authorizer - ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, IDTokenHintVerifier) (string, error) + ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, *IDTokenHintVerifier) (string, error) } func authorizeHandler(authorizer Authorizer) func(http.ResponseWriter, *http.Request) { @@ -204,7 +204,7 @@ func CopyRequestObjectToAuthRequest(authReq *oidc.AuthRequest, requestObject *oi } // ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed -func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier IDTokenHintVerifier) (sub string, err error) { +func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier *IDTokenHintVerifier) (sub string, err error) { authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge) if err != nil { return "", err @@ -384,7 +384,7 @@ func ValidateAuthReqResponseType(client Client, responseType oidc.ResponseType) // ValidateAuthReqIDTokenHint validates the id_token_hint (if passed as parameter in the request) // and returns the `sub` claim -func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier IDTokenHintVerifier) (string, error) { +func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier *IDTokenHintVerifier) (string, error) { if idTokenHint == "" { return "", nil } diff --git a/pkg/op/auth_request_test.go b/pkg/op/auth_request_test.go index 65c65dab..3179e258 100644 --- a/pkg/op/auth_request_test.go +++ b/pkg/op/auth_request_test.go @@ -12,6 +12,7 @@ import ( "github.com/gorilla/schema" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + tu "github.com/zitadel/oidc/v3/internal/testutil" httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/op" @@ -146,7 +147,7 @@ func TestValidateAuthRequest(t *testing.T) { type args struct { authRequest *oidc.AuthRequest storage op.Storage - verifier op.IDTokenHintVerifier + verifier *op.IDTokenHintVerifier } tests := []struct { name string @@ -1003,3 +1004,34 @@ func Test_parseAuthorizeCallbackRequest(t *testing.T) { }) } } + +func TestValidateAuthReqIDTokenHint(t *testing.T) { + token, _ := tu.ValidIDToken() + tests := []struct { + name string + idTokenHint string + want string + wantErr error + }{ + { + name: "empty", + }, + { + name: "verify err", + idTokenHint: "foo", + wantErr: oidc.ErrLoginRequired(), + }, + { + name: "ok", + idTokenHint: token, + want: tu.ValidSubject, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := op.ValidateAuthReqIDTokenHint(context.Background(), tt.idTokenHint, op.NewIDTokenHintVerifier(tu.ValidIssuer, tu.KeySet{})) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/op/client.go b/pkg/op/client.go index 175caec2..754636cc 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -81,7 +81,7 @@ var ( ) type ClientJWTProfile interface { - JWTProfileVerifier(context.Context) JWTProfileVerifier + JWTProfileVerifier(context.Context) *JWTProfileVerifier } func ClientJWTAuth(ctx context.Context, ca oidc.ClientAssertionParams, verifier ClientJWTProfile) (clientID string, err error) { diff --git a/pkg/op/client_test.go b/pkg/op/client_test.go index 2e40d9af..bb17192a 100644 --- a/pkg/op/client_test.go +++ b/pkg/op/client_test.go @@ -22,7 +22,7 @@ import ( type testClientJWTProfile struct{} -func (testClientJWTProfile) JWTProfileVerifier(context.Context) op.JWTProfileVerifier { return nil } +func (testClientJWTProfile) JWTProfileVerifier(context.Context) *op.JWTProfileVerifier { return nil } func TestClientJWTAuth(t *testing.T) { type args struct { diff --git a/pkg/op/mock/authorizer.mock.go b/pkg/op/mock/authorizer.mock.go index 931b8969..a0c67e3d 100644 --- a/pkg/op/mock/authorizer.mock.go +++ b/pkg/op/mock/authorizer.mock.go @@ -79,10 +79,10 @@ func (mr *MockAuthorizerMockRecorder) Encoder() *gomock.Call { } // IDTokenHintVerifier mocks base method. -func (m *MockAuthorizer) IDTokenHintVerifier(arg0 context.Context) op.IDTokenHintVerifier { +func (m *MockAuthorizer) IDTokenHintVerifier(arg0 context.Context) *op.IDTokenHintVerifier { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "IDTokenHintVerifier", arg0) - ret0, _ := ret[0].(op.IDTokenHintVerifier) + ret0, _ := ret[0].(*op.IDTokenHintVerifier) return ret0 } diff --git a/pkg/op/mock/authorizer.mock.impl.go b/pkg/op/mock/authorizer.mock.impl.go index 6a5bdfd3..409683ab 100644 --- a/pkg/op/mock/authorizer.mock.impl.go +++ b/pkg/op/mock/authorizer.mock.impl.go @@ -49,7 +49,7 @@ func ExpectEncoder(a op.Authorizer) { func ExpectVerifier(a op.Authorizer, t *testing.T) { mockA := a.(*MockAuthorizer) mockA.EXPECT().IDTokenHintVerifier(gomock.Any()).DoAndReturn( - func() op.IDTokenHintVerifier { + func() *op.IDTokenHintVerifier { return op.NewIDTokenHintVerifier("", nil) }) } diff --git a/pkg/op/op.go b/pkg/op/op.go index 27c14103..9ed5662c 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -73,8 +73,8 @@ type OpenIDProvider interface { Storage() Storage Decoder() httphelper.Decoder Encoder() httphelper.Encoder - IDTokenHintVerifier(context.Context) IDTokenHintVerifier - AccessTokenVerifier(context.Context) AccessTokenVerifier + IDTokenHintVerifier(context.Context) *IDTokenHintVerifier + AccessTokenVerifier(context.Context) *AccessTokenVerifier Crypto() Crypto DefaultLogoutRedirectURI() string Probes() []ProbesFn @@ -342,15 +342,15 @@ func (o *Provider) Encoder() httphelper.Encoder { return o.encoder } -func (o *Provider) IDTokenHintVerifier(ctx context.Context) IDTokenHintVerifier { +func (o *Provider) IDTokenHintVerifier(ctx context.Context) *IDTokenHintVerifier { return NewIDTokenHintVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.idTokenHintVerifierOpts...) } -func (o *Provider) JWTProfileVerifier(ctx context.Context) JWTProfileVerifier { +func (o *Provider) JWTProfileVerifier(ctx context.Context) *JWTProfileVerifier { return NewJWTProfileVerifier(o.Storage(), IssuerFromContext(ctx), 1*time.Hour, time.Second) } -func (o *Provider) AccessTokenVerifier(ctx context.Context) AccessTokenVerifier { +func (o *Provider) AccessTokenVerifier(ctx context.Context) *AccessTokenVerifier { return NewAccessTokenVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.accessTokenVerifierOpts...) } diff --git a/pkg/op/session.go b/pkg/op/session.go index fbce125f..fd914d11 100644 --- a/pkg/op/session.go +++ b/pkg/op/session.go @@ -13,7 +13,7 @@ import ( type SessionEnder interface { Decoder() httphelper.Decoder Storage() Storage - IDTokenHintVerifier(context.Context) IDTokenHintVerifier + IDTokenHintVerifier(context.Context) *IDTokenHintVerifier DefaultLogoutRedirectURI() string } diff --git a/pkg/op/token_intospection.go b/pkg/op/token_intospection.go index 28df2175..21b79c3b 100644 --- a/pkg/op/token_intospection.go +++ b/pkg/op/token_intospection.go @@ -13,7 +13,7 @@ type Introspector interface { Decoder() httphelper.Decoder Crypto() Crypto Storage() Storage - AccessTokenVerifier(context.Context) AccessTokenVerifier + AccessTokenVerifier(context.Context) *AccessTokenVerifier } type IntrospectorJWTProfile interface { diff --git a/pkg/op/token_jwt_profile.go b/pkg/op/token_jwt_profile.go index 4563e16f..4cd7b1e4 100644 --- a/pkg/op/token_jwt_profile.go +++ b/pkg/op/token_jwt_profile.go @@ -11,7 +11,7 @@ import ( type JWTAuthorizationGrantExchanger interface { Exchanger - JWTProfileVerifier(context.Context) JWTProfileVerifier + JWTProfileVerifier(context.Context) *JWTProfileVerifier } // JWTProfile handles the OAuth 2.0 JWT Profile Authorization Grant https://tools.ietf.org/html/rfc7523#section-2.1 diff --git a/pkg/op/token_request.go b/pkg/op/token_request.go index 058a2029..c06a51bc 100644 --- a/pkg/op/token_request.go +++ b/pkg/op/token_request.go @@ -20,8 +20,8 @@ type Exchanger interface { GrantTypeJWTAuthorizationSupported() bool GrantTypeClientCredentialsSupported() bool GrantTypeDeviceCodeSupported() bool - AccessTokenVerifier(context.Context) AccessTokenVerifier - IDTokenHintVerifier(context.Context) IDTokenHintVerifier + AccessTokenVerifier(context.Context) *AccessTokenVerifier + IDTokenHintVerifier(context.Context) *IDTokenHintVerifier } func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Request) { diff --git a/pkg/op/token_revocation.go b/pkg/op/token_revocation.go index 34f8746f..fd1ee931 100644 --- a/pkg/op/token_revocation.go +++ b/pkg/op/token_revocation.go @@ -15,14 +15,14 @@ type Revoker interface { Decoder() httphelper.Decoder Crypto() Crypto Storage() Storage - AccessTokenVerifier(context.Context) AccessTokenVerifier + AccessTokenVerifier(context.Context) *AccessTokenVerifier AuthMethodPrivateKeyJWTSupported() bool AuthMethodPostSupported() bool } type RevokerJWTProfile interface { Revoker - JWTProfileVerifier(context.Context) JWTProfileVerifier + JWTProfileVerifier(context.Context) *JWTProfileVerifier } func revocationHandler(revoker Revoker) func(http.ResponseWriter, *http.Request) { diff --git a/pkg/op/userinfo.go b/pkg/op/userinfo.go index 52a2aa20..86205b5f 100644 --- a/pkg/op/userinfo.go +++ b/pkg/op/userinfo.go @@ -14,7 +14,7 @@ type UserinfoProvider interface { Decoder() httphelper.Decoder Crypto() Crypto Storage() Storage - AccessTokenVerifier(context.Context) AccessTokenVerifier + AccessTokenVerifier(context.Context) *AccessTokenVerifier } func userinfoHandler(userinfoProvider UserinfoProvider) func(http.ResponseWriter, *http.Request) { diff --git a/pkg/op/verifier_access_token.go b/pkg/op/verifier_access_token.go index 7527ea69..120bfa71 100644 --- a/pkg/op/verifier_access_token.go +++ b/pkg/op/verifier_access_token.go @@ -2,62 +2,25 @@ package op import ( "context" - "time" "github.com/zitadel/oidc/v3/pkg/oidc" ) -type AccessTokenVerifier interface { - oidc.Verifier - SupportedSignAlgs() []string - KeySet() oidc.KeySet -} - -type accessTokenVerifier struct { - issuer string - maxAgeIAT time.Duration - offset time.Duration - supportedSignAlgs []string - keySet oidc.KeySet -} - -// Issuer implements oidc.Verifier interface -func (i *accessTokenVerifier) Issuer() string { - return i.issuer -} - -// MaxAgeIAT implements oidc.Verifier interface -func (i *accessTokenVerifier) MaxAgeIAT() time.Duration { - return i.maxAgeIAT -} - -// Offset implements oidc.Verifier interface -func (i *accessTokenVerifier) Offset() time.Duration { - return i.offset -} - -// SupportedSignAlgs implements AccessTokenVerifier interface -func (i *accessTokenVerifier) SupportedSignAlgs() []string { - return i.supportedSignAlgs -} - -// KeySet implements AccessTokenVerifier interface -func (i *accessTokenVerifier) KeySet() oidc.KeySet { - return i.keySet -} +type AccessTokenVerifier oidc.Verifier -type AccessTokenVerifierOpt func(*accessTokenVerifier) +type AccessTokenVerifierOpt func(*AccessTokenVerifier) func WithSupportedAccessTokenSigningAlgorithms(algs ...string) AccessTokenVerifierOpt { - return func(verifier *accessTokenVerifier) { - verifier.supportedSignAlgs = algs + return func(verifier *AccessTokenVerifier) { + verifier.SupportedSignAlgs = algs } } -func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTokenVerifierOpt) AccessTokenVerifier { - verifier := &accessTokenVerifier{ - issuer: issuer, - keySet: keySet, +// NewAccessTokenVerifier returns a AccessTokenVerifier suitable for access token verification. +func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTokenVerifierOpt) *AccessTokenVerifier { + verifier := &AccessTokenVerifier{ + Issuer: issuer, + KeySet: keySet, } for _, opt := range opts { opt(verifier) @@ -66,7 +29,7 @@ func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTok } // VerifyAccessToken validates the access token (issuer, signature and expiration). -func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v AccessTokenVerifier) (claims C, err error) { +func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v *AccessTokenVerifier) (claims C, err error) { var nilClaims C decrypted, err := oidc.DecryptToken(token) @@ -78,15 +41,15 @@ func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v Acces return nilClaims, err } - if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil { + if err := oidc.CheckIssuer(claims, v.Issuer); err != nil { return nilClaims, err } - if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil { + if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil { return nilClaims, err } - if err = oidc.CheckExpiration(claims, v.Offset()); err != nil { + if err = oidc.CheckExpiration(claims, v.Offset); err != nil { return nilClaims, err } diff --git a/pkg/op/verifier_access_token_test.go b/pkg/op/verifier_access_token_test.go index a1972f1c..66e32ceb 100644 --- a/pkg/op/verifier_access_token_test.go +++ b/pkg/op/verifier_access_token_test.go @@ -20,7 +20,7 @@ func TestNewAccessTokenVerifier(t *testing.T) { tests := []struct { name string args args - want AccessTokenVerifier + want *AccessTokenVerifier }{ { name: "simple", @@ -28,9 +28,9 @@ func TestNewAccessTokenVerifier(t *testing.T) { issuer: tu.ValidIssuer, keySet: tu.KeySet{}, }, - want: &accessTokenVerifier{ - issuer: tu.ValidIssuer, - keySet: tu.KeySet{}, + want: &AccessTokenVerifier{ + Issuer: tu.ValidIssuer, + KeySet: tu.KeySet{}, }, }, { @@ -42,10 +42,10 @@ func TestNewAccessTokenVerifier(t *testing.T) { WithSupportedAccessTokenSigningAlgorithms("ABC", "DEF"), }, }, - want: &accessTokenVerifier{ - issuer: tu.ValidIssuer, - keySet: tu.KeySet{}, - supportedSignAlgs: []string{"ABC", "DEF"}, + want: &AccessTokenVerifier{ + Issuer: tu.ValidIssuer, + KeySet: tu.KeySet{}, + SupportedSignAlgs: []string{"ABC", "DEF"}, }, }, } @@ -58,12 +58,12 @@ func TestNewAccessTokenVerifier(t *testing.T) { } func TestVerifyAccessToken(t *testing.T) { - verifier := &accessTokenVerifier{ - issuer: tu.ValidIssuer, - maxAgeIAT: 2 * time.Minute, - offset: time.Second, - supportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, - keySet: tu.KeySet{}, + verifier := &AccessTokenVerifier{ + Issuer: tu.ValidIssuer, + MaxAgeIAT: 2 * time.Minute, + Offset: time.Second, + SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, + KeySet: tu.KeySet{}, } tests := []struct { diff --git a/pkg/op/verifier_id_token_hint.go b/pkg/op/verifier_id_token_hint.go index 50c3ff6a..61432527 100644 --- a/pkg/op/verifier_id_token_hint.go +++ b/pkg/op/verifier_id_token_hint.go @@ -2,69 +2,24 @@ package op import ( "context" - "time" "github.com/zitadel/oidc/v3/pkg/oidc" ) -type IDTokenHintVerifier interface { - oidc.Verifier - SupportedSignAlgs() []string - KeySet() oidc.KeySet - ACR() oidc.ACRVerifier - MaxAge() time.Duration -} - -type idTokenHintVerifier struct { - issuer string - maxAgeIAT time.Duration - offset time.Duration - supportedSignAlgs []string - maxAge time.Duration - acr oidc.ACRVerifier - keySet oidc.KeySet -} - -func (i *idTokenHintVerifier) Issuer() string { - return i.issuer -} - -func (i *idTokenHintVerifier) MaxAgeIAT() time.Duration { - return i.maxAgeIAT -} - -func (i *idTokenHintVerifier) Offset() time.Duration { - return i.offset -} - -func (i *idTokenHintVerifier) SupportedSignAlgs() []string { - return i.supportedSignAlgs -} - -func (i *idTokenHintVerifier) KeySet() oidc.KeySet { - return i.keySet -} - -func (i *idTokenHintVerifier) ACR() oidc.ACRVerifier { - return i.acr -} - -func (i *idTokenHintVerifier) MaxAge() time.Duration { - return i.maxAge -} +type IDTokenHintVerifier oidc.Verifier -type IDTokenHintVerifierOpt func(*idTokenHintVerifier) +type IDTokenHintVerifierOpt func(*IDTokenHintVerifier) func WithSupportedIDTokenHintSigningAlgorithms(algs ...string) IDTokenHintVerifierOpt { - return func(verifier *idTokenHintVerifier) { - verifier.supportedSignAlgs = algs + return func(verifier *IDTokenHintVerifier) { + verifier.SupportedSignAlgs = algs } } -func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHintVerifierOpt) IDTokenHintVerifier { - verifier := &idTokenHintVerifier{ - issuer: issuer, - keySet: keySet, +func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHintVerifierOpt) *IDTokenHintVerifier { + verifier := &IDTokenHintVerifier{ + Issuer: issuer, + KeySet: keySet, } for _, opt := range opts { opt(verifier) @@ -74,7 +29,7 @@ func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHi // VerifyIDTokenHint validates the id token according to // https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation -func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v IDTokenHintVerifier) (claims C, err error) { +func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v *IDTokenHintVerifier) (claims C, err error) { var nilClaims C decrypted, err := oidc.DecryptToken(token) @@ -86,27 +41,27 @@ func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v IDTok return nilClaims, err } - if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil { + if err := oidc.CheckIssuer(claims, v.Issuer); err != nil { return nilClaims, err } - if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil { + if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil { return nilClaims, err } - if err = oidc.CheckExpiration(claims, v.Offset()); err != nil { + if err = oidc.CheckExpiration(claims, v.Offset); err != nil { return nilClaims, err } - if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil { + if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT, v.Offset); err != nil { return nilClaims, err } - if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil { + if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR); err != nil { return nilClaims, err } - if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil { + if err = oidc.CheckAuthTime(claims, v.MaxAge); err != nil { return nilClaims, err } return claims, nil diff --git a/pkg/op/verifier_id_token_hint_test.go b/pkg/op/verifier_id_token_hint_test.go index 9f4c6c18..e514a76e 100644 --- a/pkg/op/verifier_id_token_hint_test.go +++ b/pkg/op/verifier_id_token_hint_test.go @@ -20,7 +20,7 @@ func TestNewIDTokenHintVerifier(t *testing.T) { tests := []struct { name string args args - want IDTokenHintVerifier + want *IDTokenHintVerifier }{ { name: "simple", @@ -28,9 +28,9 @@ func TestNewIDTokenHintVerifier(t *testing.T) { issuer: tu.ValidIssuer, keySet: tu.KeySet{}, }, - want: &idTokenHintVerifier{ - issuer: tu.ValidIssuer, - keySet: tu.KeySet{}, + want: &IDTokenHintVerifier{ + Issuer: tu.ValidIssuer, + KeySet: tu.KeySet{}, }, }, { @@ -42,10 +42,10 @@ func TestNewIDTokenHintVerifier(t *testing.T) { WithSupportedIDTokenHintSigningAlgorithms("ABC", "DEF"), }, }, - want: &idTokenHintVerifier{ - issuer: tu.ValidIssuer, - keySet: tu.KeySet{}, - supportedSignAlgs: []string{"ABC", "DEF"}, + want: &IDTokenHintVerifier{ + Issuer: tu.ValidIssuer, + KeySet: tu.KeySet{}, + SupportedSignAlgs: []string{"ABC", "DEF"}, }, }, } @@ -58,14 +58,14 @@ func TestNewIDTokenHintVerifier(t *testing.T) { } func TestVerifyIDTokenHint(t *testing.T) { - verifier := &idTokenHintVerifier{ - issuer: tu.ValidIssuer, - maxAgeIAT: 2 * time.Minute, - offset: time.Second, - supportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, - maxAge: 2 * time.Minute, - acr: tu.ACRVerify, - keySet: tu.KeySet{}, + verifier := &IDTokenHintVerifier{ + Issuer: tu.ValidIssuer, + MaxAgeIAT: 2 * time.Minute, + Offset: time.Second, + SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, + MaxAge: 2 * time.Minute, + ACR: tu.ACRVerify, + KeySet: tu.KeySet{}, } tests := []struct { diff --git a/pkg/op/verifier_jwt_profile.go b/pkg/op/verifier_jwt_profile.go index b7dfec71..1daa15fc 100644 --- a/pkg/op/verifier_jwt_profile.go +++ b/pkg/op/verifier_jwt_profile.go @@ -11,28 +11,25 @@ import ( "github.com/zitadel/oidc/v3/pkg/oidc" ) -type JWTProfileVerifier interface { +// JWTProfileVerfiier extends oidc.Verifier with +// a jwtProfileKeyStorage and a function to check +// the subject in a token. +type JWTProfileVerifier struct { oidc.Verifier - Storage() jwtProfileKeyStorage - CheckSubject(request *oidc.JWTTokenRequest) error -} - -type jwtProfileVerifier struct { - storage jwtProfileKeyStorage - subjectCheck func(request *oidc.JWTTokenRequest) error - issuer string - maxAgeIAT time.Duration - offset time.Duration + Storage JWTProfileKeyStorage + CheckSubject func(request *oidc.JWTTokenRequest) error } // NewJWTProfileVerifier creates a oidc.Verifier for JWT Profile assertions (authorization grant and client authentication) -func NewJWTProfileVerifier(storage jwtProfileKeyStorage, issuer string, maxAgeIAT, offset time.Duration, opts ...JWTProfileVerifierOption) JWTProfileVerifier { - j := &jwtProfileVerifier{ - storage: storage, - subjectCheck: SubjectIsIssuer, - issuer: issuer, - maxAgeIAT: maxAgeIAT, - offset: offset, +func NewJWTProfileVerifier(storage JWTProfileKeyStorage, issuer string, maxAgeIAT, offset time.Duration, opts ...JWTProfileVerifierOption) *JWTProfileVerifier { + j := &JWTProfileVerifier{ + Verifier: oidc.Verifier{ + Issuer: issuer, + MaxAgeIAT: maxAgeIAT, + Offset: offset, + }, + Storage: storage, + CheckSubject: SubjectIsIssuer, } for _, opt := range opts { @@ -42,53 +39,35 @@ func NewJWTProfileVerifier(storage jwtProfileKeyStorage, issuer string, maxAgeIA return j } -type JWTProfileVerifierOption func(*jwtProfileVerifier) +type JWTProfileVerifierOption func(*JWTProfileVerifier) +// SubjectCheck sets a custom function to check the subject. +// Defaults to SubjectIsIssuer() func SubjectCheck(check func(request *oidc.JWTTokenRequest) error) JWTProfileVerifierOption { - return func(verifier *jwtProfileVerifier) { - verifier.subjectCheck = check + return func(verifier *JWTProfileVerifier) { + verifier.CheckSubject = check } } -func (v *jwtProfileVerifier) Issuer() string { - return v.issuer -} - -func (v *jwtProfileVerifier) Storage() jwtProfileKeyStorage { - return v.storage -} - -func (v *jwtProfileVerifier) MaxAgeIAT() time.Duration { - return v.maxAgeIAT -} - -func (v *jwtProfileVerifier) Offset() time.Duration { - return v.offset -} - -func (v *jwtProfileVerifier) CheckSubject(request *oidc.JWTTokenRequest) error { - return v.subjectCheck(request) -} - // VerifyJWTAssertion verifies the assertion string from JWT Profile (authorization grant and client authentication) // // checks audience, exp, iat, signature and that issuer and sub are the same -func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerifier) (*oidc.JWTTokenRequest, error) { +func VerifyJWTAssertion(ctx context.Context, assertion string, v *JWTProfileVerifier) (*oidc.JWTTokenRequest, error) { request := new(oidc.JWTTokenRequest) payload, err := oidc.ParseToken(assertion, request) if err != nil { return nil, err } - if err = oidc.CheckAudience(request, v.Issuer()); err != nil { + if err = oidc.CheckAudience(request, v.Issuer); err != nil { return nil, err } - if err = oidc.CheckExpiration(request, v.Offset()); err != nil { + if err = oidc.CheckExpiration(request, v.Offset); err != nil { return nil, err } - if err = oidc.CheckIssuedAt(request, v.MaxAgeIAT(), v.Offset()); err != nil { + if err = oidc.CheckIssuedAt(request, v.MaxAgeIAT, v.Offset); err != nil { return nil, err } @@ -96,17 +75,18 @@ func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerif return nil, err } - keySet := &jwtProfileKeySet{storage: v.Storage(), clientID: request.Issuer} + keySet := &jwtProfileKeySet{storage: v.Storage, clientID: request.Issuer} if err = oidc.CheckSignature(ctx, assertion, payload, request, nil, keySet); err != nil { return nil, err } return request, nil } -type jwtProfileKeyStorage interface { +type JWTProfileKeyStorage interface { GetKeyByIDAndClientID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error) } +// SubjectIsIssuer func SubjectIsIssuer(request *oidc.JWTTokenRequest) error { if request.Issuer != request.Subject { return errors.New("delegation not allowed, issuer and sub must be identical") @@ -115,7 +95,7 @@ func SubjectIsIssuer(request *oidc.JWTTokenRequest) error { } type jwtProfileKeySet struct { - storage jwtProfileKeyStorage + storage JWTProfileKeyStorage clientID string } diff --git a/pkg/op/verifier_jwt_profile_test.go b/pkg/op/verifier_jwt_profile_test.go new file mode 100644 index 00000000..d96cbb43 --- /dev/null +++ b/pkg/op/verifier_jwt_profile_test.go @@ -0,0 +1,117 @@ +package op_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + tu "github.com/zitadel/oidc/v3/internal/testutil" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" +) + +func TestNewJWTProfileVerifier(t *testing.T) { + want := &op.JWTProfileVerifier{ + Verifier: oidc.Verifier{ + Issuer: tu.ValidIssuer, + MaxAgeIAT: time.Minute, + Offset: time.Second, + }, + Storage: tu.JWTProfileKeyStorage{}, + } + got := op.NewJWTProfileVerifier(tu.JWTProfileKeyStorage{}, tu.ValidIssuer, time.Minute, time.Second, op.SubjectCheck(func(request *oidc.JWTTokenRequest) error { + return oidc.ErrSubjectMissing + })) + assert.Equal(t, want.Verifier, got.Verifier) + assert.Equal(t, want.Storage, got.Storage) + assert.ErrorIs(t, got.CheckSubject(nil), oidc.ErrSubjectMissing) +} + +func TestVerifyJWTAssertion(t *testing.T) { + errCtx, cancel := context.WithCancel(context.Background()) + cancel() + + verifier := op.NewJWTProfileVerifier(tu.JWTProfileKeyStorage{}, tu.ValidIssuer, time.Minute, 0) + tests := []struct { + name string + ctx context.Context + newToken func() (string, *oidc.JWTTokenRequest) + wantErr bool + }{ + { + name: "parse error", + ctx: context.Background(), + newToken: func() (string, *oidc.JWTTokenRequest) { return "!", nil }, + wantErr: true, + }, + { + name: "wrong audience", + ctx: context.Background(), + newToken: func() (string, *oidc.JWTTokenRequest) { + return tu.NewJWTProfileAssertion( + tu.ValidClientID, tu.ValidClientID, []string{"wrong"}, + time.Now(), tu.ValidExpiration, + ) + }, + wantErr: true, + }, + { + name: "expired", + ctx: context.Background(), + newToken: func() (string, *oidc.JWTTokenRequest) { + return tu.NewJWTProfileAssertion( + tu.ValidClientID, tu.ValidClientID, []string{tu.ValidIssuer}, + time.Now(), time.Now().Add(-time.Hour), + ) + }, + wantErr: true, + }, + { + name: "invalid iat", + ctx: context.Background(), + newToken: func() (string, *oidc.JWTTokenRequest) { + return tu.NewJWTProfileAssertion( + tu.ValidClientID, tu.ValidClientID, []string{tu.ValidIssuer}, + time.Now().Add(time.Hour), tu.ValidExpiration, + ) + }, + wantErr: true, + }, + { + name: "invalid subject", + ctx: context.Background(), + newToken: func() (string, *oidc.JWTTokenRequest) { + return tu.NewJWTProfileAssertion( + tu.ValidClientID, "wrong", []string{tu.ValidIssuer}, + time.Now(), tu.ValidExpiration, + ) + }, + wantErr: true, + }, + { + name: "check signature fail", + ctx: errCtx, + newToken: tu.ValidJWTProfileAssertion, + wantErr: true, + }, + { + name: "ok", + ctx: context.Background(), + newToken: tu.ValidJWTProfileAssertion, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assertion, want := tt.newToken() + got, err := op.VerifyJWTAssertion(tt.ctx, assertion, verifier) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, want, got) + }) + } +} From 6af94fded0a1d5ddb448799358b029733b77d7a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Thu, 23 Mar 2023 16:31:38 +0200 Subject: [PATCH 04/14] feat: add context to all client calls (#345) BREAKING CHANGE closes #309 --- example/client/api/api.go | 3 +- example/client/app/app.go | 3 +- example/client/device/device.go | 4 +- example/client/service/service.go | 4 +- pkg/client/client.go | 30 ++++++------- pkg/client/integration_test.go | 30 ++++++++++--- pkg/client/jwt_profile.go | 5 ++- pkg/client/key.go | 8 ++-- pkg/client/profile/jwt_profile.go | 51 ++++++++++++++++------- pkg/client/rp/device.go | 4 +- pkg/client/rp/relying_party.go | 43 +++++-------------- pkg/client/rs/resource_server.go | 18 ++++---- pkg/client/tokenexchange/tokenexchange.go | 16 +++---- pkg/http/http.go | 4 +- 14 files changed, 124 insertions(+), 99 deletions(-) diff --git a/example/client/api/api.go b/example/client/api/api.go index 95e84e7e..83ec2a15 100644 --- a/example/client/api/api.go +++ b/example/client/api/api.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/json" "fmt" "log" @@ -27,7 +28,7 @@ func main() { port := os.Getenv("PORT") issuer := os.Getenv("ISSUER") - provider, err := rs.NewResourceServerFromKeyFile(issuer, keyPath) + provider, err := rs.NewResourceServerFromKeyFile(context.TODO(), issuer, keyPath) if err != nil { logrus.Fatalf("error creating provider %s", err.Error()) } diff --git a/example/client/app/app.go b/example/client/app/app.go index 446c17be..2cb5dfa7 100644 --- a/example/client/app/app.go +++ b/example/client/app/app.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/json" "fmt" "net/http" @@ -43,7 +44,7 @@ func main() { options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath))) } - provider, err := rp.NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes, options...) + provider, err := rp.NewRelyingPartyOIDC(context.TODO(), issuer, clientID, clientSecret, redirectURI, scopes, options...) if err != nil { logrus.Fatalf("error creating provider %s", err.Error()) } diff --git a/example/client/device/device.go b/example/client/device/device.go index 88ecfe99..c186b341 100644 --- a/example/client/device/device.go +++ b/example/client/device/device.go @@ -39,13 +39,13 @@ func main() { options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath))) } - provider, err := rp.NewRelyingPartyOIDC(issuer, clientID, clientSecret, "", scopes, options...) + provider, err := rp.NewRelyingPartyOIDC(ctx, issuer, clientID, clientSecret, "", scopes, options...) if err != nil { logrus.Fatalf("error creating provider %s", err.Error()) } logrus.Info("starting device authorization flow") - resp, err := rp.DeviceAuthorization(scopes, provider) + resp, err := rp.DeviceAuthorization(ctx, scopes, provider) if err != nil { logrus.Fatal(err) } diff --git a/example/client/service/service.go b/example/client/service/service.go index 4908b095..ffcdccb3 100644 --- a/example/client/service/service.go +++ b/example/client/service/service.go @@ -25,7 +25,7 @@ func main() { scopes := strings.Split(os.Getenv("SCOPES"), " ") if keyPath != "" { - ts, err := profile.NewJWTProfileTokenSourceFromKeyFile(issuer, keyPath, scopes) + ts, err := profile.NewJWTProfileTokenSourceFromKeyFile(context.TODO(), issuer, keyPath, scopes) if err != nil { logrus.Fatalf("error creating token source %s", err.Error()) } @@ -76,7 +76,7 @@ func main() { http.Error(w, err.Error(), http.StatusInternalServerError) return } - ts, err := profile.NewJWTProfileTokenSourceFromKeyFileData(issuer, key, scopes) + ts, err := profile.NewJWTProfileTokenSourceFromKeyFileData(context.TODO(), issuer, key, scopes) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return diff --git a/pkg/client/client.go b/pkg/client/client.go index e9af8ce7..b9580ff0 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -23,12 +23,12 @@ var Encoder = httphelper.Encoder(oidc.NewEncoder()) // Discover calls the discovery endpoint of the provided issuer and returns its configuration // It accepts an optional argument "wellknownUrl" which can be used to overide the dicovery endpoint url -func Discover(issuer string, httpClient *http.Client, wellKnownUrl ...string) (*oidc.DiscoveryConfiguration, error) { +func Discover(ctx context.Context, issuer string, httpClient *http.Client, wellKnownUrl ...string) (*oidc.DiscoveryConfiguration, error) { wellKnown := strings.TrimSuffix(issuer, "/") + oidc.DiscoveryEndpoint if len(wellKnownUrl) == 1 && wellKnownUrl[0] != "" { wellKnown = wellKnownUrl[0] } - req, err := http.NewRequest("GET", wellKnown, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnown, nil) if err != nil { return nil, err } @@ -48,12 +48,12 @@ type TokenEndpointCaller interface { HttpClient() *http.Client } -func CallTokenEndpoint(request interface{}, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) { - return callTokenEndpoint(request, nil, caller) +func CallTokenEndpoint(ctx context.Context, request interface{}, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) { + return callTokenEndpoint(ctx, request, nil, caller) } -func callTokenEndpoint(request interface{}, authFn interface{}, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) { - req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, authFn) +func callTokenEndpoint(ctx context.Context, request interface{}, authFn interface{}, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) { + req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, authFn) if err != nil { return nil, err } @@ -74,8 +74,8 @@ type EndSessionCaller interface { HttpClient() *http.Client } -func CallEndSessionEndpoint(request interface{}, authFn interface{}, caller EndSessionCaller) (*url.URL, error) { - req, err := httphelper.FormRequest(caller.GetEndSessionEndpoint(), request, Encoder, authFn) +func CallEndSessionEndpoint(ctx context.Context, request interface{}, authFn interface{}, caller EndSessionCaller) (*url.URL, error) { + req, err := httphelper.FormRequest(ctx, caller.GetEndSessionEndpoint(), request, Encoder, authFn) if err != nil { return nil, err } @@ -117,8 +117,8 @@ type RevokeRequest struct { ClientSecret string `schema:"client_secret"` } -func CallRevokeEndpoint(request interface{}, authFn interface{}, caller RevokeCaller) error { - req, err := httphelper.FormRequest(caller.GetRevokeEndpoint(), request, Encoder, authFn) +func CallRevokeEndpoint(ctx context.Context, request interface{}, authFn interface{}, caller RevokeCaller) error { + req, err := httphelper.FormRequest(ctx, caller.GetRevokeEndpoint(), request, Encoder, authFn) if err != nil { return err } @@ -145,8 +145,8 @@ func CallRevokeEndpoint(request interface{}, authFn interface{}, caller RevokeCa return nil } -func CallTokenExchangeEndpoint(request interface{}, authFn interface{}, caller TokenEndpointCaller) (resp *oidc.TokenExchangeResponse, err error) { - req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, authFn) +func CallTokenExchangeEndpoint(ctx context.Context, request interface{}, authFn interface{}, caller TokenEndpointCaller) (resp *oidc.TokenExchangeResponse, err error) { + req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, authFn) if err != nil { return nil, err } @@ -186,8 +186,8 @@ type DeviceAuthorizationCaller interface { HttpClient() *http.Client } -func CallDeviceAuthorizationEndpoint(request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller) (*oidc.DeviceAuthorizationResponse, error) { - req, err := httphelper.FormRequest(caller.GetDeviceAuthorizationEndpoint(), request, Encoder, nil) +func CallDeviceAuthorizationEndpoint(ctx context.Context, request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller) (*oidc.DeviceAuthorizationResponse, error) { + req, err := httphelper.FormRequest(ctx, caller.GetDeviceAuthorizationEndpoint(), request, Encoder, nil) if err != nil { return nil, err } @@ -208,7 +208,7 @@ type DeviceAccessTokenRequest struct { } func CallDeviceAccessTokenEndpoint(ctx context.Context, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) { - req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, nil) + req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, nil) if err != nil { return nil, err } diff --git a/pkg/client/integration_test.go b/pkg/client/integration_test.go index 709d5a13..2c3ef623 100644 --- a/pkg/client/integration_test.go +++ b/pkg/client/integration_test.go @@ -2,6 +2,7 @@ package client_test import ( "bytes" + "context" "io" "io/ioutil" "math/rand" @@ -10,7 +11,9 @@ import ( "net/http/httptest" "net/url" "os" + "os/signal" "strconv" + "syscall" "testing" "time" @@ -27,6 +30,18 @@ import ( "github.com/zitadel/oidc/v3/pkg/oidc" ) +var CTX context.Context + +func TestMain(m *testing.M) { + os.Exit(func() int { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGINT) + defer cancel() + CTX, cancel = context.WithTimeout(ctx, time.Minute) + defer cancel() + return m.Run() + }()) +} + func TestRelyingPartySession(t *testing.T) { t.Log("------- start example OP ------") targetURL := "http://local-site" @@ -45,7 +60,7 @@ func TestRelyingPartySession(t *testing.T) { t.Log("------- refresh tokens ------") - newTokens, err := rp.RefreshAccessToken(provider, refreshToken, "", "") + newTokens, err := rp.RefreshAccessToken(CTX, provider, refreshToken, "", "") require.NoError(t, err, "refresh token") assert.NotNil(t, newTokens, "access token") t.Logf("new access token %s", newTokens.AccessToken) @@ -56,7 +71,7 @@ func TestRelyingPartySession(t *testing.T) { t.Log("------ end session (logout) ------") - newLoc, err := rp.EndSession(provider, idToken, "", "") + newLoc, err := rp.EndSession(CTX, provider, idToken, "", "") require.NoError(t, err, "logout") if newLoc != nil { t.Logf("redirect to %s", newLoc) @@ -66,11 +81,11 @@ func TestRelyingPartySession(t *testing.T) { t.Log("------ attempt refresh again (should fail) ------") t.Log("trying original refresh token", refreshToken) - _, err = rp.RefreshAccessToken(provider, refreshToken, "", "") + _, err = rp.RefreshAccessToken(CTX, provider, refreshToken, "", "") assert.Errorf(t, err, "refresh with original") if newTokens.RefreshToken != "" { t.Log("trying replacement refresh token", newTokens.RefreshToken) - _, err = rp.RefreshAccessToken(provider, newTokens.RefreshToken, "", "") + _, err = rp.RefreshAccessToken(CTX, provider, newTokens.RefreshToken, "", "") assert.Errorf(t, err, "refresh with replacement") } } @@ -92,12 +107,13 @@ func TestResourceServerTokenExchange(t *testing.T) { t.Log("------- run authorization code flow ------") provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret) - resourceServer, err := rs.NewResourceServerClientCredentials(opServer.URL, clientID, clientSecret) + resourceServer, err := rs.NewResourceServerClientCredentials(CTX, opServer.URL, clientID, clientSecret) require.NoError(t, err, "new resource server") t.Log("------- exchage refresh tokens (impersonation) ------") tokenExchangeResponse, err := tokenexchange.ExchangeToken( + CTX, resourceServer, refreshToken, oidc.RefreshTokenType, @@ -117,7 +133,7 @@ func TestResourceServerTokenExchange(t *testing.T) { t.Log("------ end session (logout) ------") - newLoc, err := rp.EndSession(provider, idToken, "", "") + newLoc, err := rp.EndSession(CTX, provider, idToken, "", "") require.NoError(t, err, "logout") if newLoc != nil { t.Logf("redirect to %s", newLoc) @@ -128,6 +144,7 @@ func TestResourceServerTokenExchange(t *testing.T) { t.Log("------- attempt exchage again (should fail) ------") tokenExchangeResponse, err = tokenexchange.ExchangeToken( + CTX, resourceServer, refreshToken, oidc.RefreshTokenType, @@ -166,6 +183,7 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, key := []byte("test1234test1234") cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure()) provider, err = rp.NewRelyingPartyOIDC( + CTX, opServer.URL, clientID, clientSecret, diff --git a/pkg/client/jwt_profile.go b/pkg/client/jwt_profile.go index 486d998e..0a5d9ec5 100644 --- a/pkg/client/jwt_profile.go +++ b/pkg/client/jwt_profile.go @@ -1,6 +1,7 @@ package client import ( + "context" "net/url" "golang.org/x/oauth2" @@ -10,8 +11,8 @@ import ( ) // JWTProfileExchange handles the oauth2 jwt profile exchange -func JWTProfileExchange(jwtProfileGrantRequest *oidc.JWTProfileGrantRequest, caller TokenEndpointCaller) (*oauth2.Token, error) { - return CallTokenEndpoint(jwtProfileGrantRequest, caller) +func JWTProfileExchange(ctx context.Context, jwtProfileGrantRequest *oidc.JWTProfileGrantRequest, caller TokenEndpointCaller) (*oauth2.Token, error) { + return CallTokenEndpoint(ctx, jwtProfileGrantRequest, caller) } func ClientAssertionCodeOptions(assertion string) []oauth2.AuthCodeOption { diff --git a/pkg/client/key.go b/pkg/client/key.go index 740c6d33..0c01dd22 100644 --- a/pkg/client/key.go +++ b/pkg/client/key.go @@ -10,7 +10,7 @@ const ( applicationKey = "application" ) -type keyFile struct { +type KeyFile struct { Type string `json:"type"` // serviceaccount or application KeyID string `json:"keyId"` Key string `json:"key"` @@ -23,7 +23,7 @@ type keyFile struct { ClientID string `json:"clientId"` } -func ConfigFromKeyFile(path string) (*keyFile, error) { +func ConfigFromKeyFile(path string) (*KeyFile, error) { data, err := ioutil.ReadFile(path) if err != nil { return nil, err @@ -31,8 +31,8 @@ func ConfigFromKeyFile(path string) (*keyFile, error) { return ConfigFromKeyFileData(data) } -func ConfigFromKeyFileData(data []byte) (*keyFile, error) { - var f keyFile +func ConfigFromKeyFileData(data []byte) (*KeyFile, error) { + var f KeyFile if err := json.Unmarshal(data, &f); err != nil { return nil, err } diff --git a/pkg/client/profile/jwt_profile.go b/pkg/client/profile/jwt_profile.go index bb185707..668f749c 100644 --- a/pkg/client/profile/jwt_profile.go +++ b/pkg/client/profile/jwt_profile.go @@ -1,6 +1,7 @@ package profile import ( + "context" "net/http" "time" @@ -11,9 +12,12 @@ import ( "github.com/zitadel/oidc/v3/pkg/oidc" ) -// jwtProfileTokenSource implement the oauth2.TokenSource -// it will request a token using the OAuth2 JWT Profile Grant -// therefore sending an `assertion` by singing a JWT with the provided private key +type TokenSource interface { + oauth2.TokenSource + TokenCtx(context.Context) (*oauth2.Token, error) +} + +// jwtProfileTokenSource implements the TokenSource type jwtProfileTokenSource struct { clientID string audience []string @@ -23,23 +27,38 @@ type jwtProfileTokenSource struct { tokenEndpoint string } -func NewJWTProfileTokenSourceFromKeyFile(issuer, keyPath string, scopes []string, options ...func(source *jwtProfileTokenSource)) (oauth2.TokenSource, error) { - keyData, err := client.ConfigFromKeyFile(keyPath) +// NewJWTProfileTokenSourceFromKeyFile returns an implementation of TokenSource +// It will request a token using the OAuth2 JWT Profile Grant, +// therefore sending an `assertion` by singing a JWT with the provided private key from jsonFile. +// +// The passed context is only used for the call to the Discover endpoint. +func NewJWTProfileTokenSourceFromKeyFile(ctx context.Context, issuer, jsonFile string, scopes []string, options ...func(source *jwtProfileTokenSource)) (TokenSource, error) { + keyData, err := client.ConfigFromKeyFile(jsonFile) if err != nil { return nil, err } - return NewJWTProfileTokenSource(issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...) + return NewJWTProfileTokenSource(ctx, issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...) } -func NewJWTProfileTokenSourceFromKeyFileData(issuer string, data []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (oauth2.TokenSource, error) { - keyData, err := client.ConfigFromKeyFileData(data) +// NewJWTProfileTokenSourceFromKeyFileData returns an implementation of oauth2.TokenSource +// It will request a token using the OAuth2 JWT Profile Grant, +// therefore sending an `assertion` by singing a JWT with the provided private key in jsonData. +// +// The passed context is only used for the call to the Discover endpoint. +func NewJWTProfileTokenSourceFromKeyFileData(ctx context.Context, issuer string, jsonData []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (TokenSource, error) { + keyData, err := client.ConfigFromKeyFileData(jsonData) if err != nil { return nil, err } - return NewJWTProfileTokenSource(issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...) + return NewJWTProfileTokenSource(ctx, issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...) } -func NewJWTProfileTokenSource(issuer, clientID, keyID string, key []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (oauth2.TokenSource, error) { +// NewJWTProfileSource returns an implementation of oauth2.TokenSource +// It will request a token using the OAuth2 JWT Profile Grant, +// therefore sending an `assertion` by singing a JWT with the provided private key. +// +// The passed context is only used for the call to the Discover endpoint. +func NewJWTProfileTokenSource(ctx context.Context, issuer, clientID, keyID string, key []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (TokenSource, error) { signer, err := client.NewSignerFromPrivateKeyByte(key, keyID) if err != nil { return nil, err @@ -55,7 +74,7 @@ func NewJWTProfileTokenSource(issuer, clientID, keyID string, key []byte, scopes opt(source) } if source.tokenEndpoint == "" { - config, err := client.Discover(issuer, source.httpClient) + config, err := client.Discover(ctx, issuer, source.httpClient) if err != nil { return nil, err } @@ -64,13 +83,13 @@ func NewJWTProfileTokenSource(issuer, clientID, keyID string, key []byte, scopes return source, nil } -func WithHTTPClient(client *http.Client) func(*jwtProfileTokenSource) { +func WithHTTPClient(client *http.Client) func(source *jwtProfileTokenSource) { return func(source *jwtProfileTokenSource) { source.httpClient = client } } -func WithStaticTokenEndpoint(issuer, tokenEndpoint string) func(*jwtProfileTokenSource) { +func WithStaticTokenEndpoint(issuer, tokenEndpoint string) func(source *jwtProfileTokenSource) { return func(source *jwtProfileTokenSource) { source.tokenEndpoint = tokenEndpoint } @@ -85,9 +104,13 @@ func (j *jwtProfileTokenSource) HttpClient() *http.Client { } func (j *jwtProfileTokenSource) Token() (*oauth2.Token, error) { + return j.TokenCtx(context.Background()) +} + +func (j *jwtProfileTokenSource) TokenCtx(ctx context.Context) (*oauth2.Token, error) { assertion, err := client.SignedJWTProfileAssertion(j.clientID, j.audience, time.Hour, j.signer) if err != nil { return nil, err } - return client.JWTProfileExchange(oidc.NewJWTProfileGrantRequest(assertion, j.scopes...), j) + return client.JWTProfileExchange(ctx, oidc.NewJWTProfileGrantRequest(assertion, j.scopes...), j) } diff --git a/pkg/client/rp/device.go b/pkg/client/rp/device.go index 9cfc41e5..b2c5be68 100644 --- a/pkg/client/rp/device.go +++ b/pkg/client/rp/device.go @@ -33,13 +33,13 @@ func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc. // DeviceAuthorization starts a new Device Authorization flow as defined // in RFC 8628, section 3.1 and 3.2: // https://www.rfc-editor.org/rfc/rfc8628#section-3.1 -func DeviceAuthorization(scopes []string, rp RelyingParty) (*oidc.DeviceAuthorizationResponse, error) { +func DeviceAuthorization(ctx context.Context, scopes []string, rp RelyingParty) (*oidc.DeviceAuthorizationResponse, error) { req, err := newDeviceClientCredentialsRequest(scopes, rp) if err != nil { return nil, err } - return client.CallDeviceAuthorizationEndpoint(req, rp) + return client.CallDeviceAuthorizationEndpoint(ctx, req, rp) } // DeviceAccessToken attempts to obtain tokens from a Device Authorization, diff --git a/pkg/client/rp/relying_party.go b/pkg/client/rp/relying_party.go index bd96e160..820107f6 100644 --- a/pkg/client/rp/relying_party.go +++ b/pkg/client/rp/relying_party.go @@ -7,7 +7,6 @@ import ( "fmt" "net/http" "net/url" - "strings" "time" "github.com/google/uuid" @@ -177,7 +176,7 @@ func NewRelyingPartyOAuth(config *oauth2.Config, options ...Option) (RelyingPart // NewRelyingPartyOIDC creates an (OIDC) RelyingParty with the given // issuer, clientID, clientSecret, redirectURI, scopes and possible configOptions // it will run discovery on the provided issuer and use the found endpoints -func NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI string, scopes []string, options ...Option) (RelyingParty, error) { +func NewRelyingPartyOIDC(ctx context.Context, issuer, clientID, clientSecret, redirectURI string, scopes []string, options ...Option) (RelyingParty, error) { rp := &relyingParty{ issuer: issuer, oauthConfig: &oauth2.Config{ @@ -195,7 +194,7 @@ func NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI string, sco return nil, err } } - discoveryConfiguration, err := client.Discover(rp.issuer, rp.httpClient, rp.DiscoveryEndpoint) + discoveryConfiguration, err := client.Discover(ctx, rp.issuer, rp.httpClient, rp.DiscoveryEndpoint) if err != nil { return nil, err } @@ -310,26 +309,6 @@ func SignerFromKeyAndKeyID(key []byte, keyID string) SignerFromKey { } } -// Discover calls the discovery endpoint of the provided issuer and returns the found endpoints -// -// deprecated: use client.Discover -func Discover(issuer string, httpClient *http.Client) (Endpoints, error) { - wellKnown := strings.TrimSuffix(issuer, "/") + oidc.DiscoveryEndpoint - req, err := http.NewRequest("GET", wellKnown, nil) - if err != nil { - return Endpoints{}, err - } - discoveryConfig := new(oidc.DiscoveryConfiguration) - err = httphelper.HttpRequest(httpClient, req, &discoveryConfig) - if err != nil { - return Endpoints{}, err - } - if discoveryConfig.Issuer != issuer { - return Endpoints{}, oidc.ErrIssuerInvalid - } - return GetEndpoints(discoveryConfig), nil -} - // AuthURL returns the auth request url // (wrapping the oauth2 `AuthCodeURL`) func AuthURL(state string, rp RelyingParty, opts ...AuthURLOpt) string { @@ -463,7 +442,7 @@ type CodeExchangeUserinfoCallback[C oidc.IDClaims] func(w http.ResponseWriter, r // on success it will pass the userinfo into its callback function as well func UserinfoCallback[C oidc.IDClaims](f CodeExchangeUserinfoCallback[C]) CodeExchangeCallback[C] { return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) { - info, err := Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp) + info, err := Userinfo(r.Context(), tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp) if err != nil { http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized) return @@ -473,8 +452,8 @@ func UserinfoCallback[C oidc.IDClaims](f CodeExchangeUserinfoCallback[C]) CodeEx } // Userinfo will call the OIDC Userinfo Endpoint with the provided token -func Userinfo(token, tokenType, subject string, rp RelyingParty) (*oidc.UserInfo, error) { - req, err := http.NewRequest("GET", rp.UserinfoEndpoint(), nil) +func Userinfo(ctx context.Context, token, tokenType, subject string, rp RelyingParty) (*oidc.UserInfo, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, rp.UserinfoEndpoint(), nil) if err != nil { return nil, err } @@ -620,7 +599,7 @@ type RefreshTokenRequest struct { GrantType oidc.GrantType `schema:"grant_type"` } -func RefreshAccessToken(rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oauth2.Token, error) { +func RefreshAccessToken(ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oauth2.Token, error) { request := RefreshTokenRequest{ RefreshToken: refreshToken, Scopes: rp.OAuthConfig().Scopes, @@ -630,17 +609,17 @@ func RefreshAccessToken(rp RelyingParty, refreshToken, clientAssertion, clientAs ClientAssertionType: clientAssertionType, GrantType: oidc.GrantTypeRefreshToken, } - return client.CallTokenEndpoint(request, tokenEndpointCaller{RelyingParty: rp}) + return client.CallTokenEndpoint(ctx, request, tokenEndpointCaller{RelyingParty: rp}) } -func EndSession(rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) { +func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) { request := oidc.EndSessionRequest{ IdTokenHint: idToken, ClientID: rp.OAuthConfig().ClientID, PostLogoutRedirectURI: optionalRedirectURI, State: optionalState, } - return client.CallEndSessionEndpoint(request, nil, rp) + return client.CallEndSessionEndpoint(ctx, request, nil, rp) } // RevokeToken requires a RelyingParty that is also a client.RevokeCaller. The RelyingParty @@ -648,7 +627,7 @@ func EndSession(rp RelyingParty, idToken, optionalRedirectURI, optionalState str // NewRelyingPartyOAuth() does not. // // tokenTypeHint should be either "id_token" or "refresh_token". -func RevokeToken(rp RelyingParty, token string, tokenTypeHint string) error { +func RevokeToken(ctx context.Context, rp RelyingParty, token string, tokenTypeHint string) error { request := client.RevokeRequest{ Token: token, TokenTypeHint: tokenTypeHint, @@ -656,7 +635,7 @@ func RevokeToken(rp RelyingParty, token string, tokenTypeHint string) error { ClientSecret: rp.OAuthConfig().ClientSecret, } if rc, ok := rp.(client.RevokeCaller); ok && rc.GetRevokeEndpoint() != "" { - return client.CallRevokeEndpoint(request, nil, rc) + return client.CallRevokeEndpoint(ctx, request, nil, rc) } return fmt.Errorf("RelyingParty does not support RevokeCaller") } diff --git a/pkg/client/rs/resource_server.go b/pkg/client/rs/resource_server.go index f0e0e0ab..054dfbef 100644 --- a/pkg/client/rs/resource_server.go +++ b/pkg/client/rs/resource_server.go @@ -42,14 +42,14 @@ func (r *resourceServer) AuthFn() (interface{}, error) { return r.authFn() } -func NewResourceServerClientCredentials(issuer, clientID, clientSecret string, option ...Option) (ResourceServer, error) { +func NewResourceServerClientCredentials(ctx context.Context, issuer, clientID, clientSecret string, option ...Option) (ResourceServer, error) { authorizer := func() (interface{}, error) { return httphelper.AuthorizeBasic(clientID, clientSecret), nil } - return newResourceServer(issuer, authorizer, option...) + return newResourceServer(ctx, issuer, authorizer, option...) } -func NewResourceServerJWTProfile(issuer, clientID, keyID string, key []byte, options ...Option) (ResourceServer, error) { +func NewResourceServerJWTProfile(ctx context.Context, issuer, clientID, keyID string, key []byte, options ...Option) (ResourceServer, error) { signer, err := client.NewSignerFromPrivateKeyByte(key, keyID) if err != nil { return nil, err @@ -61,10 +61,10 @@ func NewResourceServerJWTProfile(issuer, clientID, keyID string, key []byte, opt } return client.ClientAssertionFormAuthorization(assertion), nil } - return newResourceServer(issuer, authorizer, options...) + return newResourceServer(ctx, issuer, authorizer, options...) } -func newResourceServer(issuer string, authorizer func() (interface{}, error), options ...Option) (*resourceServer, error) { +func newResourceServer(ctx context.Context, issuer string, authorizer func() (interface{}, error), options ...Option) (*resourceServer, error) { rs := &resourceServer{ issuer: issuer, httpClient: httphelper.DefaultHTTPClient, @@ -73,7 +73,7 @@ func newResourceServer(issuer string, authorizer func() (interface{}, error), op optFunc(rs) } if rs.introspectURL == "" || rs.tokenURL == "" { - config, err := client.Discover(rs.issuer, rs.httpClient) + config, err := client.Discover(ctx, rs.issuer, rs.httpClient) if err != nil { return nil, err } @@ -87,12 +87,12 @@ func newResourceServer(issuer string, authorizer func() (interface{}, error), op return rs, nil } -func NewResourceServerFromKeyFile(issuer, path string, options ...Option) (ResourceServer, error) { +func NewResourceServerFromKeyFile(ctx context.Context, issuer, path string, options ...Option) (ResourceServer, error) { c, err := client.ConfigFromKeyFile(path) if err != nil { return nil, err } - return NewResourceServerJWTProfile(issuer, c.ClientID, c.KeyID, []byte(c.Key), options...) + return NewResourceServerJWTProfile(ctx, issuer, c.ClientID, c.KeyID, []byte(c.Key), options...) } type Option func(*resourceServer) @@ -117,7 +117,7 @@ func Introspect(ctx context.Context, rp ResourceServer, token string) (*oidc.Int if err != nil { return nil, err } - req, err := httphelper.FormRequest(rp.IntrospectionURL(), &oidc.IntrospectionRequest{Token: token}, client.Encoder, authFn) + req, err := httphelper.FormRequest(ctx, rp.IntrospectionURL(), &oidc.IntrospectionRequest{Token: token}, client.Encoder, authFn) if err != nil { return nil, err } diff --git a/pkg/client/tokenexchange/tokenexchange.go b/pkg/client/tokenexchange/tokenexchange.go index ce665cdc..1c10df28 100644 --- a/pkg/client/tokenexchange/tokenexchange.go +++ b/pkg/client/tokenexchange/tokenexchange.go @@ -1,6 +1,7 @@ package tokenexchange import ( + "context" "errors" "net/http" @@ -21,18 +22,18 @@ type OAuthTokenExchange struct { authFn func() (interface{}, error) } -func NewTokenExchanger(issuer string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) { - return newOAuthTokenExchange(issuer, nil, options...) +func NewTokenExchanger(ctx context.Context, issuer string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) { + return newOAuthTokenExchange(ctx, issuer, nil, options...) } -func NewTokenExchangerClientCredentials(issuer, clientID, clientSecret string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) { +func NewTokenExchangerClientCredentials(ctx context.Context, issuer, clientID, clientSecret string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) { authorizer := func() (interface{}, error) { return httphelper.AuthorizeBasic(clientID, clientSecret), nil } - return newOAuthTokenExchange(issuer, authorizer, options...) + return newOAuthTokenExchange(ctx, issuer, authorizer, options...) } -func newOAuthTokenExchange(issuer string, authorizer func() (interface{}, error), options ...func(source *OAuthTokenExchange)) (*OAuthTokenExchange, error) { +func newOAuthTokenExchange(ctx context.Context, issuer string, authorizer func() (interface{}, error), options ...func(source *OAuthTokenExchange)) (*OAuthTokenExchange, error) { te := &OAuthTokenExchange{ httpClient: httphelper.DefaultHTTPClient, } @@ -41,7 +42,7 @@ func newOAuthTokenExchange(issuer string, authorizer func() (interface{}, error) } if te.tokenEndpoint == "" { - config, err := client.Discover(issuer, te.httpClient) + config, err := client.Discover(ctx, issuer, te.httpClient) if err != nil { return nil, err } @@ -89,6 +90,7 @@ func (te *OAuthTokenExchange) AuthFn() (interface{}, error) { // ExchangeToken sends a token exchange request (rfc 8693) to te's token endpoint. // SubjectToken and SubjectTokenType are required parameters. func ExchangeToken( + ctx context.Context, te TokenExchanger, SubjectToken string, SubjectTokenType oidc.TokenType, @@ -123,5 +125,5 @@ func ExchangeToken( RequestedTokenType: RequestedTokenType, } - return client.CallTokenExchangeEndpoint(request, authFn, te) + return client.CallTokenExchangeEndpoint(ctx, request, authFn, te) } diff --git a/pkg/http/http.go b/pkg/http/http.go index d3c5b4f7..97718885 100644 --- a/pkg/http/http.go +++ b/pkg/http/http.go @@ -33,7 +33,7 @@ func AuthorizeBasic(user, password string) RequestAuthorization { } } -func FormRequest(endpoint string, request interface{}, encoder Encoder, authFn interface{}) (*http.Request, error) { +func FormRequest(ctx context.Context, endpoint string, request interface{}, encoder Encoder, authFn interface{}) (*http.Request, error) { form := url.Values{} if err := encoder.Encode(request, form); err != nil { return nil, err @@ -42,7 +42,7 @@ func FormRequest(endpoint string, request interface{}, encoder Encoder, authFn i fn(form) } body := strings.NewReader(form.Encode()) - req, err := http.NewRequest("POST", endpoint, body) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, body) if err != nil { return nil, err } From adebbe4c32e0e30ed6dc94b372e8c1f0fcbd53aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Tue, 28 Mar 2023 14:57:27 +0300 Subject: [PATCH 05/14] chore: replace gorilla/schema with zitadel/schema (#348) Fixes #302 --- go.mod | 2 +- go.sum | 4 ++-- pkg/oidc/types.go | 2 +- pkg/oidc/types_test.go | 2 +- pkg/op/auth_request_test.go | 2 +- pkg/op/client_test.go | 2 +- pkg/op/mock/authorizer.mock.impl.go | 2 +- pkg/op/op.go | 2 +- 8 files changed, 9 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index 7cee26ef..8f571576 100644 --- a/go.mod +++ b/go.mod @@ -7,13 +7,13 @@ require ( github.com/golang/mock v1.6.0 github.com/google/go-github/v31 v31.0.0 github.com/google/uuid v1.3.0 - github.com/gorilla/schema v1.2.0 github.com/gorilla/securecookie v1.1.1 github.com/jeremija/gosubmit v0.2.7 github.com/muhlemmer/gu v0.3.1 github.com/rs/cors v1.8.3 github.com/sirupsen/logrus v1.9.0 github.com/stretchr/testify v1.8.2 + github.com/zitadel/schema v1.3.0 golang.org/x/oauth2 v0.6.0 golang.org/x/text v0.8.0 gopkg.in/square/go-jose.v2 v2.6.0 diff --git a/go.sum b/go.sum index a5ba6426..b1757d49 100644 --- a/go.sum +++ b/go.sum @@ -21,8 +21,6 @@ github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gorilla/schema v1.2.0 h1:YufUaxZYCKGFuAq3c96BOhjgd5nmXiOY9NGzF247Tsc= -github.com/gorilla/schema v1.2.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU= github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/jeremija/gosubmit v0.2.7 h1:At0OhGCFGPXyjPYAsCchoBUhE099pcBXmsb4iZqROIc= @@ -49,6 +47,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/zitadel/schema v1.3.0 h1:kQ9W9tvIwZICCKWcMvCEweXET1OcOyGEuFbHs4o5kg0= +github.com/zitadel/schema v1.3.0/go.mod h1:NptN6mkBDFvERUCvZHlvWmmME+gmZ44xzwRXwhzsbtc= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= diff --git a/pkg/oidc/types.go b/pkg/oidc/types.go index 167f8b78..86ee1e0f 100644 --- a/pkg/oidc/types.go +++ b/pkg/oidc/types.go @@ -8,7 +8,7 @@ import ( "strings" "time" - "github.com/gorilla/schema" + "github.com/zitadel/schema" "golang.org/x/text/language" "gopkg.in/square/go-jose.v2" ) diff --git a/pkg/oidc/types_test.go b/pkg/oidc/types_test.go index 64f07f16..4bf6e55f 100644 --- a/pkg/oidc/types_test.go +++ b/pkg/oidc/types_test.go @@ -9,9 +9,9 @@ import ( "testing" "time" - "github.com/gorilla/schema" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/zitadel/schema" "golang.org/x/text/language" ) diff --git a/pkg/op/auth_request_test.go b/pkg/op/auth_request_test.go index 3179e258..4e801796 100644 --- a/pkg/op/auth_request_test.go +++ b/pkg/op/auth_request_test.go @@ -9,7 +9,6 @@ import ( "reflect" "testing" - "github.com/gorilla/schema" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" tu "github.com/zitadel/oidc/v3/internal/testutil" @@ -17,6 +16,7 @@ import ( "github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/op" "github.com/zitadel/oidc/v3/pkg/op/mock" + "github.com/zitadel/schema" ) // diff --git a/pkg/op/client_test.go b/pkg/op/client_test.go index bb17192a..0321f88a 100644 --- a/pkg/op/client_test.go +++ b/pkg/op/client_test.go @@ -11,13 +11,13 @@ import ( "testing" "github.com/golang/mock/gomock" - "github.com/gorilla/schema" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/op" "github.com/zitadel/oidc/v3/pkg/op/mock" + "github.com/zitadel/schema" ) type testClientJWTProfile struct{} diff --git a/pkg/op/mock/authorizer.mock.impl.go b/pkg/op/mock/authorizer.mock.impl.go index 409683ab..4d66a922 100644 --- a/pkg/op/mock/authorizer.mock.impl.go +++ b/pkg/op/mock/authorizer.mock.impl.go @@ -5,7 +5,7 @@ import ( "testing" "github.com/golang/mock/gomock" - "github.com/gorilla/schema" + "github.com/zitadel/schema" "gopkg.in/square/go-jose.v2" "github.com/zitadel/oidc/v3/pkg/oidc" diff --git a/pkg/op/op.go b/pkg/op/op.go index 9ed5662c..1cdb3bc9 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -7,8 +7,8 @@ import ( "time" "github.com/go-chi/chi" - "github.com/gorilla/schema" "github.com/rs/cors" + "github.com/zitadel/schema" "golang.org/x/text/language" "gopkg.in/square/go-jose.v2" From c778e8329c2b694a1e9231c3babd4489e6a4667e Mon Sep 17 00:00:00 2001 From: Thomas Hipp Date: Mon, 3 Apr 2023 14:40:29 +0200 Subject: [PATCH 06/14] feat: Allow modifying request to device authorization endpoint (#356) * feat: Allow modifying request to device authorization endpoint This change enables the caller to set URL parameters when calling the device authorization endpoint. Fixes #354 * Update device authorization example --- example/client/device/device.go | 2 +- pkg/client/client.go | 4 ++-- pkg/client/rp/device.go | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/example/client/device/device.go b/example/client/device/device.go index c186b341..bea61345 100644 --- a/example/client/device/device.go +++ b/example/client/device/device.go @@ -45,7 +45,7 @@ func main() { } logrus.Info("starting device authorization flow") - resp, err := rp.DeviceAuthorization(ctx, scopes, provider) + resp, err := rp.DeviceAuthorization(ctx, scopes, provider, nil) if err != nil { logrus.Fatal(err) } diff --git a/pkg/client/client.go b/pkg/client/client.go index b9580ff0..37c7ec27 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -186,8 +186,8 @@ type DeviceAuthorizationCaller interface { HttpClient() *http.Client } -func CallDeviceAuthorizationEndpoint(ctx context.Context, request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller) (*oidc.DeviceAuthorizationResponse, error) { - req, err := httphelper.FormRequest(ctx, caller.GetDeviceAuthorizationEndpoint(), request, Encoder, nil) +func CallDeviceAuthorizationEndpoint(ctx context.Context, request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller, authFn any) (*oidc.DeviceAuthorizationResponse, error) { + req, err := httphelper.FormRequest(ctx, caller.GetDeviceAuthorizationEndpoint(), request, Encoder, authFn) if err != nil { return nil, err } diff --git a/pkg/client/rp/device.go b/pkg/client/rp/device.go index b2c5be68..788e23eb 100644 --- a/pkg/client/rp/device.go +++ b/pkg/client/rp/device.go @@ -33,13 +33,13 @@ func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc. // DeviceAuthorization starts a new Device Authorization flow as defined // in RFC 8628, section 3.1 and 3.2: // https://www.rfc-editor.org/rfc/rfc8628#section-3.1 -func DeviceAuthorization(ctx context.Context, scopes []string, rp RelyingParty) (*oidc.DeviceAuthorizationResponse, error) { +func DeviceAuthorization(ctx context.Context, scopes []string, rp RelyingParty, authFn any) (*oidc.DeviceAuthorizationResponse, error) { req, err := newDeviceClientCredentialsRequest(scopes, rp) if err != nil { return nil, err } - return client.CallDeviceAuthorizationEndpoint(ctx, req, rp) + return client.CallDeviceAuthorizationEndpoint(ctx, req, rp, authFn) } // DeviceAccessToken attempts to obtain tokens from a Device Authorization, From 312c2a07e21bd1c3e862150da1f45cbddc6ba45c Mon Sep 17 00:00:00 2001 From: Thomas Hipp Date: Thu, 13 Apr 2023 15:04:58 +0200 Subject: [PATCH 07/14] fix: Only set GrantType once (#353) (#367) This fixes an issue where, when using the device authorization flow, the grant type would be set twice. Some OPs don't accept this, and fail when polling. With this fix the grant type is only set once, which will make some OPs happy again. Fixes #352 --- pkg/client/rp/device.go | 1 - pkg/oidc/token_request.go | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pkg/client/rp/device.go b/pkg/client/rp/device.go index 788e23eb..390c8cf4 100644 --- a/pkg/client/rp/device.go +++ b/pkg/client/rp/device.go @@ -12,7 +12,6 @@ import ( func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc.ClientCredentialsRequest, error) { confg := rp.OAuthConfig() req := &oidc.ClientCredentialsRequest{ - GrantType: oidc.GrantTypeDeviceCode, Scope: scopes, ClientID: confg.ClientID, ClientSecret: confg.ClientSecret, diff --git a/pkg/oidc/token_request.go b/pkg/oidc/token_request.go index 6b6945a1..5c5cf20f 100644 --- a/pkg/oidc/token_request.go +++ b/pkg/oidc/token_request.go @@ -241,7 +241,7 @@ type TokenExchangeRequest struct { } type ClientCredentialsRequest struct { - GrantType GrantType `schema:"grant_type"` + GrantType GrantType `schema:"grant_type,omitempty"` Scope SpaceDelimitedArray `schema:"scope"` ClientID string `schema:"client_id"` ClientSecret string `schema:"client_secret"` From e43ac6dfdfd9c9ae928e45628b9471745f327384 Mon Sep 17 00:00:00 2001 From: Giulio Ruggeri Date: Wed, 3 May 2023 12:27:28 +0200 Subject: [PATCH 08/14] fix: modify ACRValues parameter type to space separated strings (#388) Co-authored-by: Giulio Ruggeri --- pkg/oidc/authorization.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/oidc/authorization.go b/pkg/oidc/authorization.go index f620ecb9..d8bf3364 100644 --- a/pkg/oidc/authorization.go +++ b/pkg/oidc/authorization.go @@ -60,7 +60,7 @@ const ( ) // AuthRequest according to: -//https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest +// https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest type AuthRequest struct { Scopes SpaceDelimitedArray `json:"scope" schema:"scope"` ResponseType ResponseType `json:"response_type" schema:"response_type"` @@ -77,7 +77,7 @@ type AuthRequest struct { UILocales Locales `json:"ui_locales" schema:"ui_locales"` IDTokenHint string `json:"id_token_hint" schema:"id_token_hint"` LoginHint string `json:"login_hint" schema:"login_hint"` - ACRValues []string `json:"acr_values" schema:"acr_values"` + ACRValues SpaceDelimitedArray `json:"acr_values" schema:"acr_values"` CodeChallenge string `json:"code_challenge" schema:"code_challenge"` CodeChallengeMethod CodeChallengeMethod `json:"code_challenge_method" schema:"code_challenge_method"` From d5a9bd6d0e798d43c0d8cc4ea62347bcce18548a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Fri, 5 May 2023 14:36:37 +0200 Subject: [PATCH 09/14] feat: generic Userinfo and Introspect functions (#389) BREAKING CHANGE: rp.Userinfo and rs.Introspect now require a type parameter. --- example/client/api/api.go | 4 +- pkg/client/rp/relying_party.go | 30 +++++++++----- pkg/client/rp/userinfo_example_test.go | 45 ++++++++++++++++++++ pkg/client/rs/introspect_example_test.go | 52 ++++++++++++++++++++++++ pkg/client/rs/resource_server.go | 18 +++++--- pkg/oidc/userinfo.go | 5 +++ 6 files changed, 136 insertions(+), 18 deletions(-) create mode 100644 pkg/client/rp/userinfo_example_test.go create mode 100644 pkg/client/rs/introspect_example_test.go diff --git a/example/client/api/api.go b/example/client/api/api.go index 83ec2a15..2f81c07b 100644 --- a/example/client/api/api.go +++ b/example/client/api/api.go @@ -48,7 +48,7 @@ func main() { if !ok { return } - resp, err := rs.Introspect(r.Context(), provider, token) + resp, err := rs.Introspect[*oidc.IntrospectionResponse](r.Context(), provider, token) if err != nil { http.Error(w, err.Error(), http.StatusForbidden) return @@ -69,7 +69,7 @@ func main() { if !ok { return } - resp, err := rs.Introspect(r.Context(), provider, token) + resp, err := rs.Introspect[*oidc.IntrospectionResponse](r.Context(), provider, token) if err != nil { http.Error(w, err.Error(), http.StatusForbidden) return diff --git a/pkg/client/rp/relying_party.go b/pkg/client/rp/relying_party.go index b93a373c..7d73a5a2 100644 --- a/pkg/client/rp/relying_party.go +++ b/pkg/client/rp/relying_party.go @@ -435,14 +435,18 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R } } -type CodeExchangeUserinfoCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, provider RelyingParty, info *oidc.UserInfo) +type SubjectGetter interface { + GetSubject() string +} + +type CodeExchangeUserinfoCallback[C oidc.IDClaims, U SubjectGetter] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, provider RelyingParty, info U) // UserinfoCallback wraps the callback function of the CodeExchangeHandler // and calls the userinfo endpoint with the access token // on success it will pass the userinfo into its callback function as well -func UserinfoCallback[C oidc.IDClaims](f CodeExchangeUserinfoCallback[C]) CodeExchangeCallback[C] { +func UserinfoCallback[C oidc.IDClaims, U SubjectGetter](f CodeExchangeUserinfoCallback[C, U]) CodeExchangeCallback[C] { return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) { - info, err := Userinfo(r.Context(), tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp) + info, err := Userinfo[U](r.Context(), tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp) if err != nil { http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized) return @@ -451,19 +455,25 @@ func UserinfoCallback[C oidc.IDClaims](f CodeExchangeUserinfoCallback[C]) CodeEx } } -// Userinfo will call the OIDC Userinfo Endpoint with the provided token -func Userinfo(ctx context.Context, token, tokenType, subject string, rp RelyingParty) (*oidc.UserInfo, error) { +// Userinfo will call the OIDC [UserInfo] Endpoint with the provided token and returns +// the response in an instance of type U. +// [*oidc.UserInfo] can be used as a good example, or use a custom type if type-safe +// access to custom claims is needed. +// +// [UserInfo]: https://openid.net/specs/openid-connect-core-1_0.html#UserInfo +func Userinfo[U SubjectGetter](ctx context.Context, token, tokenType, subject string, rp RelyingParty) (userinfo U, err error) { + var nilU U + req, err := http.NewRequestWithContext(ctx, http.MethodGet, rp.UserinfoEndpoint(), nil) if err != nil { - return nil, err + return nilU, err } req.Header.Set("authorization", tokenType+" "+token) - userinfo := new(oidc.UserInfo) if err := httphelper.HttpRequest(rp.HttpClient(), req, &userinfo); err != nil { - return nil, err + return nilU, err } - if userinfo.Subject != subject { - return nil, ErrUserInfoSubNotMatching + if userinfo.GetSubject() != subject { + return nilU, ErrUserInfoSubNotMatching } return userinfo, nil } diff --git a/pkg/client/rp/userinfo_example_test.go b/pkg/client/rp/userinfo_example_test.go new file mode 100644 index 00000000..2cc52228 --- /dev/null +++ b/pkg/client/rp/userinfo_example_test.go @@ -0,0 +1,45 @@ +package rp_test + +import ( + "context" + "fmt" + + "github.com/zitadel/oidc/v3/pkg/client/rp" + "github.com/zitadel/oidc/v3/pkg/oidc" +) + +type UserInfo struct { + Subject string `json:"sub,omitempty"` + oidc.UserInfoProfile + oidc.UserInfoEmail + oidc.UserInfoPhone + Address *oidc.UserInfoAddress `json:"address,omitempty"` + + // Foo and Bar are custom claims + Foo string `json:"foo,omitempty"` + Bar struct { + Val1 string `json:"val_1,omitempty"` + Val2 string `json:"val_2,omitempty"` + } `json:"bar,omitempty"` + + // Claims are all the combined claims, including custom. + Claims map[string]any `json:"-,omitempty"` +} + +func (u *UserInfo) GetSubject() string { + return u.Subject +} + +func ExampleUserinfo_custom() { + rpo, err := rp.NewRelyingPartyOIDC(context.TODO(), "http://localhost:8080", "clientid", "clientsecret", "http://example.com/redirect", []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopePhone}) + if err != nil { + panic(err) + } + + info, err := rp.Userinfo[*UserInfo](context.TODO(), "accesstokenstring", "Bearer", "userid", rpo) + if err != nil { + panic(err) + } + + fmt.Println(info) +} diff --git a/pkg/client/rs/introspect_example_test.go b/pkg/client/rs/introspect_example_test.go new file mode 100644 index 00000000..eac8be27 --- /dev/null +++ b/pkg/client/rs/introspect_example_test.go @@ -0,0 +1,52 @@ +package rs_test + +import ( + "context" + "fmt" + + "github.com/zitadel/oidc/v3/pkg/client/rs" + "github.com/zitadel/oidc/v3/pkg/oidc" +) + +type IntrospectionResponse struct { + Active bool `json:"active"` + Scope oidc.SpaceDelimitedArray `json:"scope,omitempty"` + ClientID string `json:"client_id,omitempty"` + TokenType string `json:"token_type,omitempty"` + Expiration oidc.Time `json:"exp,omitempty"` + IssuedAt oidc.Time `json:"iat,omitempty"` + NotBefore oidc.Time `json:"nbf,omitempty"` + Subject string `json:"sub,omitempty"` + Audience oidc.Audience `json:"aud,omitempty"` + Issuer string `json:"iss,omitempty"` + JWTID string `json:"jti,omitempty"` + Username string `json:"username,omitempty"` + oidc.UserInfoProfile + oidc.UserInfoEmail + oidc.UserInfoPhone + Address *oidc.UserInfoAddress `json:"address,omitempty"` + + // Foo and Bar are custom claims + Foo string `json:"foo,omitempty"` + Bar struct { + Val1 string `json:"val_1,omitempty"` + Val2 string `json:"val_2,omitempty"` + } `json:"bar,omitempty"` + + // Claims are all the combined claims, including custom. + Claims map[string]any `json:"-,omitempty"` +} + +func ExampleIntrospect_custom() { + rss, err := rs.NewResourceServerClientCredentials(context.TODO(), "http://localhost:8080", "clientid", "clientsecret") + if err != nil { + panic(err) + } + + resp, err := rs.Introspect[*IntrospectionResponse](context.TODO(), rss, "accesstokenstring") + if err != nil { + panic(err) + } + + fmt.Println(resp) +} diff --git a/pkg/client/rs/resource_server.go b/pkg/client/rs/resource_server.go index 054dfbef..49232b2c 100644 --- a/pkg/client/rs/resource_server.go +++ b/pkg/client/rs/resource_server.go @@ -112,18 +112,24 @@ func WithStaticEndpoints(tokenURL, introspectURL string) Option { } } -func Introspect(ctx context.Context, rp ResourceServer, token string) (*oidc.IntrospectionResponse, error) { +// Introspect calls the [RFC7662] Token Introspection +// endpoint and returns the response in an instance of type R. +// [*oidc.IntrospectionResponse] can be used as a good example, or use a custom type if type-safe +// access to custom claims is needed. +// +// [RFC7662]: https://www.rfc-editor.org/rfc/rfc7662 +func Introspect[R any](ctx context.Context, rp ResourceServer, token string) (resp R, err error) { authFn, err := rp.AuthFn() if err != nil { - return nil, err + return resp, err } req, err := httphelper.FormRequest(ctx, rp.IntrospectionURL(), &oidc.IntrospectionRequest{Token: token}, client.Encoder, authFn) if err != nil { - return nil, err + return resp, err } - resp := new(oidc.IntrospectionResponse) - if err := httphelper.HttpRequest(rp.HttpClient(), req, resp); err != nil { - return nil, err + + if err := httphelper.HttpRequest(rp.HttpClient(), req, &resp); err != nil { + return resp, err } return resp, nil } diff --git a/pkg/oidc/userinfo.go b/pkg/oidc/userinfo.go index caff58e9..ef8ebe46 100644 --- a/pkg/oidc/userinfo.go +++ b/pkg/oidc/userinfo.go @@ -29,6 +29,11 @@ func (u *UserInfo) GetAddress() *UserInfoAddress { return u.Address } +// GetSubject implements [rp.SubjectGetter] +func (u *UserInfo) GetSubject() string { + return u.Subject +} + type uiAlias UserInfo func (u *UserInfo) MarshalJSON() ([]byte, error) { From e8262cbf1fc224e8a7447cfc3d1ffa93c1a70b3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Fri, 26 May 2023 11:06:33 +0300 Subject: [PATCH 10/14] chore: cleanup unneeded device storage methods (#399) BREAKING CHANGE, removes methods from DeviceAuthorizationStorage: - GetDeviceAuthorizationByUserCode - CompleteDeviceAuthorization - DenyDeviceAuthorization The methods are now moved to examples as something similar can be userful for implementers. --- example/server/exampleop/device.go | 13 +++++++++++++ pkg/op/device_test.go | 5 +++-- pkg/op/storage.go | 12 ------------ 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/example/server/exampleop/device.go b/example/server/exampleop/device.go index 0dda3d53..7478750c 100644 --- a/example/server/exampleop/device.go +++ b/example/server/exampleop/device.go @@ -1,6 +1,7 @@ package exampleop import ( + "context" "errors" "fmt" "io" @@ -16,6 +17,18 @@ import ( type deviceAuthenticate interface { CheckUsernamePasswordSimple(username, password string) error op.DeviceAuthorizationStorage + + // GetDeviceAuthorizationByUserCode resturns the current state of the device authorization flow, + // identified by the user code. + GetDeviceAuthorizationByUserCode(ctx context.Context, userCode string) (*op.DeviceAuthorizationState, error) + + // CompleteDeviceAuthorization marks a device authorization entry as Completed, + // identified by userCode. The Subject is added to the state, so that + // GetDeviceAuthorizatonState can use it to create a new Access Token. + CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error + + // DenyDeviceAuthorization marks a device authorization entry as Denied. + DenyDeviceAuthorization(ctx context.Context, userCode string) error } type deviceLogin struct { diff --git a/pkg/op/device_test.go b/pkg/op/device_test.go index 1e32554b..5fe6e27b 100644 --- a/pkg/op/device_test.go +++ b/pkg/op/device_test.go @@ -16,6 +16,7 @@ import ( "github.com/muhlemmer/gu" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/zitadel/oidc/v3/example/server/storage" "github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/op" ) @@ -304,7 +305,7 @@ func BenchmarkNewUserCode(b *testing.B) { } func TestDeviceAccessToken(t *testing.T) { - storage := testProvider.Storage().(op.DeviceAuthorizationStorage) + storage := testProvider.Storage().(*storage.Storage) storage.StoreDeviceAuthorization(context.Background(), "native", "qwerty", "yuiop", time.Now().Add(time.Minute), []string{"foo"}) storage.CompleteDeviceAuthorization(context.Background(), "yuiop", "tim") @@ -329,7 +330,7 @@ func TestDeviceAccessToken(t *testing.T) { func TestCheckDeviceAuthorizationState(t *testing.T) { now := time.Now() - storage := testProvider.Storage().(op.DeviceAuthorizationStorage) + storage := testProvider.Storage().(*storage.Storage) storage.StoreDeviceAuthorization(context.Background(), "native", "pending", "pending", now.Add(time.Minute), []string{"foo"}) storage.StoreDeviceAuthorization(context.Background(), "native", "denied", "denied", now.Add(time.Minute), []string{"foo"}) storage.StoreDeviceAuthorization(context.Background(), "native", "completed", "completed", now.Add(time.Minute), []string{"foo"}) diff --git a/pkg/op/storage.go b/pkg/op/storage.go index aa8721ac..23d21334 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -182,18 +182,6 @@ type DeviceAuthorizationStorage interface { // GetDeviceAuthorizatonState returns the current state of the device authorization flow in the database. // The method is polled untill the the authorization is eighter Completed, Expired or Denied. GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (*DeviceAuthorizationState, error) - - // GetDeviceAuthorizationByUserCode resturn the current state of the device authorization flow, - // identified by the user code. - GetDeviceAuthorizationByUserCode(ctx context.Context, userCode string) (*DeviceAuthorizationState, error) - - // CompleteDeviceAuthorization marks a device authorization entry as Completed, - // identified by userCode. The Subject is added to the state, so that - // GetDeviceAuthorizatonState can use it to create a new Access Token. - CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error - - // DenyDeviceAuthorization marks a device authorization entry as Denied. - DenyDeviceAuthorization(ctx context.Context, userCode string) error } func assertDeviceStorage(s Storage) (DeviceAuthorizationStorage, error) { From 6708ef4c247e6583abe54750427e93c70aaa6c29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Fri, 18 Aug 2023 15:36:39 +0300 Subject: [PATCH 11/14] feat(rp): return oidc.Tokens on token refresh (#423) BREAKING CHANGE: - rename RefreshAccessToken to RefreshToken - RefreshToken returns *oidc.Tokens instead of *oauth2.Token This change allows the return of the id_token in an explicit manner, as part of the oidc.Tokens struct. The return type is now consistent with the CodeExchange function. When an id_token is returned, it is verified. In case no id_token was received, RefreshTokens will not return an error. As per specifictation: https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse Upon successful validation of the Refresh Token, the response body is the Token Response of Section 3.1.3.3 except that it might not contain an id_token. Closes #364 --- pkg/client/integration_test.go | 43 ++++++----- pkg/client/rp/relying_party.go | 59 +++++++++------ pkg/client/rp/relying_party_test.go | 107 ++++++++++++++++++++++++++++ 3 files changed, 166 insertions(+), 43 deletions(-) create mode 100644 pkg/client/rp/relying_party_test.go diff --git a/pkg/client/integration_test.go b/pkg/client/integration_test.go index d8b3f255..073efef7 100644 --- a/pkg/client/integration_test.go +++ b/pkg/client/integration_test.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "io" - "io/ioutil" "math/rand" "net/http" "net/http/cookiejar" @@ -56,11 +55,11 @@ func TestRelyingPartySession(t *testing.T) { clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) t.Log("------- run authorization code flow ------") - provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, "secret") + provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, "secret") t.Log("------- refresh tokens ------") - newTokens, err := rp.RefreshAccessToken(CTX, provider, refreshToken, "", "") + newTokens, err := rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "") require.NoError(t, err, "refresh token") assert.NotNil(t, newTokens, "access token") t.Logf("new access token %s", newTokens.AccessToken) @@ -68,11 +67,13 @@ func TestRelyingPartySession(t *testing.T) { t.Logf("new token type %s", newTokens.TokenType) t.Logf("new expiry %s", newTokens.Expiry.Format(time.RFC3339)) require.NotEmpty(t, newTokens.AccessToken, "new accessToken") - assert.NotEmpty(t, newTokens.Extra("id_token"), "new idToken") + assert.NotEmpty(t, newTokens.IDToken, "new idToken") + assert.NotNil(t, newTokens.IDTokenClaims) + assert.Equal(t, newTokens.IDTokenClaims.Subject, tokens.IDTokenClaims.Subject) t.Log("------ end session (logout) ------") - newLoc, err := rp.EndSession(CTX, provider, idToken, "", "") + newLoc, err := rp.EndSession(CTX, provider, tokens.IDToken, "", "") require.NoError(t, err, "logout") if newLoc != nil { t.Logf("redirect to %s", newLoc) @@ -81,12 +82,12 @@ func TestRelyingPartySession(t *testing.T) { } t.Log("------ attempt refresh again (should fail) ------") - t.Log("trying original refresh token", refreshToken) - _, err = rp.RefreshAccessToken(CTX, provider, refreshToken, "", "") + t.Log("trying original refresh token", tokens.RefreshToken) + _, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "") assert.Errorf(t, err, "refresh with original") if newTokens.RefreshToken != "" { t.Log("trying replacement refresh token", newTokens.RefreshToken) - _, err = rp.RefreshAccessToken(CTX, provider, newTokens.RefreshToken, "", "") + _, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, newTokens.RefreshToken, "", "") assert.Errorf(t, err, "refresh with replacement") } } @@ -106,7 +107,7 @@ func TestResourceServerTokenExchange(t *testing.T) { clientSecret := "secret" t.Log("------- run authorization code flow ------") - provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret) + provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret) resourceServer, err := rs.NewResourceServerClientCredentials(CTX, opServer.URL, clientID, clientSecret) require.NoError(t, err, "new resource server") @@ -116,7 +117,7 @@ func TestResourceServerTokenExchange(t *testing.T) { tokenExchangeResponse, err := tokenexchange.ExchangeToken( CTX, resourceServer, - refreshToken, + tokens.RefreshToken, oidc.RefreshTokenType, "", "", @@ -134,7 +135,7 @@ func TestResourceServerTokenExchange(t *testing.T) { t.Log("------ end session (logout) ------") - newLoc, err := rp.EndSession(CTX, provider, idToken, "", "") + newLoc, err := rp.EndSession(CTX, provider, tokens.IDToken, "", "") require.NoError(t, err, "logout") if newLoc != nil { t.Logf("redirect to %s", newLoc) @@ -147,7 +148,7 @@ func TestResourceServerTokenExchange(t *testing.T) { tokenExchangeResponse, err = tokenexchange.ExchangeToken( CTX, resourceServer, - refreshToken, + tokens.RefreshToken, oidc.RefreshTokenType, "", "", @@ -161,7 +162,7 @@ func TestResourceServerTokenExchange(t *testing.T) { require.Nil(t, tokenExchangeResponse, "token exchange response") } -func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, accessToken, refreshToken, idToken string) { +func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, tokens *oidc.Tokens[*oidc.IDTokenClaims]) { targetURL := "http://local-site" localURL, err := url.Parse(targetURL + "/login?requestID=1234") require.NoError(t, err, "local url") @@ -258,7 +259,8 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, } var email string - redirect := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) { + redirect := func(w http.ResponseWriter, r *http.Request, newTokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) { + tokens = newTokens require.NotNil(t, tokens, "tokens") require.NotNil(t, info, "info") t.Log("access token", tokens.AccessToken) @@ -266,9 +268,6 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, t.Log("id token", tokens.IDToken) t.Log("email", info.Email) - accessToken = tokens.AccessToken - refreshToken = tokens.RefreshToken - idToken = tokens.IDToken email = info.Email http.Redirect(w, r, targetURL, 302) } @@ -290,12 +289,12 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, require.NoError(t, err, "get fully-authorizied redirect location") require.Equal(t, targetURL, authorizedURL.String(), "fully-authorizied redirect location") - require.NotEmpty(t, idToken, "id token") - assert.NotEmpty(t, refreshToken, "refresh token") - assert.NotEmpty(t, accessToken, "access token") + require.NotEmpty(t, tokens.IDToken, "id token") + assert.NotEmpty(t, tokens.RefreshToken, "refresh token") + assert.NotEmpty(t, tokens.AccessToken, "access token") assert.NotEmpty(t, email, "email") - return provider, accessToken, refreshToken, idToken + return provider, tokens } type deferredHandler struct { @@ -343,7 +342,7 @@ func getForm(t *testing.T, desc string, httpClient *http.Client, uri *url.URL) [ func fillForm(t *testing.T, desc string, httpClient *http.Client, body []byte, uri *url.URL, opts ...gosubmit.Option) *url.URL { // TODO: switch to io.NopCloser when go1.15 support is dropped - req := gosubmit.ParseWithURL(ioutil.NopCloser(bytes.NewReader(body)), uri.String()).FirstForm().Testing(t).NewTestRequest( + req := gosubmit.ParseWithURL(io.NopCloser(bytes.NewReader(body)), uri.String()).FirstForm().Testing(t).NewTestRequest( append([]gosubmit.Option{gosubmit.AutoFill()}, opts...)..., ) if req.URL.Scheme == "" { diff --git a/pkg/client/rp/relying_party.go b/pkg/client/rp/relying_party.go index 7d73a5a2..5597c9d9 100644 --- a/pkg/client/rp/relying_party.go +++ b/pkg/client/rp/relying_party.go @@ -356,6 +356,25 @@ func GenerateAndStoreCodeChallenge(w http.ResponseWriter, rp RelyingParty) (stri return oidc.NewSHACodeChallenge(codeVerifier), nil } +// ErrMissingIDToken is returned when an id_token was expected, +// but not received in the token response. +var ErrMissingIDToken = errors.New("id_token missing") + +func verifyTokenResponse[C oidc.IDClaims](ctx context.Context, token *oauth2.Token, rp RelyingParty) (*oidc.Tokens[C], error) { + if rp.IsOAuth2Only() { + return &oidc.Tokens[C]{Token: token}, nil + } + idTokenString, ok := token.Extra(idTokenKey).(string) + if !ok { + return &oidc.Tokens[C]{Token: token}, ErrMissingIDToken + } + idToken, err := VerifyTokens[C](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier()) + if err != nil { + return nil, err + } + return &oidc.Tokens[C]{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil +} + // CodeExchange handles the oauth2 code exchange, extracting and validating the id_token // returning it parsed together with the oauth2 tokens (access, refresh) func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens[C], err error) { @@ -369,22 +388,7 @@ func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingP if err != nil { return nil, err } - - if rp.IsOAuth2Only() { - return &oidc.Tokens[C]{Token: token}, nil - } - - idTokenString, ok := token.Extra(idTokenKey).(string) - if !ok { - return nil, errors.New("id_token missing") - } - - idToken, err := VerifyTokens[C](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier()) - if err != nil { - return nil, err - } - - return &oidc.Tokens[C]{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil + return verifyTokenResponse[C](ctx, token, rp) } type CodeExchangeCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) @@ -609,11 +613,14 @@ type RefreshTokenRequest struct { GrantType oidc.GrantType `schema:"grant_type"` } -// RefreshAccessToken performs a token refresh. If it doesn't error, it will always +// RefreshTokens performs a token refresh. If it doesn't error, it will always // provide a new AccessToken. It may provide a new RefreshToken, and if it does, then -// the old one should be considered invalid. It may also provide a new IDToken. The -// new IDToken can be retrieved with token.Extra("id_token"). -func RefreshAccessToken(ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oauth2.Token, error) { +// the old one should be considered invalid. +// +// In case the RP is not OAuth2 only and an IDToken was part of the response, +// the IDToken and AccessToken will be verfied +// and the IDToken and IDTokenClaims fields will be populated in the returned object. +func RefreshTokens[C oidc.IDClaims](ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oidc.Tokens[C], error) { request := RefreshTokenRequest{ RefreshToken: refreshToken, Scopes: rp.OAuthConfig().Scopes, @@ -623,7 +630,17 @@ func RefreshAccessToken(ctx context.Context, rp RelyingParty, refreshToken, clie ClientAssertionType: clientAssertionType, GrantType: oidc.GrantTypeRefreshToken, } - return client.CallTokenEndpoint(ctx, request, tokenEndpointCaller{RelyingParty: rp}) + newToken, err := client.CallTokenEndpoint(ctx, request, tokenEndpointCaller{RelyingParty: rp}) + if err != nil { + return nil, err + } + tokens, err := verifyTokenResponse[C](ctx, newToken, rp) + if err == nil || errors.Is(err, ErrMissingIDToken) { + // https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse + // ...except that it might not contain an id_token. + return tokens, nil + } + return nil, err } func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) { diff --git a/pkg/client/rp/relying_party_test.go b/pkg/client/rp/relying_party_test.go new file mode 100644 index 00000000..4c5a1b31 --- /dev/null +++ b/pkg/client/rp/relying_party_test.go @@ -0,0 +1,107 @@ +package rp + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + tu "github.com/zitadel/oidc/v3/internal/testutil" + "github.com/zitadel/oidc/v3/pkg/oidc" + "golang.org/x/oauth2" +) + +func Test_verifyTokenResponse(t *testing.T) { + verifier := &IDTokenVerifier{ + Issuer: tu.ValidIssuer, + MaxAgeIAT: 2 * time.Minute, + ClientID: tu.ValidClientID, + Offset: time.Second, + SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, + KeySet: tu.KeySet{}, + MaxAge: 2 * time.Minute, + ACR: tu.ACRVerify, + Nonce: func(context.Context) string { return tu.ValidNonce }, + } + tests := []struct { + name string + oauth2Only bool + tokens func() (token *oauth2.Token, want *oidc.Tokens[*oidc.IDTokenClaims]) + wantErr error + }{ + { + name: "succes, oauth2 only", + oauth2Only: true, + tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) { + accesToken, _ := tu.ValidAccessToken() + token := &oauth2.Token{ + AccessToken: accesToken, + } + return token, &oidc.Tokens[*oidc.IDTokenClaims]{ + Token: token, + } + }, + }, + { + name: "id_token missing error", + oauth2Only: false, + tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) { + accesToken, _ := tu.ValidAccessToken() + token := &oauth2.Token{ + AccessToken: accesToken, + } + return token, &oidc.Tokens[*oidc.IDTokenClaims]{ + Token: token, + } + }, + wantErr: ErrMissingIDToken, + }, + { + name: "verify tokens error", + oauth2Only: false, + tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) { + accesToken, _ := tu.ValidAccessToken() + token := &oauth2.Token{ + AccessToken: accesToken, + } + token = token.WithExtra(map[string]any{ + "id_token": "foobar", + }) + return token, nil + }, + wantErr: oidc.ErrParse, + }, + { + name: "success, with id_token", + oauth2Only: false, + tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) { + accesToken, _ := tu.ValidAccessToken() + token := &oauth2.Token{ + AccessToken: accesToken, + } + idToken, claims := tu.ValidIDToken() + token = token.WithExtra(map[string]any{ + "id_token": idToken, + }) + return token, &oidc.Tokens[*oidc.IDTokenClaims]{ + Token: token, + IDTokenClaims: claims, + IDToken: idToken, + } + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rp := &relyingParty{ + oauth2Only: tt.oauth2Only, + idTokenVerifier: verifier, + } + token, want := tt.tokens() + got, err := verifyTokenResponse[*oidc.IDTokenClaims](context.Background(), token, rp) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, want, got) + }) + } +} From 0879c883996b646ef0d96faa98e142afcfc9cd08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Tue, 29 Aug 2023 15:07:45 +0300 Subject: [PATCH 12/14] feat: add slog logging (#432) * feat(op): user slog for logging integrate with golang.org/x/exp/slog for logging. provide a middleware for request scoped logging. BREAKING CHANGES: 1. OpenIDProvider and sub-interfaces get a Logger() method to return the configured logger; 2. AuthRequestError now takes the complete Authorizer, instead of only the encoder. So that it may use its Logger() method. 3. RequestError now takes a Logger as argument. * use zitadel/logging * finish op and testing without middleware for now * minimum go version 1.19 * update go mod * log value testing only on go 1.20 or later * finish the RP and example * ping logging release --- .github/workflows/release.yml | 2 +- README.md | 6 +- example/client/app/app.go | 44 ++++- example/server/exampleop/op.go | 26 ++- example/server/main.go | 19 +- example/server/storage/oidc.go | 14 ++ go.mod | 8 +- go.sum | 13 +- pkg/client/client.go | 5 + pkg/client/integration_test.go | 12 +- pkg/client/rp/device.go | 2 + pkg/client/rp/log.go | 17 ++ pkg/client/rp/relying_party.go | 31 +++- pkg/oidc/authorization.go | 13 ++ pkg/oidc/authorization_test.go | 27 +++ pkg/oidc/error.go | 33 ++++ pkg/oidc/error_go120_test.go | 83 +++++++++ pkg/oidc/error_test.go | 81 +++++++++ pkg/oidc/types.go | 8 +- pkg/op/auth_request.go | 34 ++-- pkg/op/auth_request_test.go | 3 +- pkg/op/device.go | 4 +- pkg/op/error.go | 32 +++- pkg/op/error_test.go | 277 +++++++++++++++++++++++++++++ pkg/op/mock/authorizer.mock.go | 15 ++ pkg/op/op.go | 20 +++ pkg/op/op_test.go | 10 +- pkg/op/session.go | 6 +- pkg/op/token_client_credentials.go | 6 +- pkg/op/token_code.go | 8 +- pkg/op/token_exchange.go | 6 +- pkg/op/token_jwt_profile.go | 8 +- pkg/op/token_refresh.go | 6 +- pkg/op/token_request.go | 6 +- 34 files changed, 800 insertions(+), 85 deletions(-) create mode 100644 pkg/client/rp/log.go create mode 100644 pkg/oidc/authorization_test.go create mode 100644 pkg/oidc/error_go120_test.go create mode 100644 pkg/oidc/error_test.go create mode 100644 pkg/op/error_test.go diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7483b2f7..329428d7 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -16,7 +16,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - go: ['1.18', '1.19', '1.20'] + go: ['1.19', '1.20', '1.21'] name: Go ${{ matrix.go }} test steps: - uses: actions/checkout@v3 diff --git a/README.md b/README.md index b7993e69..91a2f39d 100644 --- a/README.md +++ b/README.md @@ -115,10 +115,10 @@ Versions that also build are marked with :warning:. | Version | Supported | | ------- | ------------------ | -| <1.18 | :x: | -| 1.18 | :warning: | -| 1.19 | :white_check_mark: | +| <1.19 | :x: | +| 1.19 | :warning: | | 1.20 | :white_check_mark: | +| 1.21 | :white_check_mark: | ## Why another library diff --git a/example/client/app/app.go b/example/client/app/app.go index 2cb5dfa7..0e339f40 100644 --- a/example/client/app/app.go +++ b/example/client/app/app.go @@ -7,11 +7,14 @@ import ( "net/http" "os" "strings" + "sync/atomic" "time" "github.com/google/uuid" "github.com/sirupsen/logrus" + "golang.org/x/exp/slog" + "github.com/zitadel/logging" "github.com/zitadel/oidc/v3/pkg/client/rp" httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" @@ -33,9 +36,25 @@ func main() { redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath) cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure()) + logger := slog.New( + slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelDebug, + }), + ) + client := &http.Client{ + Timeout: time.Minute, + } + // enable outgoing request logging + logging.EnableHTTPClient(client, + logging.WithClientGroup("client"), + ) + options := []rp.Option{ rp.WithCookieHandler(cookieHandler), rp.WithVerifierOpts(rp.WithIssuedAtOffset(5 * time.Second)), + rp.WithHTTPClient(client), + rp.WithLogger(logger), } if clientSecret == "" { options = append(options, rp.WithPKCE(cookieHandler)) @@ -44,7 +63,10 @@ func main() { options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath))) } - provider, err := rp.NewRelyingPartyOIDC(context.TODO(), issuer, clientID, clientSecret, redirectURI, scopes, options...) + // One can add a logger to the context, + // pre-defining log attributes as required. + ctx := logging.ToContext(context.TODO(), logger) + provider, err := rp.NewRelyingPartyOIDC(ctx, issuer, clientID, clientSecret, redirectURI, scopes, options...) if err != nil { logrus.Fatalf("error creating provider %s", err.Error()) } @@ -119,8 +141,22 @@ func main() { // // http.Handle(callbackPath, rp.CodeExchangeHandler(marshalToken, provider)) + // simple counter for request IDs + var counter atomic.Int64 + // enable incomming request logging + mw := logging.Middleware( + logging.WithLogger(logger), + logging.WithGroup("server"), + logging.WithIDFunc(func() slog.Attr { + return slog.Int64("id", counter.Add(1)) + }), + ) + lis := fmt.Sprintf("127.0.0.1:%s", port) - logrus.Infof("listening on http://%s/", lis) - logrus.Info("press ctrl+c to stop") - logrus.Fatal(http.ListenAndServe(lis, nil)) + logger.Info("server listening, press ctrl+c to stop", "addr", lis) + err = http.ListenAndServe(lis, mw(http.DefaultServeMux)) + if err != http.ErrServerClosed { + logger.Error("server terminated", "error", err) + os.Exit(1) + } } diff --git a/example/server/exampleop/op.go b/example/server/exampleop/op.go index 298bff69..b5ee7b37 100644 --- a/example/server/exampleop/op.go +++ b/example/server/exampleop/op.go @@ -4,9 +4,12 @@ import ( "crypto/sha256" "log" "net/http" + "sync/atomic" "time" "github.com/go-chi/chi" + "github.com/zitadel/logging" + "golang.org/x/exp/slog" "golang.org/x/text/language" "github.com/zitadel/oidc/v3/example/server/storage" @@ -31,26 +34,33 @@ type Storage interface { deviceAuthenticate } +// simple counter for request IDs +var counter atomic.Int64 + // SetupServer creates an OIDC server with Issuer=http://localhost: // // Use one of the pre-made clients in storage/clients.go or register a new one. -func SetupServer(issuer string, storage Storage) chi.Router { +func SetupServer(issuer string, storage Storage, logger *slog.Logger) chi.Router { // the OpenID Provider requires a 32-byte key for (token) encryption // be sure to create a proper crypto random key and manage it securely! key := sha256.Sum256([]byte("test")) router := chi.NewRouter() + router.Use(logging.Middleware( + logging.WithLogger(logger), + logging.WithIDFunc(func() slog.Attr { + return slog.Int64("id", counter.Add(1)) + }), + )) // for simplicity, we provide a very small default page for users who have signed out router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) { - _, err := w.Write([]byte("signed out successfully")) - if err != nil { - log.Printf("error serving logged out page: %v", err) - } + w.Write([]byte("signed out successfully")) + // no need to check/log error, this will be handeled by the middleware. }) // creation of the OpenIDProvider with the just created in-memory Storage - provider, err := newOP(storage, issuer, key) + provider, err := newOP(storage, issuer, key, logger) if err != nil { log.Fatal(err) } @@ -80,7 +90,7 @@ func SetupServer(issuer string, storage Storage) chi.Router { // newOP will create an OpenID Provider for localhost on a specified port with a given encryption key // and a predefined default logout uri // it will enable all options (see descriptions) -func newOP(storage op.Storage, issuer string, key [32]byte) (op.OpenIDProvider, error) { +func newOP(storage op.Storage, issuer string, key [32]byte, logger *slog.Logger) (op.OpenIDProvider, error) { config := &op.Config{ CryptoKey: key, @@ -117,6 +127,8 @@ func newOP(storage op.Storage, issuer string, key [32]byte) (op.OpenIDProvider, op.WithAllowInsecure(), // as an example on how to customize an endpoint this will change the authorization_endpoint from /authorize to /auth op.WithCustomAuthEndpoint(op.NewEndpoint("auth")), + // Pass our logger to the OP + op.WithLogger(logger.WithGroup("op")), ) if err != nil { return nil, err diff --git a/example/server/main.go b/example/server/main.go index ee27bbab..a1cc4618 100644 --- a/example/server/main.go +++ b/example/server/main.go @@ -2,11 +2,12 @@ package main import ( "fmt" - "log" "net/http" + "os" "github.com/zitadel/oidc/v3/example/server/exampleop" "github.com/zitadel/oidc/v3/example/server/storage" + "golang.org/x/exp/slog" ) func main() { @@ -20,16 +21,22 @@ func main() { // in this example it will be handled in-memory storage := storage.NewStorage(storage.NewUserStore(issuer)) - router := exampleop.SetupServer(issuer, storage) + logger := slog.New( + slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelDebug, + }), + ) + router := exampleop.SetupServer(issuer, storage, logger) server := &http.Server{ Addr: ":" + port, Handler: router, } - log.Printf("server listening on http://localhost:%s/", port) - log.Println("press ctrl+c to stop") + logger.Info("server listening, press ctrl+c to stop", "addr", fmt.Sprintf("http://localhost:%s/", port)) err := server.ListenAndServe() - if err != nil { - log.Fatal(err) + if err != http.ErrServerClosed { + logger.Error("server terminated", "error", err) + os.Exit(1) } } diff --git a/example/server/storage/oidc.go b/example/server/storage/oidc.go index b56ad090..63afcf93 100644 --- a/example/server/storage/oidc.go +++ b/example/server/storage/oidc.go @@ -3,6 +3,7 @@ package storage import ( "time" + "golang.org/x/exp/slog" "golang.org/x/text/language" "github.com/zitadel/oidc/v3/pkg/oidc" @@ -41,6 +42,19 @@ type AuthRequest struct { authTime time.Time } +// LogValue allows you to define which fields will be logged. +// Implements the [slog.LogValuer] +func (a *AuthRequest) LogValue() slog.Value { + return slog.GroupValue( + slog.String("id", a.ID), + slog.Time("creation_date", a.CreationDate), + slog.Any("scopes", a.Scopes), + slog.String("response_type", string(a.ResponseType)), + slog.String("app_id", a.ApplicationID), + slog.String("callback_uri", a.CallbackURI), + ) +} + func (a *AuthRequest) GetID() string { return a.ID } diff --git a/go.mod b/go.mod index 610d2a10..62aa39be 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/zitadel/oidc/v3 -go 1.18 +go 1.19 require ( github.com/go-chi/chi v1.5.4 @@ -11,9 +11,11 @@ require ( github.com/jeremija/gosubmit v0.2.7 github.com/muhlemmer/gu v0.3.1 github.com/rs/cors v1.9.0 - github.com/sirupsen/logrus v1.9.0 + github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.8.2 + github.com/zitadel/logging v0.4.0 github.com/zitadel/schema v1.3.0 + golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 golang.org/x/oauth2 v0.7.0 golang.org/x/text v0.9.0 gopkg.in/square/go-jose.v2 v2.6.0 @@ -27,7 +29,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/crypto v0.7.0 // indirect golang.org/x/net v0.9.0 // indirect - golang.org/x/sys v0.7.0 // indirect + golang.org/x/sys v0.11.0 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/protobuf v1.29.1 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect diff --git a/go.sum b/go.sum index c9c85626..9c44f0f6 100644 --- a/go.sum +++ b/go.sum @@ -36,8 +36,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rs/cors v1.9.0 h1:l9HGsTsHJcvW14Nk7J9KFz8bzeAWXn3CG6bgt7LsrAE= github.com/rs/cors v1.9.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= -github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= -github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -47,12 +47,16 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/zitadel/logging v0.4.0 h1:lRAIFgaRoJpLNbsL7jtIYHcMDoEJP9QZB4GqMfl4xaA= +github.com/zitadel/logging v0.4.0/go.mod h1:6uALRJawpkkuUPCkgzfgcPR3c2N908wqnOnIrRelUFc= github.com/zitadel/schema v1.3.0 h1:kQ9W9tvIwZICCKWcMvCEweXET1OcOyGEuFbHs4o5kg0= github.com/zitadel/schema v1.3.0/go.mod h1:NptN6mkBDFvERUCvZHlvWmmME+gmZ44xzwRXwhzsbtc= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= +golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ= +golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -73,8 +77,8 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= -golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= +golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= @@ -100,6 +104,7 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI= gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/client/client.go b/pkg/client/client.go index e3efd611..7b76dfd6 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -14,6 +14,7 @@ import ( "golang.org/x/oauth2" "gopkg.in/square/go-jose.v2" + "github.com/zitadel/logging" "github.com/zitadel/oidc/v3/pkg/crypto" httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" @@ -37,6 +38,10 @@ func Discover(ctx context.Context, issuer string, httpClient *http.Client, wellK if err != nil { return nil, err } + if logger, ok := logging.FromContext(ctx); ok { + logger.Debug("discover", "config", discoveryConfig) + } + if discoveryConfig.Issuer != issuer { return nil, oidc.ErrIssuerInvalid } diff --git a/pkg/client/integration_test.go b/pkg/client/integration_test.go index 073efef7..7cbb62e6 100644 --- a/pkg/client/integration_test.go +++ b/pkg/client/integration_test.go @@ -19,6 +19,7 @@ import ( "github.com/jeremija/gosubmit" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/exp/slog" "github.com/zitadel/oidc/v3/example/server/exampleop" "github.com/zitadel/oidc/v3/example/server/storage" @@ -29,6 +30,13 @@ import ( "github.com/zitadel/oidc/v3/pkg/oidc" ) +var Logger = slog.New( + slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelDebug, + }), +) + var CTX context.Context func TestMain(m *testing.M) { @@ -49,7 +57,7 @@ func TestRelyingPartySession(t *testing.T) { opServer := httptest.NewServer(&dh) defer opServer.Close() t.Logf("auth server at %s", opServer.URL) - dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage) + dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger) seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano())) clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) @@ -100,7 +108,7 @@ func TestResourceServerTokenExchange(t *testing.T) { opServer := httptest.NewServer(&dh) defer opServer.Close() t.Logf("auth server at %s", opServer.URL) - dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage) + dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger) seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano())) clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) diff --git a/pkg/client/rp/device.go b/pkg/client/rp/device.go index 390c8cf4..02c647e3 100644 --- a/pkg/client/rp/device.go +++ b/pkg/client/rp/device.go @@ -33,6 +33,7 @@ func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc. // in RFC 8628, section 3.1 and 3.2: // https://www.rfc-editor.org/rfc/rfc8628#section-3.1 func DeviceAuthorization(ctx context.Context, scopes []string, rp RelyingParty, authFn any) (*oidc.DeviceAuthorizationResponse, error) { + ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAuthorization") req, err := newDeviceClientCredentialsRequest(scopes, rp) if err != nil { return nil, err @@ -45,6 +46,7 @@ func DeviceAuthorization(ctx context.Context, scopes []string, rp RelyingParty, // by means of polling as defined in RFC, section 3.3 and 3.4: // https://www.rfc-editor.org/rfc/rfc8628#section-3.4 func DeviceAccessToken(ctx context.Context, deviceCode string, interval time.Duration, rp RelyingParty) (resp *oidc.AccessTokenResponse, err error) { + ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAccessToken") req := &client.DeviceAccessTokenRequest{ DeviceAccessTokenRequest: oidc.DeviceAccessTokenRequest{ GrantType: oidc.GrantTypeDeviceCode, diff --git a/pkg/client/rp/log.go b/pkg/client/rp/log.go new file mode 100644 index 00000000..6056fa2e --- /dev/null +++ b/pkg/client/rp/log.go @@ -0,0 +1,17 @@ +package rp + +import ( + "context" + + "github.com/zitadel/logging" + "golang.org/x/exp/slog" +) + +func logCtxWithRPData(ctx context.Context, rp RelyingParty, attrs ...any) context.Context { + logger, ok := rp.Logger(ctx) + if !ok { + return ctx + } + logger = logger.With(slog.Group("rp", attrs...)) + return logging.ToContext(ctx, logger) +} diff --git a/pkg/client/rp/relying_party.go b/pkg/client/rp/relying_party.go index 5597c9d9..34cdb397 100644 --- a/pkg/client/rp/relying_party.go +++ b/pkg/client/rp/relying_party.go @@ -10,6 +10,8 @@ import ( "time" "github.com/google/uuid" + "github.com/zitadel/logging" + "golang.org/x/exp/slog" "golang.org/x/oauth2" "gopkg.in/square/go-jose.v2" @@ -67,6 +69,9 @@ type RelyingParty interface { // ErrorHandler returns the handler used for callback errors ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string) + + // Logger from the context, or a fallback if set. + Logger(context.Context) (logger *slog.Logger, ok bool) } type ErrorHandler func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) @@ -90,6 +95,7 @@ type relyingParty struct { idTokenVerifier *IDTokenVerifier verifierOpts []VerifierOption signer jose.Signer + logger *slog.Logger } func (rp *relyingParty) OAuthConfig() *oauth2.Config { @@ -150,6 +156,14 @@ func (rp *relyingParty) ErrorHandler() func(http.ResponseWriter, *http.Request, return rp.errorHandler } +func (rp *relyingParty) Logger(ctx context.Context) (logger *slog.Logger, ok bool) { + logger, ok = logging.FromContext(ctx) + if ok { + return logger, ok + } + return rp.logger, rp.logger != nil +} + // NewRelyingPartyOAuth creates an (OAuth2) RelyingParty with the given // OAuth2 Config and possible configOptions // it will use the AuthURL and TokenURL set in config @@ -194,6 +208,7 @@ func NewRelyingPartyOIDC(ctx context.Context, issuer, clientID, clientSecret, re return nil, err } } + ctx = logCtxWithRPData(ctx, rp, "function", "NewRelyingPartyOIDC") discoveryConfiguration, err := client.Discover(ctx, rp.issuer, rp.httpClient, rp.DiscoveryEndpoint) if err != nil { return nil, err @@ -281,6 +296,15 @@ func WithJWTProfile(signerFromKey SignerFromKey) Option { } } +// WithLogger sets a logger that is used +// in case the request context does not contain a logger. +func WithLogger(logger *slog.Logger) Option { + return func(rp *relyingParty) error { + rp.logger = logger + return nil + } +} + type SignerFromKey func() (jose.Signer, error) func SignerFromKeyPath(path string) SignerFromKey { @@ -378,6 +402,7 @@ func verifyTokenResponse[C oidc.IDClaims](ctx context.Context, token *oauth2.Tok // CodeExchange handles the oauth2 code exchange, extracting and validating the id_token // returning it parsed together with the oauth2 tokens (access, refresh) func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens[C], err error) { + ctx = logCtxWithRPData(ctx, rp, "function", "CodeExchange") ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient()) codeOpts := make([]oauth2.AuthCodeOption, 0) for _, opt := range opts { @@ -467,6 +492,7 @@ func UserinfoCallback[C oidc.IDClaims, U SubjectGetter](f CodeExchangeUserinfoCa // [UserInfo]: https://openid.net/specs/openid-connect-core-1_0.html#UserInfo func Userinfo[U SubjectGetter](ctx context.Context, token, tokenType, subject string, rp RelyingParty) (userinfo U, err error) { var nilU U + ctx = logCtxWithRPData(ctx, rp, "function", "Userinfo") req, err := http.NewRequestWithContext(ctx, http.MethodGet, rp.UserinfoEndpoint(), nil) if err != nil { @@ -546,7 +572,7 @@ func withURLParam(key, value string) func() []oauth2.AuthCodeOption { // This is the generalized, unexported, function used by both // URLParamOpt and AuthURLOpt. func withPrompt(prompt ...string) func() []oauth2.AuthCodeOption { - return withURLParam("prompt", oidc.SpaceDelimitedArray(prompt).Encode()) + return withURLParam("prompt", oidc.SpaceDelimitedArray(prompt).String()) } type URLParamOpt func() []oauth2.AuthCodeOption @@ -621,6 +647,7 @@ type RefreshTokenRequest struct { // the IDToken and AccessToken will be verfied // and the IDToken and IDTokenClaims fields will be populated in the returned object. func RefreshTokens[C oidc.IDClaims](ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oidc.Tokens[C], error) { + ctx = logCtxWithRPData(ctx, rp, "function", "RefreshTokens") request := RefreshTokenRequest{ RefreshToken: refreshToken, Scopes: rp.OAuthConfig().Scopes, @@ -644,6 +671,7 @@ func RefreshTokens[C oidc.IDClaims](ctx context.Context, rp RelyingParty, refres } func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) { + ctx = logCtxWithRPData(ctx, rp, "function", "EndSession") request := oidc.EndSessionRequest{ IdTokenHint: idToken, ClientID: rp.OAuthConfig().ClientID, @@ -659,6 +687,7 @@ func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectU // // tokenTypeHint should be either "id_token" or "refresh_token". func RevokeToken(ctx context.Context, rp RelyingParty, token string, tokenTypeHint string) error { + ctx = logCtxWithRPData(ctx, rp, "function", "RevokeToken") request := client.RevokeRequest{ Token: token, TokenTypeHint: tokenTypeHint, diff --git a/pkg/oidc/authorization.go b/pkg/oidc/authorization.go index d8bf3364..7e7c30cc 100644 --- a/pkg/oidc/authorization.go +++ b/pkg/oidc/authorization.go @@ -1,5 +1,9 @@ package oidc +import ( + "golang.org/x/exp/slog" +) + const ( // ScopeOpenID defines the scope `openid` // OpenID Connect requests MUST contain the `openid` scope value @@ -86,6 +90,15 @@ type AuthRequest struct { RequestParam string `schema:"request"` } +func (a *AuthRequest) LogValue() slog.Value { + return slog.GroupValue( + slog.Any("scopes", a.Scopes), + slog.String("response_type", string(a.ResponseType)), + slog.String("client_id", a.ClientID), + slog.String("redirect_uri", a.RedirectURI), + ) +} + // GetRedirectURI returns the redirect_uri value for the ErrAuthRequest interface func (a *AuthRequest) GetRedirectURI() string { return a.RedirectURI diff --git a/pkg/oidc/authorization_test.go b/pkg/oidc/authorization_test.go new file mode 100644 index 00000000..573d65c3 --- /dev/null +++ b/pkg/oidc/authorization_test.go @@ -0,0 +1,27 @@ +//go:build go1.20 + +package oidc + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/exp/slog" +) + +func TestAuthRequest_LogValue(t *testing.T) { + a := &AuthRequest{ + Scopes: SpaceDelimitedArray{"a", "b"}, + ResponseType: "respType", + ClientID: "123", + RedirectURI: "http://example.com/callback", + } + want := slog.GroupValue( + slog.Any("scopes", SpaceDelimitedArray{"a", "b"}), + slog.String("response_type", "respType"), + slog.String("client_id", "123"), + slog.String("redirect_uri", "http://example.com/callback"), + ) + got := a.LogValue() + assert.Equal(t, want, got) +} diff --git a/pkg/oidc/error.go b/pkg/oidc/error.go index 79acecd9..07a90697 100644 --- a/pkg/oidc/error.go +++ b/pkg/oidc/error.go @@ -3,6 +3,8 @@ package oidc import ( "errors" "fmt" + + "golang.org/x/exp/slog" ) type errorType string @@ -171,3 +173,34 @@ func DefaultToServerError(err error, description string) *Error { } return oauth } + +func (e *Error) LogLevel() slog.Level { + level := slog.LevelWarn + if e.ErrorType == ServerError { + level = slog.LevelError + } + if e.ErrorType == AuthorizationPending { + level = slog.LevelInfo + } + return level +} + +func (e *Error) LogValue() slog.Value { + attrs := make([]slog.Attr, 0, 5) + if e.Parent != nil { + attrs = append(attrs, slog.Any("parent", e.Parent)) + } + if e.Description != "" { + attrs = append(attrs, slog.String("description", e.Description)) + } + if e.ErrorType != "" { + attrs = append(attrs, slog.String("type", string(e.ErrorType))) + } + if e.State != "" { + attrs = append(attrs, slog.String("state", e.State)) + } + if e.redirectDisabled { + attrs = append(attrs, slog.Bool("redirect_disabled", e.redirectDisabled)) + } + return slog.GroupValue(attrs...) +} diff --git a/pkg/oidc/error_go120_test.go b/pkg/oidc/error_go120_test.go new file mode 100644 index 00000000..399d7f71 --- /dev/null +++ b/pkg/oidc/error_go120_test.go @@ -0,0 +1,83 @@ +//go:build go1.20 + +package oidc + +import ( + "io" + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/exp/slog" +) + +func TestError_LogValue(t *testing.T) { + type fields struct { + Parent error + ErrorType errorType + Description string + State string + redirectDisabled bool + } + tests := []struct { + name string + fields fields + want slog.Value + }{ + { + name: "parent", + fields: fields{ + Parent: io.EOF, + }, + want: slog.GroupValue(slog.Any("parent", io.EOF)), + }, + { + name: "description", + fields: fields{ + Description: "oops", + }, + want: slog.GroupValue(slog.String("description", "oops")), + }, + { + name: "errorType", + fields: fields{ + ErrorType: ExpiredToken, + }, + want: slog.GroupValue(slog.String("type", string(ExpiredToken))), + }, + { + name: "state", + fields: fields{ + State: "123", + }, + want: slog.GroupValue(slog.String("state", "123")), + }, + { + name: "all fields", + fields: fields{ + Parent: io.EOF, + Description: "oops", + ErrorType: ExpiredToken, + State: "123", + }, + want: slog.GroupValue( + slog.Any("parent", io.EOF), + slog.String("description", "oops"), + slog.String("type", string(ExpiredToken)), + slog.String("state", "123"), + ), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &Error{ + Parent: tt.fields.Parent, + ErrorType: tt.fields.ErrorType, + Description: tt.fields.Description, + State: tt.fields.State, + redirectDisabled: tt.fields.redirectDisabled, + } + got := e.LogValue() + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/oidc/error_test.go b/pkg/oidc/error_test.go new file mode 100644 index 00000000..0554c8fb --- /dev/null +++ b/pkg/oidc/error_test.go @@ -0,0 +1,81 @@ +package oidc + +import ( + "io" + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/exp/slog" +) + +func TestDefaultToServerError(t *testing.T) { + type args struct { + err error + description string + } + tests := []struct { + name string + args args + want *Error + }{ + { + name: "default", + args: args{ + err: io.ErrClosedPipe, + description: "oops", + }, + want: &Error{ + ErrorType: ServerError, + Description: "oops", + Parent: io.ErrClosedPipe, + }, + }, + { + name: "our Error", + args: args{ + err: ErrAccessDenied(), + description: "oops", + }, + want: &Error{ + ErrorType: AccessDenied, + Description: "The authorization request was denied.", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := DefaultToServerError(tt.args.err, tt.args.description) + assert.ErrorIs(t, got, tt.want) + }) + } +} + +func TestError_LogLevel(t *testing.T) { + tests := []struct { + name string + err *Error + want slog.Level + }{ + { + name: "server error", + err: ErrServerError(), + want: slog.LevelError, + }, + { + name: "authorization pending", + err: ErrAuthorizationPending(), + want: slog.LevelInfo, + }, + { + name: "some other error", + err: ErrAccessDenied(), + want: slog.LevelWarn, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.err.LogLevel() + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/oidc/types.go b/pkg/oidc/types.go index 86ee1e0f..5db8badc 100644 --- a/pkg/oidc/types.go +++ b/pkg/oidc/types.go @@ -106,7 +106,7 @@ type ResponseType string type ResponseMode string -func (s SpaceDelimitedArray) Encode() string { +func (s SpaceDelimitedArray) String() string { return strings.Join(s, " ") } @@ -116,11 +116,11 @@ func (s *SpaceDelimitedArray) UnmarshalText(text []byte) error { } func (s SpaceDelimitedArray) MarshalText() ([]byte, error) { - return []byte(s.Encode()), nil + return []byte(s.String()), nil } func (s SpaceDelimitedArray) MarshalJSON() ([]byte, error) { - return json.Marshal((s).Encode()) + return json.Marshal((s).String()) } func (s *SpaceDelimitedArray) UnmarshalJSON(data []byte) error { @@ -165,7 +165,7 @@ func (s SpaceDelimitedArray) Value() (driver.Value, error) { func NewEncoder() *schema.Encoder { e := schema.NewEncoder() e.RegisterEncoder(SpaceDelimitedArray{}, func(value reflect.Value) string { - return value.Interface().(SpaceDelimitedArray).Encode() + return value.Interface().(SpaceDelimitedArray).String() }) return e } diff --git a/pkg/op/auth_request.go b/pkg/op/auth_request.go index 7af3779e..7610248e 100644 --- a/pkg/op/auth_request.go +++ b/pkg/op/auth_request.go @@ -14,6 +14,7 @@ import ( httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" str "github.com/zitadel/oidc/v3/pkg/strings" + "golang.org/x/exp/slog" ) type AuthRequest interface { @@ -41,6 +42,7 @@ type Authorizer interface { IDTokenHintVerifier(context.Context) *IDTokenHintVerifier Crypto() Crypto RequestObjectSupported() bool + Logger() *slog.Logger } // AuthorizeValidator is an extension of Authorizer interface @@ -67,23 +69,23 @@ func authorizeCallbackHandler(authorizer Authorizer) func(http.ResponseWriter, * func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { authReq, err := ParseAuthorizeRequest(r, authorizer.Decoder()) if err != nil { - AuthRequestError(w, r, nil, err, authorizer.Encoder()) + AuthRequestError(w, r, nil, err, authorizer) return } ctx := r.Context() if authReq.RequestParam != "" && authorizer.RequestObjectSupported() { authReq, err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx)) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } } if authReq.ClientID == "" { - AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing client_id"), authorizer.Encoder()) + AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing client_id"), authorizer) return } if authReq.RedirectURI == "" { - AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing redirect_uri"), authorizer.Encoder()) + AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing redirect_uri"), authorizer) return } validation := ValidateAuthRequest @@ -92,21 +94,21 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { } userID, err := validation(ctx, authReq, authorizer.Storage(), authorizer.IDTokenHintVerifier(ctx)) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } if authReq.RequestParam != "" { - AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer.Encoder()) + AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer) return } req, err := authorizer.Storage().CreateAuthRequest(ctx, authReq, userID) if err != nil { - AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer.Encoder()) + AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer) return } client, err := authorizer.Storage().GetClientByClientID(ctx, req.GetClientID()) if err != nil { - AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer.Encoder()) + AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer) return } RedirectToLogin(req.GetID(), client, w, r) @@ -406,18 +408,18 @@ func RedirectToLogin(authReqID string, client Client, w http.ResponseWriter, r * func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { id, err := ParseAuthorizeCallbackRequest(r) if err != nil { - AuthRequestError(w, r, nil, err, authorizer.Encoder()) + AuthRequestError(w, r, nil, err, authorizer) return } authReq, err := authorizer.Storage().AuthRequestByID(r.Context(), id) if err != nil { - AuthRequestError(w, r, nil, err, authorizer.Encoder()) + AuthRequestError(w, r, nil, err, authorizer) return } if !authReq.Done() { AuthRequestError(w, r, authReq, oidc.ErrInteractionRequired().WithDescription("Unfortunately, the user may be not logged in and/or additional interaction is required."), - authorizer.Encoder()) + authorizer) return } AuthResponse(authReq, authorizer, w, r) @@ -438,7 +440,7 @@ func ParseAuthorizeCallbackRequest(r *http.Request) (id string, err error) { func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) { client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID()) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } if authReq.GetResponseType() == oidc.ResponseTypeCode { @@ -452,7 +454,7 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) { code, err := CreateAuthRequestCode(r.Context(), authReq, authorizer.Storage(), authorizer.Crypto()) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } codeResponse := struct { @@ -464,7 +466,7 @@ func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthReques } callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder()) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } http.Redirect(w, r, callback, http.StatusFound) @@ -475,12 +477,12 @@ func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthReque createAccessToken := authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly resp, err := CreateTokenResponse(r.Context(), authReq, client, authorizer, createAccessToken, "", "") if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), resp, authorizer.Encoder()) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } http.Redirect(w, r, callback, http.StatusFound) diff --git a/pkg/op/auth_request_test.go b/pkg/op/auth_request_test.go index df340b6b..42fd0aa0 100644 --- a/pkg/op/auth_request_test.go +++ b/pkg/op/auth_request_test.go @@ -18,6 +18,7 @@ import ( "github.com/zitadel/oidc/v3/pkg/op" "github.com/zitadel/oidc/v3/pkg/op/mock" "github.com/zitadel/schema" + "golang.org/x/exp/slog" ) func TestAuthorize(t *testing.T) { @@ -38,7 +39,7 @@ func TestAuthorize(t *testing.T) { expect := authorizer.EXPECT() expect.Decoder().Return(schema.NewDecoder()) - expect.Encoder().Return(schema.NewEncoder()) + expect.Logger().Return(slog.Default()) if tt.expect != nil { tt.expect(expect) diff --git a/pkg/op/device.go b/pkg/op/device.go index 09c7fca1..029bed8a 100644 --- a/pkg/op/device.go +++ b/pkg/op/device.go @@ -57,7 +57,7 @@ var ( func DeviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { if err := DeviceAuthorization(w, r, o); err != nil { - RequestError(w, r, err) + RequestError(w, r, err, o.Logger()) } } } @@ -190,7 +190,7 @@ func (r *deviceAccessTokenRequest) GetScopes() []string { func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { if err := deviceAccessToken(w, r, exchanger); err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } } diff --git a/pkg/op/error.go b/pkg/op/error.go index b2d84ae1..9981fecc 100644 --- a/pkg/op/error.go +++ b/pkg/op/error.go @@ -5,6 +5,7 @@ import ( httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" + "golang.org/x/exp/slog" ) type ErrAuthRequest interface { @@ -13,13 +14,31 @@ type ErrAuthRequest interface { GetState() string } -func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder httphelper.Encoder) { +// LogAuthRequest is an optional interface, +// that allows logging AuthRequest fields. +// If the AuthRequest does not implement this interface, +// no details shall be printed to the logs. +type LogAuthRequest interface { + ErrAuthRequest + slog.LogValuer +} + +func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, authorizer Authorizer) { + e := oidc.DefaultToServerError(err, err.Error()) + logger := authorizer.Logger().With("oidc_error", e) + if authReq == nil { + logger.Log(r.Context(), e.LogLevel(), "auth request") http.Error(w, err.Error(), http.StatusBadRequest) return } - e := oidc.DefaultToServerError(err, err.Error()) + + if logAuthReq, ok := authReq.(LogAuthRequest); ok { + logger = logger.With("auth_request", logAuthReq) + } + if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() { + logger.Log(r.Context(), e.LogLevel(), "auth request: not redirecting") http.Error(w, e.Description, http.StatusBadRequest) return } @@ -28,19 +47,22 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok { responseMode = rm.GetResponseMode() } - url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder) + url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, authorizer.Encoder()) if err != nil { + logger.ErrorContext(r.Context(), "auth response URL", "error", err) http.Error(w, err.Error(), http.StatusBadRequest) return } + logger.Log(r.Context(), e.LogLevel(), "auth request") http.Redirect(w, r, url, http.StatusFound) } -func RequestError(w http.ResponseWriter, r *http.Request, err error) { +func RequestError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) { e := oidc.DefaultToServerError(err, err.Error()) status := http.StatusBadRequest if e.ErrorType == oidc.InvalidClient { - status = 401 + status = http.StatusUnauthorized } + logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e) httphelper.MarshalJSONWithStatus(w, e, status) } diff --git a/pkg/op/error_test.go b/pkg/op/error_test.go new file mode 100644 index 00000000..dc5ef110 --- /dev/null +++ b/pkg/op/error_test.go @@ -0,0 +1,277 @@ +package op + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/schema" + "golang.org/x/exp/slog" +) + +func TestAuthRequestError(t *testing.T) { + type args struct { + authReq ErrAuthRequest + err error + } + tests := []struct { + name string + args args + wantCode int + wantHeaders map[string]string + wantBody string + wantLog string + }{ + { + name: "nil auth request", + args: args{ + authReq: nil, + err: io.ErrClosedPipe, + }, + wantCode: http.StatusBadRequest, + wantBody: "io: read/write on closed pipe\n", + wantLog: `{ + "level":"ERROR", + "msg":"auth request", + "time":"not", + "oidc_error":{ + "description":"io: read/write on closed pipe", + "parent":"io: read/write on closed pipe", + "type":"server_error" + } + }`, + }, + { + name: "auth request, no redirect URI", + args: args{ + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + err: oidc.ErrInteractionRequired().WithDescription("sign in"), + }, + wantCode: http.StatusBadRequest, + wantBody: "sign in\n", + wantLog: `{ + "level":"WARN", + "msg":"auth request: not redirecting", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"", + "response_type":"responseType", + "scopes":"a b" + }, + "oidc_error":{ + "description":"sign in", + "type":"interaction_required" + } + }`, + }, + { + name: "auth request, redirect disabled", + args: args{ + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + RedirectURI: "http://example.com/callback", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + err: oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"), + }, + wantCode: http.StatusBadRequest, + wantBody: "oops\n", + wantLog: `{ + "level":"WARN", + "msg":"auth request: not redirecting", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"http://example.com/callback", + "response_type":"responseType", + "scopes":"a b" + }, + "oidc_error":{ + "description":"oops", + "type":"invalid_request", + "redirect_disabled":true + } + }`, + }, + { + name: "auth request, url parse error", + args: args{ + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + RedirectURI: "can't parse this!\n", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + err: oidc.ErrInteractionRequired().WithDescription("sign in"), + }, + wantCode: http.StatusBadRequest, + wantBody: "ErrorType=server_error Parent=parse \"can't parse this!\\n\": net/url: invalid control character in URL\n", + wantLog: `{ + "level":"ERROR", + "msg":"auth response URL", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"can't parse this!\n", + "response_type":"responseType", + "scopes":"a b" + }, + "error":{ + "type":"server_error", + "parent":"parse \"can't parse this!\\n\": net/url: invalid control character in URL" + }, + "oidc_error":{ + "description":"sign in", + "type":"interaction_required" + } + }`, + }, + { + name: "auth request redirect", + args: args{ + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + RedirectURI: "http://example.com/callback", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + err: oidc.ErrInteractionRequired().WithDescription("sign in"), + }, + wantCode: http.StatusFound, + wantHeaders: map[string]string{"Location": "http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1"}, + wantLog: `{ + "level":"WARN", + "msg":"auth request", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"http://example.com/callback", + "response_type":"responseType", + "scopes":"a b" + }, + "oidc_error":{ + "description":"sign in", + "type":"interaction_required" + } + }`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logOut := new(strings.Builder) + authorizer := &Provider{ + encoder: schema.NewEncoder(), + logger: slog.New( + slog.NewJSONHandler(logOut, &slog.HandlerOptions{ + Level: slog.LevelInfo, + }).WithAttrs([]slog.Attr{slog.String("time", "not")}), + ), + } + + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/path", nil) + AuthRequestError(w, r, tt.args.authReq, tt.args.err, authorizer) + + res := w.Result() + defer res.Body.Close() + + assert.Equal(t, tt.wantCode, res.StatusCode) + for key, wantHeader := range tt.wantHeaders { + gotHeader := res.Header.Get(key) + assert.Equalf(t, wantHeader, gotHeader, "header %q", key) + } + gotBody, err := io.ReadAll(res.Body) + require.NoError(t, err, "read result body") + assert.Equal(t, tt.wantBody, string(gotBody), "result body") + + gotLog := logOut.String() + t.Log(gotLog) + assert.JSONEq(t, tt.wantLog, gotLog, "log output") + }) + } +} + +func TestRequestError(t *testing.T) { + tests := []struct { + name string + err error + wantCode int + wantBody string + wantLog string + }{ + { + name: "server error", + err: io.ErrClosedPipe, + wantCode: http.StatusBadRequest, + wantBody: `{"error":"server_error", "error_description":"io: read/write on closed pipe"}`, + wantLog: `{ + "level":"ERROR", + "msg":"request error", + "time":"not", + "oidc_error":{ + "parent":"io: read/write on closed pipe", + "description":"io: read/write on closed pipe", + "type":"server_error"} + }`, + }, + { + name: "invalid client", + err: oidc.ErrInvalidClient().WithDescription("not good"), + wantCode: http.StatusUnauthorized, + wantBody: `{"error":"invalid_client", "error_description":"not good"}`, + wantLog: `{ + "level":"WARN", + "msg":"request error", + "time":"not", + "oidc_error":{ + "description":"not good", + "type":"invalid_client"} + }`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logOut := new(strings.Builder) + logger := slog.New( + slog.NewJSONHandler(logOut, &slog.HandlerOptions{ + Level: slog.LevelInfo, + }).WithAttrs([]slog.Attr{slog.String("time", "not")}), + ) + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/path", nil) + RequestError(w, r, tt.err, logger) + + res := w.Result() + defer res.Body.Close() + + assert.Equal(t, tt.wantCode, res.StatusCode, "status code") + + gotBody, err := io.ReadAll(res.Body) + require.NoError(t, err, "read result body") + assert.JSONEq(t, tt.wantBody, string(gotBody), "result body") + + gotLog := logOut.String() + t.Log(gotLog) + assert.JSONEq(t, tt.wantLog, gotLog, "log output") + }) + } +} diff --git a/pkg/op/mock/authorizer.mock.go b/pkg/op/mock/authorizer.mock.go index a0c67e3d..e4297cb8 100644 --- a/pkg/op/mock/authorizer.mock.go +++ b/pkg/op/mock/authorizer.mock.go @@ -11,6 +11,7 @@ import ( gomock "github.com/golang/mock/gomock" http "github.com/zitadel/oidc/v3/pkg/http" op "github.com/zitadel/oidc/v3/pkg/op" + slog "golang.org/x/exp/slog" ) // MockAuthorizer is a mock of Authorizer interface. @@ -92,6 +93,20 @@ func (mr *MockAuthorizerMockRecorder) IDTokenHintVerifier(arg0 interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenHintVerifier", reflect.TypeOf((*MockAuthorizer)(nil).IDTokenHintVerifier), arg0) } +// Logger mocks base method. +func (m *MockAuthorizer) Logger() *slog.Logger { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Logger") + ret0, _ := ret[0].(*slog.Logger) + return ret0 +} + +// Logger indicates an expected call of Logger. +func (mr *MockAuthorizerMockRecorder) Logger() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logger", reflect.TypeOf((*MockAuthorizer)(nil).Logger)) +} + // RequestObjectSupported mocks base method. func (m *MockAuthorizer) RequestObjectSupported() bool { m.ctrl.T.Helper() diff --git a/pkg/op/op.go b/pkg/op/op.go index 1fbe7801..d8ae570b 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -9,6 +9,7 @@ import ( "github.com/go-chi/chi" "github.com/rs/cors" "github.com/zitadel/schema" + "golang.org/x/exp/slog" "golang.org/x/text/language" "gopkg.in/square/go-jose.v2" @@ -79,6 +80,9 @@ type OpenIDProvider interface { DefaultLogoutRedirectURI() string Probes() []ProbesFn + // EXPERIMENTAL: Will change to log/slog import after we drop support for Go 1.20 + Logger() *slog.Logger + // Deprecated: Provider now implements http.Handler directly. HttpHandler() http.Handler } @@ -174,6 +178,7 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR storage: storage, endpoints: DefaultEndpoints, timer: make(<-chan time.Time), + logger: slog.Default(), } for _, optFunc := range opOpts { @@ -217,6 +222,7 @@ type Provider struct { timer <-chan time.Time accessTokenVerifierOpts []AccessTokenVerifierOpt idTokenHintVerifierOpts []IDTokenHintVerifierOpt + logger *slog.Logger } func (o *Provider) IssuerFromRequest(r *http.Request) string { @@ -375,6 +381,10 @@ func (o *Provider) Probes() []ProbesFn { } } +func (o *Provider) Logger() *slog.Logger { + return o.logger +} + // Deprecated: Provider now implements http.Handler directly. func (o *Provider) HttpHandler() http.Handler { return o @@ -523,6 +533,16 @@ func WithIDTokenHintVerifierOpts(opts ...IDTokenHintVerifierOpt) Option { } } +// WithLogger lets a logger other than slog.Default(). +// +// EXPERIMENTAL: Will change to log/slog import after we drop support for Go 1.20 +func WithLogger(logger *slog.Logger) Option { + return func(o *Provider) error { + o.logger = logger + return nil + } +} + func intercept(i IssuerFromRequest, interceptors ...HttpInterceptor) func(handler http.Handler) http.Handler { issuerInterceptor := NewIssuerInterceptor(i) return func(handler http.Handler) http.Handler { diff --git a/pkg/op/op_test.go b/pkg/op/op_test.go index d347d048..d33b39d5 100644 --- a/pkg/op/op_test.go +++ b/pkg/op/op_test.go @@ -156,7 +156,7 @@ func TestRoutes(t *testing.T) { values: map[string]string{ "client_id": client.GetID(), "redirect_uri": "https://example.com", - "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), "response_type": string(oidc.ResponseTypeCode), }, wantCode: http.StatusFound, @@ -193,7 +193,7 @@ func TestRoutes(t *testing.T) { path: testProvider.TokenEndpoint().Relative(), values: map[string]string{ "grant_type": string(oidc.GrantTypeBearer), - "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), "assertion": jwtToken, }, wantCode: http.StatusBadRequest, @@ -206,7 +206,7 @@ func TestRoutes(t *testing.T) { basicAuth: &basicAuth{"web", "secret"}, values: map[string]string{ "grant_type": string(oidc.GrantTypeTokenExchange), - "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), "subject_token": jwtToken, "subject_token_type": string(oidc.AccessTokenType), }, @@ -223,7 +223,7 @@ func TestRoutes(t *testing.T) { basicAuth: &basicAuth{"sid1", "verysecret"}, values: map[string]string{ "grant_type": string(oidc.GrantTypeClientCredentials), - "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), }, wantCode: http.StatusOK, contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299}`}, @@ -338,7 +338,7 @@ func TestRoutes(t *testing.T) { path: testProvider.DeviceAuthorizationEndpoint().Relative(), basicAuth: &basicAuth{"web", "secret"}, values: map[string]string{ - "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), }, wantCode: http.StatusOK, contains: []string{ diff --git a/pkg/op/session.go b/pkg/op/session.go index fd914d11..2467b20f 100644 --- a/pkg/op/session.go +++ b/pkg/op/session.go @@ -8,6 +8,7 @@ import ( httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" + "golang.org/x/exp/slog" ) type SessionEnder interface { @@ -15,6 +16,7 @@ type SessionEnder interface { Storage() Storage IDTokenHintVerifier(context.Context) *IDTokenHintVerifier DefaultLogoutRedirectURI() string + Logger() *slog.Logger } func endSessionHandler(ender SessionEnder) func(http.ResponseWriter, *http.Request) { @@ -31,12 +33,12 @@ func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) { } session, err := ValidateEndSessionRequest(r.Context(), req, ender) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, ender.Logger()) return } err = ender.Storage().TerminateSession(r.Context(), session.UserID, session.ClientID) if err != nil { - RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session")) + RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session"), ender.Logger()) return } http.Redirect(w, r, session.RedirectURI, http.StatusFound) diff --git a/pkg/op/token_client_credentials.go b/pkg/op/token_client_credentials.go index 0cf77961..043bb072 100644 --- a/pkg/op/token_client_credentials.go +++ b/pkg/op/token_client_credentials.go @@ -14,18 +14,18 @@ import ( func ClientCredentialsExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { request, err := ParseClientCredentialsRequest(r, exchanger.Decoder()) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } validatedRequest, client, err := ValidateClientCredentialsRequest(r.Context(), request, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } resp, err := CreateClientCredentialsTokenResponse(r.Context(), validatedRequest, exchanger, client) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } diff --git a/pkg/op/token_code.go b/pkg/op/token_code.go index b5e892af..baf377bc 100644 --- a/pkg/op/token_code.go +++ b/pkg/op/token_code.go @@ -13,20 +13,20 @@ import ( func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { tokenReq, err := ParseAccessTokenRequest(r, exchanger.Decoder()) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } if tokenReq.Code == "" { - RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing")) + RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), exchanger.Logger()) return } authReq, client, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } resp, err := CreateTokenResponse(r.Context(), authReq, client, exchanger, true, tokenReq.Code, "") if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } httphelper.MarshalJSON(w, resp) diff --git a/pkg/op/token_exchange.go b/pkg/op/token_exchange.go index 93aa9b24..21db1347 100644 --- a/pkg/op/token_exchange.go +++ b/pkg/op/token_exchange.go @@ -136,17 +136,17 @@ func (r *tokenExchangeRequest) SetSubject(subject string) { func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { tokenExchangeReq, clientID, clientSecret, err := ParseTokenExchangeRequest(r, exchanger.Decoder()) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } tokenExchangeRequest, client, err := ValidateTokenExchangeRequest(r.Context(), tokenExchangeReq, clientID, clientSecret, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } resp, err := CreateTokenExchangeResponse(r.Context(), tokenExchangeRequest, client, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } httphelper.MarshalJSON(w, resp) diff --git a/pkg/op/token_jwt_profile.go b/pkg/op/token_jwt_profile.go index 4cd7b1e4..357200ee 100644 --- a/pkg/op/token_jwt_profile.go +++ b/pkg/op/token_jwt_profile.go @@ -18,23 +18,23 @@ type JWTAuthorizationGrantExchanger interface { func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger JWTAuthorizationGrantExchanger) { profileRequest, err := ParseJWTProfileGrantRequest(r, exchanger.Decoder()) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } tokenRequest, err := VerifyJWTAssertion(r.Context(), profileRequest.Assertion, exchanger.JWTProfileVerifier(r.Context())) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } tokenRequest.Scopes, err = exchanger.Storage().ValidateJWTProfileScopes(r.Context(), tokenRequest.Issuer, profileRequest.Scope) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } resp, err := CreateJWTTokenResponse(r.Context(), tokenRequest, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } httphelper.MarshalJSON(w, resp) diff --git a/pkg/op/token_refresh.go b/pkg/op/token_refresh.go index aeaa5b4b..9421033f 100644 --- a/pkg/op/token_refresh.go +++ b/pkg/op/token_refresh.go @@ -26,16 +26,16 @@ type RefreshTokenRequest interface { func RefreshTokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { tokenReq, err := ParseRefreshTokenRequest(r, exchanger.Decoder()) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } validatedRequest, client, err := ValidateRefreshTokenRequest(r.Context(), tokenReq, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } resp, err := CreateTokenResponse(r.Context(), validatedRequest, client, exchanger, true, "", tokenReq.RefreshToken) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } httphelper.MarshalJSON(w, resp) diff --git a/pkg/op/token_request.go b/pkg/op/token_request.go index c06a51bc..0df2fcee 100644 --- a/pkg/op/token_request.go +++ b/pkg/op/token_request.go @@ -7,6 +7,7 @@ import ( httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" + "golang.org/x/exp/slog" ) type Exchanger interface { @@ -22,6 +23,7 @@ type Exchanger interface { GrantTypeDeviceCodeSupported() bool AccessTokenVerifier(context.Context) *AccessTokenVerifier IDTokenHintVerifier(context.Context) *IDTokenHintVerifier + Logger() *slog.Logger } func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Request) { @@ -63,10 +65,10 @@ func Exchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { return } case "": - RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing")) + RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), exchanger.Logger()) return } - RequestError(w, r, oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", grantType)) + RequestError(w, r, oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", grantType), exchanger.Logger()) } // AuthenticatedTokenRequest is a helper interface for ParseAuthenticatedTokenRequest From daf82a5e041cb1a2279d7e6f8f56d64dd58b4371 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Fri, 1 Sep 2023 14:33:16 +0300 Subject: [PATCH 13/14] chore(deps): migrage jose to go-jose/v3 (#433) closes #390 --- example/server/storage/storage.go | 2 +- example/server/storage/storage_dynamic.go | 2 +- go.mod | 8 ++++---- go.sum | 19 +++++++++++-------- internal/testutil/token.go | 2 +- pkg/client/client.go | 2 +- pkg/client/profile/jwt_profile.go | 2 +- pkg/client/rp/jwks.go | 2 +- pkg/client/rp/relying_party.go | 2 +- pkg/client/rp/verifier.go | 2 +- pkg/client/rp/verifier_test.go | 2 +- pkg/crypto/hash.go | 2 +- pkg/crypto/sign.go | 2 +- pkg/oidc/keyset.go | 6 +++--- pkg/oidc/keyset_test.go | 2 +- pkg/oidc/token.go | 2 +- pkg/oidc/token_request.go | 2 +- pkg/oidc/token_test.go | 2 +- pkg/oidc/types.go | 2 +- pkg/oidc/verifier.go | 2 +- pkg/op/discovery.go | 2 +- pkg/op/discovery_test.go | 2 +- pkg/op/keys.go | 2 +- pkg/op/keys_test.go | 2 +- pkg/op/mock/authorizer.mock.impl.go | 2 +- pkg/op/mock/discovery.mock.go | 2 +- pkg/op/mock/signer.mock.go | 2 +- pkg/op/mock/storage.mock.go | 2 +- pkg/op/op.go | 2 +- pkg/op/signer.go | 2 +- pkg/op/storage.go | 2 +- pkg/op/verifier_jwt_profile.go | 2 +- 32 files changed, 47 insertions(+), 44 deletions(-) diff --git a/example/server/storage/storage.go b/example/server/storage/storage.go index e1160b6b..56f96ce3 100644 --- a/example/server/storage/storage.go +++ b/example/server/storage/storage.go @@ -11,8 +11,8 @@ import ( "sync" "time" + jose "github.com/go-jose/go-jose/v3" "github.com/google/uuid" - "gopkg.in/square/go-jose.v2" "github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/op" diff --git a/example/server/storage/storage_dynamic.go b/example/server/storage/storage_dynamic.go index 0d99aa27..5deb5dca 100644 --- a/example/server/storage/storage_dynamic.go +++ b/example/server/storage/storage_dynamic.go @@ -4,7 +4,7 @@ import ( "context" "time" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" "github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/op" diff --git a/go.mod b/go.mod index 62aa39be..1fe66cc1 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.19 require ( github.com/go-chi/chi v1.5.4 + github.com/go-jose/go-jose/v3 v3.0.0 github.com/golang/mock v1.6.0 github.com/google/go-github/v31 v31.0.0 github.com/google/uuid v1.3.0 @@ -17,8 +18,7 @@ require ( github.com/zitadel/schema v1.3.0 golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 golang.org/x/oauth2 v0.7.0 - golang.org/x/text v0.9.0 - gopkg.in/square/go-jose.v2 v2.6.0 + golang.org/x/text v0.10.0 ) require ( @@ -27,8 +27,8 @@ require ( github.com/google/go-cmp v0.5.9 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - golang.org/x/crypto v0.7.0 // indirect - golang.org/x/net v0.9.0 // indirect + golang.org/x/crypto v0.10.0 // indirect + golang.org/x/net v0.11.0 // indirect golang.org/x/sys v0.11.0 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/protobuf v1.29.1 // indirect diff --git a/go.sum b/go.sum index 9c44f0f6..9524f82b 100644 --- a/go.sum +++ b/go.sum @@ -3,6 +3,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-chi/chi v1.5.4 h1:QHdzF2szwjqVV4wmByUnTcsbIg7UGaQ0tPF2t5GcAIs= github.com/go-chi/chi v1.5.4/go.mod h1:uaf8YgoFazUOkPBG7fxPftUylNumIev9awIWOENIuEg= +github.com/go-jose/go-jose/v3 v3.0.0 h1:s6rrhirfEP/CGIoc6p+PZAeogN2SxKav6Wp7+dyMWVo= +github.com/go-jose/go-jose/v3 v3.0.0/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -10,6 +12,7 @@ github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= @@ -41,6 +44,7 @@ github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVs github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= @@ -52,9 +56,10 @@ github.com/zitadel/logging v0.4.0/go.mod h1:6uALRJawpkkuUPCkgzfgcPR3c2N908wqnOnI github.com/zitadel/schema v1.3.0 h1:kQ9W9tvIwZICCKWcMvCEweXET1OcOyGEuFbHs4o5kg0= github.com/zitadel/schema v1.3.0/go.mod h1:NptN6mkBDFvERUCvZHlvWmmME+gmZ44xzwRXwhzsbtc= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= -golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= +golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= +golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ= golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= @@ -64,8 +69,8 @@ golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= -golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= +golang.org/x/net v0.11.0 h1:Gi2tvZIJyBtO9SDr1q9h5hEQCp/4L2RQ+ar0qjx2oNU= +golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.7.0 h1:qe6s0zUXlPX80/dITx3440hWZ7GwMwgDDyrSGTPJG/g= golang.org/x/oauth2 v0.7.0/go.mod h1:hPLQkd9LyjfXTiRohC/41GhcFqxisoUQ99sCUOHO9x4= @@ -83,8 +88,8 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9sn golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58= +golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= @@ -102,8 +107,6 @@ google.golang.org/protobuf v1.29.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqw gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI= -gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/internal/testutil/token.go b/internal/testutil/token.go index 41778de7..2dd788f5 100644 --- a/internal/testutil/token.go +++ b/internal/testutil/token.go @@ -8,9 +8,9 @@ import ( "errors" "time" + jose "github.com/go-jose/go-jose/v3" "github.com/muhlemmer/gu" "github.com/zitadel/oidc/v3/pkg/oidc" - "gopkg.in/square/go-jose.v2" ) // KeySet implements oidc.Keys diff --git a/pkg/client/client.go b/pkg/client/client.go index 7b76dfd6..d7764f66 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -11,8 +11,8 @@ import ( "strings" "time" + jose "github.com/go-jose/go-jose/v3" "golang.org/x/oauth2" - "gopkg.in/square/go-jose.v2" "github.com/zitadel/logging" "github.com/zitadel/oidc/v3/pkg/crypto" diff --git a/pkg/client/profile/jwt_profile.go b/pkg/client/profile/jwt_profile.go index 419f4175..a24033c9 100644 --- a/pkg/client/profile/jwt_profile.go +++ b/pkg/client/profile/jwt_profile.go @@ -5,8 +5,8 @@ import ( "net/http" "time" + jose "github.com/go-jose/go-jose/v3" "golang.org/x/oauth2" - "gopkg.in/square/go-jose.v2" "github.com/zitadel/oidc/v3/pkg/client" "github.com/zitadel/oidc/v3/pkg/oidc" diff --git a/pkg/client/rp/jwks.go b/pkg/client/rp/jwks.go index 79cf2322..28aec9b9 100644 --- a/pkg/client/rp/jwks.go +++ b/pkg/client/rp/jwks.go @@ -7,7 +7,7 @@ import ( "net/http" "sync" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" diff --git a/pkg/client/rp/relying_party.go b/pkg/client/rp/relying_party.go index 34cdb397..877c837e 100644 --- a/pkg/client/rp/relying_party.go +++ b/pkg/client/rp/relying_party.go @@ -9,11 +9,11 @@ import ( "net/url" "time" + jose "github.com/go-jose/go-jose/v3" "github.com/google/uuid" "github.com/zitadel/logging" "golang.org/x/exp/slog" "golang.org/x/oauth2" - "gopkg.in/square/go-jose.v2" "github.com/zitadel/oidc/v3/pkg/client" httphelper "github.com/zitadel/oidc/v3/pkg/http" diff --git a/pkg/client/rp/verifier.go b/pkg/client/rp/verifier.go index 3294f407..adf88725 100644 --- a/pkg/client/rp/verifier.go +++ b/pkg/client/rp/verifier.go @@ -4,7 +4,7 @@ import ( "context" "time" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" "github.com/zitadel/oidc/v3/pkg/oidc" ) diff --git a/pkg/client/rp/verifier_test.go b/pkg/client/rp/verifier_test.go index 11bf2f9f..3e6d9d99 100644 --- a/pkg/client/rp/verifier_test.go +++ b/pkg/client/rp/verifier_test.go @@ -5,11 +5,11 @@ import ( "testing" "time" + jose "github.com/go-jose/go-jose/v3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" tu "github.com/zitadel/oidc/v3/internal/testutil" "github.com/zitadel/oidc/v3/pkg/oidc" - "gopkg.in/square/go-jose.v2" ) func TestVerifyTokens(t *testing.T) { diff --git a/pkg/crypto/hash.go b/pkg/crypto/hash.go index 6fcc71fd..0ed2774d 100644 --- a/pkg/crypto/hash.go +++ b/pkg/crypto/hash.go @@ -8,7 +8,7 @@ import ( "fmt" "hash" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" ) var ErrUnsupportedAlgorithm = errors.New("unsupported signing algorithm") diff --git a/pkg/crypto/sign.go b/pkg/crypto/sign.go index a0b9cae7..d6c002b3 100644 --- a/pkg/crypto/sign.go +++ b/pkg/crypto/sign.go @@ -4,7 +4,7 @@ import ( "encoding/json" "errors" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" ) func Sign(object interface{}, signer jose.Signer) (string, error) { diff --git a/pkg/oidc/keyset.go b/pkg/oidc/keyset.go index c6e865b2..a4a5a1c3 100644 --- a/pkg/oidc/keyset.go +++ b/pkg/oidc/keyset.go @@ -7,7 +7,7 @@ import ( "crypto/rsa" "errors" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" ) const ( @@ -46,8 +46,8 @@ func GetKeyIDAndAlg(jws *jose.JSONWebSignature) (string, string) { // // will return false none or multiple match // -//deprecated: use FindMatchingKey which will return an error (more specific) instead of just a bool -//moved implementation already to FindMatchingKey +// deprecated: use FindMatchingKey which will return an error (more specific) instead of just a bool +// moved implementation already to FindMatchingKey func FindKey(keyID, use, expectedAlg string, keys ...jose.JSONWebKey) (jose.JSONWebKey, bool) { key, err := FindMatchingKey(keyID, use, expectedAlg, keys...) return key, err == nil diff --git a/pkg/oidc/keyset_test.go b/pkg/oidc/keyset_test.go index 82b3ee83..f8641f2a 100644 --- a/pkg/oidc/keyset_test.go +++ b/pkg/oidc/keyset_test.go @@ -7,7 +7,7 @@ import ( "reflect" "testing" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" ) func TestFindKey(t *testing.T) { diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index c02eaf4b..4624e509 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -5,8 +5,8 @@ import ( "os" "time" + jose "github.com/go-jose/go-jose/v3" "golang.org/x/oauth2" - "gopkg.in/square/go-jose.v2" "github.com/muhlemmer/gu" "github.com/zitadel/oidc/v3/pkg/crypto" diff --git a/pkg/oidc/token_request.go b/pkg/oidc/token_request.go index 5c5cf20f..330c0c27 100644 --- a/pkg/oidc/token_request.go +++ b/pkg/oidc/token_request.go @@ -5,7 +5,7 @@ import ( "fmt" "time" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" ) const ( diff --git a/pkg/oidc/token_test.go b/pkg/oidc/token_test.go index ef1e77f8..854f4555 100644 --- a/pkg/oidc/token_test.go +++ b/pkg/oidc/token_test.go @@ -4,9 +4,9 @@ import ( "testing" "time" + jose "github.com/go-jose/go-jose/v3" "github.com/stretchr/testify/assert" "golang.org/x/text/language" - "gopkg.in/square/go-jose.v2" ) var ( diff --git a/pkg/oidc/types.go b/pkg/oidc/types.go index 5db8badc..dd604ad4 100644 --- a/pkg/oidc/types.go +++ b/pkg/oidc/types.go @@ -8,9 +8,9 @@ import ( "strings" "time" + jose "github.com/go-jose/go-jose/v3" "github.com/zitadel/schema" "golang.org/x/text/language" - "gopkg.in/square/go-jose.v2" ) type Audience []string diff --git a/pkg/oidc/verifier.go b/pkg/oidc/verifier.go index 2d4e7a67..14174a4d 100644 --- a/pkg/oidc/verifier.go +++ b/pkg/oidc/verifier.go @@ -10,7 +10,7 @@ import ( "strings" "time" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" str "github.com/zitadel/oidc/v3/pkg/strings" ) diff --git a/pkg/op/discovery.go b/pkg/op/discovery.go index 38afeab7..782a279e 100644 --- a/pkg/op/discovery.go +++ b/pkg/op/discovery.go @@ -4,7 +4,7 @@ import ( "context" "net/http" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" diff --git a/pkg/op/discovery_test.go b/pkg/op/discovery_test.go index e55e9051..3e95ec35 100644 --- a/pkg/op/discovery_test.go +++ b/pkg/op/discovery_test.go @@ -6,10 +6,10 @@ import ( "net/http/httptest" "testing" + jose "github.com/go-jose/go-jose/v3" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/square/go-jose.v2" "github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/op" diff --git a/pkg/op/keys.go b/pkg/op/keys.go index 418dcb59..fe111f0f 100644 --- a/pkg/op/keys.go +++ b/pkg/op/keys.go @@ -4,7 +4,7 @@ import ( "context" "net/http" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" httphelper "github.com/zitadel/oidc/v3/pkg/http" ) diff --git a/pkg/op/keys_test.go b/pkg/op/keys_test.go index 259b87c6..e1a38512 100644 --- a/pkg/op/keys_test.go +++ b/pkg/op/keys_test.go @@ -7,9 +7,9 @@ import ( "net/http/httptest" "testing" + jose "github.com/go-jose/go-jose/v3" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" - "gopkg.in/square/go-jose.v2" "github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/op" diff --git a/pkg/op/mock/authorizer.mock.impl.go b/pkg/op/mock/authorizer.mock.impl.go index 4d66a922..ba5082f0 100644 --- a/pkg/op/mock/authorizer.mock.impl.go +++ b/pkg/op/mock/authorizer.mock.impl.go @@ -4,9 +4,9 @@ import ( "context" "testing" + jose "github.com/go-jose/go-jose/v3" "github.com/golang/mock/gomock" "github.com/zitadel/schema" - "gopkg.in/square/go-jose.v2" "github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/op" diff --git a/pkg/op/mock/discovery.mock.go b/pkg/op/mock/discovery.mock.go index 4c33953d..c5d3d3a6 100644 --- a/pkg/op/mock/discovery.mock.go +++ b/pkg/op/mock/discovery.mock.go @@ -8,8 +8,8 @@ import ( context "context" reflect "reflect" + jose "github.com/go-jose/go-jose/v3" gomock "github.com/golang/mock/gomock" - jose "gopkg.in/square/go-jose.v2" ) // MockDiscoverStorage is a mock of DiscoverStorage interface. diff --git a/pkg/op/mock/signer.mock.go b/pkg/op/mock/signer.mock.go index 7075241d..15718e07 100644 --- a/pkg/op/mock/signer.mock.go +++ b/pkg/op/mock/signer.mock.go @@ -7,8 +7,8 @@ package mock import ( reflect "reflect" + jose "github.com/go-jose/go-jose/v3" gomock "github.com/golang/mock/gomock" - jose "gopkg.in/square/go-jose.v2" ) // MockSigningKey is a mock of SigningKey interface. diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go index 6bfb1c98..a1ce598b 100644 --- a/pkg/op/mock/storage.mock.go +++ b/pkg/op/mock/storage.mock.go @@ -9,10 +9,10 @@ import ( reflect "reflect" time "time" + jose "github.com/go-jose/go-jose/v3" gomock "github.com/golang/mock/gomock" oidc "github.com/zitadel/oidc/v3/pkg/oidc" op "github.com/zitadel/oidc/v3/pkg/op" - jose "gopkg.in/square/go-jose.v2" ) // MockStorage is a mock of Storage interface. diff --git a/pkg/op/op.go b/pkg/op/op.go index d8ae570b..0175d7fe 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -7,11 +7,11 @@ import ( "time" "github.com/go-chi/chi" + jose "github.com/go-jose/go-jose/v3" "github.com/rs/cors" "github.com/zitadel/schema" "golang.org/x/exp/slog" "golang.org/x/text/language" - "gopkg.in/square/go-jose.v2" httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" diff --git a/pkg/op/signer.go b/pkg/op/signer.go index 22ef8caa..37bd5bbf 100644 --- a/pkg/op/signer.go +++ b/pkg/op/signer.go @@ -3,7 +3,7 @@ package op import ( "errors" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" ) var ( diff --git a/pkg/op/storage.go b/pkg/op/storage.go index 23d21334..aca16157 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -5,7 +5,7 @@ import ( "errors" "time" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" "github.com/zitadel/oidc/v3/pkg/oidc" ) diff --git a/pkg/op/verifier_jwt_profile.go b/pkg/op/verifier_jwt_profile.go index 1daa15fc..19adbb67 100644 --- a/pkg/op/verifier_jwt_profile.go +++ b/pkg/op/verifier_jwt_profile.go @@ -6,7 +6,7 @@ import ( "fmt" "time" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v3" "github.com/zitadel/oidc/v3/pkg/oidc" ) From 0f8a0585bf29f1daf26bd17b3b899be7fb40a37f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Thu, 28 Sep 2023 17:30:08 +0300 Subject: [PATCH 14/14] feat(op): Server interface (#447) * first draft of a new server interface * allow any response type * complete interface docs * refelct the format from the proposal * intermediate commit with some methods implemented * implement remaining token grant type methods * implement remaining server methods * error handling * rewrite auth request validation * define handlers, routes * input validation and concrete handlers * check if client credential client is authenticated * copy and modify the routes test for the legacy server * run integration tests against both Server and Provider * remove unuse ValidateAuthRequestV2 function * unit tests for error handling * cleanup tokenHandler * move server routest test * unit test authorize * handle client credentials in VerifyClient * change code exchange route test * finish http unit tests * review server interface docs and spelling * add withClient unit test * server options * cleanup unused GrantType method * resolve typo comments * make endpoints pointers to enable/disable them * jwt profile base work * jwt: correct the test expect --------- Co-authored-by: Livio Spring --- example/server/exampleop/op.go | 9 +- example/server/main.go | 2 +- example/server/storage/client.go | 2 +- pkg/client/integration_test.go | 21 +- pkg/op/auth_request.go | 18 +- pkg/op/client.go | 7 + pkg/op/config.go | 16 +- pkg/op/device.go | 34 +- pkg/op/discovery.go | 39 +- pkg/op/discovery_test.go | 8 +- pkg/op/endpoint.go | 30 +- pkg/op/endpoint_test.go | 31 +- pkg/op/error.go | 101 +++ pkg/op/error_test.go | 400 +++++++++ pkg/op/mock/configuration.mock.go | 32 +- pkg/op/op.go | 67 +- pkg/op/op_test.go | 51 ++ pkg/op/probes.go | 4 +- pkg/op/server.go | 346 ++++++++ pkg/op/server_http.go | 480 +++++++++++ pkg/op/server_http_routes_test.go | 345 ++++++++ pkg/op/server_http_test.go | 1333 +++++++++++++++++++++++++++++ pkg/op/server_legacy.go | 344 ++++++++ pkg/op/server_test.go | 5 + pkg/op/token_code.go | 2 +- pkg/op/token_exchange.go | 36 +- pkg/op/token_request.go | 6 +- pkg/op/token_revocation.go | 7 +- 28 files changed, 3652 insertions(+), 124 deletions(-) create mode 100644 pkg/op/server.go create mode 100644 pkg/op/server_http.go create mode 100644 pkg/op/server_http_routes_test.go create mode 100644 pkg/op/server_http_test.go create mode 100644 pkg/op/server_legacy.go create mode 100644 pkg/op/server_test.go diff --git a/example/server/exampleop/op.go b/example/server/exampleop/op.go index b5ee7b37..f1906baf 100644 --- a/example/server/exampleop/op.go +++ b/example/server/exampleop/op.go @@ -40,7 +40,7 @@ var counter atomic.Int64 // SetupServer creates an OIDC server with Issuer=http://localhost: // // Use one of the pre-made clients in storage/clients.go or register a new one. -func SetupServer(issuer string, storage Storage, logger *slog.Logger) chi.Router { +func SetupServer(issuer string, storage Storage, logger *slog.Logger, wrapServer bool) chi.Router { // the OpenID Provider requires a 32-byte key for (token) encryption // be sure to create a proper crypto random key and manage it securely! key := sha256.Sum256([]byte("test")) @@ -77,12 +77,17 @@ func SetupServer(issuer string, storage Storage, logger *slog.Logger) chi.Router registerDeviceAuth(storage, r) }) + handler := http.Handler(provider) + if wrapServer { + handler = op.NewLegacyServer(provider, *op.DefaultEndpoints) + } + // we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration) // is served on the correct path // // if your issuer ends with a path (e.g. http://localhost:9998/custom/path/), // then you would have to set the path prefix (/custom/path/) - router.Mount("/", provider) + router.Mount("/", handler) return router } diff --git a/example/server/main.go b/example/server/main.go index a1cc4618..38057fb7 100644 --- a/example/server/main.go +++ b/example/server/main.go @@ -27,7 +27,7 @@ func main() { Level: slog.LevelDebug, }), ) - router := exampleop.SetupServer(issuer, storage, logger) + router := exampleop.SetupServer(issuer, storage, logger, false) server := &http.Server{ Addr: ":" + port, diff --git a/example/server/storage/client.go b/example/server/storage/client.go index a3e7cc45..f512a992 100644 --- a/example/server/storage/client.go +++ b/example/server/storage/client.go @@ -185,7 +185,7 @@ func WebClient(id, secret string, redirectURIs ...string) *Client { authMethod: oidc.AuthMethodBasic, loginURL: defaultLoginURL, responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode}, - grantTypes: []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken}, + grantTypes: oidc.AllGrantTypes, accessTokenType: op.AccessTokenTypeBearer, devMode: false, idTokenUserinfoClaimsAssertion: false, diff --git a/pkg/client/integration_test.go b/pkg/client/integration_test.go index 7cbb62e6..1d3559e3 100644 --- a/pkg/client/integration_test.go +++ b/pkg/client/integration_test.go @@ -3,6 +3,7 @@ package client_test import ( "bytes" "context" + "fmt" "io" "math/rand" "net/http" @@ -50,6 +51,14 @@ func TestMain(m *testing.M) { } func TestRelyingPartySession(t *testing.T) { + for _, wrapServer := range []bool{false, true} { + t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) { + testRelyingPartySession(t, wrapServer) + }) + } +} + +func testRelyingPartySession(t *testing.T, wrapServer bool) { t.Log("------- start example OP ------") targetURL := "http://local-site" exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL)) @@ -57,7 +66,7 @@ func TestRelyingPartySession(t *testing.T) { opServer := httptest.NewServer(&dh) defer opServer.Close() t.Logf("auth server at %s", opServer.URL) - dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger) + dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer) seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano())) clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) @@ -101,6 +110,14 @@ func TestRelyingPartySession(t *testing.T) { } func TestResourceServerTokenExchange(t *testing.T) { + for _, wrapServer := range []bool{false, true} { + t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) { + testResourceServerTokenExchange(t, wrapServer) + }) + } +} + +func testResourceServerTokenExchange(t *testing.T, wrapServer bool) { t.Log("------- start example OP ------") targetURL := "http://local-site" exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL)) @@ -108,7 +125,7 @@ func TestResourceServerTokenExchange(t *testing.T) { opServer := httptest.NewServer(&dh) defer opServer.Close() t.Logf("auth server at %s", opServer.URL) - dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger) + dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer) seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano())) clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) diff --git a/pkg/op/auth_request.go b/pkg/op/auth_request.go index 7610248e..20b1bf4c 100644 --- a/pkg/op/auth_request.go +++ b/pkg/op/auth_request.go @@ -74,7 +74,7 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { } ctx := r.Context() if authReq.RequestParam != "" && authorizer.RequestObjectSupported() { - authReq, err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx)) + err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx)) if err != nil { AuthRequestError(w, r, authReq, err, authorizer) return @@ -130,31 +130,31 @@ func ParseAuthorizeRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.A // ParseRequestObject parse the `request` parameter, validates the token including the signature // and copies the token claims into the auth request -func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, issuer string) (*oidc.AuthRequest, error) { +func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, issuer string) error { requestObject := new(oidc.RequestObject) payload, err := oidc.ParseToken(authReq.RequestParam, requestObject) if err != nil { - return nil, err + return err } if requestObject.ClientID != "" && requestObject.ClientID != authReq.ClientID { - return authReq, oidc.ErrInvalidRequest() + return oidc.ErrInvalidRequest() } if requestObject.ResponseType != "" && requestObject.ResponseType != authReq.ResponseType { - return authReq, oidc.ErrInvalidRequest() + return oidc.ErrInvalidRequest() } if requestObject.Issuer != requestObject.ClientID { - return authReq, oidc.ErrInvalidRequest() + return oidc.ErrInvalidRequest() } if !str.Contains(requestObject.Audience, issuer) { - return authReq, oidc.ErrInvalidRequest() + return oidc.ErrInvalidRequest() } keySet := &jwtProfileKeySet{storage: storage, clientID: requestObject.Issuer} if err = oidc.CheckSignature(ctx, authReq.RequestParam, payload, requestObject, nil, keySet); err != nil { - return authReq, err + return err } CopyRequestObjectToAuthRequest(authReq, requestObject) - return authReq, nil + return nil } // CopyRequestObjectToAuthRequest overwrites present values from the Request Object into the auth request diff --git a/pkg/op/client.go b/pkg/op/client.go index d01845f2..04ef3c71 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -180,3 +180,10 @@ func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, au } return data.ClientID, false, nil } + +type ClientCredentials struct { + ClientID string `schema:"client_id"` + ClientSecret string `schema:"client_secret"` // Client secret from Basic auth or request body + ClientAssertion string `schema:"client_assertion"` // JWT + ClientAssertionType string `schema:"client_assertion_type"` +} diff --git a/pkg/op/config.go b/pkg/op/config.go index c40ed39e..f61412a8 100644 --- a/pkg/op/config.go +++ b/pkg/op/config.go @@ -20,14 +20,14 @@ var ( type Configuration interface { IssuerFromRequest(r *http.Request) string Insecure() bool - AuthorizationEndpoint() Endpoint - TokenEndpoint() Endpoint - IntrospectionEndpoint() Endpoint - UserinfoEndpoint() Endpoint - RevocationEndpoint() Endpoint - EndSessionEndpoint() Endpoint - KeysEndpoint() Endpoint - DeviceAuthorizationEndpoint() Endpoint + AuthorizationEndpoint() *Endpoint + TokenEndpoint() *Endpoint + IntrospectionEndpoint() *Endpoint + UserinfoEndpoint() *Endpoint + RevocationEndpoint() *Endpoint + EndSessionEndpoint() *Endpoint + KeysEndpoint() *Endpoint + DeviceAuthorizationEndpoint() *Endpoint AuthMethodPostSupported() bool CodeMethodS256Supported() bool diff --git a/pkg/op/device.go b/pkg/op/device.go index 029bed8a..55d3c572 100644 --- a/pkg/op/device.go +++ b/pkg/op/device.go @@ -63,41 +63,51 @@ func DeviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *htt } func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) error { - storage, err := assertDeviceStorage(o.Storage()) + req, err := ParseDeviceCodeRequest(r, o) if err != nil { return err } - - req, err := ParseDeviceCodeRequest(r, o) + response, err := createDeviceAuthorization(r.Context(), req, req.ClientID, o) if err != nil { return err } + httphelper.MarshalJSON(w, response) + return nil +} + +func createDeviceAuthorization(ctx context.Context, req *oidc.DeviceAuthorizationRequest, clientID string, o OpenIDProvider) (*oidc.DeviceAuthorizationResponse, error) { + storage, err := assertDeviceStorage(o.Storage()) + if err != nil { + return nil, err + } config := o.DeviceAuthorization() deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes) if err != nil { - return err + return nil, NewStatusError(err, http.StatusInternalServerError) } userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.DashInterval) if err != nil { - return err + return nil, NewStatusError(err, http.StatusInternalServerError) } expires := time.Now().Add(config.Lifetime) - err = storage.StoreDeviceAuthorization(r.Context(), req.ClientID, deviceCode, userCode, expires, req.Scopes) + err = storage.StoreDeviceAuthorization(ctx, clientID, deviceCode, userCode, expires, req.Scopes) if err != nil { - return err + return nil, NewStatusError(err, http.StatusInternalServerError) } var verification *url.URL if config.UserFormURL != "" { if verification, err = url.Parse(config.UserFormURL); err != nil { - return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form") + err = oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form") + return nil, NewStatusError(err, http.StatusInternalServerError) } } else { - if verification, err = url.Parse(IssuerFromContext(r.Context())); err != nil { - return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer") + if verification, err = url.Parse(IssuerFromContext(ctx)); err != nil { + err = oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer") + return nil, NewStatusError(err, http.StatusInternalServerError) } verification.Path = config.UserFormPath } @@ -112,9 +122,7 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide verification.RawQuery = "user_code=" + userCode response.VerificationURIComplete = verification.String() - - httphelper.MarshalJSON(w, response) - return nil + return response, nil } func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuthorizationRequest, error) { diff --git a/pkg/op/discovery.go b/pkg/op/discovery.go index 782a279e..82512615 100644 --- a/pkg/op/discovery.go +++ b/pkg/op/discovery.go @@ -25,7 +25,7 @@ var DefaultSupportedScopes = []string{ func discoveryHandler(c Configuration, s DiscoverStorage) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - Discover(w, CreateDiscoveryConfig(r, c, s)) + Discover(w, CreateDiscoveryConfig(r.Context(), c, s)) } } @@ -33,8 +33,8 @@ func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) { httphelper.MarshalJSON(w, config) } -func CreateDiscoveryConfig(r *http.Request, config Configuration, storage DiscoverStorage) *oidc.DiscoveryConfiguration { - issuer := config.IssuerFromRequest(r) +func CreateDiscoveryConfig(ctx context.Context, config Configuration, storage DiscoverStorage) *oidc.DiscoveryConfiguration { + issuer := IssuerFromContext(ctx) return &oidc.DiscoveryConfiguration{ Issuer: issuer, AuthorizationEndpoint: config.AuthorizationEndpoint().Absolute(issuer), @@ -49,7 +49,38 @@ func CreateDiscoveryConfig(r *http.Request, config Configuration, storage Discov ResponseTypesSupported: ResponseTypes(config), GrantTypesSupported: GrantTypes(config), SubjectTypesSupported: SubjectTypes(config), - IDTokenSigningAlgValuesSupported: SigAlgorithms(r.Context(), storage), + IDTokenSigningAlgValuesSupported: SigAlgorithms(ctx, storage), + RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config), + TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config), + TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config), + IntrospectionEndpointAuthSigningAlgValuesSupported: IntrospectionSigAlgorithms(config), + IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(config), + RevocationEndpointAuthSigningAlgValuesSupported: RevocationSigAlgorithms(config), + RevocationEndpointAuthMethodsSupported: AuthMethodsRevocationEndpoint(config), + ClaimsSupported: SupportedClaims(config), + CodeChallengeMethodsSupported: CodeChallengeMethods(config), + UILocalesSupported: config.SupportedUILocales(), + RequestParameterSupported: config.RequestObjectSupported(), + } +} + +func createDiscoveryConfigV2(ctx context.Context, config Configuration, storage DiscoverStorage, endpoints *Endpoints) *oidc.DiscoveryConfiguration { + issuer := IssuerFromContext(ctx) + return &oidc.DiscoveryConfiguration{ + Issuer: issuer, + AuthorizationEndpoint: endpoints.Authorization.Absolute(issuer), + TokenEndpoint: endpoints.Token.Absolute(issuer), + IntrospectionEndpoint: endpoints.Introspection.Absolute(issuer), + UserinfoEndpoint: endpoints.Userinfo.Absolute(issuer), + RevocationEndpoint: endpoints.Revocation.Absolute(issuer), + EndSessionEndpoint: endpoints.EndSession.Absolute(issuer), + JwksURI: endpoints.JwksURI.Absolute(issuer), + DeviceAuthorizationEndpoint: endpoints.DeviceAuthorization.Absolute(issuer), + ScopesSupported: Scopes(config), + ResponseTypesSupported: ResponseTypes(config), + GrantTypesSupported: GrantTypes(config), + SubjectTypesSupported: SubjectTypes(config), + IDTokenSigningAlgValuesSupported: SigAlgorithms(ctx, storage), RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config), TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config), TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config), diff --git a/pkg/op/discovery_test.go b/pkg/op/discovery_test.go index 3e95ec35..84e12165 100644 --- a/pkg/op/discovery_test.go +++ b/pkg/op/discovery_test.go @@ -48,9 +48,9 @@ func TestDiscover(t *testing.T) { func TestCreateDiscoveryConfig(t *testing.T) { type args struct { - request *http.Request - c op.Configuration - s op.DiscoverStorage + ctx context.Context + c op.Configuration + s op.DiscoverStorage } tests := []struct { name string @@ -61,7 +61,7 @@ func TestCreateDiscoveryConfig(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := op.CreateDiscoveryConfig(tt.args.request, tt.args.c, tt.args.s) + got := op.CreateDiscoveryConfig(tt.args.ctx, tt.args.c, tt.args.s) assert.Equal(t, tt.want, got) }) } diff --git a/pkg/op/endpoint.go b/pkg/op/endpoint.go index b1e15073..1ac1cad7 100644 --- a/pkg/op/endpoint.go +++ b/pkg/op/endpoint.go @@ -1,32 +1,46 @@ package op -import "strings" +import ( + "errors" + "strings" +) type Endpoint struct { path string url string } -func NewEndpoint(path string) Endpoint { - return Endpoint{path: path} +func NewEndpoint(path string) *Endpoint { + return &Endpoint{path: path} } -func NewEndpointWithURL(path, url string) Endpoint { - return Endpoint{path: path, url: url} +func NewEndpointWithURL(path, url string) *Endpoint { + return &Endpoint{path: path, url: url} } -func (e Endpoint) Relative() string { +func (e *Endpoint) Relative() string { + if e == nil { + return "" + } return relativeEndpoint(e.path) } -func (e Endpoint) Absolute(host string) string { +func (e *Endpoint) Absolute(host string) string { + if e == nil { + return "" + } if e.url != "" { return e.url } return absoluteEndpoint(host, e.path) } -func (e Endpoint) Validate() error { +var ErrNilEndpoint = errors.New("nil endpoint") + +func (e *Endpoint) Validate() error { + if e == nil { + return ErrNilEndpoint + } return nil // TODO: } diff --git a/pkg/op/endpoint_test.go b/pkg/op/endpoint_test.go index 46e5d478..bf112eff 100644 --- a/pkg/op/endpoint_test.go +++ b/pkg/op/endpoint_test.go @@ -3,13 +3,14 @@ package op_test import ( "testing" + "github.com/stretchr/testify/require" "github.com/zitadel/oidc/v3/pkg/op" ) func TestEndpoint_Path(t *testing.T) { tests := []struct { name string - e op.Endpoint + e *op.Endpoint want string }{ { @@ -27,6 +28,11 @@ func TestEndpoint_Path(t *testing.T) { op.NewEndpointWithURL("/test", "http://test.com/test"), "/test", }, + { + "nil", + nil, + "", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -43,7 +49,7 @@ func TestEndpoint_Absolute(t *testing.T) { } tests := []struct { name string - e op.Endpoint + e *op.Endpoint args args want string }{ @@ -77,6 +83,12 @@ func TestEndpoint_Absolute(t *testing.T) { args{"https://host"}, "https://test.com/test", }, + { + "nil", + nil, + args{"https://host"}, + "", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -91,16 +103,19 @@ func TestEndpoint_Absolute(t *testing.T) { func TestEndpoint_Validate(t *testing.T) { tests := []struct { name string - e op.Endpoint - wantErr bool + e *op.Endpoint + wantErr error }{ - // TODO: Add test cases. + { + "nil", + nil, + op.ErrNilEndpoint, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := tt.e.Validate(); (err != nil) != tt.wantErr { - t.Errorf("Endpoint.Validate() error = %v, wantErr %v", err, tt.wantErr) - } + err := tt.e.Validate() + require.ErrorIs(t, err, tt.wantErr) }) } } diff --git a/pkg/op/error.go b/pkg/op/error.go index 9981fecc..0cac14b9 100644 --- a/pkg/op/error.go +++ b/pkg/op/error.go @@ -1,6 +1,9 @@ package op import ( + "context" + "errors" + "fmt" "net/http" httphelper "github.com/zitadel/oidc/v3/pkg/http" @@ -66,3 +69,101 @@ func RequestError(w http.ResponseWriter, r *http.Request, err error, logger *slo logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e) httphelper.MarshalJSONWithStatus(w, e, status) } + +// TryErrorRedirect tries to handle an error by redirecting a client. +// If this attempt fails, an error is returned that must be returned +// to the client instead. +func TryErrorRedirect(ctx context.Context, authReq ErrAuthRequest, parent error, encoder httphelper.Encoder, logger *slog.Logger) (*Redirect, error) { + e := oidc.DefaultToServerError(parent, parent.Error()) + logger = logger.With("oidc_error", e) + + if authReq == nil { + logger.Log(ctx, e.LogLevel(), "auth request") + return nil, AsStatusError(e, http.StatusBadRequest) + } + + if logAuthReq, ok := authReq.(LogAuthRequest); ok { + logger = logger.With("auth_request", logAuthReq) + } + + if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() { + logger.Log(ctx, e.LogLevel(), "auth request: not redirecting") + return nil, AsStatusError(e, http.StatusBadRequest) + } + + e.State = authReq.GetState() + var responseMode oidc.ResponseMode + if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok { + responseMode = rm.GetResponseMode() + } + url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder) + if err != nil { + logger.ErrorContext(ctx, "auth response URL", "error", err) + return nil, AsStatusError(err, http.StatusBadRequest) + } + logger.Log(ctx, e.LogLevel(), "auth request redirect", "url", url) + return NewRedirect(url), nil +} + +// StatusError wraps an error with a HTTP status code. +// The status code is passed to the handler's writer. +type StatusError struct { + parent error + statusCode int +} + +// NewStatusError sets the parent and statusCode to a new StatusError. +// It is recommended for parent to be an [oidc.Error]. +// +// Typically implementations should only use this to signal something +// very specific, like an internal server error. +// If a returned error is not a StatusError, the framework +// will set a statusCode based on what the standard specifies, +// which is [http.StatusBadRequest] for most of the time. +// If the error encountered can described clearly with a [oidc.Error], +// do not use this function, as it might break standard rules! +func NewStatusError(parent error, statusCode int) StatusError { + return StatusError{ + parent: parent, + statusCode: statusCode, + } +} + +// AsStatusError unwraps a StatusError from err +// and returns it unmodified if found. +// If no StatuError was found, a new one is returned +// with statusCode set to it as a default. +func AsStatusError(err error, statusCode int) (target StatusError) { + if errors.As(err, &target) { + return target + } + return NewStatusError(err, statusCode) +} + +func (e StatusError) Error() string { + return fmt.Sprintf("%s: %s", http.StatusText(e.statusCode), e.parent.Error()) +} + +func (e StatusError) Unwrap() error { + return e.parent +} + +func (e StatusError) Is(err error) bool { + var target StatusError + if !errors.As(err, &target) { + return false + } + return errors.Is(e.parent, target.parent) && + e.statusCode == target.statusCode +} + +// WriteError asserts for a StatusError containing an [oidc.Error]. +// If no StatusError is found, the status code will default to [http.StatusBadRequest]. +// If no [oidc.Error] was found in the parent, the error type defaults to [oidc.ServerError]. +func WriteError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) { + statusError := AsStatusError(err, http.StatusBadRequest) + e := oidc.DefaultToServerError(statusError.parent, statusError.parent.Error()) + + logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e) + httphelper.MarshalJSONWithStatus(w, e, statusError.statusCode) +} diff --git a/pkg/op/error_test.go b/pkg/op/error_test.go index dc5ef110..689ee5ab 100644 --- a/pkg/op/error_test.go +++ b/pkg/op/error_test.go @@ -1,9 +1,12 @@ package op import ( + "context" + "fmt" "io" "net/http" "net/http/httptest" + "net/url" "strings" "testing" @@ -275,3 +278,400 @@ func TestRequestError(t *testing.T) { }) } } + +func TestTryErrorRedirect(t *testing.T) { + type args struct { + ctx context.Context + authReq ErrAuthRequest + parent error + } + tests := []struct { + name string + args args + want *Redirect + wantErr error + wantLog string + }{ + { + name: "nil auth request", + args: args{ + ctx: context.Background(), + authReq: nil, + parent: io.ErrClosedPipe, + }, + wantErr: NewStatusError(io.ErrClosedPipe, http.StatusBadRequest), + wantLog: `{ + "level":"ERROR", + "msg":"auth request", + "time":"not", + "oidc_error":{ + "description":"io: read/write on closed pipe", + "parent":"io: read/write on closed pipe", + "type":"server_error" + } + }`, + }, + { + name: "auth request, no redirect URI", + args: args{ + ctx: context.Background(), + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + parent: oidc.ErrInteractionRequired().WithDescription("sign in"), + }, + wantErr: NewStatusError(oidc.ErrInteractionRequired().WithDescription("sign in"), http.StatusBadRequest), + wantLog: `{ + "level":"WARN", + "msg":"auth request: not redirecting", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"", + "response_type":"responseType", + "scopes":"a b" + }, + "oidc_error":{ + "description":"sign in", + "type":"interaction_required" + } + }`, + }, + { + name: "auth request, redirect disabled", + args: args{ + ctx: context.Background(), + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + RedirectURI: "http://example.com/callback", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + parent: oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"), + }, + wantErr: NewStatusError(oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"), http.StatusBadRequest), + wantLog: `{ + "level":"WARN", + "msg":"auth request: not redirecting", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"http://example.com/callback", + "response_type":"responseType", + "scopes":"a b" + }, + "oidc_error":{ + "description":"oops", + "type":"invalid_request", + "redirect_disabled":true + } + }`, + }, + { + name: "auth request, url parse error", + args: args{ + ctx: context.Background(), + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + RedirectURI: "can't parse this!\n", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + parent: oidc.ErrInteractionRequired().WithDescription("sign in"), + }, + wantErr: func() error { + //lint:ignore SA1007 just recreating the error for testing + _, err := url.Parse("can't parse this!\n") + err = oidc.ErrServerError().WithParent(err) + return NewStatusError(err, http.StatusBadRequest) + }(), + wantLog: `{ + "level":"ERROR", + "msg":"auth response URL", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"can't parse this!\n", + "response_type":"responseType", + "scopes":"a b" + }, + "error":{ + "type":"server_error", + "parent":"parse \"can't parse this!\\n\": net/url: invalid control character in URL" + }, + "oidc_error":{ + "description":"sign in", + "type":"interaction_required" + } + }`, + }, + { + name: "auth request redirect", + args: args{ + ctx: context.Background(), + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + RedirectURI: "http://example.com/callback", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + parent: oidc.ErrInteractionRequired().WithDescription("sign in"), + }, + want: &Redirect{ + URL: "http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1", + }, + wantLog: `{ + "level":"WARN", + "msg":"auth request redirect", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"http://example.com/callback", + "response_type":"responseType", + "scopes":"a b" + }, + "oidc_error":{ + "description":"sign in", + "type":"interaction_required" + }, + "url":"http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1" + }`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logOut := new(strings.Builder) + logger := slog.New( + slog.NewJSONHandler(logOut, &slog.HandlerOptions{ + Level: slog.LevelInfo, + }).WithAttrs([]slog.Attr{slog.String("time", "not")}), + ) + encoder := schema.NewEncoder() + + got, err := TryErrorRedirect(tt.args.ctx, tt.args.authReq, tt.args.parent, encoder, logger) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + + gotLog := logOut.String() + t.Log(gotLog) + assert.JSONEq(t, tt.wantLog, gotLog, "log output") + }) + } +} + +func TestNewStatusError(t *testing.T) { + err := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError) + + want := "Internal Server Error: io: read/write on closed pipe" + got := fmt.Sprint(err) + assert.Equal(t, want, got) +} + +func TestAsStatusError(t *testing.T) { + type args struct { + err error + statusCode int + } + tests := []struct { + name string + args args + want string + }{ + { + name: "already status error", + args: args{ + err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError), + statusCode: http.StatusBadRequest, + }, + want: "Internal Server Error: io: read/write on closed pipe", + }, + { + name: "oidc error", + args: args{ + err: oidc.ErrAcrInvalid, + statusCode: http.StatusBadRequest, + }, + want: "Bad Request: acr is invalid", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := AsStatusError(tt.args.err, tt.args.statusCode) + got := fmt.Sprint(err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestStatusError_Unwrap(t *testing.T) { + err := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError) + require.ErrorIs(t, err, io.ErrClosedPipe) +} + +func TestStatusError_Is(t *testing.T) { + type args struct { + err error + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "nil error", + args: args{err: nil}, + want: false, + }, + { + name: "other error", + args: args{err: io.EOF}, + want: false, + }, + { + name: "other parent", + args: args{err: NewStatusError(io.EOF, http.StatusInternalServerError)}, + want: false, + }, + { + name: "other status", + args: args{err: NewStatusError(io.ErrClosedPipe, http.StatusInsufficientStorage)}, + want: false, + }, + { + name: "same", + args: args{err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)}, + want: true, + }, + { + name: "wrapped", + args: args{err: fmt.Errorf("wrap: %w", NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError))}, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError) + if got := e.Is(tt.args.err); got != tt.want { + t.Errorf("StatusError.Is() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestWriteError(t *testing.T) { + tests := []struct { + name string + err error + wantStatus int + wantBody string + wantLog string + }{ + { + name: "not a status or oidc error", + err: io.ErrClosedPipe, + wantStatus: http.StatusBadRequest, + wantBody: `{ + "error":"server_error", + "error_description":"io: read/write on closed pipe" + }`, + wantLog: `{ + "level":"ERROR", + "msg":"request error", + "oidc_error":{ + "description":"io: read/write on closed pipe", + "parent":"io: read/write on closed pipe", + "type":"server_error" + }, + "time":"not" + }`, + }, + { + name: "status error w/o oidc", + err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError), + wantStatus: http.StatusInternalServerError, + wantBody: `{ + "error":"server_error", + "error_description":"io: read/write on closed pipe" + }`, + wantLog: `{ + "level":"ERROR", + "msg":"request error", + "oidc_error":{ + "description":"io: read/write on closed pipe", + "parent":"io: read/write on closed pipe", + "type":"server_error" + }, + "time":"not" + }`, + }, + { + name: "oidc error w/o status", + err: oidc.ErrInvalidRequest().WithDescription("oops"), + wantStatus: http.StatusBadRequest, + wantBody: `{ + "error":"invalid_request", + "error_description":"oops" + }`, + wantLog: `{ + "level":"WARN", + "msg":"request error", + "oidc_error":{ + "description":"oops", + "type":"invalid_request" + }, + "time":"not" + }`, + }, + { + name: "status with oidc error", + err: NewStatusError( + oidc.ErrUnauthorizedClient().WithDescription("oops"), + http.StatusUnauthorized, + ), + wantStatus: http.StatusUnauthorized, + wantBody: `{ + "error":"unauthorized_client", + "error_description":"oops" + }`, + wantLog: `{ + "level":"WARN", + "msg":"request error", + "oidc_error":{ + "description":"oops", + "type":"unauthorized_client" + }, + "time":"not" + }`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logOut := new(strings.Builder) + logger := slog.New( + slog.NewJSONHandler(logOut, &slog.HandlerOptions{ + Level: slog.LevelInfo, + }).WithAttrs([]slog.Attr{slog.String("time", "not")}), + ) + r := httptest.NewRequest("GET", "/target", nil) + w := httptest.NewRecorder() + + WriteError(w, r, tt.err, logger) + res := w.Result() + assert.Equal(t, tt.wantStatus, res.StatusCode, "status code") + gotBody, err := io.ReadAll(res.Body) + require.NoError(t, err) + assert.JSONEq(t, tt.wantBody, string(gotBody), "body") + assert.JSONEq(t, tt.wantLog, logOut.String()) + }) + } +} diff --git a/pkg/op/mock/configuration.mock.go b/pkg/op/mock/configuration.mock.go index 96429ddb..f392a455 100644 --- a/pkg/op/mock/configuration.mock.go +++ b/pkg/op/mock/configuration.mock.go @@ -65,10 +65,10 @@ func (mr *MockConfigurationMockRecorder) AuthMethodPrivateKeyJWTSupported() *gom } // AuthorizationEndpoint mocks base method. -func (m *MockConfiguration) AuthorizationEndpoint() op.Endpoint { +func (m *MockConfiguration) AuthorizationEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AuthorizationEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -107,10 +107,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorization() *gomock.Call { } // DeviceAuthorizationEndpoint mocks base method. -func (m *MockConfiguration) DeviceAuthorizationEndpoint() op.Endpoint { +func (m *MockConfiguration) DeviceAuthorizationEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "DeviceAuthorizationEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -121,10 +121,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorizationEndpoint() *gomock.C } // EndSessionEndpoint mocks base method. -func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint { +func (m *MockConfiguration) EndSessionEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "EndSessionEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -233,10 +233,10 @@ func (mr *MockConfigurationMockRecorder) IntrospectionAuthMethodPrivateKeyJWTSup } // IntrospectionEndpoint mocks base method. -func (m *MockConfiguration) IntrospectionEndpoint() op.Endpoint { +func (m *MockConfiguration) IntrospectionEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "IntrospectionEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -275,10 +275,10 @@ func (mr *MockConfigurationMockRecorder) IssuerFromRequest(arg0 interface{}) *go } // KeysEndpoint mocks base method. -func (m *MockConfiguration) KeysEndpoint() op.Endpoint { +func (m *MockConfiguration) KeysEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "KeysEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -331,10 +331,10 @@ func (mr *MockConfigurationMockRecorder) RevocationAuthMethodPrivateKeyJWTSuppor } // RevocationEndpoint mocks base method. -func (m *MockConfiguration) RevocationEndpoint() op.Endpoint { +func (m *MockConfiguration) RevocationEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RevocationEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -373,10 +373,10 @@ func (mr *MockConfigurationMockRecorder) SupportedUILocales() *gomock.Call { } // TokenEndpoint mocks base method. -func (m *MockConfiguration) TokenEndpoint() op.Endpoint { +func (m *MockConfiguration) TokenEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "TokenEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -401,10 +401,10 @@ func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported } // UserinfoEndpoint mocks base method. -func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint { +func (m *MockConfiguration) UserinfoEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UserinfoEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } diff --git a/pkg/op/op.go b/pkg/op/op.go index 0175d7fe..55ee9866 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -32,7 +32,7 @@ const ( ) var ( - DefaultEndpoints = &endpoints{ + DefaultEndpoints = &Endpoints{ Authorization: NewEndpoint(defaultAuthorizationEndpoint), Token: NewEndpoint(defaultTokenEndpoint), Introspection: NewEndpoint(defaultIntrospectEndpoint), @@ -131,16 +131,17 @@ type Config struct { DeviceAuthorization DeviceAuthorizationConfig } -type endpoints struct { - Authorization Endpoint - Token Endpoint - Introspection Endpoint - Userinfo Endpoint - Revocation Endpoint - EndSession Endpoint - CheckSessionIframe Endpoint - JwksURI Endpoint - DeviceAuthorization Endpoint +// Endpoints defines endpoint routes. +type Endpoints struct { + Authorization *Endpoint + Token *Endpoint + Introspection *Endpoint + Userinfo *Endpoint + Revocation *Endpoint + EndSession *Endpoint + CheckSessionIframe *Endpoint + JwksURI *Endpoint + DeviceAuthorization *Endpoint } // NewOpenIDProvider creates a provider. The provider provides (with HttpHandler()) @@ -212,7 +213,7 @@ type Provider struct { config *Config issuer IssuerFromRequest insecure bool - endpoints *endpoints + endpoints *Endpoints storage Storage keySet *openIDKeySet crypto Crypto @@ -233,35 +234,35 @@ func (o *Provider) Insecure() bool { return o.insecure } -func (o *Provider) AuthorizationEndpoint() Endpoint { +func (o *Provider) AuthorizationEndpoint() *Endpoint { return o.endpoints.Authorization } -func (o *Provider) TokenEndpoint() Endpoint { +func (o *Provider) TokenEndpoint() *Endpoint { return o.endpoints.Token } -func (o *Provider) IntrospectionEndpoint() Endpoint { +func (o *Provider) IntrospectionEndpoint() *Endpoint { return o.endpoints.Introspection } -func (o *Provider) UserinfoEndpoint() Endpoint { +func (o *Provider) UserinfoEndpoint() *Endpoint { return o.endpoints.Userinfo } -func (o *Provider) RevocationEndpoint() Endpoint { +func (o *Provider) RevocationEndpoint() *Endpoint { return o.endpoints.Revocation } -func (o *Provider) EndSessionEndpoint() Endpoint { +func (o *Provider) EndSessionEndpoint() *Endpoint { return o.endpoints.EndSession } -func (o *Provider) DeviceAuthorizationEndpoint() Endpoint { +func (o *Provider) DeviceAuthorizationEndpoint() *Endpoint { return o.endpoints.DeviceAuthorization } -func (o *Provider) KeysEndpoint() Endpoint { +func (o *Provider) KeysEndpoint() *Endpoint { return o.endpoints.JwksURI } @@ -420,7 +421,7 @@ func WithAllowInsecure() Option { } } -func WithCustomAuthEndpoint(endpoint Endpoint) Option { +func WithCustomAuthEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -430,7 +431,7 @@ func WithCustomAuthEndpoint(endpoint Endpoint) Option { } } -func WithCustomTokenEndpoint(endpoint Endpoint) Option { +func WithCustomTokenEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -440,7 +441,7 @@ func WithCustomTokenEndpoint(endpoint Endpoint) Option { } } -func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option { +func WithCustomIntrospectionEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -450,7 +451,7 @@ func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option { } } -func WithCustomUserinfoEndpoint(endpoint Endpoint) Option { +func WithCustomUserinfoEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -460,7 +461,7 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) Option { } } -func WithCustomRevocationEndpoint(endpoint Endpoint) Option { +func WithCustomRevocationEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -470,7 +471,7 @@ func WithCustomRevocationEndpoint(endpoint Endpoint) Option { } } -func WithCustomEndSessionEndpoint(endpoint Endpoint) Option { +func WithCustomEndSessionEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -480,7 +481,7 @@ func WithCustomEndSessionEndpoint(endpoint Endpoint) Option { } } -func WithCustomKeysEndpoint(endpoint Endpoint) Option { +func WithCustomKeysEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -490,7 +491,7 @@ func WithCustomKeysEndpoint(endpoint Endpoint) Option { } } -func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option { +func WithCustomDeviceAuthorizationEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -500,8 +501,16 @@ func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option { } } -func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys Endpoint) Option { +// WithCustomEndpoints sets multiple endpoints at once. +// Non of the endpoints may be nil, or an error will +// be returned when the Option used by the Provider. +func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys *Endpoint) Option { return func(o *Provider) error { + for _, e := range []*Endpoint{auth, token, userInfo, revocation, endSession, keys} { + if err := e.Validate(); err != nil { + return err + } + } o.endpoints.Authorization = auth o.endpoints.Token = token o.endpoints.Userinfo = userInfo diff --git a/pkg/op/op_test.go b/pkg/op/op_test.go index d33b39d5..abe53bc1 100644 --- a/pkg/op/op_test.go +++ b/pkg/op/op_test.go @@ -395,3 +395,54 @@ func TestRoutes(t *testing.T) { }) } } + +func TestWithCustomEndpoints(t *testing.T) { + type args struct { + auth *op.Endpoint + token *op.Endpoint + userInfo *op.Endpoint + revocation *op.Endpoint + endSession *op.Endpoint + keys *op.Endpoint + } + tests := []struct { + name string + args args + wantErr error + }{ + { + name: "all nil", + args: args{}, + wantErr: op.ErrNilEndpoint, + }, + { + name: "all set", + args: args{ + auth: op.NewEndpoint("/authorize"), + token: op.NewEndpoint("/oauth/token"), + userInfo: op.NewEndpoint("/userinfo"), + revocation: op.NewEndpoint("/revoke"), + endSession: op.NewEndpoint("/end_session"), + keys: op.NewEndpoint("/keys"), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider, err := op.NewOpenIDProvider(testIssuer, testConfig, + storage.NewStorage(storage.NewUserStore(testIssuer)), + op.WithCustomEndpoints(tt.args.auth, tt.args.token, tt.args.userInfo, tt.args.revocation, tt.args.endSession, tt.args.keys), + ) + require.ErrorIs(t, err, tt.wantErr) + if tt.wantErr != nil { + return + } + assert.Equal(t, tt.args.auth, provider.AuthorizationEndpoint()) + assert.Equal(t, tt.args.token, provider.TokenEndpoint()) + assert.Equal(t, tt.args.userInfo, provider.UserinfoEndpoint()) + assert.Equal(t, tt.args.revocation, provider.RevocationEndpoint()) + assert.Equal(t, tt.args.endSession, provider.EndSessionEndpoint()) + assert.Equal(t, tt.args.keys, provider.KeysEndpoint()) + }) + } +} diff --git a/pkg/op/probes.go b/pkg/op/probes.go index 9ef5bb56..cb3853d8 100644 --- a/pkg/op/probes.go +++ b/pkg/op/probes.go @@ -41,9 +41,9 @@ func ReadyStorage(s Storage) ProbesFn { } func ok(w http.ResponseWriter) { - httphelper.MarshalJSON(w, status{"ok"}) + httphelper.MarshalJSON(w, Status{"ok"}) } -type status struct { +type Status struct { Status string `json:"status,omitempty"` } diff --git a/pkg/op/server.go b/pkg/op/server.go new file mode 100644 index 00000000..a9cdcf5f --- /dev/null +++ b/pkg/op/server.go @@ -0,0 +1,346 @@ +package op + +import ( + "context" + "net/http" + "net/url" + + "github.com/muhlemmer/gu" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" +) + +// Server describes the interface that needs to be implemented to serve +// OpenID Connect and Oauth2 standard requests. +// +// Methods are called after the HTTP route is resolved and +// the request body is parsed into the Request's Data field. +// When a method is called, it can be assumed that required fields, +// as described in their relevant standard, are validated already. +// The Response Data field may be of any type to allow flexibility +// to extend responses with custom fields. There are however requirements +// in the standards regarding the response models. Where applicable +// the method documentation gives a recommended type which can be used +// directly or extended upon. +// +// The addition of new methods is not considered a breaking change +// as defined by semver rules. +// Implementations MUST embed [UnimplementedServer] to maintain +// forward compatibility. +// +// EXPERIMENTAL: may change until v4 +type Server interface { + // Health returns a status of "ok" once the Server is listening. + // The recommended Response Data type is [Status]. + Health(context.Context, *Request[struct{}]) (*Response, error) + + // Ready returns a status of "ok" once all dependencies, + // such as database storage, are ready. + // An error can be returned to explain what is not ready. + // The recommended Response Data type is [Status]. + Ready(context.Context, *Request[struct{}]) (*Response, error) + + // Discovery returns the OpenID Provider Configuration Information for this server. + // https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfig + // The recommended Response Data type is [oidc.DiscoveryConfiguration]. + Discovery(context.Context, *Request[struct{}]) (*Response, error) + + // Keys serves the JWK set which the client can use verify signatures from the op. + // https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata `jwks_uri` key. + // The recommended Response Data type is [jose.JSONWebKeySet]. + Keys(context.Context, *Request[struct{}]) (*Response, error) + + // VerifyAuthRequest verifies the Auth Request and + // adds the Client to the request. + // + // When the `request` field is populated with a + // "Request Object" JWT, it needs to be Validated + // and its claims overwrite any fields in the AuthRequest. + // If the implementation does not support "Request Object", + // it MUST return an [oidc.ErrRequestNotSupported]. + // https://openid.net/specs/openid-connect-core-1_0.html#RequestObject + VerifyAuthRequest(context.Context, *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) + + // Authorize initiates the authorization flow and redirects to a login page. + // See the various https://openid.net/specs/openid-connect-core-1_0.html + // authorize endpoint sections (one for each type of flow). + Authorize(context.Context, *ClientRequest[oidc.AuthRequest]) (*Redirect, error) + + // DeviceAuthorization initiates the device authorization flow. + // https://datatracker.ietf.org/doc/html/rfc8628#section-3.1 + // The recommended Response Data type is [oidc.DeviceAuthorizationResponse]. + DeviceAuthorization(context.Context, *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) + + // VerifyClient is called on most oauth/token handlers to authenticate, + // using either a secret (POST, Basic) or assertion (JWT). + // If no secrets are provided, the client must be public. + // This method is called before each method that takes a + // [ClientRequest] argument. + VerifyClient(context.Context, *Request[ClientCredentials]) (Client, error) + + // CodeExchange returns Tokens after an authorization code + // is obtained in a successful Authorize flow. + // It is called by the Token endpoint handler when + // grant_type has the value authorization_code + // https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint + // The recommended Response Data type is [oidc.AccessTokenResponse]. + CodeExchange(context.Context, *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) + + // RefreshToken returns new Tokens after verifying a Refresh token. + // It is called by the Token endpoint handler when + // grant_type has the value refresh_token + // https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens + // The recommended Response Data type is [oidc.AccessTokenResponse]. + RefreshToken(context.Context, *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) + + // JWTProfile handles the OAuth 2.0 JWT Profile Authorization Grant + // It is called by the Token endpoint handler when + // grant_type has the value urn:ietf:params:oauth:grant-type:jwt-bearer + // https://datatracker.ietf.org/doc/html/rfc7523#section-2.1 + // The recommended Response Data type is [oidc.AccessTokenResponse]. + JWTProfile(context.Context, *Request[oidc.JWTProfileGrantRequest]) (*Response, error) + + // TokenExchange handles the OAuth 2.0 token exchange grant + // It is called by the Token endpoint handler when + // grant_type has the value urn:ietf:params:oauth:grant-type:token-exchange + // https://datatracker.ietf.org/doc/html/rfc8693 + // The recommended Response Data type is [oidc.AccessTokenResponse]. + TokenExchange(context.Context, *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) + + // ClientCredentialsExchange handles the OAuth 2.0 client credentials grant + // It is called by the Token endpoint handler when + // grant_type has the value client_credentials + // https://datatracker.ietf.org/doc/html/rfc6749#section-4.4 + // The recommended Response Data type is [oidc.AccessTokenResponse]. + ClientCredentialsExchange(context.Context, *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) + + // DeviceToken handles the OAuth 2.0 Device Authorization Grant + // It is called by the Token endpoint handler when + // grant_type has the value urn:ietf:params:oauth:grant-type:device_code. + // It is typically called in a polling fashion and appropriate errors + // should be returned to signal authorization_pending or access_denied etc. + // https://datatracker.ietf.org/doc/html/rfc8628#section-3.4, + // https://datatracker.ietf.org/doc/html/rfc8628#section-3.5. + // The recommended Response Data type is [oidc.AccessTokenResponse]. + DeviceToken(context.Context, *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) + + // Introspect handles the OAuth 2.0 Token Introspection endpoint. + // https://datatracker.ietf.org/doc/html/rfc7662 + // The recommended Response Data type is [oidc.IntrospectionResponse]. + Introspect(context.Context, *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) + + // UserInfo handles the UserInfo endpoint and returns Claims about the authenticated End-User. + // https://openid.net/specs/openid-connect-core-1_0.html#UserInfo + // The recommended Response Data type is [oidc.UserInfo]. + UserInfo(context.Context, *Request[oidc.UserInfoRequest]) (*Response, error) + + // Revocation handles token revocation using an access or refresh token. + // https://datatracker.ietf.org/doc/html/rfc7009 + // There are no response requirements. Data may remain empty. + Revocation(context.Context, *ClientRequest[oidc.RevocationRequest]) (*Response, error) + + // EndSession handles the OpenID Connect RP-Initiated Logout. + // https://openid.net/specs/openid-connect-rpinitiated-1_0.html + // There are no response requirements. Data may remain empty. + EndSession(context.Context, *Request[oidc.EndSessionRequest]) (*Redirect, error) + + // mustImpl forces implementations to embed the UnimplementedServer for forward + // compatibility with the interface. + mustImpl() +} + +// Request contains the [http.Request] informational fields +// and parsed Data from the request body (POST) or URL parameters (GET). +// Data can be assumed to be validated according to the applicable +// standard for the specific endpoints. +// +// EXPERIMENTAL: may change until v4 +type Request[T any] struct { + Method string + URL *url.URL + Header http.Header + Form url.Values + PostForm url.Values + Data *T +} + +func (r *Request[_]) path() string { + return r.URL.Path +} + +func newRequest[T any](r *http.Request, data *T) *Request[T] { + return &Request[T]{ + Method: r.Method, + URL: r.URL, + Header: r.Header, + Form: r.Form, + PostForm: r.PostForm, + Data: data, + } +} + +// ClientRequest is a Request with a verified client attached to it. +// Methods that receive this argument may assume the client was authenticated, +// or verified to be a public client. +// +// EXPERIMENTAL: may change until v4 +type ClientRequest[T any] struct { + *Request[T] + Client Client +} + +func newClientRequest[T any](r *http.Request, data *T, client Client) *ClientRequest[T] { + return &ClientRequest[T]{ + Request: newRequest[T](r, data), + Client: client, + } +} + +// Response object for most [Server] methods. +// +// EXPERIMENTAL: may change until v4 +type Response struct { + // Header map will be merged with the + // header on the [http.ResponseWriter]. + Header http.Header + + // Data will be JSON marshaled to + // the response body. + // We allow any type, so that implementations + // can extend the standard types as they wish. + // However, each method will recommend which + // (base) type to use as model, in order to + // be compliant with the standards. + Data any +} + +// NewResponse creates a new response for data, +// without custom headers. +func NewResponse(data any) *Response { + return &Response{ + Data: data, + } +} + +func (resp *Response) writeOut(w http.ResponseWriter) { + gu.MapMerge(resp.Header, w.Header()) + httphelper.MarshalJSON(w, resp.Data) +} + +// Redirect is a special response type which will +// initiate a [http.StatusFound] redirect. +// The Params field will be encoded and set to the +// URL's RawQuery field before building the URL. +// +// EXPERIMENTAL: may change until v4 +type Redirect struct { + // Header map will be merged with the + // header on the [http.ResponseWriter]. + Header http.Header + + URL string +} + +func NewRedirect(url string) *Redirect { + return &Redirect{URL: url} +} + +func (red *Redirect) writeOut(w http.ResponseWriter, r *http.Request) { + gu.MapMerge(r.Header, w.Header()) + http.Redirect(w, r, red.URL, http.StatusFound) +} + +type UnimplementedServer struct{} + +// UnimplementedStatusCode is the status code returned for methods +// that are not yet implemented. +// Note that this means methods in the sense of the Go interface, +// and not http methods covered by "501 Not Implemented". +var UnimplementedStatusCode = http.StatusNotFound + +func unimplementedError(r interface{ path() string }) StatusError { + err := oidc.ErrServerError().WithDescription("%s not implemented on this server", r.path()) + return NewStatusError(err, UnimplementedStatusCode) +} + +func unimplementedGrantError(gt oidc.GrantType) StatusError { + err := oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", gt) + return NewStatusError(err, http.StatusBadRequest) // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 +} + +func (UnimplementedServer) mustImpl() {} + +func (UnimplementedServer) Health(ctx context.Context, r *Request[struct{}]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) Ready(ctx context.Context, r *Request[struct{}]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) { + if r.Data.RequestParam != "" { + return nil, oidc.ErrRequestNotSupported() + } + return nil, unimplementedError(r) +} + +func (UnimplementedServer) Authorize(ctx context.Context, r *ClientRequest[oidc.AuthRequest]) (*Redirect, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) { + return nil, unimplementedGrantError(oidc.GrantTypeCode) +} + +func (UnimplementedServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) { + return nil, unimplementedGrantError(oidc.GrantTypeRefreshToken) +} + +func (UnimplementedServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfileGrantRequest]) (*Response, error) { + return nil, unimplementedGrantError(oidc.GrantTypeBearer) +} + +func (UnimplementedServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) { + return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange) +} + +func (UnimplementedServer) ClientCredentialsExchange(ctx context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) { + return nil, unimplementedGrantError(oidc.GrantTypeClientCredentials) +} + +func (UnimplementedServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) { + return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode) +} + +func (UnimplementedServer) Introspect(ctx context.Context, r *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) UserInfo(ctx context.Context, r *Request[oidc.UserInfoRequest]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) Revocation(ctx context.Context, r *ClientRequest[oidc.RevocationRequest]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) EndSession(ctx context.Context, r *Request[oidc.EndSessionRequest]) (*Redirect, error) { + return nil, unimplementedError(r) +} diff --git a/pkg/op/server_http.go b/pkg/op/server_http.go new file mode 100644 index 00000000..3fb481d1 --- /dev/null +++ b/pkg/op/server_http.go @@ -0,0 +1,480 @@ +package op + +import ( + "context" + "net/http" + "net/url" + + "github.com/go-chi/chi" + "github.com/rs/cors" + "github.com/zitadel/logging" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/schema" + "golang.org/x/exp/slog" +) + +// RegisterServer registers an implementation of Server. +// The resulting handler takes care of routing and request parsing, +// with some basic validation of required fields. +// The routes can be customized with [WithEndpoints]. +// +// EXPERIMENTAL: may change until v4 +func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption) http.Handler { + decoder := schema.NewDecoder() + decoder.IgnoreUnknownKeys(true) + + ws := &webServer{ + server: server, + endpoints: endpoints, + decoder: decoder, + logger: slog.Default(), + } + + for _, option := range options { + option(ws) + } + + ws.createRouter() + return ws +} + +type ServerOption func(s *webServer) + +// WithHTTPMiddleware sets the passed middleware chain to the root of +// the Server's router. +func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption { + return func(s *webServer) { + s.middleware = m + } +} + +// WithDecoder overrides the default decoder, +// which is a [schema.Decoder] with IgnoreUnknownKeys set to true. +func WithDecoder(decoder httphelper.Decoder) ServerOption { + return func(s *webServer) { + s.decoder = decoder + } +} + +// WithFallbackLogger overrides the fallback logger, which +// is used when no logger was found in the context. +// Defaults to [slog.Default]. +func WithFallbackLogger(logger *slog.Logger) ServerOption { + return func(s *webServer) { + s.logger = logger + } +} + +type webServer struct { + http.Handler + server Server + middleware []func(http.Handler) http.Handler + endpoints Endpoints + decoder httphelper.Decoder + logger *slog.Logger +} + +func (s *webServer) getLogger(ctx context.Context) *slog.Logger { + if logger, ok := logging.FromContext(ctx); ok { + return logger + } + return s.logger +} + +func (s *webServer) createRouter() { + router := chi.NewRouter() + router.Use(cors.New(defaultCORSOptions).Handler) + router.Use(s.middleware...) + router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health)) + router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready)) + router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery)) + + s.endpointRoute(router, s.endpoints.Authorization, s.authorizeHandler) + s.endpointRoute(router, s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler)) + s.endpointRoute(router, s.endpoints.Token, s.tokensHandler) + s.endpointRoute(router, s.endpoints.Introspection, s.withClient(s.introspectionHandler)) + s.endpointRoute(router, s.endpoints.Userinfo, s.userInfoHandler) + s.endpointRoute(router, s.endpoints.Revocation, s.withClient(s.revocationHandler)) + s.endpointRoute(router, s.endpoints.EndSession, s.endSessionHandler) + s.endpointRoute(router, s.endpoints.JwksURI, simpleHandler(s, s.server.Keys)) + s.Handler = router +} + +func (s *webServer) endpointRoute(router *chi.Mux, e *Endpoint, hf http.HandlerFunc) { + if e != nil { + router.HandleFunc(e.Relative(), hf) + s.logger.Info("registered route", "endpoint", e.Relative()) + } +} + +type clientHandler func(w http.ResponseWriter, r *http.Request, client Client) + +func (s *webServer) withClient(handler clientHandler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + client, err := s.verifyRequestClient(r) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType != "" { + if !ValidateGrantType(client, grantType) { + WriteError(w, r, oidc.ErrUnauthorizedClient().WithDescription("grant_type %q not allowed", grantType), s.getLogger(r.Context())) + return + } + } + handler(w, r, client) + } +} + +func (s *webServer) verifyRequestClient(r *http.Request) (_ Client, err error) { + if err = r.ParseForm(); err != nil { + return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err) + } + cc := new(ClientCredentials) + if err = s.decoder.Decode(cc, r.Form); err != nil { + return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err) + } + // Basic auth takes precedence, so if set it overwrites the form data. + if clientID, clientSecret, ok := r.BasicAuth(); ok { + cc.ClientID, err = url.QueryUnescape(clientID) + if err != nil { + return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err) + } + cc.ClientSecret, err = url.QueryUnescape(clientSecret) + if err != nil { + return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err) + } + } + if cc.ClientID == "" && cc.ClientAssertion == "" { + return nil, oidc.ErrInvalidRequest().WithDescription("client_id or client_assertion must be provided") + } + if cc.ClientAssertion != "" && cc.ClientAssertionType != oidc.ClientAssertionTypeJWTAssertion { + return nil, oidc.ErrInvalidRequest().WithDescription("invalid client_assertion_type %s", cc.ClientAssertionType) + } + return s.server.VerifyClient(r.Context(), &Request[ClientCredentials]{ + Method: r.Method, + URL: r.URL, + Header: r.Header, + Form: r.Form, + Data: cc, + }) +} + +func (s *webServer) authorizeHandler(w http.ResponseWriter, r *http.Request) { + request, err := decodeRequest[oidc.AuthRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + redirect, err := s.authorize(r.Context(), newRequest(r, request)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + redirect.writeOut(w, r) +} + +func (s *webServer) authorize(ctx context.Context, r *Request[oidc.AuthRequest]) (_ *Redirect, err error) { + cr, err := s.server.VerifyAuthRequest(ctx, r) + if err != nil { + return nil, err + } + authReq := cr.Data + if authReq.RedirectURI == "" { + return nil, ErrAuthReqMissingRedirectURI + } + authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge) + if err != nil { + return nil, err + } + authReq.Scopes, err = ValidateAuthReqScopes(cr.Client, authReq.Scopes) + if err != nil { + return nil, err + } + if err := ValidateAuthReqRedirectURI(cr.Client, authReq.RedirectURI, authReq.ResponseType); err != nil { + return nil, err + } + if err := ValidateAuthReqResponseType(cr.Client, authReq.ResponseType); err != nil { + return nil, err + } + return s.server.Authorize(ctx, cr) +} + +func (s *webServer) deviceAuthorizationHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.DeviceAuthorizationRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp, err := s.server.DeviceAuthorization(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context())) + return + } + + switch grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType { + case oidc.GrantTypeCode: + s.withClient(s.codeExchangeHandler)(w, r) + case oidc.GrantTypeRefreshToken: + s.withClient(s.refreshTokenHandler)(w, r) + case oidc.GrantTypeClientCredentials: + s.withClient(s.clientCredentialsHandler)(w, r) + case oidc.GrantTypeBearer: + s.jwtProfileHandler(w, r) + case oidc.GrantTypeTokenExchange: + s.withClient(s.tokenExchangeHandler)(w, r) + case oidc.GrantTypeDeviceCode: + s.withClient(s.deviceTokenHandler)(w, r) + case "": + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), s.getLogger(r.Context())) + default: + WriteError(w, r, unimplementedGrantError(grantType), s.getLogger(r.Context())) + } +} + +func (s *webServer) jwtProfileHandler(w http.ResponseWriter, r *http.Request) { + request, err := decodeRequest[oidc.JWTProfileGrantRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.Assertion == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("assertion missing"), s.getLogger(r.Context())) + return + } + resp, err := s.server.JWTProfile(r.Context(), newRequest(r, request)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) codeExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.AccessTokenRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.Code == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), s.getLogger(r.Context())) + return + } + if request.RedirectURI == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("redirect_uri missing"), s.getLogger(r.Context())) + return + } + resp, err := s.server.CodeExchange(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) refreshTokenHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.RefreshTokenRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.RefreshToken == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("refresh_token missing"), s.getLogger(r.Context())) + return + } + resp, err := s.server.RefreshToken(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) tokenExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.TokenExchangeRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.SubjectToken == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token missing"), s.getLogger(r.Context())) + return + } + if request.SubjectTokenType == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing"), s.getLogger(r.Context())) + return + } + if !request.SubjectTokenType.IsSupported() { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type is not supported"), s.getLogger(r.Context())) + return + } + if request.RequestedTokenType != "" && !request.RequestedTokenType.IsSupported() { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("requested_token_type is not supported"), s.getLogger(r.Context())) + return + } + if request.ActorTokenType != "" && !request.ActorTokenType.IsSupported() { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.getLogger(r.Context())) + return + } + resp, err := s.server.TokenExchange(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) clientCredentialsHandler(w http.ResponseWriter, r *http.Request, client Client) { + if client.AuthMethod() == oidc.AuthMethodNone { + WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context())) + return + } + + request, err := decodeRequest[oidc.ClientCredentialsRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp, err := s.server.ClientCredentialsExchange(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) deviceTokenHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.DeviceAccessTokenRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.DeviceCode == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("device_code missing"), s.getLogger(r.Context())) + return + } + resp, err := s.server.DeviceToken(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request, client Client) { + if client.AuthMethod() == oidc.AuthMethodNone { + WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context())) + return + } + request, err := decodeRequest[oidc.IntrospectionRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.Token == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context())) + return + } + resp, err := s.server.Introspect(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) { + request, err := decodeRequest[oidc.UserInfoRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if token, err := getAccessToken(r); err == nil { + request.AccessToken = token + } + if request.AccessToken == "" { + err = NewStatusError( + oidc.ErrInvalidRequest().WithDescription("access token missing"), + http.StatusUnauthorized, + ) + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp, err := s.server.UserInfo(r.Context(), newRequest(r, request)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) revocationHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.RevocationRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.Token == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context())) + return + } + resp, err := s.server.Revocation(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) endSessionHandler(w http.ResponseWriter, r *http.Request) { + request, err := decodeRequest[oidc.EndSessionRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp, err := s.server.EndSession(r.Context(), newRequest(r, request)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w, r) +} + +func simpleHandler(s *webServer, method func(context.Context, *Request[struct{}]) (*Response, error)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context())) + return + } + resp, err := method(r.Context(), newRequest(r, &struct{}{})) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) + } +} + +func decodeRequest[R any](decoder httphelper.Decoder, r *http.Request, postOnly bool) (*R, error) { + dst := new(R) + if err := r.ParseForm(); err != nil { + return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err) + } + form := r.Form + if postOnly { + form = r.PostForm + } + if err := decoder.Decode(dst, form); err != nil { + return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err) + } + return dst, nil +} diff --git a/pkg/op/server_http_routes_test.go b/pkg/op/server_http_routes_test.go new file mode 100644 index 00000000..6a8b75d6 --- /dev/null +++ b/pkg/op/server_http_routes_test.go @@ -0,0 +1,345 @@ +package op_test + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/oidc/v3/pkg/client" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" +) + +func jwtProfile() (string, error) { + keyData, err := client.ConfigFromKeyFile("../../example/server/service-key1.json") + if err != nil { + return "", err + } + signer, err := client.NewSignerFromPrivateKeyByte([]byte(keyData.Key), keyData.KeyID) + if err != nil { + return "", err + } + return client.SignedJWTProfileAssertion(keyData.UserID, []string{testIssuer}, time.Hour, signer) +} + +func TestServerRoutes(t *testing.T) { + server := op.NewLegacyServer(testProvider, *op.DefaultEndpoints) + + storage := testProvider.Storage().(routesTestStorage) + ctx := op.ContextWithIssuer(context.Background(), testIssuer) + + client, err := storage.GetClientByClientID(ctx, "web") + require.NoError(t, err) + + oidcAuthReq := &oidc.AuthRequest{ + ClientID: client.GetID(), + RedirectURI: "https://example.com", + MaxAge: gu.Ptr[uint](300), + Scopes: oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess, oidc.ScopeEmail, oidc.ScopeProfile, oidc.ScopePhone}, + ResponseType: oidc.ResponseTypeCode, + } + + authReq, err := storage.CreateAuthRequest(ctx, oidcAuthReq, "id1") + require.NoError(t, err) + storage.AuthRequestDone(authReq.GetID()) + + accessToken, refreshToken, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "") + require.NoError(t, err) + accessTokenRevoke, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "") + require.NoError(t, err) + idToken, err := op.CreateIDToken(ctx, testIssuer, authReq, time.Hour, accessToken, "123", storage, client) + require.NoError(t, err) + jwtToken, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeJWT, testProvider, client, "") + require.NoError(t, err) + jwtProfileToken, err := jwtProfile() + require.NoError(t, err) + + oidcAuthReq.IDTokenHint = idToken + + serverURL, err := url.Parse(testIssuer) + require.NoError(t, err) + + type basicAuth struct { + username, password string + } + + tests := []struct { + name string + method string + path string + basicAuth *basicAuth + header map[string]string + values map[string]string + body map[string]string + wantCode int + headerContains map[string]string + json string // test for exact json output + contains []string // when the body output is not constant, we just check for snippets to be present in the response + }{ + { + name: "health", + method: http.MethodGet, + path: "/healthz", + wantCode: http.StatusOK, + json: `{"status":"ok"}`, + }, + { + name: "ready", + method: http.MethodGet, + path: "/ready", + wantCode: http.StatusOK, + json: `{"status":"ok"}`, + }, + { + name: "discovery", + method: http.MethodGet, + path: oidc.DiscoveryEndpoint, + wantCode: http.StatusOK, + json: `{"issuer":"https://localhost:9998/","authorization_endpoint":"https://localhost:9998/authorize","token_endpoint":"https://localhost:9998/oauth/token","introspection_endpoint":"https://localhost:9998/oauth/introspect","userinfo_endpoint":"https://localhost:9998/userinfo","revocation_endpoint":"https://localhost:9998/revoke","end_session_endpoint":"https://localhost:9998/end_session","device_authorization_endpoint":"https://localhost:9998/device_authorization","jwks_uri":"https://localhost:9998/keys","scopes_supported":["openid","profile","email","phone","address","offline_access"],"response_types_supported":["code","id_token","id_token token"],"grant_types_supported":["authorization_code","implicit","refresh_token","client_credentials","urn:ietf:params:oauth:grant-type:token-exchange","urn:ietf:params:oauth:grant-type:jwt-bearer","urn:ietf:params:oauth:grant-type:device_code"],"subject_types_supported":["public"],"id_token_signing_alg_values_supported":["RS256"],"request_object_signing_alg_values_supported":["RS256"],"token_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"token_endpoint_auth_signing_alg_values_supported":["RS256"],"revocation_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"revocation_endpoint_auth_signing_alg_values_supported":["RS256"],"introspection_endpoint_auth_methods_supported":["client_secret_basic","private_key_jwt"],"introspection_endpoint_auth_signing_alg_values_supported":["RS256"],"claims_supported":["sub","aud","exp","iat","iss","auth_time","nonce","acr","amr","c_hash","at_hash","act","scopes","client_id","azp","preferred_username","name","family_name","given_name","locale","email","email_verified","phone_number","phone_number_verified"],"code_challenge_methods_supported":["S256"],"ui_locales_supported":["en"],"request_parameter_supported":true,"request_uri_parameter_supported":false}`, + }, + { + name: "authorization", + method: http.MethodGet, + path: testProvider.AuthorizationEndpoint().Relative(), + values: map[string]string{ + "client_id": client.GetID(), + "redirect_uri": "https://example.com", + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), + "response_type": string(oidc.ResponseTypeCode), + }, + wantCode: http.StatusFound, + headerContains: map[string]string{"Location": "/login/username?authRequestID="}, + }, + { + // This call will fail. A successfull test is already + // part of client/integration_test.go + name: "code exchange", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + values: map[string]string{ + "grant_type": string(oidc.GrantTypeCode), + "client_id": client.GetID(), + "client_secret": "secret", + "redirect_uri": "https://example.com", + "code": "123", + }, + wantCode: http.StatusBadRequest, + json: `{"error":"invalid_grant", "error_description":"invalid code"}`, + }, + { + name: "JWT authorization", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + values: map[string]string{ + "grant_type": string(oidc.GrantTypeBearer), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), + "assertion": jwtProfileToken, + }, + wantCode: http.StatusOK, + contains: []string{`{"access_token":`, `"token_type":"Bearer","expires_in":299}`}, + }, + { + name: "Token exchange", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + basicAuth: &basicAuth{"web", "secret"}, + values: map[string]string{ + "grant_type": string(oidc.GrantTypeTokenExchange), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), + "subject_token": jwtToken, + "subject_token_type": string(oidc.AccessTokenType), + }, + wantCode: http.StatusOK, + contains: []string{ + `{"access_token":"`, + `","issued_token_type":"urn:ietf:params:oauth:token-type:refresh_token","token_type":"Bearer","expires_in":299,"scope":"openid offline_access","refresh_token":"`, + }, + }, + { + name: "Client credentials exchange", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + basicAuth: &basicAuth{"sid1", "verysecret"}, + values: map[string]string{ + "grant_type": string(oidc.GrantTypeClientCredentials), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), + }, + wantCode: http.StatusOK, + contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299}`}, + }, + { + // This call will fail. A successfull test is already + // part of device_test.go + name: "device token", + method: http.MethodPost, + path: testProvider.TokenEndpoint().Relative(), + basicAuth: &basicAuth{"web", "secret"}, + header: map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + }, + body: map[string]string{ + "grant_type": string(oidc.GrantTypeDeviceCode), + "device_code": "123", + }, + wantCode: http.StatusBadRequest, + json: `{"error":"access_denied","error_description":"The authorization request was denied."}`, + }, + { + name: "missing grant type", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + wantCode: http.StatusBadRequest, + json: `{"error":"invalid_request","error_description":"grant_type missing"}`, + }, + { + name: "unsupported grant type", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + values: map[string]string{ + "grant_type": "foo", + }, + wantCode: http.StatusBadRequest, + json: `{"error":"unsupported_grant_type","error_description":"foo not supported"}`, + }, + { + name: "introspection", + method: http.MethodGet, + path: testProvider.IntrospectionEndpoint().Relative(), + basicAuth: &basicAuth{"web", "secret"}, + values: map[string]string{ + "token": accessToken, + }, + wantCode: http.StatusOK, + json: `{"active":true,"scope":"openid offline_access email profile phone","client_id":"web","sub":"id1","username":"test-user@localhost","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`, + }, + { + name: "user info", + method: http.MethodGet, + path: testProvider.UserinfoEndpoint().Relative(), + header: map[string]string{ + "authorization": "Bearer " + accessToken, + }, + wantCode: http.StatusOK, + json: `{"sub":"id1","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`, + }, + { + name: "refresh token", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + values: map[string]string{ + "grant_type": string(oidc.GrantTypeRefreshToken), + "refresh_token": refreshToken, + "client_id": client.GetID(), + "client_secret": "secret", + }, + wantCode: http.StatusOK, + contains: []string{ + `{"access_token":"`, + `","token_type":"Bearer","refresh_token":"`, + `","expires_in":299,"id_token":"`, + }, + }, + { + name: "revoke", + method: http.MethodGet, + path: testProvider.RevocationEndpoint().Relative(), + basicAuth: &basicAuth{"web", "secret"}, + values: map[string]string{ + "token": accessTokenRevoke, + }, + wantCode: http.StatusOK, + }, + { + name: "end session", + method: http.MethodGet, + path: testProvider.EndSessionEndpoint().Relative(), + values: map[string]string{ + "id_token_hint": idToken, + "client_id": "web", + }, + wantCode: http.StatusFound, + headerContains: map[string]string{"Location": "/logged-out"}, + contains: []string{`Found.`}, + }, + { + name: "keys", + method: http.MethodGet, + path: testProvider.KeysEndpoint().Relative(), + wantCode: http.StatusOK, + contains: []string{ + `{"keys":[{"use":"sig","kty":"RSA","kid":"`, + `","alg":"RS256","n":"`, `","e":"AQAB"}]}`, + }, + }, + { + name: "device authorization", + method: http.MethodGet, + path: testProvider.DeviceAuthorizationEndpoint().Relative(), + basicAuth: &basicAuth{"web", "secret"}, + values: map[string]string{ + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), + }, + wantCode: http.StatusOK, + contains: []string{ + `{"device_code":"`, `","user_code":"`, + `","verification_uri":"https://localhost:9998/device"`, + `"verification_uri_complete":"https://localhost:9998/device?user_code=`, + `","expires_in":300,"interval":5}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u := gu.PtrCopy(serverURL) + u.Path = tt.path + if tt.values != nil { + u.RawQuery = mapAsValues(tt.values) + } + var body io.Reader + if tt.body != nil { + body = strings.NewReader(mapAsValues(tt.body)) + } + + req := httptest.NewRequest(tt.method, u.String(), body) + for k, v := range tt.header { + req.Header.Set(k, v) + } + if tt.basicAuth != nil { + req.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password) + } + + rec := httptest.NewRecorder() + server.ServeHTTP(rec, req) + + resp := rec.Result() + require.NoError(t, err) + assert.Equal(t, tt.wantCode, resp.StatusCode) + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + respBodyString := string(respBody) + t.Log(respBodyString) + t.Log(resp.Header) + + if tt.json != "" { + assert.JSONEq(t, tt.json, respBodyString) + } + for _, c := range tt.contains { + assert.Contains(t, respBodyString, c) + } + for k, v := range tt.headerContains { + assert.Contains(t, resp.Header.Get(k), v) + } + }) + } +} diff --git a/pkg/op/server_http_test.go b/pkg/op/server_http_test.go new file mode 100644 index 00000000..86fe7ed8 --- /dev/null +++ b/pkg/op/server_http_test.go @@ -0,0 +1,1333 @@ +package op + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "strings" + "testing" + "time" + + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/schema" + "golang.org/x/exp/slog" +) + +func TestRegisterServer(t *testing.T) { + server := UnimplementedServer{} + endpoints := Endpoints{ + Authorization: &Endpoint{ + path: "/auth", + }, + } + decoder := schema.NewDecoder() + logger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) + + h := RegisterServer(server, endpoints, + WithDecoder(decoder), + WithFallbackLogger(logger), + ) + got := h.(*webServer) + assert.Equal(t, got.server, server) + assert.Equal(t, got.endpoints, endpoints) + assert.Equal(t, got.decoder, decoder) + assert.Equal(t, got.logger, logger) +} + +type testClient struct { + id string + appType ApplicationType + authMethod oidc.AuthMethod + accessTokenType AccessTokenType + responseTypes []oidc.ResponseType + grantTypes []oidc.GrantType + devMode bool +} + +type clientType string + +const ( + clientTypeWeb clientType = "web" + clientTypeNative clientType = "native" + clientTypeUserAgent clientType = "useragent" +) + +func newClient(kind clientType) *testClient { + client := &testClient{ + id: string(kind), + } + + switch kind { + case clientTypeWeb: + client.appType = ApplicationTypeWeb + client.authMethod = oidc.AuthMethodBasic + client.accessTokenType = AccessTokenTypeBearer + client.responseTypes = []oidc.ResponseType{oidc.ResponseTypeCode} + case clientTypeNative: + client.appType = ApplicationTypeNative + client.authMethod = oidc.AuthMethodNone + client.accessTokenType = AccessTokenTypeBearer + client.responseTypes = []oidc.ResponseType{oidc.ResponseTypeCode} + case clientTypeUserAgent: + client.appType = ApplicationTypeUserAgent + client.authMethod = oidc.AuthMethodBasic + client.accessTokenType = AccessTokenTypeJWT + client.responseTypes = []oidc.ResponseType{oidc.ResponseTypeIDToken} + default: + panic(fmt.Errorf("invalid client type %s", kind)) + } + return client +} + +func (c *testClient) RedirectURIs() []string { + return []string{ + "https://registered.com/callback", + "http://registered.com/callback", + "http://localhost:9999/callback", + "custom://callback", + } +} + +func (c *testClient) PostLogoutRedirectURIs() []string { + return []string{} +} + +func (c *testClient) LoginURL(id string) string { + return "login?id=" + id +} + +func (c *testClient) ApplicationType() ApplicationType { + return c.appType +} + +func (c *testClient) AuthMethod() oidc.AuthMethod { + return c.authMethod +} + +func (c *testClient) GetID() string { + return c.id +} + +func (c *testClient) AccessTokenLifetime() time.Duration { + return 5 * time.Minute +} + +func (c *testClient) IDTokenLifetime() time.Duration { + return 5 * time.Minute +} + +func (c *testClient) AccessTokenType() AccessTokenType { + return c.accessTokenType +} + +func (c *testClient) ResponseTypes() []oidc.ResponseType { + return c.responseTypes +} + +func (c *testClient) GrantTypes() []oidc.GrantType { + return c.grantTypes +} + +func (c *testClient) DevMode() bool { + return c.devMode +} + +func (c *testClient) AllowedScopes() []string { + return nil +} + +func (c *testClient) RestrictAdditionalIdTokenScopes() func(scopes []string) []string { + return func(scopes []string) []string { + return scopes + } +} + +func (c *testClient) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string { + return func(scopes []string) []string { + return scopes + } +} + +func (c *testClient) IsScopeAllowed(scope string) bool { + return false +} + +func (c *testClient) IDTokenUserinfoClaimsAssertion() bool { + return false +} + +func (c *testClient) ClockSkew() time.Duration { + return 0 +} + +type requestVerifier struct { + UnimplementedServer + client Client +} + +func (s *requestVerifier) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) { + if s.client == nil { + return nil, oidc.ErrServerError() + } + return &ClientRequest[oidc.AuthRequest]{ + Request: r, + Client: s.client, + }, nil +} + +func (s *requestVerifier) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) { + if s.client == nil { + return nil, oidc.ErrServerError() + } + return s.client, nil +} + +var testDecoder = func() *schema.Decoder { + decoder := schema.NewDecoder() + decoder.IgnoreUnknownKeys(true) + return decoder +}() + +type webServerResult struct { + wantStatus int + wantBody string +} + +func runWebServerTest(t *testing.T, handler http.HandlerFunc, r *http.Request, want webServerResult) { + t.Helper() + if r.Method == http.MethodPost { + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } + w := httptest.NewRecorder() + handler(w, r) + res := w.Result() + assert.Equal(t, want.wantStatus, res.StatusCode) + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + assert.JSONEq(t, want.wantBody, string(body)) +} + +func Test_webServer_withClient(t *testing.T) { + tests := []struct { + name string + r *http.Request + want webServerResult + }{ + { + name: "parse error", + r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error parsing form"}`, + }, + }, + { + name: "invalid grant type", + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=native&grant_type=bad&foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unauthorized_client", "error_description":"grant_type \"bad\" not allowed"}`, + }, + }, + { + name: "no grant type", + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=native&foo=bar")), + want: webServerResult{ + wantStatus: http.StatusOK, + wantBody: `{"foo":"bar"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: &requestVerifier{ + client: newClient(clientTypeNative), + }, + decoder: testDecoder, + logger: slog.Default(), + } + handler := func(w http.ResponseWriter, r *http.Request, client Client) { + fmt.Fprintf(w, `{"foo":%q}`, r.FormValue("foo")) + } + runWebServerTest(t, s.withClient(handler), tt.r, tt.want) + }) + } +} + +func Test_webServer_verifyRequestClient(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want Client + wantErr error + }{ + { + name: "parse form error", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))), + wantErr: oidc.ErrInvalidRequest().WithDescription("error parsing form"), + }, + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + wantErr: oidc.ErrInvalidRequest().WithDescription("error decoding form"), + }, + { + name: "basic auth, client_id error", + decoder: testDecoder, + r: func() *http.Request { + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")) + r.SetBasicAuth(`%%%`, "secret") + return r + }(), + wantErr: oidc.ErrInvalidClient().WithDescription("invalid basic auth header"), + }, + { + name: "basic auth, client_secret error", + decoder: testDecoder, + r: func() *http.Request { + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")) + r.SetBasicAuth("web", `%%%`) + return r + }(), + wantErr: oidc.ErrInvalidClient().WithDescription("invalid basic auth header"), + }, + { + name: "missing client id and assertion", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + wantErr: oidc.ErrInvalidRequest().WithDescription("client_id or client_assertion must be provided"), + }, + { + name: "wrong assertion type", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar&client_assertion=xxx&client_assertion_type=wrong")), + wantErr: oidc.ErrInvalidRequest().WithDescription("invalid client_assertion_type wrong"), + }, + { + name: "unimplemented verify client called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar&client_id=web")), + wantErr: StatusError{ + parent: oidc.ErrServerError().WithDescription("/ not implemented on this server"), + statusCode: UnimplementedStatusCode, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + tt.r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + got, err := s.verifyRequestClient(tt.r) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_webServer_authorizeHandler(t *testing.T) { + type fields struct { + server Server + decoder httphelper.Decoder + } + tests := []struct { + name string + fields fields + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + fields: fields{ + server: &requestVerifier{}, + decoder: schema.NewDecoder(), + }, + r: httptest.NewRequest(http.MethodPost, "/authorize", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "authorize error", + fields: fields{ + server: &requestVerifier{}, + decoder: testDecoder, + }, + r: httptest.NewRequest(http.MethodPost, "/authorize", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"server_error"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: tt.fields.server, + decoder: tt.fields.decoder, + logger: slog.Default(), + } + runWebServerTest(t, s.authorizeHandler, tt.r, tt.want) + }) + } +} + +func Test_webServer_authorize(t *testing.T) { + type args struct { + ctx context.Context + r *Request[oidc.AuthRequest] + } + tests := []struct { + name string + server Server + args args + want *Redirect + wantErr error + }{ + { + name: "verify error", + server: &requestVerifier{}, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + Data: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"openid"}, + ResponseType: oidc.ResponseTypeCode, + ClientID: "web", + RedirectURI: "https://registered.com/callback", + MaxAge: gu.Ptr[uint](300), + }, + }, + }, + wantErr: oidc.ErrServerError(), + }, + { + name: "missing redirect", + server: &requestVerifier{ + client: newClient(clientTypeWeb), + }, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + Data: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"openid"}, + ResponseType: oidc.ResponseTypeCode, + ClientID: "web", + MaxAge: gu.Ptr[uint](300), + }, + }, + }, + wantErr: ErrAuthReqMissingRedirectURI, + }, + { + name: "invalid prompt", + server: &requestVerifier{ + client: newClient(clientTypeWeb), + }, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + Data: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"openid"}, + ResponseType: oidc.ResponseTypeCode, + ClientID: "web", + RedirectURI: "https://registered.com/callback", + MaxAge: gu.Ptr[uint](300), + Prompt: []string{oidc.PromptNone, oidc.PromptLogin}, + }, + }, + }, + wantErr: oidc.ErrInvalidRequest().WithDescription("The prompt parameter `none` must only be used as a single value"), + }, + { + name: "missing scopes", + server: &requestVerifier{ + client: newClient(clientTypeWeb), + }, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + Data: &oidc.AuthRequest{ + ResponseType: oidc.ResponseTypeCode, + ClientID: "web", + RedirectURI: "https://registered.com/callback", + MaxAge: gu.Ptr[uint](300), + Prompt: []string{oidc.PromptNone}, + }, + }, + }, + wantErr: oidc.ErrInvalidRequest(). + WithDescription("The scope of your request is missing. Please ensure some scopes are requested. " + + "If you have any questions, you may contact the administrator of the application."), + }, + { + name: "invalid redirect", + server: &requestVerifier{ + client: newClient(clientTypeWeb), + }, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + Data: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"openid"}, + ResponseType: oidc.ResponseTypeCode, + ClientID: "web", + RedirectURI: "https://example.com/callback", + MaxAge: gu.Ptr[uint](300), + Prompt: []string{oidc.PromptNone}, + }, + }, + }, + wantErr: oidc.ErrInvalidRequestRedirectURI(). + WithDescription("The requested redirect_uri is missing in the client configuration. " + + "If you have any questions, you may contact the administrator of the application."), + }, + { + name: "invalid response type", + server: &requestVerifier{ + client: newClient(clientTypeWeb), + }, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + Data: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"openid"}, + ResponseType: oidc.ResponseTypeIDToken, + ClientID: "web", + RedirectURI: "https://registered.com/callback", + MaxAge: gu.Ptr[uint](300), + Prompt: []string{oidc.PromptNone}, + }, + }, + }, + wantErr: oidc.ErrUnauthorizedClient().WithDescription("The requested response type is missing in the client configuration. " + + "If you have any questions, you may contact the administrator of the application."), + }, + { + name: "unimplemented Authorize called", + server: &requestVerifier{ + client: newClient(clientTypeWeb), + }, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + URL: &url.URL{ + Path: "/authorize", + }, + Data: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"openid"}, + ResponseType: oidc.ResponseTypeCode, + ClientID: "web", + RedirectURI: "https://registered.com/callback", + MaxAge: gu.Ptr[uint](300), + Prompt: []string{oidc.PromptNone}, + }, + }, + }, + wantErr: StatusError{ + parent: oidc.ErrServerError().WithDescription("/authorize not implemented on this server"), + statusCode: UnimplementedStatusCode, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: tt.server, + decoder: testDecoder, + logger: slog.Default(), + } + got, err := s.authorize(tt.args.ctx, tt.args.r) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_webServer_deviceAuthorizationHandler(t *testing.T) { + type fields struct { + server Server + decoder httphelper.Decoder + } + tests := []struct { + name string + fields fields + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + fields: fields{ + server: &requestVerifier{}, + decoder: schema.NewDecoder(), + }, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "unimplemented DeviceAuthorization called", + fields: fields{ + server: &requestVerifier{ + client: newClient(clientTypeNative), + }, + decoder: testDecoder, + }, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=native_client")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: tt.fields.server, + decoder: tt.fields.decoder, + logger: slog.Default(), + } + client := newClient(clientTypeUserAgent) + runWebServerClientTest(t, s.deviceAuthorizationHandler, tt.r, client, tt.want) + }) + } +} + +func Test_webServer_tokensHandler(t *testing.T) { + tests := []struct { + name string + r *http.Request + want webServerResult + }{ + { + name: "parse form error", + r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error parsing form"}`, + }, + }, + { + name: "missing grant type", + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"grant_type missing"}`, + }, + }, + { + name: "invalid grant type", + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("grant_type=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"bar not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + logger: slog.Default(), + } + runWebServerTest(t, s.tokensHandler, tt.r, tt.want) + }) + } +} + +func Test_webServer_jwtProfileHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "assertion missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"assertion missing"}`, + }, + }, + { + name: "unimplemented JWTProfile called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("assertion=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"urn:ietf:params:oauth:grant-type:jwt-bearer not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerTest(t, s.jwtProfileHandler, tt.r, tt.want) + }) + } +} + +func runWebServerClientTest(t *testing.T, handler func(http.ResponseWriter, *http.Request, Client), r *http.Request, client Client, want webServerResult) { + t.Helper() + runWebServerTest(t, func(client Client) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + handler(w, r, client) + } + }(client), r, want) +} + +func Test_webServer_codeExchangeHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "code missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"code missing"}`, + }, + }, + { + name: "redirect missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("code=123")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"redirect_uri missing"}`, + }, + }, + { + name: "unimplemented CodeExchange called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("code=123&redirect_uri=https://example.com/callback")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"authorization_code not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + client := newClient(clientTypeUserAgent) + runWebServerClientTest(t, s.codeExchangeHandler, tt.r, client, tt.want) + }) + } +} + +func Test_webServer_refreshTokenHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "refresh token missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"refresh_token missing"}`, + }, + }, + { + name: "unimplemented RefreshToken called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("refresh_token=xxx")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"refresh_token not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + client := newClient(clientTypeUserAgent) + runWebServerClientTest(t, s.refreshTokenHandler, tt.r, client, tt.want) + }) + } +} + +func Test_webServer_tokenExchangeHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "subject token missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"subject_token missing"}`, + }, + }, + { + name: "subject token type missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"subject_token_type missing"}`, + }, + }, + { + name: "subject token type unsupported", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=foo")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"subject_token_type is not supported"}`, + }, + }, + { + name: "unsupported requested token type", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=urn:ietf:params:oauth:token-type:access_token&requested_token_type=foo")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"requested_token_type is not supported"}`, + }, + }, + { + name: "unsupported actor token type", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=urn:ietf:params:oauth:token-type:access_token&requested_token_type=urn:ietf:params:oauth:token-type:access_token&actor_token_type=foo")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"actor_token_type is not supported"}`, + }, + }, + { + name: "unimplemented TokenExchange called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=urn:ietf:params:oauth:token-type:access_token&requested_token_type=urn:ietf:params:oauth:token-type:access_token&actor_token_type=urn:ietf:params:oauth:token-type:access_token")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"urn:ietf:params:oauth:grant-type:token-exchange not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + client := newClient(clientTypeUserAgent) + runWebServerClientTest(t, s.tokenExchangeHandler, tt.r, client, tt.want) + }) + } +} + +func Test_webServer_clientCredentialsHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + client Client + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + client: newClient(clientTypeUserAgent), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "public client", + decoder: testDecoder, + client: newClient(clientTypeNative), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_client", "error_description":"client must be authenticated"}`, + }, + }, + { + name: "unimplemented ClientCredentialsExchange called", + decoder: testDecoder, + client: newClient(clientTypeUserAgent), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"client_credentials not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerClientTest(t, s.clientCredentialsHandler, tt.r, tt.client, tt.want) + }) + } +} + +func Test_webServer_deviceTokenHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "device code missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"device_code missing"}`, + }, + }, + { + name: "unimplemented DeviceToken called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("device_code=xxx")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"urn:ietf:params:oauth:grant-type:device_code not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + client := newClient(clientTypeUserAgent) + runWebServerClientTest(t, s.deviceTokenHandler, tt.r, client, tt.want) + }) + } +} + +func Test_webServer_introspectionHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + client Client + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + client: newClient(clientTypeUserAgent), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "public client", + decoder: testDecoder, + client: newClient(clientTypeNative), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_client", "error_description":"client must be authenticated"}`, + }, + }, + { + name: "token missing", + decoder: testDecoder, + client: newClient(clientTypeWeb), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"token missing"}`, + }, + }, + { + name: "unimplemented Introspect called", + decoder: testDecoder, + client: newClient(clientTypeWeb), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("token=xxx")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerClientTest(t, s.introspectionHandler, tt.r, tt.client, tt.want) + }) + } +} + +func Test_webServer_userInfoHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "access token missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusUnauthorized, + wantBody: `{"error":"invalid_request", "error_description":"access token missing"}`, + }, + }, + { + name: "unimplemented UserInfo called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("access_token=xxx")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + { + name: "bearer", + decoder: testDecoder, + r: func() *http.Request { + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Set("authorization", strings.Join([]string{"Bearer", "xxx"}, " ")) + return r + }(), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerTest(t, s.userInfoHandler, tt.r, tt.want) + }) + } +} + +func Test_webServer_revocationHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + client Client + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + client: newClient(clientTypeWeb), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "token missing", + decoder: testDecoder, + client: newClient(clientTypeWeb), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"token missing"}`, + }, + }, + { + name: "unimplemented Revocation called, confidential client", + decoder: testDecoder, + client: newClient(clientTypeWeb), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("token=xxx")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + { + name: "unimplemented Revocation called, public client", + decoder: testDecoder, + client: newClient(clientTypeNative), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("token=xxx")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerClientTest(t, s.revocationHandler, tt.r, tt.client, tt.want) + }) + } +} + +func Test_webServer_endSessionHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "unimplemented EndSession called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("id_token_hint=xxx")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerTest(t, s.endSessionHandler, tt.r, tt.want) + }) + } +} + +func Test_webServer_simpleHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + method func(context.Context, *Request[struct{}]) (*Response, error) + r *http.Request + want webServerResult + }{ + { + name: "parse error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error parsing form"}`, + }, + }, + { + name: "method error", + decoder: schema.NewDecoder(), + method: func(ctx context.Context, r *Request[struct{}]) (*Response, error) { + return nil, io.ErrClosedPipe + }, + r: httptest.NewRequest(http.MethodGet, "/", bytes.NewReader(make([]byte, 11<<20))), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"server_error", "error_description":"io: read/write on closed pipe"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerTest(t, simpleHandler(s, tt.method), tt.r, tt.want) + }) + } +} + +func Test_decodeRequest(t *testing.T) { + type dst struct { + A string `schema:"a"` + B string `schema:"b"` + } + type args struct { + r *http.Request + postOnly bool + } + tests := []struct { + name string + args args + want *dst + wantErr error + }{ + { + name: "parse error", + args: args{ + r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))), + }, + wantErr: oidc.ErrInvalidRequest().WithDescription("error parsing form"), + }, + { + name: "decode error", + args: args{ + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + }, + wantErr: oidc.ErrInvalidRequest().WithDescription("error decoding form"), + }, + { + name: "success, get", + args: args{ + r: httptest.NewRequest(http.MethodGet, "/?a=b&b=a", nil), + }, + want: &dst{ + A: "b", + B: "a", + }, + }, + { + name: "success, post only", + args: args{ + r: httptest.NewRequest(http.MethodPost, "/?b=a", strings.NewReader("a=b&")), + postOnly: true, + }, + want: &dst{ + A: "b", + }, + }, + { + name: "success, post mixed", + args: args{ + r: httptest.NewRequest(http.MethodPost, "/?b=a", strings.NewReader("a=b&")), + postOnly: false, + }, + want: &dst{ + A: "b", + B: "a", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.args.r.Method == http.MethodPost { + tt.args.r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } + got, err := decodeRequest[dst](schema.NewDecoder(), tt.args.r, tt.args.postOnly) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/op/server_legacy.go b/pkg/op/server_legacy.go new file mode 100644 index 00000000..0a7de855 --- /dev/null +++ b/pkg/op/server_legacy.go @@ -0,0 +1,344 @@ +package op + +import ( + "context" + "errors" + "net/http" + "time" + + "github.com/go-chi/chi" + "github.com/zitadel/oidc/v3/pkg/oidc" +) + +// LegacyServer is an implementation of [Server[] that +// simply wraps a [OpenIDProvider]. +// It can be used to transition from the former Provider/Storage +// interfaces to the new Server interface. +type LegacyServer struct { + UnimplementedServer + provider OpenIDProvider + endpoints Endpoints +} + +// NewLegacyServer wraps provider in a `Server` and returns a handler which is +// the Server's router. +// +// Only non-nil endpoints will be registered on the router. +// Nil endpoints are disabled. +// +// The passed endpoints is also set to the provider, +// to be consistent with the discovery config. +// Any `With*Endpoint()` option used on the provider is +// therefore ineffective. +func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) http.Handler { + server := RegisterServer(&LegacyServer{ + provider: provider, + endpoints: endpoints, + }, endpoints, WithHTTPMiddleware(intercept(provider.IssuerFromRequest))) + + router := chi.NewRouter() + router.Mount("/", server) + router.HandleFunc(authCallbackPath(provider), authorizeCallbackHandler(provider)) + + return router +} + +func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) { + return NewResponse(Status{Status: "ok"}), nil +} + +func (s *LegacyServer) Ready(ctx context.Context, r *Request[struct{}]) (*Response, error) { + for _, probe := range s.provider.Probes() { + // shouldn't we run probes in Go routines? + if err := probe(ctx); err != nil { + return nil, NewStatusError(err, http.StatusInternalServerError) + } + } + return NewResponse(Status{Status: "ok"}), nil +} + +func (s *LegacyServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) { + return NewResponse( + createDiscoveryConfigV2(ctx, s.provider, s.provider.Storage(), &s.endpoints), + ), nil +} + +func (s *LegacyServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) { + keys, err := s.provider.Storage().KeySet(ctx) + if err != nil { + return nil, NewStatusError(err, http.StatusInternalServerError) + } + return NewResponse(jsonWebKeySet(keys)), nil +} + +var ( + ErrAuthReqMissingClientID = errors.New("auth request is missing client_id") + ErrAuthReqMissingRedirectURI = errors.New("auth request is missing redirect_uri") +) + +func (s *LegacyServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) { + if r.Data.RequestParam != "" { + if !s.provider.RequestObjectSupported() { + return nil, oidc.ErrRequestNotSupported() + } + err := ParseRequestObject(ctx, r.Data, s.provider.Storage(), IssuerFromContext(ctx)) + if err != nil { + return nil, err + } + } + if r.Data.ClientID == "" { + return nil, ErrAuthReqMissingClientID + } + client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID) + if err != nil { + return nil, oidc.DefaultToServerError(err, "unable to retrieve client by id") + } + + return &ClientRequest[oidc.AuthRequest]{ + Request: r, + Client: client, + }, nil +} + +func (s *LegacyServer) Authorize(ctx context.Context, r *ClientRequest[oidc.AuthRequest]) (_ *Redirect, err error) { + userID, err := ValidateAuthReqIDTokenHint(ctx, r.Data.IDTokenHint, s.provider.IDTokenHintVerifier(ctx)) + if err != nil { + return nil, err + } + req, err := s.provider.Storage().CreateAuthRequest(ctx, r.Data, userID) + if err != nil { + return TryErrorRedirect(ctx, r.Data, oidc.DefaultToServerError(err, "unable to save auth request"), s.provider.Encoder(), s.provider.Logger()) + } + return NewRedirect(r.Client.LoginURL(req.GetID())), nil +} + +func (s *LegacyServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) { + response, err := createDeviceAuthorization(ctx, r.Data, r.Client.GetID(), s.provider) + if err != nil { + return nil, NewStatusError(err, http.StatusInternalServerError) + } + return NewResponse(response), nil +} + +func (s *LegacyServer) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) { + if oidc.GrantType(r.Form.Get("grant_type")) == oidc.GrantTypeClientCredentials { + storage, ok := s.provider.Storage().(ClientCredentialsStorage) + if !ok { + return nil, oidc.ErrUnsupportedGrantType().WithDescription("client_credentials grant not supported") + } + return storage.ClientCredentials(ctx, r.Data.ClientID, r.Data.ClientSecret) + } + + if r.Data.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion { + jwtExchanger, ok := s.provider.(JWTAuthorizationGrantExchanger) + if !ok || !s.provider.AuthMethodPrivateKeyJWTSupported() { + return nil, oidc.ErrInvalidClient().WithDescription("auth_method private_key_jwt not supported") + } + return AuthorizePrivateJWTKey(ctx, r.Data.ClientAssertion, jwtExchanger) + } + client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID) + if err != nil { + return nil, oidc.ErrInvalidClient().WithParent(err) + } + + switch client.AuthMethod() { + case oidc.AuthMethodNone: + return client, nil + case oidc.AuthMethodPrivateKeyJWT: + return nil, oidc.ErrInvalidClient().WithDescription("private_key_jwt not allowed for this client") + case oidc.AuthMethodPost: + if !s.provider.AuthMethodPostSupported() { + return nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported") + } + } + + err = AuthorizeClientIDSecret(ctx, r.Data.ClientID, r.Data.ClientSecret, s.provider.Storage()) + if err != nil { + return nil, err + } + + return client, nil +} + +func (s *LegacyServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) { + authReq, err := AuthRequestByCode(ctx, s.provider.Storage(), r.Data.Code) + if err != nil { + return nil, err + } + if r.Client.AuthMethod() == oidc.AuthMethodNone { + if err = AuthorizeCodeChallenge(r.Data.CodeVerifier, authReq.GetCodeChallenge()); err != nil { + return nil, err + } + } + resp, err := CreateTokenResponse(ctx, authReq, r.Client, s.provider, true, r.Data.Code, "") + if err != nil { + return nil, err + } + return NewResponse(resp), nil +} + +func (s *LegacyServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) { + if !s.provider.GrantTypeRefreshTokenSupported() { + return nil, unimplementedGrantError(oidc.GrantTypeRefreshToken) + } + request, err := RefreshTokenRequestByRefreshToken(ctx, s.provider.Storage(), r.Data.RefreshToken) + if err != nil { + return nil, err + } + if r.Client.GetID() != request.GetClientID() { + return nil, oidc.ErrInvalidGrant() + } + if err = ValidateRefreshTokenScopes(r.Data.Scopes, request); err != nil { + return nil, err + } + resp, err := CreateTokenResponse(ctx, request, r.Client, s.provider, true, "", r.Data.RefreshToken) + if err != nil { + return nil, err + } + return NewResponse(resp), nil +} + +func (s *LegacyServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfileGrantRequest]) (*Response, error) { + exchanger, ok := s.provider.(JWTAuthorizationGrantExchanger) + if !ok { + return nil, unimplementedGrantError(oidc.GrantTypeBearer) + } + tokenRequest, err := VerifyJWTAssertion(ctx, r.Data.Assertion, exchanger.JWTProfileVerifier(ctx)) + if err != nil { + return nil, err + } + + tokenRequest.Scopes, err = exchanger.Storage().ValidateJWTProfileScopes(ctx, tokenRequest.Issuer, r.Data.Scope) + if err != nil { + return nil, err + } + resp, err := CreateJWTTokenResponse(ctx, tokenRequest, exchanger) + if err != nil { + return nil, err + } + return NewResponse(resp), nil +} + +func (s *LegacyServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) { + if !s.provider.GrantTypeTokenExchangeSupported() { + return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange) + } + tokenExchangeRequest, err := CreateTokenExchangeRequest(ctx, r.Data, r.Client, s.provider) + if err != nil { + return nil, err + } + resp, err := CreateTokenExchangeResponse(ctx, tokenExchangeRequest, r.Client, s.provider) + if err != nil { + return nil, err + } + return NewResponse(resp), nil +} + +func (s *LegacyServer) ClientCredentialsExchange(ctx context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) { + storage, ok := s.provider.Storage().(ClientCredentialsStorage) + if !ok { + return nil, unimplementedGrantError(oidc.GrantTypeClientCredentials) + } + tokenRequest, err := storage.ClientCredentialsTokenRequest(ctx, r.Client.GetID(), r.Data.Scope) + if err != nil { + return nil, err + } + resp, err := CreateClientCredentialsTokenResponse(ctx, tokenRequest, s.provider, r.Client) + if err != nil { + return nil, err + } + return NewResponse(resp), nil +} + +func (s *LegacyServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) { + if !s.provider.GrantTypeClientCredentialsSupported() { + return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode) + } + // use a limited context timeout shorter as the default + // poll interval of 5 seconds. + ctx, cancel := context.WithTimeout(ctx, 4*time.Second) + defer cancel() + + state, err := CheckDeviceAuthorizationState(ctx, r.Client.GetID(), r.Data.DeviceCode, s.provider) + if err != nil { + return nil, err + } + tokenRequest := &deviceAccessTokenRequest{ + subject: state.Subject, + audience: []string{r.Client.GetID()}, + scopes: state.Scopes, + } + resp, err := CreateDeviceTokenResponse(ctx, tokenRequest, s.provider, r.Client) + if err != nil { + return nil, err + } + return NewResponse(resp), nil +} + +func (s *LegacyServer) Introspect(ctx context.Context, r *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) { + response := new(oidc.IntrospectionResponse) + tokenID, subject, ok := getTokenIDAndSubject(ctx, s.provider, r.Data.Token) + if !ok { + return NewResponse(response), nil + } + err := s.provider.Storage().SetIntrospectionFromToken(ctx, response, tokenID, subject, r.Client.GetID()) + if err != nil { + return NewResponse(response), nil + } + response.Active = true + return NewResponse(response), nil +} + +func (s *LegacyServer) UserInfo(ctx context.Context, r *Request[oidc.UserInfoRequest]) (*Response, error) { + tokenID, subject, ok := getTokenIDAndSubject(ctx, s.provider, r.Data.AccessToken) + if !ok { + return nil, NewStatusError(oidc.ErrAccessDenied().WithDescription("access token invalid"), http.StatusUnauthorized) + } + info := new(oidc.UserInfo) + err := s.provider.Storage().SetUserinfoFromToken(ctx, info, tokenID, subject, r.Header.Get("origin")) + if err != nil { + return nil, NewStatusError(err, http.StatusForbidden) + } + return NewResponse(info), nil +} + +func (s *LegacyServer) Revocation(ctx context.Context, r *ClientRequest[oidc.RevocationRequest]) (*Response, error) { + var subject string + doDecrypt := true + if r.Data.TokenTypeHint != "access_token" { + userID, tokenID, err := s.provider.Storage().GetRefreshTokenInfo(ctx, r.Client.GetID(), r.Data.Token) + if err != nil { + // An invalid refresh token means that we'll try other things (leaving doDecrypt==true) + if !errors.Is(err, ErrInvalidRefreshToken) { + return nil, RevocationError(oidc.ErrServerError().WithParent(err)) + } + } else { + r.Data.Token = tokenID + subject = userID + doDecrypt = false + } + } + if doDecrypt { + tokenID, userID, ok := getTokenIDAndSubjectForRevocation(ctx, s.provider, r.Data.Token) + if ok { + r.Data.Token = tokenID + subject = userID + } + } + if err := s.provider.Storage().RevokeToken(ctx, r.Data.Token, subject, r.Client.GetID()); err != nil { + return nil, RevocationError(err) + } + return NewResponse(nil), nil +} + +func (s *LegacyServer) EndSession(ctx context.Context, r *Request[oidc.EndSessionRequest]) (*Redirect, error) { + session, err := ValidateEndSessionRequest(ctx, r.Data, s.provider) + if err != nil { + return nil, err + } + err = s.provider.Storage().TerminateSession(ctx, session.UserID, session.ClientID) + if err != nil { + return nil, err + } + return NewRedirect(session.RedirectURI), nil +} diff --git a/pkg/op/server_test.go b/pkg/op/server_test.go new file mode 100644 index 00000000..0cad8fd5 --- /dev/null +++ b/pkg/op/server_test.go @@ -0,0 +1,5 @@ +package op + +// implementation check +var _ Server = &UnimplementedServer{} +var _ Server = &LegacyServer{} diff --git a/pkg/op/token_code.go b/pkg/op/token_code.go index baf377bc..371e1d41 100644 --- a/pkg/op/token_code.go +++ b/pkg/op/token_code.go @@ -88,7 +88,7 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, if err != nil { return nil, nil, err } - err = AuthorizeCodeChallenge(tokenReq, request.GetCodeChallenge()) + err = AuthorizeCodeChallenge(tokenReq.CodeVerifier, request.GetCodeChallenge()) return request, client, err } if client.AuthMethod() == oidc.AuthMethodPost && !exchanger.AuthMethodPostSupported() { diff --git a/pkg/op/token_exchange.go b/pkg/op/token_exchange.go index 21db1347..5156741d 100644 --- a/pkg/op/token_exchange.go +++ b/pkg/op/token_exchange.go @@ -197,12 +197,6 @@ func ValidateTokenExchangeRequest( return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing") } - storage := exchanger.Storage() - teStorage, ok := storage.(TokenExchangeStorage) - if !ok { - return nil, nil, oidc.ErrUnsupportedGrantType().WithDescription("token_exchange grant not supported") - } - client, err := AuthorizeTokenExchangeClient(ctx, clientID, clientSecret, exchanger) if err != nil { return nil, nil, err @@ -220,10 +214,28 @@ func ValidateTokenExchangeRequest( return nil, nil, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported") } + req, err := CreateTokenExchangeRequest(ctx, oidcTokenExchangeRequest, client, exchanger) + if err != nil { + return nil, nil, err + } + return req, client, nil +} + +func CreateTokenExchangeRequest( + ctx context.Context, + oidcTokenExchangeRequest *oidc.TokenExchangeRequest, + client Client, + exchanger Exchanger, +) (TokenExchangeRequest, error) { + teStorage, ok := exchanger.Storage().(TokenExchangeStorage) + if !ok { + return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange) + } + exchangeSubjectTokenIDOrToken, exchangeSubject, exchangeSubjectTokenClaims, ok := GetTokenIDAndSubjectFromToken(ctx, exchanger, oidcTokenExchangeRequest.SubjectToken, oidcTokenExchangeRequest.SubjectTokenType, false) if !ok { - return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token is invalid") + return nil, oidc.ErrInvalidRequest().WithDescription("subject_token is invalid") } var ( @@ -234,7 +246,7 @@ func ValidateTokenExchangeRequest( exchangeActorTokenIDOrToken, exchangeActor, exchangeActorTokenClaims, ok = GetTokenIDAndSubjectFromToken(ctx, exchanger, oidcTokenExchangeRequest.ActorToken, oidcTokenExchangeRequest.ActorTokenType, true) if !ok { - return nil, nil, oidc.ErrInvalidRequest().WithDescription("actor_token is invalid") + return nil, oidc.ErrInvalidRequest().WithDescription("actor_token is invalid") } } @@ -258,17 +270,17 @@ func ValidateTokenExchangeRequest( authTime: time.Now(), } - err = teStorage.ValidateTokenExchangeRequest(ctx, req) + err := teStorage.ValidateTokenExchangeRequest(ctx, req) if err != nil { - return nil, nil, err + return nil, err } err = teStorage.CreateTokenExchangeRequest(ctx, req) if err != nil { - return nil, nil, err + return nil, err } - return req, client, nil + return req, nil } func GetTokenIDAndSubjectFromToken( diff --git a/pkg/op/token_request.go b/pkg/op/token_request.go index 0df2fcee..b810633c 100644 --- a/pkg/op/token_request.go +++ b/pkg/op/token_request.go @@ -117,11 +117,11 @@ func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string, // AuthorizeCodeChallenge authorizes a client by validating the code_verifier against the previously sent // code_challenge of the auth request (PKCE) -func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, challenge *oidc.CodeChallenge) error { - if tokenReq.CodeVerifier == "" { +func AuthorizeCodeChallenge(codeVerifier string, challenge *oidc.CodeChallenge) error { + if codeVerifier == "" { return oidc.ErrInvalidRequest().WithDescription("code_challenge required") } - if !oidc.VerifyCodeChallenge(challenge, tokenReq.CodeVerifier) { + if !oidc.VerifyCodeChallenge(challenge, codeVerifier) { return oidc.ErrInvalidGrant().WithDescription("invalid code challenge") } return nil diff --git a/pkg/op/token_revocation.go b/pkg/op/token_revocation.go index fd1ee931..d19c7f7f 100644 --- a/pkg/op/token_revocation.go +++ b/pkg/op/token_revocation.go @@ -131,6 +131,11 @@ func ParseTokenRevocationRequest(r *http.Request, revoker Revoker) (token, token } func RevocationRequestError(w http.ResponseWriter, r *http.Request, err error) { + statusErr := RevocationError(err) + httphelper.MarshalJSONWithStatus(w, statusErr.parent, statusErr.statusCode) +} + +func RevocationError(err error) StatusError { e := oidc.DefaultToServerError(err, err.Error()) status := http.StatusBadRequest switch e.ErrorType { @@ -139,7 +144,7 @@ func RevocationRequestError(w http.ResponseWriter, r *http.Request, err error) { case oidc.ServerError: status = 500 } - httphelper.MarshalJSONWithStatus(w, e, status) + return NewStatusError(e, status) } func getTokenIDAndSubjectForRevocation(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, bool) {