Skip to content

Commit

Permalink
feat: add table sampling support (#1421)
Browse files Browse the repository at this point in the history
* feat: add table sampling support

* add sqlglot

* update QueryComposer

* fix python test cases

* disable by default an

* comments

* fix test case
  • Loading branch information
jczhong84 authored Mar 19, 2024
1 parent 44d8055 commit 4287294
Show file tree
Hide file tree
Showing 27 changed files with 789 additions and 46 deletions.
2 changes: 1 addition & 1 deletion docs_website/docs/configurations/infra_installation.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ If you install the required packages, these integrations will be automatically s
- Elasticsearch:
- AWS hosted (via `-r platform/aws.txt`)
- Parsing (transpilation):
- SQLGlot (via `-r parser/sqlglot.txt`)
- SQLGlot (supported by default)

## How to install packages for integration

Expand Down
2 changes: 2 additions & 0 deletions querybook/config/datadoc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ cell_types:
collapsed: false
query_collapsed: false
limit: 0
sample_rate: 0.0
meta_default:
title: ''
engine: null
collapsed: false
query_collapsed: false
limit: 0
sample_rate: 0
chart:
key: 'chart'
name: 'Chart'
Expand Down
6 changes: 6 additions & 0 deletions querybook/config/querybook_public_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,9 @@ survey:
# - surface: table_view
# - surface: text_to_sql
# - surface: query_authoring

table_sampling:
enabled: false
sample_rates:
- 0.1
default_sample_rate: 0 # 0 means no sampling
17 changes: 17 additions & 0 deletions querybook/server/const/sqlglot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Mapping from Querybook languages to SQLGlot languages
QUERYBOOK_TO_SQLGLOT_LANGUAGE_MAPPING = {
# same name
"presto": "presto",
"trino": "trino",
"bigquery": "bigquery",
"clickhouse": "clickhouse",
"hive": "hive",
"mysql": "mysql",
"oracle": "oracle",
"sqlite": "sqlite",
"snowflake": "snowflake",
# different name
"mssql": "tsql",
"postgresql": "postgres",
"sparksql": "spark",
}
2 changes: 2 additions & 0 deletions querybook/server/datasources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from . import data_element
from . import comment
from . import survey
from . import query_transform

# Keep this at the end of imports to make sure the plugin APIs override the default ones
try:
Expand Down Expand Up @@ -44,4 +45,5 @@
data_element
comment
survey
query_transform
api_plugin
4 changes: 3 additions & 1 deletion querybook/server/datasources/metastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def get_table_by_name(
table = logic.get_table_by_name(
schema_name, table_name, metastore_id, session=session
)
api_assert(table, "{}.{} does not exist".format(schema_name, table_name))
if not table:
return None

verify_data_schema_permission(table.schema_id, session=session)
table_dict = table.to_dict(with_schema, with_column, with_warnings)

Expand Down
15 changes: 15 additions & 0 deletions querybook/server/datasources/query_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from app.datasource import register
from lib.query_analysis.transform import (
transform_to_sampled_query,
)


@register("/query/transform/sampling/", methods=["POST"])
def query_sampling(
query: str,
language: str,
sampling_tables: dict[str, dict[str, str]],
):
return transform_to_sampled_query(
query=query, language=language, sampling_tables=sampling_tables
)
169 changes: 169 additions & 0 deletions querybook/server/lib/query_analysis/transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
from typing import List, Optional, Union
from sqlglot import exp, parse, parse_one, transpile, errors

from lib.logger import get_logger
from const.sqlglot import QUERYBOOK_TO_SQLGLOT_LANGUAGE_MAPPING

LOG = get_logger(__file__)


def _get_sqlglot_dialect(language: Optional[str] = None):
return QUERYBOOK_TO_SQLGLOT_LANGUAGE_MAPPING.get(language, None)


def _statements_to_query(statements: List[str]):
return "\n".join(statement + ";" for statement in statements)


def format_query(query: str, language: Optional[str] = None):
dialect = _get_sqlglot_dialect(language)
statements = transpile(
query,
read=dialect,
write=dialect,
pretty=True,
)

return _statements_to_query(statements)


def get_select_statement_limit(
statement: Union[exp.Expression, str],
language: Optional[str] = None,
) -> Union[int, None]:
"""Get the limit of a select/union statement.
Args:
statement_ast: The select statement ast
Returns:
int: The limit of the select statement. -1 if no limit, or None if not a select/union statement
"""
if isinstance(statement, str):
statement = parse_one(statement, dialect=_get_sqlglot_dialect(language))

if not isinstance(
statement, (exp.Select, exp.Union)
): # not a select or union statement
return None

limit = -1
limit_clause = statement.args.get("limit")

if isinstance(limit_clause, exp.Limit):
limit = limit_clause.expression.this
elif isinstance(limit_clause, exp.Fetch):
limit = limit_clause.args.get("count").this

return int(limit)


def get_limited_select_statement(statement_ast: exp.Expression, limit: int):
"""Apply a limit to a select/union statement if it doesn't already have a limit.
It returns a new statement with the limit applied and the original statement is not modified.
"""
current_limit = get_select_statement_limit(statement_ast)
if current_limit is None or current_limit >= 0:
return statement_ast

return statement_ast.limit(limit)


def has_query_contains_unlimited_select(query: str, language: str) -> bool:
"""Check if a query contains a select statement without a limit.
Args:
query: The query to check
Returns:
bool: True if the query contains a select statement without a limit, False otherwise
"""
statements = parse(query, dialect=_get_sqlglot_dialect[language])
return any(get_select_statement_limit(s) == -1 for s in statements)


def transform_to_limited_query(
query: str, limit: int = None, language: str = None
) -> str:
"""Apply a limit to all select statements in a query if they don't already have a limit.
It returns a new query with the limit applied and the original query is not modified.
"""
if not limit:
return query

try:
dialect = _get_sqlglot_dialect(language)
statements = parse(query, dialect=dialect)

updated_statements = [
get_limited_select_statement(s, limit) for s in statements
]
return _statements_to_query(
[s.sql(dialect=dialect, pretty=True) for s in updated_statements]
)
except errors.ParseError as e:
LOG.error(e, exc_info=True)
# If parsing fails, return the original query
return query


def _get_sampled_statement(
statement_ast: exp.Expression,
sampling_tables: dict[str, dict[str, str]],
):
"""Apply sampling to a sglglot statement AST for the given tables."""

def transformer(node):
if isinstance(node, exp.Table):
full_table_name = f"{node.db}.{node.name}" if node.db else node.name
if full_table_name not in sampling_tables:
return node

if (
sampled_table := sampling_tables[full_table_name].get("sampled_table")
) is not None:
node.set("this", exp.to_identifier(sampled_table, quoted=False))
node.set("db", None)
elif (
sample_rate := sampling_tables[full_table_name].get("sample_rate")
) is not None:
return exp.TableSample(
this=node, method="SYSTEM", percent=str(sample_rate)
)
return node

return statement_ast.transform(transformer)


def transform_to_sampled_query(
query: str,
language: str = None,
sampling_tables: dict[str, dict[str, str]] = {},
):
"""Apply sampling to the query for the given tables.
An example of sampling_tables:
{
"db.table1": {"sampled_table": "db.sampled_table1"},
"db.table2": {"sample_rate": 0.1},
}
If sampled_table is provided, the table will be replaced with the sampled_table over using the sample_rate.
Args:
query: The query to apply sampling to
language: The language of the query
sampling_tables: A dictionary of tables to sample and their sampled version or sample rates
Returns:
str: The sampled query
"""
try:
dialect = _get_sqlglot_dialect(language)
statements = parse(query, dialect=dialect)
sampled_statements = [
_get_sampled_statement(s, sampling_tables) for s in statements
]
return _statements_to_query(
[s.sql(dialect=dialect, pretty=True) for s in sampled_statements]
)

except errors.ParseError as e:
LOG.error(e, exc_info=True)
# If parsing fails, return the original query
return query
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from lib.utils.import_helper import import_modules, import_module_with_default
from lib.query_analysis.transpilation.base_query_transpiler import BaseQueryTranspiler

ALL_PLUGIN_QUERY_VALIDATORS_BY_NAME = import_module_with_default(
ALL_PLUGIN_QUERY_QUERY_TRANSPILERS = import_module_with_default(
"query_transpilation_plugin", "ALL_PLUGIN_QUERY_TRANSPILERS", default=[]
)

Expand All @@ -15,7 +15,7 @@
]
)

ALL_TRANSPILERS: List[BaseQueryTranspiler] = ALL_PLUGIN_QUERY_VALIDATORS_BY_NAME + [
ALL_TRANSPILERS: List[BaseQueryTranspiler] = ALL_PLUGIN_QUERY_QUERY_TRANSPILERS + [
transpiler_cls() for transpiler_cls in PROVIDED_TRANSPILERS
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,9 @@

from typing import List

from const.sqlglot import QUERYBOOK_TO_SQLGLOT_LANGUAGE_MAPPING
from lib.query_analysis.transpilation.base_query_transpiler import BaseQueryTranspiler

# Mapping from Querybook languages to SQLGlot languages
QUERYBOOK_TO_SQLGLOT_LANGUAGE_MAPPING = {
# same name
"presto": "presto",
"trino": "trino",
"bigquery": "bigquery",
"clickhouse": "clickhouse",
"hive": "hive",
"mysql": "mysql",
"oracle": "oracle",
"sqlite": "sqlite",
"snowflake": "snowflake",
# different name
"mssql": "tsql",
"postgresql": "postgres",
"sparksql": "spark",
}


def statements_to_query(statements: List[str]):
return "\n".join(statement + ";" for statement in statements)
Expand Down
Loading

0 comments on commit 4287294

Please sign in to comment.