Skip to content

Commit

Permalink
Multiple joins in one query (#984)
Browse files Browse the repository at this point in the history
* Remove unnecessary variable

* Refactor `check_join_as()`

* Use `NULL` instead of `NA`

* Refactor semi join (#981)

* Transfer attributes while inlining

* Change semi-join vars to tibble

* Document

* Add `types` to `copy_inline()` (#964)

* Add `types` to `copy_inline()`

* Remove commented code

* Clarify `types` documentation

* Remove incorrect `types` argumen

* Check `types` argument

* Test that `types` argument works

* Remove unnecessary code

* `rows_*()` casts `y` columns if it copies them (#965)

* `rows_*()` casts `y` columns if it copies them

* Fix `name = NULL` case

* Check containment before copying

* Only use type inference for Postgres

* Multiple joins in one query

* Fix incorrect check in `join_needs_new_query()`

* Can use by column from other table than first

* Remove unnecessary `join_as_already_used()`

* Fix NEWS

* Replace `vec_unchop()` by `list_unchop()`

* Document `joins` data structure

* Minor refactoring

* Update documentation

* Remove duplicated code

* Update R/db-sql.R

Co-authored-by: Hadley Wickham <[email protected]>

* Avoid pipe

* Refactor `joins` data structure

* Pull out generation of table names

* Refactor `sql_build.lazy_multi_join_query()`

* Use `length()` instead of `vec_size()`

* Rename to `make_join_aliases()`

* Avoid unnecessary negation

* Use postgres for simpler snapshot

* Split test

* Replace explicity test with snapshot

* Refactor join table name logic

Co-authored-by: Hadley Wickham <[email protected]>
  • Loading branch information
mgirlich and hadley authored Dec 6, 2022
1 parent 8bf404a commit b16399f
Show file tree
Hide file tree
Showing 15 changed files with 811 additions and 200 deletions.
8 changes: 8 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ S3method(explain,tbl_sql)
S3method(flatten_query,ident)
S3method(flatten_query,join_query)
S3method(flatten_query,lazy_values_query)
S3method(flatten_query,multi_join_query)
S3method(flatten_query,select_query)
S3method(flatten_query,semi_join_query)
S3method(flatten_query,set_op_query)
Expand Down Expand Up @@ -135,6 +136,7 @@ S3method(op_sort,lazy_query)
S3method(op_sort,tbl_lazy)
S3method(op_vars,lazy_base_query)
S3method(op_vars,lazy_join_query)
S3method(op_vars,lazy_multi_join_query)
S3method(op_vars,lazy_query)
S3method(op_vars,lazy_semi_join_query)
S3method(op_vars,lazy_set_op_query)
Expand All @@ -150,6 +152,7 @@ S3method(print,lazy_join_query)
S3method(print,lazy_select_query)
S3method(print,lazy_semi_join_query)
S3method(print,lazy_set_op_query)
S3method(print,multi_join_query)
S3method(print,select_query)
S3method(print,semi_join_query)
S3method(print,set_op_query)
Expand Down Expand Up @@ -190,6 +193,7 @@ S3method(sql_build,ident)
S3method(sql_build,lazy_base_local_query)
S3method(sql_build,lazy_base_remote_query)
S3method(sql_build,lazy_join_query)
S3method(sql_build,lazy_multi_join_query)
S3method(sql_build,lazy_select_query)
S3method(sql_build,lazy_semi_join_query)
S3method(sql_build,lazy_set_op_query)
Expand Down Expand Up @@ -260,6 +264,7 @@ S3method(sql_query_join,MariaDBConnection)
S3method(sql_query_join,MySQL)
S3method(sql_query_join,MySQLConnection)
S3method(sql_query_join,SQLiteConnection)
S3method(sql_query_multi_join,DBIConnection)
S3method(sql_query_rows,DBIConnection)
S3method(sql_query_save,"Microsoft SQL Server")
S3method(sql_query_save,DBIConnection)
Expand Down Expand Up @@ -303,6 +308,7 @@ S3method(sql_render,ident)
S3method(sql_render,join_query)
S3method(sql_render,lazy_query)
S3method(sql_render,lazy_values_query)
S3method(sql_render,multi_join_query)
S3method(sql_render,select_query)
S3method(sql_render,semi_join_query)
S3method(sql_render,set_op_query)
Expand Down Expand Up @@ -435,6 +441,7 @@ export(lahman_srcs)
export(lazy_base_query)
export(lazy_frame)
export(lazy_join_query)
export(lazy_multi_join_query)
export(lazy_query)
export(lazy_select_query)
export(lazy_semi_join_query)
Expand Down Expand Up @@ -501,6 +508,7 @@ export(sql_query_explain)
export(sql_query_fields)
export(sql_query_insert)
export(sql_query_join)
export(sql_query_multi_join)
export(sql_query_rows)
export(sql_query_save)
export(sql_query_select)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@

* `summarise()` + `filter()` now translates to `HAVING` where possible
(@mgirlich, #877).

* `left/inner_join()` + `left/inner_join()` (@mgirlich, #865).

* The generated SQL is now shorter and more readable:

Expand Down
87 changes: 87 additions & 0 deletions R/db-sql.R
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,93 @@ sql_join.DBIConnection <- function(con, x, y, vars, type = "inner", by = NULL, n
)
}

#' @rdname db-sql
#' @export
sql_query_multi_join <- function(con,
x,
joins,
table_vars,
vars,
...,
lvl = 0) {
UseMethod("sql_query_multi_join")
}

#' @export
#' @param vars tibble with six columns:
#' * `table` `<tbl_lazy>`: the tables to join with.
#' * `type` `<character>`: the join type (left, right, inner, full).
#' * `by_x`, `by_y` `<list_of<character>>`: The columns to join by
#' * `by_x_table_id` `<list_of<integer>>`: The table index where the join column
#' comes from. This needs to be a list because a the join columns might come
#' from different tables
#' * `on` `<character>`
#' * `na_matches` `<character>`: Either `"na"` or `"never"`.
#' @param vars See [sql_multi_join_vars()].
#' @param table_vars `named <list_of<character>>`: All variables in each table.
#' @noRd
#' @examples
#' # Left join with *
#' df1 <- lazy_frame(x = 1, y = 1)
#' df2 <- lazy_frame(x = 1, z = 1)
#' df3 <- lazy_frame(x = 1, z2 = 1)
#'
#' tmp <- left_join(df1, df2, by = "x") %>%
#' left_join(df3, by = c("x", z = "z2"))
#' tibble(
#' table = list(df1, df2),
#' type = c("left", "left"),
#' by_x = list("x", c("x", "z")),
#' by_y = list("x", c("x", "z2")),
#' by_x_table_id = list(1L, c(1L, 2L)),
#' on = c(NA, NA),
#' na_matches = c("never", "never")
#' )
sql_query_multi_join.DBIConnection <- function(con,
x,
joins,
table_vars,
vars,
...,
lvl = 0) {
table_names <- names(table_vars)
if (vctrs::vec_duplicate_any(table_names)) {
cli_abort("{.arg table_names} must be unique.")
}

select_sql <- sql_multi_join_vars(con, vars, table_vars)
from <- dbplyr_sql_subquery(con, x, name = table_names[[1]], lvl = lvl)

join_table_queries <- purrr::map2(
joins$table,
table_names[-1],
function(table, name) dbplyr_sql_subquery(con, table, name = name, lvl = lvl)
)
types <- toupper(paste0(joins$type, " JOIN"))
join_clauses <- purrr::map2(
types,
join_table_queries,
function(join_kw, from) sql_clause(join_kw, from)
)

on_clauses <- purrr::map(
joins$by,
function(by) {
on <- sql_join_tbls(con, by = by, na_matches = by$na_matches)
sql_clause("ON", on, sep = " AND", parens = TRUE, lvl = 1)
}
)
join_on_clauses <- vctrs::vec_interleave(join_clauses, on_clauses)

clauses <- list2(
sql_clause_select(con, select_sql),
sql_clause_from(from),
!!!join_on_clauses
)

sql_format_clauses(clauses, lvl = lvl, con = con)
}

#' @rdname db-sql
#' @export
sql_query_semi_join <- function(con, x, y, anti, by, vars, ..., lvl = 0) {
Expand Down
121 changes: 121 additions & 0 deletions R/lazy-join-query.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,49 @@ lazy_join_query <- function(x,
)
}

#' @export
#' @rdname sql_build
lazy_multi_join_query <- function(x,
joins,
table_names,
vars,
group_vars = op_grps(x),
order_vars = op_sort(x),
frame = op_frame(x),
call = caller_env()) {
stopifnot(inherits(x, "lazy_query"))

if (!identical(colnames(joins), c("table", "type", "by_x_table_id", "by"))) {
cli_abort("`joins` must have fields `table`, `type`, `by_x_table_id`, `by`", .internal = TRUE)
}
vctrs::vec_assert(joins$type, character(), arg = "joins$type", call = caller_env())

if (!identical(colnames(table_names), c("name", "from"))) {
cli_abort("`table_names` must have fields `name`, `from`", .internal = TRUE)
}
vctrs::vec_assert(table_names$name, character(), arg = "table_names$as", call = caller_env())
vctrs::vec_assert(table_names$from, character(), arg = "table_names$from", call = caller_env())

if (!identical(colnames(vars), c("name", "table", "var"))) {
cli_abort("`vars` must have fields `name`, `table`, `var`", .internal = TRUE)
}
vctrs::vec_assert(vars$name, character(), arg = "vars$name", call = caller_env())
vctrs::vec_assert(vars$table, list(), arg = "vars$table", call = caller_env())
vctrs::vec_assert(vars$var, list(), arg = "vars$var", call = caller_env())

lazy_query(
query_type = "multi_join",
x = x,
joins = joins,
table_names = table_names,
vars = vars,
last_op = "join",
group_vars = group_vars,
order_vars = order_vars,
frame = frame
)
}

join_check_vars <- function(vars, call) {
if (!vctrs::vec_is_list(vars)) {
# TODO use `cli_abort()` after https://github.com/r-lib/rlang/issues/1386
Expand Down Expand Up @@ -139,6 +182,10 @@ op_vars.lazy_join_query <- function(op) {
op$vars$alias
}
#' @export
op_vars.lazy_multi_join_query <- function(op) {
op$vars$name
}
#' @export
op_vars.lazy_semi_join_query <- function(op) {
op$vars$name
}
Expand All @@ -156,6 +203,80 @@ sql_build.lazy_join_query <- function(op, con, ...) {
)
}

#' @export
sql_build.lazy_multi_join_query <- function(op, con, ...) {
table_names_out <- generate_join_table_names(op$table_names)

if (length(table_names_out) > 2 || !op$joins$type %in% c("full", "right")) {
table_vars <- purrr::map(
set_names(c(list(op$x), op$joins$table), table_names_out),
op_vars
)

op$joins$table <- purrr::map(op$joins$table, ~ sql_optimise(sql_build(.x, con), con))
op$joins$by <- purrr::map2(
op$joins$by, seq_along(op$joins$by),
function(by, i) {
by$x_as <- ident(table_names_out[op$joins$by_x_table_id[[i]]])
by$y_as <- ident(table_names_out[i + 1L])
by
}
)

out <- multi_join_query(
x = sql_optimise(sql_build(op$x, con), con),
joins = op$joins,
table_vars = table_vars,
vars = op$vars
)

return(out)
}

# construct a classical `join_query()` so that special handling of
# full and right join continue to work
type <- op$joins$type
x_idx <- purrr::map_lgl(op$vars$table, ~ 1L %in% .x)
vars_x <- purrr::map2_chr(op$vars$var, x_idx, ~ {if (.y) .x[[1]] else NA_character_})

y_idx <- purrr::map_lgl(op$vars$table, ~ 2L %in% .x)
vars_y <- purrr::map2_chr(op$vars$var, y_idx, ~ {if (.y) dplyr::last(.x) %||% .x[[1]] else NA_character_})

vars_classic <- list(
alias = op$vars$name,
x = vars_x,
y = vars_y,
all_x = op_vars(op$x),
all_y = op_vars(op$joins$table[[1]])
)

by <- op$joins$by[[1]]
by$x_as <- ident(table_names_out[[1]])
by$y_as <- ident(table_names_out[[2]])

join_query(
sql_optimise(sql_build(op$x, con), con),
sql_optimise(sql_build(op$joins$table[[1]], con), con),
vars = vars_classic,
type = type,
by = by,
suffix = NULL, # it seems like the suffix is not used for rendering
na_matches = by$na_matches
)
}

generate_join_table_names <- function(table_names) {
if (length(table_names$name) != 2) {
table_names_repaired <- vctrs::vec_as_names(table_names$name, repair = "unique", quiet = TRUE)
auto_name <- table_names$from != "as"
table_names$name[auto_name] <- table_names_repaired[auto_name]

return(table_names$name)
}

join_two_table_alias(table_names$name, table_names$from)
}

#' @export
sql_build.lazy_semi_join_query <- function(op, con, ...) {
vars_prev <- op_vars(op$x)
Expand Down
51 changes: 50 additions & 1 deletion R/query-join.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@ join_query <- function(x, y, vars, type = "inner", by = NULL, suffix = c(".x", "
)
}

multi_join_query <- function(x, joins, table_vars, vars) {
structure(
list(
x = x,
joins = joins,
table_vars = table_vars,
vars = vars
),
class = c("multi_join_query", "query")
)
}

#' @export
print.join_query <- function(x, ...) {
cat_line("<SQL JOIN (", toupper(x$type), ")>")
Expand All @@ -28,6 +40,24 @@ print.join_query <- function(x, ...) {
cat_line(indent_print(sql_build(x$y)))
}

#' @export
print.multi_join_query <- function(x, ...) {
cat_line("<SQL JOINS>")

cat_line("X:")
cat_line(indent_print(sql_build(x$x)))

for (i in vctrs::vec_seq_along(x$joins)) {
cat_line("Type: ", paste0(x$joins$type[[i]]))

cat_line("By:")
cat_line(indent(paste0(x$joins$by[[i]]$x, "-", x$joins$by[[i]]$y)))

cat_line("Y:")
cat_line(indent_print(sql_build(x$joins$table[[i]])))
}
}

#' @export
sql_render.join_query <- function(query, con = NULL, ..., subquery = FALSE, lvl = 0) {
from_x <- sql_render(query$x, con, ..., subquery = TRUE, lvl = lvl + 1)
Expand All @@ -42,6 +72,25 @@ sql_render.join_query <- function(query, con = NULL, ..., subquery = FALSE, lvl
)
}

#' @export
sql_render.multi_join_query <- function(query, con = NULL, ..., subquery = FALSE, lvl = 0) {
x <- sql_render(query$x, con, ..., subquery = TRUE, lvl = lvl + 1)
query$joins$table <- purrr::map(
query$joins$table,
~ sql_render(.x, con, ..., subquery = TRUE, lvl = lvl + 1)
)

sql_query_multi_join(
con = con,
x = x,
joins = query$joins,
table_vars = query$table_vars,
by_list = query$by_list,
vars = query$vars,
lvl = lvl
)
}

# SQL generation ----------------------------------------------------------

#' @param vars tibble with three columns:
Expand Down Expand Up @@ -248,7 +297,7 @@ sql_table_prefix <- function(con, var, table = NULL) {
var <- sql_escape_ident(con, var)

if (!is.null(table)) {
table <- escape(table, con = con)
table <- escape(table, collapse = NULL, con = con)
sql(paste0(table, ".", var))
} else {
var
Expand Down
Loading

0 comments on commit b16399f

Please sign in to comment.