From 04a78d560b27fffa4a6e50d0fa46aed3bbdf756b Mon Sep 17 00:00:00 2001 From: Cai Wingfield Date: Sat, 2 Dec 2023 07:12:38 +0000 Subject: [PATCH] Fix `zipfile`'s persnickety path specifications --- kymata/entities/expression.py | 4 ++-- kymata/io/nkg.py | 28 ++++++++++++---------------- kymata/io/yaml.py | 5 +++-- 3 files changed, 17 insertions(+), 20 deletions(-) diff --git a/kymata/entities/expression.py b/kymata/entities/expression.py index 948c29ab..0e507e21 100644 --- a/kymata/entities/expression.py +++ b/kymata/entities/expression.py @@ -73,7 +73,7 @@ def __init__(self, data_layers[layer] = [data_layers[layer]] assert len(functions) == len(set(functions)), "Duplicated functions in input" - for layer, data in data_layers: + for layer, data in data_layers.items(): assert len(functions) == len(data), _length_mismatch_message assert all_equal([arr.shape for _layer, arrs in data_layers.items() for arr in arrs]) @@ -84,7 +84,7 @@ def __init__(self, datasets = [] for i, f in enumerate(functions): dataset_dict = dict() - for layer, data in data_layers: + for layer, data in data_layers.items(): # Get this function's data data = data[i] data = self._init_prep_data(data) diff --git a/kymata/io/nkg.py b/kymata/io/nkg.py index 48da169c..de6db40a 100644 --- a/kymata/io/nkg.py +++ b/kymata/io/nkg.py @@ -25,7 +25,7 @@ class _Keys: expressionset_type = "expressionset-type" -class _ExpressionSetType: +class _ExpressionSetTypeIdentifier: hexel = "hexel" sensor = "sensor" @@ -54,7 +54,7 @@ def from_expression_set(cls, expression_set: ExpressionSet): def file_version(from_path_or_file: path_type | file_type) -> version.Version: with open_or_use(from_path_or_file, mode="rb") as archive, ZipFile(archive, "r") as zf: - with TextIOWrapper(zf.open("/_metadata/format-version.txt"), encoding="utf-8") as f: + with TextIOWrapper(zf.open("_metadata/format-version.txt"), encoding="utf-8") as f: return version.parse(str(f.read()).strip()) @@ -63,7 +63,7 @@ def load_expression_set(from_path_or_file: path_type | file_type) -> ExpressionS type_identifier = data_dict[_Keys.expressionset_type] - if type_identifier == _ExpressionSetType.hexel: + if type_identifier == _ExpressionSetTypeIdentifier.hexel: return HexelExpressionSet( functions=data_dict[_Keys.functions], hexels=[HexelDType(c) for c in data_dict[_Keys.channels]], @@ -73,7 +73,7 @@ def load_expression_set(from_path_or_file: path_type | file_type) -> ExpressionS data_rh=[data_dict[_Keys.data][LAYER_RIGHT][:, :, i] for i in range(len(data_dict[_Keys.functions]))], ) - elif type_identifier == _ExpressionSetType.sensor: + elif type_identifier == _ExpressionSetTypeIdentifier.sensor: return SensorExpressionSet( functions=data_dict[_Keys.functions], sensors=[SensorDType(c) for c in data_dict[_Keys.channels]], @@ -106,7 +106,7 @@ def save_expression_set(expression_set: ExpressionSet, with open_or_use(to_path_or_file, mode="wb") as f, ZipFile(f, "w", compression=compression) as zf: zf.writestr("_metadata/format-version.txt", CURRENT_VERSION) - zf.writestr("_metadata/expression-set-type.txt", _ExpressionSetType.from_expression_set(expression_set)) + zf.writestr("_metadata/expression-set-type.txt", _ExpressionSetTypeIdentifier.from_expression_set(expression_set)) zf.writestr("/channels.txt", "\n".join(str(x) for x in expression_set._channels)) zf.writestr("/latencies.txt", "\n".join(str(x) for x in expression_set.latencies)) zf.writestr("/functions.txt", "\n".join(str(x) for x in expression_set.functions)) @@ -158,7 +158,7 @@ def _load_data_current(from_path_or_file: path_type | file_type) -> dict[str, An return_dict = dict() with open_or_use(from_path_or_file, mode="rb") as archive, ZipFile(archive, "r") as zf: - with TextIOWrapper(zf.open("/_metadata/expression-set-type.txt"), encoding="utf-8") as f: + with TextIOWrapper(zf.open("_metadata/expression-set-type.txt"), encoding="utf-8") as f: return_dict[_Keys.expressionset_type] = str(f.read()).strip() with TextIOWrapper(zf.open("/layers.txt"), encoding="utf-8") as f: layers = [str(l.strip()) for l in f.readlines()] @@ -176,14 +176,14 @@ def _load_data_current(from_path_or_file: path_type | file_type) -> dict[str, An data: ndarray = frombuffer(f.read(), dtype=float) with TextIOWrapper(zf.open(f"/{layer}/coo-shape.txt"), encoding="utf-8") as f: shape: tuple[int, ...] = tuple(int(s.strip()) for s in f.readlines()) - return_dict[_Keys.data][layer] = COO(coords=coords, data=data, shape=shape, prune=True, fill_value=1.0) + sparse_data = COO(coords=coords, data=data, shape=shape, prune=True, fill_value=1.0) # In case there was only 1 function and we have a 2-d data matrix - # TODO: does this ever actually happen? if len(shape) == 2: - return_dict[_Keys.data][layer] = expand_dims(return_dict[_Keys.data][layer]) + sparse_data = expand_dims(sparse_data) assert shape == (len(return_dict[_Keys.channels]), len(return_dict[_Keys.latencies]), len(return_dict[_Keys.functions])) + return_dict[_Keys.data][layer] = sparse_data return return_dict @@ -217,11 +217,9 @@ def _load_data_0_1(from_path_or_file: path_type | file_type) -> dict[str, Any]: data: ndarray = frombuffer(f.read(), dtype=float) with TextIOWrapper(zf.open(f"/{layer}/coo-shape.txt"), encoding="utf-8") as f: shape: tuple[int, ...] = tuple(int(s.strip()) for s in f.readlines()) - sparse_data = COO(coords=coords, data=data, shape=shape, - prune=True, fill_value=1.0) + sparse_data = COO(coords=coords, data=data, shape=shape, prune=True, fill_value=1.0) # In case there was only 1 function and we have a 2-d data matrix - if len(shape) == 2: - # TODO: does this ever actually happen? + if len(sparse_data.shape) == 2: sparse_data = expand_dims(sparse_data) assert sparse_data.shape == ( @@ -229,7 +227,5 @@ def _load_data_0_1(from_path_or_file: path_type | file_type) -> dict[str, Any]: len(return_dict["latencies"]), len(return_dict["functions"]), ) - # Split by function (this is the expected format) - return_dict[layer]["data"] = [sparse_data[:, :, i] for i in range(len(return_dict["functions"]))] - + return_dict["data"][layer] = sparse_data return return_dict diff --git a/kymata/io/yaml.py b/kymata/io/yaml.py index fa3ec0f8..61c92d6c 100644 --- a/kymata/io/yaml.py +++ b/kymata/io/yaml.py @@ -1,6 +1,7 @@ import yaml -def load_config_parameters(file_location: String): + +def load_config_parameters(file_location: str): '''Load config parameters''' with open(file_location, "r") as stream: - return yaml.safe_load(stream) \ No newline at end of file + return yaml.safe_load(stream)