Skip to content

Commit

Permalink
get tests to pass
Browse files Browse the repository at this point in the history
  • Loading branch information
zigaLuksic committed Aug 18, 2023
1 parent 1abf1aa commit d50d125
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 11 deletions.
1 change: 0 additions & 1 deletion eogrow/pipelines/batch_to_eopatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def filter_patch_list(self, patch_list: PatchList) -> PatchList:
def _get_output_features(self) -> list[tuple[FeatureType, str]]:
"""Lists all features that the pipeline outputs."""
features = [feature_mapping.feature for feature_mapping in self.config.mapping]
features.extend(feature_mapping.feature for feature_mapping in self.config.mapping)

if self.config.userdata_feature_name:
features.append((FeatureType.META_INFO, self.config.userdata_feature_name))
Expand Down
3 changes: 2 additions & 1 deletion eogrow/pipelines/merge_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def build_workflow(self) -> EOWorkflow:
self.storage.get_folder(self.config.input_folder_key),
filesystem=self.storage.filesystem,
features=self.config.features_to_merge,
load_timestamps=True if self.config.include_timestamp else "auto",
)
output_task = OutputTask(name=self._OUTPUT_NAME)
return EOWorkflow(linearly_connect_tasks(load_task, output_task))
Expand Down Expand Up @@ -101,7 +102,7 @@ def merge_and_save_features(self, patches: list[EOPatch]) -> None:
if self.config.include_timestamp:
arrays = []
for patch, sample_num in zip(patches, patch_sample_nums):
arrays.append(np.tile(np.array(patch.timestamps), (sample_num, 1)))
arrays.append(np.tile(np.array(patch.get_timestamps()), (sample_num, 1)))
patch.timestamps = []

self._save_array(np.concatenate(arrays, axis=0), "TIMESTAMPS")
Expand Down
1 change: 1 addition & 0 deletions eogrow/pipelines/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def build_workflow(self) -> EOWorkflow:
self.storage.get_folder(self.config.output_folder_key),
filesystem=self.storage.filesystem,
overwrite_permission=OverwritePermission.OVERWRITE_FEATURES,
save_timestamps=self.config.timestamps is not None,
)
save_node = EONode(save_task, inputs=[previous_node])

Expand Down
2 changes: 1 addition & 1 deletion eogrow/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _calculate_stats(self, folder: str | None = None) -> JsonDict:
if self.filesystem.isdir(content_path):
fs_data_info = get_filesystem_data_info(self.filesystem, content_path)
if fs_data_info.bbox is not None:
eopatch = EOPatch.load(content_path, filesystem=self.filesystem)
eopatch = EOPatch.load(content_path, filesystem=self.filesystem, load_timestamps=True)
stats[content] = self._calculate_eopatch_stats(eopatch)
else: # Probably it is not an EOPatch folder
stats[content] = self._calculate_stats(folder=content_path)
Expand Down
14 changes: 7 additions & 7 deletions tests/pipelines/test_merge_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

@pytest.mark.chain()
@pytest.mark.order(after="test_features.py::test_features_pipeline")
@pytest.mark.parametrize(
("experiment_name", "reset_folder"),
[("merge_features_samples", True), ("merge_reference_samples", False)],
)
def test_merge_samples_pipeline(config_and_stats_paths, experiment_name, reset_folder):
config_path, stats_path = config_and_stats_paths("merge_samples", experiment_name)
output_path = run_config(config_path, reset_output_folder=reset_folder)
def test_merge_samples_pipeline(config_and_stats_paths):
config_path, stats_path = config_and_stats_paths("merge_samples", "merge_features_samples")
output_path = run_config(config_path, reset_output_folder=True)
compare_content(output_path, stats_path)

config_path, stats_path = config_and_stats_paths("merge_samples", "merge_reference_samples")
output_path = run_config(config_path, reset_output_folder=False)
compare_content(output_path, stats_path)
2 changes: 1 addition & 1 deletion tests/pipelines/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

@pytest.mark.chain()
@pytest.mark.parametrize("experiment_name", ["testing", "timestamps_only"])
def test_features_pipeline(config_and_stats_paths, experiment_name):
def test_data_generating_pipeline(config_and_stats_paths, experiment_name):
config_path, stats_path = config_and_stats_paths("testing", experiment_name)
output_path = run_config(config_path)
compare_content(output_path, stats_path)

0 comments on commit d50d125

Please sign in to comment.