Skip to content

Commit

Permalink
Fixes for named pipe dialer (#81)
Browse files Browse the repository at this point in the history
  • Loading branch information
shueybubbles authored Jan 19, 2023
1 parent 91e6060 commit efa88a7
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 97 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
### Features

* Add driver version and name to TDS login packets
* Add `pipe` connection string parameter for named pipe dialer

### Bug fixes

* Added checks while reading prelogin for invalid data ([#64](https://github.com/microsoft/go-mssqldb/issues/64))([86ecefd8b](https://github.com/microsoft/go-mssqldb/commit/86ecefd8b57683aeb5ad9328066ee73fbccd62f5))
* Added checks while reading prelogin for invalid data ([#64](https://github.com/microsoft/go-mssqldb/issues/64))([86ecefd8b](https://github.com/microsoft/go-mssqldb/commit/86ecefd8b57683aeb5ad9328066ee73fbccd62f5))

* Fixed multi-protocol dialer path to avoid unneeded SQL Browser queries
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,22 @@ Other supported formats are listed below.
* `ServerSPN` - The kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port.
* `Workstation ID` - The workstation name (default is the host name)
* `ApplicationIntent` - Can be given the value `ReadOnly` to initiate a read-only connection to an Availability Group listener. The `database` must be specified when connecting with `Application Intent` set to `ReadOnly`.
* `protocol` - forces use of a protocol. Make sure the corresponding package is imported.

### Connection parameters for namedpipe package
* `pipe` - If set, no Browser query is made and named pipe used will be `\\<host>\pipe\<pipe>`
* `protocol` can be set to `np`
* For a non-URL DSN, the `server` parameter can be set to the full pipe name like `\\host\pipe\sql\query`

If no pipe name can be derived from the DSN, connection attempts will first query the SQL Browser service to find the pipe name for the instance.

### Protocol configuration

To force a specific protocol for the connection there two several options:
1. Prepend the server name in a DSN with the protocol and a colon, like `np:host` or `lpc:host` or `tcp:host`
2. Set the `protocol` parameter to the protocol name

`msdsn.ProtocolParsers` can be reordered to prioritize other protocols ahead of `tcp`

### Kerberos Active Directory authentication outside Windows
The package supports authentication via 3 methods.
Expand Down
33 changes: 19 additions & 14 deletions msdsn/conn_str.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,20 +197,6 @@ func Parse(dsn string) (Config, error) {
}
p.ConnTimeout = time.Duration(timeout) * time.Second
}
f := len(p.Protocols)
if f == 0 {
f = 1
}
p.DialTimeout = time.Duration(15*f) * time.Second
if strdialtimeout, ok := params["dial timeout"]; ok {
timeout, err := strconv.ParseUint(strdialtimeout, 10, 64)
if err != nil {
f := "invalid dial timeout '%v': %v"
return p, fmt.Errorf(f, strdialtimeout, err.Error())
}

p.DialTimeout = time.Duration(timeout) * time.Second
}

// default keep alive should be 30 seconds according to spec:
// https://msdn.microsoft.com/en-us/library/dd341108.aspx
Expand Down Expand Up @@ -353,6 +339,21 @@ func Parse(dsn string) (Config, error) {
return p, fmt.Errorf("No protocol handler is available for protocol: '%s'", protocol)
}

f := len(p.Protocols)
if f == 0 {
f = 1
}
p.DialTimeout = time.Duration(15*f) * time.Second
if strdialtimeout, ok := params["dial timeout"]; ok {
timeout, err := strconv.ParseUint(strdialtimeout, 10, 64)
if err != nil {
f := "invalid dial timeout '%v': %v"
return p, fmt.Errorf(f, strdialtimeout, err.Error())
}

p.DialTimeout = time.Duration(timeout) * time.Second
}

return p, nil
}

Expand All @@ -375,6 +376,10 @@ func (p Config) URL() *url.URL {
if ok {
q.Add("protocol", protocol)
}
pipe, ok := p.Parameters["pipe"]
if ok {
q.Add("pipe", pipe)
}
res := url.URL{
Scheme: "sqlserver",
Host: host,
Expand Down
32 changes: 27 additions & 5 deletions namedpipe/namedpipe_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,19 @@ type namedPipeData struct {
PipeName string
}

var azureDomains = []string{
".database.windows.net",
".database.chinacloudapi.cn",
".database.usgovcloudapi.net",
}

func (n namedPipeDialer) ParseServer(server string, p *msdsn.Config) error {
// assume a server name starting with \\ is the full named pipe path
if strings.HasPrefix(server, `\\`) {
p.ProtocolParameters[n.Protocol()] = &namedPipeData{PipeName: server}
} else if p.Port > 0 {
if p.Port > 0 {
return fmt.Errorf("Named pipes disallowed due to port being specified")
}
if strings.HasPrefix(server, `\\`) {
// assume a server name starting with \\ is the full named pipe path
p.ProtocolParameters[n.Protocol()] = namedPipeData{PipeName: server}
} else if p.Host == "" { // if the string specifies np:host\instance, tcpParser won't have filled in p.Host
parts := strings.SplitN(server, `\`, 2)
p.Host = parts[0]
Expand All @@ -30,6 +37,17 @@ func (n namedPipeDialer) ParseServer(server string, p *msdsn.Config) error {
if len(parts) > 1 {
p.Instance = parts[1]
}
} else {
host := strings.ToLower(p.Host)
for _, domain := range azureDomains {
if strings.HasSuffix(host, domain) {
return fmt.Errorf("Named pipes disallowed for Azure SQL Database connections")
}
}
}
pipe, ok := p.Parameters["pipe"]
if ok {
p.ProtocolParameters[n.Protocol()] = namedPipeData{PipeName: fmt.Sprintf(`\\%s\pipe\%s`, p.Host, pipe)}
}
return nil
}
Expand All @@ -46,7 +64,11 @@ func (n namedPipeDialer) ParseBrowserData(data msdsn.BrowserData, p *msdsn.Confi
if instance == "" {
instance = "MSSQLSERVER"
}
pipename, ok := data[instance]["np"]
ok := len(data) > 0
pipename := ""
if ok {
pipename, ok = data[instance]["np"]
}
if !ok {
f := "no named pipe instance matching '%v' returned from host '%v'"
return fmt.Errorf(f, p.Instance, p.Host)
Expand Down
5 changes: 3 additions & 2 deletions namedpipe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package mssql

import (
"strings"
"testing"

"github.com/microsoft/go-mssqldb/msdsn"
Expand All @@ -22,7 +23,7 @@ func TestNamedPipeProtocolInstalled(t *testing.T) {
func TestNamedPipeConnection(t *testing.T) {
params := testConnParams(t)
protocol, ok := params.Parameters["protocol"]
if ok && protocol != "np" {
if (ok && protocol != "np") || strings.Contains(params.Host, "database.windows.net") {
t.Skip("Test is not running with named pipe protocol set")
}
conn, _ := open(t)
Expand All @@ -31,6 +32,6 @@ func TestNamedPipeConnection(t *testing.T) {
t.Fatalf("Unable to query connection protocol %s", err.Error())
}
if protocol != "Named pipe" {
t.Fatalf("Named pips connection not made. Protocol: %s", protocol)
t.Fatalf("Named pipe connection not made. Protocol: %s", protocol)
}
}
9 changes: 7 additions & 2 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ type tcpDialer struct{}
func (t tcpDialer) ParseBrowserData(data msdsn.BrowserData, p *msdsn.Config) error {
// If instance is specified, but no port, check SQL Server Browser
// for the instance and discover its port.
p.Instance = strings.ToUpper(p.Instance)
strport, ok := data[p.Instance]["tcp"]
ok := len(data) > 0
strport := ""
if ok {
p.Instance = strings.ToUpper(p.Instance)
strport, ok = data[p.Instance]["tcp"]
}
if !ok {
f := "no instance matching '%v' returned from host '%v'"
return fmt.Errorf(f, p.Instance, p.Host)
Expand Down Expand Up @@ -106,6 +110,7 @@ func (t tcpDialer) DialSqlConnection(ctx context.Context, c *Connector, p *msdsn
if p.ServerSPN == "" {
p.ServerSPN = generateSpn(p.Host, instanceOrPort(p.Instance, p.Port))
}
p.Port = resolveServerPort(p.Port)
return conn, err
}

Expand Down
3 changes: 0 additions & 3 deletions queries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2082,9 +2082,6 @@ func getLatency(t *testing.T) time.Duration {
}
c := &Connector{params: params}
now := time.Now()
if err := queryBrowser(context.Background(), context.Background(), c, nil, &params); err != nil {
t.Fatalf("queryBrowser failed: %s", err.Error())
}
// Dialing both tcp and np for a named-pipes only connection takes a long time
if len(params.Protocols) > 1 && testing.Short() {
t.Skip("short")
Expand Down
99 changes: 32 additions & 67 deletions tds.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,21 @@ func parseInstances(msg []byte) msdsn.BrowserData {

func getInstances(ctx context.Context, d Dialer, address string) (msdsn.BrowserData, error) {
conn, err := d.DialContext(ctx, "udp", net.JoinHostPort(address, "1434"))
emptyInstances := msdsn.BrowserData{}
if err != nil {
return nil, err
return emptyInstances, err
}
defer conn.Close()
deadline, _ := ctx.Deadline()
conn.SetDeadline(deadline)
_, err = conn.Write([]byte{3})
if err != nil {
return nil, err
return emptyInstances, err
}
var resp = make([]byte, 16*1024-1)
read, err := conn.Read(resp)
if err != nil {
return nil, err
return emptyInstances, err
}
return parseInstances(resp[:read]), nil
}
Expand Down Expand Up @@ -897,8 +898,26 @@ func sendAttention(buf *tdsBuffer) error {

// Makes an attempt to connect with each available protocol, in order, until one succeeds or the timeout elapses
func dialConnection(ctx context.Context, c *Connector, p *msdsn.Config, logger ContextLogger) (conn net.Conn, err error) {
var instances msdsn.BrowserData
for _, protocol := range p.Protocols {
dialer := msdsn.ProtocolDialers[protocol]
if dialer.CallBrowser(p) {
if instances == nil {
d := c.getDialer(p)
instances, err = getInstances(ctx, d, p.Host)
if err != nil && logger != nil && uint64(p.LogFlags)&logErrors != 0 {
e := fmt.Sprintf("unable to get instances from Sql Server Browser on host %v: %v", p.Host, err.Error())
logger.Log(ctx, msdsn.Log(logErrors), e)
}
}
err = dialer.ParseBrowserData(instances, p)
if err != nil {
if logger != nil && uint64(p.LogFlags)&logErrors != 0 {
logger.Log(ctx, msdsn.Log(logErrors), "Skipping protocol "+protocol+". Error:"+err.Error())
}
continue
}
}
sqlDialer, ok := dialer.(MssqlProtocolDialer)
if logger != nil && uint64(p.LogFlags)&logDebug != 0 {
logger.Log(ctx, msdsn.LogDebug, "Dialing with protocol "+protocol)
Expand Down Expand Up @@ -1050,63 +1069,7 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont
return l, nil
}

func queryBrowser(ctx context.Context, dialCtx context.Context, c *Connector, logger ContextLogger, p *msdsn.Config) error {
var instances msdsn.BrowserData
var err error
pErrors := make(map[string]error)
for _, protocol := range p.Protocols {
pd, ok := msdsn.ProtocolDialers[protocol]
if !ok {
return fmt.Errorf("No dialer is configured for protocol '%s'", protocol)
}
if pd.CallBrowser(p) {
if instances == nil {
d := c.getDialer(p)
instances, err = getInstances(dialCtx, d, p.Host)
if err != nil {
f := "unable to get instances from Sql Server Browser on host %v: %v"
return fmt.Errorf(f, p.Host, err.Error())
}
}
pErr := pd.ParseBrowserData(instances, p)
if pErr != nil {
pErrors[protocol] = pErr
if logger != nil && uint64(p.LogFlags)&logErrors != 0 {
logger.Log(ctx, msdsn.Log(logErrors), "Removing protocol "+protocol+" from dialers. Error:"+pErr.Error())
}
}
}
}
// If any dialer got an error parsing instances, remove it from the dialer list
// If no dialers are left, return an error
if len(pErrors) == len(p.Protocols) {
return fmt.Errorf("Unable to find a matching instance for any supported protocol on host %v", p.Host)
}
if len(pErrors) > 0 {
validProtocols := make([]string, len(p.Protocols)-len(pErrors))
i := 0
for _, protocol := range p.Protocols {
_, hasError := pErrors[protocol]
if !hasError {
validProtocols[i] = protocol
i++
}
}
}
return nil
}

func connect(ctx context.Context, c *Connector, logger ContextLogger, p msdsn.Config) (res *tdsSession, err error) {
dialCtx := ctx
if p.DialTimeout >= 0 {
dt := p.DialTimeout
if dt == 0 {
dt = time.Duration(15*len(p.Protocols)) * time.Second
}
var cancel func()
dialCtx, cancel = context.WithTimeout(ctx, dt)
defer cancel()
}

// if instance is specified use instance resolution service
if len(p.Instance) > 0 && p.Port != 0 && uint64(p.LogFlags)&logDebug != 0 {
Expand All @@ -1116,14 +1079,6 @@ func connect(ctx context.Context, c *Connector, logger ContextLogger, p msdsn.Co
logger.Log(ctx, msdsn.LogDebug, "WARN: You specified both instance name and port in the connection string, port will be used and instance name will be ignored")
}

err = queryBrowser(ctx, dialCtx, c, logger, &p)
if err != nil {
return nil, err
}
if p.Port == 0 {
p.Port = defaultServerPort
}

packetSize := p.PacketSize
if packetSize == 0 {
packetSize = defaultPacketSize
Expand All @@ -1139,6 +1094,16 @@ func connect(ctx context.Context, c *Connector, logger ContextLogger, p msdsn.Co
}

initiate_connection:
dialCtx := ctx
if p.DialTimeout >= 0 {
dt := p.DialTimeout
if dt == 0 {
dt = time.Duration(15*len(p.Protocols)) * time.Second
}
var cancel func()
dialCtx, cancel = context.WithTimeout(ctx, dt)
defer cancel()
}
conn, err := dialConnection(dialCtx, c, &p, logger)
if err != nil {
return nil, err
Expand Down
9 changes: 6 additions & 3 deletions tds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ func GetConnParams() (*msdsn.Config, error) {
if os.Getenv("PROTOCOL") != "" {
c.Parameters["protocol"] = os.Getenv("PROTOCOL")
}
if os.Getenv("PIPE") != "" {
c.Parameters["pipe"] = os.Getenv("PIPE")
}
return c, nil
}
// try loading connection string from file
Expand Down Expand Up @@ -364,7 +367,7 @@ func TestConnect(t *testing.T) {

func TestConnectViaIp(t *testing.T) {
params := testConnParams(t)
if params.Encryption == msdsn.EncryptionRequired {
if params.Encryption == msdsn.EncryptionRequired || strings.Contains(params.Host, "database.windows.net") {
t.Skip("Unable to test connection to IP for servers that expect encryption")
}

Expand Down Expand Up @@ -596,7 +599,7 @@ func TestBadHost(t *testing.T) {
}

func TestSqlBrowserNotUsedIfPortSpecified(t *testing.T) {
const errorSubstrStringToCheckFor = "unable to get instances from Sql Server Browser"
const errorSubstrStringToCheckFor = "instance matching 'foobar' returned from host 'badhost'"

// Connect to an instance on a host that doesn't exist (so connection will always expectedly fail)
params := testConnParams(t)
Expand All @@ -611,7 +614,7 @@ func TestSqlBrowserNotUsedIfPortSpecified(t *testing.T) {

err := testConnectionBad(t, params.URL().String())

if !strings.Contains(err.Error(), errorSubstrStringToCheckFor) {
if !strings.Contains(strings.ToLower(err.Error()), errorSubstrStringToCheckFor) {
t.Fatalf("Connection should have tried to use SQL Browser. Error:%s", err.Error())
}

Expand Down

0 comments on commit efa88a7

Please sign in to comment.