diff --git a/fakesnow/transforms.py b/fakesnow/transforms.py index d6e5ba3..418f3f1 100644 --- a/fakesnow/transforms.py +++ b/fakesnow/transforms.py @@ -803,16 +803,74 @@ def to_date(expression: exp.Expression) -> exp.Expression: return expression +def _get_to_number_args(e: exp.ToNumber) -> tuple[exp.Expression | None, exp.Expression | None, exp.Expression | None]: + arg_format = e.args.get("format") + arg_precision = e.args.get("precision") + arg_scale = e.args.get("scale") + + _format = None + _precision = None + _scale = None + + # to_number(value, , , ) + if arg_format: + if arg_format.is_string: + # to_number('100', 'TM9' ...) + _format = arg_format + + # to_number('100', 'TM9', 10 ...) + if arg_precision: + _precision = arg_precision + + # to_number('100', 'TM9', 10, 2) + if arg_scale: + _scale = arg_scale + else: + pass + else: + # to_number('100', 10, ...) + # arg_format is not a string, so it must be precision. + _precision = arg_format + + # to_number('100', 10, 2) + # And arg_precision must be scale + if arg_precision: + _scale = arg_precision + else: + # If format is not provided, just check for precision and scale directly + if arg_precision: + _precision = arg_precision + if arg_scale: + _scale = arg_scale + + return _format, _precision, _scale + + def to_decimal(expression: exp.Expression) -> exp.Expression: """Transform to_decimal, to_number, to_numeric expressions from snowflake to duckdb. See https://docs.snowflake.com/en/sql-reference/functions/to_decimal """ + if isinstance(expression, exp.ToNumber): + format_, precision, scale = _get_to_number_args(expression) + if format_: + raise NotImplementedError(f"{expression.this} with format argument") + + if not precision: + precision = exp.Literal(this="38", is_string=False) + if not scale: + scale = exp.Literal(this="0", is_string=False) + + return exp.Cast( + this=expression.this, + to=exp.DataType(this=exp.DataType.Type.DECIMAL, expressions=[precision, scale], nested=False, prefix=False), + ) + if ( isinstance(expression, exp.Anonymous) and isinstance(expression.this, str) - and expression.this.upper() in ["TO_DECIMAL", "TO_NUMBER", "TO_NUMERIC"] + and expression.this.upper() in ["TO_DECIMAL", "TO_NUMERIC"] ): expressions: list[exp.Expression] = expression.expressions diff --git a/pyproject.toml b/pyproject.toml index 7bb430d..95b324d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ dependencies = [ "duckdb~=0.10.0", "pyarrow", "snowflake-connector-python", - "sqlglot~=21.2.0", + "sqlglot~=23.3.0", ] [project.urls] diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 88ce83c..cd06425 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1,10 +1,12 @@ from pathlib import Path +import pytest import sqlglot from sqlglot import exp from fakesnow.transforms import ( SUCCESS_NOP, + _get_to_number_args, array_size, create_database, describe_table, @@ -245,7 +247,7 @@ def test_json_extract_precedence() -> None: ) .transform(json_extract_precedence) .sql(dialect="duckdb") - == """SELECT {'K1': {'K2': 1}} AS col WHERE (col -> '$.K1' -> '$.K2') > 0""" + == """SELECT {'K1': {'K2': 1}} AS col WHERE (col -> '$.K1.K2') > 0""" ) @@ -396,6 +398,144 @@ def test_use() -> None: ) +def test__get_to_number_args() -> None: + e = sqlglot.parse_one("to_number('100')", read="snowflake") + assert isinstance(e, exp.ToNumber) + assert _get_to_number_args(e) == (None, None, None) + + e = sqlglot.parse_one("to_number('100', 10)", read="snowflake") + assert isinstance(e, exp.ToNumber) + assert _get_to_number_args(e) == (None, exp.Literal(this="10", is_string=False), None) + + e = sqlglot.parse_one("to_number('100', 10,2)", read="snowflake") + assert isinstance(e, exp.ToNumber) + assert _get_to_number_args(e) == ( + None, + exp.Literal(this="10", is_string=False), + exp.Literal(this="2", is_string=False), + ) + + e = sqlglot.parse_one("to_number('100', 'TM9')", read="snowflake") + assert isinstance(e, exp.ToNumber) + assert _get_to_number_args(e) == (exp.Literal(this="TM9", is_string=True), None, None) + + e = sqlglot.parse_one("to_number('100', 'TM9', 10)", read="snowflake") + assert isinstance(e, exp.ToNumber) + assert _get_to_number_args(e) == ( + exp.Literal(this="TM9", is_string=True), + exp.Literal(this="10", is_string=False), + None, + ) + + e = sqlglot.parse_one("to_number('100', 'TM9', 10, 2)", read="snowflake") + assert isinstance(e, exp.ToNumber) + assert _get_to_number_args(e) == ( + exp.Literal(this="TM9", is_string=True), + exp.Literal(this="10", is_string=False), + exp.Literal(this="2", is_string=False), + ) + + +def test_to_number() -> None: + assert ( + sqlglot.parse_one("SELECT to_number('100')", read="snowflake").transform(to_decimal).sql(dialect="duckdb") + == "SELECT CAST('100' AS DECIMAL(38, 0))" + ) + + assert ( + sqlglot.parse_one("SELECT to_number('100', 10)", read="snowflake").transform(to_decimal).sql(dialect="duckdb") + == "SELECT CAST('100' AS DECIMAL(10, 0))" + ) + + assert ( + sqlglot.parse_one("SELECT to_number('100', 10,2)", read="snowflake").transform(to_decimal).sql(dialect="duckdb") + == "SELECT CAST('100' AS DECIMAL(10, 2))" + ) + + with pytest.raises(NotImplementedError): + sqlglot.parse_one("SELECT to_number('100', 'TM9')", read="snowflake").transform(to_decimal).sql( + dialect="duckdb" + ) + + with pytest.raises(NotImplementedError): + sqlglot.parse_one("SELECT to_number('100', 'TM9', 10)", read="snowflake").transform(to_decimal).sql( + dialect="duckdb" + ) + + with pytest.raises(NotImplementedError): + sqlglot.parse_one("SELECT to_number('100', 'TM9', 10, 2)", read="snowflake").transform(to_decimal).sql( + dialect="duckdb" + ) + + +def test_to_number_decimal() -> None: + assert ( + sqlglot.parse_one("SELECT to_decimal('100')", read="snowflake").transform(to_decimal).sql(dialect="duckdb") + == "SELECT CAST('100' AS DECIMAL(38, 0))" + ) + + assert ( + sqlglot.parse_one("SELECT to_decimal('100', 10)", read="snowflake").transform(to_decimal).sql(dialect="duckdb") + == "SELECT CAST('100' AS DECIMAL(10, 0))" + ) + + assert ( + sqlglot.parse_one("SELECT to_decimal('100', 10,2)", read="snowflake") + .transform(to_decimal) + .sql(dialect="duckdb") + == "SELECT CAST('100' AS DECIMAL(10, 2))" + ) + + with pytest.raises(NotImplementedError): + sqlglot.parse_one("SELECT to_decimal('100', 'TM9')", read="snowflake").transform(to_decimal).sql( + dialect="duckdb" + ) + + with pytest.raises(NotImplementedError): + sqlglot.parse_one("SELECT to_decimal('100', 'TM9', 10)", read="snowflake").transform(to_decimal).sql( + dialect="duckdb" + ) + + with pytest.raises(NotImplementedError): + sqlglot.parse_one("SELECT to_decimal('100', 'TM9', 10, 2)", read="snowflake").transform(to_decimal).sql( + dialect="duckdb" + ) + + +def test_to_number_numeric() -> None: + assert ( + sqlglot.parse_one("SELECT to_numeric('100')", read="snowflake").transform(to_decimal).sql(dialect="duckdb") + == "SELECT CAST('100' AS DECIMAL(38, 0))" + ) + + assert ( + sqlglot.parse_one("SELECT to_numeric('100', 10)", read="snowflake").transform(to_decimal).sql(dialect="duckdb") + == "SELECT CAST('100' AS DECIMAL(10, 0))" + ) + + assert ( + sqlglot.parse_one("SELECT to_numeric('100', 10,2)", read="snowflake") + .transform(to_decimal) + .sql(dialect="duckdb") + == "SELECT CAST('100' AS DECIMAL(10, 2))" + ) + + with pytest.raises(NotImplementedError): + sqlglot.parse_one("SELECT to_numeric('100', 'TM9')", read="snowflake").transform(to_decimal).sql( + dialect="duckdb" + ) + + with pytest.raises(NotImplementedError): + sqlglot.parse_one("SELECT to_numeric('100', 'TM9', 10)", read="snowflake").transform(to_decimal).sql( + dialect="duckdb" + ) + + with pytest.raises(NotImplementedError): + sqlglot.parse_one("SELECT to_numeric('100', 'TM9', 10, 2)", read="snowflake").transform(to_decimal).sql( + dialect="duckdb" + ) + + def test_upper_case_unquoted_identifiers() -> None: assert ( sqlglot.parse_one("select name, name as fname from table1").transform(upper_case_unquoted_identifiers).sql()