Skip to content

Commit

Permalink
Update azurekeyvault.go to support azure china
Browse files Browse the repository at this point in the history
  • Loading branch information
harshktpa authored Oct 28, 2024
1 parent b046a7d commit bc53269
Showing 1 changed file with 28 additions and 15 deletions.
43 changes: 28 additions & 15 deletions pkg/backends/azurekeyvault.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package backends
import (
"context"
"fmt"
"os" // Import the os package to access environment variables
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azsecrets"
"github.com/argoproj-labs/argocd-vault-plugin/pkg/utils"
"time"
)

// AzureKeyVault is a struct for working with an Azure Key Vault backend
Expand Down Expand Up @@ -37,17 +39,23 @@ func (a *AzureKeyVault) Login() error {
}

// GetSecrets gets secrets from Azure Key Vault and returns the formatted data
// For Azure Key Vault, `kvpath` is the unique name of your vault
// For Azure use the version here not make really sens as each secret have a different version but let support it
func (a *AzureKeyVault) GetSecrets(kvpath string, version string, _ map[string]string) (map[string]interface{}, error) {
kvpath = fmt.Sprintf("https://%s.vault.azure.net", kvpath)
// Check for the cloud environment variable
cloud := os.Getenv("AVP_AZ_CLOUD_NAME")
var vaultURL string

if cloud == "azurechina" {
vaultURL = fmt.Sprintf("https://%s.vault.azure.cn", kvpath)
} else {
vaultURL = fmt.Sprintf("https://%s.vault.azure.net", kvpath)
}

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

verboseOptionalVersion("Azure Key Vault list all secrets from vault %s", version, kvpath)
verboseOptionalVersion("Azure Key Vault list all secrets from vault %s", version, vaultURL)

client, err := a.ClientBuilder(kvpath, a.Credential, nil)
client, err := a.ClientBuilder(vaultURL, a.Credential, nil)
if err != nil {
return nil, err
}
Expand All @@ -61,22 +69,20 @@ func (a *AzureKeyVault) GetSecrets(kvpath string, version string, _ map[string]s
return nil, err
}
for _, secretVersion := range page.Value {
// Azure Key Vault has ability to enable/disable a secret, so lets honour that
if !*secretVersion.Attributes.Enabled {
continue
}
name := secretVersion.ID.Name()
// Secret version matched given version ?
if version == "" || secretVersion.ID.Version() == version {
verboseOptionalVersion("Azure Key Vault getting secret %s from vault %s", version, name, kvpath)
verboseOptionalVersion("Azure Key Vault getting secret %s from vault %s", version, name, vaultURL)
secret, err := client.GetSecret(ctx, name, version, nil)
if err != nil {
return nil, err
}
utils.VerboseToStdErr("Azure Key Vault get secret response %v", secret)
data[name] = *secret.Value
} else {
verboseOptionalVersion("Azure Key Vault getting secret %s from vault %s", version, name, kvpath)
verboseOptionalVersion("Azure Key Vault getting secret %s from vault %s", version, name, vaultURL)
secret, err := client.GetSecret(ctx, name, version, nil)
if err != nil || !*secretVersion.Attributes.Enabled {
utils.VerboseToStdErr("Azure Key Vault get versioned secret not found %s", err)
Expand All @@ -90,17 +96,24 @@ func (a *AzureKeyVault) GetSecrets(kvpath string, version string, _ map[string]s
return data, nil
}

// GetIndividualSecret will get the specific secret (placeholder) from the SM backend
// For Azure Key Vault, `kvpath` is the unique name of your vault
// Secrets (placeholders) are directly addressable via the API, so only one call is needed here
// GetIndividualSecret will get the specific secret from the SM backend
func (a *AzureKeyVault) GetIndividualSecret(kvpath, secret, version string, annotations map[string]string) (interface{}, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

verboseOptionalVersion("Azure Key Vault getting individual secret %s from vault %s", version, secret, kvpath)

kvpath = fmt.Sprintf("https://%s.vault.azure.net", kvpath)
client, err := a.ClientBuilder(kvpath, a.Credential, nil)
// Check for the cloud environment variable
cloud := os.Getenv("cloud")
var vaultURL string

if cloud == "azurechina" {
vaultURL = fmt.Sprintf("https://%s.vault.azure.cn", kvpath)
} else {
vaultURL = fmt.Sprintf("https://%s.vault.azure.net", kvpath)
}

client, err := a.ClientBuilder(vaultURL, a.Credential, nil)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit bc53269

Please sign in to comment.