From 3b47a4febefa84f99b2385d54fc3756d5338d2fb Mon Sep 17 00:00:00 2001 From: parmaster Date: Wed, 9 Oct 2024 14:01:36 +0300 Subject: [PATCH 1/5] connection not closed when database name is incorrect #173 fix --- tds.go | 1 + 1 file changed, 1 insertion(+) diff --git a/tds.go b/tds.go index 5142df10..312d412e 100644 --- a/tds.go +++ b/tds.go @@ -1382,6 +1382,7 @@ initiate_connection: if token.isError() { tokenErr := token.getError() tokenErr.Message = "login error: " + tokenErr.Message + conn.Close() return nil, tokenErr } case error: From 5b9d44f846dc0d8b870c812e793c0def6523681c Mon Sep 17 00:00:00 2001 From: parmaster Date: Fri, 11 Oct 2024 12:45:44 +0300 Subject: [PATCH 2/5] test for leaked connections (connection not closed when database name is incorrect #173 --- tds_test.go | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tds_test.go b/tds_test.go index a656231e..386dab5d 100644 --- a/tds_test.go +++ b/tds_test.go @@ -694,6 +694,52 @@ func TestBadCredentials(t *testing.T) { _ = testConnectionBad(t, params.URL().String()) } +func TestLeakedConnections(t *testing.T) { + goodParams := testConnParams(t) + badParams := testConnParams(t) + badParams.Database = "unknown_db" + + // Connecting with good credentials should not fail + goodConn, err := sql.Open("mssql", goodParams.URL().String()) + if err != nil { + t.Fatal("Open connection failed:", err.Error()) + } + err = goodConn.Ping() + if err != nil { + t.Fatal("Ping with good credentials should not fail, but got error:", err.Error()) + } + // Remember the number of open connections, excluding the current one + var openConnections int + err = goodConn.QueryRow("SELECT COUNT(*) AS openConnections FROM sys.dm_exec_connections WHERE session_id!=@@SPID").Scan(&openConnections) + if err != nil { + t.Fatal("cannot scan value", err) + } + + // Open 10 connections to the unknown database, all should be closed immediately + for i := 0; i < 10; i++ { + conn, err := sql.Open("mssql", badParams.URL().String()) + if err != nil { + // should not fail here + t.Fatal("sql.Open failed:", err.Error()) + } + err = conn.Ping() + if err == nil { + t.Fatalf("Pinging %s should fail, but it succeeded", badParams.Database) + } + conn.Close() // force close the connection + } + + // Check if the number of open connections is the same as before + var newOpenConnections int + err = goodConn.QueryRow("SELECT COUNT(*) AS openConnections FROM sys.dm_exec_connections WHERE session_id!=@@SPID").Scan(&newOpenConnections) + if err != nil { + t.Fatal("cannot scan value", err) + } + if openConnections != newOpenConnections { + t.Fatalf("Number of open connections should be the same as before, %d leaked connections found", newOpenConnections-openConnections) + } +} + func TestBadHost(t *testing.T) { params := testConnParams(t) params.Host = "badhost" From 4cc16353774eb058b6bff4ac90fe4d4ffd13057d Mon Sep 17 00:00:00 2001 From: parmaster Date: Fri, 11 Oct 2024 17:33:05 +0300 Subject: [PATCH 3/5] Checking the number of open connections from local_net_address only --- tds_test.go | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tds_test.go b/tds_test.go index 386dab5d..497f7a12 100644 --- a/tds_test.go +++ b/tds_test.go @@ -700,7 +700,7 @@ func TestLeakedConnections(t *testing.T) { badParams.Database = "unknown_db" // Connecting with good credentials should not fail - goodConn, err := sql.Open("mssql", goodParams.URL().String()) + goodConn, err := sql.Open("sqlserver", goodParams.URL().String()) if err != nil { t.Fatal("Open connection failed:", err.Error()) } @@ -708,16 +708,23 @@ func TestLeakedConnections(t *testing.T) { if err != nil { t.Fatal("Ping with good credentials should not fail, but got error:", err.Error()) } - // Remember the number of open connections, excluding the current one + + var localNetAddr string + err = goodConn.QueryRow("SELECT local_net_address FROM sys.dm_exec_connections WHERE session_id=@@SPID").Scan(&localNetAddr) + if err != nil { + t.Fatal("cannot scan local_net_address value", err) + } + + // Remember the number of open connections from local_net_address, excluding the current one var openConnections int - err = goodConn.QueryRow("SELECT COUNT(*) AS openConnections FROM sys.dm_exec_connections WHERE session_id!=@@SPID").Scan(&openConnections) + err = goodConn.QueryRow("SELECT COUNT(*) AS openConnections FROM sys.dm_exec_connections WHERE session_id!=@@SPID AND local_net_address=@p1", localNetAddr).Scan(&openConnections) if err != nil { t.Fatal("cannot scan value", err) } // Open 10 connections to the unknown database, all should be closed immediately for i := 0; i < 10; i++ { - conn, err := sql.Open("mssql", badParams.URL().String()) + conn, err := sql.Open("sqlserver", badParams.URL().String()) if err != nil { // should not fail here t.Fatal("sql.Open failed:", err.Error()) @@ -731,7 +738,7 @@ func TestLeakedConnections(t *testing.T) { // Check if the number of open connections is the same as before var newOpenConnections int - err = goodConn.QueryRow("SELECT COUNT(*) AS openConnections FROM sys.dm_exec_connections WHERE session_id!=@@SPID").Scan(&newOpenConnections) + err = goodConn.QueryRow("SELECT COUNT(*) AS openConnections FROM sys.dm_exec_connections WHERE session_id!=@@SPID AND local_net_address=@p1", localNetAddr).Scan(&newOpenConnections) if err != nil { t.Fatal("cannot scan value", err) } From 5c143a721ddbfe8b1512f78d0af4a36b722d3f7d Mon Sep 17 00:00:00 2001 From: parmaster Date: Fri, 11 Oct 2024 18:16:44 +0300 Subject: [PATCH 4/5] using sql.NullString for localNetAddr --- tds_test.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tds_test.go b/tds_test.go index 497f7a12..3fd102fa 100644 --- a/tds_test.go +++ b/tds_test.go @@ -709,11 +709,14 @@ func TestLeakedConnections(t *testing.T) { t.Fatal("Ping with good credentials should not fail, but got error:", err.Error()) } - var localNetAddr string + var localNetAddr sql.NullString err = goodConn.QueryRow("SELECT local_net_address FROM sys.dm_exec_connections WHERE session_id=@@SPID").Scan(&localNetAddr) if err != nil { t.Fatal("cannot scan local_net_address value", err) } + if !localNetAddr.Valid { + t.Fatal("local_net_address should not be NULL") + } // Remember the number of open connections from local_net_address, excluding the current one var openConnections int From baba64018a895e78f9970ce1c7154711fd192df9 Mon Sep 17 00:00:00 2001 From: parmaster Date: Fri, 11 Oct 2024 20:10:08 +0300 Subject: [PATCH 5/5] handling local_net_address==NULL correctly --- tds_test.go | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/tds_test.go b/tds_test.go index 3fd102fa..499ee084 100644 --- a/tds_test.go +++ b/tds_test.go @@ -714,13 +714,17 @@ func TestLeakedConnections(t *testing.T) { if err != nil { t.Fatal("cannot scan local_net_address value", err) } - if !localNetAddr.Valid { - t.Fatal("local_net_address should not be NULL") - } // Remember the number of open connections from local_net_address, excluding the current one + // NULL value is possible, particularly for non-tcp local connections var openConnections int - err = goodConn.QueryRow("SELECT COUNT(*) AS openConnections FROM sys.dm_exec_connections WHERE session_id!=@@SPID AND local_net_address=@p1", localNetAddr).Scan(&openConnections) + err = goodConn.QueryRow(` + SELECT COUNT(*) AS openConnections + FROM sys.dm_exec_connections + WHERE session_id != @@SPID + AND ((@p1 IS NULL AND local_net_address IS NULL) + OR local_net_address = @p1)`, + localNetAddr).Scan(&openConnections) if err != nil { t.Fatal("cannot scan value", err) } @@ -741,7 +745,13 @@ func TestLeakedConnections(t *testing.T) { // Check if the number of open connections is the same as before var newOpenConnections int - err = goodConn.QueryRow("SELECT COUNT(*) AS openConnections FROM sys.dm_exec_connections WHERE session_id!=@@SPID AND local_net_address=@p1", localNetAddr).Scan(&newOpenConnections) + err = goodConn.QueryRow(` + SELECT COUNT(*) AS openConnections + FROM sys.dm_exec_connections + WHERE session_id != @@SPID + AND ((@p1 IS NULL AND local_net_address IS NULL) + OR local_net_address = @p1)`, + localNetAddr).Scan(&newOpenConnections) if err != nil { t.Fatal("cannot scan value", err) }