Skip to content

Commit

Permalink
python deserailized schema from java (#361)
Browse files Browse the repository at this point in the history
Signed-off-by: zenghua <[email protected]>
Co-authored-by: zenghua <[email protected]>
  • Loading branch information
Ceng23333 and zenghua authored Nov 7, 2023
1 parent aac6e24 commit f933d5c
Showing 1 changed file with 136 additions and 10 deletions.
146 changes: 136 additions & 10 deletions python/lakesoul/metadata/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pyarrow


def to_arrow_type(arrow_type_json):
def deserialize_from_rust_arrow_type(arrow_type_json):
if isinstance(arrow_type_json, str):
if arrow_type_json == 'Boolean':
return pyarrow.bool_()
Expand Down Expand Up @@ -57,18 +57,19 @@ def to_arrow_type(arrow_type_json):
elif arrow_type_json['Interval'] == 'YearMonth':
return pyarrow.month_day_nano_interval()
elif 'List' in arrow_type_json:
return pyarrow.list_(to_arrow_type(arrow_type_json['List']['data_type']))
return pyarrow.list_(deserialize_from_rust_arrow_type(arrow_type_json['List']['data_type']))
elif 'FixedSizeList' in arrow_type_json:
return pyarrow.list_(to_arrow_type(arrow_type_json['FixedSizeList'][0]['data_type']),
return pyarrow.list_(deserialize_from_rust_arrow_type(arrow_type_json['FixedSizeList'][0]['data_type']),
arrow_type_json['FixedSizeList'][1])
elif 'Dictionary' in arrow_type_json:
return pyarrow.dictionary(arrow_type_json['Dictionary'][0], arrow_type_json['Dictionary'][1])
elif 'FixedSizeBinary' in arrow_type_json:
return pyarrow.binary(arrow_type_json['FixedSizeBinary'])
elif 'Map' in arrow_type_json:
return pyarrow.map_(to_arrow_type(arrow_type_json['Map'][0]['data_type']['Struct'][0]['data_type']),
to_arrow_type(arrow_type_json['Map'][0]['data_type']['Struct'][1]['data_type']),
arrow_type_json['Map'][1])
return pyarrow.map_(
deserialize_from_rust_arrow_type(arrow_type_json['Map'][0]['data_type']['Struct'][0]['data_type']),
deserialize_from_rust_arrow_type(arrow_type_json['Map'][0]['data_type']['Struct'][1]['data_type']),
arrow_type_json['Map'][1])
elif 'Struct' in arrow_type_json:
arrow_fields = []
for field in arrow_type_json['Struct']:
Expand All @@ -82,13 +83,137 @@ def to_arrow_type(arrow_type_json):
unit = arrow_type_json['Timestamp'][0]
unit = 's' if unit == 'Second' else 'ms' if unit == 'Millisecond' else 'us' if unit == 'Microsecond' else 'ns'
return pyarrow.timestamp(unit, arrow_type_json['Timestamp'][1])
else:
raise IOError("Not supported spark type " + str(arrow_type_json))
raise IOError("Failed at deserialize_from_rust_arrow_type: " + str(arrow_type_json))


def deserialize_from_java_arrow_field(arrow_field_json):
arrow_type_json = arrow_field_json['type']
filed_name = arrow_field_json['name']
nullable = arrow_field_json['nullable']
arrow_type = None
if isinstance(arrow_type_json, dict):
name = arrow_type_json['name']
if name == 'null':
arrow_type = pyarrow.null()
if name == 'struct':
arrow_fields = []
for child_field in arrow_field_json['children']:
arrow_fields.append(deserialize_from_java_arrow_field(child_field))
arrow_type = pyarrow.struct(arrow_fields)
if name == 'list':
child_field = arrow_field_json['children'][0]
arrow_type = pyarrow.list_(deserialize_from_java_arrow_field(child_field).type)
if name == 'largelist':
child_field = arrow_field_json['children'][0]
arrow_type = pyarrow.large_list(deserialize_from_java_arrow_field(child_field).type)
if name == 'fixedsizelist':
child_field = arrow_field_json['children'][0]
list_size = arrow_type_json['listSize']
arrow_type = pyarrow.list_(deserialize_from_java_arrow_field(child_field).type, list_size)
if name == 'union':
pass
if name == 'map':
keys_sorted = arrow_type_json['keysSorted']
child_field = arrow_field_json['children'][0]
child_type = deserialize_from_java_arrow_field(child_field).type
pyarrow.map_(child_type.field[0].type, child_type.field[1].type, keys_sorted)
if name == 'int':
if arrow_type_json['isSigned']:
if arrow_type_json['bitWidth'] == 8:
arrow_type = pyarrow.int8()
elif arrow_type_json['bitWidth'] == 16:
arrow_type = pyarrow.int16()
elif arrow_type_json['bitWidth'] == 32:
arrow_type = pyarrow.int32()
elif arrow_type_json['bitWidth'] == 64:
arrow_type = pyarrow.int64()
else:
if arrow_type_json['bitWidth'] == 8:
arrow_type = pyarrow.uint8()
elif arrow_type_json['bitWidth'] == 16:
arrow_type = pyarrow.uint16()
elif arrow_type_json['bitWidth'] == 32:
arrow_type = pyarrow.uint32()
elif arrow_type_json['bitWidth'] == 64:
arrow_type = pyarrow.uint64()
if name == 'floatingpoint':
precision = arrow_type_json['precision']
if precision == 'HALF':
arrow_type = pyarrow.float16()
elif precision == 'SINGLE':
arrow_type = pyarrow.float32()
elif precision == 'DOUBLE':
arrow_type = pyarrow.float64()
if name == 'utf8':
arrow_type = pyarrow.utf8()
if name == 'largeutf8':
arrow_type = pyarrow.large_utf8()
if name == 'binary':
arrow_type = pyarrow.binary()
if name == 'largebinary':
arrow_type = pyarrow.large_binary()
if name == 'fixedsizebinary':
bit_width = arrow_type_json['bitWidth']
arrow_type = pyarrow.binary(bit_width)
if name == 'bool':
arrow_type = pyarrow.bool_()
if name == 'decimal':
precision = arrow_type_json['precision']
scale = arrow_type_json['scale']
bit_width = arrow_type_json['bitWidth']
if bit_width > 128:
arrow_type = pyarrow.decimal256(precision, scale)
else:
arrow_type = pyarrow.decimal128(precision, scale)
if name == 'date':
unit = arrow_type_json['unit']
if unit == 'DAY':
arrow_type = pyarrow.date32()
else:
arrow_type = pyarrow.date64()
if name == 'time':
unit = arrow_type_json['unit']
unit = arrow_type_json['unit']
if unit == 'SECOND':
unit = 's'
elif unit == 'MILLISECOND':
unit = 'ms'
elif unit == 'MICROSECOND':
unit = 'us'
elif unit == 'NANOSECOND':
unit = 'ns'
bit_width = arrow_type_json['bitWidth']
if bit_width > 32:
arrow_type = pyarrow.time64(unit)
else:
arrow_type = pyarrow.time32(unit)
if name == 'timestamp':
unit = arrow_type_json['unit']
if unit == 'SECOND':
unit = 's'
elif unit == 'MILLISECOND':
unit = 'ms'
elif unit == 'MICROSECOND':
unit = 'us'
elif unit == 'NANOSECOND':
unit = 'ns'
timezone = arrow_type_json['timezone']
arrow_type = pyarrow.timestamp(unit, timezone)
if name == 'interval':
pass
if name == 'duration':
pass
if arrow_type is None:
raise IOError("Failed at deserialize_from_java_arrow_type: " + str(arrow_type_json))
return pyarrow.field(filed_name, arrow_type, nullable)


def to_arrow_field(arrow_field_json):
return pyarrow.field(arrow_field_json['name'], to_arrow_type(arrow_field_json['data_type']),
arrow_field_json['nullable'])
if 'data_type' in arrow_field_json:
return pyarrow.field(arrow_field_json['name'], deserialize_from_rust_arrow_type(arrow_field_json['data_type']),
arrow_field_json['nullable'])
else:
return deserialize_from_java_arrow_field(arrow_field_json)


def to_arrow_schema(schema_json_str, exclude_columns=None):
Expand All @@ -100,4 +225,5 @@ def to_arrow_schema(schema_json_str, exclude_columns=None):
if field['name'] in exclude_columns:
continue
arrow_fields.append(to_arrow_field(field))
print(arrow_fields)
return pyarrow.schema(arrow_fields)

0 comments on commit f933d5c

Please sign in to comment.