Skip to content

Commit

Permalink
feat: add ARRAY_AGG() WITHIN GROUP() support (tekumara#66)
Browse files Browse the repository at this point in the history
In Snowflake, to order elements in `ARRAY_AGG` aggregation one needs to
provide `WITHIN GROUP (..)`;
*
https://docs.snowflake.com/en/sql-reference/functions/array_agg#arguments
    
    ```
    ARRAY_AGG(<expr>) WITHIN GROUP (<order-by-clause>)
    ```
In DuckDB, `LIST()/ARRAY_AGG()` requires expression to be passed ordered
and does not support `WITHIN GROUP (...)`;
*
https://duckdb.org/docs/sql/aggregates.html#order-by-clause-in-aggregate-functions
    ```
    ARRAY_AGG( <expr> <order-by-clause> )
    ```

This transformer simply combines *aggreage* expression in `ARRAY_AGG`
and *ordering* one in `WITHIN GROUP` in an *order* expression.

```sql
ARRAY_AGG(DISTINCT id) WITHIN GROUP (ORDER BY id)
```

```python
WithinGroup(
  this=ArrayAgg(
    this=Distinct(
      expressions=[
        Column(
          this=Identifier(this=id, quoted=False))])),
  expression=Order(
    expressions=[
      Ordered(
        this=Column(
          this=Identifier(this=id, quoted=False)),
        nulls_first=False)]))
```

```python
ArrayAgg(
  this=Order(
    this=Distinct(
      expressions=[
        Column(
          this=Identifier(this=ID, quoted=False))]),
    expressions=[
      Ordered(
        this=Column(
          this=Identifier(this=ID, quoted=False)),
        nulls_first=False)]))
```


```diff
-WithinGroup(
-  this=ArrayAgg(
+ArrayAgg(
+  this=Order(
     this=Distinct(
       expressions=[
         Column(
-          this=Identifier(this=id, quoted=False))])),
-  expression=Order(
+          this=Identifier(this=ID, quoted=False))]),
     expressions=[
       Ordered(
         this=Column(
-          this=Identifier(this=id, quoted=False)),
+          this=Identifier(this=ID, quoted=False)),
         nulls_first=False)]))

```

-----

Snowflake has following limitation though;
>
https://docs.snowflake.com/en/sql-reference/functions/array_agg#usage-notes
> If you specify DISTINCT and WITHIN GROUP, both must refer to the same
column. For example:

I've left it out of the scope of this PR to keep it simple.

---------

Co-authored-by: Oliver Mannion <[email protected]>
  • Loading branch information
seruman and tekumara authored Mar 30, 2024
1 parent 9dce794 commit 32afc18
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 1 deletion.
2 changes: 2 additions & 0 deletions fakesnow/fakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ def _execute(
.transform(transforms.array_size)
.transform(transforms.random)
.transform(transforms.identifier)
.transform(transforms.array_agg_within_group)
.transform(transforms.array_agg_to_json)
.transform(lambda e: transforms.show_schemas(e, self._conn.database))
.transform(lambda e: transforms.show_objects_tables(e, self._conn.database))
# TODO collapse into a single show_keys function
Expand Down
33 changes: 33 additions & 0 deletions fakesnow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,39 @@ def array_size(expression: exp.Expression) -> exp.Expression:
return expression


def array_agg_to_json(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.ArrayAgg):
return exp.Anonymous(this="TO_JSON", expressions=[expression])

return expression


def array_agg_within_group(expression: exp.Expression) -> exp.Expression:
"""Convert ARRAY_AGG(<expr>) WITHIN GROUP (<order-by-clause>) to ARRAY_AGG( <expr> <order-by-clause> )
Snowflake uses ARRAY_AGG(<expr>) WITHIN GROUP (ORDER BY <order-by-clause>)
to order the array, but DuckDB uses ARRAY_AGG( <expr> <order-by-clause> ).
See;
- https://docs.snowflake.com/en/sql-reference/functions/array_agg
- https://duckdb.org/docs/sql/aggregates.html#order-by-clause-in-aggregate-functions
Note; Snowflake has following restriction;
If you specify DISTINCT and WITHIN GROUP, both must refer to the same column.
Transformation does not handle this restriction.
"""
if (
isinstance(expression, exp.WithinGroup)
and (agg := expression.find(exp.ArrayAgg))
and (order := expression.expression)
):
return exp.ArrayAgg(
this=exp.Order(
this=agg.this,
expressions=order.expressions,
)
)

return expression


# TODO: move this into a Dialect as a transpilation
def create_database(expression: exp.Expression, db_path: Path | None = None) -> exp.Expression:
"""Transform create database to attach database.
Expand Down
56 changes: 55 additions & 1 deletion tests/test_fakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import tempfile
from collections.abc import Sequence
from decimal import Decimal
from typing import cast

import pandas as pd
import pytest
Expand Down Expand Up @@ -35,6 +36,46 @@ def test_array_size(cur: snowflake.connector.cursor.SnowflakeCursor):
assert cur.fetchall() == [(None,)]


def test_array_agg_to_json(dcur: snowflake.connector.cursor.DictCursor):
dcur.execute("create table table1 (id number, name varchar)")
values = [(1, "foo"), (2, "bar"), (1, "baz"), (2, "qux")]

dcur.executemany("insert into table1 values (%s, %s)", values)

dcur.execute("select array_agg(name) as names from table1")
assert dindent(dcur.fetchall()) == [{"NAMES": '[\n "foo",\n "bar",\n "baz",\n "qux"\n]'}]


def test_array_agg_within_group(dcur: snowflake.connector.cursor.DictCursor):
dcur.execute("CREATE TABLE table1 (ID INT, amount INT)")

# two unique ids, for id 1 there are 3 amounts, for id 2 there are 2 amounts
values = [
(2, 40),
(1, 10),
(1, 30),
(2, 50),
(1, 20),
]
dcur.executemany("INSERT INTO TABLE1 VALUES (%s, %s)", values)

dcur.execute("SELECT id, ARRAY_AGG(amount) WITHIN GROUP (ORDER BY amount DESC) amounts FROM table1 GROUP BY id")
rows = dcur.fetchall()

assert dindent(rows) == [
{"ID": 1, "AMOUNTS": "[\n 30,\n 20,\n 10\n]"},
{"ID": 2, "AMOUNTS": "[\n 50,\n 40\n]"},
]

dcur.execute("SELECT id, ARRAY_AGG(amount) WITHIN GROUP (ORDER BY amount ASC) amounts FROM table1 GROUP BY id")
rows = dcur.fetchall()

assert dindent(rows) == [
{"ID": 1, "AMOUNTS": "[\n 10,\n 20,\n 30\n]"},
{"ID": 2, "AMOUNTS": "[\n 40,\n 50\n]"},
]


def test_binding_default_paramstyle(conn: snowflake.connector.SnowflakeConnection):
assert snowflake.connector.paramstyle == "pyformat"
with conn.cursor() as cur:
Expand Down Expand Up @@ -1373,13 +1414,26 @@ def test_write_pandas_dict_different_keys(conn: snowflake.connector.SnowflakeCon


def indent(rows: Sequence[tuple] | Sequence[dict]) -> list[tuple]:
# indent duckdb json strings to match snowflake json strings
# indent duckdb json strings tuple values to match snowflake json strings
assert isinstance(rows[0], tuple)
return [
(*[json.dumps(json.loads(c), indent=2) if (isinstance(c, str) and c.startswith(("[", "{"))) else c for c in r],)
for r in rows
]


def dindent(rows: Sequence[tuple] | Sequence[dict]) -> list[dict]:
# indent duckdb json strings dict values to match snowflake json strings
assert isinstance(rows[0], dict)
return [
{
k: json.dumps(json.loads(v), indent=2) if (isinstance(v, str) and v.startswith(("[", "{"))) else v
for k, v in cast(dict, r).items()
}
for r in rows
]


def sort_keys(sdict: str, indent: int | None = 2) -> str:
return json.dumps(
json.loads(sdict, object_pairs_hook=lambda x: dict(sorted(x))),
Expand Down
26 changes: 26 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fakesnow.transforms import (
SUCCESS_NOP,
_get_to_number_args,
array_agg_within_group,
array_size,
create_database,
describe_table,
Expand Down Expand Up @@ -51,6 +52,31 @@ def test_array_size() -> None:
)


def test_array_agg_within_group() -> None:
assert (
sqlglot.parse_one(
"SELECT someid, ARRAY_AGG(DISTINCT id) WITHIN GROUP (ORDER BY id) AS ids FROM example GROUP BY someid"
)
.transform(array_agg_within_group)
.sql(dialect="duckdb")
== "SELECT someid, ARRAY_AGG(DISTINCT id ORDER BY id NULLS FIRST) AS ids FROM example GROUP BY someid"
)

assert (
sqlglot.parse_one(
"SELECT someid, ARRAY_AGG(id) WITHIN GROUP (ORDER BY id DESC) AS ids FROM example WHERE someid IS NOT NULL GROUP BY someid" # noqa: E501
)
.transform(array_agg_within_group)
.sql(dialect="duckdb")
== "SELECT someid, ARRAY_AGG(id ORDER BY id DESC) AS ids FROM example WHERE NOT someid IS NULL GROUP BY someid"
)

assert (
sqlglot.parse_one("SELECT ARRAY_AGG(id) FROM example").transform(array_agg_within_group).sql(dialect="duckdb")
== "SELECT ARRAY_AGG(id) FROM example"
)


def test_create_database() -> None:
e = sqlglot.parse_one("create database foobar").transform(create_database)
assert e.sql() == "ATTACH DATABASE ':memory:' AS foobar"
Expand Down

0 comments on commit 32afc18

Please sign in to comment.