From 80f0549158b0d3bc107efb55e05b4f05e64d7ca1 Mon Sep 17 00:00:00 2001 From: Francisco Muniz Date: Wed, 18 Aug 2021 17:11:30 -0300 Subject: [PATCH 1/2] data type with timestamps and timezone --- sqlalchemy_trino/datatype.py | 6 +++++- tests/test_datatype_parse.py | 3 ++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/sqlalchemy_trino/datatype.py b/sqlalchemy_trino/datatype.py index c90bf7c..4129fe1 100644 --- a/sqlalchemy_trino/datatype.py +++ b/sqlalchemy_trino/datatype.py @@ -142,7 +142,8 @@ def parse_sqltype(type_str: str) -> TypeEngine: elif type_name == "row": attr_types: Dict[str, SQLType] = {} for attr_str in split(type_opts): - name, attr_type_str = split(attr_str.strip(), delimiter=' ') + outputs = list(split(attr_str.strip(), delimiter=' ')) + name, attr_type_str = outputs[:2] attr_type = parse_sqltype(attr_type_str) attr_types[name] = attr_type return ROW(attr_types) @@ -156,3 +157,6 @@ def parse_sqltype(type_str: str) -> TypeEngine: type_kwargs = dict(timezone=type_str.endswith("with time zone")) return type_class(**type_kwargs) # TODO: handle time/timestamp(p) precision return type_class(*type_args) + +if __name__ == "__main__": + parse_sqltype("row(min timestamp(6) with time zone, max timestamp(6) with time zone, null_count bigint)") diff --git a/tests/test_datatype_parse.py b/tests/test_datatype_parse.py index 5585643..df67012 100644 --- a/tests/test_datatype_parse.py +++ b/tests/test_datatype_parse.py @@ -2,7 +2,6 @@ from assertpy import assert_that from sqlalchemy.sql.sqltypes import * from sqlalchemy.sql.type_api import TypeEngine - from sqlalchemy_trino import datatype from sqlalchemy_trino.datatype import MAP, ROW @@ -79,6 +78,8 @@ def test_parse_map(type_str: str, sql_type: ARRAY): 'row(a varchar(20), b decimal(20,3))': ROW(dict(a=VARCHAR(20), b=DECIMAL(20, 3))), 'row(x array(varchar(10)), y array(array(varchar(10))), z decimal(20,3))': ROW(dict(x=ARRAY(VARCHAR(10)), y=ARRAY(VARCHAR(10), dimensions=2), z=DECIMAL(20, 3))), + 'row(x timestamp(6) with time zone)': + ROW(dict(x=TIMESTAMP())), } From ee8d5a0df83f131f60382039eaddec975d9caaa9 Mon Sep 17 00:00:00 2001 From: Francisco Muniz Date: Thu, 2 Sep 2021 09:44:40 -0300 Subject: [PATCH 2/2] moved test into test_datatype_parse.py and fixed attr_type_str to contain timestamp(6) with time zone instead of only timestamp --- sqlalchemy_trino/datatype.py | 4 +--- tests/test_datatype_parse.py | 5 +++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/sqlalchemy_trino/datatype.py b/sqlalchemy_trino/datatype.py index 4129fe1..c4525ac 100644 --- a/sqlalchemy_trino/datatype.py +++ b/sqlalchemy_trino/datatype.py @@ -143,7 +143,7 @@ def parse_sqltype(type_str: str) -> TypeEngine: attr_types: Dict[str, SQLType] = {} for attr_str in split(type_opts): outputs = list(split(attr_str.strip(), delimiter=' ')) - name, attr_type_str = outputs[:2] + name, attr_type_str = outputs[0], " ".join(outputs[1:]) attr_type = parse_sqltype(attr_type_str) attr_types[name] = attr_type return ROW(attr_types) @@ -158,5 +158,3 @@ def parse_sqltype(type_str: str) -> TypeEngine: return type_class(**type_kwargs) # TODO: handle time/timestamp(p) precision return type_class(*type_args) -if __name__ == "__main__": - parse_sqltype("row(min timestamp(6) with time zone, max timestamp(6) with time zone, null_count bigint)") diff --git a/tests/test_datatype_parse.py b/tests/test_datatype_parse.py index df67012..75c2fe9 100644 --- a/tests/test_datatype_parse.py +++ b/tests/test_datatype_parse.py @@ -2,6 +2,7 @@ from assertpy import assert_that from sqlalchemy.sql.sqltypes import * from sqlalchemy.sql.type_api import TypeEngine + from sqlalchemy_trino import datatype from sqlalchemy_trino.datatype import MAP, ROW @@ -78,8 +79,8 @@ def test_parse_map(type_str: str, sql_type: ARRAY): 'row(a varchar(20), b decimal(20,3))': ROW(dict(a=VARCHAR(20), b=DECIMAL(20, 3))), 'row(x array(varchar(10)), y array(array(varchar(10))), z decimal(20,3))': ROW(dict(x=ARRAY(VARCHAR(10)), y=ARRAY(VARCHAR(10), dimensions=2), z=DECIMAL(20, 3))), - 'row(x timestamp(6) with time zone)': - ROW(dict(x=TIMESTAMP())), + 'row(min timestamp(6) with time zone, max timestamp(6) with time zone, null_count bigint)': + ROW(dict(min=TIMESTAMP(), max=TIMESTAMP(), null_count=BIGINT())), }