-
Notifications
You must be signed in to change notification settings - Fork 238
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add table sampling support (#1421)
* feat: add table sampling support * add sqlglot * update QueryComposer * fix python test cases * disable by default an * comments * fix test case
- Loading branch information
Showing
27 changed files
with
789 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Mapping from Querybook languages to SQLGlot languages | ||
QUERYBOOK_TO_SQLGLOT_LANGUAGE_MAPPING = { | ||
# same name | ||
"presto": "presto", | ||
"trino": "trino", | ||
"bigquery": "bigquery", | ||
"clickhouse": "clickhouse", | ||
"hive": "hive", | ||
"mysql": "mysql", | ||
"oracle": "oracle", | ||
"sqlite": "sqlite", | ||
"snowflake": "snowflake", | ||
# different name | ||
"mssql": "tsql", | ||
"postgresql": "postgres", | ||
"sparksql": "spark", | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from app.datasource import register | ||
from lib.query_analysis.transform import ( | ||
transform_to_sampled_query, | ||
) | ||
|
||
|
||
@register("/query/transform/sampling/", methods=["POST"]) | ||
def query_sampling( | ||
query: str, | ||
language: str, | ||
sampling_tables: dict[str, dict[str, str]], | ||
): | ||
return transform_to_sampled_query( | ||
query=query, language=language, sampling_tables=sampling_tables | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
from typing import List, Optional, Union | ||
from sqlglot import exp, parse, parse_one, transpile, errors | ||
|
||
from lib.logger import get_logger | ||
from const.sqlglot import QUERYBOOK_TO_SQLGLOT_LANGUAGE_MAPPING | ||
|
||
LOG = get_logger(__file__) | ||
|
||
|
||
def _get_sqlglot_dialect(language: Optional[str] = None): | ||
return QUERYBOOK_TO_SQLGLOT_LANGUAGE_MAPPING.get(language, None) | ||
|
||
|
||
def _statements_to_query(statements: List[str]): | ||
return "\n".join(statement + ";" for statement in statements) | ||
|
||
|
||
def format_query(query: str, language: Optional[str] = None): | ||
dialect = _get_sqlglot_dialect(language) | ||
statements = transpile( | ||
query, | ||
read=dialect, | ||
write=dialect, | ||
pretty=True, | ||
) | ||
|
||
return _statements_to_query(statements) | ||
|
||
|
||
def get_select_statement_limit( | ||
statement: Union[exp.Expression, str], | ||
language: Optional[str] = None, | ||
) -> Union[int, None]: | ||
"""Get the limit of a select/union statement. | ||
Args: | ||
statement_ast: The select statement ast | ||
Returns: | ||
int: The limit of the select statement. -1 if no limit, or None if not a select/union statement | ||
""" | ||
if isinstance(statement, str): | ||
statement = parse_one(statement, dialect=_get_sqlglot_dialect(language)) | ||
|
||
if not isinstance( | ||
statement, (exp.Select, exp.Union) | ||
): # not a select or union statement | ||
return None | ||
|
||
limit = -1 | ||
limit_clause = statement.args.get("limit") | ||
|
||
if isinstance(limit_clause, exp.Limit): | ||
limit = limit_clause.expression.this | ||
elif isinstance(limit_clause, exp.Fetch): | ||
limit = limit_clause.args.get("count").this | ||
|
||
return int(limit) | ||
|
||
|
||
def get_limited_select_statement(statement_ast: exp.Expression, limit: int): | ||
"""Apply a limit to a select/union statement if it doesn't already have a limit. | ||
It returns a new statement with the limit applied and the original statement is not modified. | ||
""" | ||
current_limit = get_select_statement_limit(statement_ast) | ||
if current_limit is None or current_limit >= 0: | ||
return statement_ast | ||
|
||
return statement_ast.limit(limit) | ||
|
||
|
||
def has_query_contains_unlimited_select(query: str, language: str) -> bool: | ||
"""Check if a query contains a select statement without a limit. | ||
Args: | ||
query: The query to check | ||
Returns: | ||
bool: True if the query contains a select statement without a limit, False otherwise | ||
""" | ||
statements = parse(query, dialect=_get_sqlglot_dialect[language]) | ||
return any(get_select_statement_limit(s) == -1 for s in statements) | ||
|
||
|
||
def transform_to_limited_query( | ||
query: str, limit: int = None, language: str = None | ||
) -> str: | ||
"""Apply a limit to all select statements in a query if they don't already have a limit. | ||
It returns a new query with the limit applied and the original query is not modified. | ||
""" | ||
if not limit: | ||
return query | ||
|
||
try: | ||
dialect = _get_sqlglot_dialect(language) | ||
statements = parse(query, dialect=dialect) | ||
|
||
updated_statements = [ | ||
get_limited_select_statement(s, limit) for s in statements | ||
] | ||
return _statements_to_query( | ||
[s.sql(dialect=dialect, pretty=True) for s in updated_statements] | ||
) | ||
except errors.ParseError as e: | ||
LOG.error(e, exc_info=True) | ||
# If parsing fails, return the original query | ||
return query | ||
|
||
|
||
def _get_sampled_statement( | ||
statement_ast: exp.Expression, | ||
sampling_tables: dict[str, dict[str, str]], | ||
): | ||
"""Apply sampling to a sglglot statement AST for the given tables.""" | ||
|
||
def transformer(node): | ||
if isinstance(node, exp.Table): | ||
full_table_name = f"{node.db}.{node.name}" if node.db else node.name | ||
if full_table_name not in sampling_tables: | ||
return node | ||
|
||
if ( | ||
sampled_table := sampling_tables[full_table_name].get("sampled_table") | ||
) is not None: | ||
node.set("this", exp.to_identifier(sampled_table, quoted=False)) | ||
node.set("db", None) | ||
elif ( | ||
sample_rate := sampling_tables[full_table_name].get("sample_rate") | ||
) is not None: | ||
return exp.TableSample( | ||
this=node, method="SYSTEM", percent=str(sample_rate) | ||
) | ||
return node | ||
|
||
return statement_ast.transform(transformer) | ||
|
||
|
||
def transform_to_sampled_query( | ||
query: str, | ||
language: str = None, | ||
sampling_tables: dict[str, dict[str, str]] = {}, | ||
): | ||
"""Apply sampling to the query for the given tables. | ||
An example of sampling_tables: | ||
{ | ||
"db.table1": {"sampled_table": "db.sampled_table1"}, | ||
"db.table2": {"sample_rate": 0.1}, | ||
} | ||
If sampled_table is provided, the table will be replaced with the sampled_table over using the sample_rate. | ||
Args: | ||
query: The query to apply sampling to | ||
language: The language of the query | ||
sampling_tables: A dictionary of tables to sample and their sampled version or sample rates | ||
Returns: | ||
str: The sampled query | ||
""" | ||
try: | ||
dialect = _get_sqlglot_dialect(language) | ||
statements = parse(query, dialect=dialect) | ||
sampled_statements = [ | ||
_get_sampled_statement(s, sampling_tables) for s in statements | ||
] | ||
return _statements_to_query( | ||
[s.sql(dialect=dialect, pretty=True) for s in sampled_statements] | ||
) | ||
|
||
except errors.ParseError as e: | ||
LOG.error(e, exc_info=True) | ||
# If parsing fails, return the original query | ||
return query |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.