diff --git a/fakesnow/fakes.py b/fakesnow/fakes.py index 6566b35..cc165fd 100644 --- a/fakesnow/fakes.py +++ b/fakesnow/fakes.py @@ -195,6 +195,8 @@ def _execute( .transform(transforms.array_size) .transform(transforms.random) .transform(transforms.identifier) + .transform(transforms.array_agg_within_group) + .transform(transforms.array_agg_to_json) .transform(lambda e: transforms.show_schemas(e, self._conn.database)) .transform(lambda e: transforms.show_objects_tables(e, self._conn.database)) # TODO collapse into a single show_keys function diff --git a/fakesnow/transforms.py b/fakesnow/transforms.py index 418f3f1..4083651 100644 --- a/fakesnow/transforms.py +++ b/fakesnow/transforms.py @@ -22,6 +22,39 @@ def array_size(expression: exp.Expression) -> exp.Expression: return expression +def array_agg_to_json(expression: exp.Expression) -> exp.Expression: + if isinstance(expression, exp.ArrayAgg): + return exp.Anonymous(this="TO_JSON", expressions=[expression]) + + return expression + + +def array_agg_within_group(expression: exp.Expression) -> exp.Expression: + """Convert ARRAY_AGG() WITHIN GROUP () to ARRAY_AGG( ) + Snowflake uses ARRAY_AGG() WITHIN GROUP (ORDER BY ) + to order the array, but DuckDB uses ARRAY_AGG( ). + See; + - https://docs.snowflake.com/en/sql-reference/functions/array_agg + - https://duckdb.org/docs/sql/aggregates.html#order-by-clause-in-aggregate-functions + Note; Snowflake has following restriction; + If you specify DISTINCT and WITHIN GROUP, both must refer to the same column. + Transformation does not handle this restriction. + """ + if ( + isinstance(expression, exp.WithinGroup) + and (agg := expression.find(exp.ArrayAgg)) + and (order := expression.expression) + ): + return exp.ArrayAgg( + this=exp.Order( + this=agg.this, + expressions=order.expressions, + ) + ) + + return expression + + # TODO: move this into a Dialect as a transpilation def create_database(expression: exp.Expression, db_path: Path | None = None) -> exp.Expression: """Transform create database to attach database. diff --git a/tests/test_fakes.py b/tests/test_fakes.py index 21cc2b9..0daa890 100644 --- a/tests/test_fakes.py +++ b/tests/test_fakes.py @@ -7,6 +7,7 @@ import tempfile from collections.abc import Sequence from decimal import Decimal +from typing import cast import pandas as pd import pytest @@ -35,6 +36,46 @@ def test_array_size(cur: snowflake.connector.cursor.SnowflakeCursor): assert cur.fetchall() == [(None,)] +def test_array_agg_to_json(dcur: snowflake.connector.cursor.DictCursor): + dcur.execute("create table table1 (id number, name varchar)") + values = [(1, "foo"), (2, "bar"), (1, "baz"), (2, "qux")] + + dcur.executemany("insert into table1 values (%s, %s)", values) + + dcur.execute("select array_agg(name) as names from table1") + assert dindent(dcur.fetchall()) == [{"NAMES": '[\n "foo",\n "bar",\n "baz",\n "qux"\n]'}] + + +def test_array_agg_within_group(dcur: snowflake.connector.cursor.DictCursor): + dcur.execute("CREATE TABLE table1 (ID INT, amount INT)") + + # two unique ids, for id 1 there are 3 amounts, for id 2 there are 2 amounts + values = [ + (2, 40), + (1, 10), + (1, 30), + (2, 50), + (1, 20), + ] + dcur.executemany("INSERT INTO TABLE1 VALUES (%s, %s)", values) + + dcur.execute("SELECT id, ARRAY_AGG(amount) WITHIN GROUP (ORDER BY amount DESC) amounts FROM table1 GROUP BY id") + rows = dcur.fetchall() + + assert dindent(rows) == [ + {"ID": 1, "AMOUNTS": "[\n 30,\n 20,\n 10\n]"}, + {"ID": 2, "AMOUNTS": "[\n 50,\n 40\n]"}, + ] + + dcur.execute("SELECT id, ARRAY_AGG(amount) WITHIN GROUP (ORDER BY amount ASC) amounts FROM table1 GROUP BY id") + rows = dcur.fetchall() + + assert dindent(rows) == [ + {"ID": 1, "AMOUNTS": "[\n 10,\n 20,\n 30\n]"}, + {"ID": 2, "AMOUNTS": "[\n 40,\n 50\n]"}, + ] + + def test_binding_default_paramstyle(conn: snowflake.connector.SnowflakeConnection): assert snowflake.connector.paramstyle == "pyformat" with conn.cursor() as cur: @@ -1373,13 +1414,26 @@ def test_write_pandas_dict_different_keys(conn: snowflake.connector.SnowflakeCon def indent(rows: Sequence[tuple] | Sequence[dict]) -> list[tuple]: - # indent duckdb json strings to match snowflake json strings + # indent duckdb json strings tuple values to match snowflake json strings + assert isinstance(rows[0], tuple) return [ (*[json.dumps(json.loads(c), indent=2) if (isinstance(c, str) and c.startswith(("[", "{"))) else c for c in r],) for r in rows ] +def dindent(rows: Sequence[tuple] | Sequence[dict]) -> list[dict]: + # indent duckdb json strings dict values to match snowflake json strings + assert isinstance(rows[0], dict) + return [ + { + k: json.dumps(json.loads(v), indent=2) if (isinstance(v, str) and v.startswith(("[", "{"))) else v + for k, v in cast(dict, r).items() + } + for r in rows + ] + + def sort_keys(sdict: str, indent: int | None = 2) -> str: return json.dumps( json.loads(sdict, object_pairs_hook=lambda x: dict(sorted(x))), diff --git a/tests/test_transforms.py b/tests/test_transforms.py index cd06425..d0ea466 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -7,6 +7,7 @@ from fakesnow.transforms import ( SUCCESS_NOP, _get_to_number_args, + array_agg_within_group, array_size, create_database, describe_table, @@ -51,6 +52,31 @@ def test_array_size() -> None: ) +def test_array_agg_within_group() -> None: + assert ( + sqlglot.parse_one( + "SELECT someid, ARRAY_AGG(DISTINCT id) WITHIN GROUP (ORDER BY id) AS ids FROM example GROUP BY someid" + ) + .transform(array_agg_within_group) + .sql(dialect="duckdb") + == "SELECT someid, ARRAY_AGG(DISTINCT id ORDER BY id NULLS FIRST) AS ids FROM example GROUP BY someid" + ) + + assert ( + sqlglot.parse_one( + "SELECT someid, ARRAY_AGG(id) WITHIN GROUP (ORDER BY id DESC) AS ids FROM example WHERE someid IS NOT NULL GROUP BY someid" # noqa: E501 + ) + .transform(array_agg_within_group) + .sql(dialect="duckdb") + == "SELECT someid, ARRAY_AGG(id ORDER BY id DESC) AS ids FROM example WHERE NOT someid IS NULL GROUP BY someid" + ) + + assert ( + sqlglot.parse_one("SELECT ARRAY_AGG(id) FROM example").transform(array_agg_within_group).sql(dialect="duckdb") + == "SELECT ARRAY_AGG(id) FROM example" + ) + + def test_create_database() -> None: e = sqlglot.parse_one("create database foobar").transform(create_database) assert e.sql() == "ATTACH DATABASE ':memory:' AS foobar"