Skip to content

Commit

Permalink
Fixed #93: servername was not supplied to TLS cfg (#94)
Browse files Browse the repository at this point in the history
* Fixed #93: servername was not supplied to TLS cfg

A change in fd44003 moved the code for setting p.Host down in the parse
function, resulting in an empty p.Host value at the time of the creation
of p.TLSConfig. I extracted the TLS parsing into a separate function and
moved the TLSConfig creation to the end of the parsing function.

* Added extra test scenario with hostnameincertificate parameter
  • Loading branch information
jwbargsten authored Feb 27, 2023
1 parent f481f92 commit 0a7b7a4
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 59 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## [Unreleased]

### Bug fixes

* Fixed uninitialized server name in TLS config ([#93](https://github.com/microsoft/go-mssqldb/issues/93))([#94](https://github.com/microsoft/go-mssqldb/pull/94))

## 0.20.0

### Features
Expand Down
116 changes: 57 additions & 59 deletions msdsn/conn_str.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ type Config struct {
ProtocolParameters map[string]interface{}
}

// Build a tls.Config object from the supplied certificate.
func SetupTLS(certificate string, insecureSkipVerify bool, hostInCertificate string, minTLSVersion string) (*tls.Config, error) {
config := tls.Config{
ServerName: hostInCertificate,
Expand Down Expand Up @@ -113,6 +114,49 @@ func SetupTLS(certificate string, insecureSkipVerify bool, hostInCertificate str
return &config, nil
}

// Parse and handle encryption parameters. If encryption is desired, it returns the corresponding tls.Config object.
func parseTLS(params map[string]string, host string) (Encryption, *tls.Config, error) {
trustServerCert := false

var encryption Encryption = EncryptionOff
encrypt, ok := params["encrypt"]
if ok {
if strings.EqualFold(encrypt, "DISABLE") {
encryption = EncryptionDisabled
} else {
e, err := strconv.ParseBool(encrypt)
if err != nil {
f := "invalid encrypt '%s': %s"
return encryption, nil, fmt.Errorf(f, encrypt, err.Error())
}
if e {
encryption = EncryptionRequired
}
}
} else {
trustServerCert = true
}
trust, ok := params["trustservercertificate"]
if ok {
var err error
trustServerCert, err = strconv.ParseBool(trust)
if err != nil {
f := "invalid trust server certificate '%s': %s"
return encryption, nil, fmt.Errorf(f, trust, err.Error())
}
}
certificate := params["certificate"]
if encryption != EncryptionDisabled {
tlsMin := params["tlsmin"]
tlsConfig, err := SetupTLS(certificate, trustServerCert, host, tlsMin)
if err != nil {
return encryption, nil, fmt.Errorf("failed to setup TLS: %w", err)
}
return encryption, tlsConfig, nil
}
return encryption, nil, nil
}

var skipSetup = errors.New("skip setting up TLS")

func Parse(dsn string) (Config, error) {
Expand Down Expand Up @@ -210,55 +254,6 @@ func Parse(dsn string) (Config, error) {
p.KeepAlive = time.Duration(timeout) * time.Second
}

var (
trustServerCert = false
certificate = ""
hostInCertificate = ""
)
encrypt, ok := params["encrypt"]
if ok {
if strings.EqualFold(encrypt, "DISABLE") {
p.Encryption = EncryptionDisabled
} else {
e, err := strconv.ParseBool(encrypt)
if err != nil {
f := "invalid encrypt '%s': %s"
return p, fmt.Errorf(f, encrypt, err.Error())
}
if e {
p.Encryption = EncryptionRequired
}
}
} else {
trustServerCert = true
}
trust, ok := params["trustservercertificate"]
if ok {
var err error
trustServerCert, err = strconv.ParseBool(trust)
if err != nil {
f := "invalid trust server certificate '%s': %s"
return p, fmt.Errorf(f, trust, err.Error())
}
}
certificate = params["certificate"]
hostInCertificate, ok = params["hostnameincertificate"]
if ok {
p.HostInCertificateProvided = true
} else {
hostInCertificate = p.Host
p.HostInCertificateProvided = false
}

if p.Encryption != EncryptionDisabled {
tlsMin := params["tlsmin"]
var err error
p.TLSConfig, err = SetupTLS(certificate, trustServerCert, hostInCertificate, tlsMin)
if err != nil {
return p, fmt.Errorf("failed to setup TLS: %w", err)
}
}

serverSPN, ok := params["serverspn"]
if ok {
p.ServerSPN = serverSPN
Expand Down Expand Up @@ -354,6 +349,19 @@ func Parse(dsn string) (Config, error) {
p.DialTimeout = time.Duration(timeout) * time.Second
}

hostInCertificate, ok := params["hostnameincertificate"]
if ok {
p.HostInCertificateProvided = true
} else {
hostInCertificate = p.Host
p.HostInCertificateProvided = false
}

p.Encryption, p.TLSConfig, err = parseTLS(params, hostInCertificate)
if err != nil {
return p, err
}

return p, nil
}

Expand Down Expand Up @@ -651,16 +659,6 @@ func normalizeOdbcKey(s string) string {
return strings.ToLower(strings.TrimRightFunc(s, unicode.IsSpace))
}

const defaultServerPort = 1433

func resolveServerPort(port uint64) uint64 {
if port == 0 {
return defaultServerPort
}

return port
}

// ProtocolParser can populate Config with parameters to dial using its protocol
type ProtocolParser interface {
ParseServer(server string, p *Config) error
Expand Down
27 changes: 27 additions & 0 deletions msdsn/conn_str_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,33 @@ func TestConnParseRoundTripFixed(t *testing.T) {
}
}

func TestServerNameInTLSConfig(t *testing.T) {
var tests = []struct {
dsn string
host string
hasTLSConfig bool
}{
{"sqlserver://someuser:somepass@somehost?TrustServerCertificate=false&encrypt=true", "somehost", true},
{"sqlserver://someuser:somepass@somehost?TrustServerCertificate=false&encrypt=false", "somehost", true},
{"sqlserver://someuser:somepass@somehost?TrustServerCertificate=false&encrypt=true&hostnameincertificate=someotherhost", "someotherhost", true},
{"sqlserver://someuser:somepass@somehost?TrustServerCertificate=false", "somehost", true},
{"sqlserver://someuser:somepass@somehost?TrustServerCertificate=false&encrypt=DISABLE", "", false},
{"sqlserver://someuser:somepass@somehost?TrustServerCertificate=false&encrypt=DISABLE&hostnameincertificate=someotherhost", "", false},
{"sqlserver://someuser:somepass@somehost?TrustServerCertificate=false&encrypt=false", "somehost", true},
}
for _, test := range tests {
cfg, err := Parse(test.dsn)
if err != nil {
t.Errorf("Could not parse valid connection string %s: %v", test.dsn, err)
}
if !test.hasTLSConfig && cfg.TLSConfig != nil {
t.Errorf("Expected empty TLS config, but got %v (cfg.Host was %s)", cfg.TLSConfig, cfg.Host)
}
if test.hasTLSConfig && cfg.TLSConfig.ServerName != test.host {
t.Errorf("Expected somehost as TLS server, but got %s (cfg.Host was %s)", cfg.TLSConfig.ServerName, cfg.Host)
}
}
}
func TestAllKeysAreAvailableInParametersMap(t *testing.T) {
keys := map[string]string{
"user id": "1",
Expand Down

0 comments on commit 0a7b7a4

Please sign in to comment.