Skip to content

Commit

Permalink
feat(VirtualDataFrame): virtual dataframe to load data on demand and …
Browse files Browse the repository at this point in the history
…enable direct_sql (#1434)

* refactor(pandasai): make pandasai v3 work for dataframe

* fix(sql): load and work with dataframe

* fix: handle invalid data source type

* feat(VirtualDataframe): lazy load data from the schema and fetch on demand
  • Loading branch information
ArslanSaleem authored Nov 20, 2024
1 parent 00f31f4 commit 3f9816a
Show file tree
Hide file tree
Showing 50 changed files with 268 additions and 2,780 deletions.
6 changes: 3 additions & 3 deletions pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .agent import Agent
from .helpers.cache import Cache
from .dataframe.base import DataFrame
from .dataframe.loader import DatasetLoader
from .data_loader.loader import DatasetLoader

# Global variable to store the current agent
_current_agent = None
Expand Down Expand Up @@ -61,7 +61,7 @@ def follow_up(query: str):
_dataset_loader = DatasetLoader()


def load(dataset_path: str) -> DataFrame:
def load(dataset_path: str, virtualized=False) -> DataFrame:
"""
Load data based on the provided dataset path.
Expand All @@ -72,7 +72,7 @@ def load(dataset_path: str) -> DataFrame:
DataFrame: A new PandasAI DataFrame instance with loaded data.
"""
global _dataset_loader
return _dataset_loader.load(dataset_path)
return _dataset_loader.load(dataset_path, virtualized)


__all__ = [
Expand Down
49 changes: 40 additions & 9 deletions pandasai/agent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import pandas as pd
from pandasai.agent.base_security import BaseSecurity

from pandasai.data_loader.schema_validator import is_schema_source_same
from pandasai.llm.bamboo_llm import BambooLLM
from pandasai.pipelines.chat.chat_pipeline_input import ChatPipelineInput
from pandasai.pipelines.chat.code_execution_pipeline_input import (
Expand Down Expand Up @@ -62,17 +64,13 @@ def __init__(

self.dfs = dfs if isinstance(dfs, list) else [dfs]

# Validate SQL connectors
sql_connectors = [
df
for df in self.dfs
if hasattr(df, "type") and df.type in ["sql", "postgresql"]
]
if len(sql_connectors) > 1:
raise InvalidConfigError("Cannot use multiple SQL connectors")

# Instantiate the context
self.config = self.get_config(config)

# Validate df input with configurations
self.validate_input()

# Initialize the context
self.context = PipelineContext(
dfs=self.dfs,
config=self.config,
Expand Down Expand Up @@ -106,6 +104,39 @@ def __init__(
self.pipeline = None
self.security = security

def validate_input(self):
from pandasai.dataframe.virtual_dataframe import VirtualDataFrame

# Check if all DataFrames are VirtualDataFrame, and set direct_sql accordingly
all_virtual = all(isinstance(df, VirtualDataFrame) for df in self.dfs)
if all_virtual:
self.config.direct_sql = True

# Validate the configurations based on direct_sql flag all have same source
if self.config.direct_sql and all_virtual:
base_schema_source = self.dfs[0].schema
for df in self.dfs[1:]:
# Ensure all DataFrames have the same source in direct_sql mode

if not is_schema_source_same(base_schema_source, df.schema):
raise InvalidConfigError(
"Direct SQL requires all connectors to be of the same type, "
"belong to the same datasource, and have the same credentials."
)
else:
# If not using direct_sql, ensure all DataFrames have the same source
if any(isinstance(df, VirtualDataFrame) for df in self.dfs):
base_schema_source = self.dfs[0].schema
for df in self.dfs[1:]:
if not is_schema_source_same(base_schema_source, df.schema):
raise InvalidConfigError(
"All DataFrames must belong to the same source."
)
self.config.direct_sql = True
else:
# Means all are none virtual
self.config.direct_sql = False

def configure(self):
# Add project root path if save_charts_path is default
if (
Expand Down
103 changes: 80 additions & 23 deletions pandasai/dataframe/loader.py → pandasai/data_loader/loader.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import copy
import os
import yaml
import pandas as pd
from datetime import datetime, timedelta
import hashlib

from pandasai.dataframe.base import DataFrame
from pandasai.dataframe.virtual_dataframe import VirtualDataFrame
from pandasai.exceptions import InvalidDataSourceType
from pandasai.helpers.path import find_project_root
from .base import DataFrame
import importlib
from typing import Any
from .query_builder import QueryBuilder
Expand All @@ -18,27 +20,35 @@ def __init__(self):
self.schema = None
self.dataset_path = None

def load(self, dataset_path: str, lazy=False) -> DataFrame:
def load(self, dataset_path: str, virtualized=False) -> DataFrame:
self.dataset_path = dataset_path
self._load_schema()
self._validate_source_type()
if not virtualized:
cache_file = self._get_cache_file_path()

cache_file = self._get_cache_file_path()
if self._is_cache_valid(cache_file):
return self._read_cache(cache_file)

if self._is_cache_valid(cache_file):
return self._read_cache(cache_file)
df = self._load_from_source()
df = self._apply_transformations(df)
self._cache_data(df, cache_file)

df = self._load_from_source()
df = self._apply_transformations(df)
self._cache_data(df, cache_file)
table_name = self.schema["source"]["table"]

return DataFrame(df, schema=self.schema)
return DataFrame(df, schema=self.schema, name=table_name)
else:
# Initialize new dataset loader for virtualization
data_loader = self.copy()
table_name = self.schema["source"]["table"]
return VirtualDataFrame(
schema=self.schema, data_loader=data_loader, name=table_name
)

def _load_schema(self):
schema_path = os.path.join(
find_project_root(), "datasets", self.dataset_path, "schema.yaml"
)
print(schema_path)
if not os.path.exists(schema_path):
raise FileNotFoundError(f"Schema file not found: {schema_path}")

Expand Down Expand Up @@ -82,32 +92,67 @@ def _read_cache(self, cache_file: str) -> DataFrame:
else:
raise ValueError(f"Unsupported cache format: {cache_format}")

def _load_from_source(self) -> pd.DataFrame:
source_type = self.schema["source"]["type"]
connection_info = self.schema["source"].get("connection", {})
query_builder = QueryBuilder(self.schema)
query = query_builder.build_query()

def _get_loader_function(self, source_type: str):
"""
Get the loader function for a specified data source type.
"""
try:
module_name = SUPPORTED_SOURCES[source_type]
module = importlib.import_module(module_name)

if source_type in [
if source_type not in {
"mysql",
"postgres",
"cockroach",
"sqlite",
"cockroachdb",
]:
load_function = getattr(module, f"load_from_{source_type}")
return load_function(connection_info, query)
else:
raise InvalidDataSourceType("Invalid data source type")
}:
raise InvalidDataSourceType(
f"Unsupported data source type: {source_type}"
)

return getattr(module, f"load_from_{source_type}")

except KeyError:
raise InvalidDataSourceType(f"Unsupported data source type: {source_type}")

except ImportError as e:
raise ImportError(
f"{source_type.capitalize()} connector not found. "
f"Please install the {module_name} library."
f"Please install the {SUPPORTED_SOURCES[source_type]} library."
) from e

def _load_from_source(self) -> pd.DataFrame:
query_builder = QueryBuilder(self.schema)
query = query_builder.build_query()
return self.execute_query(query)

def load_head(self) -> pd.DataFrame:
query_builder = QueryBuilder(self.schema)
query = query_builder.get_head_query()
return self.execute_query(query)

def get_row_count(self) -> int:
query_builder = QueryBuilder(self.schema)
query = query_builder.get_row_count()
result = self.execute_query(query)
return result.iloc[0, 0]

def execute_query(self, query: str) -> pd.DataFrame:
source = self.schema.get("source", {})
source_type = source.get("type")
connection_info = source.get("connection", {})

if not source_type:
raise ValueError("Source type is missing in the schema.")

load_function = self._get_loader_function(source_type)

try:
return load_function(connection_info, query)
except Exception as e:
raise RuntimeError(
f"Failed to execute query for source type '{source_type}' with query: {query}"
) from e

def _apply_transformations(self, df: pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -140,3 +185,15 @@ def _cache_data(self, df: pd.DataFrame, cache_file: str):
df.to_csv(cache_file, index=False)
else:
raise ValueError(f"Unsupported cache format: {cache_format}")

def copy(self) -> "DatasetLoader":
"""
Create a new independent copy of the current DatasetLoader instance.
Returns:
DatasetLoader: A new instance with the same state.
"""
new_loader = DatasetLoader()
new_loader.schema = copy.deepcopy(self.schema)
new_loader.dataset_path = self.dataset_path
return new_loader
55 changes: 55 additions & 0 deletions pandasai/data_loader/query_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Dict, Any, List, Union


class QueryBuilder:
def __init__(self, schema: Dict[str, Any]):
self.schema = schema

def build_query(self) -> str:
columns = self._get_columns()
table_name = self.schema["source"]["table"]
query = f"SELECT {columns} FROM {table_name}"

query += self._add_order_by()
query += self._add_limit()

return query

def _get_columns(self) -> str:
if "columns" in self.schema:
return ", ".join([col["name"] for col in self.schema["columns"]])
else:
return "*"

def _add_order_by(self) -> str:
if "order_by" not in self.schema:
return ""

order_by = self.schema["order_by"]
order_by_clause = self._format_order_by(order_by)
return f" ORDER BY {order_by_clause}"

def _format_order_by(self, order_by: Union[List[str], str]) -> str:
return ", ".join(order_by) if isinstance(order_by, list) else order_by

def _add_limit(self, n=None) -> str:
limit = n if n else (self.schema["limit"] if "limit" in self.schema else "")
return f" LIMIT {self.schema['limit']}" if limit else ""

def get_head_query(self, n=5):
source = self.schema.get("source", {})
source_type = source.get("type")

table_name = self.schema["source"]["table"]

columns = self._get_columns()

order_by = "RAND()"
if source_type in {"sqlite", "postgres"}:
order_by = "RANDOM()"

return f"SELECT {columns} FROM {table_name} ORDER BY {order_by} LIMIT {n}"

def get_row_count(self):
table_name = self.schema["source"]["table"]
return f"SELECT COUNT(*) FROM {table_name}"
9 changes: 9 additions & 0 deletions pandasai/data_loader/schema_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import json


def is_schema_source_same(schema1: dict, schema2: dict) -> bool:
return schema1.get("source").get("type") == schema2.get("source").get(
"type"
) and json.dumps(
schema1.get("source").get("connection"), sort_keys=True
) == json.dumps(schema2.get("source").get("connection"), sort_keys=True)
41 changes: 41 additions & 0 deletions pandasai/dataframe/virtual_dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
import pandas as pd
from pandasai.dataframe.base import DataFrame

if TYPE_CHECKING:
from pandasai.data_loader.loader import DatasetLoader


class VirtualDataFrame(DataFrame):
_metadata: ClassVar[list] = [
"_loader",
"head",
"_head",
"name",
"description",
"schema",
"config",
"_agent",
"_column_hash",
]

def __init__(self, *args, **kwargs):
self._loader: DatasetLoader = kwargs.pop("data_loader", None)
if not self._loader:
raise Exception("Data loader is required for virtualization!")
self._head = None
super().__init__(self.get_head(), *args, **kwargs)

def head(self):
if self._head is None:
self._head = self._loader.load_head()

return self._head

@property
def rows_count(self) -> int:
return self._loader.get_row_count()

def execute_sql_query(self, query: str) -> pd.DataFrame:
return self._loader.execute_query(query)
Loading

0 comments on commit 3f9816a

Please sign in to comment.