Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Connection not closed when database name is incorrect #173 fix #224

Merged
merged 5 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tds.go
Original file line number Diff line number Diff line change
Expand Up @@ -1382,6 +1382,7 @@ initiate_connection:
if token.isError() {
tokenErr := token.getError()
tokenErr.Message = "login error: " + tokenErr.Message
conn.Close()
return nil, tokenErr
}
case error:
Expand Down
56 changes: 56 additions & 0 deletions tds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,62 @@ func TestBadCredentials(t *testing.T) {
_ = testConnectionBad(t, params.URL().String())
}

func TestLeakedConnections(t *testing.T) {
goodParams := testConnParams(t)
badParams := testConnParams(t)
badParams.Database = "unknown_db"

// Connecting with good credentials should not fail
goodConn, err := sql.Open("sqlserver", goodParams.URL().String())
if err != nil {
t.Fatal("Open connection failed:", err.Error())
}
err = goodConn.Ping()
if err != nil {
t.Fatal("Ping with good credentials should not fail, but got error:", err.Error())
}

var localNetAddr sql.NullString
err = goodConn.QueryRow("SELECT local_net_address FROM sys.dm_exec_connections WHERE session_id=@@SPID").Scan(&localNetAddr)
shueybubbles marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
t.Fatal("cannot scan local_net_address value", err)
}
if !localNetAddr.Valid {
shueybubbles marked this conversation as resolved.
Show resolved Hide resolved
t.Fatal("local_net_address should not be NULL")
}

// Remember the number of open connections from local_net_address, excluding the current one
var openConnections int
err = goodConn.QueryRow("SELECT COUNT(*) AS openConnections FROM sys.dm_exec_connections WHERE session_id!=@@SPID AND local_net_address=@p1", localNetAddr).Scan(&openConnections)
if err != nil {
t.Fatal("cannot scan value", err)
}

// Open 10 connections to the unknown database, all should be closed immediately
for i := 0; i < 10; i++ {
conn, err := sql.Open("sqlserver", badParams.URL().String())
if err != nil {
// should not fail here
t.Fatal("sql.Open failed:", err.Error())
}
err = conn.Ping()
if err == nil {
t.Fatalf("Pinging %s should fail, but it succeeded", badParams.Database)
}
conn.Close() // force close the connection
}

// Check if the number of open connections is the same as before
var newOpenConnections int
err = goodConn.QueryRow("SELECT COUNT(*) AS openConnections FROM sys.dm_exec_connections WHERE session_id!=@@SPID AND local_net_address=@p1", localNetAddr).Scan(&newOpenConnections)
if err != nil {
t.Fatal("cannot scan value", err)
}
if openConnections != newOpenConnections {
t.Fatalf("Number of open connections should be the same as before, %d leaked connections found", newOpenConnections-openConnections)
}
}

func TestBadHost(t *testing.T) {
params := testConnParams(t)
params.Host = "badhost"
Expand Down
Loading