Skip to content

Commit

Permalink
SNOW-955538: Multiple SAML Integrations Support (#1025)
Browse files Browse the repository at this point in the history
sdk issue #726
  • Loading branch information
sfc-gh-ext-simba-jl authored Jan 15, 2024
1 parent f1dfe51 commit 5c89d42
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 13 deletions.
4 changes: 4 additions & 0 deletions assert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ func assertStringContainsE(t *testing.T, actual string, expectedToContain string
errorOnNonEmpty(t, validateStringContains(actual, expectedToContain, descriptions...))
}

func assertStringContainsF(t *testing.T, actual string, expectedToContain string, descriptions ...string) {
fatalOnNonEmpty(t, validateStringContains(actual, expectedToContain, descriptions...))
}

func assertHasPrefixE(t *testing.T, actual string, expectedPrefix string, descriptions ...string) {
errorOnNonEmpty(t, validateHasPrefix(actual, expectedPrefix, descriptions...))
}
Expand Down
7 changes: 6 additions & 1 deletion auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,10 @@ func authenticateWithConfig(sc *snowflakeConn) error {
if sc.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue {
fillCachedIDToken(sc)
}
// Disable console login by default
if sc.cfg.DisableConsoleLogin == configBoolNotSet {
sc.cfg.DisableConsoleLogin = ConfigBoolTrue
}
}

if sc.cfg.Authenticator == AuthTypeUsernamePasswordMFA {
Expand All @@ -524,7 +528,8 @@ func authenticateWithConfig(sc *snowflakeConn) error {
sc.cfg.Account,
sc.cfg.User,
sc.cfg.Password,
sc.cfg.ExternalBrowserTimeout)
sc.cfg.ExternalBrowserTimeout,
sc.cfg.DisableConsoleLogin)
if err != nil {
sc.cleanup()
return err
Expand Down
16 changes: 16 additions & 0 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,22 @@ func TestUnitAuthenticateWithConfigOkta(t *testing.T) {
assertEqualE(t, err.Error(), "failed to get SAML response")
}

func TestUnitAuthenticateWithConfigExternalBrowser(t *testing.T) {
var err error
sr := &snowflakeRestful{
FuncPostAuthSAML: postAuthSAMLError,
TokenAccessor: getSimpleTokenAccessor(),
}
sc := getDefaultSnowflakeConn()
sc.cfg.Authenticator = AuthTypeExternalBrowser
sc.cfg.ExternalBrowserTimeout = defaultExternalBrowserTimeout
sc.rest = sr
sc.ctx = context.Background()
err = authenticateWithConfig(sc)
assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.")
assertEqualE(t, err.Error(), "failed to get SAML response")
}

func TestUnitAuthenticateExternalBrowser(t *testing.T) {
var err error
sr := &snowflakeRestful{
Expand Down
48 changes: 40 additions & 8 deletions authexternalbrowser.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package gosnowflake
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -70,11 +71,11 @@ func createLocalTCPListener() (*net.TCPListener, error) {
return tcpListener, nil
}

// Opens a browser window (or new tab) with the configured IDP Url.
// Opens a browser window (or new tab) with the configured login Url.
// This can / will fail if running inside a shell with no display, ie
// ssh'ing into a box attempting to authenticate via external browser.
func openBrowser(idpURL string) error {
err := browser.OpenURL(idpURL)
func openBrowser(loginURL string) error {
err := browser.OpenURL(loginURL)
if err != nil {
logger.Infof("failed to open a browser. err: %v", err)
return err
Expand All @@ -91,6 +92,7 @@ func getIdpURLProofKey(
authenticator string,
application string,
account string,
user string,
callbackPort int) (string, string, error) {

headers := make(map[string]string)
Expand All @@ -108,6 +110,7 @@ func getIdpURLProofKey(
ClientAppID: clientType,
ClientAppVersion: SnowflakeGoDriverVersion,
AccountName: account,
LoginName: user,
ClientEnvironment: clientEnvironment,
Authenticator: authenticator,
BrowserModeRedirectPort: strconv.Itoa(callbackPort),
Expand Down Expand Up @@ -144,6 +147,24 @@ func getIdpURLProofKey(
return respd.Data.SSOURL, respd.Data.ProofKey, nil
}

// Gets the login URL for multiple SAML
func getLoginURL(sr *snowflakeRestful, user string, callbackPort int) (string, string, error) {
proofKey := generateProofKey()

params := &url.Values{}
params.Add("login_name", user)
params.Add("browser_mode_redirect_port", strconv.Itoa(callbackPort))
params.Add("proof_key", proofKey)
url := sr.getFullURL(consoleLoginRequestPath, params)

return url.String(), proofKey, nil
}

func generateProofKey() string {
randomness := getSecureRandom(32)
return base64.StdEncoding.WithPadding(base64.StdPadding).EncodeToString(randomness)
}

// The response returned from Snowflake looks like so:
// GET /?token=encodedSamlToken
// Host: localhost:54001
Expand Down Expand Up @@ -187,10 +208,11 @@ func authenticateByExternalBrowser(
user string,
password string,
externalBrowserTimeout time.Duration,
disableConsoleLogin ConfigBool,
) ([]byte, []byte, error) {
resultChan := make(chan authenticateByExternalBrowserResult, 1)
go func() {
resultChan <- doAuthenticateByExternalBrowser(ctx, sr, authenticator, application, account, user, password)
resultChan <- doAuthenticateByExternalBrowser(ctx, sr, authenticator, application, account, user, password, disableConsoleLogin)
}()
select {
case <-time.After(externalBrowserTimeout):
Expand All @@ -204,7 +226,7 @@ func authenticateByExternalBrowser(
// - the golang snowflake driver communicates to Snowflake that the user wishes to
// authenticate via external browser
// - snowflake sends back the IDP Url configured at the Snowflake side for the
// provided account
// provided account, or use the multiple SAML way via console login
// - the default browser is opened to that URL
// - user authenticates at the IDP, and is redirected to Snowflake
// - Snowflake directs the user back to the driver
Expand All @@ -217,6 +239,7 @@ func doAuthenticateByExternalBrowser(
account string,
user string,
password string,
disableConsoleLogin ConfigBool,
) authenticateByExternalBrowserResult {
l, err := createLocalTCPListener()
if err != nil {
Expand All @@ -225,13 +248,22 @@ func doAuthenticateByExternalBrowser(
defer l.Close()

callbackPort := l.Addr().(*net.TCPAddr).Port
idpURL, proofKey, err := getIdpURLProofKey(
ctx, sr, authenticator, application, account, callbackPort)

var loginURL string
var proofKey string
if disableConsoleLogin == ConfigBoolTrue {
// Gets the IDP URL and Proof Key from Snowflake
loginURL, proofKey, err = getIdpURLProofKey(ctx, sr, authenticator, application, account, user, callbackPort)
} else {
// Multiple SAML way to do authentication via console login
loginURL, proofKey, err = getLoginURL(sr, user, callbackPort)
}

if err != nil {
return authenticateByExternalBrowserResult{nil, nil, err}
}

if err = openBrowser(idpURL); err != nil {
if err = openBrowser(loginURL); err != nil {
return authenticateByExternalBrowserResult{nil, nil, err}
}

Expand Down
35 changes: 31 additions & 4 deletions authexternalbrowser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package gosnowflake
import (
"context"
"errors"
"net/url"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -91,17 +92,17 @@ func TestUnitAuthenticateByExternalBrowser(t *testing.T) {
FuncPostAuthSAML: postAuthExternalBrowserError,
TokenAccessor: getSimpleTokenAccessor(),
}
_, _, err := authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout)
_, _, err := authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout, ConfigBoolTrue)
if err == nil {
t.Fatal("should have failed.")
}
sr.FuncPostAuthSAML = postAuthExternalBrowserFail
_, _, err = authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout)
_, _, err = authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout, ConfigBoolTrue)
if err == nil {
t.Fatal("should have failed.")
}
sr.FuncPostAuthSAML = postAuthExternalBrowserFailWithCode
_, _, err = authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout)
_, _, err = authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout, ConfigBoolTrue)
if err == nil {
t.Fatal("should have failed.")
}
Expand All @@ -128,7 +129,7 @@ func TestAuthenticationTimeout(t *testing.T) {
FuncPostAuthSAML: postAuthExternalBrowserError,
TokenAccessor: getSimpleTokenAccessor(),
}
_, _, err := authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout)
_, _, err := authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout, ConfigBoolTrue)
if err.Error() != "authentication timed out" {
t.Fatal("should have timed out")
}
Expand All @@ -146,3 +147,29 @@ func Test_createLocalTCPListener(t *testing.T) {
// Close the listener after the test.
defer listener.Close()
}

func TestUnitGetLoginURL(t *testing.T) {
expectedScheme := "https"
expectedHost := "abc.com:443"
user := "u"
callbackPort := 123
sr := &snowflakeRestful{
Protocol: "https",
Host: "abc.com",
Port: 443,
TokenAccessor: getSimpleTokenAccessor(),
}

loginURL, proofKey, err := getLoginURL(sr, user, callbackPort)
assertNilF(t, err, "failed to get login URL")
assertNotNilF(t, len(proofKey), "proofKey should be non-empty string")

urlPtr, err := url.Parse(loginURL)
assertNilF(t, err, "failed to parse the login URL")
assertEqualF(t, urlPtr.Scheme, expectedScheme)
assertEqualF(t, urlPtr.Host, expectedHost)
assertEqualF(t, urlPtr.Path, consoleLoginRequestPath)
assertStringContainsF(t, urlPtr.RawQuery, "login_name")
assertStringContainsF(t, urlPtr.RawQuery, "browser_mode_redirect_port")
assertStringContainsF(t, urlPtr.RawQuery, "proof_key")
}
16 changes: 16 additions & 0 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ type Config struct {
IncludeRetryReason ConfigBool // Should retried request contain retry reason

ClientConfigFile string // File path to the client configuration json file

DisableConsoleLogin ConfigBool // Indicates whether console login should be disabled
}

// Validate enables testing if config is correct.
Expand Down Expand Up @@ -262,6 +264,9 @@ func DSN(cfg *Config) (dsn string, err error) {
if cfg.ClientConfigFile != "" {
params.Add("clientConfigFile", cfg.ClientConfigFile)
}
if cfg.DisableConsoleLogin != configBoolNotSet {
params.Add("disableConsoleLogin", strconv.FormatBool(cfg.DisableConsoleLogin != ConfigBoolFalse))
}

dsn = fmt.Sprintf("%v:%v@%v:%v", url.QueryEscape(cfg.User), url.QueryEscape(cfg.Password), cfg.Host, cfg.Port)
if params.Encode() != "" {
Expand Down Expand Up @@ -754,6 +759,17 @@ func parseDSNParams(cfg *Config, params string) (err error) {
}
case "clientConfigFile":
cfg.ClientConfigFile = value
case "disableConsoleLogin":
var vv bool
vv, err = strconv.ParseBool(value)
if err != nil {
return
}
if vv {
cfg.DisableConsoleLogin = ConfigBoolTrue
} else {
cfg.DisableConsoleLogin = ConfigBoolFalse
}
default:
if cfg.Params == nil {
cfg.Params = make(map[string]*string)
Expand Down
57 changes: 57 additions & 0 deletions dsn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,40 @@ func TestParseDSN(t *testing.T) {
dsn: "u:[email protected]:443?authenticator=http%3A%2F%2Fsc.okta.com&ocspFailOpen=true&validateDefaultParameters=true",
err: errFailedToParseAuthenticator(),
},
{
dsn: "u:[email protected]:9876?account=a&protocol=http&authenticator=EXTERNALBROWSER&disableConsoleLogin=true",
config: &Config{
Account: "a", User: "u", Password: "p",
Authenticator: AuthTypeExternalBrowser,
Protocol: "http", Host: "a.snowflake.local", Port: 9876,
OCSPFailOpen: OCSPFailOpenTrue,
ValidateDefaultParameters: ConfigBoolTrue,
ClientTimeout: defaultClientTimeout,
JWTClientTimeout: defaultJWTClientTimeout,
ExternalBrowserTimeout: defaultExternalBrowserTimeout,
IncludeRetryReason: ConfigBoolTrue,
DisableConsoleLogin: ConfigBoolTrue,
},
ocspMode: ocspModeFailOpen,
err: nil,
},
{
dsn: "u:[email protected]:9876?account=a&protocol=http&authenticator=EXTERNALBROWSER&disableConsoleLogin=false",
config: &Config{
Account: "a", User: "u", Password: "p",
Authenticator: AuthTypeExternalBrowser,
Protocol: "http", Host: "a.snowflake.local", Port: 9876,
OCSPFailOpen: OCSPFailOpenTrue,
ValidateDefaultParameters: ConfigBoolTrue,
ClientTimeout: defaultClientTimeout,
JWTClientTimeout: defaultJWTClientTimeout,
ExternalBrowserTimeout: defaultExternalBrowserTimeout,
IncludeRetryReason: ConfigBoolTrue,
DisableConsoleLogin: ConfigBoolFalse,
},
ocspMode: ocspModeFailOpen,
err: nil,
},
}

for _, at := range []AuthType{AuthTypeExternalBrowser, AuthTypeOAuth} {
Expand Down Expand Up @@ -873,6 +907,9 @@ func TestParseDSN(t *testing.T) {
if test.config.IncludeRetryReason != cfg.IncludeRetryReason {
t.Fatalf("%v: Failed to match IncludeRetryReason. expected: %v, got: %v", i, test.config.IncludeRetryReason, cfg.IncludeRetryReason)
}
if test.config.DisableConsoleLogin != cfg.DisableConsoleLogin {
t.Fatalf("%v: Failed to match DisableConsoleLogin. expected: %v, got: %v", i, test.config.DisableConsoleLogin, cfg.DisableConsoleLogin)
}
assertEqualF(t, cfg.ClientConfigFile, test.config.ClientConfigFile, "client config file")
case test.err != nil:
driverErrE, okE := test.err.(*SnowflakeError)
Expand Down Expand Up @@ -1322,6 +1359,26 @@ func TestDSN(t *testing.T) {
},
dsn: "u:[email protected]:443?clientConfigFile=c%3A%5CUsers%5Cuser%5Cconfig.json&ocspFailOpen=true&region=b.c&validateDefaultParameters=true",
},
{
cfg: &Config{
User: "u",
Password: "p",
Account: "a.b.c",
Authenticator: AuthTypeExternalBrowser,
DisableConsoleLogin: ConfigBoolTrue,
},
dsn: "u:[email protected]:443?authenticator=externalbrowser&disableConsoleLogin=true&ocspFailOpen=true&region=b.c&validateDefaultParameters=true",
},
{
cfg: &Config{
User: "u",
Password: "p",
Account: "a.b.c",
Authenticator: AuthTypeExternalBrowser,
DisableConsoleLogin: ConfigBoolFalse,
},
dsn: "u:[email protected]:443?authenticator=externalbrowser&disableConsoleLogin=false&ocspFailOpen=true&region=b.c&validateDefaultParameters=true",
},
}
for _, test := range testcases {
t.Run(test.dsn, func(t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions restful.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ const (
monitoringQueriesPath = "/monitoring/queries"
sessionRequestPath = "/session"
heartBeatPath = "/session/heartbeat"
consoleLoginRequestPath = "/console/login"
)

type (
Expand Down

0 comments on commit 5c89d42

Please sign in to comment.