Skip to content

Commit

Permalink
feat: support order by using the projection columns (#1136)
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurav274 authored Sep 16, 2023
1 parent a0ec785 commit 74fbd8e
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 25 deletions.
16 changes: 16 additions & 0 deletions evadb/binder/binder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,19 @@ def drop_row_id_from_target_list(
continue
filtered_list.append(expr)
return filtered_list


def add_func_expr_outputs_to_binder_context(
func_expr: FunctionExpression, binder_context: StatementBinderContext
):
output_cols = []
for obj, alias in zip(func_expr.output_objs, func_expr.alias.col_names):
col_alias = "{}.{}".format(func_expr.alias.alias_name, alias)
alias_obj = TupleValueExpression(
name=alias,
table_alias=func_expr.alias.alias_name,
col_object=obj,
col_alias=col_alias,
)
output_cols.append(alias_obj)
binder_context.add_derived_table_alias(func_expr.alias.alias_name, output_cols)
18 changes: 5 additions & 13 deletions evadb/binder/statement_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from evadb.binder.binder_utils import (
BinderError,
add_func_expr_outputs_to_binder_context,
bind_table_info,
check_column_name_is_string,
check_groupby_pattern,
Expand Down Expand Up @@ -199,6 +200,9 @@ def _bind_select_statement(self, node: SelectStatement):
node.target_list = extend_star(self._binder_context)
for expr in node.target_list:
self.bind(expr)
if isinstance(expr, FunctionExpression):
add_func_expr_outputs_to_binder_context(expr, self._binder_context)

if node.groupby_clause:
self.bind(node.groupby_clause)
check_table_object_is_groupable(node.from_table)
Expand Down Expand Up @@ -275,19 +279,7 @@ def _bind_tableref(self, node: TableRef):
func_expr = node.table_valued_expr.func_expr
func_expr.alias = node.alias
self.bind(func_expr)
output_cols = []
for obj, alias in zip(func_expr.output_objs, func_expr.alias.col_names):
col_alias = "{}.{}".format(func_expr.alias.alias_name, alias)
alias_obj = TupleValueExpression(
name=alias,
table_alias=func_expr.alias.alias_name,
col_object=obj,
col_alias=col_alias,
)
output_cols.append(alias_obj)
self._binder_context.add_derived_table_alias(
func_expr.alias.alias_name, output_cols
)
add_func_expr_outputs_to_binder_context(func_expr, self._binder_context)
else:
raise BinderError(f"Unsupported node {type(node)}")

Expand Down
22 changes: 12 additions & 10 deletions evadb/optimizer/statement_to_opr_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def visit_select(self, statement: SelectStatement):
statement {SelectStatement} - - [input select statement]
"""

# order of evaluation
# from, where, group by, select, order by, limit, union
table_ref = statement.from_table
if table_ref is not None:
self.visit_table_ref(table_ref)
Expand All @@ -133,22 +135,22 @@ def visit_select(self, statement: SelectStatement):
if statement.groupby_clause is not None:
self._visit_groupby(statement.groupby_clause)

if statement.orderby_list is not None:
self._visit_orderby(statement.orderby_list)

if statement.limit_count is not None:
self._visit_limit(statement.limit_count)

# union
if statement.union_link is not None:
self._visit_union(statement.union_link, statement.union_all)

# Projection operator
select_columns = statement.target_list

if select_columns is not None:
self._visit_projection(select_columns)

if statement.orderby_list is not None:
self._visit_orderby(statement.orderby_list)

if statement.limit_count is not None:
self._visit_limit(statement.limit_count)

# union
if statement.union_link is not None:
self._visit_union(statement.union_link, statement.union_all)

def _visit_sample(self, sample_freq, sample_type):
sample_opr = LogicalSample(sample_freq, sample_type)
sample_opr.append_child(self._plan)
Expand Down
2 changes: 1 addition & 1 deletion test/integration_tests/long/test_model_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_forecast(self):
execute_query_fetch_all(self.evadb, create_predict_udf)

predict_query = """
SELECT AirForecast(12);
SELECT AirForecast(12) order by y;
"""
result = execute_query_fetch_all(self.evadb, predict_query)
self.assertEqual(len(result), 12)
Expand Down
2 changes: 1 addition & 1 deletion test/integration_tests/long/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def test_should_run_pytorch_and_facenet(self):
execute_query_fetch_all(self.evadb, create_function_query)

select_query = """SELECT FaceDetector(data) FROM MyVideo
WHERE id < 5;"""
WHERE id < 5 order by scores;"""
actual_batch = execute_query_fetch_all(self.evadb, select_query)
self.assertEqual(len(actual_batch), 5)

Expand Down

0 comments on commit 74fbd8e

Please sign in to comment.