diff --git a/pipestat/backends/dbbackend.py b/pipestat/backends/dbbackend.py index 2ec36295..9c516b8a 100644 --- a/pipestat/backends/dbbackend.py +++ b/pipestat/backends/dbbackend.py @@ -552,7 +552,8 @@ def retrieve_multiple( "limit": limit, "offset": offset, "record_identifiers": record_identifier, - "result_identifiers": result_identifier, + "result_identifiers": result_identifier + or list(self.parsed_schema.results_data.keys()) + [CREATED_TIME] + [MODIFIED_TIME], "records": record_list, } diff --git a/pipestat/backends/filebackend.py b/pipestat/backends/filebackend.py index 14516c24..e14636b3 100644 --- a/pipestat/backends/filebackend.py +++ b/pipestat/backends/filebackend.py @@ -462,7 +462,42 @@ def retrieve_multiple( limit: Optional[int] = 1000, offset: Optional[int] = 0, ) -> Union[Any, Dict[str, Any]]: - pass + record_list = [] + + if result_identifier == []: + result_identifier = ( + list(self.parsed_schema.results_data.keys()) + [CREATED_TIME] + [MODIFIED_TIME] + ) + if record_identifier == [] or record_identifier is None: + record_identifier = list( + self._data.data[self.pipeline_name][self.pipeline_type].keys() + ) + + for k in list(self._data.data[self.pipeline_name][self.pipeline_type].keys())[ + offset : offset + limit + ]: + if k in record_identifier: + retrieved_record = {} + retrieved_results = {} + for key, value in self._data.data[self.pipeline_name][self.pipeline_type][ + k + ].items(): + if key in result_identifier: + retrieved_results.update({key: value}) + + if retrieved_results != {}: + retrieved_record.update({k: retrieved_results}) + record_list.append(retrieved_record) + + records_dict = { + "count": len(record_list), + "limit": limit, + "offset": offset, + "record_identifiers": record_identifier, + "result_identifiers": result_identifier, + "records": record_list, + } + return records_dict def retrieve( self, diff --git a/tests/test_pipestat.py b/tests/test_pipestat.py index a51844d4..617a0124 100644 --- a/tests/test_pipestat.py +++ b/tests/test_pipestat.py @@ -394,7 +394,7 @@ def test_retrieve_basic( # Test Retrieve Whole Record assert isinstance(psm.retrieve(record_identifier=rec_id), Mapping) - @pytest.mark.parametrize("backend", ["db"]) + @pytest.mark.parametrize("backend", ["file", "db"]) def test_retrieve_multiple( self, config_file_path, @@ -445,11 +445,11 @@ def test_retrieve_multiple( # Test combinations of empty list for either record or result identifiers. results = psm.retrieve(record_identifier=r_ids, result_identifier=[]) - assert results["result_identifiers"] is None + assert len(results["result_identifiers"]) == 9 assert len(results["records"]) == 2 results = psm.retrieve(record_identifier=[], result_identifier=[]) - assert results["result_identifiers"] is None + assert len(results["result_identifiers"]) == 9 assert len(results["records"]) == 3 results = psm.retrieve(record_identifier=[], result_identifier=res_id)