Skip to content

Commit

Permalink
[!] rewrite Prepare and Deallocate mocking implementation (#203)
Browse files Browse the repository at this point in the history
pgx doesn't have complicated logic behind `Prepare()` like `lib/pq`.
All prepared statements are accessible by name and don't need any
additional structs. One is able to call `Deallocate` (`DeallocateAll`)
without any prior `Prepare` calls.
  • Loading branch information
pashagolub authored May 13, 2024
1 parent 52a5df5 commit e0fce08
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 83 deletions.
61 changes: 18 additions & 43 deletions expectations.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,58 +266,33 @@ func (e *ExpectedExec) WillReturnResult(result pgconn.CommandTag) *ExpectedExec
// Returned by pgxmock.ExpectPrepare.
type ExpectedPrepare struct {
commonExpectation
mock *pgxmock
expectStmtName string
expectSQL string
deallocateErr error
mustBeClosed bool
deallocated bool
}

// WillReturnCloseError allows to set an error for this prepared statement Close action
func (e *ExpectedPrepare) WillReturnCloseError(err error) *ExpectedPrepare {
e.deallocateErr = err
return e
}

// WillBeClosed is for backward compatibility only and will be removed soon.
//
// Deprecated: One should use WillBeDeallocated() instead.
func (e *ExpectedPrepare) WillBeClosed() *ExpectedPrepare {
return e.WillBeDeallocated()
}

// WillBeDeallocated expects this prepared statement to be deallocated
func (e *ExpectedPrepare) WillBeDeallocated() *ExpectedPrepare {
e.mustBeClosed = true
return e
}

// ExpectQuery allows to expect Query() or QueryRow() on this prepared statement.
// This method is convenient in order to prevent duplicating sql query string matching.
func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery {
eq := &ExpectedQuery{}
eq.expectSQL = e.expectStmtName
e.mock.expectations = append(e.mock.expectations, eq)
return eq
// String returns string representation
func (e *ExpectedPrepare) String() string {
msg := "ExpectedPrepare => expecting call to Prepare():\n"
msg += fmt.Sprintf("\t- matches statement name: '%s'\n", e.expectStmtName)
msg += fmt.Sprintf("\t- matches sql: '%s'\n", e.expectSQL)
return msg + e.commonExpectation.String()
}

// ExpectExec allows to expect Exec() on this prepared statement.
// This method is convenient in order to prevent duplicating sql query string matching.
func (e *ExpectedPrepare) ExpectExec() *ExpectedExec {
eq := &ExpectedExec{}
eq.expectSQL = e.expectStmtName
e.mock.expectations = append(e.mock.expectations, eq)
return eq
// ExpectedDeallocate is used to manage pgx.Deallocate and pgx.DeallocateAll expectations.
// Returned by pgxmock.ExpectDeallocate(string) and pgxmock.ExpectDeallocateAll().
type ExpectedDeallocate struct {
commonExpectation
expectStmtName string
expectAll bool
}

// String returns string representation
func (e *ExpectedPrepare) String() string {
msg := "ExpectedPrepare => expecting call to Prepare():"
msg += fmt.Sprintf("\t- matches statement name: '%s'", e.expectStmtName)
msg += fmt.Sprintf("\t- matches sql: '%s'\n", e.expectSQL)
if e.deallocateErr != nil {
msg += fmt.Sprintf("\t- returns error on Close: %s", e.deallocateErr)
func (e *ExpectedDeallocate) String() string {
msg := "ExpectedDeallocate => expecting call to Deallocate():\n"
if e.expectAll {
msg += "\t- matches all statements\n"
} else {
msg += fmt.Sprintf("\t- matches statement name: '%s'\n", e.expectStmtName)
}
return msg + e.commonExpectation.String()
}
Expand Down
71 changes: 42 additions & 29 deletions pgxmock.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ type Expecter interface {
ExpectClose() *ExpectedClose

// ExpectPrepare expects Prepare() to be called with expectedSQL query.
// the *ExpectedPrepare allows to mock database response.
// Note that you may expect Query() or Exec() on the *ExpectedPrepare
// statement to prevent repeating expectedSQL
ExpectPrepare(expectedStmtName, expectedSQL string) *ExpectedPrepare

// ExpectDeallocate expects Deallocate() to be called with expectedStmtName.
// The *ExpectedDeallocate allows to mock database response
ExpectDeallocate(expectedStmtName string) *ExpectedDeallocate
ExpectDeallocateAll() *ExpectedDeallocate

// ExpectQuery expects Query() or QueryRow() to be called with expectedSQL query.
// the *ExpectedQuery allows to mock database response.
ExpectQuery(expectedSQL string) *ExpectedQuery
Expand Down Expand Up @@ -114,6 +116,7 @@ type PgxConnIface interface {
PgxCommonIface
Close(ctx context.Context) error
Deallocate(ctx context.Context, name string) error
DeallocateAll(ctx context.Context) error
Config() *pgx.ConnConfig
PgConn() *pgconn.PgConn
}
Expand Down Expand Up @@ -166,13 +169,6 @@ func (c *pgxmock) ExpectationsWereMet() error {
return fmt.Errorf("there is a remaining expectation which was not matched: %s", e)
}

// for expected prepared statement check whether it was closed if expected
if prep, ok := e.(*ExpectedPrepare); ok {
if prep.mustBeClosed && !prep.deallocated {
return fmt.Errorf("expected prepared statement to be closed, but it was not: %s", prep)
}
}

// must check whether all expected queried rows are closed
if query, ok := e.(*ExpectedQuery); ok {
if query.rowsMustBeClosed && !query.rowsWereClosed {
Expand Down Expand Up @@ -241,7 +237,19 @@ func (c *pgxmock) ExpectPing() *ExpectedPing {
}

func (c *pgxmock) ExpectPrepare(expectedStmtName, expectedSQL string) *ExpectedPrepare {
e := &ExpectedPrepare{expectSQL: expectedSQL, expectStmtName: expectedStmtName, mock: c}
e := &ExpectedPrepare{expectSQL: expectedSQL, expectStmtName: expectedStmtName}
c.expectations = append(c.expectations, e)
return e
}

func (c *pgxmock) ExpectDeallocate(expectedStmtName string) *ExpectedDeallocate {
e := &ExpectedDeallocate{expectStmtName: expectedStmtName}
c.expectations = append(c.expectations, e)
return e
}

func (c *pgxmock) ExpectDeallocateAll() *ExpectedDeallocate {
e := &ExpectedDeallocate{expectAll: true}
c.expectations = append(c.expectations, e)
return e
}
Expand Down Expand Up @@ -371,27 +379,32 @@ func (c *pgxmock) Prepare(ctx context.Context, name, query string) (*pgconn.Stat
}

func (c *pgxmock) Deallocate(ctx context.Context, name string) error {
var (
expected *ExpectedPrepare
ok bool
)
for _, next := range c.expectations {
next.Lock()
expected, ok = next.(*ExpectedPrepare)
ok = ok && expected.expectStmtName == name
next.Unlock()
if ok {
break
ex, err := findExpectationFunc[*ExpectedDeallocate](c, "Deallocate()", func(deallocateExp *ExpectedDeallocate) error {
if deallocateExp.expectAll {
return fmt.Errorf("Deallocate: all prepared statements were expected to be deallocated, instead only '%s' specified", name)
}
if deallocateExp.expectStmtName != name {
return fmt.Errorf("Deallocate: prepared statement name '%s' was not expected, expected name is '%s'", name, deallocateExp.expectStmtName)
}
return nil
})
if err != nil {
return err
}
if expected == nil {
return fmt.Errorf("Deallocate: prepared statement name '%s' doesn't exist", name)
}
if ctx.Err() != nil {
return ctx.Err()
return ex.waitForDelay(ctx)
}

func (c *pgxmock) DeallocateAll(ctx context.Context) error {
ex, err := findExpectationFunc[*ExpectedDeallocate](c, "DeallocateAll()", func(deallocateExp *ExpectedDeallocate) error {
if !deallocateExp.expectAll {
return fmt.Errorf("Deallocate: deallocate all prepared statements was not expected, expected name is '%s'", deallocateExp.expectStmtName)
}
return nil
})
if err != nil {
return err
}
expected.deallocated = true
return expected.deallocateErr
return ex.waitForDelay(ctx)
}

func (c *pgxmock) Commit(ctx context.Context) error {
Expand Down
39 changes: 28 additions & 11 deletions pgxmock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,8 @@ func TestPrepareExpectations(t *testing.T) {
a := assert.New(t)
expErr := errors.New("invaders must die")
mock.ExpectPrepare("foo", "SELECT (.+) FROM articles WHERE id = ?").
WillReturnCloseError(expErr).
WillDelayFor(1 * time.Second)
mock.ExpectDeallocate("foo").WillReturnError(expErr)

stmt, err := mock.Prepare(context.Background(), "baz", "SELECT (.+) FROM articles WHERE id = ?")
a.Error(err, "wrong prepare stmt name should raise an error")
Expand Down Expand Up @@ -344,12 +344,12 @@ func TestUnorderedPreparedQueryExecutions(t *testing.T) {

mock.MatchExpectationsInOrder(false)

mock.ExpectPrepare("articles_stmt", "SELECT (.+) FROM articles WHERE id = ?").
ExpectQuery().
mock.ExpectPrepare("articles_stmt", "SELECT (.+) FROM articles WHERE id = ?")
mock.ExpectQuery("articles_stmt").
WithArgs(5).
WillReturnRows(NewRows([]string{"id", "title"}).AddRow(5, "The quick brown fox"))
mock.ExpectPrepare("authors_stmt", "SELECT (.+) FROM authors WHERE id = ?").
ExpectQuery().
mock.ExpectPrepare("authors_stmt", "SELECT (.+) FROM authors WHERE id = ?")
mock.ExpectQuery("authors_stmt").
WithArgs(1).
WillReturnRows(NewRows([]string{"id", "title"}).AddRow(1, "Betty B."))

Expand Down Expand Up @@ -911,9 +911,9 @@ func TestPrepareExec(t *testing.T) {
}
defer mock.Close(context.Background())
mock.ExpectBegin()
ep := mock.ExpectPrepare("foo", "INSERT INTO ORDERS\\(ID, STATUS\\) VALUES \\(\\?, \\?\\)")
mock.ExpectPrepare("foo", "INSERT INTO ORDERS\\(ID, STATUS\\) VALUES \\(\\?, \\?\\)")
for i := 0; i < 3; i++ {
ep.ExpectExec().WithArgs(AnyArg(), AnyArg()).WillReturnResult(NewResult("UPDATE", 1))
mock.ExpectExec("foo").WithArgs(AnyArg(), AnyArg()).WillReturnResult(NewResult("UPDATE", 1))
}
mock.ExpectCommit()
tx, _ := mock.Begin(context.Background())
Expand Down Expand Up @@ -942,8 +942,8 @@ func TestPrepareQuery(t *testing.T) {
}
defer mock.Close(context.Background())
mock.ExpectBegin()
ep := mock.ExpectPrepare("foo", "SELECT ID, STATUS FROM ORDERS WHERE ID = \\?")
ep.ExpectQuery().WithArgs(101).WillReturnRows(NewRows([]string{"ID", "STATUS"}).AddRow(101, "Hello"))
mock.ExpectPrepare("foo", "SELECT ID, STATUS FROM ORDERS WHERE ID = \\?")
mock.ExpectQuery("foo").WithArgs(101).WillReturnRows(NewRows([]string{"ID", "STATUS"}).AddRow(101, "Hello"))
mock.ExpectCommit()
tx, _ := mock.Begin(context.Background())
_, err = tx.Prepare(context.Background(), "foo", "SELECT ID, STATUS FROM ORDERS WHERE ID = ?")
Expand Down Expand Up @@ -1021,8 +1021,10 @@ func TestPreparedStatementCloseExpectation(t *testing.T) {
mock, _ := NewConn()
a := assert.New(t)

ep := mock.ExpectPrepare("foo", "INSERT INTO ORDERS").WillBeClosed()
ep.ExpectExec().WithArgs(AnyArg(), AnyArg()).WillReturnResult(NewResult("UPDATE", 1))
mock.ExpectPrepare("foo", "INSERT INTO ORDERS")
mock.ExpectExec("foo").WithArgs(AnyArg(), AnyArg()).WillReturnResult(NewResult("UPDATE", 1))
mock.ExpectDeallocate("foo")
mock.ExpectDeallocateAll()

stmt, err := mock.Prepare(context.Background(), "foo", "INSERT INTO ORDERS(ID, STATUS) VALUES (?, ?)")
a.NoError(err)
Expand All @@ -1034,9 +1036,24 @@ func TestPreparedStatementCloseExpectation(t *testing.T) {
err = mock.Deallocate(context.Background(), "baz")
a.Error(err, "wrong prepares stmt name should raise an error")

err = mock.DeallocateAll(context.Background())
a.Error(err, "we're expecting one statement deallocation, not all")

err = mock.Ping(context.Background())
a.Error(err, "ping should raise an error, we're expecting deallocate")

err = mock.Deallocate(context.Background(), "foo")
a.NoError(err)

err = mock.Ping(context.Background())
a.Error(err, "ping should raise an error, we're expecting deallocate")

err = mock.Deallocate(context.Background(), "baz")
a.Error(err, "wrong prepares stmt name should raise an error")

err = mock.DeallocateAll(context.Background())
a.NoError(err)

if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
Expand Down

0 comments on commit e0fce08

Please sign in to comment.