From 8a80a141b9ac2b9650e3649b1ed32d0ebcc5257c Mon Sep 17 00:00:00 2001 From: Jiashen Cao Date: Mon, 2 Oct 2023 10:22:02 -0400 Subject: [PATCH] only include tv_expr when necessary during create index --- evadb/parser/lark_visitor/_create_statements.py | 3 ++- test/unit_tests/parser/test_parser.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/evadb/parser/lark_visitor/_create_statements.py b/evadb/parser/lark_visitor/_create_statements.py index ecb332853b..86decf1efe 100644 --- a/evadb/parser/lark_visitor/_create_statements.py +++ b/evadb/parser/lark_visitor/_create_statements.py @@ -267,11 +267,12 @@ def create_index(self, tree): while not isinstance(index_elem, TupleValueExpression): index_elem = index_elem.children[0] index_elem = [index_elem] + else: + project_expr_list += index_elem # Add tv_expr for projected columns. col_list = [] for tv_expr in index_elem: - project_expr_list += [tv_expr] col_list += [ColumnDefinition(tv_expr.name, None, None, None)] return CreateIndexStatement( diff --git a/test/unit_tests/parser/test_parser.py b/test/unit_tests/parser/test_parser.py index 0488e16096..6bc2fe058d 100644 --- a/test/unit_tests/parser/test_parser.py +++ b/test/unit_tests/parser/test_parser.py @@ -113,7 +113,7 @@ def test_create_index_statement(self): ColumnDefinition("featCol", None, None, None), ], VectorStoreType.FAISS, - [TupleValueExpression(name="*")], + [TupleValueExpression(name="featCol")], ) actual_stmt = evadb_stmt_list[0] self.assertEqual(actual_stmt, expected_stmt) @@ -128,7 +128,7 @@ def test_create_index_statement(self): ColumnDefinition("featCol", None, None, None), ], VectorStoreType.FAISS, - [TupleValueExpression(name="*")], + [TupleValueExpression(name="featCol")], ) create_index_query = ( "CREATE INDEX IF NOT EXISTS testindex ON MyVideo (featCol) USING FAISS;" @@ -160,7 +160,7 @@ def test_create_index_statement(self): ColumnDefinition("featCol", None, None, None), ], VectorStoreType.FAISS, - [TupleValueExpression(name="*"), func_expr], + [func_expr], ) actual_stmt = evadb_stmt_list[0] self.assertEqual(actual_stmt, expected_stmt)