diff --git a/appveyor.yml b/appveyor.yml index 3de25a75..19121093 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -57,7 +57,7 @@ install: - go env - go get -u github.com/golang-sql/civil - go get -u github.com/golang-sql/sqlexp - + - go get -u golang.org/x/crypto/md4 build_script: - go build diff --git a/auth_unix.go b/auth_unix.go new file mode 100644 index 00000000..79be775d --- /dev/null +++ b/auth_unix.go @@ -0,0 +1,15 @@ +// +build !windows + +package mssql + +import ( + "github.com/microsoft/go-mssqldb/integratedauth" + // nolint importing the ntlm package causes it to be registered as an available authentication provider + _ "github.com/microsoft/go-mssqldb/integratedauth/ntlm" +) + +func init() { + // we set the default authentication provider name here, rather than within each imported package, + // to force a known default. Go will order execution of init() calls but it is better to be explicit. + integratedauth.DefaultProviderName = "ntlm" +} diff --git a/auth_windows.go b/auth_windows.go new file mode 100644 index 00000000..8ed454f0 --- /dev/null +++ b/auth_windows.go @@ -0,0 +1,18 @@ +// +build windows + +package mssql + +import ( + "github.com/microsoft/go-mssqldb/integratedauth" + + // nolint importing the ntlm package causes it to be registered as an available authentication provider + _ "github.com/microsoft/go-mssqldb/integratedauth/ntlm" + // nolint importing the winsspi package causes it to be registered as an available authentication provider + _ "github.com/microsoft/go-mssqldb/integratedauth/winsspi" +) + +func init() { + // we set the default authentication provider name here, rather than within each imported package, + // to force a known default. Go will order execution of init() calls but it is better to be explicit. + integratedauth.DefaultProviderName = "winsspi" +} diff --git a/azuread/configuration.go b/azuread/configuration.go index 041c0efa..e3c2eb70 100644 --- a/azuread/configuration.go +++ b/azuread/configuration.go @@ -53,7 +53,7 @@ type azureFedAuthConfig struct { // parse returns a config based on an msdsn-style connection string func parse(dsn string) (*azureFedAuthConfig, error) { - mssqlConfig, params, err := msdsn.Parse(dsn) + mssqlConfig, err := msdsn.Parse(dsn) if err != nil { return nil, err } @@ -62,7 +62,7 @@ func parse(dsn string) (*azureFedAuthConfig, error) { mssqlConfig: mssqlConfig, } - err = config.validateParameters(params) + err = config.validateParameters(mssqlConfig.Parameters) if err != nil { return nil, err } diff --git a/azuread/configuration_test.go b/azuread/configuration_test.go index 12f32821..e830d17d 100644 --- a/azuread/configuration_test.go +++ b/azuread/configuration_test.go @@ -1,128 +1,128 @@ -//go:build go1.18 -// +build go1.18 - -package azuread - -import ( - "testing" - - mssql "github.com/microsoft/go-mssqldb" - "github.com/microsoft/go-mssqldb/msdsn" -) - -func TestValidateParameters(t *testing.T) { - passphrase := "somesecret" - certificatepath := "/user/cert/cert.pfx" - appid := "applicationclientid=someguid" - certprop := "clientcertpath=" + certificatepath - tests := []struct { - name string - dsn string - expected *azureFedAuthConfig - }{ - { - name: "no fed auth configured", - dsn: "server=someserver", - expected: &azureFedAuthConfig{fedAuthLibrary: mssql.FedAuthLibraryReserved}, - }, - { - name: "application with cert/key", - dsn: `sqlserver://service-principal-id%40tenant-id:somesecret@someserver.database.windows.net?fedauth=ActiveDirectoryApplication&` + certprop + "&" + appid, - expected: &azureFedAuthConfig{ - fedAuthLibrary: mssql.FedAuthLibraryADAL, - clientID: "service-principal-id", - tenantID: "tenant-id", - certificatePath: certificatepath, - clientSecret: passphrase, - adalWorkflow: mssql.FedAuthADALWorkflowPassword, - fedAuthWorkflow: ActiveDirectoryApplication, - applicationClientID: "someguid", - }, - }, - { - name: "application with cert/key missing tenant id", - dsn: "server=someserver.database.windows.net;fedauth=ActiveDirectoryApplication;user id=service-principal-id;password=somesecret;" + certprop + ";" + appid, - expected: &azureFedAuthConfig{ - fedAuthLibrary: mssql.FedAuthLibraryADAL, - clientID: "service-principal-id", - certificatePath: certificatepath, - clientSecret: passphrase, - adalWorkflow: mssql.FedAuthADALWorkflowPassword, - fedAuthWorkflow: ActiveDirectoryApplication, - applicationClientID: "someguid", - }, - }, - { - name: "application with secret", - dsn: "server=someserver.database.windows.net;fedauth=ActiveDirectoryServicePrincipal;user id=service-principal-id@tenant-id;password=somesecret;", - expected: &azureFedAuthConfig{ - clientID: "service-principal-id", - tenantID: "tenant-id", - clientSecret: passphrase, - adalWorkflow: mssql.FedAuthADALWorkflowPassword, - fedAuthWorkflow: ActiveDirectoryServicePrincipal, - }, - }, - { - name: "user with password", - dsn: "server=someserver.database.windows.net;fedauth=ActiveDirectoryPassword;user id=azure-ad-user@example.com;password=somesecret;" + appid, - expected: &azureFedAuthConfig{ - adalWorkflow: mssql.FedAuthADALWorkflowPassword, - user: "azure-ad-user@example.com", - password: passphrase, - applicationClientID: "someguid", - fedAuthWorkflow: ActiveDirectoryPassword, - }, - }, - { - name: "managed identity without client id", - dsn: "server=someserver.database.windows.net;fedauth=ActiveDirectoryMSI", - expected: &azureFedAuthConfig{ - adalWorkflow: mssql.FedAuthADALWorkflowMSI, - fedAuthWorkflow: ActiveDirectoryMSI, - }, - }, - { - name: "managed identity with client id", - dsn: "server=someserver.database.windows.net;fedauth=ActiveDirectoryManagedIdentity;user id=identity-client-id", - expected: &azureFedAuthConfig{ - adalWorkflow: mssql.FedAuthADALWorkflowMSI, - clientID: "identity-client-id", - fedAuthWorkflow: ActiveDirectoryManagedIdentity, - }, - }, - { - name: "managed identity with resource id", - dsn: "server=someserver.database.windows.net;fedauth=ActiveDirectoryManagedIdentity;resource id=/subscriptions/{guid}/resourceGroups/{resource-group-name}/{resource-provider-namespace}/{resource-type}/{resource-name}", - expected: &azureFedAuthConfig{ - adalWorkflow: mssql.FedAuthADALWorkflowMSI, - resourceID: "/subscriptions/{guid}/resourceGroups/{resource-group-name}/{resource-provider-namespace}/{resource-type}/{resource-name}", - fedAuthWorkflow: ActiveDirectoryManagedIdentity, - }, - }, - } - for _, tst := range tests { - config, err := parse(tst.dsn) - if tst.expected == nil { - if err == nil { - t.Errorf("No error returned when error expected in test case '%s'", tst.name) - } - continue - } - if err != nil { - t.Errorf("Error returned when none expected in test case '%s': %v", tst.name, err) - continue - } - if tst.expected.fedAuthLibrary != mssql.FedAuthLibraryReserved { - if tst.expected.fedAuthLibrary == 0 { - tst.expected.fedAuthLibrary = mssql.FedAuthLibraryADAL - } - } - // mssqlConfig is not idempotent due to pointers in it, plus we aren't testing its correctness here - config.mssqlConfig = msdsn.Config{} - if *config != *tst.expected { - t.Errorf("Captured parameters do not match in test case '%s'. Expected:%+v, Actual:%+v", tst.name, tst.expected, config) - } - } - -} +//go:build go1.18 +// +build go1.18 + +package azuread + +import ( + "reflect" + "testing" + + mssql "github.com/microsoft/go-mssqldb" + "github.com/microsoft/go-mssqldb/msdsn" +) + +func TestValidateParameters(t *testing.T) { + passphrase := "somesecret" + certificatepath := "/user/cert/cert.pfx" + appid := "applicationclientid=someguid" + certprop := "clientcertpath=" + certificatepath + tests := []struct { + name string + dsn string + expected *azureFedAuthConfig + }{ + { + name: "no fed auth configured", + dsn: "server=someserver", + expected: &azureFedAuthConfig{fedAuthLibrary: mssql.FedAuthLibraryReserved}, + }, + { + name: "application with cert/key", + dsn: `sqlserver://service-principal-id%40tenant-id:somesecret@someserver.database.windows.net?fedauth=ActiveDirectoryApplication&` + certprop + "&" + appid, + expected: &azureFedAuthConfig{ + fedAuthLibrary: mssql.FedAuthLibraryADAL, + clientID: "service-principal-id", + tenantID: "tenant-id", + certificatePath: certificatepath, + clientSecret: passphrase, + adalWorkflow: mssql.FedAuthADALWorkflowPassword, + fedAuthWorkflow: ActiveDirectoryApplication, + applicationClientID: "someguid", + }, + }, + { + name: "application with cert/key missing tenant id", + dsn: "server=someserver.database.windows.net;fedauth=ActiveDirectoryApplication;user id=service-principal-id;password=somesecret;" + certprop + ";" + appid, + expected: &azureFedAuthConfig{ + fedAuthLibrary: mssql.FedAuthLibraryADAL, + clientID: "service-principal-id", + certificatePath: certificatepath, + clientSecret: passphrase, + adalWorkflow: mssql.FedAuthADALWorkflowPassword, + fedAuthWorkflow: ActiveDirectoryApplication, + applicationClientID: "someguid", + }, + }, + { + name: "application with secret", + dsn: "server=someserver.database.windows.net;fedauth=ActiveDirectoryServicePrincipal;user id=service-principal-id@tenant-id;password=somesecret;", + expected: &azureFedAuthConfig{ + clientID: "service-principal-id", + tenantID: "tenant-id", + clientSecret: passphrase, + adalWorkflow: mssql.FedAuthADALWorkflowPassword, + fedAuthWorkflow: ActiveDirectoryServicePrincipal, + }, + }, + { + name: "user with password", + dsn: "server=someserver.database.windows.net;fedauth=ActiveDirectoryPassword;user id=azure-ad-user@example.com;password=somesecret;" + appid, + expected: &azureFedAuthConfig{ + adalWorkflow: mssql.FedAuthADALWorkflowPassword, + user: "azure-ad-user@example.com", + password: passphrase, + applicationClientID: "someguid", + fedAuthWorkflow: ActiveDirectoryPassword, + }, + }, + { + name: "managed identity without client id", + dsn: "server=someserver.database.windows.net;fedauth=ActiveDirectoryMSI", + expected: &azureFedAuthConfig{ + adalWorkflow: mssql.FedAuthADALWorkflowMSI, + fedAuthWorkflow: ActiveDirectoryMSI, + }, + }, + { + name: "managed identity with client id", + dsn: "server=someserver.database.windows.net;fedauth=ActiveDirectoryManagedIdentity;user id=identity-client-id", + expected: &azureFedAuthConfig{ + adalWorkflow: mssql.FedAuthADALWorkflowMSI, + clientID: "identity-client-id", + fedAuthWorkflow: ActiveDirectoryManagedIdentity, + }, + }, + { + name: "managed identity with resource id", + dsn: "server=someserver.database.windows.net;fedauth=ActiveDirectoryManagedIdentity;resource id=/subscriptions/{guid}/resourceGroups/{resource-group-name}/{resource-provider-namespace}/{resource-type}/{resource-name}", + expected: &azureFedAuthConfig{ + adalWorkflow: mssql.FedAuthADALWorkflowMSI, + resourceID: "/subscriptions/{guid}/resourceGroups/{resource-group-name}/{resource-provider-namespace}/{resource-type}/{resource-name}", + fedAuthWorkflow: ActiveDirectoryManagedIdentity, + }, + }, + } + for _, tst := range tests { + config, err := parse(tst.dsn) + if tst.expected == nil { + if err == nil { + t.Errorf("No error returned when error expected in test case '%s'", tst.name) + } + continue + } + if err != nil { + t.Errorf("Error returned when none expected in test case '%s': %v", tst.name, err) + continue + } + if tst.expected.fedAuthLibrary != mssql.FedAuthLibraryReserved { + if tst.expected.fedAuthLibrary == 0 { + tst.expected.fedAuthLibrary = mssql.FedAuthLibraryADAL + } + } + // mssqlConfig is not idempotent due to pointers in it, plus we aren't testing its correctness here + config.mssqlConfig = msdsn.Config{} + if !reflect.DeepEqual(config, tst.expected) { + t.Errorf("Captured parameters do not match in test case '%s'. Expected:%+v, Actual:%+v", tst.name, tst.expected, config) + } + } +} diff --git a/integratedauth/auth.go b/integratedauth/auth.go new file mode 100644 index 00000000..0d68da75 --- /dev/null +++ b/integratedauth/auth.go @@ -0,0 +1,73 @@ +package integratedauth + +import ( + "errors" + "fmt" + + "github.com/microsoft/go-mssqldb/msdsn" +) + +var ( + providers map[string]Provider + DefaultProviderName string + + ErrProviderCannotBeNil = errors.New("provider cannot be nil") + ErrProviderNameMustBePopulated = errors.New("provider name must be populated") +) + +func init() { + providers = make(map[string]Provider) +} + +// GetIntegratedAuthenticator calls the authProvider specified in the 'authenticator' connection string parameter, if supplied. +// Otherwise fails back to the DefaultProviderName implementation for the platform. +func GetIntegratedAuthenticator(config msdsn.Config) (IntegratedAuthenticator, error) { + authenticatorName, ok := config.Parameters["authenticator"] + if !ok { + provider, err := getProvider(DefaultProviderName) + if err != nil { + return nil, err + } + + p, err := provider.GetIntegratedAuthenticator(config) + // we ignore the error in this case to force a fallback to sqlserver authentication. + // this preserves the original behaviour + if err != nil { + return nil, nil + } + + return p, nil + } + + provider, err := getProvider(authenticatorName) + if err != nil { + return nil, err + } + + return provider.GetIntegratedAuthenticator(config) +} + +func getProvider(name string) (Provider, error) { + provider, ok := providers[name] + + if !ok { + return nil, fmt.Errorf("provider %v not found", name) + } + + return provider, nil +} + +// SetIntegratedAuthenticationProvider stores a named authentication provider. It should be called before any connections are created. +func SetIntegratedAuthenticationProvider(providerName string, p Provider) error { + if p == nil { + return ErrProviderCannotBeNil + } + + if providerName == "" { + return ErrProviderNameMustBePopulated + } + + providers[providerName] = p + + return nil +} diff --git a/integratedauth/auth_test.go b/integratedauth/auth_test.go new file mode 100644 index 00000000..a3d25e30 --- /dev/null +++ b/integratedauth/auth_test.go @@ -0,0 +1,203 @@ +package integratedauth + +import ( + "errors" + "fmt" + "testing" + + "github.com/microsoft/go-mssqldb/msdsn" +) + +const providerName = "stub" + +type stubAuth struct { + user string +} + +func (s *stubAuth) InitialBytes() ([]byte, error) { return nil, nil } +func (s *stubAuth) NextBytes([]byte) ([]byte, error) { return nil, nil } +func (s *stubAuth) Free() {} + +func getAuth(config msdsn.Config) (IntegratedAuthenticator, error) { + return &stubAuth{config.User}, nil +} + +func TestSetIntegratedAuthenticationProviderReturnsErrOnNilProvider(t *testing.T) { + err := SetIntegratedAuthenticationProvider(providerName, nil) + + if err != ErrProviderCannotBeNil { + t.Errorf("SetIntegratedAuthenticationProvider() returned err: %v, want %v", err, ErrProviderCannotBeNil) + } +} + +func TestSetIntegratedAuthenticationProviderReturnsErrOnEmptyProviderName(t *testing.T) { + err := SetIntegratedAuthenticationProvider("", ProviderFunc(getAuth)) + + if err != ErrProviderNameMustBePopulated { + t.Errorf("SetIntegratedAuthenticationProvider() returned err: %v, want %v", err, ErrProviderNameMustBePopulated) + } +} + +func TestSetIntegratedAuthenticationProviderStored(t *testing.T) { + err := SetIntegratedAuthenticationProvider(providerName, ProviderFunc(getAuth)) + if err != nil { + t.Errorf("SetIntegratedAuthenticationProvider() returned unexpected err %v", err) + } + defer removeStubProvider() + + if _, ok := providers[providerName]; !ok { + t.Error("SetIntegratedAuthenticationProvider() added provider not found") + } +} + +func TestSetIntegratedAuthenticationProviderInstanceIsPassedConnString(t *testing.T) { + err := SetIntegratedAuthenticationProvider(providerName, ProviderFunc(getAuth)) + if err != nil { + t.Errorf("SetIntegratedAuthenticationProvider() returned unexpected err %v", err) + } + defer removeStubProvider() + + config, err := msdsn.Parse(fmt.Sprintf("authenticator=%v;user id=username", providerName)) + if err != nil { + t.Errorf("msdsn.Parse : Unexpected error %v", err) + return + } + + authenticator, err := GetIntegratedAuthenticator(config) + + if err != nil { + t.Errorf("expected GetIntegratedAuthenticator() to return ok, found %v", err) + } + + a, ok := authenticator.(*stubAuth) + if !ok { + t.Errorf("expected result of GetIntegratedAuthenticator() to be an instance of stubAuth") + } + + if a.user != "username" { + t.Errorf("expected stubAuth username to be correct") + } +} + +func TestSetIntegratedAuthenticationProviderInstanceIsDefaultWhenAuthenticatorParamNotPassed(t *testing.T) { + removeStubProvider() + + config, err := msdsn.Parse("user id=username") + if err != nil { + t.Errorf("msdsn.Parse : Unexpected error %v", err) + return + } + + DefaultProviderName = "DEFAULT_PROVIDER" + defer func() { DefaultProviderName = "" }() + + err = SetIntegratedAuthenticationProvider(DefaultProviderName, ProviderFunc(func(config msdsn.Config) (IntegratedAuthenticator, error) { + return &stubAuth{"DEFAULT INSTANCE"}, nil + })) + if err != nil { + t.Errorf("SetIntegratedAuthenticationProvider() returned unexpected err %v", err) + } + + result, err := GetIntegratedAuthenticator(config) + + if err != nil { + t.Errorf("expected GetIntegratedAuthenticator() to return ok, found %v", err) + } + + a, ok := result.(*stubAuth) + if !ok { + t.Errorf("expected result of GetIntegratedAuthenticator() to be an instance of stubAuth") + } + + if a.user != "DEFAULT INSTANCE" { + t.Errorf("expected GetIntegratedAuthenticator for return DefaultProviderName instance when no authenticator param is passed, found %v", a.user) + } +} + +func TestGetIntegratedAuthenticatorFallBackToSqlAuthOnErrorOfDefaultProvider(t *testing.T) { + removeStubProvider() + + config, err := msdsn.Parse("user id=username") + if err != nil { + t.Errorf("msdsn.Parse : Unexpected error %v", err) + return + } + + DefaultProviderName = "DEFAULT_PROVIDER" + defer func() { DefaultProviderName = "" }() + + err = SetIntegratedAuthenticationProvider(DefaultProviderName, ProviderFunc(func(config msdsn.Config) (IntegratedAuthenticator, error) { + return nil, errors.New("default authenticator cant continue") + })) + if err != nil { + t.Errorf("SetIntegratedAuthenticationProvider() returned unexpected err %v", err) + } + + result, err := GetIntegratedAuthenticator(config) + + if err != nil { + t.Errorf("expected GetIntegratedAuthenticator() to return ok, found %v", err) + } + + if result != nil { + t.Errorf("expected GetIntegratedAuthenticator() to return nill authenticator, found %v", result) + } +} + +func TestGetIntegratedAuthenticatorToErrorWhenNoDefaultProviderFound(t *testing.T) { + removeStubProvider() + + // dont set an authenticator + config, err := msdsn.Parse("user id=username") + if err != nil { + t.Errorf("msdsn.Parse : Unexpected error %v", err) + return + } + + DefaultProviderName = "NONEXISTANT_DEFAULT_PROVIDER" + defer func() { DefaultProviderName = "" }() + + result, err := GetIntegratedAuthenticator(config) + + if err == nil { + t.Error("expected GetIntegratedAuthenticator() to return error, found nil") + } + + if result != nil { + t.Errorf("expected GetIntegratedAuthenticator() to return nill provider, found %v", result) + } + + if err != nil && err.Error() != "provider NONEXISTANT_DEFAULT_PROVIDER not found" { + t.Errorf("expected err that default provider was not found, found %v", err) + } +} + +func TestGetIntegratedAuthenticatorToErrorWhenNoSpecifiedProviderFound(t *testing.T) { + removeStubProvider() + defer removeStubProvider() + + config, err := msdsn.Parse("authenticator=NONEXISTANTPROVIDER;user id=username") + if err != nil { + t.Errorf("msdsn.Parse : Unexpected error %v", err) + return + } + + // dont set an authenticator + result, err := GetIntegratedAuthenticator(config) + + if err == nil { + t.Error("expected GetIntegratedAuthenticator() to return error, found nil") + } + + if result != nil { + t.Errorf("expected GetIntegratedAuthenticator() to return nill provider, found %v", result) + } + + if err != nil && err.Error() != "provider NONEXISTANTPROVIDER not found" { + t.Errorf("expected err that default provider was not found, found %v", err) + } +} + +func removeStubProvider() { + delete(providers, providerName) +} diff --git a/integratedauth/integratedauthenticator.go b/integratedauth/integratedauthenticator.go new file mode 100644 index 00000000..ce8240d7 --- /dev/null +++ b/integratedauth/integratedauthenticator.go @@ -0,0 +1,25 @@ +package integratedauth + +import ( + "github.com/microsoft/go-mssqldb/msdsn" +) + +// Provider returns an SSPI compatible authentication provider +type Provider interface { + // GetIntegratedAuthenticator is responsible for returning an instance of the required IntegratedAuthenticator interface + GetIntegratedAuthenticator(config msdsn.Config) (IntegratedAuthenticator, error) +} + +// IntegratedAuthenticator is the interface for SSPI Login Authentication providers +type IntegratedAuthenticator interface { + InitialBytes() ([]byte, error) + NextBytes([]byte) ([]byte, error) + Free() +} + +// ProviderFunc is an adapter to convert a GetIntegratedAuthenticator func into a Provider +type ProviderFunc func(config msdsn.Config) (IntegratedAuthenticator, error) + +func (f ProviderFunc) GetIntegratedAuthenticator(config msdsn.Config) (IntegratedAuthenticator, error) { + return f(config) +} diff --git a/ntlm.go b/integratedauth/ntlm/ntlm.go similarity index 94% rename from ntlm.go rename to integratedauth/ntlm/ntlm.go index 90adb5a0..d95032f2 100644 --- a/ntlm.go +++ b/integratedauth/ntlm/ntlm.go @@ -1,6 +1,4 @@ -// +build !windows - -package mssql +package ntlm import ( "crypto/des" @@ -14,6 +12,9 @@ import ( "time" "unicode/utf16" + "github.com/microsoft/go-mssqldb/integratedauth" + "github.com/microsoft/go-mssqldb/msdsn" + //lint:ignore SA1019 MD4 is used by legacy NTLM "golang.org/x/crypto/md4" ) @@ -56,24 +57,26 @@ const _NEGOTIATE_FLAGS = _NEGOTIATE_UNICODE | _NEGOTIATE_ALWAYS_SIGN | _NEGOTIATE_EXTENDED_SESSIONSECURITY -type ntlmAuth struct { +type Auth struct { Domain string UserName string Password string Workstation string } -func getAuth(user, password, service, workstation string) (auth, bool) { - if !strings.ContainsRune(user, '\\') { - return nil, false +// getAuth returns an authentication handle Auth to provide authentication content +// to mssql.connect +func getAuth(config msdsn.Config) (integratedauth.IntegratedAuthenticator, error) { + if !strings.ContainsRune(config.User, '\\') { + return nil, fmt.Errorf("ntlm : invalid username %v", config.User) } - domain_user := strings.SplitN(user, "\\", 2) - return &ntlmAuth{ - Domain: domain_user[0], - UserName: domain_user[1], - Password: password, - Workstation: workstation, - }, true + domainUser := strings.SplitN(config.User, "\\", 2) + return &Auth{ + Domain: domainUser[0], + UserName: domainUser[1], + Password: config.Password, + Workstation: config.Workstation, + }, nil } func utf16le(val string) []byte { @@ -90,7 +93,7 @@ func utf16le(val string) []byte { return v } -func (auth *ntlmAuth) InitialBytes() ([]byte, error) { +func (auth *Auth) InitialBytes() ([]byte, error) { domain_len := len(auth.Domain) workstation_len := len(auth.Workstation) msg := make([]byte, 40+domain_len+workstation_len) @@ -358,7 +361,7 @@ func buildNTLMResponsePayload(lm, nt []byte, flags uint32, domain, workstation, return msg, nil } -func (auth *ntlmAuth) NextBytes(bytes []byte) ([]byte, error) { +func (auth *Auth) NextBytes(bytes []byte) ([]byte, error) { signature := string(bytes[0:8]) if signature != "NTLMSSP\x00" { return nil, errorNTLM @@ -389,5 +392,5 @@ func (auth *ntlmAuth) NextBytes(bytes []byte) ([]byte, error) { return buildNTLMResponsePayload(lm, nt, flags, auth.Domain, auth.Workstation, auth.UserName) } -func (auth *ntlmAuth) Free() { +func (auth *Auth) Free() { } diff --git a/ntlm_test.go b/integratedauth/ntlm/ntlm_test.go similarity index 99% rename from ntlm_test.go rename to integratedauth/ntlm/ntlm_test.go index edac77af..1df99e23 100644 --- a/ntlm_test.go +++ b/integratedauth/ntlm/ntlm_test.go @@ -1,6 +1,4 @@ -// +build !windows - -package mssql +package ntlm import ( "bytes" diff --git a/integratedauth/ntlm/provider.go b/integratedauth/ntlm/provider.go new file mode 100644 index 00000000..b0c780b5 --- /dev/null +++ b/integratedauth/ntlm/provider.go @@ -0,0 +1,15 @@ +package ntlm + +import ( + "github.com/microsoft/go-mssqldb/integratedauth" +) + +// AuthProvider handles NTLM SSPI Windows Authentication +var AuthProvider integratedauth.Provider = integratedauth.ProviderFunc(getAuth) + +func init() { + err := integratedauth.SetIntegratedAuthenticationProvider("ntlm", AuthProvider) + if err != nil { + panic(err) + } +} diff --git a/integratedauth/winsspi/provider.go b/integratedauth/winsspi/provider.go new file mode 100644 index 00000000..05da93fd --- /dev/null +++ b/integratedauth/winsspi/provider.go @@ -0,0 +1,15 @@ +// +build windows + +package winsspi + +import "github.com/microsoft/go-mssqldb/integratedauth" + +// AuthProvider handles SSPI Windows Authentication via secur32.dll functions +var AuthProvider integratedauth.Provider = integratedauth.ProviderFunc(getAuth) + +func init() { + err := integratedauth.SetIntegratedAuthenticationProvider("winsspi", AuthProvider) + if err != nil { + panic(err) + } +} \ No newline at end of file diff --git a/sspi_windows.go b/integratedauth/winsspi/winsspi.go similarity index 88% rename from sspi_windows.go rename to integratedauth/winsspi/winsspi.go index 9b5bc689..195d2288 100644 --- a/sspi_windows.go +++ b/integratedauth/winsspi/winsspi.go @@ -1,10 +1,15 @@ -package mssql +// +build windows + +package winsspi import ( "fmt" "strings" "syscall" "unsafe" + + "github.com/microsoft/go-mssqldb/integratedauth" + "github.com/microsoft/go-mssqldb/msdsn" ) var ( @@ -104,7 +109,7 @@ type SecBufferDesc struct { pBuffers *SecBuffer } -type SSPIAuth struct { +type Auth struct { Domain string UserName string Password string @@ -113,23 +118,25 @@ type SSPIAuth struct { ctxt SecHandle } -func getAuth(user, password, service, workstation string) (auth, bool) { - if user == "" { - return &SSPIAuth{Service: service}, true +// getAuth returns an authentication handle Auth to provide authentication content +// to mssql.connect +func getAuth(config msdsn.Config) (integratedauth.IntegratedAuthenticator, error) { + if config.User == "" { + return &Auth{Service: config.ServerSPN}, nil } - if !strings.ContainsRune(user, '\\') { - return nil, false + if !strings.ContainsRune(config.User, '\\') { + return nil, fmt.Errorf("winsspi : invalid username %v", config.User) } - domain_user := strings.SplitN(user, "\\", 2) - return &SSPIAuth{ - Domain: domain_user[0], - UserName: domain_user[1], - Password: password, - Service: service, - }, true + domainUser := strings.SplitN(config.User, "\\", 2) + return &Auth{ + Domain: domainUser[0], + UserName: domainUser[1], + Password: config.Password, + Service: config.ServerSPN, + }, nil } -func (auth *SSPIAuth) InitialBytes() ([]byte, error) { +func (auth *Auth) InitialBytes() ([]byte, error) { var identity *SEC_WINNT_AUTH_IDENTITY if auth.UserName != "" { identity = &SEC_WINNT_AUTH_IDENTITY{ @@ -202,7 +209,7 @@ func (auth *SSPIAuth) InitialBytes() ([]byte, error) { return outbuf[:buf.cbBuffer], nil } -func (auth *SSPIAuth) NextBytes(bytes []byte) ([]byte, error) { +func (auth *Auth) NextBytes(bytes []byte) ([]byte, error) { var in_buf, out_buf SecBuffer var in_desc, out_desc SecBufferDesc @@ -254,7 +261,7 @@ func (auth *SSPIAuth) NextBytes(bytes []byte) ([]byte, error) { return outbuf[:out_buf.cbBuffer], nil } -func (auth *SSPIAuth) Free() { +func (auth *Auth) Free() { syscall.Syscall6(sec_fn.DeleteSecurityContext, 1, uintptr(unsafe.Pointer(&auth.ctxt)), diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index 0b5354c0..74799ee5 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -73,6 +73,8 @@ type Config struct { ConnTimeout time.Duration // Use context for timeouts. KeepAlive time.Duration // Leave at default. PacketSize uint16 + + Parameters map[string]string } func SetupTLS(certificate string, insecureSkipVerify bool, hostInCertificate string, minTLSVersion string) (*tls.Config, error) { @@ -109,31 +111,32 @@ func SetupTLS(certificate string, insecureSkipVerify bool, hostInCertificate str var skipSetup = errors.New("skip setting up TLS") -func Parse(dsn string) (Config, map[string]string, error) { +func Parse(dsn string) (Config, error) { p := Config{} var params map[string]string + var err error if strings.HasPrefix(dsn, "odbc:") { - parameters, err := splitConnectionStringOdbc(dsn[len("odbc:"):]) + params, err = splitConnectionStringOdbc(dsn[len("odbc:"):]) if err != nil { - return p, params, err + return p, err } - params = parameters } else if strings.HasPrefix(dsn, "sqlserver://") { - parameters, err := splitConnectionStringURL(dsn) + params, err = splitConnectionStringURL(dsn) if err != nil { - return p, params, err + return p, err } - params = parameters } else { params = splitConnectionString(dsn) } + p.Parameters = params + strlog, ok := params["log"] if ok { flags, err := strconv.ParseUint(strlog, 10, 64) if err != nil { - return p, params, fmt.Errorf("invalid log parameter '%s': %s", strlog, err.Error()) + return p, fmt.Errorf("invalid log parameter '%s': %s", strlog, err.Error()) } p.LogFlags = Log(flags) } @@ -157,7 +160,7 @@ func Parse(dsn string) (Config, map[string]string, error) { p.Port, err = strconv.ParseUint(strport, 10, 16) if err != nil { f := "invalid tcp port '%v': %v" - return p, params, fmt.Errorf(f, strport, err.Error()) + return p, fmt.Errorf(f, strport, err.Error()) } } @@ -168,7 +171,7 @@ func Parse(dsn string) (Config, map[string]string, error) { psize, err := strconv.ParseUint(strpsize, 0, 16) if err != nil { f := "invalid packet size '%v': %v" - return p, params, fmt.Errorf(f, strpsize, err.Error()) + return p, fmt.Errorf(f, strpsize, err.Error()) } // Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes @@ -191,7 +194,7 @@ func Parse(dsn string) (Config, map[string]string, error) { timeout, err := strconv.ParseUint(strconntimeout, 10, 64) if err != nil { f := "invalid connection timeout '%v': %v" - return p, params, fmt.Errorf(f, strconntimeout, err.Error()) + return p, fmt.Errorf(f, strconntimeout, err.Error()) } p.ConnTimeout = time.Duration(timeout) * time.Second } @@ -200,7 +203,7 @@ func Parse(dsn string) (Config, map[string]string, error) { timeout, err := strconv.ParseUint(strdialtimeout, 10, 64) if err != nil { f := "invalid dial timeout '%v': %v" - return p, params, fmt.Errorf(f, strdialtimeout, err.Error()) + return p, fmt.Errorf(f, strdialtimeout, err.Error()) } p.DialTimeout = time.Duration(timeout) * time.Second } @@ -212,7 +215,7 @@ func Parse(dsn string) (Config, map[string]string, error) { timeout, err := strconv.ParseUint(keepAlive, 10, 64) if err != nil { f := "invalid keepAlive value '%s': %s" - return p, params, fmt.Errorf(f, keepAlive, err.Error()) + return p, fmt.Errorf(f, keepAlive, err.Error()) } p.KeepAlive = time.Duration(timeout) * time.Second } @@ -230,7 +233,7 @@ func Parse(dsn string) (Config, map[string]string, error) { e, err := strconv.ParseBool(encrypt) if err != nil { f := "invalid encrypt '%s': %s" - return p, params, fmt.Errorf(f, encrypt, err.Error()) + return p, fmt.Errorf(f, encrypt, err.Error()) } if e { p.Encryption = EncryptionRequired @@ -245,7 +248,7 @@ func Parse(dsn string) (Config, map[string]string, error) { trustServerCert, err = strconv.ParseBool(trust) if err != nil { f := "invalid trust server certificate '%s': %s" - return p, params, fmt.Errorf(f, trust, err.Error()) + return p, fmt.Errorf(f, trust, err.Error()) } } certificate = params["certificate"] @@ -262,7 +265,7 @@ func Parse(dsn string) (Config, map[string]string, error) { var err error p.TLSConfig, err = SetupTLS(certificate, trustServerCert, hostInCertificate, tlsMin) if err != nil { - return p, params, fmt.Errorf("failed to setup TLS: %w", err) + return p, fmt.Errorf("failed to setup TLS: %w", err) } } @@ -270,7 +273,8 @@ func Parse(dsn string) (Config, map[string]string, error) { if ok { p.ServerSPN = serverSPN } else { - p.ServerSPN = generateSpn(p.Host, p.Port) + // allow connections to sql server instances + p.ServerSPN = generateSpn(p.Host, instanceOrPort(p.Instance, p.Port)) } workstation, ok := params["workstation id"] @@ -293,7 +297,7 @@ func Parse(dsn string) (Config, map[string]string, error) { if ok { if appintent == "ReadOnly" { if p.Database == "" { - return p, params, fmt.Errorf("database must be specified when ApplicationIntent is ReadOnly") + return p, fmt.Errorf("database must be specified when ApplicationIntent is ReadOnly") } p.ReadOnlyIntent = true } @@ -310,7 +314,7 @@ func Parse(dsn string) (Config, map[string]string, error) { p.FailOverPort, err = strconv.ParseUint(failOverPort, 0, 16) if err != nil { f := "invalid failover port '%v': %v" - return p, params, fmt.Errorf(f, failOverPort, err.Error()) + return p, fmt.Errorf(f, failOverPort, err.Error()) } } @@ -320,13 +324,13 @@ func Parse(dsn string) (Config, map[string]string, error) { p.DisableRetry, err = strconv.ParseBool(disableRetry) if err != nil { f := "invalid disableRetry '%s': %s" - return p, params, fmt.Errorf(f, disableRetry, err.Error()) + return p, fmt.Errorf(f, disableRetry, err.Error()) } } else { p.DisableRetry = disableRetryDefault } - return p, params, nil + return p, nil } // convert connectionParams to url style connection string @@ -608,6 +612,26 @@ func normalizeOdbcKey(s string) string { return strings.ToLower(strings.TrimRightFunc(s, unicode.IsSpace)) } -func generateSpn(host string, port uint64) string { - return fmt.Sprintf("MSSQLSvc/%s:%d", host, port) +func instanceOrPort(instance string, port uint64) string { + if len(instance) > 0 { + return instance + } + + port = resolveServerPort(port) + + return strconv.FormatInt(int64(port), 10) +} + +const defaultServerPort = 1433 + +func resolveServerPort(port uint64) uint64 { + if port == 0 { + return defaultServerPort + } + + return port +} + +func generateSpn(host string, port string) string { + return fmt.Sprintf("MSSQLSvc/%s:%s", host, port) } diff --git a/msdsn/conn_str_test.go b/msdsn/conn_str_test.go index a7601870..20a6cd25 100644 --- a/msdsn/conn_str_test.go +++ b/msdsn/conn_str_test.go @@ -35,7 +35,7 @@ func TestInvalidConnectionString(t *testing.T) { "sqlserver://host?key=value1&key=value2", // duplicate keys } for _, connStr := range connStrings { - _, _, err := Parse(connStr) + _, err := Parse(connStr) if err == nil { t.Errorf("Connection expected to fail for connection string %s but it didn't", connStr) continue @@ -181,7 +181,7 @@ func TestValidConnectionString(t *testing.T) { }}, } for _, ts := range connStrings { - p, _, err := Parse(ts.connStr) + p, err := Parse(ts.connStr) if err == nil { t.Logf("Connection string was parsed successfully %s", ts.connStr) } else { @@ -203,12 +203,12 @@ func TestSplitConnectionStringURL(t *testing.T) { } func TestConnParseRoundTripFixed(t *testing.T) { - connStr := "sqlserver://sa:sa@localhost/sqlexpress?database=master&log=127" - params, _, err := Parse(connStr) + connStr := "sqlserver://sa:sa@localhost/sqlexpress?database=master&log=127&disableretry=true" + params, err := Parse(connStr) if err != nil { t.Fatal("Test URL is not valid", err) } - rtParams, _, err := Parse(params.URL().String()) + rtParams, err := Parse(params.URL().String()) if err != nil { t.Fatal("Params after roundtrip are not valid", err) } @@ -216,3 +216,39 @@ func TestConnParseRoundTripFixed(t *testing.T) { t.Fatal("Parameters do not match after roundtrip", params, rtParams) } } + +func TestAllKeysAreAvailableInParametersMap(t *testing.T) { + keys := map[string]string{ + "user id": "1", + "testparam": "testvalue", + "password": "test", + "thisisanunknownkey": "thisisthevalue", + "server": "name", + } + + connString := "" + for key, val := range keys { + connString += key + "=" + val + ";" + } + + params, err := Parse(connString) + if err != nil { + t.Errorf("unexpected error while parsing, %v", err) + } + + if params.Parameters == nil { + t.Error("Expected parameters map to be instanciated, found nil") + return + } + + if len(params.Parameters) != len(keys) { + t.Errorf("Expected parameters map to be same length as input map length, expected %v, found %v", len(keys), len(params.Parameters)) + return + } + + for key, val := range keys { + if params.Parameters[key] != val { + t.Errorf("Expected parameters map to contain key %v and value %v, found %v", key, val, params.Parameters[key]) + } + } +} diff --git a/mssql.go b/mssql.go index f0cba5ac..d0361294 100644 --- a/mssql.go +++ b/mssql.go @@ -61,7 +61,7 @@ type Driver struct { // OpenConnector opens a new connector. Useful to dial with a context. func (d *Driver) OpenConnector(dsn string) (*Connector, error) { - params, _, err := msdsn.Parse(dsn) + params, err := msdsn.Parse(dsn) if err != nil { return nil, err } @@ -115,7 +115,7 @@ func (d *Driver) SetContextLogger(ctxLogger ContextLogger) { // NewConnector creates a new connector from a DSN. // The returned connector may be used with sql.OpenDB. func NewConnector(dsn string) (*Connector, error) { - params, _, err := msdsn.Parse(dsn) + params, err := msdsn.Parse(dsn) if err != nil { return nil, err } @@ -129,7 +129,7 @@ func NewConnector(dsn string) (*Connector, error) { // NewConnectorWithAccessTokenProvider creates a new connector from a DSN using the given // access token provider. The returned connector may be used with sql.OpenDB. func NewConnectorWithAccessTokenProvider(dsn string, tokenProvider func(ctx context.Context) (string, error)) (*Connector, error) { - params, _, err := msdsn.Parse(dsn) + params, err := msdsn.Parse(dsn) if err != nil { return nil, err } @@ -387,7 +387,7 @@ func (c *Conn) processBeginResponse(ctx context.Context) (driver.Tx, error) { } func (d *Driver) open(ctx context.Context, dsn string) (*Conn, error) { - params, _, err := msdsn.Parse(dsn) + params, err := msdsn.Parse(dsn) if err != nil { return nil, err } diff --git a/tds.go b/tds.go index ced4a7de..82da5f09 100644 --- a/tds.go +++ b/tds.go @@ -16,6 +16,7 @@ import ( "unicode/utf16" "unicode/utf8" + "github.com/microsoft/go-mssqldb/integratedauth" "github.com/microsoft/go-mssqldb/msdsn" ) @@ -836,12 +837,6 @@ func sendAttention(buf *tdsBuffer) error { return buf.FinishPacket() } -type auth interface { - InitialBytes() ([]byte, error) - NextBytes([]byte) ([]byte, error) - Free() -} - // SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a // list of IP addresses. So if there is more than one, try them all and // use the first one that allows a connection. @@ -966,7 +961,7 @@ func interpretPreloginResponse(p msdsn.Config, fe *featureExtFedAuth, fields map return } -func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger ContextLogger, auth auth, fe *featureExtFedAuth, packetSize uint32) (l *login, err error) { +func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger ContextLogger, auth integratedauth.IntegratedAuthenticator, fe *featureExtFedAuth, packetSize uint32) (l *login, err error) { var typeFlags uint8 if p.ReadOnlyIntent { typeFlags |= fReadOnlyIntent @@ -1171,11 +1166,17 @@ initiate_connection: } } - auth, authOk := getAuth(p.User, p.Password, p.ServerSPN, p.Workstation) - if authOk { + auth, err := integratedauth.GetIntegratedAuthenticator(p) + if err != nil { + if uint64(p.LogFlags)&logDebug != 0 { + logger.Log(ctx, msdsn.LogDebug, fmt.Sprintf("Error while creating integrated authenticator: %v", err)) + } + + return nil, err + } + + if auth != nil { defer auth.Free() - } else { - auth = nil } login, err := prepareLogin(ctx, c, p, logger, auth, fedAuth, uint32(outbuf.PackageSize())) diff --git a/tds_login_test.go b/tds_login_test.go index 3bb11e59..e6d1bebe 100644 --- a/tds_login_test.go +++ b/tds_login_test.go @@ -170,7 +170,7 @@ func TestLoginWithSQLServerAuth(t *testing.T) { } func TestLoginWithSecurityTokenAuth(t *testing.T) { - config, _, err := msdsn.Parse("sqlserver://localhost:1433?Workstation ID=localhost&log=128") + config, err := msdsn.Parse("sqlserver://localhost:1433?Workstation ID=localhost&log=128") if err != nil { t.Fatal(err) } @@ -231,7 +231,7 @@ func TestLoginWithSecurityTokenAuth(t *testing.T) { } func TestLoginWithADALUsernamePasswordAuth(t *testing.T) { - config, _, err := msdsn.Parse("sqlserver://localhost:1433?Workstation ID=localhost&log=128") + config, err := msdsn.Parse("sqlserver://localhost:1433?Workstation ID=localhost&log=128") if err != nil { t.Fatal(err) } @@ -305,7 +305,7 @@ func TestLoginWithADALUsernamePasswordAuth(t *testing.T) { } func TestLoginWithADALManagedIdentityAuth(t *testing.T) { - config, _, err := msdsn.Parse("sqlserver://localhost:1433?Workstation ID=localhost&log=128") + config, err := msdsn.Parse("sqlserver://localhost:1433?Workstation ID=localhost&log=128") if err != nil { t.Fatal(err) } diff --git a/tds_test.go b/tds_test.go index 0caeed8a..0dfeeb2b 100644 --- a/tds_test.go +++ b/tds_test.go @@ -156,7 +156,7 @@ func TestSendLoginWithFeatureExt(t *testing.T) { func TestSendSqlBatch(t *testing.T) { checkConnStr(t) - p, _, err := msdsn.Parse(makeConnStr(t).String()) + p, err := msdsn.Parse(makeConnStr(t).String()) if err != nil { t.Error("parseConnectParams failed:", err.Error()) return @@ -225,7 +225,7 @@ func GetConnParams() (*msdsn.Config, error) { dsn := os.Getenv("SQLSERVER_DSN") const logFlags = 127 if len(dsn) > 0 { - params, _, err := msdsn.Parse(dsn) + params, err := msdsn.Parse(dsn) if err != nil { return nil, err } @@ -250,7 +250,7 @@ func GetConnParams() (*msdsn.Config, error) { if err != io.EOF && err != nil { return nil, err } - params, _, err := msdsn.Parse(dsn) + params, err := msdsn.Parse(dsn) if err != nil { return nil, err } @@ -875,7 +875,7 @@ func TestReadBVarByte(t *testing.T) { func BenchmarkPacketSize(b *testing.B) { checkConnStr(b) - p, _, err := msdsn.Parse(makeConnStr(b).String()) + p, err := msdsn.Parse(makeConnStr(b).String()) if err != nil { b.Error("parseConnectParams failed:", err.Error()) return