Skip to content

Commit

Permalink
Use Idp Account Name as key for credentials store
Browse files Browse the repository at this point in the history
At the moment the credentials are stored using the idp server url as
a primary key for the keychain store.

This is very useful if you need to authenticate through different
providers or environment (and thus different endpoints), however this
prevents the use case where users may want to use different idp accounts
to authenticate using different users.

A typical example of this is that some users often have a normal user
and an admin user with different properties, roles and limitations.

This change is refactoring how credentials are stored across the code
base to use the idp Name instead of the server URL.
  • Loading branch information
sledigabel committed Feb 16, 2022
1 parent cdd6da8 commit c52b819
Show file tree
Hide file tree
Showing 13 changed files with 129 additions and 70 deletions.
9 changes: 5 additions & 4 deletions cmd/saml2aws/commands/configure.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"path"

"github.com/pkg/errors"
"github.com/versent/saml2aws/v2"
saml2aws "github.com/versent/saml2aws/v2"
"github.com/versent/saml2aws/v2/helper/credentials"
"github.com/versent/saml2aws/v2/pkg/cfg"
"github.com/versent/saml2aws/v2/pkg/flags"
Expand Down Expand Up @@ -68,14 +68,14 @@ func storeCredentials(configFlags *flags.CommonFlags, account *cfg.IDPAccount) e
return nil
}
if configFlags.Password != "" {
if err := credentials.SaveCredentials(account.URL, account.Username, configFlags.Password); err != nil {
if err := credentials.SaveCredentials(account.Name, account.URL, account.Username, configFlags.Password); err != nil {
return errors.Wrap(err, "error storing password in keychain")
}
} else {
password := prompter.Password("Password")
if password != "" {
if confirmPassword := prompter.Password("Confirm"); confirmPassword == password {
if err := credentials.SaveCredentials(account.URL, account.Username, password); err != nil {
if err := credentials.SaveCredentials(account.Name, account.URL, account.Username, password); err != nil {
return errors.Wrap(err, "error storing password in keychain")
}
} else {
Expand All @@ -91,7 +91,8 @@ func storeCredentials(configFlags *flags.CommonFlags, account *cfg.IDPAccount) e
log.Println("OneLogin provider requires --client_id and --client_secret flags to be set.")
os.Exit(1)
}
if err := credentials.SaveCredentials(path.Join(account.URL, OneLoginOAuthPath), configFlags.ClientID, configFlags.ClientSecret); err != nil {
// we store the OneLogin token in a different secret (idpName + the one login suffix)
if err := credentials.SaveCredentials(account.Name+credentials.OneLoginTokenSuffix, path.Join(account.URL, OneLoginOAuthPath), configFlags.ClientID, configFlags.ClientSecret); err != nil {
return errors.Wrap(err, "error storing client_id and client_secret in keychain")
}
}
Expand Down
4 changes: 2 additions & 2 deletions cmd/saml2aws/commands/list_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (

"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/versent/saml2aws/v2"
saml2aws "github.com/versent/saml2aws/v2"
"github.com/versent/saml2aws/v2/helper/credentials"
"github.com/versent/saml2aws/v2/pkg/flags"
"github.com/versent/saml2aws/v2/pkg/samlcache"
Expand Down Expand Up @@ -83,7 +83,7 @@ func ListRoles(loginFlags *flags.LoginExecFlags) error {
}

if !loginFlags.CommonFlags.DisableKeychain {
err = credentials.SaveCredentials(loginDetails.URL, loginDetails.Username, loginDetails.Password)
err = credentials.SaveCredentials(loginDetails.IdpName, loginDetails.URL, loginDetails.Username, loginDetails.Password)
if err != nil {
return errors.Wrap(err, "error storing password in keychain")
}
Expand Down
17 changes: 11 additions & 6 deletions cmd/saml2aws/commands/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"github.com/aws/aws-sdk-go/service/sts"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/versent/saml2aws/v2"
saml2aws "github.com/versent/saml2aws/v2"
"github.com/versent/saml2aws/v2/helper/credentials"
"github.com/versent/saml2aws/v2/pkg/awsconfig"
"github.com/versent/saml2aws/v2/pkg/cfg"
Expand Down Expand Up @@ -122,7 +122,7 @@ func Login(loginFlags *flags.LoginExecFlags) error {
}

if !loginFlags.CommonFlags.DisableKeychain {
err = credentials.SaveCredentials(loginDetails.URL, loginDetails.Username, loginDetails.Password)
err = credentials.SaveCredentials(loginDetails.IdpName, loginDetails.URL, loginDetails.Username, loginDetails.Password)
if err != nil {
return errors.Wrap(err, "Error storing password in keychain.")
}
Expand Down Expand Up @@ -174,15 +174,20 @@ func buildIdpAccount(loginFlags *flags.LoginExecFlags) (*cfg.IDPAccount, error)

func resolveLoginDetails(account *cfg.IDPAccount, loginFlags *flags.LoginExecFlags) (*creds.LoginDetails, error) {

// log.Printf("loginFlags %+v", loginFlags)

loginDetails := &creds.LoginDetails{URL: account.URL, Username: account.Username, MFAToken: loginFlags.CommonFlags.MFAToken, DuoMFAOption: loginFlags.DuoMFAOption}
loginDetails := &creds.LoginDetails{
URL: account.URL,
Username: account.Username,
MFAToken: loginFlags.CommonFlags.MFAToken,
DuoMFAOption: loginFlags.DuoMFAOption,
IdpName: account.Name,
IdpProvider: account.Provider,
}

log.Printf("Using IdP Account %s to access %s %s", loginFlags.CommonFlags.IdpAccount, account.Provider, account.URL)

var err error
if !loginFlags.CommonFlags.DisableKeychain {
err = credentials.LookupCredentials(loginDetails, account.Provider)
err = credentials.LookupCredentials(loginDetails)
if err != nil {
if !credentials.IsErrCredentialsNotFound(err) {
return nil, errors.Wrap(err, "Error loading saved password.")
Expand Down
59 changes: 52 additions & 7 deletions cmd/saml2aws/commands/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"time"

"github.com/stretchr/testify/assert"
"github.com/versent/saml2aws/v2"
saml2aws "github.com/versent/saml2aws/v2"
"github.com/versent/saml2aws/v2/pkg/awsconfig"
"github.com/versent/saml2aws/v2/pkg/cfg"
"github.com/versent/saml2aws/v2/pkg/creds"
Expand All @@ -15,10 +15,17 @@ import (

func TestResolveLoginDetailsWithFlags(t *testing.T) {

commonFlags := &flags.CommonFlags{URL: "https://id.example.com", Username: "wolfeidau", Password: "testtestlol", MFAToken: "123456", SkipPrompt: true}
commonFlags := &flags.CommonFlags{
URL: "https://id.example.com",
Username: "wolfeidau",
Password: "testtestlol",
MFAToken: "123456",
SkipPrompt: true,
}
loginFlags := &flags.LoginExecFlags{CommonFlags: commonFlags}

idpa := &cfg.IDPAccount{
Name: "AccountName",
URL: "https://id.example.com",
MFA: "none",
Provider: "Ping",
Expand All @@ -27,16 +34,31 @@ func TestResolveLoginDetailsWithFlags(t *testing.T) {
loginDetails, err := resolveLoginDetails(idpa, loginFlags)

assert.Empty(t, err)
assert.Equal(t, &creds.LoginDetails{Username: "wolfeidau", Password: "testtestlol", URL: "https://id.example.com", MFAToken: "123456"}, loginDetails)
assert.Equal(t,
&creds.LoginDetails{
IdpName: "AccountName",
IdpProvider: "Ping",
Username: "wolfeidau",
Password: "testtestlol",
URL: "https://id.example.com",
MFAToken: "123456",
}, loginDetails)
}

func TestOktaResolveLoginDetailsWithFlags(t *testing.T) {

// Default state - user did not supply values for DisableSessions and DisableSessions
commonFlags := &flags.CommonFlags{URL: "https://id.example.com", Username: "testuser", Password: "testtestlol", MFAToken: "123456", SkipPrompt: true}
commonFlags := &flags.CommonFlags{
URL: "https://id.example.com",
Username: "testuser",
Password: "testtestlol",
MFAToken: "123456",
SkipPrompt: true,
}
loginFlags := &flags.LoginExecFlags{CommonFlags: commonFlags}

idpa := &cfg.IDPAccount{
Name: "AnotherAccountName",
URL: "https://id.example.com",
MFA: "none",
Provider: "Okta",
Expand All @@ -47,19 +69,42 @@ func TestOktaResolveLoginDetailsWithFlags(t *testing.T) {
assert.Nil(t, err)
assert.False(t, idpa.DisableSessions, fmt.Errorf("default state, DisableSessions should be false"))
assert.False(t, idpa.DisableRememberDevice, fmt.Errorf("default state, DisableRememberDevice should be false"))
assert.Equal(t, &creds.LoginDetails{Username: "testuser", Password: "testtestlol", URL: "https://id.example.com", MFAToken: "123456"}, loginDetails)
assert.Equal(t,
&creds.LoginDetails{
IdpName: "AnotherAccountName",
IdpProvider: "Okta",
Username: "testuser",
Password: "testtestlol",
URL: "https://id.example.com",
MFAToken: "123456",
}, loginDetails)

// User disabled keychain, resolveLoginDetails should set the account's DisableSessions and DisableSessions fields to true

commonFlags = &flags.CommonFlags{URL: "https://id.example.com", Username: "testuser", Password: "testtestlol", MFAToken: "123456", SkipPrompt: true, DisableKeychain: true}
commonFlags = &flags.CommonFlags{
URL: "https://id.example.com",
Username: "testuser",
Password: "testtestlol",
MFAToken: "123456",
SkipPrompt: true,
DisableKeychain: true,
}
loginFlags = &flags.LoginExecFlags{CommonFlags: commonFlags}

loginDetails, err = resolveLoginDetails(idpa, loginFlags)

assert.Nil(t, err)
assert.True(t, idpa.DisableSessions, fmt.Errorf("user disabled keychain, DisableSessions should be true"))
assert.True(t, idpa.DisableRememberDevice, fmt.Errorf("user disabled keychain, DisableRememberDevice should be true"))
assert.Equal(t, &creds.LoginDetails{Username: "testuser", Password: "testtestlol", URL: "https://id.example.com", MFAToken: "123456"}, loginDetails)
assert.Equal(t,
&creds.LoginDetails{
IdpName: "AnotherAccountName",
IdpProvider: "Okta",
Username: "testuser",
Password: "testtestlol",
URL: "https://id.example.com",
MFAToken: "123456",
}, loginDetails)

}

Expand Down
23 changes: 17 additions & 6 deletions helper/credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package credentials

import (
"errors"
"fmt"
)

var (
Expand All @@ -14,25 +15,35 @@ var (

// Credentials holds the information shared between saml2aws and the credentials store.
type Credentials struct {
IdpName string
ServerURL string
Username string
Secret string
}

// CredsLabel saml2aws credentials should be labeled as such in credentials stores that allow labelling.
// That label allows to filter out non-Docker credentials too at lookup/search in macOS keychain,
// Windows credentials manager and Linux libsecret. Default value is "saml2aws Credentials"
var CredsLabel = "saml2aws Credentials"
const (
// CredsLabel saml2aws credentials should be labeled as such in credentials stores that allow labelling.
// That label allows to filter out non-Docker credentials too at lookup/search in macOS keychain,
// Windows credentials manager and Linux libsecret. Default value is "saml2aws Credentials"
CredsLabel = "saml2aws Credentials"
CredsKeyPrefix = "saml2aws_credentials"
OktaSessionCookieSuffix = "_okta_session"
OneLoginTokenSuffix = "_onelogin_token"
)

func GetKeyFromAccount(accountName string) string {
return fmt.Sprintf("%s_%s", CredsKeyPrefix, accountName)
}

// Helper is the interface a credentials store helper must implement.
type Helper interface {
// Add appends credentials to the store.
Add(*Credentials) error
// Delete removes credentials from the store.
Delete(serverURL string) error
Delete(idpName string) error
// Get retrieves credentials from the store.
// It returns username and secret as strings.
Get(serverURL string) (string, string, error)
Get(idpName string) (string, string, error)
// SupportsCredentialStorage returns true or false if there is credential storage.
SupportsCredentialStorage() bool
}
Expand Down
19 changes: 10 additions & 9 deletions helper/credentials/saml.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
package credentials

import (
"path"

"github.com/versent/saml2aws/v2/pkg/creds"
)

// LookupCredentials lookup an existing set of credentials and validate it.
func LookupCredentials(loginDetails *creds.LoginDetails, provider string) error {
func LookupCredentials(loginDetails *creds.LoginDetails) error {

username, password, err := CurrentHelper.Get(loginDetails.URL)
username, password, err := CurrentHelper.Get(loginDetails.IdpName)
if err != nil {
return err
}
Expand All @@ -18,15 +16,17 @@ func LookupCredentials(loginDetails *creds.LoginDetails, provider string) error
loginDetails.Password = password

// If the provider is Okta, check for existing Okta Session Cookie (sid)
if provider == "Okta" {
_, oktaSessionCookie, err := CurrentHelper.Get(loginDetails.URL + "/sessionCookie")
if loginDetails.IdpProvider == "Okta" {
// load up the Okta token from a different secret (idp name + Okta suffix)
_, oktaSessionCookie, err := CurrentHelper.Get(loginDetails.IdpName + OktaSessionCookieSuffix)
if err == nil {
loginDetails.OktaSessionCookie = oktaSessionCookie
}
}

if provider == "OneLogin" {
id, secret, err := CurrentHelper.Get(path.Join(loginDetails.URL, "/auth/oauth2/v2/token"))
if loginDetails.IdpProvider == "OneLogin" {
// load up the one login token from a different secret (idp name + one login suffix)
id, secret, err := CurrentHelper.Get(loginDetails.IdpName + OneLoginTokenSuffix)
if err != nil {
return err
}
Expand All @@ -37,9 +37,10 @@ func LookupCredentials(loginDetails *creds.LoginDetails, provider string) error
}

// SaveCredentials save the user credentials.
func SaveCredentials(url, username, password string) error {
func SaveCredentials(idpName, url, username, password string) error {

creds := &Credentials{
IdpName: idpName,
ServerURL: url,
Username: username,
Secret: password,
Expand Down
10 changes: 5 additions & 5 deletions helper/linuxkeyring/linuxkeyring_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,19 @@ func (kr *KeyringHelper) Add(creds *credentials.Credentials) error {
}

return kr.keyring.Set(keyring.Item{
Key: creds.ServerURL,
Key: credentials.GetKeyFromAccount(creds.IdpName),
Label: credentials.CredsLabel,
Data: encoded,
KeychainNotTrustApplication: false,
})
}

func (kr *KeyringHelper) Delete(serverURL string) error {
return kr.keyring.Remove(serverURL)
func (kr *KeyringHelper) Delete(idpName string) error {
return kr.keyring.Remove(credentials.GetKeyFromAccount(idpName))
}

func (kr *KeyringHelper) Get(serverURL string) (string, string, error) {
item, err := kr.keyring.Get(serverURL)
func (kr *KeyringHelper) Get(idpName string) (string, string, error) {
item, err := kr.keyring.Get(credentials.GetKeyFromAccount(idpName))
if err != nil {
logger.WithField("err", err).Error("keychain Get returned error")
return "", "", credentials.ErrCredentialsNotFound
Expand Down
32 changes: 11 additions & 21 deletions helper/osxkeychain/osxkeychain.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//go:build darwin && cgo
// +build darwin,cgo

package osxkeychain
Expand All @@ -18,14 +19,15 @@ type Osxkeychain struct{}

// Add adds new credentials to the keychain.
func (h Osxkeychain) Add(creds *credentials.Credentials) error {
err := h.Delete(creds.ServerURL)
err := h.Delete(creds.IdpName)
if err != nil {
logger.WithError(err).Debug("delete of existing keychain entry failed")
}

item := keychain.NewItem()
item.SetSecClass(keychain.SecClassInternetPassword)
item.SetLabel(credentials.CredsLabel)
item.SetLabel(credentials.GetKeyFromAccount(creds.IdpName))
item.SetString("Purpose", credentials.CredsLabel)
item.SetAccount(creds.Username)
item.SetData([]byte(creds.Secret))
err = splitServer3(creds.ServerURL, item)
Expand All @@ -43,36 +45,24 @@ func (h Osxkeychain) Add(creds *credentials.Credentials) error {
}

// Delete removes credentials from the keychain.
func (h Osxkeychain) Delete(serverURL string) error {
func (h Osxkeychain) Delete(idpName string) error {

item := keychain.NewItem()
item.SetSecClass(keychain.SecClassInternetPassword)
err := splitServer3(serverURL, item)
if err != nil {
return err
}

err = keychain.DeleteItem(item)
if err != nil {
return err
}

return nil
item.SetLabel(credentials.GetKeyFromAccount(idpName))
return keychain.DeleteItem(item)
}

// Get returns the username and secret to use for a given registry server URL.
func (h Osxkeychain) Get(serverURL string) (string, string, error) {
func (h Osxkeychain) Get(idpName string) (string, string, error) {

logger.WithField("serverURL", serverURL).Debug("Get credentials")
logger.WithField("Credential Key", idpName).Debug("Get credentials")

query := keychain.NewItem()
query.SetSecClass(keychain.SecClassInternetPassword)

err := splitServer3(serverURL, query)
if err != nil {
return "", "", err
}

// only search on the idp name
query.SetLabel(credentials.GetKeyFromAccount(idpName))
query.SetMatchLimit(keychain.MatchLimitOne)
query.SetReturnAttributes(true)
query.SetReturnData(true)
Expand Down
Loading

0 comments on commit c52b819

Please sign in to comment.