Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple joins in one query #984

Merged
merged 33 commits into from
Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7e7fce6
Remove unnecessary variable
mgirlich Aug 18, 2022
42e4071
Refactor `check_join_as()`
mgirlich Aug 19, 2022
c953e07
Use `NULL` instead of `NA`
mgirlich Aug 19, 2022
a4fb7c5
Refactor semi join (#981)
mgirlich Aug 19, 2022
09d93d6
Add `types` to `copy_inline()` (#964)
mgirlich Aug 19, 2022
faf8d63
`rows_*()` casts `y` columns if it copies them (#965)
mgirlich Aug 19, 2022
e980aaa
Multiple joins in one query
mgirlich Aug 19, 2022
a8ae7da
Merge commit '79b1fd59491357ce214a6fcfe4ad372bc548109f'
mgirlich Aug 25, 2022
c4dd32b
Fix incorrect check in `join_needs_new_query()`
mgirlich Aug 25, 2022
fe93ae6
Can use by column from other table than first
mgirlich Sep 14, 2022
f1f5961
Remove unnecessary `join_as_already_used()`
mgirlich Sep 14, 2022
2847b48
Merged upstream/main into multi_join
mgirlich Sep 14, 2022
e029821
Fix NEWS
mgirlich Sep 14, 2022
98730d4
Merge commit 'f4cef2360088c39620bd85f474ec2e888a520dcc'
mgirlich Oct 18, 2022
082d357
Replace `vec_unchop()` by `list_unchop()`
mgirlich Oct 18, 2022
fbc89bf
Merged upstream/main into multi_join
mgirlich Nov 18, 2022
1183574
Merge commit '2fc8db416527f096afb66f9d0008e12f812d8ad2'
mgirlich Nov 22, 2022
023cb46
Document `joins` data structure
mgirlich Nov 23, 2022
64bc627
Minor refactoring
mgirlich Nov 23, 2022
6504c18
Update documentation
mgirlich Nov 23, 2022
78b63d7
Remove duplicated code
mgirlich Nov 23, 2022
ef21e2a
Update R/db-sql.R
mgirlich Nov 24, 2022
ef7ff2f
Avoid pipe
mgirlich Nov 24, 2022
eaf60f4
Refactor `joins` data structure
mgirlich Nov 24, 2022
a442e28
Pull out generation of table names
mgirlich Nov 24, 2022
40c84b2
Refactor `sql_build.lazy_multi_join_query()`
mgirlich Nov 24, 2022
4f7e430
Use `length()` instead of `vec_size()`
mgirlich Nov 24, 2022
57090a6
Rename to `make_join_aliases()`
mgirlich Nov 24, 2022
0b0502c
Avoid unnecessary negation
mgirlich Nov 24, 2022
155f5ca
Use postgres for simpler snapshot
mgirlich Nov 24, 2022
5d4fd19
Split test
mgirlich Nov 24, 2022
8a267d1
Replace explicity test with snapshot
mgirlich Nov 24, 2022
678ffa6
Refactor join table name logic
mgirlich Nov 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ S3method(explain,tbl_sql)
S3method(flatten_query,ident)
S3method(flatten_query,join_query)
S3method(flatten_query,lazy_values_query)
S3method(flatten_query,multi_join_query)
S3method(flatten_query,select_query)
S3method(flatten_query,semi_join_query)
S3method(flatten_query,set_op_query)
Expand Down Expand Up @@ -135,6 +136,7 @@ S3method(op_sort,lazy_query)
S3method(op_sort,tbl_lazy)
S3method(op_vars,lazy_base_query)
S3method(op_vars,lazy_join_query)
S3method(op_vars,lazy_multi_join_query)
S3method(op_vars,lazy_query)
S3method(op_vars,lazy_semi_join_query)
S3method(op_vars,lazy_set_op_query)
Expand All @@ -150,6 +152,7 @@ S3method(print,lazy_join_query)
S3method(print,lazy_select_query)
S3method(print,lazy_semi_join_query)
S3method(print,lazy_set_op_query)
S3method(print,multi_join_query)
S3method(print,select_query)
S3method(print,semi_join_query)
S3method(print,set_op_query)
Expand Down Expand Up @@ -190,6 +193,7 @@ S3method(sql_build,ident)
S3method(sql_build,lazy_base_local_query)
S3method(sql_build,lazy_base_remote_query)
S3method(sql_build,lazy_join_query)
S3method(sql_build,lazy_multi_join_query)
S3method(sql_build,lazy_select_query)
S3method(sql_build,lazy_semi_join_query)
S3method(sql_build,lazy_set_op_query)
Expand Down Expand Up @@ -260,6 +264,7 @@ S3method(sql_query_join,MariaDBConnection)
S3method(sql_query_join,MySQL)
S3method(sql_query_join,MySQLConnection)
S3method(sql_query_join,SQLiteConnection)
S3method(sql_query_multi_join,DBIConnection)
S3method(sql_query_rows,DBIConnection)
S3method(sql_query_save,"Microsoft SQL Server")
S3method(sql_query_save,DBIConnection)
Expand Down Expand Up @@ -303,6 +308,7 @@ S3method(sql_render,ident)
S3method(sql_render,join_query)
S3method(sql_render,lazy_query)
S3method(sql_render,lazy_values_query)
S3method(sql_render,multi_join_query)
S3method(sql_render,select_query)
S3method(sql_render,semi_join_query)
S3method(sql_render,set_op_query)
Expand Down Expand Up @@ -435,6 +441,7 @@ export(lahman_srcs)
export(lazy_base_query)
export(lazy_frame)
export(lazy_join_query)
export(lazy_multi_join_query)
export(lazy_query)
export(lazy_select_query)
export(lazy_semi_join_query)
Expand Down Expand Up @@ -501,6 +508,7 @@ export(sql_query_explain)
export(sql_query_fields)
export(sql_query_insert)
export(sql_query_join)
export(sql_query_multi_join)
export(sql_query_rows)
export(sql_query_save)
export(sql_query_select)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@

* `summarise()` + `filter()` now translates to `HAVING` where possible
(@mgirlich, #877).

* `left/inner_join()` + `left/inner_join()` (@mgirlich, #865).

* The generated SQL is now shorter and more readable:

Expand Down
100 changes: 100 additions & 0 deletions R/db-sql.R
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,106 @@ sql_join.DBIConnection <- function(con, x, y, vars, type = "inner", by = NULL, n
)
}

#' @rdname db-sql
#' @export
sql_query_multi_join <- function(con,
x,
joins,
table_vars,
vars,
...,
lvl = 0) {
UseMethod("sql_query_multi_join")
}

#' @export
#' @param vars tibble with three columns:
mgirlich marked this conversation as resolved.
Show resolved Hide resolved
#' * `table` `<tbl_lazy>`: the tables to join with.
#' * `type` `<character>`: the join type (left, right, inner, full).
#' * `by_x`, `by_y` `<list_of<character>>`: The columns to join by
#' * `by_x_table_id` `<list_of<integer>>`: The table index where the join column
#' comes from. This needs to be a list because a the join columns might come
#' from different tables
#' * `on` `<character>`
#' * `na_matches` `<character>`: Either `"na"` or `"never"`.
#' @param vars See [sql_multi_join_vars()].
#' @param table_vars `named <list_of<character>>`: All variables in each table.
#' @noRd
#' @examples
#' # Left join with *
#' df1 <- lazy_frame(x = 1, y = 1)
#' df2 <- lazy_frame(x = 1, z = 1)
#' df3 <- lazy_frame(x = 1, z2 = 1)
#'
#' tmp <- left_join(df1, df2, by = "x") %>%
#' left_join(df3, by = c("x", z = "z2"))
#' tibble(
#' table = list(df1, df2),
#' type = c("left", "left"),
#' by_x = list("x", c("x", "z")),
#' by_y = list("x", c("x", "z2")),
#' by_x_table_id = list(1L, c(1L, 2L)),
#' on = c(NA, NA),
#' na_matches = c("never", "never")
#' )
sql_query_multi_join.DBIConnection <- function(con,
x,
joins,
table_vars,
vars,
...,
lvl = 0) {
mgirlich marked this conversation as resolved.
Show resolved Hide resolved
table_names <- names(table_vars)
if (vctrs::vec_duplicate_any(table_names)) {
cli_abort("{.arg table_names} must be unique.")
}

select_sql <- sql_multi_join_vars(con, vars, table_vars)
from <- dbplyr_sql_subquery(con, x, name = table_names[[1]], lvl = lvl)

join_table_queries <- purrr::map2(
joins$table,
table_names[-1],
function(table, name) dbplyr_sql_subquery(con, table, name = name, lvl = lvl)
)
types <- toupper(paste0(joins$type, " JOIN"))
join_clauses <- purrr::map2(
types,
join_table_queries,
function(join_kw, from) sql_clause(join_kw, from)
)

on_clauses <- purrr::pmap(
vctrs::vec_cbind(
rhs = table_names[-1],
joins[c("by_x", "by_x_table_id", "by_y", "on", "na_matches")]
),
mgirlich marked this conversation as resolved.
Show resolved Hide resolved
function(rhs, by_x, by_x_table_id, by_y, on, na_matches) {
mgirlich marked this conversation as resolved.
Show resolved Hide resolved
if (!is.na(on)) {
on <- sql(on)
mgirlich marked this conversation as resolved.
Show resolved Hide resolved
} else {
by <- list(
x = ident(by_x),
y = ident(by_y),
x_as = ident(table_names[by_x_table_id]),
y_as = ident(rhs)
)
on <- sql_join_tbls(con, by = by, na_matches = na_matches)
}

sql_clause("ON", on, sep = " AND", parens = TRUE, lvl = 1)
}
)
join_on_clauses <- vctrs::vec_interleave(join_clauses, on_clauses)

list2(
mgirlich marked this conversation as resolved.
Show resolved Hide resolved
sql_clause_select(con, select_sql),
sql_clause_from(from),
!!!join_on_clauses
) %>%
sql_format_clauses(lvl = lvl, con = con)
}

#' @rdname db-sql
#' @export
sql_query_semi_join <- function(con, x, y, anti, by, vars, ..., lvl = 0) {
Expand Down
124 changes: 124 additions & 0 deletions R/lazy-join-query.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,51 @@ lazy_join_query <- function(x,
)
}

#' @export
#' @rdname sql_build
lazy_multi_join_query <- function(x,
joins,
table_names,
vars,
group_vars = op_grps(x),
order_vars = op_sort(x),
frame = op_frame(x),
call = caller_env()) {
stopifnot(inherits(x, "lazy_query"))

if (!identical(colnames(joins), c("table", "type", "by_x", "by_y", "by_x_table_id", "on", "na_matches"))) {
cli_abort("`joins` must have fields `table`, `type`, `by_x`, `by_y`, `by_x_table_id`, `on`, `na_matches`", .internal = TRUE)
}
vctrs::vec_assert(joins$type, character(), arg = "joins$type", call = caller_env())
vctrs::vec_assert(joins$on, character(), arg = "joins$on", call = caller_env())
vctrs::vec_assert(joins$na_matches, character(), arg = "joins$na_matches", call = caller_env())

if (!identical(colnames(table_names), c("as", "name"))) {
cli_abort("`table_names` must have fields `as`, `name`", .internal = TRUE)
}
vctrs::vec_assert(table_names$as, character(), arg = "table_names$as", call = caller_env())
vctrs::vec_assert(table_names$name, character(), arg = "table_names$as", call = caller_env())

if (!identical(colnames(vars), c("name", "table", "var"))) {
cli_abort("`vars` must have fields `name`, `table`, `var`", .internal = TRUE)
}
vctrs::vec_assert(vars$name, character(), arg = "vars$name", call = caller_env())
vctrs::vec_assert(vars$table, list(), arg = "vars$table", call = caller_env())
vctrs::vec_assert(vars$var, list(), arg = "vars$var", call = caller_env())

lazy_query(
query_type = "multi_join",
x = x,
joins = joins,
table_names = table_names,
vars = vars,
last_op = "join",
group_vars = group_vars,
order_vars = order_vars,
frame = frame
)
}

join_check_vars <- function(vars, call) {
if (!vctrs::vec_is_list(vars)) {
# TODO use `cli_abort()` after https://github.com/r-lib/rlang/issues/1386
Expand Down Expand Up @@ -139,6 +184,10 @@ op_vars.lazy_join_query <- function(op) {
op$vars$alias
}
#' @export
op_vars.lazy_multi_join_query <- function(op) {
op$vars$name
}
#' @export
op_vars.lazy_semi_join_query <- function(op) {
op$vars$name
}
Expand All @@ -156,6 +205,81 @@ sql_build.lazy_join_query <- function(op, con, ...) {
)
}

#' @export
sql_build.lazy_multi_join_query <- function(op, con, ...) {
auto_name <- is.na(op$table_names$as)
table_names_out <- dplyr::coalesce(op$table_names$as, op$table_names$name)
if (length(table_names_out) == 2) {
mgirlich marked this conversation as resolved.
Show resolved Hide resolved
na_to_null <- function(x) {
if (is.na(x)) {
NULL
} else {
x
}
}

x_name <- na_to_null(op$table_names$name[[1]])
y_name <- na_to_null(op$table_names$name[[2]])
x_alias <- na_to_null(op$table_names$as[[1]])
y_alias <- na_to_null(op$table_names$as[[2]])
table_names_out <- join_two_table_alias(x_name, y_name, x_alias, y_alias)

if (op$joins$type %in% c("full", "right")) {
# construct a classical `join_query()` so that special handling of
# full and right join continue to work
type <- op$joins$type
x_idx <- purrr::map_lgl(op$vars$table, ~ 1L %in% .x)
vars_x <- purrr::map2_chr(op$vars$var, x_idx, ~ {if (.y) .x[[1]] else NA_character_})

y_idx <- purrr::map_lgl(op$vars$table, ~ 2L %in% .x)
vars_y <- purrr::map2_chr(op$vars$var, y_idx, ~ {if (.y) dplyr::last(.x) %||% .x[[1]] else NA_character_})


vars_classic <- list(
alias = op$vars$name,
x = vars_x,
y = vars_y,
all_x = op_vars(op$x),
all_y = op_vars(op$joins$table[[1]])
)

out <- join_query(
sql_optimise(sql_build(op$x, con), con),
sql_optimise(sql_build(op$joins$table[[1]], con), con),
vars = vars_classic,
type = type,
by = list(
on = sql(op$joins$on),
x = op$joins$by_x[[1]],
y = op$joins$by_y[[1]],
x_as = ident(table_names_out[[1]]),
y_as = ident(table_names_out[[2]])
),
suffix = NULL, # it seems like the suffix is not used for rendering
na_matches = op$joins$na_matches
)
return(out)
}
} else {
table_names_repaired <- vctrs::vec_as_names(table_names_out, repair = "unique", quiet = TRUE)
table_names_out[auto_name] <- table_names_repaired[auto_name]
}

all_vars_list <- purrr::map(
c(list(op$x), op$joins$table),
op_vars
)

op$joins$table <- purrr::map(op$joins$table, ~ sql_optimise(sql_build(.x, con), con))

multi_join_query(
x = sql_optimise(sql_build(op$x, con), con),
joins = op$joins,
table_vars = set_names(all_vars_list, table_names_out),
vars = op$vars
)
}

#' @export
sql_build.lazy_semi_join_query <- function(op, con, ...) {
vars_prev <- op_vars(op$x)
Expand Down
Loading