Skip to content

Commit

Permalink
fix getitem implementation, add test #99
Browse files Browse the repository at this point in the history
  • Loading branch information
donaldcampbelljr committed Oct 26, 2023
1 parent df4eed8 commit 4b96cc9
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 28 deletions.
14 changes: 7 additions & 7 deletions pipestat/pipestat.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def __str__(self):
res += f"\nStatus Schema key: {self.cfg[STATUS_SCHEMA_KEY]}"
res += f"\nResults formatter: {str(self.cfg[RESULT_FORMATTER].__name__)}"
res += f"\nResults schema source: {self.cfg[SCHEMA_PATH]}"
res += f"\nStatus schema source: {self.status_schema_source}"
res += f"\nStatus schema source: {self.cfg[STATUS_SCHEMA_SOURCE_KEY]}"
res += f"\nRecords count: {self.record_count}"
if self.cfg[SCHEMA_PATH] is not None:
high_res = self.highlighted_results
Expand All @@ -231,15 +231,15 @@ def __str__(self):
return res

def __getitem__(self, key):
# return self.cfg[self._keytransform(key)]
print(key)
# This is a wrapper for the retrieve function:
result = self.retrieve(record_identifier=key)
return result

def __setitem__(self, key, value):
# This is a wrapper for the report function:
result = self.report(record_identifier=key, values=value)
return result


def __delitem__(self, key):
del self.cfg[self._keytransform(key)]

Expand Down Expand Up @@ -488,7 +488,7 @@ def remove(
:return bool: whether the result has been removed
"""

r_id = record_identifier or self.record_identifier
r_id = record_identifier or self.cfg[RECORD_IDENTIFIER]
return self.backend.remove(
record_identifier=r_id,
result_identifier=result_identifier,
Expand Down Expand Up @@ -518,7 +518,7 @@ def report(
:return str reported_results: return list of formatted string
"""

result_formatter = result_formatter or self[RESULT_FORMATTER]
result_formatter = result_formatter or self.cfg[RESULT_FORMATTER]
values = deepcopy(values)
r_id = record_identifier or self.cfg[RECORD_IDENTIFIER]
if r_id is None:
Expand Down Expand Up @@ -595,7 +595,7 @@ def retrieve(
record_identifier, result_identifier, limit, offset
)

r_id = record_identifier or self.record_identifier
r_id = record_identifier or self.cfg[RECORD_IDENTIFIER]

return self.backend.retrieve(r_id, result_identifier)

Expand Down
22 changes: 11 additions & 11 deletions pipestat/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def create_sample_html(self, sample_stats, navbar, footer, sample_name):
flag = flag_dict["flag"]
highlighted_results = fetch_pipeline_results(
project=self.prj,
pipeline_name=self.prj.pipeline_name,
pipeline_name=self.prj.cfg[PIPELINE_NAME],
sample_name=sample_name,
inclusion_fun=lambda x: x == "file",
highlighted=True,
Expand Down Expand Up @@ -482,7 +482,7 @@ def create_index_html(self, navbar, footer):
# Add stats_summary.tsv button link
stats_file_path = get_file_for_project(
prj=self.prj,
pipeline_name=self.prj.pipeline_name,
pipeline_name=self.prj.cfg[PIPELINE_NAME],
appendix="stats_summary.tsv",
reportdir=self.reports_dir,
)
Expand All @@ -495,7 +495,7 @@ def create_index_html(self, navbar, footer):
# Add objects_summary.yaml button link
objs_file_path = get_file_for_project(
prj=self.prj,
pipeline_name=self.prj.pipeline_name,
pipeline_name=self.prj.cfg[PIPELINE_NAME],
appendix="objs_summary.yaml",
reportdir=self.reports_dir,
)
Expand Down Expand Up @@ -548,7 +548,7 @@ def create_index_html(self, navbar, footer):
)
# Create status page with each sample's status listed
status_tab = create_status_table(
pipeline_name=self.prj.pipeline_name,
pipeline_name=self.prj.cfg[PIPELINE_NAME],
project=self.prj,
pipeline_reports_dir=self.pipeline_reports,
)
Expand All @@ -571,8 +571,8 @@ def create_index_html(self, navbar, footer):
columns=columns,
columns_json=dumps(columns),
table_row_data=table_row_data,
project_name=self.prj.project_name,
pipeline_name=self.prj.pipeline_name,
project_name=self.prj.cfg[PROJECT_NAME],
pipeline_name=self.prj.cfg[PIPELINE_NAME],
stats_json=self._stats_to_json_str(),
footer=footer,
amendments="",
Expand Down Expand Up @@ -619,7 +619,7 @@ def _stats_to_json_str(self):
results[sample_name] = fetch_pipeline_results(
project=self.prj,
sample_name=sample_name,
pipeline_name=self.prj.pipeline_name,
pipeline_name=self.prj.cfg[PIPELINE_NAME],
inclusion_fun=lambda x: x not in OBJECT_TYPES,
casting_fun=str,
)
Expand Down Expand Up @@ -983,10 +983,10 @@ def get_file_for_project(prj, pipeline_name, appendix=None, directory=None, repo
output_dir = output_dir or results_file_path or config_path
output_dir = os.path.dirname(output_dir)
reportdir = os.path.join(output_dir, "reports")
if prj["project_name"] is None:
if prj.cfg["project_name"] is None:
fp = os.path.join(reportdir, directory or "", f"NO_PROJECT_NAME_{pipeline_name}")
else:
fp = os.path.join(reportdir, directory or "", f"{prj['project_name']}_{pipeline_name}")
fp = os.path.join(reportdir, directory or "", f"{prj.cfg['project_name']}_{pipeline_name}")

if hasattr(prj, "amendments") and getattr(prj, "amendments"):
fp += f"_{'_'.join(prj.amendments)}"
Expand Down Expand Up @@ -1040,7 +1040,7 @@ def _create_stats_objs_summaries(prj, pipeline_name) -> List[str]:
reported_stats = []
stats = []

if prj.pipeline_type == "sample":
if prj.cfg[PIPELINE_TYPE] == "sample":
columns = ["Sample Index", "Sample Name", "Results"]
else:
columns = ["Sample Index", "Project Name", "Sample Name", "Results"]
Expand All @@ -1051,7 +1051,7 @@ def _create_stats_objs_summaries(prj, pipeline_name) -> List[str]:
record_index += 1
record_name = record

if prj.pipeline_type == "sample":
if prj.cfg[PIPELINE_TYPE] == "sample":
reported_stats = [record_index, record_name]
rep_data = prj.retrieve(record_identifier=record_name)
else:
Expand Down
54 changes: 44 additions & 10 deletions tests/test_pipestat.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,13 @@ def test_report_basic(
)
@pytest.mark.parametrize("backend", ["file", "db"])
def test_report_setitem(
self,
rec_id,
val,
config_file_path,
schema_file_path,
results_file_path,
backend,
self,
rec_id,
val,
config_file_path,
schema_file_path,
results_file_path,
backend,
):
with NamedTemporaryFile() as f, ContextManagerDBTesting(DB_URL):
results_file_path = f.name
Expand All @@ -222,16 +222,16 @@ def test_report_setitem(
)
args.update(backend_data)
psm = SamplePipestatManager(**args)
#psm.report(record_identifier=rec_id, values=val, force_overwrite=True)
# psm.report(record_identifier=rec_id, values=val, force_overwrite=True)
psm[rec_id] = val
if backend == "file":
print(psm.backend._data[STANDARD_TEST_PIPE_ID])
print("Test if", rec_id, " is in ", psm.backend._data[STANDARD_TEST_PIPE_ID])
assert rec_id in psm.backend._data[STANDARD_TEST_PIPE_ID][PROJECT_SAMPLE_LEVEL]
print("Test if", list(val.keys())[0], " is in ", rec_id)
assert (
list(val.keys())[0]
in psm.backend._data[STANDARD_TEST_PIPE_ID][PROJECT_SAMPLE_LEVEL][rec_id]
list(val.keys())[0]
in psm.backend._data[STANDARD_TEST_PIPE_ID][PROJECT_SAMPLE_LEVEL][rec_id]
)
if backend == "file":
assert_is_in_files(results_file_path, str(list(val.values())[0]))
Expand Down Expand Up @@ -438,6 +438,40 @@ def test_retrieve_basic(
# Test Retrieve Whole Record
assert isinstance(psm.retrieve(record_identifier=rec_id), Mapping)

@pytest.mark.parametrize(
["rec_id", "val"],
[
("sample1", {"name_of_something": "test_name"}),
("sample1", {"number_of_things": 2}),
],
)
@pytest.mark.parametrize("backend", ["file", "db"])
def test_retrieve_getitem(
self,
rec_id,
val,
config_file_path,
results_file_path,
schema_file_path,
backend,
):
with NamedTemporaryFile() as f, ContextManagerDBTesting(DB_URL):
results_file_path = f.name
args = dict(schema_path=schema_file_path, database_only=False)
backend_data = (
{"config_file": config_file_path}
if backend == "db"
else {"results_file_path": results_file_path}
)
args.update(backend_data)
psm = SamplePipestatManager(**args)
psm.report(record_identifier=rec_id, values=val, force_overwrite=True)
retrieved_val = psm[rec_id]
# Test Retrieve Basic
assert str(retrieved_val[str(list(retrieved_val.keys())[0])]) == str(
list(val.values())[0]
)

@pytest.mark.parametrize("backend", ["file", "db"])
def test_retrieve_multiple(
self,
Expand Down

0 comments on commit 4b96cc9

Please sign in to comment.