Skip to content

Commit

Permalink
add transformation for object_agg function (#1)
Browse files Browse the repository at this point in the history
* add transformation for object_agg function

* add more tests  for object_agg

* rm duplicate tests

* ruff format
  • Loading branch information
ilkinulas authored Apr 5, 2024
1 parent f3e1e7a commit e5dd229
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 144 deletions.
3 changes: 2 additions & 1 deletion fakesnow/fakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def _execute(
.transform(transforms.timestamp_ntz_ns)
.transform(transforms.float_to_double)
.transform(transforms.integer_precision)
# TODO(selman): Broken, failes on CTAS queries with CASTs;
# TODO(selman): Broken, fails on CTAS queries with CASTs;
# CREATE TABLE SOME_TABLE AS (
# SELECT
# R1 AS C1,
Expand Down Expand Up @@ -234,6 +234,7 @@ def _execute(
.transform(transforms.to_variant)
.transform(transforms.show_users)
.transform(transforms.create_user)
.transform(transforms.object_agg)
)
sql = transformed.sql(dialect="duckdb")
result_sql = None
Expand Down
18 changes: 14 additions & 4 deletions fakesnow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,13 +1425,13 @@ def to_json_extract_scalar(expression: exp.JSONExtract) -> exp.Expression:
left = expression.left
right = expression.right

# <json-extact> = <string-literal>
# <json-extract> = <string-literal>
if is_json_extract(left) and isinstance(right, exp.Literal) and right.is_string:
json_extract_scalar = exp.Paren(this=to_json_extract_scalar(cast(exp.JSONExtract, left)))

return exp.EQ(this=json_extract_scalar, expression=right)

# <string-literal> = <json-extact>
# <string-literal> = <json-extract>
elif is_json_extract(right) and isinstance(left, exp.Literal) and left.is_string:
json_extract_scalar = exp.Paren(this=to_json_extract_scalar(cast(exp.JSONExtract, right)))

Expand All @@ -1441,7 +1441,7 @@ def to_json_extract_scalar(expression: exp.JSONExtract) -> exp.Expression:


def json_extract_in_string_literals(expression: exp.Expression) -> exp.Expression:
"""Snowflake does implicit casting on JSON extract value on an IN caluse;
"""Snowflake does implicit casting on JSON extract value on an IN clause;
Snowflake;
SELECT
Expand Down Expand Up @@ -1483,7 +1483,7 @@ def json_extract_in_string_literals(expression: exp.Expression) -> exp.Expressio
via `->` operator thus FALSE in first case of the first query.
To keep things simple, if all values in IN clause are string literals, we
can extract JSON value as string/VARCHAR to achive similar behaviour.
can extract JSON value as string/VARCHAR to achieve similar behaviour.
SELECT
TO_JSON({'k': '10'}) AS D,
Expand Down Expand Up @@ -1514,3 +1514,13 @@ def json_extract_in_string_literals(expression: exp.Expression) -> exp.Expressio

json_extract_scalar = exp.JSONExtractScalar(this=je.this, expression=path)
return exp.In(this=json_extract_scalar, expressions=expression.expressions)


def object_agg(expression: exp.Expression) -> exp.Expression:
if (
isinstance(expression, exp.Anonymous)
and isinstance(expression.this, str)
and expression.this.upper() == "OBJECT_AGG"
):
return exp.Anonymous(this="JSON_GROUP_OBJECT", expressions=expression.expressions)
return expression
14 changes: 14 additions & 0 deletions tests/test_fakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1508,6 +1508,20 @@ def test_json_extract_cast_as_varchar(dcur: snowflake.connector.cursor.DictCurso
assert dcur.fetchall() == [{"C_STR_NUMBER": 100, "C_NUM_NUMBER": 100}]


def test_json_group_object(dcur: snowflake.connector.cursor.DictCursor):
dcur.execute("create table table1 (id number, key varchar, value varchar)")
values = [(1, "a", "1"), (1, "b", "2"), (1, "c", "3"), (2, "e", "1"), (2, "f", "1"), (3, "a", "2")]

dcur.executemany("insert into table1 values (%s, %s, %s)", values)
expected = [
{"ID": 1, "OBJ": '{\n "a": "1",\n "b": "2",\n "c": "3"\n}'},
{"ID": 2, "OBJ": '{\n "e": "1",\n "f": "1"\n}'},
{"ID": 3, "OBJ": '{\n "a": "2"\n}'},
]
dcur.execute("select id, object_agg(key, value) as obj from table1 group by 1")
assert dindent(dcur.fetchall()) == expected


def test_write_pandas_quoted_column_names(conn: snowflake.connector.SnowflakeConnection):
with conn.cursor(snowflake.connector.cursor.DictCursor) as dcur:
# colunmn names with spaces
Expand Down
2 changes: 1 addition & 1 deletion tests/test_info_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_info_schema_columns_other(cur: snowflake.connector.cursor.SnowflakeCurs
]


@pytest.mark.xfail(reason="NOTE(selman): removed extact_extract_text_length transformation")
@pytest.mark.xfail(reason="NOTE(selman): removed extract_extract_text_length transformation")
def test_info_schema_columns_text(cur: snowflake.connector.cursor.SnowflakeCursor):
# see https://docs.snowflake.com/en/sql-reference/data-types-text
cur.execute(
Expand Down
145 changes: 7 additions & 138 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
json_extract_eq_string_literal,
json_extract_in_string_literals,
json_extract_precedence,
object_agg,
object_construct,
random,
regex_replace,
Expand Down Expand Up @@ -831,144 +832,6 @@ def test_trim_cast_varchar() -> None:
)


def test__get_to_number_args() -> None:
e = sqlglot.parse_one("to_number('100')", read="snowflake")
assert isinstance(e, exp.ToNumber)
assert _get_to_number_args(e) == (None, None, None)

e = sqlglot.parse_one("to_number('100', 10)", read="snowflake")
assert isinstance(e, exp.ToNumber)
assert _get_to_number_args(e) == (None, exp.Literal(this="10", is_string=False), None)

e = sqlglot.parse_one("to_number('100', 10,2)", read="snowflake")
assert isinstance(e, exp.ToNumber)
assert _get_to_number_args(e) == (
None,
exp.Literal(this="10", is_string=False),
exp.Literal(this="2", is_string=False),
)

e = sqlglot.parse_one("to_number('100', 'TM9')", read="snowflake")
assert isinstance(e, exp.ToNumber)
assert _get_to_number_args(e) == (exp.Literal(this="TM9", is_string=True), None, None)

e = sqlglot.parse_one("to_number('100', 'TM9', 10)", read="snowflake")
assert isinstance(e, exp.ToNumber)
assert _get_to_number_args(e) == (
exp.Literal(this="TM9", is_string=True),
exp.Literal(this="10", is_string=False),
None,
)

e = sqlglot.parse_one("to_number('100', 'TM9', 10, 2)", read="snowflake")
assert isinstance(e, exp.ToNumber)
assert _get_to_number_args(e) == (
exp.Literal(this="TM9", is_string=True),
exp.Literal(this="10", is_string=False),
exp.Literal(this="2", is_string=False),
)


def test_to_number() -> None:
assert (
sqlglot.parse_one("SELECT to_number('100')", read="snowflake").transform(to_decimal).sql(dialect="duckdb")
== "SELECT CAST('100' AS DECIMAL(38, 0))"
)

assert (
sqlglot.parse_one("SELECT to_number('100', 10)", read="snowflake").transform(to_decimal).sql(dialect="duckdb")
== "SELECT CAST('100' AS DECIMAL(10, 0))"
)

assert (
sqlglot.parse_one("SELECT to_number('100', 10,2)", read="snowflake").transform(to_decimal).sql(dialect="duckdb")
== "SELECT CAST('100' AS DECIMAL(10, 2))"
)

with pytest.raises(NotImplementedError):
sqlglot.parse_one("SELECT to_number('100', 'TM9')", read="snowflake").transform(to_decimal).sql(
dialect="duckdb"
)

with pytest.raises(NotImplementedError):
sqlglot.parse_one("SELECT to_number('100', 'TM9', 10)", read="snowflake").transform(to_decimal).sql(
dialect="duckdb"
)

with pytest.raises(NotImplementedError):
sqlglot.parse_one("SELECT to_number('100', 'TM9', 10, 2)", read="snowflake").transform(to_decimal).sql(
dialect="duckdb"
)


def test_to_number_decimal() -> None:
assert (
sqlglot.parse_one("SELECT to_decimal('100')", read="snowflake").transform(to_decimal).sql(dialect="duckdb")
== "SELECT CAST('100' AS DECIMAL(38, 0))"
)

assert (
sqlglot.parse_one("SELECT to_decimal('100', 10)", read="snowflake").transform(to_decimal).sql(dialect="duckdb")
== "SELECT CAST('100' AS DECIMAL(10, 0))"
)

assert (
sqlglot.parse_one("SELECT to_decimal('100', 10,2)", read="snowflake")
.transform(to_decimal)
.sql(dialect="duckdb")
== "SELECT CAST('100' AS DECIMAL(10, 2))"
)

with pytest.raises(NotImplementedError):
sqlglot.parse_one("SELECT to_decimal('100', 'TM9')", read="snowflake").transform(to_decimal).sql(
dialect="duckdb"
)

with pytest.raises(NotImplementedError):
sqlglot.parse_one("SELECT to_decimal('100', 'TM9', 10)", read="snowflake").transform(to_decimal).sql(
dialect="duckdb"
)

with pytest.raises(NotImplementedError):
sqlglot.parse_one("SELECT to_decimal('100', 'TM9', 10, 2)", read="snowflake").transform(to_decimal).sql(
dialect="duckdb"
)


def test_to_number_numeric() -> None:
assert (
sqlglot.parse_one("SELECT to_numeric('100')", read="snowflake").transform(to_decimal).sql(dialect="duckdb")
== "SELECT CAST('100' AS DECIMAL(38, 0))"
)

assert (
sqlglot.parse_one("SELECT to_numeric('100', 10)", read="snowflake").transform(to_decimal).sql(dialect="duckdb")
== "SELECT CAST('100' AS DECIMAL(10, 0))"
)

assert (
sqlglot.parse_one("SELECT to_numeric('100', 10,2)", read="snowflake")
.transform(to_decimal)
.sql(dialect="duckdb")
== "SELECT CAST('100' AS DECIMAL(10, 2))"
)

with pytest.raises(NotImplementedError):
sqlglot.parse_one("SELECT to_numeric('100', 'TM9')", read="snowflake").transform(to_decimal).sql(
dialect="duckdb"
)

with pytest.raises(NotImplementedError):
sqlglot.parse_one("SELECT to_numeric('100', 'TM9', 10)", read="snowflake").transform(to_decimal).sql(
dialect="duckdb"
)

with pytest.raises(NotImplementedError):
sqlglot.parse_one("SELECT to_numeric('100', 'TM9', 10, 2)", read="snowflake").transform(to_decimal).sql(
dialect="duckdb"
)


def test_upper_case_unquoted_identifiers() -> None:
assert (
sqlglot.parse_one("select name, name as fname from table1").transform(upper_case_unquoted_identifiers).sql()
Expand Down Expand Up @@ -1082,3 +945,9 @@ def test_to_variant() -> None:
.sql(dialect="duckdb")
== "SELECT TO_JSON('str')"
)


def test_object_agg() -> None:
sql = "SELECT ID, OBJECT_AGG(KEY, VALUE) AS POSTCALC FROM POSTCALC_WITH_UPDATED_AT GROUP BY 1"
expected = "SELECT ID, JSON_GROUP_OBJECT(KEY, VALUE) AS POSTCALC FROM POSTCALC_WITH_UPDATED_AT GROUP BY 1"
assert sqlglot.parse_one(sql, read="snowflake").transform(object_agg).sql(dialect="duckdb") == expected

0 comments on commit e5dd229

Please sign in to comment.