Skip to content
This repository has been archived by the owner on May 5, 2022. It is now read-only.

Fix data type with timestamps and timezone #22

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion sqlalchemy_trino/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ def parse_sqltype(type_str: str) -> TypeEngine:
elif type_name == "row":
attr_types: Dict[str, SQLType] = {}
for attr_str in split(type_opts):
name, attr_type_str = split(attr_str.strip(), delimiter=' ')
outputs = list(split(attr_str.strip(), delimiter=' '))
name, attr_type_str = outputs[:2]
fcomuniz marked this conversation as resolved.
Show resolved Hide resolved
attr_type = parse_sqltype(attr_type_str)
attr_types[name] = attr_type
return ROW(attr_types)
Expand All @@ -156,3 +157,6 @@ def parse_sqltype(type_str: str) -> TypeEngine:
type_kwargs = dict(timezone=type_str.endswith("with time zone"))
return type_class(**type_kwargs) # TODO: handle time/timestamp(p) precision
return type_class(*type_args)

if __name__ == "__main__":
parse_sqltype("row(min timestamp(6) with time zone, max timestamp(6) with time zone, null_count bigint)")
fcomuniz marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 2 additions & 1 deletion tests/test_datatype_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from assertpy import assert_that
from sqlalchemy.sql.sqltypes import *
from sqlalchemy.sql.type_api import TypeEngine

fcomuniz marked this conversation as resolved.
Show resolved Hide resolved
from sqlalchemy_trino import datatype
from sqlalchemy_trino.datatype import MAP, ROW

Expand Down Expand Up @@ -79,6 +78,8 @@ def test_parse_map(type_str: str, sql_type: ARRAY):
'row(a varchar(20), b decimal(20,3))': ROW(dict(a=VARCHAR(20), b=DECIMAL(20, 3))),
'row(x array(varchar(10)), y array(array(varchar(10))), z decimal(20,3))':
ROW(dict(x=ARRAY(VARCHAR(10)), y=ARRAY(VARCHAR(10), dimensions=2), z=DECIMAL(20, 3))),
'row(x timestamp(6) with time zone)':
ROW(dict(x=TIMESTAMP())),
}


Expand Down