diff --git a/sqlalchemy_trino/datatype.py b/sqlalchemy_trino/datatype.py index c90bf7c..c4525ac 100644 --- a/sqlalchemy_trino/datatype.py +++ b/sqlalchemy_trino/datatype.py @@ -142,7 +142,8 @@ def parse_sqltype(type_str: str) -> TypeEngine: elif type_name == "row": attr_types: Dict[str, SQLType] = {} for attr_str in split(type_opts): - name, attr_type_str = split(attr_str.strip(), delimiter=' ') + outputs = list(split(attr_str.strip(), delimiter=' ')) + name, attr_type_str = outputs[0], " ".join(outputs[1:]) attr_type = parse_sqltype(attr_type_str) attr_types[name] = attr_type return ROW(attr_types) @@ -156,3 +157,4 @@ def parse_sqltype(type_str: str) -> TypeEngine: type_kwargs = dict(timezone=type_str.endswith("with time zone")) return type_class(**type_kwargs) # TODO: handle time/timestamp(p) precision return type_class(*type_args) + diff --git a/tests/test_datatype_parse.py b/tests/test_datatype_parse.py index 5585643..75c2fe9 100644 --- a/tests/test_datatype_parse.py +++ b/tests/test_datatype_parse.py @@ -79,6 +79,8 @@ def test_parse_map(type_str: str, sql_type: ARRAY): 'row(a varchar(20), b decimal(20,3))': ROW(dict(a=VARCHAR(20), b=DECIMAL(20, 3))), 'row(x array(varchar(10)), y array(array(varchar(10))), z decimal(20,3))': ROW(dict(x=ARRAY(VARCHAR(10)), y=ARRAY(VARCHAR(10), dimensions=2), z=DECIMAL(20, 3))), + 'row(min timestamp(6) with time zone, max timestamp(6) with time zone, null_count bigint)': + ROW(dict(min=TIMESTAMP(), max=TIMESTAMP(), null_count=BIGINT())), }