Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add transformation for object_agg function #1

Merged
merged 4 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion fakesnow/fakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def _execute(
.transform(transforms.timestamp_ntz_ns)
.transform(transforms.float_to_double)
.transform(transforms.integer_precision)
# TODO(selman): Broken, failes on CTAS queries with CASTs;
# TODO(selman): Broken, fails on CTAS queries with CASTs;
# CREATE TABLE SOME_TABLE AS (
# SELECT
# R1 AS C1,
Expand Down Expand Up @@ -234,6 +234,7 @@ def _execute(
.transform(transforms.to_variant)
.transform(transforms.show_users)
.transform(transforms.create_user)
.transform(transforms.object_agg)
)
sql = transformed.sql(dialect="duckdb")
result_sql = None
Expand Down
18 changes: 14 additions & 4 deletions fakesnow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,13 +1425,13 @@ def to_json_extract_scalar(expression: exp.JSONExtract) -> exp.Expression:
left = expression.left
right = expression.right

# <json-extact> = <string-literal>
# <json-extract> = <string-literal>
if is_json_extract(left) and isinstance(right, exp.Literal) and right.is_string:
json_extract_scalar = exp.Paren(this=to_json_extract_scalar(cast(exp.JSONExtract, left)))

return exp.EQ(this=json_extract_scalar, expression=right)

# <string-literal> = <json-extact>
# <string-literal> = <json-extract>
elif is_json_extract(right) and isinstance(left, exp.Literal) and left.is_string:
json_extract_scalar = exp.Paren(this=to_json_extract_scalar(cast(exp.JSONExtract, right)))

Expand All @@ -1441,7 +1441,7 @@ def to_json_extract_scalar(expression: exp.JSONExtract) -> exp.Expression:


def json_extract_in_string_literals(expression: exp.Expression) -> exp.Expression:
"""Snowflake does implicit casting on JSON extract value on an IN caluse;
"""Snowflake does implicit casting on JSON extract value on an IN clause;

Snowflake;
SELECT
Expand Down Expand Up @@ -1483,7 +1483,7 @@ def json_extract_in_string_literals(expression: exp.Expression) -> exp.Expressio
via `->` operator thus FALSE in first case of the first query.

To keep things simple, if all values in IN clause are string literals, we
can extract JSON value as string/VARCHAR to achive similar behaviour.
can extract JSON value as string/VARCHAR to achieve similar behaviour.

SELECT
TO_JSON({'k': '10'}) AS D,
Expand Down Expand Up @@ -1514,3 +1514,13 @@ def json_extract_in_string_literals(expression: exp.Expression) -> exp.Expressio

json_extract_scalar = exp.JSONExtractScalar(this=je.this, expression=path)
return exp.In(this=json_extract_scalar, expressions=expression.expressions)


def object_agg(expression: exp.Expression) -> exp.Expression:
if (
isinstance(expression, exp.Anonymous)
and isinstance(expression.this, str)
and expression.this.upper() == "OBJECT_AGG"
):
return exp.Anonymous(this="JSON_GROUP_OBJECT", expressions=expression.expressions)
return expression
14 changes: 14 additions & 0 deletions tests/test_fakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1508,6 +1508,20 @@ def test_json_extract_cast_as_varchar(dcur: snowflake.connector.cursor.DictCurso
assert dcur.fetchall() == [{"C_STR_NUMBER": 100, "C_NUM_NUMBER": 100}]


def test_json_group_object(dcur: snowflake.connector.cursor.DictCursor):
dcur.execute("create table table1 (id number, key varchar, value varchar)")
values = [(1, "a", "1"), (1, "b", "2"), (1, "c", "3"), (2, "e", "1"), (2, "f", "1"), (3, "a", "2")]

dcur.executemany("insert into table1 values (%s, %s, %s)", values)
expected = [
{"ID": 1, "OBJ": '{\n "a": "1",\n "b": "2",\n "c": "3"\n}'},
{"ID": 2, "OBJ": '{\n "e": "1",\n "f": "1"\n}'},
{"ID": 3, "OBJ": '{\n "a": "2"\n}'},
]
dcur.execute("select id, object_agg(key, value) as obj from table1 group by 1")
assert dindent(dcur.fetchall()) == expected


def test_write_pandas_quoted_column_names(conn: snowflake.connector.SnowflakeConnection):
with conn.cursor(snowflake.connector.cursor.DictCursor) as dcur:
# colunmn names with spaces
Expand Down
2 changes: 1 addition & 1 deletion tests/test_info_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_info_schema_columns_other(cur: snowflake.connector.cursor.SnowflakeCurs
]


@pytest.mark.xfail(reason="NOTE(selman): removed extact_extract_text_length transformation")
@pytest.mark.xfail(reason="NOTE(selman): removed extract_extract_text_length transformation")
def test_info_schema_columns_text(cur: snowflake.connector.cursor.SnowflakeCursor):
# see https://docs.snowflake.com/en/sql-reference/data-types-text
cur.execute(
Expand Down
145 changes: 7 additions & 138 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
json_extract_eq_string_literal,
json_extract_in_string_literals,
json_extract_precedence,
object_agg,
object_construct,
random,
regex_replace,
Expand Down Expand Up @@ -831,144 +832,6 @@ def test_trim_cast_varchar() -> None:
)


def test__get_to_number_args() -> None:
e = sqlglot.parse_one("to_number('100')", read="snowflake")
assert isinstance(e, exp.ToNumber)
assert _get_to_number_args(e) == (None, None, None)

e = sqlglot.parse_one("to_number('100', 10)", read="snowflake")
assert isinstance(e, exp.ToNumber)
assert _get_to_number_args(e) == (None, exp.Literal(this="10", is_string=False), None)

e = sqlglot.parse_one("to_number('100', 10,2)", read="snowflake")
assert isinstance(e, exp.ToNumber)
assert _get_to_number_args(e) == (
None,
exp.Literal(this="10", is_string=False),
exp.Literal(this="2", is_string=False),
)

e = sqlglot.parse_one("to_number('100', 'TM9')", read="snowflake")
assert isinstance(e, exp.ToNumber)
assert _get_to_number_args(e) == (exp.Literal(this="TM9", is_string=True), None, None)

e = sqlglot.parse_one("to_number('100', 'TM9', 10)", read="snowflake")
assert isinstance(e, exp.ToNumber)
assert _get_to_number_args(e) == (
exp.Literal(this="TM9", is_string=True),
exp.Literal(this="10", is_string=False),
None,
)

e = sqlglot.parse_one("to_number('100', 'TM9', 10, 2)", read="snowflake")
assert isinstance(e, exp.ToNumber)
assert _get_to_number_args(e) == (
exp.Literal(this="TM9", is_string=True),
exp.Literal(this="10", is_string=False),
exp.Literal(this="2", is_string=False),
)


def test_to_number() -> None:
assert (
sqlglot.parse_one("SELECT to_number('100')", read="snowflake").transform(to_decimal).sql(dialect="duckdb")
== "SELECT CAST('100' AS DECIMAL(38, 0))"
)

assert (
sqlglot.parse_one("SELECT to_number('100', 10)", read="snowflake").transform(to_decimal).sql(dialect="duckdb")
== "SELECT CAST('100' AS DECIMAL(10, 0))"
)

assert (
sqlglot.parse_one("SELECT to_number('100', 10,2)", read="snowflake").transform(to_decimal).sql(dialect="duckdb")
== "SELECT CAST('100' AS DECIMAL(10, 2))"
)

with pytest.raises(NotImplementedError):
sqlglot.parse_one("SELECT to_number('100', 'TM9')", read="snowflake").transform(to_decimal).sql(
dialect="duckdb"
)

with pytest.raises(NotImplementedError):
sqlglot.parse_one("SELECT to_number('100', 'TM9', 10)", read="snowflake").transform(to_decimal).sql(
dialect="duckdb"
)

with pytest.raises(NotImplementedError):
sqlglot.parse_one("SELECT to_number('100', 'TM9', 10, 2)", read="snowflake").transform(to_decimal).sql(
dialect="duckdb"
)


def test_to_number_decimal() -> None:
assert (
sqlglot.parse_one("SELECT to_decimal('100')", read="snowflake").transform(to_decimal).sql(dialect="duckdb")
== "SELECT CAST('100' AS DECIMAL(38, 0))"
)

assert (
sqlglot.parse_one("SELECT to_decimal('100', 10)", read="snowflake").transform(to_decimal).sql(dialect="duckdb")
== "SELECT CAST('100' AS DECIMAL(10, 0))"
)

assert (
sqlglot.parse_one("SELECT to_decimal('100', 10,2)", read="snowflake")
.transform(to_decimal)
.sql(dialect="duckdb")
== "SELECT CAST('100' AS DECIMAL(10, 2))"
)

with pytest.raises(NotImplementedError):
sqlglot.parse_one("SELECT to_decimal('100', 'TM9')", read="snowflake").transform(to_decimal).sql(
dialect="duckdb"
)

with pytest.raises(NotImplementedError):
sqlglot.parse_one("SELECT to_decimal('100', 'TM9', 10)", read="snowflake").transform(to_decimal).sql(
dialect="duckdb"
)

with pytest.raises(NotImplementedError):
sqlglot.parse_one("SELECT to_decimal('100', 'TM9', 10, 2)", read="snowflake").transform(to_decimal).sql(
dialect="duckdb"
)


def test_to_number_numeric() -> None:
assert (
sqlglot.parse_one("SELECT to_numeric('100')", read="snowflake").transform(to_decimal).sql(dialect="duckdb")
== "SELECT CAST('100' AS DECIMAL(38, 0))"
)

assert (
sqlglot.parse_one("SELECT to_numeric('100', 10)", read="snowflake").transform(to_decimal).sql(dialect="duckdb")
== "SELECT CAST('100' AS DECIMAL(10, 0))"
)

assert (
sqlglot.parse_one("SELECT to_numeric('100', 10,2)", read="snowflake")
.transform(to_decimal)
.sql(dialect="duckdb")
== "SELECT CAST('100' AS DECIMAL(10, 2))"
)

with pytest.raises(NotImplementedError):
sqlglot.parse_one("SELECT to_numeric('100', 'TM9')", read="snowflake").transform(to_decimal).sql(
dialect="duckdb"
)

with pytest.raises(NotImplementedError):
sqlglot.parse_one("SELECT to_numeric('100', 'TM9', 10)", read="snowflake").transform(to_decimal).sql(
dialect="duckdb"
)

with pytest.raises(NotImplementedError):
sqlglot.parse_one("SELECT to_numeric('100', 'TM9', 10, 2)", read="snowflake").transform(to_decimal).sql(
dialect="duckdb"
)


def test_upper_case_unquoted_identifiers() -> None:
assert (
sqlglot.parse_one("select name, name as fname from table1").transform(upper_case_unquoted_identifiers).sql()
Expand Down Expand Up @@ -1082,3 +945,9 @@ def test_to_variant() -> None:
.sql(dialect="duckdb")
== "SELECT TO_JSON('str')"
)


def test_object_agg() -> None:
sql = "SELECT ID, OBJECT_AGG(KEY, VALUE) AS POSTCALC FROM POSTCALC_WITH_UPDATED_AT GROUP BY 1"
expected = "SELECT ID, JSON_GROUP_OBJECT(KEY, VALUE) AS POSTCALC FROM POSTCALC_WITH_UPDATED_AT GROUP BY 1"
assert sqlglot.parse_one(sql, read="snowflake").transform(object_agg).sql(dialect="duckdb") == expected