Skip to content

Commit

Permalink
feat: support --with-access-token for auth (#409)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrschumacher authored Oct 29, 2024
1 parent cdaae40 commit 856efa4
Show file tree
Hide file tree
Showing 12 changed files with 165 additions and 64 deletions.
2 changes: 1 addition & 1 deletion cmd/auth-clientCredentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ var clientCredentialsCmd = man.Docs.GetCommand("auth/client-credentials",

func auth_clientCredentials(cmd *cobra.Command, args []string) {
c := cli.New(cmd, args)
cp := InitProfile(c, false)
_, cp := InitProfile(c, false)

var clientId string
var clientSecret string
Expand Down
2 changes: 1 addition & 1 deletion cmd/auth-login.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

func auth_codeLogin(cmd *cobra.Command, args []string) {
c := cli.New(cmd, args)
cp := InitProfile(c, false)
_, cp := InitProfile(c, false)

c.Print("Initiating login...")
tok, publicClientID, err := auth.LoginWithPKCE(
Expand Down
2 changes: 1 addition & 1 deletion cmd/auth-logout.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

func auth_logout(cmd *cobra.Command, args []string) {
c := cli.New(cmd, args)
cp := InitProfile(c, false)
_, cp := InitProfile(c, false)
c.Println("Initiating logout...")

// we can only revoke access tokens stored for the code login flow, not client credentials
Expand Down
2 changes: 1 addition & 1 deletion cmd/auth-printAccessToken.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ var auth_printAccessTokenCmd = man.Docs.GetCommand("auth/print-access-token",

func auth_printAccessToken(cmd *cobra.Command, args []string) {
c := cli.New(cmd, args)
cp := InitProfile(c, false)
_, cp := InitProfile(c, false)

ac := cp.GetAuthCredentials()
switch ac.AuthType {
Expand Down
128 changes: 82 additions & 46 deletions cmd/root.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
/*
Copyright © 2023 NAME HERE <EMAIL ADDRESS>
*/
package cmd

import (
Expand Down Expand Up @@ -35,7 +32,10 @@ type version struct {
BuildTime string `json:"build_time"`
}

func InitProfile(c *cli.Cli, onlyNew bool) *profiles.ProfileStore {
// InitProfile initializes the profile store and loads the profile specified in the flags
// if onlyNew is set to true, a new profile will be created and returned
// returns the profile and the current profile store
func InitProfile(c *cli.Cli, onlyNew bool) (*profiles.Profile, *profiles.ProfileStore) {
var err error
profileName := c.FlagHelper.GetOptionalString("profile")

Expand All @@ -45,126 +45,157 @@ func InitProfile(c *cli.Cli, onlyNew bool) *profiles.ProfileStore {
}

// short circuit if onlyNew is set to enable creating a new profile
if onlyNew {
return nil
if onlyNew && profileName == "" {
return profile, nil
}

// check if there exists a default profile and warn if not with steps to create one
if profile.GetGlobalConfig().GetDefaultProfile() == "" {
c.ExitWithWarning("No default profile set. Use `" + config.AppName + " profile create <profile> <endpoint>` to create a default profile.")
c.ExitWithWarning(fmt.Sprintf("No default profile set. Use `%s profile create <profile> <endpoint>` to create a default profile.", config.AppName))
}
c.Printf("Using profile [%s]\n", profile.GetGlobalConfig().GetDefaultProfile())

if profileName == "" {
profileName = profile.GetGlobalConfig().GetDefaultProfile()
}

c.Printf("Using profile [%s]\n", profileName)

// load profile
cp, err := profile.UseProfile(profileName)
if err != nil {
c.ExitWithError("Failed to load profile "+profileName, err)
c.ExitWithError(fmt.Sprintf("Failed to load profile: %s", profileName), err)
}

return cp
return profile, cp
}

// instantiates a new handler with authentication via client credentials
// TODO make this a preRun hook
//
//nolint:nestif // separate refactor [https://github.com/opentdf/otdfctl/issues/383]
func NewHandler(c *cli.Cli) handlers.Handler {
// if global flags are set then validate and create a temporary profile in memory
var cp *profiles.ProfileStore

// Non-profile flags
host := c.FlagHelper.GetOptionalString("host")
tlsNoVerify := c.FlagHelper.GetOptionalBool("tls-no-verify")
withClientCreds := c.FlagHelper.GetOptionalString("with-client-creds")
withClientCredsFile := c.FlagHelper.GetOptionalString("with-client-creds-file")
withAccessToken := c.FlagHelper.GetOptionalString("with-access-token")
var inMemoryProfile bool

// if global flags are set then validate and create a temporary profile in memory
var cp *profiles.ProfileStore
authFlags := []string{"--with-access-token", "--with-client-creds", "--with-client-creds-file"}
nonProfileFlags := append([]string{"--host", "--tls-no-verify"}, authFlags...)
hasNonProfileFlags := host != "" || tlsNoVerify || withClientCreds != "" || withClientCredsFile != "" || withAccessToken != ""

//nolint:nestif // nested if statements are necessary for validation
if host != "" || tlsNoVerify || withClientCreds != "" || withClientCredsFile != "" {
err := errors.New(
"when using global flags --host, --tls-no-verify, --with-client-creds, or --with-client-creds-file, " +
"profiles will not be used and all required flags must be set",
)
if hasNonProfileFlags {
err := fmt.Errorf("when using global flags %s, profiles will not be used and all required flags must be set", cli.PrettyList(nonProfileFlags))

// host must be set
if host == "" {
cli.ExitWithError("Host must be set", err)
}

// either with-client-creds or with-client-creds-file must be set
if withClientCreds == "" && withClientCredsFile == "" {
cli.ExitWithError("Either --with-client-creds or --with-client-creds-file must be set", err)
} else if withClientCreds != "" && withClientCredsFile != "" {
cli.ExitWithError("Only one of --with-client-creds or --with-client-creds-file can be set", err)
authFlagsCounter := 0
if withAccessToken != "" {
authFlagsCounter++
}

var cc auth.ClientCredentials
if withClientCreds != "" {
cc, err = auth.GetClientCredsFromJSON([]byte(withClientCreds))
} else {
cc, err = auth.GetClientCredsFromFile(withClientCredsFile)
authFlagsCounter++
}
if err != nil {
cli.ExitWithError("Failed to get client credentials", err)
if withClientCredsFile != "" {
authFlagsCounter++
}
if authFlagsCounter == 0 {
cli.ExitWithError(fmt.Sprintf("One of %s must be set", cli.PrettyList(authFlags)), err)
} else if authFlagsCounter > 1 {
cli.ExitWithError(fmt.Sprintf("Only one of %s must be set", cli.PrettyList(authFlags)), err)
}

inMemoryProfile = true
profile, err = profiles.New(profiles.WithInMemoryStore())
if err != nil || profile == nil {
cli.ExitWithError("Failed to initialize a temporary profile", err)
cli.ExitWithError("Failed to initialize in-memory profile", err)
}

if err := profile.AddProfile("temp", host, tlsNoVerify, true); err != nil {
cli.ExitWithError("Failed to create temporary profile", err)
cli.ExitWithError("Failed to create in-memory profile", err)
}

// add credentials to the temporary profile
cp, err = profile.UseProfile("temp")
if err != nil {
cli.ExitWithError("Failed to load temporary profile", err)
cli.ExitWithError("Failed to load in-memory profile", err)
}

// add credentials to the temporary profile
if err := cp.SetAuthCredentials(profiles.AuthCredentials{
AuthType: profiles.PROFILE_AUTH_TYPE_CLIENT_CREDENTIALS,
ClientId: cc.ClientId,
ClientSecret: cc.ClientSecret,
}); err != nil {
cli.ExitWithError("Failed to set client credentials", err)
// get credentials from flags
if withAccessToken != "" {
claims, err := auth.ParseClaimsJWT(withAccessToken)
if err != nil {
cli.ExitWithError("Failed to get access token", err)
}

if err := cp.SetAuthCredentials(profiles.AuthCredentials{
AuthType: profiles.PROFILE_AUTH_TYPE_ACCESS_TOKEN,
AccessToken: profiles.AuthCredentialsAccessToken{
AccessToken: withAccessToken,
Expiration: claims.Expiration,
},
}); err != nil {
cli.ExitWithError("Failed to set access token", err)
}
} else {
var cc auth.ClientCredentials
if withClientCreds != "" {
cc, err = auth.GetClientCredsFromJSON([]byte(withClientCreds))
} else if withClientCredsFile != "" {
cc, err = auth.GetClientCredsFromFile(withClientCredsFile)
}
if err != nil {
cli.ExitWithError("Failed to get client credentials", err)
}

// add credentials to the temporary profile
if err := cp.SetAuthCredentials(profiles.AuthCredentials{
AuthType: profiles.PROFILE_AUTH_TYPE_CLIENT_CREDENTIALS,
ClientId: cc.ClientId,
ClientSecret: cc.ClientSecret,
}); err != nil {
cli.ExitWithError("Failed to set client credentials", err)
}
}
if err := cp.Save(); err != nil {
cli.ExitWithError("Failed to save profile", err)
}
} else {
cp = InitProfile(c, false)
profile, cp = InitProfile(c, false)
}

if err := auth.ValidateProfileAuthCredentials(c.Context(), cp); err != nil {
if errors.Is(err, auth.ErrPlatformConfigNotFound) {
cli.ExitWithError(fmt.Sprintf("Failed to get platform configuration. Is the platform accepting connections at '%s'?", cp.GetEndpoint()), nil)
}
if inMemoryProfile {
cli.ExitWithError("Failed to authenticate with flag-provided client credentials", err)
cli.ExitWithError("Failed to authenticate with flag-provided client credentials.", err)
}
if errors.Is(err, auth.ErrProfileCredentialsNotFound) {
cli.ExitWithWarning("Profile missing credentials. Please login or add client credentials.")
}

if errors.Is(err, auth.ErrAccessTokenExpired) {
cli.ExitWithWarning("Access token expired. Please login again.")
cli.ExitWithWarning("Access token expired. Please login or add flag-provided credentials.")
}
if errors.Is(err, auth.ErrAccessTokenNotFound) {
cli.ExitWithWarning("No access token found. Please login or add client credentials.")
cli.ExitWithWarning("No access token found. Please login or add flag-provided credentials.")
}
cli.ExitWithError("Failed to get access token", err)
cli.ExitWithError("Failed to get access token.", err)
}

h, err := handlers.New(handlers.WithProfile(cp))
if err != nil {
cli.ExitWithError("Failed to create handler", err)
cli.ExitWithError("Unexpected error", err)
}
return h
}
Expand All @@ -181,7 +212,7 @@ func init() {
BuildTime: config.BuildTime,
}

c.Println(config.AppName + " version " + config.Version + " (" + config.BuildTime + ") " + config.CommitSha)
c.Println(fmt.Sprintf("%s version %s (%s) %s", config.AppName, config.Version, config.BuildTime, config.CommitSha))
c.ExitWithJSON(v)
return
}
Expand Down Expand Up @@ -243,5 +274,10 @@ func init() {
rootCmd.GetDocFlag("with-client-creds").Default,
rootCmd.GetDocFlag("with-client-creds").Description,
)
RootCmd.PersistentFlags().String(
rootCmd.GetDocFlag("with-access-token").Name,
rootCmd.GetDocFlag("with-access-token").Default,
rootCmd.GetDocFlag("with-access-token").Description,
)
RootCmd.AddGroup(&cobra.Group{ID: TDF})
}
2 changes: 2 additions & 0 deletions docs/man/_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ command:
- fatal
- panic
default: info
- name: with-access-token
description: access token for authentication via bearer token
- name: with-client-creds-file
description: path to a JSON file containing a 'clientId' and 'clientSecret' for auth via client-credentials flow
- name: with-client-creds
Expand Down
9 changes: 5 additions & 4 deletions e2e/auth.bats
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ teardown_file() {
BAD_HOST='--host http://localhost:9000'
run_otdfctl $BAD_HOST $WITH_CREDS policy attributes list
assert_failure
assert_output --partial "Failed to get platform configuration. Is the platform accepting connections at 'http://localhost:9000'?"
assert_output --partial "Failed to get platform configuration. Is the platform accepting connections at"
}

@test "helpful error if bad credentials" {
Expand All @@ -43,17 +43,18 @@ teardown_file() {
BAD_CREDS="--with-client-creds '{clientId:"badClient",clientSecret:"badSecret"}'"
run_otdfctl $HOST $BAD_CREDS policy attributes list
assert_failure
assert_output --partial "Failed to get client credentials: failed to decode creds JSON"
assert_output --partial "Failed to get client credentials"
}

@test "helpful error if missing client credentials" {
run_otdfctl $HOST policy attributes list
assert_failure
assert_output --partial "Either --with-client-creds or --with-client-creds-file must be set: when using global flags --host, --tls-no-verify, --with-client-creds, or --with-client-creds-file, profiles will not be used and all required flags must be set"
assert_output --partial "One of"
assert_output --partial "must be set: when using global flags"
}

@test "helpful error if missing host" {
run_otdfctl $WITH_CREDS policy attributes list
assert_failure
assert_output --partial "Host must be set: when using global flags --host, --tls-no-verify, --with-client-creds, or --with-client-creds-file, profiles will not be used and all required flags must be set"
assert_output --partial "Host must be set: when using global flags"
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ require (
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/go-jose/go-jose/v3 v3.0.3 // indirect
github.com/go-jose/go-jose/v4 v4.0.4 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
Expand Down
Loading

0 comments on commit 856efa4

Please sign in to comment.