From 6b6956068faaea7ea45b183c8b77dd57e99d5dc8 Mon Sep 17 00:00:00 2001 From: David Shiflet Date: Mon, 21 Aug 2023 08:30:27 -0500 Subject: [PATCH] Feat: Implement change password during login (#141) * Feat: Implement change password during login * use -v for go test * move assert usage to go117+ --- .github/workflows/pr-validation.yml | 3 +- .pipelines/TestSql2017.yml | 9 +- appveyor.yml | 1 + msdsn/conn_str.go | 133 ++++++++++++++++++---------- tds.go | 28 +++--- tds_go117_test.go | 60 +++++++++++++ 6 files changed, 164 insertions(+), 70 deletions(-) create mode 100644 tds_go117_test.go diff --git a/.github/workflows/pr-validation.yml b/.github/workflows/pr-validation.yml index 7963a660..3a8f831a 100644 --- a/.github/workflows/pr-validation.yml +++ b/.github/workflows/pr-validation.yml @@ -21,7 +21,6 @@ jobs: - name: Run tests against Linux SQL run: | go version - go get -d export SQLCMDPASSWORD=$(date +%s|sha256sum|base64|head -c 32) export SQLCMDUSER=sa export SQLUSER=sa @@ -30,4 +29,4 @@ jobs: docker run -m 2GB -e ACCEPT_EULA=1 -d --name sqlserver -p:1433:1433 -e SA_PASSWORD=$SQLCMDPASSWORD mcr.microsoft.com/mssql/server:${{ matrix.sqlImage }} sleep 10 sqlcmd -Q "CREATE DATABASE test" - go test -race -cpu 4 ./... + go test -v ./... diff --git a/.pipelines/TestSql2017.yml b/.pipelines/TestSql2017.yml index 9bdb5303..9633308a 100644 --- a/.pipelines/TestSql2017.yml +++ b/.pipelines/TestSql2017.yml @@ -8,14 +8,7 @@ variables: steps: - task: GoTool@0 inputs: - version: '1.19' -- task: Go@0 - displayName: 'Go: get sources' - inputs: - command: 'get' - arguments: '-d' - workingDirectory: '$(Build.SourcesDirectory)' - + version: '1.20' - task: Go@0 displayName: 'Go: install gotest.tools/gotestsum' diff --git a/appveyor.yml b/appveyor.yml index fdeeedf3..b3bcc5c2 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -49,6 +49,7 @@ install: - go get -u github.com/golang-sql/civil - go get -u github.com/golang-sql/sqlexp - go get -u golang.org/x/crypto/md4 + - go get github.com/stretchr/testify/assert@v1.8.1 - go get -u golang.org/x/text/encoding/unicode build_script: diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index ee186e0d..544cc4db 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -44,6 +44,34 @@ const ( BrowserDAC BrowserMsg = 0x0f ) +const ( + Database = "database" + Encrypt = "encrypt" + Password = "password" + ChangePassword = "change password" + UserID = "user id" + Port = "port" + TrustServerCertificate = "trustservercertificate" + Certificate = "certificate" + TLSMin = "tlsmin" + PacketSize = "packet size" + LogParam = "log" + ConnectionTimeout = "connection timeout" + HostNameInCertificate = "hostnameincertificate" + KeepAlive = "keepalive" + ServerSpn = "serverspn" + WorkstationID = "workstation id" + AppName = "app name" + ApplicationIntent = "applicationintent" + FailoverPartner = "failoverpartner" + FailOverPort = "failoverport" + DisableRetry = "disableretry" + Server = "server" + Protocol = "protocol" + DialTimeout = "dial timeout" + Pipe = "pipe" +) + type Config struct { Port uint64 Host string @@ -88,6 +116,8 @@ type Config struct { ProtocolParameters map[string]interface{} // BrowserMsg is the message identifier to fetch instance data from SQL browser BrowserMessage BrowserMsg + // ChangePassword is used to set the login's password during login. Ignored for non-SQL authentication. + ChangePassword string //ColumnEncryption is true if the application needs to decrypt or encrypt Always Encrypted values ColumnEncryption bool } @@ -130,7 +160,7 @@ func parseTLS(params map[string]string, host string) (Encryption, *tls.Config, e trustServerCert := false var encryption Encryption = EncryptionOff - encrypt, ok := params["encrypt"] + encrypt, ok := params[Encrypt] if ok { if strings.EqualFold(encrypt, "DISABLE") { encryption = EncryptionDisabled @@ -147,7 +177,7 @@ func parseTLS(params map[string]string, host string) (Encryption, *tls.Config, e } else { trustServerCert = true } - trust, ok := params["trustservercertificate"] + trust, ok := params[TrustServerCertificate] if ok { var err error trustServerCert, err = strconv.ParseBool(trust) @@ -156,9 +186,9 @@ func parseTLS(params map[string]string, host string) (Encryption, *tls.Config, e return encryption, nil, fmt.Errorf(f, trust, err.Error()) } } - certificate := params["certificate"] + certificate := params[Certificate] if encryption != EncryptionDisabled { - tlsMin := params["tlsmin"] + tlsMin := params[TLSMin] tlsConfig, err := SetupTLS(certificate, trustServerCert, host, tlsMin) if err != nil { return encryption, nil, fmt.Errorf("failed to setup TLS: %w", err) @@ -194,7 +224,7 @@ func Parse(dsn string) (Config, error) { p.Parameters = params - strlog, ok := params["log"] + strlog, ok := params[LogParam] if ok { flags, err := strconv.ParseUint(strlog, 10, 64) if err != nil { @@ -203,12 +233,12 @@ func Parse(dsn string) (Config, error) { p.LogFlags = Log(flags) } - p.Database = params["database"] - p.User = params["user id"] - p.Password = params["password"] - + p.Database = params[Database] + p.User = params[UserID] + p.Password = params[Password] + p.ChangePassword = params[ChangePassword] p.Port = 0 - strport, ok := params["port"] + strport, ok := params[Port] if ok { var err error p.Port, err = strconv.ParseUint(strport, 10, 16) @@ -219,7 +249,7 @@ func Parse(dsn string) (Config, error) { } // https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option\ - strpsize, ok := params["packet size"] + strpsize, ok := params[PacketSize] if ok { var err error psize, err := strconv.ParseUint(strpsize, 0, 16) @@ -244,7 +274,7 @@ func Parse(dsn string) (Config, error) { // // Do not set a connection timeout. Use Context to manage such things. // Default to zero, but still allow it to be set. - if strconntimeout, ok := params["connection timeout"]; ok { + if strconntimeout, ok := params[ConnectionTimeout]; ok { timeout, err := strconv.ParseUint(strconntimeout, 10, 64) if err != nil { f := "invalid connection timeout '%v': %v" @@ -256,7 +286,7 @@ func Parse(dsn string) (Config, error) { // default keep alive should be 30 seconds according to spec: // https://msdn.microsoft.com/en-us/library/dd341108.aspx p.KeepAlive = 30 * time.Second - if keepAlive, ok := params["keepalive"]; ok { + if keepAlive, ok := params[KeepAlive]; ok { timeout, err := strconv.ParseUint(keepAlive, 10, 64) if err != nil { f := "invalid keepAlive value '%s': %s" @@ -265,12 +295,12 @@ func Parse(dsn string) (Config, error) { p.KeepAlive = time.Duration(timeout) * time.Second } - serverSPN, ok := params["serverspn"] + serverSPN, ok := params[ServerSpn] if ok { p.ServerSPN = serverSPN } // If not set by the app, ServerSPN will be set by the successful dialer. - workstation, ok := params["workstation id"] + workstation, ok := params[WorkstationID] if ok { p.Workstation = workstation } else { @@ -280,13 +310,13 @@ func Parse(dsn string) (Config, error) { } } - appname, ok := params["app name"] + appname, ok := params[AppName] if !ok { appname = "go-mssqldb" } p.AppName = appname - appintent, ok := params["applicationintent"] + appintent, ok := params[ApplicationIntent] if ok { if appintent == "ReadOnly" { if p.Database == "" { @@ -296,12 +326,12 @@ func Parse(dsn string) (Config, error) { } } - failOverPartner, ok := params["failoverpartner"] + failOverPartner, ok := params[FailoverPartner] if ok { p.FailOverPartner = failOverPartner } - failOverPort, ok := params["failoverport"] + failOverPort, ok := params[FailOverPort] if ok { var err error p.FailOverPort, err = strconv.ParseUint(failOverPort, 0, 16) @@ -311,7 +341,7 @@ func Parse(dsn string) (Config, error) { } } - disableRetry, ok := params["disableretry"] + disableRetry, ok := params[DisableRetry] if ok { var err error p.DisableRetry, err = strconv.ParseBool(disableRetry) @@ -323,8 +353,8 @@ func Parse(dsn string) (Config, error) { p.DisableRetry = disableRetryDefault } - server := params["server"] - protocol, ok := params["protocol"] + server := params[Server] + protocol, ok := params[Protocol] for _, parser := range ProtocolParsers { if (!ok && !parser.Hidden()) || parser.Protocol() == protocol { @@ -350,7 +380,7 @@ func Parse(dsn string) (Config, error) { f = 1 } p.DialTimeout = time.Duration(15*f) * time.Second - if strdialtimeout, ok := params["dial timeout"]; ok { + if strdialtimeout, ok := params[DialTimeout]; ok { timeout, err := strconv.ParseUint(strdialtimeout, 10, 64) if err != nil { f := "invalid dial timeout '%v': %v" @@ -360,7 +390,7 @@ func Parse(dsn string) (Config, error) { p.DialTimeout = time.Duration(timeout) * time.Second } - hostInCertificate, ok := params["hostnameincertificate"] + hostInCertificate, ok := params[HostNameInCertificate] if ok { p.HostInCertificateProvided = true } else { @@ -394,10 +424,10 @@ func Parse(dsn string) (Config, error) { func (p Config) URL() *url.URL { q := url.Values{} if p.Database != "" { - q.Add("database", p.Database) + q.Add(Database, p.Database) } if p.LogFlags != 0 { - q.Add("log", strconv.FormatUint(uint64(p.LogFlags), 10)) + q.Add(LogParam, strconv.FormatUint(uint64(p.LogFlags), 10)) } host := p.Host protocol := "" @@ -412,8 +442,8 @@ func (p Config) URL() *url.URL { if p.Port > 0 { host = fmt.Sprintf("%s:%d", host, p.Port) } - q.Add("disableRetry", fmt.Sprintf("%t", p.DisableRetry)) - protocolParam, ok := p.Parameters["protocol"] + q.Add(DisableRetry, fmt.Sprintf("%t", p.DisableRetry)) + protocolParam, ok := p.Parameters[Protocol] if ok { if protocol != "" && protocolParam != protocol { panic("Mismatched protocol parameters!") @@ -421,11 +451,11 @@ func (p Config) URL() *url.URL { protocol = protocolParam } if protocol != "" { - q.Add("protocol", protocol) + q.Add(Protocol, protocol) } - pipe, ok := p.Parameters["pipe"] + pipe, ok := p.Parameters[Pipe] if ok { - q.Add("pipe", pipe) + q.Add(Pipe, pipe) } res := url.URL{ Scheme: "sqlserver", @@ -435,7 +465,14 @@ func (p Config) URL() *url.URL { if p.Instance != "" { res.Path = p.Instance } - q.Add("dial timeout", strconv.FormatFloat(float64(p.DialTimeout.Seconds()), 'f', 0, 64)) + q.Add(DialTimeout, strconv.FormatFloat(float64(p.DialTimeout.Seconds()), 'f', 0, 64)) + + switch p.Encryption { + case EncryptionDisabled: + q.Add(Encrypt, "DISABLE") + case EncryptionRequired: + q.Add(Encrypt, "true") + } if p.ColumnEncryption { q.Add("columnencryption", "true") } @@ -448,14 +485,14 @@ func (p Config) URL() *url.URL { // ADO connection string keywords at https://github.com/dotnet/SqlClient/blob/main/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/DbConnectionStringCommon.cs var adoSynonyms = map[string]string{ - "application name": "app name", - "data source": "server", - "address": "server", - "network address": "server", - "addr": "server", - "user": "user id", - "uid": "user id", - "initial catalog": "database", + "application name": AppName, + "data source": Server, + "address": Server, + "network address": Server, + "addr": Server, + "user": UserID, + "uid": UserID, + "initial catalog": Database, "column encryption setting": "columnencryption", } @@ -480,18 +517,18 @@ func splitConnectionString(dsn string) (res map[string]string) { name = synonym } // "server" in ADO can include a protocol and a port. - if name == "server" { + if name == Server { for _, parser := range ProtocolParsers { prot := parser.Protocol() + ":" if strings.HasPrefix(value, prot) { - res["protocol"] = parser.Protocol() + res[Protocol] = parser.Protocol() } value = strings.TrimPrefix(value, prot) } serverParts := strings.Split(value, ",") if len(serverParts) == 2 && len(serverParts[1]) > 0 { value = serverParts[0] - res["port"] = serverParts[1] + res[Port] = serverParts[1] } } res[name] = value @@ -513,10 +550,10 @@ func splitConnectionStringURL(dsn string) (map[string]string, error) { } if u.User != nil { - res["user id"] = u.User.Username() + res[UserID] = u.User.Username() p, exists := u.User.Password() if exists { - res["password"] = p + res[Password] = p } } @@ -526,13 +563,13 @@ func splitConnectionStringURL(dsn string) (map[string]string, error) { } if len(u.Path) > 0 { - res["server"] = host + "\\" + u.Path[1:] + res[Server] = host + "\\" + u.Path[1:] } else { - res["server"] = host + res[Server] = host } if len(port) > 0 { - res["port"] = port + res[Port] = port } query := u.Query() diff --git a/tds.go b/tds.go index 891630c6..db6eda14 100644 --- a/tds.go +++ b/tds.go @@ -601,7 +601,7 @@ func sendLogin(w *tdsBuffer, login *login) error { language := str2ucs2(login.Language) database := str2ucs2(login.Database) atchdbfile := str2ucs2(login.AtchDBFile) - changepassword := str2ucs2(login.ChangePassword) + changepassword := manglePassword(login.ChangePassword) featureExt := login.FeatureExt.toBytes() hdr := loginHeader{ @@ -662,6 +662,9 @@ func sendLogin(w *tdsBuffer, login *login) error { offset += hdr.ExtensionLength // DWORD featureExtOffset = uint32(offset) } + if len(changepassword) > 0 { + hdr.OptionFlags3 |= fChangePassword + } hdr.Length = uint32(offset) + uint32(featureExtLen) var err error @@ -1059,17 +1062,18 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont serverName = p.Host } l = &login{ - TDSVersion: verTDS74, - PacketSize: packetSize, - Database: p.Database, - OptionFlags2: fODBC, // to get unlimited TEXTSIZE - OptionFlags1: fUseDB | fSetLang, - HostName: p.Workstation, - ServerName: serverName, - AppName: p.AppName, - TypeFlags: typeFlags, - CtlIntName: "go-mssqldb", - ClientProgVer: getDriverVersion(driverVersion), + TDSVersion: verTDS74, + PacketSize: packetSize, + Database: p.Database, + OptionFlags2: fODBC, // to get unlimited TEXTSIZE + OptionFlags1: fUseDB | fSetLang, + HostName: p.Workstation, + ServerName: serverName, + AppName: p.AppName, + TypeFlags: typeFlags, + CtlIntName: "go-mssqldb", + ClientProgVer: getDriverVersion(driverVersion), + ChangePassword: p.ChangePassword, } if p.ColumnEncryption { _ = l.FeatureExt.Add(&featureExtColumnEncryption{}) diff --git a/tds_go117_test.go b/tds_go117_test.go new file mode 100644 index 00000000..24bd967c --- /dev/null +++ b/tds_go117_test.go @@ -0,0 +1,60 @@ +//go:build go1.17 +// +build go1.17 + +package mssql + +import ( + "context" + "crypto/rand" + "database/sql" + "fmt" + "math/big" + "testing" + + "github.com/microsoft/go-mssqldb/msdsn" + "github.com/stretchr/testify/assert" +) + +func TestChangePassword(t *testing.T) { + conn, logger := open(t) + defer conn.Close() + defer logger.StopLogging() + login, pwd := createLogin(t, conn) + defer dropLogin(t, conn, login) + p, err := msdsn.Parse(makeConnStr(t).String()) + assert.NoError(t, err, "Parse failed") + p.ChangePassword = "Change" + pwd + p.User = login + p.Password = pwd + p.Parameters[msdsn.UserID] = p.User + p.Parameters[msdsn.Password] = p.Password + tl := testLogger{t: t} + defer tl.StopLogging() + c, err := connect(context.Background(), &Connector{params: p}, optionalLogger{loggerAdapter{&tl}}, p) + if assert.NoError(t, err, "Login with new login failed") { + c.buf.transport.Close() + + p.Password = p.ChangePassword + p.ChangePassword = "" + c, err = connect(context.Background(), &Connector{params: p}, optionalLogger{loggerAdapter{&tl}}, p) + if assert.NoError(t, err, "Login with new password failed") { + c.buf.transport.Close() + } + } + +} + +func createLogin(t *testing.T, conn *sql.DB) (login string, password string) { + t.Helper() + suffix, _ := rand.Int(rand.Reader, big.NewInt(10000)) + login = fmt.Sprintf("mssqlLogin%d", suffix.Int64()) + password = fmt.Sprintf("mssqlPwd!%d", suffix.Int64()) + _, err := conn.Exec(fmt.Sprintf("CREATE LOGIN [%s] WITH PASSWORD = '%s', CHECK_POLICY=OFF\nCREATE USER %s", login, password, login)) + assert.NoError(t, err, "create login failed") + return +} + +func dropLogin(t *testing.T, conn *sql.DB, login string) { + t.Helper() + _, _ = conn.Exec(fmt.Sprintf("DROP USER %s\nDROP LOGIN [%s]", login, login)) +}