diff --git a/pipestat/pipestat.py b/pipestat/pipestat.py index c0be5521..20038e16 100644 --- a/pipestat/pipestat.py +++ b/pipestat/pipestat.py @@ -235,9 +235,10 @@ def __getitem__(self, key): print(key) def __setitem__(self, key, value): - # self.cfg[self._keytransform(key)] = value - print(key) - # self.report() + # This is a wrapper for the report function: + result = self.report(record_identifier=key, values=value) + return result + def __delitem__(self, key): del self.cfg[self._keytransform(key)] diff --git a/tests/test_pipestat.py b/tests/test_pipestat.py index dd053c75..54f76967 100644 --- a/tests/test_pipestat.py +++ b/tests/test_pipestat.py @@ -195,6 +195,50 @@ def test_report_basic( # This is being captured in TestSplitClasses pass + @pytest.mark.parametrize( + ["rec_id", "val"], + [ + ("sample1", {"name_of_something": "test_name"}), + ("sample1", {"number_of_things": 1}), + ], + ) + @pytest.mark.parametrize("backend", ["file", "db"]) + def test_report_setitem( + self, + rec_id, + val, + config_file_path, + schema_file_path, + results_file_path, + backend, + ): + with NamedTemporaryFile() as f, ContextManagerDBTesting(DB_URL): + results_file_path = f.name + args = dict(schema_path=schema_file_path, database_only=False) + backend_data = ( + {"config_file": config_file_path} + if backend == "db" + else {"results_file_path": results_file_path} + ) + args.update(backend_data) + psm = SamplePipestatManager(**args) + #psm.report(record_identifier=rec_id, values=val, force_overwrite=True) + psm[rec_id] = val + if backend == "file": + print(psm.backend._data[STANDARD_TEST_PIPE_ID]) + print("Test if", rec_id, " is in ", psm.backend._data[STANDARD_TEST_PIPE_ID]) + assert rec_id in psm.backend._data[STANDARD_TEST_PIPE_ID][PROJECT_SAMPLE_LEVEL] + print("Test if", list(val.keys())[0], " is in ", rec_id) + assert ( + list(val.keys())[0] + in psm.backend._data[STANDARD_TEST_PIPE_ID][PROJECT_SAMPLE_LEVEL][rec_id] + ) + if backend == "file": + assert_is_in_files(results_file_path, str(list(val.values())[0])) + if backend == "db": + # This is being captured in TestSplitClasses + pass + @pytest.mark.parametrize( ["rec_id", "val"], [