Skip to content

Commit

Permalink
Improve unparsing for ORDER BY, UNION, Windows functions with Agg…
Browse files Browse the repository at this point in the history
…regation (apache#12946)

* Improve unparsing for ORDER BY with Aggregation functions (#38)

* Improve UNION unparsing (#39)

* Scalar functions in ORDER BY unparsing support (#41)

* Improve unparsing for complex Window functions with Aggregation (#42)

* WindowFunction order_by should respect `supports_nulls_first_in_sort` dialect setting (#43)

* Fix plan_to_sql

* Improve
  • Loading branch information
sgrebnov authored Oct 17, 2024
1 parent ccfe020 commit ad273ca
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 36 deletions.
10 changes: 3 additions & 7 deletions datafusion/sql/src/unparser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,6 @@ pub fn expr_to_sql(expr: &Expr) -> Result<ast::Expr> {
unparser.expr_to_sql(expr)
}

pub fn sort_to_sql(sort: &Sort) -> Result<ast::OrderByExpr> {
let unparser = Unparser::default();
unparser.sort_to_sql(sort)
}

const LOWEST: &BinaryOperator = &BinaryOperator::Or;
// Closest precedence we have to IS operator is BitwiseAnd (any other) in PG docs
// (https://www.postgresql.org/docs/7.2/sql-precedence.html)
Expand Down Expand Up @@ -229,9 +224,10 @@ impl Unparser<'_> {
ast::WindowFrameUnits::Groups
}
};
let order_by: Vec<ast::OrderByExpr> = order_by

let order_by = order_by
.iter()
.map(sort_to_sql)
.map(|sort_expr| self.sort_to_sql(sort_expr))
.collect::<Result<Vec<_>>>()?;

let start_bound = self.convert_bound(&window_frame.start_bound)?;
Expand Down
42 changes: 29 additions & 13 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use super::{
},
utils::{
find_agg_node_within_select, find_window_nodes_within_select,
unproject_window_exprs,
unproject_sort_expr, unproject_window_exprs,
},
Unparser,
};
Expand Down Expand Up @@ -352,19 +352,30 @@ impl Unparser<'_> {
if select.already_projected() {
return self.derive(plan, relation);
}
if let Some(query_ref) = query {
if let Some(fetch) = sort.fetch {
query_ref.limit(Some(ast::Expr::Value(ast::Value::Number(
fetch.to_string(),
false,
))));
}
query_ref.order_by(self.sorts_to_sql(sort.expr.clone())?);
} else {
let Some(query_ref) = query else {
return internal_err!(
"Sort operator only valid in a statement context."
);
}
};

if let Some(fetch) = sort.fetch {
query_ref.limit(Some(ast::Expr::Value(ast::Value::Number(
fetch.to_string(),
false,
))));
};

let agg = find_agg_node_within_select(plan, select.already_projected());
// unproject sort expressions
let sort_exprs: Vec<SortExpr> = sort
.expr
.iter()
.map(|sort_expr| {
unproject_sort_expr(sort_expr, agg, sort.input.as_ref())
})
.collect::<Result<Vec<_>>>()?;

query_ref.order_by(self.sorts_to_sql(&sort_exprs)?);

self.select_to_sql_recursively(
sort.input.as_ref(),
Expand Down Expand Up @@ -402,7 +413,7 @@ impl Unparser<'_> {
.collect::<Result<Vec<_>>>()?;
if let Some(sort_expr) = &on.sort_expr {
if let Some(query_ref) = query {
query_ref.order_by(self.sorts_to_sql(sort_expr.clone())?);
query_ref.order_by(self.sorts_to_sql(sort_expr)?);
} else {
return internal_err!(
"Sort operator only valid in a statement context."
Expand Down Expand Up @@ -546,6 +557,11 @@ impl Unparser<'_> {
);
}

// Covers cases where the UNION is a subquery and the projection is at the top level
if select.already_projected() {
return self.derive(plan, relation);
}

let input_exprs: Vec<SetExpr> = union
.inputs
.iter()
Expand Down Expand Up @@ -691,7 +707,7 @@ impl Unparser<'_> {
}
}

fn sorts_to_sql(&self, sort_exprs: Vec<SortExpr>) -> Result<Vec<ast::OrderByExpr>> {
fn sorts_to_sql(&self, sort_exprs: &[SortExpr]) -> Result<Vec<ast::OrderByExpr>> {
sort_exprs
.iter()
.map(|sort_expr| self.sort_to_sql(sort_expr))
Expand Down
69 changes: 54 additions & 15 deletions datafusion/sql/src/unparser/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ use std::cmp::Ordering;
use datafusion_common::{
internal_err,
tree_node::{Transformed, TreeNode},
Column, DataFusionError, Result, ScalarValue,
Column, Result, ScalarValue,
};
use datafusion_expr::{
utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Window,
utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Projection, SortExpr,
Window,
};
use sqlparser::ast;

Expand Down Expand Up @@ -118,21 +119,11 @@ pub(crate) fn unproject_agg_exprs(
if let Expr::Column(c) = sub_expr {
if let Some(unprojected_expr) = find_agg_expr(agg, &c)? {
Ok(Transformed::yes(unprojected_expr.clone()))
} else if let Some(mut unprojected_expr) =
} else if let Some(unprojected_expr) =
windows.and_then(|w| find_window_expr(w, &c.name).cloned())
{
if let Expr::WindowFunction(func) = &mut unprojected_expr {
// Window function can contain an aggregation column, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected
func.args.iter_mut().try_for_each(|arg| {
if let Expr::Column(c) = arg {
if let Some(expr) = find_agg_expr(agg, c)? {
*arg = expr.clone();
}
}
Ok::<(), DataFusionError>(())
})?;
}
Ok(Transformed::yes(unprojected_expr))
// Window function can contain an aggregation columns, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected
return Ok(Transformed::yes(unproject_agg_exprs(&unprojected_expr, agg, None)?));
} else {
internal_err!(
"Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name
Expand Down Expand Up @@ -200,6 +191,54 @@ fn find_window_expr<'a>(
.find(|expr| expr.schema_name().to_string() == column_name)
}

/// Transforms a Column expression into the actual expression from aggregation or projection if found.
/// This is required because if an ORDER BY expression is present in an Aggregate or Select, it is replaced
/// with a Column expression (e.g., "sum(catalog_returns.cr_net_loss)"). We need to transform it back to
/// the actual expression, such as sum("catalog_returns"."cr_net_loss").
pub(crate) fn unproject_sort_expr(
sort_expr: &SortExpr,
agg: Option<&Aggregate>,
input: &LogicalPlan,
) -> Result<SortExpr> {
let mut sort_expr = sort_expr.clone();

// Remove alias if present, because ORDER BY cannot use aliases
if let Expr::Alias(alias) = &sort_expr.expr {
sort_expr.expr = *alias.expr.clone();
}

let Expr::Column(ref col_ref) = sort_expr.expr else {
return Ok(sort_expr);
};

if col_ref.relation.is_some() {
return Ok(sort_expr);
};

// In case of aggregation there could be columns containing aggregation functions we need to unproject
if let Some(agg) = agg {
if agg.schema.is_column_from_schema(col_ref) {
let new_expr = unproject_agg_exprs(&sort_expr.expr, agg, None)?;
sort_expr.expr = new_expr;
return Ok(sort_expr);
}
}

// If SELECT and ORDER BY contain the same expression with a scalar function, the ORDER BY expression will
// be replaced by a Column expression (e.g., "substr(customer.c_last_name, Int64(0), Int64(5))"), and we need
// to transform it back to the actual expression.
if let LogicalPlan::Projection(Projection { expr, schema, .. }) = input {
if let Ok(idx) = schema.index_of_column(col_ref) {
if let Some(Expr::ScalarFunction(scalar_fn)) = expr.get(idx) {
sort_expr.expr = Expr::ScalarFunction(scalar_fn.clone());
}
}
return Ok(sort_expr);
}

Ok(sort_expr)
}

/// Converts a date_part function to SQL, tailoring it to the supported date field extraction style.
pub(crate) fn date_part_to_sql(
unparser: &Unparser,
Expand Down
63 changes: 62 additions & 1 deletion datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ use arrow_schema::*;
use datafusion_common::{DFSchema, Result, TableReference};
use datafusion_expr::test::function_stub::{count_udaf, max_udaf, min_udaf, sum_udaf};
use datafusion_expr::{col, lit, table_scan, wildcard, LogicalPlanBuilder};
use datafusion_functions::unicode;
use datafusion_functions_aggregate::grouping::grouping_udaf;
use datafusion_functions_window::rank::rank_udwf;
use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel};
use datafusion_sql::unparser::dialect::{
DefaultDialect as UnparserDefaultDialect, Dialect as UnparserDialect,
Expand Down Expand Up @@ -139,6 +142,13 @@ fn roundtrip_statement() -> Result<()> {
SELECT j2_string as string FROM j2
ORDER BY string DESC
LIMIT 10"#,
r#"SELECT col1, id FROM (
SELECT j1_string AS col1, j1_id AS id FROM j1
UNION ALL
SELECT j2_string AS col1, j2_id AS id FROM j2
UNION ALL
SELECT j3_string AS col1, j3_id AS id FROM j3
) AS subquery GROUP BY col1, id ORDER BY col1 ASC, id ASC"#,
"SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
first_name from person",
Expand Down Expand Up @@ -657,7 +667,12 @@ where
.unwrap();

let context = MockContextProvider {
state: MockSessionState::default(),
state: MockSessionState::default()
.with_aggregate_function(sum_udaf())
.with_aggregate_function(max_udaf())
.with_aggregate_function(grouping_udaf())
.with_window_function(rank_udwf())
.with_scalar_function(Arc::new(unicode::substr().as_ref().clone())),
};
let sql_to_rel = SqlToRel::new(&context);
let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
Expand Down Expand Up @@ -969,3 +984,49 @@ fn test_with_offset0() {
fn test_with_offset95() {
sql_round_trip(MySqlDialect {}, "select 1 offset 95", "SELECT 1 OFFSET 95");
}

#[test]
fn test_order_by_to_sql() {
// order by aggregation function
sql_round_trip(
GenericDialect {},
r#"SELECT id, first_name, SUM(id) FROM person GROUP BY id, first_name ORDER BY SUM(id) ASC, first_name DESC, id, first_name LIMIT 10"#,
r#"SELECT person.id, person.first_name, sum(person.id) FROM person GROUP BY person.id, person.first_name ORDER BY sum(person.id) ASC NULLS LAST, person.first_name DESC NULLS FIRST, person.id ASC NULLS LAST, person.first_name ASC NULLS LAST LIMIT 10"#,
);

// order by aggregation function alias
sql_round_trip(
GenericDialect {},
r#"SELECT id, first_name, SUM(id) as total_sum FROM person GROUP BY id, first_name ORDER BY total_sum ASC, first_name DESC, id, first_name LIMIT 10"#,
r#"SELECT person.id, person.first_name, sum(person.id) AS total_sum FROM person GROUP BY person.id, person.first_name ORDER BY total_sum ASC NULLS LAST, person.first_name DESC NULLS FIRST, person.id ASC NULLS LAST, person.first_name ASC NULLS LAST LIMIT 10"#,
);

// order by scalar function from projection
sql_round_trip(
GenericDialect {},
r#"SELECT id, first_name, substr(first_name,0,5) FROM person ORDER BY id, substr(first_name,0,5)"#,
r#"SELECT person.id, person.first_name, substr(person.first_name, 0, 5) FROM person ORDER BY person.id ASC NULLS LAST, substr(person.first_name, 0, 5) ASC NULLS LAST"#,
);
}

#[test]
fn test_aggregation_to_sql() {
sql_round_trip(
GenericDialect {},
r#"SELECT id, first_name,
SUM(id) AS total_sum,
SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum,
MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total,
rank() OVER (PARTITION BY grouping(id) + grouping(age), CASE WHEN grouping(age) = 0 THEN id END ORDER BY sum(id) DESC) AS rank_within_parent_1,
rank() OVER (PARTITION BY grouping(age) + grouping(id), CASE WHEN (CAST(grouping(age) AS BIGINT) = 0) THEN id END ORDER BY sum(id) DESC) AS rank_within_parent_2
FROM person
GROUP BY id, first_name;"#,
r#"SELECT person.id, person.first_name,
sum(person.id) AS total_sum, sum(person.id) OVER (PARTITION BY person.first_name ROWS BETWEEN '5' PRECEDING AND '2' FOLLOWING) AS moving_sum,
max(sum(person.id)) OVER (PARTITION BY person.first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total,
rank() OVER (PARTITION BY (grouping(person.id) + grouping(person.age)), CASE WHEN (grouping(person.age) = 0) THEN person.id END ORDER BY sum(person.id) DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rank_within_parent_1,
rank() OVER (PARTITION BY (grouping(person.age) + grouping(person.id)), CASE WHEN (CAST(grouping(person.age) AS BIGINT) = 0) THEN person.id END ORDER BY sum(person.id) DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rank_within_parent_2
FROM person
GROUP BY person.id, person.first_name"#.replace("\n", " ").as_str(),
);
}

0 comments on commit ad273ca

Please sign in to comment.