From c902b674869e3f177b723cf2a5b1364c14efda47 Mon Sep 17 00:00:00 2001 From: abairmj Date: Thu, 19 Oct 2023 14:43:35 -0400 Subject: [PATCH] fix: Added multisubnetfailover option, set to false to prevent issue #158 (#159) * fix: Added multisubnetfailover option that can be set to false to prevent issue #158 --- README.md | 3 +++ msdsn/conn_str.go | 21 +++++++++++++++++++++ msdsn/conn_str_test.go | 3 +++ protocol.go | 20 ++++++++++++++------ 4 files changed, 41 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index ba87fca5..a505f166 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,9 @@ Other supported formats are listed below. * `ApplicationIntent` - Can be given the value `ReadOnly` to initiate a read-only connection to an Availability Group listener. The `database` must be specified when connecting with `Application Intent` set to `ReadOnly`. * `protocol` - forces use of a protocol. Make sure the corresponding package is imported. * `columnencryption` or `column encryption setting` - a boolean value indicating whether Always Encrypted should be enabled on the connection. +* `multisubnetfailover` + * `true` (Default) Client attempt to connect to all IPs simultaneously. + * `false` Client attempts to connect to IPs in serial. ### Connection parameters for namedpipe package * `pipe` - If set, no Browser query is made and named pipe used will be `\\\pipe\` diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index 2bdddb57..460dd4e3 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -78,6 +78,7 @@ const ( Protocol = "protocol" DialTimeout = "dial timeout" Pipe = "pipe" + MultiSubnetFailover = "multisubnetfailover" ) type Config struct { @@ -128,6 +129,8 @@ type Config struct { ChangePassword string //ColumnEncryption is true if the application needs to decrypt or encrypt Always Encrypted values ColumnEncryption bool + // Attempt to connect to all IPs in parallel when MultiSubnetFailover is true + MultiSubnetFailover bool } func readDERFile(filename string) ([]byte, error) { @@ -483,6 +486,24 @@ func Parse(dsn string) (Config, error) { } p.ColumnEncryption = columnEncryption } + + msf, ok := params[MultiSubnetFailover] + if ok { + multiSubnetFailover, err := strconv.ParseBool(msf) + if err != nil { + if strings.EqualFold(msf, "Enabled") { + multiSubnetFailover = true + } else if strings.EqualFold(msf, "Disabled") { + multiSubnetFailover = false + } else { + return p, fmt.Errorf("invalid multiSubnetFailover value '%v': %v", multiSubnetFailover, err.Error()) + } + } + p.MultiSubnetFailover = multiSubnetFailover + } else { + // Defaulting to true to prevent breaking change although other client libraries default to false + p.MultiSubnetFailover = true + } return p, nil } diff --git a/msdsn/conn_str_test.go b/msdsn/conn_str_test.go index 0645313b..f1bf03eb 100644 --- a/msdsn/conn_str_test.go +++ b/msdsn/conn_str_test.go @@ -25,6 +25,7 @@ func TestInvalidConnectionString(t *testing.T) { "failoverport=invalid", "applicationintent=ReadOnly", "disableretry=invalid", + "multisubnetfailover=invalid", // ODBC mode "odbc:password={", @@ -104,6 +105,8 @@ func TestValidConnectionString(t *testing.T) { {"disableretry=1", func(p Config) bool { return p.DisableRetry }}, {"disableretry=0", func(p Config) bool { return !p.DisableRetry }}, {"", func(p Config) bool { return p.DisableRetry == disableRetryDefault }}, + {"MultiSubnetFailover=true", func(p Config) bool { return p.MultiSubnetFailover }}, + {"MultiSubnetFailover=false", func(p Config) bool { return !p.MultiSubnetFailover }}, // those are supported currently, but maybe should not be {"someparam", func(p Config) bool { return true }}, diff --git a/protocol.go b/protocol.go index 6a83ab25..23ad0681 100644 --- a/protocol.go +++ b/protocol.go @@ -69,12 +69,14 @@ func (t tcpDialer) DialConnection(ctx context.Context, p *msdsn.Config) (conn ne func (t tcpDialer) DialSqlConnection(ctx context.Context, c *Connector, p *msdsn.Config) (conn net.Conn, err error) { var ips []net.IP ip := net.ParseIP(p.Host) + portStr := strconv.Itoa(int(resolveServerPort(p.Port))) + if ip == nil { // if the custom dialer is a host dialer, the DNS is resolved within the network // the dialer is sending the request to, rather than the one the driver is running on d := c.getDialer(p) if _, ok := d.(HostDialer); ok { - addr := net.JoinHostPort(p.Host, strconv.Itoa(int(resolveServerPort(p.Port)))) + addr := net.JoinHostPort(p.Host, portStr) return d.DialContext(ctx, "tcp", addr) } @@ -85,16 +87,22 @@ func (t tcpDialer) DialSqlConnection(ctx context.Context, c *Connector, p *msdsn } else { ips = []net.IP{ip} } - if len(ips) == 1 { - d := c.getDialer(p) - addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(resolveServerPort(p.Port)))) - conn, err = d.DialContext(ctx, "tcp", addr) + if len(ips) == 1 || !p.MultiSubnetFailover { + // Try to connect to IPs sequentially until one is successful per MultiSubnetFailover false rules + for _, ipaddress := range ips { + d := c.getDialer(p) + addr := net.JoinHostPort(ipaddress.String(), portStr) + conn, err = d.DialContext(ctx, "tcp", addr) + if err == nil { + break + } + } } else { //Try Dials in parallel to avoid waiting for timeouts. connChan := make(chan net.Conn, len(ips)) errChan := make(chan error, len(ips)) - portStr := strconv.Itoa(int(resolveServerPort(p.Port))) + for _, ip := range ips { go func(ip net.IP) { d := c.getDialer(p)