-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(VirtualDataFrame): virtual dataframe to load data on demand and …
…enable direct_sql (#1434) * refactor(pandasai): make pandasai v3 work for dataframe * fix(sql): load and work with dataframe * fix: handle invalid data source type * feat(VirtualDataframe): lazy load data from the schema and fetch on demand
- Loading branch information
1 parent
00f31f4
commit 3f9816a
Showing
50 changed files
with
268 additions
and
2,780 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from typing import Dict, Any, List, Union | ||
|
||
|
||
class QueryBuilder: | ||
def __init__(self, schema: Dict[str, Any]): | ||
self.schema = schema | ||
|
||
def build_query(self) -> str: | ||
columns = self._get_columns() | ||
table_name = self.schema["source"]["table"] | ||
query = f"SELECT {columns} FROM {table_name}" | ||
|
||
query += self._add_order_by() | ||
query += self._add_limit() | ||
|
||
return query | ||
|
||
def _get_columns(self) -> str: | ||
if "columns" in self.schema: | ||
return ", ".join([col["name"] for col in self.schema["columns"]]) | ||
else: | ||
return "*" | ||
|
||
def _add_order_by(self) -> str: | ||
if "order_by" not in self.schema: | ||
return "" | ||
|
||
order_by = self.schema["order_by"] | ||
order_by_clause = self._format_order_by(order_by) | ||
return f" ORDER BY {order_by_clause}" | ||
|
||
def _format_order_by(self, order_by: Union[List[str], str]) -> str: | ||
return ", ".join(order_by) if isinstance(order_by, list) else order_by | ||
|
||
def _add_limit(self, n=None) -> str: | ||
limit = n if n else (self.schema["limit"] if "limit" in self.schema else "") | ||
return f" LIMIT {self.schema['limit']}" if limit else "" | ||
|
||
def get_head_query(self, n=5): | ||
source = self.schema.get("source", {}) | ||
source_type = source.get("type") | ||
|
||
table_name = self.schema["source"]["table"] | ||
|
||
columns = self._get_columns() | ||
|
||
order_by = "RAND()" | ||
if source_type in {"sqlite", "postgres"}: | ||
order_by = "RANDOM()" | ||
|
||
return f"SELECT {columns} FROM {table_name} ORDER BY {order_by} LIMIT {n}" | ||
|
||
def get_row_count(self): | ||
table_name = self.schema["source"]["table"] | ||
return f"SELECT COUNT(*) FROM {table_name}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import json | ||
|
||
|
||
def is_schema_source_same(schema1: dict, schema2: dict) -> bool: | ||
return schema1.get("source").get("type") == schema2.get("source").get( | ||
"type" | ||
) and json.dumps( | ||
schema1.get("source").get("connection"), sort_keys=True | ||
) == json.dumps(schema2.get("source").get("connection"), sort_keys=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from __future__ import annotations | ||
from typing import TYPE_CHECKING, ClassVar | ||
import pandas as pd | ||
from pandasai.dataframe.base import DataFrame | ||
|
||
if TYPE_CHECKING: | ||
from pandasai.data_loader.loader import DatasetLoader | ||
|
||
|
||
class VirtualDataFrame(DataFrame): | ||
_metadata: ClassVar[list] = [ | ||
"_loader", | ||
"head", | ||
"_head", | ||
"name", | ||
"description", | ||
"schema", | ||
"config", | ||
"_agent", | ||
"_column_hash", | ||
] | ||
|
||
def __init__(self, *args, **kwargs): | ||
self._loader: DatasetLoader = kwargs.pop("data_loader", None) | ||
if not self._loader: | ||
raise Exception("Data loader is required for virtualization!") | ||
self._head = None | ||
super().__init__(self.get_head(), *args, **kwargs) | ||
|
||
def head(self): | ||
if self._head is None: | ||
self._head = self._loader.load_head() | ||
|
||
return self._head | ||
|
||
@property | ||
def rows_count(self) -> int: | ||
return self._loader.get_row_count() | ||
|
||
def execute_sql_query(self, query: str) -> pd.DataFrame: | ||
return self._loader.execute_query(query) |
Oops, something went wrong.