Skip to content

Commit

Permalink
update python metadata interface && full arrow types test
Browse files Browse the repository at this point in the history
Signed-off-by: zenghua <[email protected]>
  • Loading branch information
zenghua committed Nov 2, 2023
1 parent e299795 commit 4dced35
Show file tree
Hide file tree
Showing 5 changed files with 285 additions and 64 deletions.
24 changes: 19 additions & 5 deletions python/lakesoul/metadata/lib/lakesoul_metadata_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class NonNull(Structure):


def reload_lib(path):
global lib, execute_query, create_tokio_runtime, free_tokio_runtime, create_tokio_postgres_client, free_tokio_postgres_client, create_prepared_statement, free_prepared_statement
global lib, execute_query, create_tokio_runtime, free_tokio_runtime, create_tokio_postgres_client, free_tokio_postgres_client, create_prepared_statement, free_prepared_statement, export_bytes_result, free_bytes_result
lib = CDLL(path)
# pub extern "C" fn execute_query(
# callback: extern "C" fn(i32, *const c_char),
Expand All @@ -22,13 +22,27 @@ def reload_lib(path):
# prepared: NonNull<Result<PreparedStatement>>,
# query_type: i32,
# joined_string: *const c_char,
# addr: c_ptrdiff_t,
# )
# ) -> NonNull<Result<BytesResult>>
execute_query = lib.execute_query
execute_query.restype = c_void_p
execute_query.restype = POINTER(NonNull)
execute_query.argtypes = [CFUNCTYPE(c_void_p, c_int, c_char_p), POINTER(NonNull), POINTER(NonNull),
POINTER(NonNull),
c_int, c_char_p, c_char_p]
c_int, c_char_p]

# pub extern "C" fn export_bytes_result(
# callback: extern "C" fn(bool, *const c_char),
# bytes: NonNull<Result<BytesResult>>,
# len: i32,
# addr: c_ptrdiff_t,
# )
export_bytes_result = lib.export_bytes_result
export_bytes_result.restype = c_void_p
export_bytes_result.argtypes = [CFUNCTYPE(c_void_p, c_bool, c_char_p), POINTER(NonNull), c_int, c_char_p]

# pub extern "C" fn free_bytes_result(bytes: NonNull<Result<BytesResult>>)
free_bytes_result = lib.free_bytes_result
free_bytes_result.restype = c_void_p
free_bytes_result.argtypes = [POINTER(NonNull)]

# pub extern "C" fn create_tokio_runtime() -> NonNull<Result<TokioRuntime>>
create_tokio_runtime = lib.create_tokio_runtime
Expand Down
52 changes: 33 additions & 19 deletions python/lakesoul/metadata/native_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
global config
config = None


def reset_pg_conf(conf):
global config
config = " ".join(conf)


def get_pg_conf_from_env():
import os
conf = []
Expand All @@ -31,17 +33,18 @@ def get_pg_conf_from_env():
return conf
return None


class NativeMetadataClient:
def __init__(self):
self._lock = threading.Lock()
importlib.reload(lib)
self._buffer = create_string_buffer(4096)
self._large_buffer = create_string_buffer(65536)
self._runtime = lib.lakesoul_metadata_c.create_tokio_runtime()
self._free_tokio_runtime = lib.lakesoul_metadata_c.free_tokio_runtime
self._query_result_len = 0
self._bool = False

def callback(bool, msg):
#print("create connection callback: status={} msg={}".format(bool, msg.decode("utf-8")))
print("create connection callback: status={} msg={}".format(bool, msg.decode("utf-8")))
if not bool:
message = "fail to initialize lakesoul.metadata.native_client.NativeMetadataClient"
raise RuntimeError(message)
Expand Down Expand Up @@ -84,27 +87,38 @@ def __del__(self):

def execute_query(self, query_type, params):
joined_params = PARAM_DELIM.join(params).encode("utf-8")
buffer = self._buffer
if query_type >= DAO_TYPE_QUERY_LIST_OFFSET:
buffer = self._large_buffer
buffer.value = b''

def callback(len, msg):
#print("execute_query query_type={} callback: len={} msg={}".format(query_type, len, msg.decode("utf-8")))
pass
def execute_query_callback(len, msg):
print("execute_query query_type={} callback: len={} msg={}".format(query_type, len, msg.decode("utf-8")))
self._query_result_len = len

def export_bytes_result_callback(bool, msg):
print(
"export_bytes_result callback: bool={} msg={}".format(query_type, bool, msg.decode("utf-8")))
self._bool = bool

with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(lib.lakesoul_metadata_c.execute_query,
CFUNCTYPE(c_void_p, c_int, c_char_p)(callback), self._runtime, self._client,
self._prepared, query_type, joined_params, buffer)
CFUNCTYPE(c_void_p, c_int, c_char_p)(execute_query_callback), self._runtime,
self._client,
self._prepared, query_type, joined_params)
bytes = future.result(2.0)

buffer = create_string_buffer(self._query_result_len)
future = executor.submit(lib.lakesoul_metadata_c.export_bytes_result,
CFUNCTYPE(c_void_p, c_bool, c_char_p)(export_bytes_result_callback), bytes,
self._query_result_len, buffer)
future.result(2.0)

if len(buffer.value) == 0:
return None
else:
wrapper = entity_pb2.JniWrapper()
wrapper.ParseFromString(buffer.value)
return wrapper
ret = None
if len(buffer.value) > 0:
wrapper = entity_pb2.JniWrapper()
wrapper.ParseFromString(buffer.value)
ret = wrapper

lib.lakesoul_metadata_c.free_bytes_result(bytes)

return ret

def get_lock(self):
return self._lock
Expand All @@ -120,7 +134,7 @@ def get_instance():
if INSTANCE is None:
import os
dir_path = os.path.dirname(os.path.realpath(__file__))
file_path = os.path.join(dir_path, 'lib', 'liblakesoul_metadata_c.so')
file_path = os.path.join(dir_path, 'lib', 'liblakesoul_metadata_c.dylib')
lib.reload_lib(file_path)
INSTANCE = NativeMetadataClient()
return INSTANCE
Expand Down
112 changes: 85 additions & 27 deletions python/lakesoul/metadata/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,95 @@
import pyarrow


def to_arrow_field(spark_field_json):
spark_type = spark_field_json['type']
arrow_type = None
if spark_type == 'long':
arrow_type = pyarrow.int64()
elif spark_type == 'integer':
arrow_type = pyarrow.int32()
elif spark_type == 'string':
arrow_type = pyarrow.utf8()
elif spark_type == 'float':
arrow_type = pyarrow.float32()
elif spark_type == 'double':
arrow_type = pyarrow.float64()
elif spark_type == "binary":
arrow_type = pyarrow.binary()
elif spark_type.startswith("decimal"):
arrow_type = pyarrow.decimal128(38)
elif spark_type == 'struct':
fields = spark_field_json['fields']
arrow_fields = []
for field in fields:
arrow_fields.append(to_arrow_field(field))
arrow_type = pyarrow.struct(arrow_fields)
def to_arrow_type(arrow_type_json):
if isinstance(arrow_type_json, str):
if arrow_type_json == 'Boolean':
return pyarrow.bool_()
elif arrow_type_json == 'Date32':
return pyarrow.date32()
elif arrow_type_json == 'Date64':
return pyarrow.date64()
elif arrow_type_json == 'Int8':
return pyarrow.int8()
elif arrow_type_json == 'Int16':
return pyarrow.int16()
elif arrow_type_json == 'Int32':
return pyarrow.int32()
elif arrow_type_json == 'Int64':
return pyarrow.int64()
elif arrow_type_json == 'UInt8':
return pyarrow.uint8()
elif arrow_type_json == 'UInt16':
return pyarrow.uint16()
elif arrow_type_json == 'UInt32':
return pyarrow.uint32()
elif arrow_type_json == 'UInt64':
return pyarrow.uint64()
elif arrow_type_json == 'String':
return pyarrow.string()
elif arrow_type_json == 'Utf8':
return pyarrow.utf8()
elif arrow_type_json == 'LargeUtf8':
return pyarrow.large_utf8()
elif arrow_type_json == 'Float32':
return pyarrow.float32()
elif arrow_type_json == 'Float64':
return pyarrow.float64()
elif arrow_type_json == "Binary":
return pyarrow.binary()
elif arrow_type_json == "LargeBinary":
return pyarrow.large_binary()
elif arrow_type_json == "Null":
return pyarrow.null()
elif isinstance(arrow_type_json, dict):
if 'Decimal128' in arrow_type_json:
return pyarrow.decimal128(arrow_type_json['Decimal128'][0], arrow_type_json['Decimal128'][1])
elif 'Decimal256' in arrow_type_json:
return pyarrow.decimal256(arrow_type_json['Decimal256'][0], arrow_type_json['Decimal256'][1])
elif 'Interval' in arrow_type_json:
if arrow_type_json['Interval'] == 'DayTime':
return pyarrow.month_day_nano_interval()
elif arrow_type_json['Interval'] == 'YearMonth':
return pyarrow.month_day_nano_interval()
elif 'List' in arrow_type_json:
return pyarrow.list_(to_arrow_type(arrow_type_json['List']['data_type']))
elif 'FixedSizeList' in arrow_type_json:
return pyarrow.list_(to_arrow_type(arrow_type_json['FixedSizeList'][0]['data_type']),
arrow_type_json['FixedSizeList'][1])
elif 'Dictionary' in arrow_type_json:
return pyarrow.dictionary(arrow_type_json['Dictionary'][0], arrow_type_json['Dictionary'][1])
elif 'FixedSizeBinary' in arrow_type_json:
return pyarrow.binary(arrow_type_json['FixedSizeBinary'])
elif 'Map' in arrow_type_json:
return pyarrow.map_(to_arrow_type(arrow_type_json['Map'][0]['data_type']['Struct'][0]['data_type']),
to_arrow_type(arrow_type_json['Map'][0]['data_type']['Struct'][1]['data_type']),
arrow_type_json['Map'][1])
elif 'Struct' in arrow_type_json:
arrow_fields = []
for field in arrow_type_json['Struct']:
arrow_fields.append(to_arrow_field(field))
return pyarrow.struct(arrow_fields)
elif 'Time32' in arrow_type_json:
return pyarrow.time32('ms' if arrow_type_json['Time32'] == 'Millisecond' else 's')
elif 'Time64' in arrow_type_json:
return pyarrow.time64('us' if arrow_type_json['Time64'] == 'Microsecond' else 'ns')
elif 'Timestamp' in arrow_type_json:
unit = arrow_type_json['Timestamp'][0]
unit = 's' if unit == 'Second' else 'ms' if unit == 'Millisecond' else 'us' if unit == 'Microsecond' else 'ns'
return pyarrow.timestamp(unit, arrow_type_json['Timestamp'][1])
else:
raise IOError("Not supported spark type " + str(spark_type))
return pyarrow.field(spark_field_json['name'], arrow_type, spark_field_json['nullable'])
raise IOError("Not supported spark type " + str(arrow_type_json))


def to_arrow_schema(spark_schema_str, exclude_columns=None):
def to_arrow_field(arrow_field_json):
return pyarrow.field(arrow_field_json['name'], to_arrow_type(arrow_field_json['data_type']),
arrow_field_json['nullable'])


def to_arrow_schema(schema_json_str, exclude_columns=None):
exclude_columns = frozenset(exclude_columns or frozenset())
fields = json.loads(spark_schema_str)['fields']
_json = json.loads(schema_json_str)
fields = json.loads(schema_json_str)['fields']
arrow_fields = []
for field in fields:
if field['name'] in exclude_columns:
Expand Down
21 changes: 11 additions & 10 deletions python/metadata_poc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,19 @@

if __name__ == '__main__':
reset_pg_conf(
["host=localhost", "port=5432", " dbname=lakesoul_test", " user=lakesoul_test", "password=lakesoul_test"])
["host=localhost", "port=5433", " dbname=test_lakesoul_meta", " user=yugabyte", "password=yugabyte"])

db_manager = DBManager()
data_files = db_manager.get_data_files_by_table_name("titanic")
table_name = "test_datatypes"
data_files = db_manager.get_data_files_by_table_name(table_name)
print(data_files)
data_files = db_manager.get_data_files_by_table_name("titanic", partitions={"split": "train"})
data_files = db_manager.get_data_files_by_table_name(table_name)
print(data_files)
arrow_schema = db_manager.get_arrow_schema_by_table_name("titanic")
print(arrow_schema)
data_files = db_manager.get_data_files_by_table_name("imdb")
print(data_files)
data_files = db_manager.get_data_files_by_table_name("imdb", partitions={"split": "train"})
print(data_files)
arrow_schema = db_manager.get_arrow_schema_by_table_name("imdb")
arrow_schema = db_manager.get_arrow_schema_by_table_name(table_name)
print(arrow_schema)
# data_files = db_manager.get_data_files_by_table_name("imdb")
# print(data_files)
# data_files = db_manager.get_data_files_by_table_name("imdb", partitions={"split": "train"})
# print(data_files)
# arrow_schema = db_manager.get_arrow_schema_by_table_name("imdb")
# print(arrow_schema)
Loading

0 comments on commit 4dced35

Please sign in to comment.