Skip to content

Commit

Permalink
Make sure channels are closed and close Rows and Stmt (denisenkom#600)
Browse files Browse the repository at this point in the history
* make sure channels are closed

* fix checking errors

* fix race condition in (*Stmt).processExec

* fix race condition in (*Stmt).processQueryResponse

* fix race cindition in TestIgnoreEmptyResults

* fix TestQueryCancelLowLevel

* fix sqlclosecheck warnings
  • Loading branch information
shogo82148 authored Nov 3, 2020
1 parent d4db269 commit 628e054
Show file tree
Hide file tree
Showing 11 changed files with 132 additions and 58 deletions.
4 changes: 4 additions & 0 deletions bulkcopy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ func TestBulkcopy(t *testing.T) {
t.Log("Preparing copy in statement")

stmt, err := conn.PrepareContext(ctx, CopyIn(tableName, BulkOptions{}, columns...))
if err != nil {
t.Fatal(err)
}
defer stmt.Close()

for i := 0; i < 10; i++ {
t.Logf("Executing copy in statement %d time with %d values", i+1, len(values))
Expand Down
4 changes: 3 additions & 1 deletion datetimeoffset_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
"log"
"time"

mssql "github.com/denisenkom/go-mssqldb"
"github.com/golang-sql/civil"
"github.com/denisenkom/go-mssqldb"
)

// This example shows how to insert and retrieve date and time types data
Expand Down Expand Up @@ -49,6 +49,8 @@ func insertDateTime(db *sql.DB) {
if err != nil {
log.Fatal(err)
}
defer stmt.Close()

tin, err := time.Parse(time.RFC3339, "2006-01-02T22:04:05.787-07:00")
if err != nil {
log.Fatal(err)
Expand Down
1 change: 1 addition & 0 deletions lastinsertid_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ func ExampleLastInsertId() {
if err != nil {
log.Fatal(err)
}
defer rows.Close()
var lastInsertId1 int64
for rows.Next() {
rows.Scan(&lastInsertId1)
Expand Down
38 changes: 30 additions & 8 deletions mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,17 +204,21 @@ func (c *Conn) simpleProcessResp(ctx context.Context) error {
tokchan := make(chan tokenStruct, 5)
go processResponse(ctx, c.sess, tokchan, c.outs)
c.clearOuts()

var err error
for tok := range tokchan {
switch token := tok.(type) {
case doneStruct:
if token.isError() {
return c.checkBadConn(token.getError())
if token.isError() && err == nil {
err = c.checkBadConn(token.getError())
}
case error:
return c.checkBadConn(token)
if err == nil {
err = c.checkBadConn(token)
}
}
}
return nil
return err
}

func (c *Conn) Commit() error {
Expand Down Expand Up @@ -617,12 +621,22 @@ loop:
case doneStruct:
if token.isError() {
cancel()

// make sure tokchan is closed
for range tokchan {
}

return nil, s.c.checkBadConn(token.getError())
}
case ReturnStatus:
s.c.setReturnStatus(token)
case error:
cancel()

// make sure tokchan is closed
for range tokchan {
}

return nil, s.c.checkBadConn(token)
}
}
Expand Down Expand Up @@ -662,15 +676,20 @@ func (s *Stmt) processExec(ctx context.Context) (res driver.Result, err error) {
if token.Status&doneCount != 0 {
rowCount += int64(token.RowCount)
}
if token.isError() {
return nil, token.getError()
if token.isError() && err == nil {
err = token.getError()
}
case ReturnStatus:
s.c.setReturnStatus(token)
case error:
return nil, token
if err == nil {
err = token
}
}
}
if err != nil {
return nil, err
}
return &Result{s.c, rowCount}, nil
}

Expand All @@ -686,9 +705,12 @@ type Rows struct {

func (rc *Rows) Close() error {
rc.cancel()
for _ = range rc.tokchan {

// make sure tokchan is closed
for range rc.tokchan {
}
rc.tokchan = nil

return nil
}

Expand Down
2 changes: 2 additions & 0 deletions queries_go110_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ func TestReturnStatusWithQuery(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer rows.Close()

var str string
for rows.Next() {
err = rows.Scan(&str)
Expand Down
3 changes: 2 additions & 1 deletion queries_go19_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ END;
if err != nil {
t.Error(err)
}
defer rows.Close()
// reading first row
if !rows.Next() {
t.Error("Next returned false")
Expand Down Expand Up @@ -947,10 +948,10 @@ with
}
return
}
defer rows.Close()
for rows.Next() {
// Nothing.
}
rows.Close()
})
}
}
Expand Down
81 changes: 52 additions & 29 deletions queries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -622,8 +622,9 @@ func TestError(t *testing.T) {
conn := open(t)
defer conn.Close()

_, err := conn.Query("exec bad")
row, err := conn.Query("exec bad")
if err == nil {
defer row.Close()
t.Fatal("Query should fail")
}

Expand All @@ -645,6 +646,7 @@ func TestQueryNoRows(t *testing.T) {
if rows, err = conn.Query("create table #abc (fld int)"); err != nil {
t.Fatal("Query failed", err)
}
defer rows.Close()
if rows.Next() {
t.Fatal("Query shoulnd't return any rows")
}
Expand Down Expand Up @@ -697,6 +699,7 @@ func TestOrderBy(t *testing.T) {
if err != nil {
t.Fatal("Query failed", err)
}
defer rows.Close()

for rows.Next() {
var fld1 int32
Expand Down Expand Up @@ -880,21 +883,25 @@ func TestUniqueIdentifierParam(t *testing.T) {
func TestBigQuery(t *testing.T) {
conn := open(t)
defer conn.Close()
rows, err := conn.Query(`WITH n(n) AS

func() {
rows, err := conn.Query(`WITH n(n) AS
(
SELECT 1
UNION ALL
SELECT n+1 FROM n WHERE n < 10000
)
SELECT n, @@version FROM n ORDER BY n
OPTION (MAXRECURSION 10000);`)
if err != nil {
t.Fatal("cannot exec query", err)
}
rows.Next()
rows.Close()
if err != nil {
t.Fatal("cannot exec query", err)
}
defer rows.Close()
rows.Next()
}()

var res int
err = conn.QueryRow("select 0").Scan(&res)
err := conn.QueryRow("select 0").Scan(&res)
if err != nil {
t.Fatal("cannot scan value", err)
}
Expand Down Expand Up @@ -936,6 +943,7 @@ func TestIgnoreEmptyResults(t *testing.T) {
if err != nil {
t.Fatal("Query failed", err.Error())
}
defer rows.Close()
if !rows.Next() {
t.Fatal("Query didn't return row")
}
Expand Down Expand Up @@ -1039,17 +1047,20 @@ func TestConnectionClosing(t *testing.T) {
return
}

stmt, err := pool.Query("select 1")
if err != nil {
t.Fatalf("Query failed with unexpected error %s", err)
}
for stmt.Next() {
var val interface{}
err := stmt.Scan(&val)
func() {
rows, err := pool.Query("select 1")
if err != nil {
t.Fatalf("Query failed with unexpected error %s", err)
}
}
defer rows.Close()
for rows.Next() {
var val interface{}
err := rows.Scan(&val)
if err != nil {
t.Fatalf("Query failed with unexpected error %s", err)
}
}
}()
}
}

Expand Down Expand Up @@ -1536,6 +1547,7 @@ func TestColumnIntrospection(t *testing.T) {
if err != nil {
t.Fatalf("Query failed with unexpected error %s", err)
}
defer rows.Close()

ct, err := rows.ColumnTypes()
if err != nil {
Expand Down Expand Up @@ -1595,19 +1607,27 @@ func TestContext(t *testing.T) {
t.Errorf("BeginTx failed with unexpected error %s", err)
return
}
rows, err := tx.QueryContext(ctx, "DBCC USEROPTIONS")
properties := make(map[string]string)
for rows.Next() {
var name, value string
if err = rows.Scan(&name, &value); err != nil {
t.Errorf("Scan failed with unexpected error %s", err)
}
properties[name] = value
}
defer tx.Rollback()

if properties["isolation level"] != "serializable" {
t.Errorf("Expected isolation level to be serializable but it is %s", properties["isolation level"])
}
// check the isolation level
func() {
rows, err := tx.QueryContext(ctx, "DBCC USEROPTIONS")
if err != nil {
t.Fatal(err)
}
defer rows.Close()
properties := make(map[string]string)
for rows.Next() {
var name, value string
if err = rows.Scan(&name, &value); err != nil {
t.Errorf("Scan failed with unexpected error %s", err)
}
properties[name] = value
}
if properties["isolation level"] != "serializable" {
t.Errorf("Expected isolation level to be serializable but it is %s", properties["isolation level"])
}
}()

row := tx.QueryRowContext(ctx, "select 1")
var val int64
Expand All @@ -1624,11 +1644,12 @@ func TestContext(t *testing.T) {
return
}

_, err = tx.PrepareContext(ctx, "select 1")
stmt, err := tx.PrepareContext(ctx, "select 1")
if err != nil {
t.Errorf("PrepareContext failed with unexpected error %s", err)
return
}
defer stmt.Close()
}

func TestBeginTxtReadOnlyNotSupported(t *testing.T) {
Expand Down Expand Up @@ -1673,6 +1694,7 @@ func TestConn_BeginTx(t *testing.T) {
if err != nil {
t.Fatal("select failed with error", err)
}
defer rows.Close()
values := []int64{}
for rows.Next() {
var val int64
Expand Down Expand Up @@ -1806,6 +1828,7 @@ func TestQueryCancelLowLevel(t *testing.T) {
if err != nil {
t.Fatalf("Query failed with error %v", err)
}
defer rows.Close()

values := []driver.Value{nil}
err = rows.Next(values)
Expand Down
5 changes: 5 additions & 0 deletions tds.go
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,11 @@ initiate_connection:
if token.isError() {
return nil, fmt.Errorf("Login error: %s", token.getError())
}

// make sure tokchan is closed
for range tokchan {
}

goto loginEnd
}
}
Expand Down
33 changes: 20 additions & 13 deletions tds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ func TestSendSqlBatch(t *testing.T) {

ch := make(chan tokenStruct, 5)
go processResponse(context.Background(), conn, ch, nil)
defer func() {
// make share ch is closed
for range ch {
}
}()

var lastRow []interface{}
loop:
Expand Down Expand Up @@ -338,20 +343,22 @@ func TestMultipleQueryClose(t *testing.T) {
}
defer stmt.Close()

rows, err := stmt.Query()
if err != nil {
t.Error("Query failed:", err.Error())
return
}
rows.Close()
func() {
rows, err := stmt.Query()
if err != nil {
t.Fatal("Query failed:", err.Error())
}
defer rows.Close()
}()

rows, err = stmt.Query()
if err != nil {
t.Error("Query failed:", err.Error())
return
}
defer rows.Close()
checkSimpleQuery(rows, t)
func() {
rows, err := stmt.Query()
if err != nil {
t.Fatal("Query failed:", err.Error())
}
defer rows.Close()
checkSimpleQuery(rows, t)
}()
}

func TestPing(t *testing.T) {
Expand Down
Loading

0 comments on commit 628e054

Please sign in to comment.