diff --git a/evadb/binder/binder_utils.py b/evadb/binder/binder_utils.py index 59746445cd..1c6f4ff91d 100644 --- a/evadb/binder/binder_utils.py +++ b/evadb/binder/binder_utils.py @@ -58,13 +58,12 @@ def check_data_source_and_table_are_valid( db_catalog_entry = catalog.get_database_catalog_entry(database_name) if db_catalog_entry is not None: - handler = get_database_handler( + with get_database_handler( db_catalog_entry.engine, **db_catalog_entry.params - ) - handler.connect() + ) as handler: + # Get table definition. + resp = handler.get_tables() - # Get table definition. - resp = handler.get_tables() if resp.error is not None: error = "There is no table in data source {}. Create the table using native query.".format( database_name, @@ -90,7 +89,7 @@ def check_data_source_and_table_are_valid( def create_table_catalog_entry_for_data_source( - table_name: str, column_info: pd.DataFrame + table_name: str, database_name: str, column_info: pd.DataFrame ): column_name_list = list(column_info["name"]) column_type_list = [ @@ -107,6 +106,7 @@ def create_table_catalog_entry_for_data_source( file_url=None, table_type=TableType.NATIVE_DATA, columns=column_list, + database_name=database_name, ) return table_catalog_entry @@ -134,14 +134,14 @@ def bind_native_table_info(catalog: CatalogManager, table_info: TableInfo): ) db_catalog_entry = catalog.get_database_catalog_entry(table_info.database_name) - handler = get_database_handler(db_catalog_entry.engine, **db_catalog_entry.params) - handler.connect() - - # Assemble columns. - column_df = handler.get_columns(table_info.table_name).data - table_info.table_obj = create_table_catalog_entry_for_data_source( - table_info.table_name, column_df - ) + with get_database_handler( + db_catalog_entry.engine, **db_catalog_entry.params + ) as handler: + # Assemble columns. + column_df = handler.get_columns(table_info.table_name).data + table_info.table_obj = create_table_catalog_entry_for_data_source( + table_info.table_name, table_info.database_name, column_df + ) def bind_evadb_table_info(catalog: CatalogManager, table_info: TableInfo): diff --git a/evadb/binder/statement_binder_context.py b/evadb/binder/statement_binder_context.py index 32dc12c7d8..1d038bb3e5 100644 --- a/evadb/binder/statement_binder_context.py +++ b/evadb/binder/statement_binder_context.py @@ -85,16 +85,14 @@ def add_table_alias(self, alias: str, database_name: str, table_name: str): ) db_catalog_entry = self._catalog().get_database_catalog_entry(database_name) - handler = get_database_handler( + with get_database_handler( db_catalog_entry.engine, **db_catalog_entry.params - ) - handler.connect() - - # Assemble columns. - column_df = handler.get_columns(table_name).data - table_obj = create_table_catalog_entry_for_data_source( - table_name, column_df - ) + ) as handler: + # Assemble columns. + column_df = handler.get_columns(table_name).data + table_obj = create_table_catalog_entry_for_data_source( + table_name, database_name, column_df + ) else: table_obj = self._catalog().get_table_catalog_entry(table_name) diff --git a/evadb/catalog/models/utils.py b/evadb/catalog/models/utils.py index 30773af12b..cdcd6c8ece 100644 --- a/evadb/catalog/models/utils.py +++ b/evadb/catalog/models/utils.py @@ -137,6 +137,7 @@ class TableCatalogEntry: identifier_column: str = "id" columns: List[ColumnCatalogEntry] = field(compare=False, default_factory=list) row_id: int = None + database_name: str = "EvaDB" @dataclass(unsafe_hash=True) diff --git a/evadb/executor/create_database_executor.py b/evadb/executor/create_database_executor.py index bd1f0efd7b..ac4211bb25 100644 --- a/evadb/executor/create_database_executor.py +++ b/evadb/executor/create_database_executor.py @@ -42,10 +42,8 @@ def exec(self, *args, **kwargs): raise ExecutorError(f"{self.node.database_name} already exists.") # Check the validity of database entry. - handler = get_database_handler(self.node.engine, **self.node.param_dict) - resp = handler.connect() - if not resp.status: - raise ExecutorError(f"Cannot establish connection due to {resp.error}") + with get_database_handler(self.node.engine, **self.node.param_dict): + pass logger.debug(f"Creating database {self.node}") self.catalog().insert_database_catalog_entry( diff --git a/evadb/executor/executor_utils.py b/evadb/executor/executor_utils.py index 89cf1e4736..24bb65a49f 100644 --- a/evadb/executor/executor_utils.py +++ b/evadb/executor/executor_utils.py @@ -17,10 +17,14 @@ from pathlib import Path from typing import TYPE_CHECKING, Generator, List +from evadb.catalog.catalog_utils import xform_column_definitions_to_catalog_entries +from evadb.catalog.models.utils import TableCatalogEntry +from evadb.parser.create_statement import ColumnDefinition + if TYPE_CHECKING: from evadb.catalog.catalog_manager import CatalogManager -from evadb.catalog.catalog_type import VectorStoreType +from evadb.catalog.catalog_type import TableType, VectorStoreType from evadb.expression.abstract_expression import AbstractExpression from evadb.expression.function_expression import FunctionExpression from evadb.models.storage.batch import Batch @@ -169,3 +173,19 @@ def handle_vector_store_params( return {"index_db": str(Path(index_path).parent)} else: raise ValueError("Unsupported vector store type: {}".format(vector_store_type)) + + +def create_table_catalog_entry_for_native_table( + table_info: TableInfo, column_list: List[ColumnDefinition] +): + column_catalog_entires = xform_column_definitions_to_catalog_entries(column_list) + + # Assemble table. + table_catalog_entry = TableCatalogEntry( + name=table_info.table_name, + file_url=None, + table_type=TableType.NATIVE_DATA, + columns=column_catalog_entires, + database_name=table_info.database_name, + ) + return table_catalog_entry diff --git a/evadb/executor/storage_executor.py b/evadb/executor/storage_executor.py index ce2853d686..45b119e4d7 100644 --- a/evadb/executor/storage_executor.py +++ b/evadb/executor/storage_executor.py @@ -49,9 +49,7 @@ def exec(self, *args, **kwargs) -> Iterator[Batch]: elif self.node.table.table_type == TableType.STRUCTURED_DATA: return storage_engine.read(self.node.table, self.node.batch_mem_size) elif self.node.table.table_type == TableType.NATIVE_DATA: - return storage_engine.read( - self.node.table_ref.table.database_name, self.node.table - ) + return storage_engine.read(self.node.table) elif self.node.table.table_type == TableType.PDF_DATA: return storage_engine.read(self.node.table) else: diff --git a/evadb/executor/use_executor.py b/evadb/executor/use_executor.py index dd1a94448a..9e55d14275 100644 --- a/evadb/executor/use_executor.py +++ b/evadb/executor/use_executor.py @@ -38,16 +38,12 @@ def exec(self, *args, **kwargs) -> Iterator[Batch]: f"{self._database_name} data source does not exist. Use CREATE DATABASE to add a new data source." ) - handler = get_database_handler( - db_catalog_entry.engine, - **db_catalog_entry.params, - ) - - handler.connect() - resp = handler.execute_native_query(self._query_string) - handler.disconnect() + with get_database_handler( + db_catalog_entry.engine, **db_catalog_entry.params + ) as handler: + resp = handler.execute_native_query(self._query_string) - if resp.error is None: + if resp and resp.error is None: return Batch(resp.data) else: raise ExecutorError(resp.error) diff --git a/evadb/storage/native_storage_engine.py b/evadb/storage/native_storage_engine.py index f728524029..5439724c66 100644 --- a/evadb/storage/native_storage_engine.py +++ b/evadb/storage/native_storage_engine.py @@ -28,33 +28,30 @@ class NativeStorageEngine(AbstractStorageEngine): def __init__(self, db: EvaDBDatabase): super().__init__(db) - def create(self, table: TableCatalogEntry): - pass - def write(self, table: TableCatalogEntry, rows: Batch): pass - def read(self, database_name: str, table: TableCatalogEntry) -> Iterator[Batch]: + def read(self, table: TableCatalogEntry) -> Iterator[Batch]: try: db_catalog_entry = self.db.catalog().get_database_catalog_entry( - database_name + table.database_name ) - handler = get_database_handler( + with get_database_handler( db_catalog_entry.engine, **db_catalog_entry.params - ) - handler.connect() - - data_df = handler.execute_native_query(f"SELECT * FROM {table.name}").data - - # Handling case-sensitive databases like SQLite can be tricky. Currently, - # EvaDB converts all columns to lowercase, which may result in issues with - # these databases. As we move forward, we are actively working on improving - # this aspect within Binder. - # For more information, please refer to https://github.com/georgia-tech-db/evadb/issues/1079. - data_df.columns = data_df.columns.str.lower() - yield Batch(pd.DataFrame(data_df)) + ) as handler: + data_df = handler.execute_native_query( + f"SELECT * FROM {table.name}" + ).data + + # Handling case-sensitive databases like SQLite can be tricky. + # Currently, EvaDB converts all columns to lowercase, which may result + # in issues with these databases. As we move forward, we are actively + # working on improving this aspect within Binder. For more information, + # please refer to https://github.com/georgia-tech-db/evadb/issues/1079. + data_df.columns = data_df.columns.str.lower() + yield Batch(pd.DataFrame(data_df)) except Exception as e: - err_msg = f"Failed to read the table {table.name} in data source {database_name} with exception {str(e)}" + err_msg = f"Failed to read the table {table.name} in data source {table.database_name} with exception {str(e)}" logger.exception(err_msg) raise Exception(err_msg) diff --git a/evadb/third_party/databases/interface.py b/evadb/third_party/databases/interface.py index f0c3ee14ec..620b863624 100644 --- a/evadb/third_party/databases/interface.py +++ b/evadb/third_party/databases/interface.py @@ -14,9 +14,10 @@ # limitations under the License. import importlib import os +from contextlib import contextmanager -def get_database_handler(engine: str, **kwargs): +def _get_database_handler(engine: str, **kwargs): """ Return the database handler. User should modify this function for their new integrated handlers. @@ -43,6 +44,18 @@ def get_database_handler(engine: str, **kwargs): raise NotImplementedError(f"Engine {engine} is not supported") +@contextmanager +def get_database_handler(engine: str, **kwargs): + handler = _get_database_handler(engine, **kwargs) + try: + handler.connect() + yield handler + except Exception as e: + raise Exception(f"Error connecting to the database: {str(e)}") + finally: + handler.disconnect() + + def dynamic_import(handler_dir): import_path = f"evadb.third_party.databases.{handler_dir}.{handler_dir}_handler" return importlib.import_module(import_path) diff --git a/evadb/third_party/databases/mariadb/mariadb_handler.py b/evadb/third_party/databases/mariadb/mariadb_handler.py index 8da40e0981..4dbdb7d88c 100644 --- a/evadb/third_party/databases/mariadb/mariadb_handler.py +++ b/evadb/third_party/databases/mariadb/mariadb_handler.py @@ -70,6 +70,9 @@ def disconnect(self): if self.connection: self.connection.close() + def get_sqlalchmey_uri(self) -> str: + return f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}" + def check_connection(self) -> DBHandlerStatus: """ Method for checking the status of database connection. diff --git a/evadb/third_party/databases/mysql/mysql_handler.py b/evadb/third_party/databases/mysql/mysql_handler.py index 829cc243c8..967342d314 100644 --- a/evadb/third_party/databases/mysql/mysql_handler.py +++ b/evadb/third_party/databases/mysql/mysql_handler.py @@ -49,6 +49,9 @@ def disconnect(self): if self.connection: self.connection.close() + def get_sqlalchmey_uri(self) -> str: + return f"mysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}" + def check_connection(self) -> DBHandlerStatus: if self.connection: return DBHandlerStatus(status=True) diff --git a/evadb/third_party/databases/postgres/postgres_handler.py b/evadb/third_party/databases/postgres/postgres_handler.py index dda08f7da8..43b8e9f95b 100644 --- a/evadb/third_party/databases/postgres/postgres_handler.py +++ b/evadb/third_party/databases/postgres/postgres_handler.py @@ -64,6 +64,9 @@ def disconnect(self): if self.connection: self.connection.close() + def get_sqlalchmey_uri(self) -> str: + return f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}" + def check_connection(self) -> DBHandlerStatus: """ Check connection to the handler. diff --git a/evadb/third_party/databases/sqlite/sqlite_handler.py b/evadb/third_party/databases/sqlite/sqlite_handler.py index 7256280ada..10078236bf 100644 --- a/evadb/third_party/databases/sqlite/sqlite_handler.py +++ b/evadb/third_party/databases/sqlite/sqlite_handler.py @@ -56,6 +56,9 @@ def disconnect(self): if self.connection: self.connection.close() + def get_sqlalchmey_uri(self) -> str: + return f"sqlite:///{self.database}" + def check_connection(self) -> DBHandlerStatus: """ Check connection to the handler. diff --git a/evadb/third_party/databases/types.py b/evadb/third_party/databases/types.py index 6708cf1a11..09b487185f 100644 --- a/evadb/third_party/databases/types.py +++ b/evadb/third_party/databases/types.py @@ -74,6 +74,15 @@ def disconnect(self): """ raise NotImplementedError() + def get_sqlalchmey_uri(self) -> str: + """ + Return the valid sqlalchemy uri to connect to the database. + + Raises: + NotImplementedError: This method should be implemented in derived classes. + """ + raise NotImplementedError() + def check_connection(self) -> DBHandlerStatus: """ Checks the status of the database connection.