diff --git a/internal/dbtest/db_test.go b/internal/dbtest/db_test.go index d560041f6..5afbb3731 100644 --- a/internal/dbtest/db_test.go +++ b/internal/dbtest/db_test.go @@ -234,6 +234,7 @@ func TestDB(t *testing.T) { {testNilModel}, {testSelectScan}, {testSelectCount}, + {testSelectLimit}, {testSelectMap}, {testSelectMapSlice}, {testSelectStruct}, @@ -348,6 +349,36 @@ func testSelectCount(t *testing.T, db *bun.DB) { require.Equal(t, 3, count) } +func testSelectLimit(t *testing.T, db *bun.DB) { + if !db.Dialect().Features().Has(feature.CTE) { + t.Skip() + return + } + + values := db.NewValues(&[]map[string]interface{}{ + {"num": 1}, + {"num": 2}, + {"num": 3}, + }) + + q := db.NewSelect(). + With("t", values). + Column("t.num"). + TableExpr("t") + + count, err := q.Limit(5).Count(ctx) + require.NoError(t, err) + require.Equal(t, 3, count) + + count, err = q.Limit(2).Count(ctx) + require.NoError(t, err) + require.Equal(t, 2, count) + + count, err = q.Limit(0).Count(ctx) + require.NoError(t, err) + require.Equal(t, 0, count) +} + func testSelectMap(t *testing.T, db *bun.DB) { var m map[string]interface{} err := db.NewSelect(). @@ -1357,6 +1388,21 @@ func testScanAndCount(t *testing.T, db *bun.DB) { require.Equal(t, 2, count) require.Equal(t, 2, len(dest)) }) + + t.Run("limit 0", func(t *testing.T) { + src := []Model{ + {Str: "str1"}, + {Str: "str2"}, + } + _, err = db.NewInsert().Model(&src).Exec(ctx) + require.NoError(t, err) + + var dest []Model + count, err := db.NewSelect().Model(&dest).Limit(0).ScanAndCount(ctx) + require.NoError(t, err) + require.Equal(t, 0, count) + require.Equal(t, 0, len(dest)) + }) } func testEmbedModelValue(t *testing.T, db *bun.DB) { diff --git a/query_select.go b/query_select.go index c0e145110..95e3a0925 100644 --- a/query_select.go +++ b/query_select.go @@ -48,6 +48,7 @@ func NewSelectQuery(db *DB) *SelectQuery { conn: db.DB, }, }, + limit: -1, } } @@ -631,7 +632,7 @@ func (q *SelectQuery) appendQuery( b = append(b, " ROWS"...) } } else { - if q.limit > 0 { + if q.limit >= 0 { b = append(b, " LIMIT "...) b = strconv.AppendInt(b, int64(q.limit), 10) } @@ -958,7 +959,7 @@ func (q *SelectQuery) scanAndCountConc(ctx context.Context, dest ...interface{}) var mu sync.Mutex var firstErr error - if q.limit >= 0 { + if q.limit >= -1 { wg.Add(1) go func() { defer wg.Done() @@ -995,7 +996,7 @@ func (q *SelectQuery) scanAndCountConc(ctx context.Context, dest ...interface{}) func (q *SelectQuery) scanAndCountSeq(ctx context.Context, dest ...interface{}) (int, error) { var firstErr error - if q.limit >= 0 { + if q.limit >= -1 { firstErr = q.Scan(ctx, dest...) }