Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CBG-801: Auto-generated OIDC callback URL should include provider when non-default #4549

Closed
wants to merge 8 commits into from
51 changes: 50 additions & 1 deletion rest/oidc_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
package rest

import (
"errors"
"fmt"
"net/http"
"net/url"
Expand All @@ -24,13 +25,26 @@ const (
OIDC_AUTH_RESPONSE_TYPE = "response_type"
OIDC_AUTH_CLIENT_ID = "client_id"
OIDC_AUTH_SCOPE = "scope"
OIDC_AUTH_REDIRECT_URI = "redirect_uri"
OIDC_AUTH_STATE = "state"

// Request parameter to specify the OpenID Connect provider to be used for authentication,
// from the list of providers defined in the Sync Gateway configuration.
oidcAuthProvider = "provider"

// Request parameter to specify the URL to which you want the end-user to be redirected
// after the authorization is complete.
oidcAuthRedirectURI = "redirect_uri"

OIDC_RESPONSE_TYPE_CODE = "code"
OIDC_RESPONSE_TYPE_IMPLICIT = "id_token%20token"
)

// Error codes returned by failures to add parameters to callback URL.
var (
ErrBadCallbackURL = errors.New("oidc: callback URL must not be nil")
ErrNoRedirectURI = errors.New("oidc: no redirect_uri parameter found in URL")
)

type OIDCTokenResponse struct {
IDToken string `json:"id_token"` // ID token, from OP
RefreshToken string `json:"refresh_token,omitempty"` // Refresh token, from OP
Expand Down Expand Up @@ -99,9 +113,44 @@ func (h *handler) handleOIDCCommon() (redirectURLString string, err error) {
return redirectURLString, err
}

if !provider.IsDefault {
sarathkumarsivan marked this conversation as resolved.
Show resolved Hide resolved
base.Debugf(base.KeyAuth, "Adding provider (%v) to callback URL", base.UD(provider.Name))
if err = addCallbackURLQueryParam(redirectURL, oidcAuthProvider, provider.Name); err != nil {
base.Errorf("Failed to add provider to callback URL, err: %v", err)
}
base.Debugf(base.KeyAuth, "Callback URL: %s", redirectURL.String())
sarathkumarsivan marked this conversation as resolved.
Show resolved Hide resolved
}

return redirectURL.String(), nil
}

func addCallbackURLQueryParam(uri *url.URL, name, value string) error {
if uri == nil {
return ErrBadCallbackURL
}
rawQuery, err := url.ParseQuery(uri.RawQuery)
if err != nil {
return err
}
redirectURL := rawQuery.Get(oidcAuthRedirectURI)
if redirectURL == "" {
return ErrNoRedirectURI
}
redirectURI, err := url.Parse(redirectURL)
if err != nil {
return err
}
rawQueryRedirectURI, err := url.ParseQuery(redirectURI.RawQuery)
if err != nil {
return err
}
rawQueryRedirectURI.Set(name, value)
redirectURI.RawQuery = rawQueryRedirectURI.Encode()
rawQuery.Set(oidcAuthRedirectURI, redirectURI.String())
uri.RawQuery = rawQuery.Encode()
return nil
}

func (h *handler) handleOIDCCallback() error {
callbackError := h.getQuery("error")
if callbackError != "" {
Expand Down
61 changes: 61 additions & 0 deletions rest/oidc_api_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package rest

import (
"net/url"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestAddCallbackURLQueryParam(t *testing.T) {
var oidcAuthProviderGoogle = "google"
tests := []struct {
name string
inputURL string
inputParamName string
inputParamValue string
wantURL string
wantError error
}{{
name: "Add provider parameter to callback URL",
inputURL: "https://accounts.google.com/o/oauth2/v2/auth?client_id=EADGBE&redirect_uri=http%3A%2F%2Flocalhost%3A4984%2Fdefault%2F_oidc_callback&response_type=code&scope=openid+email&state=GDCEm",
inputParamName: oidcAuthProvider,
inputParamValue: oidcAuthProviderGoogle,
wantURL: "https://accounts.google.com/o/oauth2/v2/auth?client_id=EADGBE&redirect_uri=http%3A%2F%2Flocalhost%3A4984%2Fdefault%2F_oidc_callback%3Fprovider%3Dgoogle&response_type=code&scope=openid+email&state=GDCEm",
}, {
name: "Add provider parameter with empty value to callback URL",
inputURL: "https://accounts.google.com/o/oauth2/v2/auth?client_id=EADGBE&redirect_uri=http%3A%2F%2Flocalhost%3A4984%2Fdefault%2F_oidc_callback&response_type=code&scope=openid+email&state=GDCEm",
inputParamName: oidcAuthProvider,
wantURL: "https://accounts.google.com/o/oauth2/v2/auth?client_id=EADGBE&redirect_uri=http%3A%2F%2Flocalhost%3A4984%2Fdefault%2F_oidc_callback%3Fprovider%3D&response_type=code&scope=openid+email&state=GDCEm",
}, {
name: "Add provider parameter to callback URL which doesn't have redirect_uri",
inputURL: "https://accounts.google.com/o/oauth2/v2/auth?access_type=offline&client_id=client123&prompt=consent",
inputParamName: oidcAuthProvider,
inputParamValue: oidcAuthProviderGoogle,
wantURL: "https://accounts.google.com/o/oauth2/v2/auth?access_type=offline&client_id=client123&prompt=consent",
wantError: ErrNoRedirectURI,
}, {
name: "Add provider parameter to callback URL which has invalid redirect_uri",
inputURL: "https://accounts.google.com/o/oauth2/v2/auth?client_id=EADGBE&redirect_uri=http%%3A%2F%2Flocalhost%3A4984%2Fdefault%2F_oidc_callback&response_type=code&scope=openid+email&state=GDCEm",
inputParamName: oidcAuthProvider,
inputParamValue: oidcAuthProviderGoogle,
wantURL: "https://accounts.google.com/o/oauth2/v2/auth?client_id=EADGBE&redirect_uri=http%%3A%2F%2Flocalhost%3A4984%2Fdefault%2F_oidc_callback&response_type=code&scope=openid+email&state=GDCEm",
wantError: url.EscapeError("%%3"),
sarathkumarsivan marked this conversation as resolved.
Show resolved Hide resolved
}}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
inputURL, err := url.Parse(test.inputURL)
require.NoError(t, err, "Couldn't parse URL")
err = addCallbackURLQueryParam(inputURL, test.inputParamName, test.inputParamValue)
assert.Equal(t, test.wantError, err)
assert.Equal(t, test.wantURL, inputURL.String())
})
}
}

func TestAddCallbackURLQueryParamNoURL(t *testing.T) {
var oidcAuthProviderGoogle = "google"
err := addCallbackURLQueryParam(nil, oidcAuthProvider, oidcAuthProviderGoogle)
assert.Equal(t, ErrBadCallbackURL, err)
}