Skip to content

Commit

Permalink
Move save/load code to a new location to avoid circular references
Browse files Browse the repository at this point in the history
  • Loading branch information
caiw committed Dec 1, 2023
1 parent c77d848 commit e8ab655
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 202 deletions.
10 changes: 5 additions & 5 deletions demos/demo_save_load.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"from tempfile import NamedTemporaryFile\n",
"\n",
"from kymata.datasets.sample import KymataMirror2023Q3Dataset, TVLInsLoudnessOnlyDataset, TVLDeltaInsTC1LoudnessOnlyDataset\n",
"from kymata.entities.expression import HexelExpressionSet"
"from kymata.io.nkg import save_expression_set, load_expression_set"
]
},
{
Expand Down Expand Up @@ -77,7 +77,7 @@
"outputs": [],
"source": [
"# Let's load the KymataMirror2023Q3 .nkg file. This contains around 30 functions.\n",
"expression_data_kymata_mirror = HexelExpressionSet.load(from_path_or_file=nkg_path)"
"expression_data_kymata_mirror = load_expression_set(from_path_or_file=nkg_path)"
],
"metadata": {
"collapsed": false,
Expand All @@ -95,8 +95,8 @@
"source": [
"# Let's seperately load the 'ins_loudness' .nkg file, and then load and add the\n",
"# d_ins_tc1_loudness to it using '+='. 'expression_data_new_results' now contains two functions.\n",
"expression_data_new_results = HexelExpressionSet.load(from_path_or_file=ins_loudness_path)\n",
"expression_data_new_results += HexelExpressionSet.load(from_path_or_file=d_ins_tc1_loudness_path)"
"expression_data_new_results = load_expression_set(from_path_or_file=ins_loudness_path)\n",
"expression_data_new_results += load_expression_set(from_path_or_file=d_ins_tc1_loudness_path)"
],
"metadata": {
"collapsed": false,
Expand Down Expand Up @@ -126,7 +126,7 @@
"\n",
"# Save new expressionSet for use again in the future.\n",
"with NamedTemporaryFile() as tf:\n",
" expression_data_extended.save(tf)"
" save_expression_set(expression_data_extended, tf)"
],
"metadata": {
"collapsed": false,
Expand Down
13 changes: 10 additions & 3 deletions kymata/datasets/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from kymata.entities.expression import HexelExpressionSet
from kymata.io.file import path_type
from kymata.io.nkg import load_expression_set

_DATA_PATH_ENVIRONMENT_VAR_NAME = "KYMATA_DATA_ROOT"
_DATA_DIR_NAME = "kymata-toolbox-data"
Expand Down Expand Up @@ -78,7 +79,9 @@ def __init__(self, data_root: Optional[path_type] = None, download: bool = True)
)

def to_expressionset(self) -> HexelExpressionSet:
return HexelExpressionSet.load(from_path_or_file=Path(self.path, self.filenames[0]))
es = load_expression_set(from_path_or_file=Path(self.path, self.filenames[0]))
assert isinstance(es, HexelExpressionSet)
return es


class TVLInsLoudnessOnlyDataset(SampleDataset):
Expand All @@ -95,7 +98,9 @@ def __init__(self, data_root: Optional[path_type] = None, download: bool = True)
)

def to_expressionset(self) -> HexelExpressionSet:
return HexelExpressionSet.load(from_path_or_file=Path(self.path, self.filenames[0]))
es = load_expression_set(from_path_or_file=Path(self.path, self.filenames[0]))
assert isinstance(es, HexelExpressionSet)
return es


class TVLDeltaInsTC1LoudnessOnlyDataset(SampleDataset):
Expand All @@ -112,7 +117,9 @@ def __init__(self, data_root: Optional[path_type] = None, download: bool = True)
)

def to_expressionset(self) -> HexelExpressionSet:
return HexelExpressionSet.load(from_path_or_file=Path(self.path, self.filenames[0]))
es = load_expression_set(from_path_or_file=Path(self.path, self.filenames[0]))
assert isinstance(es, HexelExpressionSet)
return es


def data_root_path(data_root: Optional[path_type] = None) -> Path:
Expand Down
12 changes: 12 additions & 0 deletions kymata/entities/datatypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from numpy import int_, str_, float_


Hexel = int # Todo: change this and others to `type Hexel = int` on dropping support for python <3.12
Sensor = str
Latency = float

# Set consistent dtypes for use in arrays
HexelDType = int_
SensorDType = str_
LatencyDType = float_
FunctionNameDType = str_
Loading

0 comments on commit e8ab655

Please sign in to comment.