Skip to content

Commit

Permalink
Fix zipfile's persnickety path specifications
Browse files Browse the repository at this point in the history
  • Loading branch information
caiw committed Dec 2, 2023
1 parent ed110e6 commit 04a78d5
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 20 deletions.
4 changes: 2 additions & 2 deletions kymata/entities/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self,
data_layers[layer] = [data_layers[layer]]

assert len(functions) == len(set(functions)), "Duplicated functions in input"
for layer, data in data_layers:
for layer, data in data_layers.items():
assert len(functions) == len(data), _length_mismatch_message
assert all_equal([arr.shape for _layer, arrs in data_layers.items() for arr in arrs])

Expand All @@ -84,7 +84,7 @@ def __init__(self,
datasets = []
for i, f in enumerate(functions):
dataset_dict = dict()
for layer, data in data_layers:
for layer, data in data_layers.items():
# Get this function's data
data = data[i]
data = self._init_prep_data(data)
Expand Down
28 changes: 12 additions & 16 deletions kymata/io/nkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class _Keys:
expressionset_type = "expressionset-type"


class _ExpressionSetType:
class _ExpressionSetTypeIdentifier:
hexel = "hexel"
sensor = "sensor"

Expand Down Expand Up @@ -54,7 +54,7 @@ def from_expression_set(cls, expression_set: ExpressionSet):

def file_version(from_path_or_file: path_type | file_type) -> version.Version:
with open_or_use(from_path_or_file, mode="rb") as archive, ZipFile(archive, "r") as zf:
with TextIOWrapper(zf.open("/_metadata/format-version.txt"), encoding="utf-8") as f:
with TextIOWrapper(zf.open("_metadata/format-version.txt"), encoding="utf-8") as f:
return version.parse(str(f.read()).strip())


Expand All @@ -63,7 +63,7 @@ def load_expression_set(from_path_or_file: path_type | file_type) -> ExpressionS

type_identifier = data_dict[_Keys.expressionset_type]

if type_identifier == _ExpressionSetType.hexel:
if type_identifier == _ExpressionSetTypeIdentifier.hexel:
return HexelExpressionSet(
functions=data_dict[_Keys.functions],
hexels=[HexelDType(c) for c in data_dict[_Keys.channels]],
Expand All @@ -73,7 +73,7 @@ def load_expression_set(from_path_or_file: path_type | file_type) -> ExpressionS
data_rh=[data_dict[_Keys.data][LAYER_RIGHT][:, :, i]
for i in range(len(data_dict[_Keys.functions]))],
)
elif type_identifier == _ExpressionSetType.sensor:
elif type_identifier == _ExpressionSetTypeIdentifier.sensor:
return SensorExpressionSet(
functions=data_dict[_Keys.functions],
sensors=[SensorDType(c) for c in data_dict[_Keys.channels]],
Expand Down Expand Up @@ -106,7 +106,7 @@ def save_expression_set(expression_set: ExpressionSet,

with open_or_use(to_path_or_file, mode="wb") as f, ZipFile(f, "w", compression=compression) as zf:
zf.writestr("_metadata/format-version.txt", CURRENT_VERSION)
zf.writestr("_metadata/expression-set-type.txt", _ExpressionSetType.from_expression_set(expression_set))
zf.writestr("_metadata/expression-set-type.txt", _ExpressionSetTypeIdentifier.from_expression_set(expression_set))
zf.writestr("/channels.txt", "\n".join(str(x) for x in expression_set._channels))
zf.writestr("/latencies.txt", "\n".join(str(x) for x in expression_set.latencies))
zf.writestr("/functions.txt", "\n".join(str(x) for x in expression_set.functions))
Expand Down Expand Up @@ -158,7 +158,7 @@ def _load_data_current(from_path_or_file: path_type | file_type) -> dict[str, An
return_dict = dict()
with open_or_use(from_path_or_file, mode="rb") as archive, ZipFile(archive, "r") as zf:

with TextIOWrapper(zf.open("/_metadata/expression-set-type.txt"), encoding="utf-8") as f:
with TextIOWrapper(zf.open("_metadata/expression-set-type.txt"), encoding="utf-8") as f:
return_dict[_Keys.expressionset_type] = str(f.read()).strip()
with TextIOWrapper(zf.open("/layers.txt"), encoding="utf-8") as f:
layers = [str(l.strip()) for l in f.readlines()]
Expand All @@ -176,14 +176,14 @@ def _load_data_current(from_path_or_file: path_type | file_type) -> dict[str, An
data: ndarray = frombuffer(f.read(), dtype=float)
with TextIOWrapper(zf.open(f"/{layer}/coo-shape.txt"), encoding="utf-8") as f:
shape: tuple[int, ...] = tuple(int(s.strip()) for s in f.readlines())
return_dict[_Keys.data][layer] = COO(coords=coords, data=data, shape=shape, prune=True, fill_value=1.0)
sparse_data = COO(coords=coords, data=data, shape=shape, prune=True, fill_value=1.0)
# In case there was only 1 function and we have a 2-d data matrix
# TODO: does this ever actually happen?
if len(shape) == 2:
return_dict[_Keys.data][layer] = expand_dims(return_dict[_Keys.data][layer])
sparse_data = expand_dims(sparse_data)
assert shape == (len(return_dict[_Keys.channels]),
len(return_dict[_Keys.latencies]),
len(return_dict[_Keys.functions]))
return_dict[_Keys.data][layer] = sparse_data
return return_dict


Expand Down Expand Up @@ -217,19 +217,15 @@ def _load_data_0_1(from_path_or_file: path_type | file_type) -> dict[str, Any]:
data: ndarray = frombuffer(f.read(), dtype=float)
with TextIOWrapper(zf.open(f"/{layer}/coo-shape.txt"), encoding="utf-8") as f:
shape: tuple[int, ...] = tuple(int(s.strip()) for s in f.readlines())
sparse_data = COO(coords=coords, data=data, shape=shape,
prune=True, fill_value=1.0)
sparse_data = COO(coords=coords, data=data, shape=shape, prune=True, fill_value=1.0)
# In case there was only 1 function and we have a 2-d data matrix
if len(shape) == 2:
# TODO: does this ever actually happen?
if len(sparse_data.shape) == 2:
sparse_data = expand_dims(sparse_data)

assert sparse_data.shape == (
len(return_dict["hexels"]),
len(return_dict["latencies"]),
len(return_dict["functions"]),
)
# Split by function (this is the expected format)
return_dict[layer]["data"] = [sparse_data[:, :, i] for i in range(len(return_dict["functions"]))]

return_dict["data"][layer] = sparse_data
return return_dict
5 changes: 3 additions & 2 deletions kymata/io/yaml.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import yaml

def load_config_parameters(file_location: String):

def load_config_parameters(file_location: str):
'''Load config parameters'''
with open(file_location, "r") as stream:
return yaml.safe_load(stream)
return yaml.safe_load(stream)

0 comments on commit 04a78d5

Please sign in to comment.