Skip to content

Commit

Permalink
support multiple tables in a WITH clause
Browse files Browse the repository at this point in the history
  • Loading branch information
huandu committed Jul 24, 2024
1 parent 5588d0a commit a0af5e4
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 54 deletions.
63 changes: 21 additions & 42 deletions cte.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@ package sqlbuilder
const (
cteMarkerInit injectionMarker = iota
cteMarkerAfterWith
cteMarkerAfterAs
)

// With creates a new CTE builder with default flavor.
func With(name string, cols ...string) *CTEBuilder {
return DefaultFlavor.NewCTEBuilder().With(name, cols...)
func With(tables ...*CTETableBuilder) *CTEBuilder {
return DefaultFlavor.NewCTEBuilder().With(tables...)
}

func newCTEBuilder() *CTEBuilder {
Expand All @@ -23,9 +22,8 @@ func newCTEBuilder() *CTEBuilder {

// CTEBuilder is a CTE (Common Table Expression) builder.
type CTEBuilder struct {
name string
cols []string
builderVar string
tableNames []string
tableBuilderVars []string

args *Args

Expand All @@ -36,17 +34,18 @@ type CTEBuilder struct {
var _ Builder = new(CTEBuilder)

// With sets the CTE name and columns.
func (cteb *CTEBuilder) With(name string, cols ...string) *CTEBuilder {
cteb.name = name
cteb.cols = cols
cteb.marker = cteMarkerAfterWith
return cteb
}
func (cteb *CTEBuilder) With(tables ...*CTETableBuilder) *CTEBuilder {
tableNames := make([]string, 0, len(tables))
tableBuilderVars := make([]string, 0, len(tables))

// As sets the builder to select data.
func (cteb *CTEBuilder) As(builder Builder) *CTEBuilder {
cteb.builderVar = cteb.args.Add(builder)
cteb.marker = cteMarkerAfterAs
for _, table := range tables {
tableNames = append(tableNames, table.TableName())
tableBuilderVars = append(tableBuilderVars, cteb.args.Add(table))
}

cteb.tableNames = tableNames
cteb.tableBuilderVars = tableBuilderVars
cteb.marker = cteMarkerAfterWith
return cteb
}

Expand All @@ -72,27 +71,12 @@ func (cteb *CTEBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}
buf := newStringBuilder()
cteb.injection.WriteTo(buf, cteMarkerInit)

if cteb.name != "" {
if len(cteb.tableBuilderVars) > 0 {
buf.WriteLeadingString("WITH ")
buf.WriteString(cteb.name)

if len(cteb.cols) > 0 {
buf.WriteLeadingString("(")
buf.WriteStrings(cteb.cols, ", ")
buf.WriteString(")")
}

cteb.injection.WriteTo(buf, cteMarkerAfterWith)
}

if cteb.builderVar != "" {
buf.WriteLeadingString("AS (")
buf.WriteString(cteb.builderVar)
buf.WriteRune(')')

cteb.injection.WriteTo(buf, cteMarkerAfterAs)
buf.WriteStrings(cteb.tableBuilderVars, ", ")
}

cteb.injection.WriteTo(buf, cteMarkerAfterWith)
return cteb.args.CompileWithFlavor(buf.String(), flavor, initialArg...)
}

Expand All @@ -103,18 +87,13 @@ func (cteb *CTEBuilder) SetFlavor(flavor Flavor) (old Flavor) {
return
}

// Var returns a placeholder for value.
func (cteb *CTEBuilder) Var(arg interface{}) string {
return cteb.args.Add(arg)
}

// SQL adds an arbitrary sql to current position.
func (cteb *CTEBuilder) SQL(sql string) *CTEBuilder {
cteb.injection.SQL(cteb.marker, sql)
return cteb
}

// TableName returns the CTE table name.
func (cteb *CTEBuilder) TableName() string {
return cteb.name
// TableNames returns all table names in a CTE.
func (cteb *CTEBuilder) TableNames() []string {
return cteb.tableNames
}
38 changes: 27 additions & 11 deletions cte_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,33 @@ import (
)

func ExampleWith() {
sb := With("users", "id", "name").As(
Select("id", "name").From("users").Where("name IS NOT NULL"),
).Select("users.id", "orders.id").Join("orders", "users.id = orders.user_id")
sb := With(
CTETable("users", "id", "name").As(
Select("id", "name").From("users").Where("name IS NOT NULL"),
),
CTETable("devices").As(
Select("device_id").From("devices"),
),
).Select("users.id", "orders.id", "devices.device_id").Join(
"orders",
"users.id = orders.user_id",
"devices.device_id = orders.device_id",
)

fmt.Println(sb)

// Output:
// WITH users (id, name) AS (SELECT id, name FROM users WHERE name IS NOT NULL) SELECT users.id, orders.id FROM users JOIN orders ON users.id = orders.user_id
// WITH users (id, name) AS (SELECT id, name FROM users WHERE name IS NOT NULL), devices AS (SELECT device_id FROM devices) SELECT users.id, orders.id, devices.device_id FROM users, devices JOIN orders ON users.id = orders.user_id AND devices.device_id = orders.device_id
}

func ExampleCTEBuilder() {
usersBuilder := Select("id", "name", "level").From("users")
usersBuilder.Where(
usersBuilder.GreaterEqualThan("level", 10),
)
cteb := With("valid_users").As(usersBuilder)
cteb := With(
CTETable("valid_users").As(usersBuilder),
)
fmt.Println(cteb)

sb := Select("valid_users.id", "valid_users.name", "orders.id").With(cteb)
Expand All @@ -49,17 +60,22 @@ func ExampleCTEBuilder() {
func TestCTEBuilder(t *testing.T) {
a := assert.New(t)
cteb := newCTEBuilder()
ctetb := newCTETableBuilder()
cteb.SQL("/* init */")
cteb.With("t", "a", "b")
cteb.With(ctetb)
cteb.SQL("/* after with */")

// Make sure that calling Var() will not affect the As().
cteb.Var(123)
ctetb.SQL("/* table init */")
ctetb.Table("t", "a", "b")
ctetb.SQL("/* after table */")

cteb.As(Select("a", "b").From("t"))
cteb.SQL("/* after as */")
ctetb.As(Select("a", "b").From("t"))
ctetb.SQL("/* after table as */")

sql, args := cteb.Build()
a.Equal(sql, "/* init */ WITH t (a, b) /* after with */ AS (SELECT a, b FROM t) /* after as */")
a.Equal(sql, "/* init */ WITH /* table init */ t (a, b) /* after table */ AS (SELECT a, b FROM t) /* after table as */ /* after with */")
a.Assert(args == nil)

sql = ctetb.String()
a.Equal(sql, "/* table init */ t (a, b) /* after table */ AS (SELECT a, b FROM t) /* after table as */")
}
106 changes: 106 additions & 0 deletions ctetable.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Copyright 2024 Huan Du. All rights reserved.
// Licensed under the MIT license that can be found in the LICENSE file.

package sqlbuilder

const (
cteTableMarkerInit injectionMarker = iota
cteTableMarkerAfterTable
cteTableMarkerAfterAs
)

// CTETable creates a new CTE table builder with default flavor.
func CTETable(name string, cols ...string) *CTETableBuilder {
return DefaultFlavor.NewCTETableBuilder().Table(name, cols...)
}

func newCTETableBuilder() *CTETableBuilder {
return &CTETableBuilder{
args: &Args{},
injection: newInjection(),
}
}

// CTETableBuilder is a builder to build one table in CTE (Common Table Expression).
type CTETableBuilder struct {
name string
cols []string
builderVar string

args *Args

injection *injection
marker injectionMarker
}

// Table sets the table name and columns in a CTE table.
func (ctetb *CTETableBuilder) Table(name string, cols ...string) *CTETableBuilder {
ctetb.name = name
ctetb.cols = cols
ctetb.marker = cteTableMarkerAfterTable
return ctetb
}

// As sets the builder to select data.
func (ctetb *CTETableBuilder) As(builder Builder) *CTETableBuilder {
ctetb.builderVar = ctetb.args.Add(builder)
ctetb.marker = cteTableMarkerAfterAs
return ctetb
}

// String returns the compiled CTE string.
func (ctetb *CTETableBuilder) String() string {
sql, _ := ctetb.Build()
return sql
}

// Build returns compiled CTE string and args.
func (ctetb *CTETableBuilder) Build() (sql string, args []interface{}) {
return ctetb.BuildWithFlavor(ctetb.args.Flavor)
}

// BuildWithFlavor builds a CTE with the specified flavor and initial arguments.
func (ctetb *CTETableBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) {
buf := newStringBuilder()
ctetb.injection.WriteTo(buf, cteTableMarkerInit)

if ctetb.name != "" {
buf.WriteLeadingString(ctetb.name)

if len(ctetb.cols) > 0 {
buf.WriteLeadingString("(")
buf.WriteStrings(ctetb.cols, ", ")
buf.WriteString(")")
}

ctetb.injection.WriteTo(buf, cteTableMarkerAfterTable)
}

if ctetb.builderVar != "" {
buf.WriteLeadingString("AS (")
buf.WriteString(ctetb.builderVar)
buf.WriteRune(')')

ctetb.injection.WriteTo(buf, cteTableMarkerAfterAs)
}

return ctetb.args.CompileWithFlavor(buf.String(), flavor, initialArg...)
}

// SetFlavor sets the flavor of compiled sql.
func (ctetb *CTETableBuilder) SetFlavor(flavor Flavor) (old Flavor) {
old = ctetb.args.Flavor
ctetb.args.Flavor = flavor
return
}

// SQL adds an arbitrary sql to current position.
func (ctetb *CTETableBuilder) SQL(sql string) *CTETableBuilder {
ctetb.injection.SQL(ctetb.marker, sql)
return ctetb
}

// TableName returns the CTE table name.
func (ctetb *CTETableBuilder) TableName() string {
return ctetb.name
}
7 changes: 7 additions & 0 deletions flavor.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@ func (f Flavor) NewCTEBuilder() *CTEBuilder {
return b
}

// NewCTETableBuilder creates a new CTE table builder with flavor.
func (f Flavor) NewCTETableBuilder() *CTETableBuilder {
b := newCTETableBuilder()
b.SetFlavor(f)
return b
}

// Quote adds quote for name to make sure the name can be used safely
// as table name or field name.
//
Expand Down
2 changes: 1 addition & 1 deletion select.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func Select(col ...string) *SelectBuilder {
func (sb *SelectBuilder) With(builder *CTEBuilder) *SelectBuilder {
sb.marker = selectMarkerAfterWith
sb.cteBuilder = sb.Var(builder)
sb.tables = []string{builder.TableName()}
sb.tables = builder.TableNames()
return sb
}

Expand Down

0 comments on commit a0af5e4

Please sign in to comment.