Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CBG-801: Auto-generated OIDC callback URL should include provider when non-default #4549

Closed
wants to merge 8 commits into from
18 changes: 18 additions & 0 deletions auth/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package auth

import (
"os"
"testing"

"github.com/couchbase/sync_gateway/base"
)

func TestMain(m *testing.M) {
base.GTestBucketPool = base.NewTestBucketPool(base.FlushBucketEmptierFunc, base.NoopInitFunc)

status := m.Run()

base.GTestBucketPool.Close()

os.Exit(status)
}
28 changes: 26 additions & 2 deletions auth/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
package auth

import (
"errors"
"fmt"
"net/http"
"net/url"
Expand All @@ -30,6 +31,12 @@ const (

var OIDCDiscoveryRetryWait = 500 * time.Millisecond

// Request parameter to specify the OpenID Connect provider to be used for authentication,
// from the list of providers defined in the Sync Gateway configuration.
var OIDCAuthProvider = "provider"

var ErrAddURLQueryParam = errors.New("URL, parameter name and value must not be empty")

// Options for OpenID Connect
type OIDCOptions struct {
Providers OIDCProviderMap `json:"providers,omitempty"` // List of OIDC issuers
Expand Down Expand Up @@ -57,7 +64,7 @@ type OIDCProvider struct {

type OIDCProviderMap map[string]*OIDCProvider

type OIDCCallbackURLFunc func() string
type OIDCCallbackURLFunc func(string, bool) string

func (opm OIDCProviderMap) GetDefaultProvider() *OIDCProvider {
for _, provider := range opm {
Expand Down Expand Up @@ -93,7 +100,7 @@ func (op *OIDCProvider) GetClient(buildCallbackURLFunc OIDCCallbackURLFunc) *oid
// If the redirect URL is not defined for the provider generate it from the
// handler request and set it on the provider
if op.CallbackURL == nil || *op.CallbackURL == "" {
callbackURL := buildCallbackURLFunc()
callbackURL := buildCallbackURLFunc(op.Name, op.IsDefault)
if callbackURL != "" {
op.CallbackURL = &callbackURL
}
Expand Down Expand Up @@ -282,3 +289,20 @@ func OIDCToHTTPError(err error) error {
}
return err
}

func AddURLQueryParam(strURL, name, value string) (string, error) {
if strURL == "" || name == "" || value == "" {
return "", ErrAddURLQueryParam
}
uri, err := url.Parse(strURL)
if err != nil {
return "", err
}
rawQuery, err := url.ParseQuery(uri.RawQuery)
if err != nil {
return "", err
}
rawQuery.Add(name, value)
uri.RawQuery = rawQuery.Encode()
return uri.String(), nil
}
70 changes: 69 additions & 1 deletion auth/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
package auth

import (
"errors"
"net/http"
"net/url"
"strconv"
"testing"
"time"
Expand Down Expand Up @@ -253,7 +255,7 @@ func TestOIDCProvider_InitOIDCClient(t *testing.T) {
}

if test.Provider != nil {
client := test.Provider.GetClient(func() string { return "" })
client := test.Provider.GetClient(func(string, bool) string { return "" })
if test.ExpectOIDCClient {
assert.NotEqual(tt, (*oidc.Client)(nil), client)
} else {
Expand Down Expand Up @@ -378,3 +380,69 @@ func TestOIDCToHTTPError(t *testing.T) {
assert.Error(t, httpErr)
assert.Contains(t, httpErr.Error(), strconv.Itoa(http.StatusBadRequest))
}

func TestAddURLQueryParam(t *testing.T) {
var oidcAuthProviderGoogle = "google"
tests := []struct {
name string
inputCallbackURL string
inputParamName string
inputParamValue string
wantCallbackURL string
wantError error
}{{
name: "Add provider to callback URL",
inputCallbackURL: "http://localhost:4984/default/_oidc_callback",
inputParamName: OIDCAuthProvider,
inputParamValue: oidcAuthProviderGoogle,
wantCallbackURL: "http://localhost:4984/default/_oidc_callback?provider=google",
}, {
name: "Add provider to callback URL with ? character",
inputCallbackURL: "http://localhost:4984/default/_oidc_callback?",
inputParamName: OIDCAuthProvider,
inputParamValue: oidcAuthProviderGoogle,
wantCallbackURL: "http://localhost:4984/default/_oidc_callback?provider=google",
}, {
name: "Add provider to empty callback URL",
inputCallbackURL: "",
inputParamName: OIDCAuthProvider,
inputParamValue: oidcAuthProviderGoogle,
wantCallbackURL: "",
wantError: ErrAddURLQueryParam,
}, {
name: "Add empty provider value to callback URL",
inputCallbackURL: "http://localhost:4984/default/_oidc_callback",
inputParamName: OIDCAuthProvider,
inputParamValue: "",
wantCallbackURL: "",
wantError: ErrAddURLQueryParam,
}, {
name: "Add empty provider name to callback URL",
inputCallbackURL: "http://localhost:4984/default/_oidc_callback",
inputParamName: "",
inputParamValue: oidcAuthProviderGoogle,
wantCallbackURL: "",
wantError: ErrAddURLQueryParam,
}, {
name: "Add provider to callback URL with illegal value in query param",
inputCallbackURL: "http://localhost:4984/default/_oidc_callback?provider=%%3",
inputParamName: OIDCAuthProvider,
inputParamValue: oidcAuthProviderGoogle,
wantCallbackURL: "",
wantError: url.EscapeError("%%3"),
}, {
name: "Add provider to callback URL with missing protocol scheme",
inputCallbackURL: "://localhost:4984/default/_oidc_callback",
inputParamName: OIDCAuthProvider,
inputParamValue: oidcAuthProviderGoogle,
wantCallbackURL: "",
wantError: &url.Error{Op: "parse", URL: "://localhost:4984/default/_oidc_callback", Err: errors.New("missing protocol scheme")},
}}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
callbackURL, err := AddURLQueryParam(test.inputCallbackURL, test.inputParamName, test.inputParamValue)
assert.Equal(t, test.wantError, err)
assert.Equal(t, test.wantCallbackURL, callbackURL)
})
}
}
92 changes: 80 additions & 12 deletions base/bucket_gocb.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ type CouchbaseBucketGoCB struct {

// Creates a Bucket that talks to a real live Couchbase server.
func GetCouchbaseBucketGoCB(spec BucketSpec) (bucket *CouchbaseBucketGoCB, err error) {

// TODO: Push the above down into spec.GetConnString
connString, err := spec.GetGoCBConnString()
if err != nil {
Warnf("Unable to parse server value: %s error: %v", SD(spec.Server), err)
Expand All @@ -86,17 +84,17 @@ func GetCouchbaseBucketGoCB(spec BucketSpec) (bucket *CouchbaseBucketGoCB, err e
return nil, err
}

password := ""
bucketPassword := ""
// Check for client cert (x.509) authentication
if spec.Certpath != "" {
Infof(KeyAuth, "Attempting cert authentication against bucket %s on %s", MD(spec.BucketName), MD(connString))
Infof(KeyAuth, "Attempting cert authentication against bucket %s on %s", MD(spec.BucketName), MD(spec.Server))
certAuthErr := cluster.Authenticate(gocb.CertAuthenticator{})
if certAuthErr != nil {
Infof(KeyAuth, "Error Attempting certificate authentication %s", certAuthErr)
return nil, pkgerrors.WithStack(certAuthErr)
}
} else if spec.Auth != nil {
Infof(KeyAuth, "Attempting credential authentication against bucket %s on %s", MD(spec.BucketName), MD(connString))
Infof(KeyAuth, "Attempting credential authentication against bucket %s on %s", MD(spec.BucketName), MD(spec.Server))
user, pass, _ := spec.Auth.GetCredentials()
authErr := cluster.Authenticate(gocb.PasswordAuthenticator{
Username: user,
Expand All @@ -105,11 +103,15 @@ func GetCouchbaseBucketGoCB(spec BucketSpec) (bucket *CouchbaseBucketGoCB, err e
// If RBAC authentication fails, revert to non-RBAC authentication by including the password to OpenBucket
if authErr != nil {
Warnf("RBAC authentication against bucket %s as user %s failed - will re-attempt w/ bucketname, password", MD(spec.BucketName), UD(user))
password = pass
bucketPassword = pass
}
}

goCBBucket, err := cluster.OpenBucket(spec.BucketName, password)
return GetCouchbaseBucketGoCBFromAuthenticatedCluster(cluster, spec, bucketPassword)
}

func GetCouchbaseBucketGoCBFromAuthenticatedCluster(cluster *gocb.Cluster, spec BucketSpec, bucketPassword string) (bucket *CouchbaseBucketGoCB, err error) {
goCBBucket, err := cluster.OpenBucket(spec.BucketName, bucketPassword)
if err != nil {
Infof(KeyAll, "Error opening bucket %s: %v", spec.BucketName, err)
return nil, pkgerrors.WithStack(err)
Expand Down Expand Up @@ -181,9 +183,7 @@ func GetCouchbaseBucketGoCB(spec BucketSpec) (bucket *CouchbaseBucketGoCB, err e
bucket.Bucket.SetN1qlTimeout(bucket.spec.GetViewQueryTimeout())

Infof(KeyAll, "Set query timeouts for bucket %s to cluster:%v, bucket:%v", spec.BucketName, cluster.N1qlTimeout(), bucket.N1qlTimeout())

return bucket, err

}

func (bucket *CouchbaseBucketGoCB) GetBucketCredentials() (username, password string) {
Expand Down Expand Up @@ -1786,6 +1786,35 @@ func (bucket *CouchbaseBucketGoCB) Incr(k string, amt, def uint64, exp uint32) (

}

func (bucket *CouchbaseBucketGoCB) GetDDocs(into interface{}) error {
bucketManager, err := bucket.getBucketManager()
if err != nil {
return err
}

ddocs, err := bucketManager.GetDesignDocuments()
if err != nil {
return err
}

result := make(map[string]*gocb.DesignDocument, len(ddocs))
for _, ddoc := range ddocs {
result[ddoc.Name] = ddoc
}

resultBytes, err := JSONMarshal(result)
if err != nil {
return err
}

// Deserialize []byte into "into" empty interface
if err := JSONUnmarshal(resultBytes, into); err != nil {
return err
}

return nil
}

func (bucket *CouchbaseBucketGoCB) GetDDoc(docname string, into interface{}) error {

bucketManager, err := bucket.getBucketManager()
Expand Down Expand Up @@ -1882,7 +1911,18 @@ func (bucket *CouchbaseBucketGoCB) PutDDoc(docname string, value interface{}) er
return bucket.putDDocForTombstones(gocbDesignDoc)
}

return manager.UpsertDesignDocument(gocbDesignDoc)
// Retry for all errors (The view service sporadically returns 500 status codes with Erlang errors (for unknown reasons) - E.g: 500 {"error":"case_clause","reason":"false"})
var worker RetryWorker = func() (bool, error, interface{}) {
err := manager.UpsertDesignDocument(gocbDesignDoc)
if err != nil {
Warnf("Got error from UpsertDesignDocument: %v - Retrying...", err)
return true, err, nil
}
return false, nil, nil
}

err, _ = RetryLoop("PutDDocRetryLoop", worker, CreateSleeperFunc(5, 100))
return err

}

Expand Down Expand Up @@ -2346,10 +2386,22 @@ func (bucket *CouchbaseBucketGoCB) Flush() error {

}

// BucketItemCount first tries to retrieve an accurate bucket count via N1QL,
// but falls back to the REST API if that cannot be done (when there's no index to count all items in a bucket)
func (bucket *CouchbaseBucketGoCB) BucketItemCount() (itemCount int, err error) {
itemCount, err = bucket.QueryBucketItemCount()
if err == nil {
return itemCount, nil
}

itemCount, err = bucket.APIBucketItemCount()
return itemCount, err
}

// Get the number of items in the bucket.
// GOCB doesn't currently offer a way to do this, and so this is a workaround to go directly
// to Couchbase Server REST API.
func (bucket *CouchbaseBucketGoCB) BucketItemCount() (itemCount int, err error) {
func (bucket *CouchbaseBucketGoCB) APIBucketItemCount() (itemCount int, err error) {
uri := fmt.Sprintf("/pools/default/buckets/%s", bucket.Name())
resp, err := bucket.mgmtRequest(http.MethodGet, uri, "application/json", nil)
if err != nil {
Expand Down Expand Up @@ -2379,6 +2431,22 @@ func (bucket *CouchbaseBucketGoCB) BucketItemCount() (itemCount int, err error)
return int(itemCountFloat), nil
}

// QueryBucketItemCount uses a request plus query to get the number of items in a bucket, as the REST API can be slow to update its value.
func (bucket *CouchbaseBucketGoCB) QueryBucketItemCount() (itemCount int, err error) {
r, err := bucket.Query("SELECT COUNT(1) AS count FROM `$_bucket`", nil, gocb.RequestPlus, true)
if err != nil {
return -1, err
}
var val struct {
Count int `json:"count"`
}
err = r.One(&val)
if err != nil {
return -1, err
}
return val.Count, nil
}

func (bucket *CouchbaseBucketGoCB) getExpirySingleAttempt(k string) (expiry uint32, getMetaError error) {

bucket.singleOps <- struct{}{}
Expand Down Expand Up @@ -2653,7 +2721,7 @@ func AsGoCBBucket(bucket Bucket) (*CouchbaseBucketGoCB, bool) {
underlyingBucket = typedBucket.GetUnderlyingBucket()
case *LeakyBucket:
underlyingBucket = typedBucket.GetUnderlyingBucket()
case TestBucket:
case *TestBucket:
underlyingBucket = typedBucket.Bucket
default:
// bail out for unrecognised/unsupported buckets
Expand Down
Loading