diff --git a/azuread/configuration.go b/azuread/configuration.go index 8c274a15..7876ae03 100644 --- a/azuread/configuration.go +++ b/azuread/configuration.go @@ -184,11 +184,11 @@ func (p *azureFedAuthConfig) provideActiveDirectoryToken(ctx context.Context, se case p.certificatePath != "": var certData []byte certData, err = os.ReadFile(p.certificatePath) - if err != nil { + if err == nil { var certs []*x509.Certificate var key crypto.PrivateKey certs, key, err = azidentity.ParseCertificates(certData, []byte(p.clientSecret)) - if err != nil { + if err == nil { cred, err = azidentity.NewClientCertificateCredential(tenant, p.clientID, certs, key, nil) } } diff --git a/azuread/configuration_test.go b/azuread/configuration_test.go index 0662df77..013a97aa 100644 --- a/azuread/configuration_test.go +++ b/azuread/configuration_test.go @@ -4,7 +4,13 @@ package azuread import ( + "context" + "errors" + "io/fs" + "net/url" + "os" "reflect" + "strings" "testing" mssql "github.com/microsoft/go-mssqldb" @@ -137,3 +143,73 @@ func TestValidateParameters(t *testing.T) { } } } + +func TestProvideActiveDirectoryTokenValidations(t *testing.T) { + nonExistentCertPath := os.TempDir() + "non_existent_cert.pem" + + f, err := os.CreateTemp("", "malformed_cert.pem") + if err != nil { + t.Fatalf("create temporary file: %v", err) + } + if err = f.Truncate(0); err != nil { + t.Fatalf("truncate temporary file: %v", err) + } + if _, err = f.Write([]byte("malformed")); err != nil { + t.Fatalf("write to temporary file: %v", err) + } + if err = f.Close(); err != nil { + t.Fatalf("close temporary file: %v", err) + } + malformedCertPath := f.Name() + t.Cleanup(func() { _ = os.Remove(malformedCertPath) }) + + tests := []struct { + name string + dsn string + expectedErr error + expectedErrContains string + }{ + { + name: "ActiveDirectoryServicePrincipal_cert_not_found", + dsn: `sqlserver://someserver.database.windows.net?` + + `user id=` + url.QueryEscape("my-app-id@my-tenant-id") + "&" + + `fedauth=ActiveDirectoryServicePrincipal` + "&" + + `clientcertpath=` + nonExistentCertPath + "&" + + `applicationclientid=someguid`, + expectedErr: fs.ErrNotExist, + }, + { + name: "ActiveDirectoryServicePrincipal_cert_malformed", + dsn: `sqlserver://someserver.database.windows.net?` + + `user id=` + url.QueryEscape("my-app-id@my-tenant-id") + "&" + + `fedauth=ActiveDirectoryServicePrincipal` + "&" + + `clientcertpath=` + malformedCertPath + "&" + + `applicationclientid=someguid`, + expectedErrContains: "error reading P12 data", + }, + } + for _, tst := range tests { + t.Run(tst.name, func(t *testing.T) { + config, err := parse(tst.dsn) + if err != nil { + t.Errorf("Unexpected parse error: %v", err) + return + } + _, err = config.provideActiveDirectoryToken(context.Background(), "", "authority/tenant") + if err == nil { + t.Errorf("Expected error but got nil") + return + } + if tst.expectedErr != nil { + if !errors.Is(err, tst.expectedErr) { + t.Errorf("Expected error '%v' but got err = %v", tst.expectedErr, err) + } + } + if tst.expectedErrContains != "" { + if !strings.Contains(err.Error(), tst.expectedErrContains) { + return + } + } + }) + } +}