From 7905d69d8cb06f13304f51939143495448b58cbc Mon Sep 17 00:00:00 2001 From: Andy Xu Date: Sun, 12 Nov 2023 20:05:43 -0500 Subject: [PATCH] Add unit test cases --- test/unit_tests/parser/test_parser.py | 92 +++++++++++++++------------ 1 file changed, 52 insertions(+), 40 deletions(-) diff --git a/test/unit_tests/parser/test_parser.py b/test/unit_tests/parser/test_parser.py index 9b675acac3..057a118ec6 100644 --- a/test/unit_tests/parser/test_parser.py +++ b/test/unit_tests/parser/test_parser.py @@ -537,50 +537,62 @@ def test_select_statement_where_class(self): Unit test for logical operators in the where clause. """ - parser = Parser() - select_query_new = "SELECT CLASS, REDNESS FROM TAIPAI WHERE CLASS = 'VAN' AND REDNESS < 400;" - evadb_statement_list = parser.parse(select_query_new) - - self.assertIsInstance(evadb_statement_list, list) - self.assertEqual(len(evadb_statement_list), 1) - self.assertEqual(evadb_statement_list[0].stmt_type, StatementType.SELECT) + def _verify_select_statement(evadb_statement_list): + self.assertIsInstance(evadb_statement_list, list) + self.assertEqual(len(evadb_statement_list), 1) + self.assertEqual(evadb_statement_list[0].stmt_type, StatementType.SELECT) + + select_stmt = evadb_statement_list[0] + + # target list + self.assertIsNotNone(select_stmt.target_list) + self.assertEqual(len(select_stmt.target_list), 2) + self.assertEqual(select_stmt.target_list[0].etype, ExpressionType.TUPLE_VALUE) + self.assertEqual(select_stmt.target_list[0].name, "CLASS") + self.assertEqual(select_stmt.target_list[1].etype, ExpressionType.TUPLE_VALUE) + self.assertEqual(select_stmt.target_list[1].name, "REDNESS") + + # from table + self.assertIsNotNone(select_stmt.from_table) + self.assertIsInstance(select_stmt.from_table, TableRef) + self.assertEqual(select_stmt.from_table.table.table_name, "TAIPAI") + + # where clause + self.assertIsNotNone(select_stmt.where_clause) + self.assertIsInstance(select_stmt.where_clause, LogicalExpression) + self.assertEqual(select_stmt.where_clause.etype, ExpressionType.LOGICAL_AND) + self.assertEqual(len(select_stmt.where_clause.children), 2) + left = select_stmt.where_clause.children[0] + right = select_stmt.where_clause.children[1] + self.assertEqual(left.etype, ExpressionType.COMPARE_EQUAL) + self.assertEqual(right.etype, ExpressionType.COMPARE_LESSER) + + self.assertEqual(len(left.children), 2) + self.assertEqual(left.children[0].etype, ExpressionType.TUPLE_VALUE) + self.assertEqual(left.children[0].name, "CLASS") + self.assertEqual(left.children[1].etype, ExpressionType.CONSTANT_VALUE) + self.assertEqual(left.children[1].value, "VAN") + + self.assertEqual(len(right.children), 2) + self.assertEqual(right.children[0].etype, ExpressionType.TUPLE_VALUE) + self.assertEqual(right.children[0].name, "REDNESS") + self.assertEqual(right.children[1].etype, ExpressionType.CONSTANT_VALUE) + self.assertEqual(right.children[1].value, 400) - select_stmt = evadb_statement_list[0] + parser = Parser() + select_query = "SELECT CLASS, REDNESS FROM TAIPAI WHERE CLASS = 'VAN' AND REDNESS < 400;" + _verify_select_statement(parser.parse(select_query)) - # target list - self.assertIsNotNone(select_stmt.target_list) - self.assertEqual(len(select_stmt.target_list), 2) - self.assertEqual(select_stmt.target_list[0].etype, ExpressionType.TUPLE_VALUE) - self.assertEqual(select_stmt.target_list[0].name, "CLASS") - self.assertEqual(select_stmt.target_list[1].etype, ExpressionType.TUPLE_VALUE) - self.assertEqual(select_stmt.target_list[1].name, "REDNESS") + # Case insensitive test + select_query = "select CLASS, REDNESS from TAIPAI where CLASS = 'VAN' and REDNESS < 400;" + _verify_select_statement(parser.parse(select_query)) - # from table - self.assertIsNotNone(select_stmt.from_table) - self.assertIsInstance(select_stmt.from_table, TableRef) - self.assertEqual(select_stmt.from_table.table.table_name, "TAIPAI") + # Unsupported logical operator + select_query = "SELECT CLASS, REDNESS FROM TAIPAI WHERE CLASS = 'VAN' XOR REDNESS < 400;" + with self.assertRaises(NotImplementedError) as cm: + parser.parse(select_query) + self.assertEqual(str(cm.exception), "Unsupported logical operator: XOR") - # where clause - self.assertIsNotNone(select_stmt.where_clause) - self.assertIsInstance(select_stmt.where_clause, LogicalExpression) - self.assertEqual(select_stmt.where_clause.etype, ExpressionType.LOGICAL_AND) - self.assertEqual(len(select_stmt.where_clause.children), 2) - left = select_stmt.where_clause.children[0] - right = select_stmt.where_clause.children[1] - self.assertEqual(left.etype, ExpressionType.COMPARE_EQUAL) - self.assertEqual(right.etype, ExpressionType.COMPARE_LESSER) - - self.assertEqual(len(left.children), 2) - self.assertEqual(left.children[0].etype, ExpressionType.TUPLE_VALUE) - self.assertEqual(left.children[0].name, "CLASS") - self.assertEqual(left.children[1].etype, ExpressionType.CONSTANT_VALUE) - self.assertEqual(left.children[1].value, "VAN") - - self.assertEqual(len(right.children), 2) - self.assertEqual(right.children[0].etype, ExpressionType.TUPLE_VALUE) - self.assertEqual(right.children[0].name, "REDNESS") - self.assertEqual(right.children[1].etype, ExpressionType.CONSTANT_VALUE) - self.assertEqual(right.children[1].value, 400) def test_select_statement_groupby_class(self): """Testing sample frequency"""