Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Improve db integration #1114

Merged
merged 6 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions evadb/binder/binder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,12 @@ def check_data_source_and_table_are_valid(
db_catalog_entry = catalog.get_database_catalog_entry(database_name)

if db_catalog_entry is not None:
handler = get_database_handler(
with get_database_handler(
db_catalog_entry.engine, **db_catalog_entry.params
)
handler.connect()
) as handler:
# Get table definition.
resp = handler.get_tables()

# Get table definition.
resp = handler.get_tables()
if resp.error is not None:
error = "There is no table in data source {}. Create the table using native query.".format(
database_name,
Expand All @@ -90,7 +89,7 @@ def check_data_source_and_table_are_valid(


def create_table_catalog_entry_for_data_source(
table_name: str, column_info: pd.DataFrame
table_name: str, database_name: str, column_info: pd.DataFrame
):
column_name_list = list(column_info["name"])
column_type_list = [
Expand All @@ -107,6 +106,7 @@ def create_table_catalog_entry_for_data_source(
file_url=None,
table_type=TableType.NATIVE_DATA,
columns=column_list,
database_name=database_name,
)
return table_catalog_entry

Expand Down Expand Up @@ -134,14 +134,14 @@ def bind_native_table_info(catalog: CatalogManager, table_info: TableInfo):
)

db_catalog_entry = catalog.get_database_catalog_entry(table_info.database_name)
handler = get_database_handler(db_catalog_entry.engine, **db_catalog_entry.params)
handler.connect()

# Assemble columns.
column_df = handler.get_columns(table_info.table_name).data
table_info.table_obj = create_table_catalog_entry_for_data_source(
table_info.table_name, column_df
)
with get_database_handler(
db_catalog_entry.engine, **db_catalog_entry.params
) as handler:
# Assemble columns.
column_df = handler.get_columns(table_info.table_name).data
table_info.table_obj = create_table_catalog_entry_for_data_source(
table_info.table_name, table_info.database_name, column_df
)


def bind_evadb_table_info(catalog: CatalogManager, table_info: TableInfo):
Expand Down
16 changes: 7 additions & 9 deletions evadb/binder/statement_binder_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,14 @@ def add_table_alias(self, alias: str, database_name: str, table_name: str):
)

db_catalog_entry = self._catalog().get_database_catalog_entry(database_name)
handler = get_database_handler(
with get_database_handler(
db_catalog_entry.engine, **db_catalog_entry.params
)
handler.connect()

# Assemble columns.
column_df = handler.get_columns(table_name).data
table_obj = create_table_catalog_entry_for_data_source(
table_name, column_df
)
) as handler:
# Assemble columns.
column_df = handler.get_columns(table_name).data
table_obj = create_table_catalog_entry_for_data_source(
table_name, database_name, column_df
)
else:
table_obj = self._catalog().get_table_catalog_entry(table_name)

Expand Down
1 change: 1 addition & 0 deletions evadb/catalog/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class TableCatalogEntry:
identifier_column: str = "id"
columns: List[ColumnCatalogEntry] = field(compare=False, default_factory=list)
row_id: int = None
database_name: str = "EvaDB"


@dataclass(unsafe_hash=True)
Expand Down
6 changes: 2 additions & 4 deletions evadb/executor/create_database_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,8 @@ def exec(self, *args, **kwargs):
raise ExecutorError(f"{self.node.database_name} already exists.")

# Check the validity of database entry.
handler = get_database_handler(self.node.engine, **self.node.param_dict)
resp = handler.connect()
if not resp.status:
raise ExecutorError(f"Cannot establish connection due to {resp.error}")
with get_database_handler(self.node.engine, **self.node.param_dict):
pass

logger.debug(f"Creating database {self.node}")
self.catalog().insert_database_catalog_entry(
Expand Down
22 changes: 21 additions & 1 deletion evadb/executor/executor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@
from pathlib import Path
from typing import TYPE_CHECKING, Generator, List

from evadb.catalog.catalog_utils import xform_column_definitions_to_catalog_entries
from evadb.catalog.models.utils import TableCatalogEntry
from evadb.parser.create_statement import ColumnDefinition

if TYPE_CHECKING:
from evadb.catalog.catalog_manager import CatalogManager

from evadb.catalog.catalog_type import VectorStoreType
from evadb.catalog.catalog_type import TableType, VectorStoreType
from evadb.expression.abstract_expression import AbstractExpression
from evadb.expression.function_expression import FunctionExpression
from evadb.models.storage.batch import Batch
Expand Down Expand Up @@ -169,3 +173,19 @@ def handle_vector_store_params(
return {"index_db": str(Path(index_path).parent)}
else:
raise ValueError("Unsupported vector store type: {}".format(vector_store_type))


def create_table_catalog_entry_for_native_table(
table_info: TableInfo, column_list: List[ColumnDefinition]
):
column_catalog_entires = xform_column_definitions_to_catalog_entries(column_list)

# Assemble table.
table_catalog_entry = TableCatalogEntry(
name=table_info.table_name,
file_url=None,
table_type=TableType.NATIVE_DATA,
columns=column_catalog_entires,
database_name=table_info.database_name,
)
return table_catalog_entry
4 changes: 1 addition & 3 deletions evadb/executor/storage_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ def exec(self, *args, **kwargs) -> Iterator[Batch]:
elif self.node.table.table_type == TableType.STRUCTURED_DATA:
return storage_engine.read(self.node.table, self.node.batch_mem_size)
elif self.node.table.table_type == TableType.NATIVE_DATA:
return storage_engine.read(
self.node.table_ref.table.database_name, self.node.table
)
return storage_engine.read(self.node.table)
elif self.node.table.table_type == TableType.PDF_DATA:
return storage_engine.read(self.node.table)
else:
Expand Down
14 changes: 5 additions & 9 deletions evadb/executor/use_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,12 @@ def exec(self, *args, **kwargs) -> Iterator[Batch]:
f"{self._database_name} data source does not exist. Use CREATE DATABASE to add a new data source."
)

handler = get_database_handler(
db_catalog_entry.engine,
**db_catalog_entry.params,
)

handler.connect()
resp = handler.execute_native_query(self._query_string)
handler.disconnect()
with get_database_handler(
db_catalog_entry.engine, **db_catalog_entry.params
) as handler:
resp = handler.execute_native_query(self._query_string)

if resp.error is None:
if resp and resp.error is None:
return Batch(resp.data)
else:
raise ExecutorError(resp.error)
35 changes: 16 additions & 19 deletions evadb/storage/native_storage_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,33 +28,30 @@ class NativeStorageEngine(AbstractStorageEngine):
def __init__(self, db: EvaDBDatabase):
super().__init__(db)

def create(self, table: TableCatalogEntry):
pass

def write(self, table: TableCatalogEntry, rows: Batch):
pass

def read(self, database_name: str, table: TableCatalogEntry) -> Iterator[Batch]:
def read(self, table: TableCatalogEntry) -> Iterator[Batch]:
try:
db_catalog_entry = self.db.catalog().get_database_catalog_entry(
database_name
table.database_name
)
handler = get_database_handler(
with get_database_handler(
db_catalog_entry.engine, **db_catalog_entry.params
)
handler.connect()

data_df = handler.execute_native_query(f"SELECT * FROM {table.name}").data

# Handling case-sensitive databases like SQLite can be tricky. Currently,
# EvaDB converts all columns to lowercase, which may result in issues with
# these databases. As we move forward, we are actively working on improving
# this aspect within Binder.
# For more information, please refer to https://github.com/georgia-tech-db/evadb/issues/1079.
data_df.columns = data_df.columns.str.lower()
yield Batch(pd.DataFrame(data_df))
) as handler:
data_df = handler.execute_native_query(
f"SELECT * FROM {table.name}"
).data

# Handling case-sensitive databases like SQLite can be tricky.
# Currently, EvaDB converts all columns to lowercase, which may result
# in issues with these databases. As we move forward, we are actively
# working on improving this aspect within Binder. For more information,
# please refer to https://github.com/georgia-tech-db/evadb/issues/1079.
data_df.columns = data_df.columns.str.lower()
yield Batch(pd.DataFrame(data_df))

except Exception as e:
err_msg = f"Failed to read the table {table.name} in data source {database_name} with exception {str(e)}"
err_msg = f"Failed to read the table {table.name} in data source {table.database_name} with exception {str(e)}"
logger.exception(err_msg)
raise Exception(err_msg)
15 changes: 14 additions & 1 deletion evadb/third_party/databases/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
# limitations under the License.
import importlib
import os
from contextlib import contextmanager


def get_database_handler(engine: str, **kwargs):
def _get_database_handler(engine: str, **kwargs):
"""
Return the database handler. User should modify this function for
their new integrated handlers.
Expand All @@ -43,6 +44,18 @@ def get_database_handler(engine: str, **kwargs):
raise NotImplementedError(f"Engine {engine} is not supported")


@contextmanager
def get_database_handler(engine: str, **kwargs):
handler = _get_database_handler(engine, **kwargs)
try:
handler.connect()
yield handler
except Exception as e:
raise Exception(f"Error connecting to the database: {str(e)}")
finally:
handler.disconnect()


def dynamic_import(handler_dir):
import_path = f"evadb.third_party.databases.{handler_dir}.{handler_dir}_handler"
return importlib.import_module(import_path)
3 changes: 3 additions & 0 deletions evadb/third_party/databases/mariadb/mariadb_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def disconnect(self):
if self.connection:
self.connection.close()

def get_sqlalchmey_uri(self) -> str:
return f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"

def check_connection(self) -> DBHandlerStatus:
"""
Method for checking the status of database connection.
Expand Down
3 changes: 3 additions & 0 deletions evadb/third_party/databases/mysql/mysql_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def disconnect(self):
if self.connection:
self.connection.close()

def get_sqlalchmey_uri(self) -> str:
return f"mysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"

def check_connection(self) -> DBHandlerStatus:
if self.connection:
return DBHandlerStatus(status=True)
Expand Down
3 changes: 3 additions & 0 deletions evadb/third_party/databases/postgres/postgres_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def disconnect(self):
if self.connection:
self.connection.close()

def get_sqlalchmey_uri(self) -> str:
return f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"

def check_connection(self) -> DBHandlerStatus:
"""
Check connection to the handler.
Expand Down
3 changes: 3 additions & 0 deletions evadb/third_party/databases/sqlite/sqlite_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def disconnect(self):
if self.connection:
self.connection.close()

def get_sqlalchmey_uri(self) -> str:
return f"sqlite:///{self.database}"

def check_connection(self) -> DBHandlerStatus:
"""
Check connection to the handler.
Expand Down
9 changes: 9 additions & 0 deletions evadb/third_party/databases/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ def disconnect(self):
"""
raise NotImplementedError()

def get_sqlalchmey_uri(self) -> str:
"""
Return the valid sqlalchemy uri to connect to the database.

Raises:
NotImplementedError: This method should be implemented in derived classes.
"""
raise NotImplementedError()

def check_connection(self) -> DBHandlerStatus:
"""
Checks the status of the database connection.
Expand Down