diff --git a/pipestat/backends/file_backend/filebackend.py b/pipestat/backends/file_backend/filebackend.py index 6d77b0aa..29f40883 100644 --- a/pipestat/backends/file_backend/filebackend.py +++ b/pipestat/backends/file_backend/filebackend.py @@ -556,21 +556,21 @@ def get_nested_column(result_value: dict, key_list: list, retrieved_operator: Ca if bool_operator.lower() == "or" and filtered_records_list: shared_keys = list(set(chain(*filtered_records_list))) - + record = {} if shared_keys: for record_identifier in sorted(shared_keys): if columns: # Did the user specify a list of columns as well? - for key in list( + for key, value in list( self._data.data[self.pipeline_name][self.pipeline_type][ record_identifier - ].keys() + ].items() ): - if key not in columns: - self._data.data[self.pipeline_name][self.pipeline_type][ - record_identifier - ].pop(key) - - record = self._data.data[self.pipeline_name][self.pipeline_type][record_identifier] + if key in columns: + record.update({key: value}) + else: + record = self._data.data[self.pipeline_name][self.pipeline_type][ + record_identifier + ] record.update({"record_identifier": record_identifier}) records_list.append(record) diff --git a/tests/test_pipestat.py b/tests/test_pipestat.py index 5b8aef6d..19cc0fe7 100644 --- a/tests/test_pipestat.py +++ b/tests/test_pipestat.py @@ -89,7 +89,9 @@ def test_basics( # with pytest.raises(RecordNotFoundError): # psm.retrieve(record_identifier=rec_id) if backend == "db": - assert getattr(psm.retrieve(record_identifier=rec_id), val_name, None) is None + assert ( + psm.retrieve_one(record_identifier=rec_id)["records"][0].get(val_name) is None + ) psm.remove(record_identifier=rec_id) # with pytest.raises(RecordNotFoundError): # psm.retrieve(record_identifier=rec_id) @@ -442,9 +444,6 @@ def test_retrieve_basic( args.update(backend_data) psm = SamplePipestatManager(**args) psm.report(record_identifier=rec_id, values=val, force_overwrite=True) - # retrieved_val = psm.retrieve( - # record_identifier=rec_id, result_identifier=list(val.keys())[0] - # ) retrieved_val = psm.select_records( filter_conditions=[ { @@ -1333,29 +1332,28 @@ def test_basic_time_stamp( psm.report(record_identifier=rec_id, values=val, force_overwrite=True) # CHECK CREATION AND MODIFY TIME EXIST - created = psm.retrieve(record_identifier=rec_id, result_identifier=CREATED_TIME) - # - # created = psm.select_records(filter_conditions=[ - # { - # "key": "record_identifier", - # "operator": "eq", - # "value": rec_id, - # }, - # ], - # columns=[CREATED_TIME] - # )["records"][0][CREATED_TIME] - - modified = psm.retrieve(record_identifier=rec_id, result_identifier=MODIFIED_TIME) - - # modified = psm.select_records(filter_conditions=[ - # { - # "key": "record_identifier", - # "operator": "eq", - # "value": rec_id, - # }, - # ], - # columns=[MODIFIED_TIME] - # )["records"][0][MODIFIED_TIME] + + created = psm.select_records( + filter_conditions=[ + { + "key": "record_identifier", + "operator": "eq", + "value": rec_id, + }, + ], + columns=[CREATED_TIME], + )["records"][0][CREATED_TIME] + + modified = psm.select_records( + filter_conditions=[ + { + "key": "record_identifier", + "operator": "eq", + "value": rec_id, + }, + ], + columns=[MODIFIED_TIME], + )["records"][0][MODIFIED_TIME] assert created is not None assert modified is not None @@ -1367,8 +1365,28 @@ def test_basic_time_stamp( ) # The filebackend is so fast that the updated time will equal the created time psm.report(record_identifier="sample1", values=val, force_overwrite=True) # CHECK MODIFY TIME DIFFERS FROM CREATED TIME - created = psm.retrieve(record_identifier=rec_id, result_identifier=CREATED_TIME) - modified = psm.retrieve(record_identifier=rec_id, result_identifier=MODIFIED_TIME) + created = psm.select_records( + filter_conditions=[ + { + "key": "record_identifier", + "operator": "eq", + "value": rec_id, + }, + ], + columns=[CREATED_TIME], + )["records"][0][CREATED_TIME] + + modified = psm.select_records( + filter_conditions=[ + { + "key": "record_identifier", + "operator": "eq", + "value": rec_id, + }, + ], + columns=[MODIFIED_TIME], + )["records"][0][MODIFIED_TIME] + assert created != modified @pytest.mark.parametrize("backend", ["db", "file"])