diff --git a/expectations.go b/expectations.go index fb3b66d..4d04519 100644 --- a/expectations.go +++ b/expectations.go @@ -266,58 +266,33 @@ func (e *ExpectedExec) WillReturnResult(result pgconn.CommandTag) *ExpectedExec // Returned by pgxmock.ExpectPrepare. type ExpectedPrepare struct { commonExpectation - mock *pgxmock expectStmtName string expectSQL string - deallocateErr error - mustBeClosed bool - deallocated bool } -// WillReturnCloseError allows to set an error for this prepared statement Close action -func (e *ExpectedPrepare) WillReturnCloseError(err error) *ExpectedPrepare { - e.deallocateErr = err - return e -} - -// WillBeClosed is for backward compatibility only and will be removed soon. -// -// Deprecated: One should use WillBeDeallocated() instead. -func (e *ExpectedPrepare) WillBeClosed() *ExpectedPrepare { - return e.WillBeDeallocated() -} - -// WillBeDeallocated expects this prepared statement to be deallocated -func (e *ExpectedPrepare) WillBeDeallocated() *ExpectedPrepare { - e.mustBeClosed = true - return e -} - -// ExpectQuery allows to expect Query() or QueryRow() on this prepared statement. -// This method is convenient in order to prevent duplicating sql query string matching. -func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery { - eq := &ExpectedQuery{} - eq.expectSQL = e.expectStmtName - e.mock.expectations = append(e.mock.expectations, eq) - return eq +// String returns string representation +func (e *ExpectedPrepare) String() string { + msg := "ExpectedPrepare => expecting call to Prepare():\n" + msg += fmt.Sprintf("\t- matches statement name: '%s'\n", e.expectStmtName) + msg += fmt.Sprintf("\t- matches sql: '%s'\n", e.expectSQL) + return msg + e.commonExpectation.String() } -// ExpectExec allows to expect Exec() on this prepared statement. -// This method is convenient in order to prevent duplicating sql query string matching. -func (e *ExpectedPrepare) ExpectExec() *ExpectedExec { - eq := &ExpectedExec{} - eq.expectSQL = e.expectStmtName - e.mock.expectations = append(e.mock.expectations, eq) - return eq +// ExpectedDeallocate is used to manage pgx.Deallocate and pgx.DeallocateAll expectations. +// Returned by pgxmock.ExpectDeallocate(string) and pgxmock.ExpectDeallocateAll(). +type ExpectedDeallocate struct { + commonExpectation + expectStmtName string + expectAll bool } // String returns string representation -func (e *ExpectedPrepare) String() string { - msg := "ExpectedPrepare => expecting call to Prepare():" - msg += fmt.Sprintf("\t- matches statement name: '%s'", e.expectStmtName) - msg += fmt.Sprintf("\t- matches sql: '%s'\n", e.expectSQL) - if e.deallocateErr != nil { - msg += fmt.Sprintf("\t- returns error on Close: %s", e.deallocateErr) +func (e *ExpectedDeallocate) String() string { + msg := "ExpectedDeallocate => expecting call to Deallocate():\n" + if e.expectAll { + msg += "\t- matches all statements\n" + } else { + msg += fmt.Sprintf("\t- matches statement name: '%s'\n", e.expectStmtName) } return msg + e.commonExpectation.String() } diff --git a/pgxmock.go b/pgxmock.go index 6c5620d..2204cbe 100644 --- a/pgxmock.go +++ b/pgxmock.go @@ -35,11 +35,13 @@ type Expecter interface { ExpectClose() *ExpectedClose // ExpectPrepare expects Prepare() to be called with expectedSQL query. - // the *ExpectedPrepare allows to mock database response. - // Note that you may expect Query() or Exec() on the *ExpectedPrepare - // statement to prevent repeating expectedSQL ExpectPrepare(expectedStmtName, expectedSQL string) *ExpectedPrepare + // ExpectDeallocate expects Deallocate() to be called with expectedStmtName. + // The *ExpectedDeallocate allows to mock database response + ExpectDeallocate(expectedStmtName string) *ExpectedDeallocate + ExpectDeallocateAll() *ExpectedDeallocate + // ExpectQuery expects Query() or QueryRow() to be called with expectedSQL query. // the *ExpectedQuery allows to mock database response. ExpectQuery(expectedSQL string) *ExpectedQuery @@ -114,6 +116,7 @@ type PgxConnIface interface { PgxCommonIface Close(ctx context.Context) error Deallocate(ctx context.Context, name string) error + DeallocateAll(ctx context.Context) error Config() *pgx.ConnConfig PgConn() *pgconn.PgConn } @@ -166,13 +169,6 @@ func (c *pgxmock) ExpectationsWereMet() error { return fmt.Errorf("there is a remaining expectation which was not matched: %s", e) } - // for expected prepared statement check whether it was closed if expected - if prep, ok := e.(*ExpectedPrepare); ok { - if prep.mustBeClosed && !prep.deallocated { - return fmt.Errorf("expected prepared statement to be closed, but it was not: %s", prep) - } - } - // must check whether all expected queried rows are closed if query, ok := e.(*ExpectedQuery); ok { if query.rowsMustBeClosed && !query.rowsWereClosed { @@ -241,7 +237,19 @@ func (c *pgxmock) ExpectPing() *ExpectedPing { } func (c *pgxmock) ExpectPrepare(expectedStmtName, expectedSQL string) *ExpectedPrepare { - e := &ExpectedPrepare{expectSQL: expectedSQL, expectStmtName: expectedStmtName, mock: c} + e := &ExpectedPrepare{expectSQL: expectedSQL, expectStmtName: expectedStmtName} + c.expectations = append(c.expectations, e) + return e +} + +func (c *pgxmock) ExpectDeallocate(expectedStmtName string) *ExpectedDeallocate { + e := &ExpectedDeallocate{expectStmtName: expectedStmtName} + c.expectations = append(c.expectations, e) + return e +} + +func (c *pgxmock) ExpectDeallocateAll() *ExpectedDeallocate { + e := &ExpectedDeallocate{expectAll: true} c.expectations = append(c.expectations, e) return e } @@ -371,27 +379,32 @@ func (c *pgxmock) Prepare(ctx context.Context, name, query string) (*pgconn.Stat } func (c *pgxmock) Deallocate(ctx context.Context, name string) error { - var ( - expected *ExpectedPrepare - ok bool - ) - for _, next := range c.expectations { - next.Lock() - expected, ok = next.(*ExpectedPrepare) - ok = ok && expected.expectStmtName == name - next.Unlock() - if ok { - break + ex, err := findExpectationFunc[*ExpectedDeallocate](c, "Deallocate()", func(deallocateExp *ExpectedDeallocate) error { + if deallocateExp.expectAll { + return fmt.Errorf("Deallocate: all prepared statements were expected to be deallocated, instead only '%s' specified", name) } + if deallocateExp.expectStmtName != name { + return fmt.Errorf("Deallocate: prepared statement name '%s' was not expected, expected name is '%s'", name, deallocateExp.expectStmtName) + } + return nil + }) + if err != nil { + return err } - if expected == nil { - return fmt.Errorf("Deallocate: prepared statement name '%s' doesn't exist", name) - } - if ctx.Err() != nil { - return ctx.Err() + return ex.waitForDelay(ctx) +} + +func (c *pgxmock) DeallocateAll(ctx context.Context) error { + ex, err := findExpectationFunc[*ExpectedDeallocate](c, "DeallocateAll()", func(deallocateExp *ExpectedDeallocate) error { + if !deallocateExp.expectAll { + return fmt.Errorf("Deallocate: deallocate all prepared statements was not expected, expected name is '%s'", deallocateExp.expectStmtName) + } + return nil + }) + if err != nil { + return err } - expected.deallocated = true - return expected.deallocateErr + return ex.waitForDelay(ctx) } func (c *pgxmock) Commit(ctx context.Context) error { diff --git a/pgxmock_test.go b/pgxmock_test.go index 25bb160..0bffb95 100644 --- a/pgxmock_test.go +++ b/pgxmock_test.go @@ -242,8 +242,8 @@ func TestPrepareExpectations(t *testing.T) { a := assert.New(t) expErr := errors.New("invaders must die") mock.ExpectPrepare("foo", "SELECT (.+) FROM articles WHERE id = ?"). - WillReturnCloseError(expErr). WillDelayFor(1 * time.Second) + mock.ExpectDeallocate("foo").WillReturnError(expErr) stmt, err := mock.Prepare(context.Background(), "baz", "SELECT (.+) FROM articles WHERE id = ?") a.Error(err, "wrong prepare stmt name should raise an error") @@ -344,12 +344,12 @@ func TestUnorderedPreparedQueryExecutions(t *testing.T) { mock.MatchExpectationsInOrder(false) - mock.ExpectPrepare("articles_stmt", "SELECT (.+) FROM articles WHERE id = ?"). - ExpectQuery(). + mock.ExpectPrepare("articles_stmt", "SELECT (.+) FROM articles WHERE id = ?") + mock.ExpectQuery("articles_stmt"). WithArgs(5). WillReturnRows(NewRows([]string{"id", "title"}).AddRow(5, "The quick brown fox")) - mock.ExpectPrepare("authors_stmt", "SELECT (.+) FROM authors WHERE id = ?"). - ExpectQuery(). + mock.ExpectPrepare("authors_stmt", "SELECT (.+) FROM authors WHERE id = ?") + mock.ExpectQuery("authors_stmt"). WithArgs(1). WillReturnRows(NewRows([]string{"id", "title"}).AddRow(1, "Betty B.")) @@ -911,9 +911,9 @@ func TestPrepareExec(t *testing.T) { } defer mock.Close(context.Background()) mock.ExpectBegin() - ep := mock.ExpectPrepare("foo", "INSERT INTO ORDERS\\(ID, STATUS\\) VALUES \\(\\?, \\?\\)") + mock.ExpectPrepare("foo", "INSERT INTO ORDERS\\(ID, STATUS\\) VALUES \\(\\?, \\?\\)") for i := 0; i < 3; i++ { - ep.ExpectExec().WithArgs(AnyArg(), AnyArg()).WillReturnResult(NewResult("UPDATE", 1)) + mock.ExpectExec("foo").WithArgs(AnyArg(), AnyArg()).WillReturnResult(NewResult("UPDATE", 1)) } mock.ExpectCommit() tx, _ := mock.Begin(context.Background()) @@ -942,8 +942,8 @@ func TestPrepareQuery(t *testing.T) { } defer mock.Close(context.Background()) mock.ExpectBegin() - ep := mock.ExpectPrepare("foo", "SELECT ID, STATUS FROM ORDERS WHERE ID = \\?") - ep.ExpectQuery().WithArgs(101).WillReturnRows(NewRows([]string{"ID", "STATUS"}).AddRow(101, "Hello")) + mock.ExpectPrepare("foo", "SELECT ID, STATUS FROM ORDERS WHERE ID = \\?") + mock.ExpectQuery("foo").WithArgs(101).WillReturnRows(NewRows([]string{"ID", "STATUS"}).AddRow(101, "Hello")) mock.ExpectCommit() tx, _ := mock.Begin(context.Background()) _, err = tx.Prepare(context.Background(), "foo", "SELECT ID, STATUS FROM ORDERS WHERE ID = ?") @@ -1021,8 +1021,10 @@ func TestPreparedStatementCloseExpectation(t *testing.T) { mock, _ := NewConn() a := assert.New(t) - ep := mock.ExpectPrepare("foo", "INSERT INTO ORDERS").WillBeClosed() - ep.ExpectExec().WithArgs(AnyArg(), AnyArg()).WillReturnResult(NewResult("UPDATE", 1)) + mock.ExpectPrepare("foo", "INSERT INTO ORDERS") + mock.ExpectExec("foo").WithArgs(AnyArg(), AnyArg()).WillReturnResult(NewResult("UPDATE", 1)) + mock.ExpectDeallocate("foo") + mock.ExpectDeallocateAll() stmt, err := mock.Prepare(context.Background(), "foo", "INSERT INTO ORDERS(ID, STATUS) VALUES (?, ?)") a.NoError(err) @@ -1034,9 +1036,24 @@ func TestPreparedStatementCloseExpectation(t *testing.T) { err = mock.Deallocate(context.Background(), "baz") a.Error(err, "wrong prepares stmt name should raise an error") + err = mock.DeallocateAll(context.Background()) + a.Error(err, "we're expecting one statement deallocation, not all") + + err = mock.Ping(context.Background()) + a.Error(err, "ping should raise an error, we're expecting deallocate") + err = mock.Deallocate(context.Background(), "foo") a.NoError(err) + err = mock.Ping(context.Background()) + a.Error(err, "ping should raise an error, we're expecting deallocate") + + err = mock.Deallocate(context.Background(), "baz") + a.Error(err, "wrong prepares stmt name should raise an error") + + err = mock.DeallocateAll(context.Background()) + a.NoError(err) + if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) }