diff --git a/evadb/binder/binder_utils.py b/evadb/binder/binder_utils.py index 92746eb9a..47cccf0aa 100644 --- a/evadb/binder/binder_utils.py +++ b/evadb/binder/binder_utils.py @@ -34,6 +34,8 @@ if TYPE_CHECKING: from evadb.binder.statement_binder_context import StatementBinderContext from evadb.catalog.catalog_manager import CatalogManager + +from evadb.catalog.sql_config import ROW_NUM_COLUMN from evadb.expression.abstract_expression import AbstractExpression, ExpressionType from evadb.expression.function_expression import FunctionExpression from evadb.expression.tuple_value_expression import TupleValueExpression @@ -171,6 +173,16 @@ def extend_star( return target_list +def create_row_num_tv_expr(table_alias): + tv_expr = TupleValueExpression(name=ROW_NUM_COLUMN) + tv_expr.table_alias = table_alias + tv_expr.col_alias = f"{table_alias}.{ROW_NUM_COLUMN.lower()}" + tv_expr.col_object = ColumnCatalogEntry( + name=ROW_NUM_COLUMN, type=ColumnType.INTEGER + ) + return tv_expr + + def check_groupby_pattern(table_ref: TableRef, groupby_string: str) -> None: # match the pattern of group by clause (e.g., 16 frames or 8 samples) pattern = re.search(r"^\d+\s*(?:frames|samples|paragraphs)$", groupby_string) diff --git a/evadb/binder/create_index_statement_binder.py b/evadb/binder/create_index_statement_binder.py index ea14c4902..fb9de7ebe 100644 --- a/evadb/binder/create_index_statement_binder.py +++ b/evadb/binder/create_index_statement_binder.py @@ -12,17 +12,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from evadb.binder.binder_utils import BinderError +from evadb.binder.binder_utils import BinderError, create_row_num_tv_expr from evadb.binder.statement_binder import StatementBinder from evadb.catalog.catalog_type import NdArrayType, VectorStoreType +from evadb.expression.function_expression import FunctionExpression from evadb.parser.create_index_statement import CreateIndexStatement from evadb.third_party.databases.interface import get_database_handler def bind_create_index(binder: StatementBinder, node: CreateIndexStatement): binder.bind(node.table_ref) - if node.function: - binder.bind(node.function) + + # Bind all projection expressions. + func_project_expr = None + for project_expr in node.project_expr_list: + binder.bind(project_expr) + if isinstance(project_expr, FunctionExpression): + func_project_expr = project_expr + + # Append ROW_NUM_COLUMN. + node.project_expr_list += [create_row_num_tv_expr(node.table_ref.alias)] # TODO: create index currently only supports single numpy column. assert len(node.col_list) == 1, "Index cannot be created on more than 1 column" @@ -54,13 +63,14 @@ def bind_create_index(binder: StatementBinder, node: CreateIndexStatement): # underlying native storage engine. return - if not node.function: - # Feature table type needs to be float32 numpy array. - assert ( - len(node.col_list) == 1 - ), f"Index can be only created on one column, but instead {len(node.col_list)} are provided" - col_def = node.col_list[0] + # Index can be only created on single column. + assert ( + len(node.col_list) == 1 + ), f"Index can be only created on one column, but instead {len(node.col_list)} are provided" + col_def = node.col_list[0] + if func_project_expr is None: + # Feature table type needs to be float32 numpy array. table_ref_obj = node.table_ref.table.table_obj col_list = [col for col in table_ref_obj.columns if col.name == col_def.name] assert ( @@ -78,7 +88,7 @@ def bind_create_index(binder: StatementBinder, node: CreateIndexStatement): else: # Output of the function should be 2 dimension and float32 type. function_obj = binder._catalog().get_function_catalog_entry_by_name( - node.function.name + func_project_expr.name ) for output in function_obj.outputs: assert ( diff --git a/evadb/binder/statement_binder.py b/evadb/binder/statement_binder.py index eb881c483..199c53518 100644 --- a/evadb/binder/statement_binder.py +++ b/evadb/binder/statement_binder.py @@ -30,7 +30,7 @@ resolve_alias_table_value_expression, ) from evadb.binder.statement_binder_context import StatementBinderContext -from evadb.catalog.catalog_type import ColumnType, TableType, VideoColumnName +from evadb.catalog.catalog_type import ColumnType, TableType from evadb.catalog.catalog_utils import get_metadata_properties, is_document_table from evadb.configuration.constants import EvaDB_INSTALLATION_DIR from evadb.expression.abstract_expression import AbstractExpression, ExpressionType @@ -258,16 +258,9 @@ def _bind_tableref(self, node: TableRef): @bind.register(TupleValueExpression) def _bind_tuple_expr(self, node: TupleValueExpression): - table_alias, col_obj = self._binder_context.get_binded_column( - node.name, node.table_alias - ) - node.table_alias = table_alias - if node.name == VideoColumnName.audio: - self._binder_context.enable_audio_retrieval() - if node.name == VideoColumnName.data: - self._binder_context.enable_video_retrieval() - node.col_alias = "{}.{}".format(table_alias, node.name.lower()) - node.col_object = col_obj + from evadb.binder.tuple_value_expression_binder import bind_tuple_expr + + bind_tuple_expr(self, node) @bind.register(FunctionExpression) def _bind_func_expr(self, node: FunctionExpression): diff --git a/evadb/binder/tuple_value_expression_binder.py b/evadb/binder/tuple_value_expression_binder.py new file mode 100644 index 000000000..f1b6c898f --- /dev/null +++ b/evadb/binder/tuple_value_expression_binder.py @@ -0,0 +1,30 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from evadb.binder.statement_binder import StatementBinder +from evadb.catalog.catalog_type import VideoColumnName +from evadb.expression.tuple_value_expression import TupleValueExpression + + +def bind_tuple_expr(binder: StatementBinder, node: TupleValueExpression): + table_alias, col_obj = binder._binder_context.get_binded_column( + node.name, node.table_alias + ) + node.table_alias = table_alias + if node.name == VideoColumnName.audio: + binder._binder_context.enable_audio_retrieval() + if node.name == VideoColumnName.data: + binder._binder_context.enable_video_retrieval() + node.col_alias = "{}.{}".format(table_alias, node.name.lower()) + node.col_object = col_obj diff --git a/evadb/catalog/catalog_manager.py b/evadb/catalog/catalog_manager.py index f5a657101..b7c55c9bf 100644 --- a/evadb/catalog/catalog_manager.py +++ b/evadb/catalog/catalog_manager.py @@ -414,9 +414,15 @@ def insert_index_catalog_entry( vector_store_type: VectorStoreType, feat_column: ColumnCatalogEntry, function_signature: str, + index_def: str, ) -> IndexCatalogEntry: index_catalog_entry = self._index_service.insert_entry( - name, save_file_path, vector_store_type, feat_column, function_signature + name, + save_file_path, + vector_store_type, + feat_column, + function_signature, + index_def, ) return index_catalog_entry diff --git a/evadb/catalog/models/index_catalog.py b/evadb/catalog/models/index_catalog.py index f4a9ff00b..40a40f63a 100644 --- a/evadb/catalog/models/index_catalog.py +++ b/evadb/catalog/models/index_catalog.py @@ -31,6 +31,8 @@ class IndexCatalog(BaseModel): `_feat_column_id:` the `_row_id` of the `ColumnCatalog` entry for the column on which the index is built. `_function_signature:` if the index is created by running function expression on input column, this will store the function signature of the used function. Otherwise, this field is None. + `_index_def:` the original SQL statement that is used to create this index. We record this to rerun create index + on updated table. """ __tablename__ = "index_catalog" @@ -42,6 +44,7 @@ class IndexCatalog(BaseModel): "column_id", Integer, ForeignKey("column_catalog._row_id", ondelete="CASCADE") ) _function_signature = Column("function", String, default=None) + _index_def = Column("index_def", String, default=None) _feat_column = relationship( "ColumnCatalog", @@ -55,12 +58,14 @@ def __init__( type: VectorStoreType, feat_column_id: int = None, function_signature: str = None, + index_def: str = None, ): self._name = name self._save_file_path = save_file_path self._type = type self._feat_column_id = feat_column_id self._function_signature = function_signature + self._index_def = index_def def as_dataclass(self) -> "IndexCatalogEntry": feat_column = self._feat_column.as_dataclass() if self._feat_column else None @@ -71,5 +76,6 @@ def as_dataclass(self) -> "IndexCatalogEntry": type=self._type, feat_column_id=self._feat_column_id, function_signature=self._function_signature, + index_def=self._index_def, feat_column=feat_column, ) diff --git a/evadb/catalog/models/utils.py b/evadb/catalog/models/utils.py index cdcd6c8ec..b1c067aa0 100644 --- a/evadb/catalog/models/utils.py +++ b/evadb/catalog/models/utils.py @@ -201,6 +201,7 @@ class IndexCatalogEntry: row_id: int = None feat_column_id: int = None function_signature: str = None + index_def: str = None feat_column: ColumnCatalogEntry = None diff --git a/evadb/catalog/services/index_catalog_service.py b/evadb/catalog/services/index_catalog_service.py index b2a907a77..4b4b67578 100644 --- a/evadb/catalog/services/index_catalog_service.py +++ b/evadb/catalog/services/index_catalog_service.py @@ -35,9 +35,15 @@ def insert_entry( type: VectorStoreType, feat_column: ColumnCatalogEntry, function_signature: str, + index_def: str, ) -> IndexCatalogEntry: index_entry = IndexCatalog( - name, save_file_path, type, feat_column.row_id, function_signature + name, + save_file_path, + type, + feat_column.row_id, + function_signature, + index_def, ) index_entry = index_entry.save(self.session) return index_entry.as_dataclass() diff --git a/evadb/executor/create_index_executor.py b/evadb/executor/create_index_executor.py index 8e9ff56c9..54e43d170 100644 --- a/evadb/executor/create_index_executor.py +++ b/evadb/executor/create_index_executor.py @@ -21,9 +21,9 @@ from evadb.database import EvaDBDatabase from evadb.executor.abstract_executor import AbstractExecutor from evadb.executor.executor_utils import ExecutorError, handle_vector_store_params +from evadb.expression.function_expression import FunctionExpression from evadb.models.storage.batch import Batch from evadb.plan_nodes.create_index_plan import CreateIndexPlan -from evadb.storage.storage_engine import StorageEngine from evadb.third_party.databases.interface import get_database_handler from evadb.third_party.vector_stores.types import FeaturePayload from evadb.third_party.vector_stores.utils import VectorStoreFactory @@ -101,25 +101,21 @@ def _create_evadb_index(self): col for col in feat_catalog_entry.columns if col.name == feat_col_name ][0] + # Find function expression. + function_expression = None + for project_expr in self.node.project_expr_list: + if isinstance(project_expr, FunctionExpression): + function_expression = project_expr + + if function_expression is not None: + feat_col_name = function_expression.output_objs[0].name + # Add features to index. # TODO: batch size is hardcoded for now. input_dim = -1 - storage_engine = StorageEngine.factory(self.db, feat_catalog_entry) - for input_batch in storage_engine.read(feat_catalog_entry): - if self.node.function: - # Create index through function expression. - # Function(input column) -> 2 dimension feature vector. - input_batch.modify_column_alias(feat_catalog_entry.name.lower()) - feat_batch = self.node.function.evaluate(input_batch) - feat_batch.drop_column_alias() - input_batch.drop_column_alias() - feat = feat_batch.column_as_numpy_array("features") - else: - # Create index on the feature table directly. - # Pandas wraps numpy array as an object inside a numpy - # array. Use zero index to get the actual numpy array. - feat = input_batch.column_as_numpy_array(feat_col_name) - + for input_batch in self.children[0].exec(): + input_batch.drop_column_alias() + feat = input_batch.column_as_numpy_array(feat_col_name) row_num = input_batch.column_as_numpy_array(ROW_NUM_COLUMN) for i in range(len(input_batch)): @@ -147,7 +143,10 @@ def _create_evadb_index(self): index_path, self.node.vector_store_type, feat_column, - self.node.function.signature() if self.node.function else None, + function_expression.signature() + if function_expression is not None + else None, + self.node.index_def, ) except Exception as e: # Delete index. diff --git a/evadb/optimizer/operators.py b/evadb/optimizer/operators.py index 59d6fa1c0..6d7e32613 100644 --- a/evadb/optimizer/operators.py +++ b/evadb/optimizer/operators.py @@ -1084,7 +1084,8 @@ def __init__( table_ref: TableRef, col_list: List[ColumnDefinition], vector_store_type: VectorStoreType, - function: FunctionExpression = None, + project_expr_list: List[AbstractExpression], + index_def: str, children: List = None, ): super().__init__(OperatorType.LOGICALCREATEINDEX, children) @@ -1093,7 +1094,8 @@ def __init__( self._table_ref = table_ref self._col_list = col_list self._vector_store_type = vector_store_type - self._function = function + self._project_expr_list = project_expr_list + self._index_def = index_def @property def name(self): @@ -1116,8 +1118,12 @@ def vector_store_type(self): return self._vector_store_type @property - def function(self): - return self._function + def project_expr_list(self): + return self._project_expr_list + + @property + def index_def(self): + return self._index_def def __eq__(self, other): is_subtree_equal = super().__eq__(other) @@ -1130,7 +1136,8 @@ def __eq__(self, other): and self.table_ref == other.table_ref and self.col_list == other.col_list and self.vector_store_type == other.vector_store_type - and self.function == other.function + and self.project_expr_list == other.project_expr_list + and self.index_def == other.index_def ) def __hash__(self) -> int: @@ -1142,7 +1149,8 @@ def __hash__(self) -> int: self.table_ref, tuple(self.col_list), self.vector_store_type, - self.function, + tuple(self.project_expr_list), + self.index_def, ) ) diff --git a/evadb/optimizer/rules/rules.py b/evadb/optimizer/rules/rules.py index 5481321f4..955885f1b 100644 --- a/evadb/optimizer/rules/rules.py +++ b/evadb/optimizer/rules/rules.py @@ -832,8 +832,19 @@ def apply(self, before: LogicalCreateIndex, context: OptimizerContext): before.table_ref, before.col_list, before.vector_store_type, - before.function, + before.project_expr_list, + before.index_def, ) + child = SeqScanPlan(None, before.project_expr_list, before.table_ref.alias) + batch_mem_size = context.db.config.get_value("executor", "batch_mem_size") + child.append_child( + StoragePlan( + before.table_ref.table.table_obj, + before.table_ref, + batch_mem_size=batch_mem_size, + ) + ) + after.append_child(child) yield after diff --git a/evadb/optimizer/statement_to_opr_converter.py b/evadb/optimizer/statement_to_opr_converter.py index d0e36e16a..d7c79bd0f 100644 --- a/evadb/optimizer/statement_to_opr_converter.py +++ b/evadb/optimizer/statement_to_opr_converter.py @@ -359,7 +359,8 @@ def visit_create_index(self, statement: CreateIndexStatement): statement.table_ref, statement.col_list, statement.vector_store_type, - statement.function, + statement.project_expr_list, + statement.index_def, ) self._plan = create_index_opr diff --git a/evadb/parser/create_index_statement.py b/evadb/parser/create_index_statement.py index 396228004..77f352520 100644 --- a/evadb/parser/create_index_statement.py +++ b/evadb/parser/create_index_statement.py @@ -15,6 +15,7 @@ from typing import List from evadb.catalog.catalog_type import VectorStoreType +from evadb.expression.abstract_expression import AbstractExpression from evadb.expression.function_expression import FunctionExpression from evadb.parser.create_statement import ColumnDefinition from evadb.parser.statement import AbstractStatement @@ -30,7 +31,7 @@ def __init__( table_ref: TableRef, col_list: List[ColumnDefinition], vector_store_type: VectorStoreType, - function: FunctionExpression = None, + project_expr_list: List[AbstractStatement], ): super().__init__(StatementType.CREATE_INDEX) self._name = name @@ -38,16 +39,28 @@ def __init__( self._table_ref = table_ref self._col_list = col_list self._vector_store_type = vector_store_type - self._function = function + self._project_expr_list = project_expr_list + + # Definition of CREATE INDEX. + self._index_def = self.__str__() def __str__(self) -> str: - print_str = "CREATE INDEX {} {} ON {} ({}{}) ".format( - self._name, - "IF NOT EXISTS" if self._if_not_exists else "", - self._table_ref, - "" if self._function else self._function, - tuple(self._col_list), - ) + function_expr = None + for project_expr in self._project_expr_list: + if isinstance(project_expr, FunctionExpression): + function_expr = project_expr + + print_str = "CREATE INDEX" + if self._if_not_exists: + print_str += " IF NOT EXISTS" + print_str += f" {self._name}" + print_str += " ON" + print_str += f" {self._table_ref.table.table_name}" + if function_expr is None: + print_str += f" ({self.col_list[0].name})" + else: + print_str += f" ({function_expr.name}({self.col_list[0].name}))" + print_str += f" USING {self._vector_store_type};" return print_str @property @@ -71,8 +84,16 @@ def vector_store_type(self): return self._vector_store_type @property - def function(self): - return self._function + def project_expr_list(self): + return self._project_expr_list + + @project_expr_list.setter + def project_expr_list(self, project_expr_list: List[AbstractExpression]): + self._project_expr_list = project_expr_list + + @property + def index_def(self): + return self._index_def def __eq__(self, other): if not isinstance(other, CreateIndexStatement): @@ -83,7 +104,8 @@ def __eq__(self, other): and self._table_ref == other.table_ref and self.col_list == other.col_list and self._vector_store_type == other.vector_store_type - and self._function == other.function + and self._project_expr_list == other.project_expr_list + and self._index_def == other.index_def ) def __hash__(self) -> int: @@ -95,6 +117,7 @@ def __hash__(self) -> int: self._table_ref, tuple(self.col_list), self._vector_store_type, - self._function, + tuple(self._project_expr_list), + self._index_def, ) ) diff --git a/evadb/parser/lark_visitor/_create_statements.py b/evadb/parser/lark_visitor/_create_statements.py index 96aa7a136..86decf1ef 100644 --- a/evadb/parser/lark_visitor/_create_statements.py +++ b/evadb/parser/lark_visitor/_create_statements.py @@ -256,22 +256,32 @@ def create_index(self, tree): elif child.data == "index_elem": index_elem = self.visit(child) + # Projection list of child of index creation. + project_expr_list = [] + # Parse either a single function call or column list. - col_list, function = None, None if not isinstance(index_elem, list): - function = index_elem + project_expr_list += [index_elem] # Traverse to the tuple value expression. while not isinstance(index_elem, TupleValueExpression): index_elem = index_elem.children[0] index_elem = [index_elem] + else: + project_expr_list += index_elem - col_list = [ - ColumnDefinition(tv_expr.name, None, None, None) for tv_expr in index_elem - ] + # Add tv_expr for projected columns. + col_list = [] + for tv_expr in index_elem: + col_list += [ColumnDefinition(tv_expr.name, None, None, None)] return CreateIndexStatement( - index_name, if_not_exists, table_ref, col_list, vector_store_type, function + index_name, + if_not_exists, + table_ref, + col_list, + vector_store_type, + project_expr_list, ) def vector_store_type(self, tree): diff --git a/evadb/plan_nodes/create_index_plan.py b/evadb/plan_nodes/create_index_plan.py index 96aa7a12a..e5c573477 100644 --- a/evadb/plan_nodes/create_index_plan.py +++ b/evadb/plan_nodes/create_index_plan.py @@ -15,6 +15,7 @@ from typing import List from evadb.catalog.catalog_type import VectorStoreType +from evadb.expression.abstract_expression import AbstractExpression from evadb.expression.function_expression import FunctionExpression from evadb.parser.create_statement import ColumnDefinition from evadb.parser.table_ref import TableRef @@ -30,7 +31,8 @@ def __init__( table_ref: TableRef, col_list: List[ColumnDefinition], vector_store_type: VectorStoreType, - function: FunctionExpression = None, + project_expr_list: List[AbstractExpression], + index_def: str, ): super().__init__(PlanOprType.CREATE_INDEX) self._name = name @@ -38,7 +40,8 @@ def __init__( self._table_ref = table_ref self._col_list = col_list self._vector_store_type = vector_store_type - self._function = function + self._project_expr_list = project_expr_list + self._index_def = index_def @property def name(self): @@ -61,10 +64,19 @@ def vector_store_type(self): return self._vector_store_type @property - def function(self): - return self._function + def project_expr_list(self): + return self._project_expr_list + + @property + def index_def(self): + return self._index_def def __str__(self): + function_expr = None + for project_expr in self._project_expr_list: + if isinstance(project_expr, FunctionExpression): + function_expr = project_expr + return "CreateIndexPlan(name={}, \ table_ref={}, \ col_list={}, \ @@ -74,7 +86,7 @@ def __str__(self): self._table_ref, tuple(self._col_list), self._vector_store_type, - "" if not self._function else "function={}".format(self._function), + "" if function_expr is None else "function={}".format(function_expr), ) def __hash__(self) -> int: @@ -86,6 +98,7 @@ def __hash__(self) -> int: self.table_ref, tuple(self.col_list), self.vector_store_type, - self.function, + tuple(self.project_expr_list), + self.index_def, ) ) diff --git a/test/integration_tests/long/test_create_index_executor.py b/test/integration_tests/long/test_create_index_executor.py index feabb5bff..804bb5b2b 100644 --- a/test/integration_tests/long/test_create_index_executor.py +++ b/test/integration_tests/long/test_create_index_executor.py @@ -134,6 +134,7 @@ def test_index_already_exist(self): @macos_skip_marker def test_should_create_index_faiss(self): query = "CREATE INDEX testCreateIndexName ON testCreateIndexFeatTable (feat) USING FAISS;" + execute_query_fetch_all(self.evadb, query) # Test index catalog. diff --git a/test/unit_tests/binder/test_statement_binder.py b/test/unit_tests/binder/test_statement_binder.py index 3de8d1745..d6642ea9a 100644 --- a/test/unit_tests/binder/test_statement_binder.py +++ b/test/unit_tests/binder/test_statement_binder.py @@ -21,6 +21,7 @@ from evadb.catalog.catalog_type import ColumnType, NdArrayType from evadb.catalog.models.utils import ColumnCatalogEntry from evadb.catalog.sql_config import IDENTIFIER_COLUMN +from evadb.expression.function_expression import FunctionExpression from evadb.expression.tuple_value_expression import TupleValueExpression from evadb.parser.alias import Alias from evadb.parser.create_statement import ColumnDefinition @@ -332,11 +333,23 @@ def test_bind_create_index(self): with self.assertRaises(AssertionError): binder._bind_create_index_statement(create_index_statement) - create_index_statement.col_list = ["foo"] + col_def = MagicMock() + col_def.name = "a" + create_index_statement.col_list = [col_def] + + col = MagicMock() + col.name = "a" + create_index_statement.table_ref.table.table_obj.columns = [col] + function_obj = MagicMock() output = MagicMock() function_obj.outputs = [output] + create_index_statement.project_expr_list = [ + FunctionExpression(MagicMock(), name="a"), + TupleValueExpression(name="*"), + ] + with patch.object( catalog(), "get_function_catalog_entry_by_name", @@ -350,13 +363,7 @@ def test_bind_create_index(self): output.array_dimensions = [1, 100] binder._bind_create_index_statement(create_index_statement) - create_index_statement.function = None - col_def = MagicMock() - col_def.name = "a" - create_index_statement.col_list = [col_def] - col = MagicMock() - col.name = "a" - create_index_statement.table_ref.table.table_obj.columns = [col] + create_index_statement.project_expr_list = [TupleValueExpression(name="*")] with self.assertRaises(AssertionError): binder._bind_create_index_statement(create_index_statement) diff --git a/test/unit_tests/optimizer/test_statement_to_opr_converter.py b/test/unit_tests/optimizer/test_statement_to_opr_converter.py index d60284b13..beeeac94b 100644 --- a/test/unit_tests/optimizer/test_statement_to_opr_converter.py +++ b/test/unit_tests/optimizer/test_statement_to_opr_converter.py @@ -286,7 +286,13 @@ def test_check_plan_equality(self): MagicMock(), MagicMock(), MagicMock(), MagicMock(), MagicMock(), MagicMock() ) create_index_plan = LogicalCreateIndex( - MagicMock(), MagicMock(), MagicMock(), MagicMock(), MagicMock(), MagicMock() + MagicMock(), + MagicMock(), + MagicMock(), + MagicMock(), + MagicMock(), + MagicMock(), + MagicMock(), ) delete_plan = LogicalDelete(MagicMock()) insert_plan = LogicalInsert( diff --git a/test/unit_tests/parser/test_parser.py b/test/unit_tests/parser/test_parser.py index f9005b0a6..6bc2fe058 100644 --- a/test/unit_tests/parser/test_parser.py +++ b/test/unit_tests/parser/test_parser.py @@ -113,11 +113,23 @@ def test_create_index_statement(self): ColumnDefinition("featCol", None, None, None), ], VectorStoreType.FAISS, + [TupleValueExpression(name="featCol")], ) actual_stmt = evadb_stmt_list[0] self.assertEqual(actual_stmt, expected_stmt) + self.assertEqual(actual_stmt.index_def, create_index_query) # create if_not_exists + expected_stmt = CreateIndexStatement( + "testindex", + True, + TableRef(TableInfo("MyVideo")), + [ + ColumnDefinition("featCol", None, None, None), + ], + VectorStoreType.FAISS, + [TupleValueExpression(name="featCol")], + ) create_index_query = ( "CREATE INDEX IF NOT EXISTS testindex ON MyVideo (featCol) USING FAISS;" ) @@ -125,6 +137,7 @@ def test_create_index_statement(self): actual_stmt = evadb_stmt_list[0] expected_stmt._if_not_exists = True self.assertEqual(actual_stmt, expected_stmt) + self.assertEqual(actual_stmt.index_def, create_index_query) # create index on Function expression create_index_query = ( @@ -147,10 +160,11 @@ def test_create_index_statement(self): ColumnDefinition("featCol", None, None, None), ], VectorStoreType.FAISS, - func_expr, + [func_expr], ) actual_stmt = evadb_stmt_list[0] self.assertEqual(actual_stmt, expected_stmt) + self.assertEqual(actual_stmt.index_def, create_index_query) @unittest.skip("Skip parser exception handling testcase, moved to binder") def test_create_index_exception_statement(self):