Skip to content

Commit

Permalink
Fix: Connection not closed when database name is incorrect #173 fix (#…
Browse files Browse the repository at this point in the history
…224)

* connection not closed when database name is incorrect #173 fix

* test for leaked connections (connection not closed when database name is incorrect #173

* Checking the number of open connections from local_net_address only

* using sql.NullString for localNetAddr

* handling local_net_address==NULL correctly
  • Loading branch information
parMaster authored Oct 14, 2024
1 parent 02deabf commit 573423d
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
1 change: 1 addition & 0 deletions tds.go
Original file line number Diff line number Diff line change
Expand Up @@ -1382,6 +1382,7 @@ initiate_connection:
if token.isError() {
tokenErr := token.getError()
tokenErr.Message = "login error: " + tokenErr.Message
conn.Close()
return nil, tokenErr
}
case error:
Expand Down
66 changes: 66 additions & 0 deletions tds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,72 @@ func TestBadCredentials(t *testing.T) {
_ = testConnectionBad(t, params.URL().String())
}

func TestLeakedConnections(t *testing.T) {
goodParams := testConnParams(t)
badParams := testConnParams(t)
badParams.Database = "unknown_db"

// Connecting with good credentials should not fail
goodConn, err := sql.Open("sqlserver", goodParams.URL().String())
if err != nil {
t.Fatal("Open connection failed:", err.Error())
}
err = goodConn.Ping()
if err != nil {
t.Fatal("Ping with good credentials should not fail, but got error:", err.Error())
}

var localNetAddr sql.NullString
err = goodConn.QueryRow("SELECT local_net_address FROM sys.dm_exec_connections WHERE session_id=@@SPID").Scan(&localNetAddr)
if err != nil {
t.Fatal("cannot scan local_net_address value", err)
}

// Remember the number of open connections from local_net_address, excluding the current one
// NULL value is possible, particularly for non-tcp local connections
var openConnections int
err = goodConn.QueryRow(`
SELECT COUNT(*) AS openConnections
FROM sys.dm_exec_connections
WHERE session_id != @@SPID
AND ((@p1 IS NULL AND local_net_address IS NULL)
OR local_net_address = @p1)`,
localNetAddr).Scan(&openConnections)
if err != nil {
t.Fatal("cannot scan value", err)
}

// Open 10 connections to the unknown database, all should be closed immediately
for i := 0; i < 10; i++ {
conn, err := sql.Open("sqlserver", badParams.URL().String())
if err != nil {
// should not fail here
t.Fatal("sql.Open failed:", err.Error())
}
err = conn.Ping()
if err == nil {
t.Fatalf("Pinging %s should fail, but it succeeded", badParams.Database)
}
conn.Close() // force close the connection
}

// Check if the number of open connections is the same as before
var newOpenConnections int
err = goodConn.QueryRow(`
SELECT COUNT(*) AS openConnections
FROM sys.dm_exec_connections
WHERE session_id != @@SPID
AND ((@p1 IS NULL AND local_net_address IS NULL)
OR local_net_address = @p1)`,
localNetAddr).Scan(&newOpenConnections)
if err != nil {
t.Fatal("cannot scan value", err)
}
if openConnections != newOpenConnections {
t.Fatalf("Number of open connections should be the same as before, %d leaked connections found", newOpenConnections-openConnections)
}
}

func TestBadHost(t *testing.T) {
params := testConnParams(t)
params.Host = "badhost"
Expand Down

0 comments on commit 573423d

Please sign in to comment.