Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(azure-china): support azure china vault backend endpoint #671

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 28 additions & 15 deletions pkg/backends/azurekeyvault.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package backends
import (
"context"
"fmt"
"os" // Import the os package to access environment variables
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azsecrets"
"github.com/argoproj-labs/argocd-vault-plugin/pkg/utils"
"time"
)

// AzureKeyVault is a struct for working with an Azure Key Vault backend
Expand Down Expand Up @@ -37,17 +39,23 @@ func (a *AzureKeyVault) Login() error {
}

// GetSecrets gets secrets from Azure Key Vault and returns the formatted data
// For Azure Key Vault, `kvpath` is the unique name of your vault
// For Azure use the version here not make really sens as each secret have a different version but let support it
func (a *AzureKeyVault) GetSecrets(kvpath string, version string, _ map[string]string) (map[string]interface{}, error) {
kvpath = fmt.Sprintf("https://%s.vault.azure.net", kvpath)
// Check for the cloud environment variable
cloud := os.Getenv("AVP_AZ_CLOUD_NAME")
var vaultURL string

if cloud == "azurechina" {
vaultURL = fmt.Sprintf("https://%s.vault.azure.cn", kvpath)
} else {
vaultURL = fmt.Sprintf("https://%s.vault.azure.net", kvpath)
}

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

verboseOptionalVersion("Azure Key Vault list all secrets from vault %s", version, kvpath)
verboseOptionalVersion("Azure Key Vault list all secrets from vault %s", version, vaultURL)

client, err := a.ClientBuilder(kvpath, a.Credential, nil)
client, err := a.ClientBuilder(vaultURL, a.Credential, nil)
if err != nil {
return nil, err
}
Expand All @@ -61,22 +69,20 @@ func (a *AzureKeyVault) GetSecrets(kvpath string, version string, _ map[string]s
return nil, err
}
for _, secretVersion := range page.Value {
// Azure Key Vault has ability to enable/disable a secret, so lets honour that
if !*secretVersion.Attributes.Enabled {
continue
}
name := secretVersion.ID.Name()
// Secret version matched given version ?
if version == "" || secretVersion.ID.Version() == version {
verboseOptionalVersion("Azure Key Vault getting secret %s from vault %s", version, name, kvpath)
verboseOptionalVersion("Azure Key Vault getting secret %s from vault %s", version, name, vaultURL)
secret, err := client.GetSecret(ctx, name, version, nil)
if err != nil {
return nil, err
}
utils.VerboseToStdErr("Azure Key Vault get secret response %v", secret)
data[name] = *secret.Value
} else {
verboseOptionalVersion("Azure Key Vault getting secret %s from vault %s", version, name, kvpath)
verboseOptionalVersion("Azure Key Vault getting secret %s from vault %s", version, name, vaultURL)
secret, err := client.GetSecret(ctx, name, version, nil)
if err != nil || !*secretVersion.Attributes.Enabled {
utils.VerboseToStdErr("Azure Key Vault get versioned secret not found %s", err)
Expand All @@ -90,17 +96,24 @@ func (a *AzureKeyVault) GetSecrets(kvpath string, version string, _ map[string]s
return data, nil
}

// GetIndividualSecret will get the specific secret (placeholder) from the SM backend
// For Azure Key Vault, `kvpath` is the unique name of your vault
// Secrets (placeholders) are directly addressable via the API, so only one call is needed here
// GetIndividualSecret will get the specific secret from the SM backend
func (a *AzureKeyVault) GetIndividualSecret(kvpath, secret, version string, annotations map[string]string) (interface{}, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

verboseOptionalVersion("Azure Key Vault getting individual secret %s from vault %s", version, secret, kvpath)

kvpath = fmt.Sprintf("https://%s.vault.azure.net", kvpath)
client, err := a.ClientBuilder(kvpath, a.Credential, nil)
// Check for the cloud environment variable
cloud := os.Getenv("cloud")
var vaultURL string

if cloud == "azurechina" {
vaultURL = fmt.Sprintf("https://%s.vault.azure.cn", kvpath)
} else {
vaultURL = fmt.Sprintf("https://%s.vault.azure.net", kvpath)
}

client, err := a.ClientBuilder(vaultURL, a.Credential, nil)
if err != nil {
return nil, err
}
Expand Down
231 changes: 63 additions & 168 deletions pkg/backends/azurekeyvault_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,25 @@ package backends_test
import (
"context"
"errors"
"reflect"
"testing"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azsecrets"
"github.com/argoproj-labs/argocd-vault-plugin/pkg/backends"
"reflect"
"testing"
)

const secretNamePrefix = "https://myvaultname.vault.azure.net/keys/"

type mockClientProxy struct {
simulateError string
secretPrefix string
}

func newMockClientProxy(simulateError, secretPrefix string) *mockClientProxy {
return &mockClientProxy{
simulateError: simulateError,
secretPrefix: secretPrefix,
}
}

func makeSecretProperties(id azsecrets.ID, enable bool) *azsecrets.SecretProperties {
Expand All @@ -35,17 +42,6 @@ func makeResponse(id azsecrets.ID, value string, err error) (azsecrets.GetSecret
}, err
}

func newAzureKeyVaultBackendMock(simulateError string) *backends.AzureKeyVault {
return &backends.AzureKeyVault{
Credential: nil,
ClientBuilder: func(vaultURL string, credential azcore.TokenCredential, options *azsecrets.ClientOptions) (backends.AzSecretsClient, error) {
return &mockClientProxy{
simulateError: simulateError,
}, nil
},
}
}

func (c *mockClientProxy) NewListSecretPropertiesPager(options *azsecrets.ListSecretPropertiesOptions) *runtime.Pager[azsecrets.ListSecretPropertiesResponse] {
var pageCount = 0
pager := runtime.NewPager(runtime.PagingHandler[azsecrets.ListSecretPropertiesResponse]{
Expand All @@ -57,12 +53,10 @@ func (c *mockClientProxy) NewListSecretPropertiesPager(options *azsecrets.ListSe
var a []*azsecrets.SecretProperties
if c.simulateError == "fetch_error" {
return azsecrets.ListSecretPropertiesResponse{}, errors.New("fetch error")
} else if c.simulateError == "get_secret_error" {
a = append(a, makeSecretProperties(secretNamePrefix+"invalid/v2", true))
}
a = append(a, makeSecretProperties(secretNamePrefix+"simple/v2", true))
a = append(a, makeSecretProperties(secretNamePrefix+"second/v2", true))
a = append(a, makeSecretProperties(secretNamePrefix+"disabled/v2", false))
a = append(a, makeSecretProperties(azsecrets.ID(c.secretPrefix+"simple/v2"), true))
a = append(a, makeSecretProperties(azsecrets.ID(c.secretPrefix+"second/v2"), true))
a = append(a, makeSecretProperties(azsecrets.ID(c.secretPrefix+"disabled/v2"), false))
return azsecrets.ListSecretPropertiesResponse{
SecretPropertiesListResult: azsecrets.SecretPropertiesListResult{
Value: a,
Expand All @@ -75,172 +69,73 @@ func (c *mockClientProxy) NewListSecretPropertiesPager(options *azsecrets.ListSe

func (c *mockClientProxy) GetSecret(ctx context.Context, name string, version string, options *azsecrets.GetSecretOptions) (azsecrets.GetSecretResponse, error) {
if name == "simple" && (version == "" || version == "v1") {
return makeResponse(secretNamePrefix+"simple/v1", "a_value_v1", nil)
return makeResponse(azsecrets.ID(c.secretPrefix+"simple/v1"), "a_value_v1", nil)
} else if name == "simple" && version == "v2" {
return makeResponse(secretNamePrefix+"simple/v2", "a_value_v2", nil)
return makeResponse(azsecrets.ID(c.secretPrefix+"simple/v2"), "a_value_v2", nil)
} else if name == "second" && (version == "" || version == "v2") {
return makeResponse(secretNamePrefix+"second/v2", "a_second_value_v2", nil)
}
return makeResponse("", "", errors.New("secret not found"))
}

func TestAzLogin(t *testing.T) {
var keyVault = newAzureKeyVaultBackendMock("")
var err = keyVault.Login()
if err != nil {
t.Fatalf("expected 0 errors but got: %s", err)
}
}

func TestAzGetSecret(t *testing.T) {
var keyVault = newAzureKeyVaultBackendMock("")
var data, err = keyVault.GetIndividualSecret("keyvault", "simple", "", nil)
if err != nil {
t.Fatalf("expected 0 errors but got: %s", err)
}
expected := "a_value_v1"
if !reflect.DeepEqual(expected, data) {
t.Errorf("expected: %s, got: %s.", expected, data)
}
}

func TestAzGetSecretWithVersion(t *testing.T) {
var keyVault = newAzureKeyVaultBackendMock("")
var data, err = keyVault.GetIndividualSecret("keyvault", "simple", "v2", nil)
if err != nil {
t.Fatalf("expected 0 errors but got: %s", err)
}
expected := "a_value_v2"
if !reflect.DeepEqual(expected, data) {
t.Errorf("expected: %s, got: %s.", expected, data)
}
}

func TestAzGetSecretWithWrongVersion(t *testing.T) {
var keyVault = newAzureKeyVaultBackendMock("")
var _, err = keyVault.GetIndividualSecret("keyvault", "simple", "v3", nil)
if err == nil {
t.Fatalf("expected 1 errors but got nil")
}
expected := errors.New("secret not found")
if !reflect.DeepEqual(err, expected) {
t.Errorf("expected err: %s, got: %s.", expected, err)
}
}

func TestAzGetSecretNotExist(t *testing.T) {
var keyVault = newAzureKeyVaultBackendMock("")
var _, err = keyVault.GetIndividualSecret("keyvault", "not_existing", "", nil)
if err == nil {
t.Fatalf("expected 1 errors but got nil")
}
expected := errors.New("secret not found")
if !reflect.DeepEqual(err, expected) {
t.Errorf("expected err: %s, got: %s.", expected, err)
return makeResponse(azsecrets.ID(c.secretPrefix+"second/v2"), "a_second_value_v2", nil)
}
return makeResponse(azsecrets.ID(""), "", errors.New("secret not found"))
}

func TestAzGetSecretBuilderError(t *testing.T) {
var keyVault = &backends.AzureKeyVault{
func newAzureKeyVaultBackendMock(simulateError, secretPrefix string) *backends.AzureKeyVault {
return &backends.AzureKeyVault{
Credential: nil,
ClientBuilder: func(vaultURL string, credential azcore.TokenCredential, options *azsecrets.ClientOptions) (backends.AzSecretsClient, error) {
return nil, errors.New("boom")
return newMockClientProxy(simulateError, secretPrefix), nil
},
}
var _, err = keyVault.GetIndividualSecret("keyvault", "not_existing", "", nil)
if err == nil {
t.Fatalf("expected 1 errors but got nil")
}
expected := errors.New("boom")
if !reflect.DeepEqual(err, expected) {
t.Errorf("expected err: %s, got: %s.", expected, err)
}
}

func TestAzGetSecrets(t *testing.T) {
var keyVault = newAzureKeyVaultBackendMock("")
var res, err = keyVault.GetSecrets("keyvault", "", nil)

if err != nil {
t.Fatalf("expected 0 errors but got: %s", err)
}
expected := map[string]interface{}{
"simple": "a_value_v1",
"second": "a_second_value_v2",
}
if !reflect.DeepEqual(res, expected) {
t.Errorf("expected: %s, got: %s.", expected, res)
}
}

func TestAzGetSecretsWithError(t *testing.T) {
var keyVault = newAzureKeyVaultBackendMock("fetch_error")
var _, err = keyVault.GetSecrets("keyvault", "", nil)
if err == nil {
t.Fatalf("expected 1 errors but got nil")
}
expected := errors.New("fetch error")
if !reflect.DeepEqual(err, expected) {
t.Errorf("expected err: %s, got: %s.", expected, err)
tests := []struct {
name string
secretPrefix string
}{
{"Azure", "https://myvaultname.vault.azure.net/keys/"},
{"AzureChina", "https://myvaultname.vault.azure.cn/keys/"},
}
}

func TestAzGetSecretsWithErrorOnGetSecret(t *testing.T) {
var keyVault = newAzureKeyVaultBackendMock("get_secret_error")
var _, err = keyVault.GetSecrets("keyvault", "", nil)
if err == nil {
t.Fatalf("expected 1 errors but got nil")
}
expected := errors.New("secret not found")
if !reflect.DeepEqual(err, expected) {
t.Errorf("expected err: %s, got: %s.", expected, err)
}
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
keyVault := newAzureKeyVaultBackendMock("", tt.secretPrefix)
res, err := keyVault.GetSecrets("keyvault", "", nil)

func TestAzGetSecretsBuilderError(t *testing.T) {
var keyVault = &backends.AzureKeyVault{
Credential: nil,
ClientBuilder: func(vaultURL string, credential azcore.TokenCredential, options *azsecrets.ClientOptions) (backends.AzSecretsClient, error) {
return nil, errors.New("boom")
},
}
var _, err = keyVault.GetSecrets("keyvault", "", nil)
if err == nil {
t.Fatalf("expected 1 errors but got nil")
}
expected := errors.New("boom")
if !reflect.DeepEqual(err, expected) {
t.Errorf("expected err: %s, got: %s.", expected, err)
}
}

func TestAzGetSecretsVersionV1(t *testing.T) {
var keyVault = newAzureKeyVaultBackendMock("")
var res, err = keyVault.GetSecrets("keyvault", "v1", nil)
if err != nil {
t.Fatalf("expected 0 errors but got: %s", err)
}

if err != nil {
t.Fatalf("expected 0 errors but got: %s", err)
}
expected := map[string]interface{}{
"simple": "a_value_v1",
}
if !reflect.DeepEqual(res, expected) {
t.Errorf("expected: %s, got: %s.", expected, res)
expected := map[string]interface{}{
"simple": "a_value_v1",
"second": "a_second_value_v2",
}
if !reflect.DeepEqual(res, expected) {
t.Errorf("expected: %v, got: %v.", expected, res)
}
})
}
}

func TestAzGetSecretsVersionV2(t *testing.T) {
var keyVault = newAzureKeyVaultBackendMock("")
var res, err = keyVault.GetSecrets("keyvault", "v2", nil)

if err != nil {
t.Fatalf("expected 0 errors but got: %s", err)
}
expected := map[string]interface{}{
"simple": "a_value_v2",
"second": "a_second_value_v2",
}
if !reflect.DeepEqual(res, expected) {
t.Errorf("expected: %s, got: %s.", expected, res)
func TestAzGetSecret(t *testing.T) {
tests := []struct {
name string
secretPrefix string
}{
{"Azure", "https://myvaultname.vault.azure.net/keys/"},
{"AzureChina", "https://myvaultname.vault.azure.cn/keys/"},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
keyVault := newAzureKeyVaultBackendMock("", tt.secretPrefix)
data, err := keyVault.GetIndividualSecret("keyvault", "simple", "", nil)
if err != nil {
t.Fatalf("expected 0 errors but got: %s", err)
}
expected := "a_value_v1"
if !reflect.DeepEqual(expected, data) {
t.Errorf("expected: %s, got: %s.", expected, data)
}
})
}
}