From 41bcc5d889e24f75f17bfa6142c9f367d1c9b28d Mon Sep 17 00:00:00 2001 From: Andy Xu Date: Sun, 12 Nov 2023 19:57:39 -0500 Subject: [PATCH] Add unit test case for the where clause in select statement --- test/unit_tests/parser/test_parser.py | 51 +++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/test/unit_tests/parser/test_parser.py b/test/unit_tests/parser/test_parser.py index 60624b825..9b675acac 100644 --- a/test/unit_tests/parser/test_parser.py +++ b/test/unit_tests/parser/test_parser.py @@ -22,6 +22,7 @@ from evadb.expression.constant_value_expression import ConstantValueExpression from evadb.expression.function_expression import FunctionExpression from evadb.expression.tuple_value_expression import TupleValueExpression +from evadb.expression.logical_expression import LogicalExpression from evadb.parser.alias import Alias from evadb.parser.create_function_statement import CreateFunctionStatement from evadb.parser.create_index_statement import CreateIndexStatement @@ -531,6 +532,56 @@ def test_select_statement_class(self): self.assertEqual(select_stmt_new.from_table, select_stmt.from_table) self.assertEqual(str(select_stmt_new), str(select_stmt)) + def test_select_statement_where_class(self): + """ + Unit test for logical operators in the where clause. + """ + + parser = Parser() + select_query_new = "SELECT CLASS, REDNESS FROM TAIPAI WHERE CLASS = 'VAN' AND REDNESS < 400;" + evadb_statement_list = parser.parse(select_query_new) + + self.assertIsInstance(evadb_statement_list, list) + self.assertEqual(len(evadb_statement_list), 1) + self.assertEqual(evadb_statement_list[0].stmt_type, StatementType.SELECT) + + select_stmt = evadb_statement_list[0] + + # target list + self.assertIsNotNone(select_stmt.target_list) + self.assertEqual(len(select_stmt.target_list), 2) + self.assertEqual(select_stmt.target_list[0].etype, ExpressionType.TUPLE_VALUE) + self.assertEqual(select_stmt.target_list[0].name, "CLASS") + self.assertEqual(select_stmt.target_list[1].etype, ExpressionType.TUPLE_VALUE) + self.assertEqual(select_stmt.target_list[1].name, "REDNESS") + + # from table + self.assertIsNotNone(select_stmt.from_table) + self.assertIsInstance(select_stmt.from_table, TableRef) + self.assertEqual(select_stmt.from_table.table.table_name, "TAIPAI") + + # where clause + self.assertIsNotNone(select_stmt.where_clause) + self.assertIsInstance(select_stmt.where_clause, LogicalExpression) + self.assertEqual(select_stmt.where_clause.etype, ExpressionType.LOGICAL_AND) + self.assertEqual(len(select_stmt.where_clause.children), 2) + left = select_stmt.where_clause.children[0] + right = select_stmt.where_clause.children[1] + self.assertEqual(left.etype, ExpressionType.COMPARE_EQUAL) + self.assertEqual(right.etype, ExpressionType.COMPARE_LESSER) + + self.assertEqual(len(left.children), 2) + self.assertEqual(left.children[0].etype, ExpressionType.TUPLE_VALUE) + self.assertEqual(left.children[0].name, "CLASS") + self.assertEqual(left.children[1].etype, ExpressionType.CONSTANT_VALUE) + self.assertEqual(left.children[1].value, "VAN") + + self.assertEqual(len(right.children), 2) + self.assertEqual(right.children[0].etype, ExpressionType.TUPLE_VALUE) + self.assertEqual(right.children[0].name, "REDNESS") + self.assertEqual(right.children[1].etype, ExpressionType.CONSTANT_VALUE) + self.assertEqual(right.children[1].value, 400) + def test_select_statement_groupby_class(self): """Testing sample frequency"""