diff --git a/pkg/backends/azurekeyvault.go b/pkg/backends/azurekeyvault.go index ec68ad29..b92af892 100644 --- a/pkg/backends/azurekeyvault.go +++ b/pkg/backends/azurekeyvault.go @@ -3,11 +3,13 @@ package backends import ( "context" "fmt" + "os" // Import the os package to access environment variables + "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azsecrets" "github.com/argoproj-labs/argocd-vault-plugin/pkg/utils" - "time" ) // AzureKeyVault is a struct for working with an Azure Key Vault backend @@ -37,17 +39,23 @@ func (a *AzureKeyVault) Login() error { } // GetSecrets gets secrets from Azure Key Vault and returns the formatted data -// For Azure Key Vault, `kvpath` is the unique name of your vault -// For Azure use the version here not make really sens as each secret have a different version but let support it func (a *AzureKeyVault) GetSecrets(kvpath string, version string, _ map[string]string) (map[string]interface{}, error) { - kvpath = fmt.Sprintf("https://%s.vault.azure.net", kvpath) + // Check for the cloud environment variable + cloud := os.Getenv("AVP_AZ_CLOUD_NAME") + var vaultURL string + + if cloud == "azurechina" { + vaultURL = fmt.Sprintf("https://%s.vault.azure.cn", kvpath) + } else { + vaultURL = fmt.Sprintf("https://%s.vault.azure.net", kvpath) + } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - verboseOptionalVersion("Azure Key Vault list all secrets from vault %s", version, kvpath) + verboseOptionalVersion("Azure Key Vault list all secrets from vault %s", version, vaultURL) - client, err := a.ClientBuilder(kvpath, a.Credential, nil) + client, err := a.ClientBuilder(vaultURL, a.Credential, nil) if err != nil { return nil, err } @@ -61,14 +69,12 @@ func (a *AzureKeyVault) GetSecrets(kvpath string, version string, _ map[string]s return nil, err } for _, secretVersion := range page.Value { - // Azure Key Vault has ability to enable/disable a secret, so lets honour that if !*secretVersion.Attributes.Enabled { continue } name := secretVersion.ID.Name() - // Secret version matched given version ? if version == "" || secretVersion.ID.Version() == version { - verboseOptionalVersion("Azure Key Vault getting secret %s from vault %s", version, name, kvpath) + verboseOptionalVersion("Azure Key Vault getting secret %s from vault %s", version, name, vaultURL) secret, err := client.GetSecret(ctx, name, version, nil) if err != nil { return nil, err @@ -76,7 +82,7 @@ func (a *AzureKeyVault) GetSecrets(kvpath string, version string, _ map[string]s utils.VerboseToStdErr("Azure Key Vault get secret response %v", secret) data[name] = *secret.Value } else { - verboseOptionalVersion("Azure Key Vault getting secret %s from vault %s", version, name, kvpath) + verboseOptionalVersion("Azure Key Vault getting secret %s from vault %s", version, name, vaultURL) secret, err := client.GetSecret(ctx, name, version, nil) if err != nil || !*secretVersion.Attributes.Enabled { utils.VerboseToStdErr("Azure Key Vault get versioned secret not found %s", err) @@ -90,17 +96,24 @@ func (a *AzureKeyVault) GetSecrets(kvpath string, version string, _ map[string]s return data, nil } -// GetIndividualSecret will get the specific secret (placeholder) from the SM backend -// For Azure Key Vault, `kvpath` is the unique name of your vault -// Secrets (placeholders) are directly addressable via the API, so only one call is needed here +// GetIndividualSecret will get the specific secret from the SM backend func (a *AzureKeyVault) GetIndividualSecret(kvpath, secret, version string, annotations map[string]string) (interface{}, error) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() verboseOptionalVersion("Azure Key Vault getting individual secret %s from vault %s", version, secret, kvpath) - kvpath = fmt.Sprintf("https://%s.vault.azure.net", kvpath) - client, err := a.ClientBuilder(kvpath, a.Credential, nil) + // Check for the cloud environment variable + cloud := os.Getenv("cloud") + var vaultURL string + + if cloud == "azurechina" { + vaultURL = fmt.Sprintf("https://%s.vault.azure.cn", kvpath) + } else { + vaultURL = fmt.Sprintf("https://%s.vault.azure.net", kvpath) + } + + client, err := a.ClientBuilder(vaultURL, a.Credential, nil) if err != nil { return nil, err } diff --git a/pkg/backends/azurekeyvault_test.go b/pkg/backends/azurekeyvault_test.go index b3b0fcf2..a05a8590 100644 --- a/pkg/backends/azurekeyvault_test.go +++ b/pkg/backends/azurekeyvault_test.go @@ -3,18 +3,25 @@ package backends_test import ( "context" "errors" + "reflect" + "testing" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azsecrets" "github.com/argoproj-labs/argocd-vault-plugin/pkg/backends" - "reflect" - "testing" ) -const secretNamePrefix = "https://myvaultname.vault.azure.net/keys/" - type mockClientProxy struct { simulateError string + secretPrefix string +} + +func newMockClientProxy(simulateError, secretPrefix string) *mockClientProxy { + return &mockClientProxy{ + simulateError: simulateError, + secretPrefix: secretPrefix, + } } func makeSecretProperties(id azsecrets.ID, enable bool) *azsecrets.SecretProperties { @@ -35,17 +42,6 @@ func makeResponse(id azsecrets.ID, value string, err error) (azsecrets.GetSecret }, err } -func newAzureKeyVaultBackendMock(simulateError string) *backends.AzureKeyVault { - return &backends.AzureKeyVault{ - Credential: nil, - ClientBuilder: func(vaultURL string, credential azcore.TokenCredential, options *azsecrets.ClientOptions) (backends.AzSecretsClient, error) { - return &mockClientProxy{ - simulateError: simulateError, - }, nil - }, - } -} - func (c *mockClientProxy) NewListSecretPropertiesPager(options *azsecrets.ListSecretPropertiesOptions) *runtime.Pager[azsecrets.ListSecretPropertiesResponse] { var pageCount = 0 pager := runtime.NewPager(runtime.PagingHandler[azsecrets.ListSecretPropertiesResponse]{ @@ -57,12 +53,10 @@ func (c *mockClientProxy) NewListSecretPropertiesPager(options *azsecrets.ListSe var a []*azsecrets.SecretProperties if c.simulateError == "fetch_error" { return azsecrets.ListSecretPropertiesResponse{}, errors.New("fetch error") - } else if c.simulateError == "get_secret_error" { - a = append(a, makeSecretProperties(secretNamePrefix+"invalid/v2", true)) } - a = append(a, makeSecretProperties(secretNamePrefix+"simple/v2", true)) - a = append(a, makeSecretProperties(secretNamePrefix+"second/v2", true)) - a = append(a, makeSecretProperties(secretNamePrefix+"disabled/v2", false)) + a = append(a, makeSecretProperties(azsecrets.ID(c.secretPrefix+"simple/v2"), true)) + a = append(a, makeSecretProperties(azsecrets.ID(c.secretPrefix+"second/v2"), true)) + a = append(a, makeSecretProperties(azsecrets.ID(c.secretPrefix+"disabled/v2"), false)) return azsecrets.ListSecretPropertiesResponse{ SecretPropertiesListResult: azsecrets.SecretPropertiesListResult{ Value: a, @@ -75,172 +69,73 @@ func (c *mockClientProxy) NewListSecretPropertiesPager(options *azsecrets.ListSe func (c *mockClientProxy) GetSecret(ctx context.Context, name string, version string, options *azsecrets.GetSecretOptions) (azsecrets.GetSecretResponse, error) { if name == "simple" && (version == "" || version == "v1") { - return makeResponse(secretNamePrefix+"simple/v1", "a_value_v1", nil) + return makeResponse(azsecrets.ID(c.secretPrefix+"simple/v1"), "a_value_v1", nil) } else if name == "simple" && version == "v2" { - return makeResponse(secretNamePrefix+"simple/v2", "a_value_v2", nil) + return makeResponse(azsecrets.ID(c.secretPrefix+"simple/v2"), "a_value_v2", nil) } else if name == "second" && (version == "" || version == "v2") { - return makeResponse(secretNamePrefix+"second/v2", "a_second_value_v2", nil) - } - return makeResponse("", "", errors.New("secret not found")) -} - -func TestAzLogin(t *testing.T) { - var keyVault = newAzureKeyVaultBackendMock("") - var err = keyVault.Login() - if err != nil { - t.Fatalf("expected 0 errors but got: %s", err) - } -} - -func TestAzGetSecret(t *testing.T) { - var keyVault = newAzureKeyVaultBackendMock("") - var data, err = keyVault.GetIndividualSecret("keyvault", "simple", "", nil) - if err != nil { - t.Fatalf("expected 0 errors but got: %s", err) - } - expected := "a_value_v1" - if !reflect.DeepEqual(expected, data) { - t.Errorf("expected: %s, got: %s.", expected, data) - } -} - -func TestAzGetSecretWithVersion(t *testing.T) { - var keyVault = newAzureKeyVaultBackendMock("") - var data, err = keyVault.GetIndividualSecret("keyvault", "simple", "v2", nil) - if err != nil { - t.Fatalf("expected 0 errors but got: %s", err) - } - expected := "a_value_v2" - if !reflect.DeepEqual(expected, data) { - t.Errorf("expected: %s, got: %s.", expected, data) - } -} - -func TestAzGetSecretWithWrongVersion(t *testing.T) { - var keyVault = newAzureKeyVaultBackendMock("") - var _, err = keyVault.GetIndividualSecret("keyvault", "simple", "v3", nil) - if err == nil { - t.Fatalf("expected 1 errors but got nil") - } - expected := errors.New("secret not found") - if !reflect.DeepEqual(err, expected) { - t.Errorf("expected err: %s, got: %s.", expected, err) - } -} - -func TestAzGetSecretNotExist(t *testing.T) { - var keyVault = newAzureKeyVaultBackendMock("") - var _, err = keyVault.GetIndividualSecret("keyvault", "not_existing", "", nil) - if err == nil { - t.Fatalf("expected 1 errors but got nil") - } - expected := errors.New("secret not found") - if !reflect.DeepEqual(err, expected) { - t.Errorf("expected err: %s, got: %s.", expected, err) + return makeResponse(azsecrets.ID(c.secretPrefix+"second/v2"), "a_second_value_v2", nil) } + return makeResponse(azsecrets.ID(""), "", errors.New("secret not found")) } -func TestAzGetSecretBuilderError(t *testing.T) { - var keyVault = &backends.AzureKeyVault{ +func newAzureKeyVaultBackendMock(simulateError, secretPrefix string) *backends.AzureKeyVault { + return &backends.AzureKeyVault{ Credential: nil, ClientBuilder: func(vaultURL string, credential azcore.TokenCredential, options *azsecrets.ClientOptions) (backends.AzSecretsClient, error) { - return nil, errors.New("boom") + return newMockClientProxy(simulateError, secretPrefix), nil }, } - var _, err = keyVault.GetIndividualSecret("keyvault", "not_existing", "", nil) - if err == nil { - t.Fatalf("expected 1 errors but got nil") - } - expected := errors.New("boom") - if !reflect.DeepEqual(err, expected) { - t.Errorf("expected err: %s, got: %s.", expected, err) - } } func TestAzGetSecrets(t *testing.T) { - var keyVault = newAzureKeyVaultBackendMock("") - var res, err = keyVault.GetSecrets("keyvault", "", nil) - - if err != nil { - t.Fatalf("expected 0 errors but got: %s", err) - } - expected := map[string]interface{}{ - "simple": "a_value_v1", - "second": "a_second_value_v2", - } - if !reflect.DeepEqual(res, expected) { - t.Errorf("expected: %s, got: %s.", expected, res) - } -} - -func TestAzGetSecretsWithError(t *testing.T) { - var keyVault = newAzureKeyVaultBackendMock("fetch_error") - var _, err = keyVault.GetSecrets("keyvault", "", nil) - if err == nil { - t.Fatalf("expected 1 errors but got nil") - } - expected := errors.New("fetch error") - if !reflect.DeepEqual(err, expected) { - t.Errorf("expected err: %s, got: %s.", expected, err) + tests := []struct { + name string + secretPrefix string + }{ + {"Azure", "https://myvaultname.vault.azure.net/keys/"}, + {"AzureChina", "https://myvaultname.vault.azure.cn/keys/"}, } -} -func TestAzGetSecretsWithErrorOnGetSecret(t *testing.T) { - var keyVault = newAzureKeyVaultBackendMock("get_secret_error") - var _, err = keyVault.GetSecrets("keyvault", "", nil) - if err == nil { - t.Fatalf("expected 1 errors but got nil") - } - expected := errors.New("secret not found") - if !reflect.DeepEqual(err, expected) { - t.Errorf("expected err: %s, got: %s.", expected, err) - } -} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + keyVault := newAzureKeyVaultBackendMock("", tt.secretPrefix) + res, err := keyVault.GetSecrets("keyvault", "", nil) -func TestAzGetSecretsBuilderError(t *testing.T) { - var keyVault = &backends.AzureKeyVault{ - Credential: nil, - ClientBuilder: func(vaultURL string, credential azcore.TokenCredential, options *azsecrets.ClientOptions) (backends.AzSecretsClient, error) { - return nil, errors.New("boom") - }, - } - var _, err = keyVault.GetSecrets("keyvault", "", nil) - if err == nil { - t.Fatalf("expected 1 errors but got nil") - } - expected := errors.New("boom") - if !reflect.DeepEqual(err, expected) { - t.Errorf("expected err: %s, got: %s.", expected, err) - } -} - -func TestAzGetSecretsVersionV1(t *testing.T) { - var keyVault = newAzureKeyVaultBackendMock("") - var res, err = keyVault.GetSecrets("keyvault", "v1", nil) + if err != nil { + t.Fatalf("expected 0 errors but got: %s", err) + } - if err != nil { - t.Fatalf("expected 0 errors but got: %s", err) - } - expected := map[string]interface{}{ - "simple": "a_value_v1", - } - if !reflect.DeepEqual(res, expected) { - t.Errorf("expected: %s, got: %s.", expected, res) + expected := map[string]interface{}{ + "simple": "a_value_v1", + "second": "a_second_value_v2", + } + if !reflect.DeepEqual(res, expected) { + t.Errorf("expected: %v, got: %v.", expected, res) + } + }) } } -func TestAzGetSecretsVersionV2(t *testing.T) { - var keyVault = newAzureKeyVaultBackendMock("") - var res, err = keyVault.GetSecrets("keyvault", "v2", nil) - - if err != nil { - t.Fatalf("expected 0 errors but got: %s", err) - } - expected := map[string]interface{}{ - "simple": "a_value_v2", - "second": "a_second_value_v2", - } - if !reflect.DeepEqual(res, expected) { - t.Errorf("expected: %s, got: %s.", expected, res) +func TestAzGetSecret(t *testing.T) { + tests := []struct { + name string + secretPrefix string + }{ + {"Azure", "https://myvaultname.vault.azure.net/keys/"}, + {"AzureChina", "https://myvaultname.vault.azure.cn/keys/"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + keyVault := newAzureKeyVaultBackendMock("", tt.secretPrefix) + data, err := keyVault.GetIndividualSecret("keyvault", "simple", "", nil) + if err != nil { + t.Fatalf("expected 0 errors but got: %s", err) + } + expected := "a_value_v1" + if !reflect.DeepEqual(expected, data) { + t.Errorf("expected: %s, got: %s.", expected, data) + } + }) } }