From 3cc2463cd7ba0de8a2a252246ca0fc477548d52e Mon Sep 17 00:00:00 2001 From: Apoorv Deshmukh Date: Thu, 6 Jul 2023 19:55:30 +0530 Subject: [PATCH] Accept additional values for encrypt --- msdsn/conn_str.go | 24 ++++++++++++++---------- msdsn/conn_str_test.go | 3 +++ 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index 4f71453d..bd06ac47 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -25,6 +25,7 @@ const ( EncryptionOff = 0 EncryptionRequired = 1 EncryptionDisabled = 3 + EncryptionStrict = 4 ) const ( @@ -130,21 +131,24 @@ func parseTLS(params map[string]string, host string) (Encryption, *tls.Config, e var encryption Encryption = EncryptionOff encrypt, ok := params["encrypt"] if ok { - if strings.EqualFold(encrypt, "DISABLE") { + encrypt = strings.ToLower(encrypt) + switch encrypt { + case "mandatory", "yes", "1", "t", "true": + encryption = EncryptionRequired + case "disable": encryption = EncryptionDisabled - } else { - e, err := strconv.ParseBool(encrypt) - if err != nil { - f := "invalid encrypt '%s': %s" - return encryption, nil, fmt.Errorf(f, encrypt, err.Error()) - } - if e { - encryption = EncryptionRequired - } + case "strict": + encryption = EncryptionStrict + case "optional", "no", "0", "f", "false": + encryption = EncryptionOff + default: + f := "invalid encrypt '%s'" + return encryption, nil, fmt.Errorf(f, encrypt) } } else { trustServerCert = true } + trust, ok := params["trustservercertificate"] if ok { var err error diff --git a/msdsn/conn_str_test.go b/msdsn/conn_str_test.go index 5fa1a0ed..31a5f3c1 100644 --- a/msdsn/conn_str_test.go +++ b/msdsn/conn_str_test.go @@ -62,6 +62,7 @@ func TestValidConnectionString(t *testing.T) { {"encrypt=disable", func(p Config) bool { return p.Encryption == EncryptionDisabled }}, {"encrypt=disable;tlsmin=1.1", func(p Config) bool { return p.Encryption == EncryptionDisabled && p.TLSConfig == nil }}, {"encrypt=true", func(p Config) bool { return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == 0 }}, + {"encrypt=mandatory", func(p Config) bool { return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == 0 }}, {"encrypt=true;tlsmin=1.0", func(p Config) bool { return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS10 }}, @@ -78,6 +79,8 @@ func TestValidConnectionString(t *testing.T) { return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == 0 }}, {"encrypt=false", func(p Config) bool { return p.Encryption == EncryptionOff }}, + {"encrypt=optional", func(p Config) bool { return p.Encryption == EncryptionOff }}, + {"encrypt=strict", func(p Config) bool { return p.Encryption == EncryptionStrict }}, {"connection timeout=3;dial timeout=4;keepalive=5", func(p Config) bool { return p.ConnTimeout == 3*time.Second && p.DialTimeout == 4*time.Second && p.KeepAlive == 5*time.Second }},