From 74fbd8e1eae16410285b19369f3900b718d07776 Mon Sep 17 00:00:00 2001 From: Gaurav Tarlok Kakkar Date: Sat, 16 Sep 2023 16:48:47 -0700 Subject: [PATCH] feat: support order by using the projection columns (#1136) --- evadb/binder/binder_utils.py | 16 ++++++++++++++ evadb/binder/statement_binder.py | 18 +++++---------- evadb/optimizer/statement_to_opr_converter.py | 22 ++++++++++--------- .../long/test_model_forecasting.py | 2 +- test/integration_tests/long/test_pytorch.py | 2 +- 5 files changed, 35 insertions(+), 25 deletions(-) diff --git a/evadb/binder/binder_utils.py b/evadb/binder/binder_utils.py index 1c6f4ff91d..c219382afc 100644 --- a/evadb/binder/binder_utils.py +++ b/evadb/binder/binder_utils.py @@ -361,3 +361,19 @@ def drop_row_id_from_target_list( continue filtered_list.append(expr) return filtered_list + + +def add_func_expr_outputs_to_binder_context( + func_expr: FunctionExpression, binder_context: StatementBinderContext +): + output_cols = [] + for obj, alias in zip(func_expr.output_objs, func_expr.alias.col_names): + col_alias = "{}.{}".format(func_expr.alias.alias_name, alias) + alias_obj = TupleValueExpression( + name=alias, + table_alias=func_expr.alias.alias_name, + col_object=obj, + col_alias=col_alias, + ) + output_cols.append(alias_obj) + binder_context.add_derived_table_alias(func_expr.alias.alias_name, output_cols) diff --git a/evadb/binder/statement_binder.py b/evadb/binder/statement_binder.py index d4c684ce6f..4df16bc72f 100644 --- a/evadb/binder/statement_binder.py +++ b/evadb/binder/statement_binder.py @@ -18,6 +18,7 @@ from evadb.binder.binder_utils import ( BinderError, + add_func_expr_outputs_to_binder_context, bind_table_info, check_column_name_is_string, check_groupby_pattern, @@ -199,6 +200,9 @@ def _bind_select_statement(self, node: SelectStatement): node.target_list = extend_star(self._binder_context) for expr in node.target_list: self.bind(expr) + if isinstance(expr, FunctionExpression): + add_func_expr_outputs_to_binder_context(expr, self._binder_context) + if node.groupby_clause: self.bind(node.groupby_clause) check_table_object_is_groupable(node.from_table) @@ -275,19 +279,7 @@ def _bind_tableref(self, node: TableRef): func_expr = node.table_valued_expr.func_expr func_expr.alias = node.alias self.bind(func_expr) - output_cols = [] - for obj, alias in zip(func_expr.output_objs, func_expr.alias.col_names): - col_alias = "{}.{}".format(func_expr.alias.alias_name, alias) - alias_obj = TupleValueExpression( - name=alias, - table_alias=func_expr.alias.alias_name, - col_object=obj, - col_alias=col_alias, - ) - output_cols.append(alias_obj) - self._binder_context.add_derived_table_alias( - func_expr.alias.alias_name, output_cols - ) + add_func_expr_outputs_to_binder_context(func_expr, self._binder_context) else: raise BinderError(f"Unsupported node {type(node)}") diff --git a/evadb/optimizer/statement_to_opr_converter.py b/evadb/optimizer/statement_to_opr_converter.py index a8e5ed3302..aacfffa209 100644 --- a/evadb/optimizer/statement_to_opr_converter.py +++ b/evadb/optimizer/statement_to_opr_converter.py @@ -119,6 +119,8 @@ def visit_select(self, statement: SelectStatement): statement {SelectStatement} - - [input select statement] """ + # order of evaluation + # from, where, group by, select, order by, limit, union table_ref = statement.from_table if table_ref is not None: self.visit_table_ref(table_ref) @@ -133,22 +135,22 @@ def visit_select(self, statement: SelectStatement): if statement.groupby_clause is not None: self._visit_groupby(statement.groupby_clause) - if statement.orderby_list is not None: - self._visit_orderby(statement.orderby_list) - - if statement.limit_count is not None: - self._visit_limit(statement.limit_count) - - # union - if statement.union_link is not None: - self._visit_union(statement.union_link, statement.union_all) - # Projection operator select_columns = statement.target_list if select_columns is not None: self._visit_projection(select_columns) + if statement.orderby_list is not None: + self._visit_orderby(statement.orderby_list) + + if statement.limit_count is not None: + self._visit_limit(statement.limit_count) + + # union + if statement.union_link is not None: + self._visit_union(statement.union_link, statement.union_all) + def _visit_sample(self, sample_freq, sample_type): sample_opr = LogicalSample(sample_freq, sample_type) sample_opr.append_child(self._plan) diff --git a/test/integration_tests/long/test_model_forecasting.py b/test/integration_tests/long/test_model_forecasting.py index b7778bba68..2a9b266c7e 100644 --- a/test/integration_tests/long/test_model_forecasting.py +++ b/test/integration_tests/long/test_model_forecasting.py @@ -75,7 +75,7 @@ def test_forecast(self): execute_query_fetch_all(self.evadb, create_predict_udf) predict_query = """ - SELECT AirForecast(12); + SELECT AirForecast(12) order by y; """ result = execute_query_fetch_all(self.evadb, predict_query) self.assertEqual(len(result), 12) diff --git a/test/integration_tests/long/test_pytorch.py b/test/integration_tests/long/test_pytorch.py index f30b4af9fe..f2bd66ba04 100644 --- a/test/integration_tests/long/test_pytorch.py +++ b/test/integration_tests/long/test_pytorch.py @@ -215,7 +215,7 @@ def test_should_run_pytorch_and_facenet(self): execute_query_fetch_all(self.evadb, create_function_query) select_query = """SELECT FaceDetector(data) FROM MyVideo - WHERE id < 5;""" + WHERE id < 5 order by scores;""" actual_batch = execute_query_fetch_all(self.evadb, select_query) self.assertEqual(len(actual_batch), 5)