diff --git a/pipestat/backends/abstract.py b/pipestat/backends/abstract.py index 235605fa..08133159 100644 --- a/pipestat/backends/abstract.py +++ b/pipestat/backends/abstract.py @@ -184,6 +184,24 @@ def retrieve( _LOGGER.warning("Not implemented yet for this backend") pass + def retrieve_multiple( + self, + record_identifier: Optional[List[str]] = None, + result_identifier: Optional[List[str]] = None, + limit: Optional[int] = 1000, + offset: Optional[int] = 0, + ) -> Union[Any, Dict[str, Any]]: + """ + :param List[str] record_identifier: list of record identifiers + :param List[str] result_identifier: list of result identifiers to be retrieved + :param int limit: limit number of records to this amount + :param int offset: offset records by this amount + :return Dict[str, any]: a mapping with filtered results reported for the record + """ + + _LOGGER.warning("Not implemented yet for this backend") + pass + def remove( self, record_identifier: Optional[str] = None, diff --git a/pipestat/backends/dbbackend.py b/pipestat/backends/dbbackend.py index fdc6c4d8..bc1c19ea 100644 --- a/pipestat/backends/dbbackend.py +++ b/pipestat/backends/dbbackend.py @@ -501,6 +501,72 @@ def retrieve( } raise RecordNotFoundError(f"Record '{record_identifier}' not found") + def retrieve_multiple( + self, + record_identifier: Optional[List[str]] = None, + result_identifier: Optional[List[str]] = None, + limit: Optional[int] = 1000, + offset: Optional[int] = 0, + ) -> Union[Any, Dict[str, Any]]: + """ + :param List[str] record_identifier: list of record identifiers + :param List[str] result_identifier: list of result identifiers to be retrieved + :param int limit: limit number of records to this amount + :param int offset: offset records by this amount + :return Dict[str, any]: a mapping with filtered results reported for the record + """ + + record_list = [] + + if result_identifier == []: + result_identifier = None + if record_identifier == []: + record_identifier = None + + ORM = self.get_model(table_name=self.table_name) + + if record_identifier is not None: + for r_id in record_identifier: + filter = [("record_identifier", "eq", r_id)] + result = self.select( + columns=result_identifier, filter_conditions=filter, limit=limit, offset=offset + ) + retrieved_record = {} + result_dict = dict(result[0]) + for k, v in list(result_dict.items()): + if k not in self.parsed_schema.results_data.keys(): + result_dict.pop(k) + retrieved_record.update({r_id: result_dict}) + record_list.append(retrieved_record) + if record_identifier is None: + if result_identifier is not None: + result_identifier = ["record_identifier"] + result_identifier + record_list = [] + records = self.select( + columns=result_identifier, filter_conditions=None, limit=limit, offset=offset + ) + for record in records: + retrieved_record = {} + r_id = record.record_identifier + record_dict = dict(record) + for k, v in list(record_dict.items()): + if k not in self.parsed_schema.results_data.keys(): + record_dict.pop(k) + retrieved_record.update({r_id: record_dict}) + record_list.append(retrieved_record) + + records_dict = { + "count": len(record_list), + "limit": limit, + "offset": offset, + "record_identifiers": record_identifier, + "result_identifiers": result_identifier + or list(self.parsed_schema.results_data.keys()) + [CREATED_TIME] + [MODIFIED_TIME], + "records": record_list, + } + + return records_dict + def select( self, columns: Optional[List[str]] = None, diff --git a/pipestat/backends/filebackend.py b/pipestat/backends/filebackend.py index 3364d92e..9fa3c6e3 100644 --- a/pipestat/backends/filebackend.py +++ b/pipestat/backends/filebackend.py @@ -455,6 +455,58 @@ def report( return results_formatted + def retrieve_multiple( + self, + record_identifier: Optional[List[str]] = None, + result_identifier: Optional[List[str]] = None, + limit: Optional[int] = 1000, + offset: Optional[int] = 0, + ) -> Union[Any, Dict[str, Any]]: + """ + :param List[str] record_identifier: list of record identifiers + :param List[str] result_identifier: list of result identifiers to be retrieved + :param int limit: limit number of records to this amount + :param int offset: offset records by this amount + :return Dict[str, any]: a mapping with filtered results reported for the record + """ + + record_list = [] + + if result_identifier == [] or result_identifier is None: + result_identifier = ( + list(self.parsed_schema.results_data.keys()) + [CREATED_TIME] + [MODIFIED_TIME] + ) + if record_identifier == [] or record_identifier is None: + record_identifier = list( + self._data.data[self.pipeline_name][self.pipeline_type].keys() + ) + + for k in list(self._data.data[self.pipeline_name][self.pipeline_type].keys())[ + offset : offset + limit + ]: + if k in record_identifier: + retrieved_record = {} + retrieved_results = {} + for key, value in self._data.data[self.pipeline_name][self.pipeline_type][ + k + ].items(): + if key in result_identifier: + retrieved_results.update({key: value}) + + if retrieved_results != {}: + retrieved_record.update({k: retrieved_results}) + record_list.append(retrieved_record) + + records_dict = { + "count": len(record_list), + "limit": limit, + "offset": offset, + "record_identifiers": record_identifier, + "result_identifiers": result_identifier, + "records": record_list, + } + return records_dict + def retrieve( self, record_identifier: Optional[str] = None, diff --git a/pipestat/pipestat.py b/pipestat/pipestat.py index eba78cc4..e8a58e32 100644 --- a/pipestat/pipestat.py +++ b/pipestat/pipestat.py @@ -497,8 +497,10 @@ def retrieve_distinct( @require_backend def retrieve( self, - record_identifier: Optional[str] = None, - result_identifier: Optional[str] = None, + record_identifier: Optional[Union[str, List[str]]] = None, + result_identifier: Optional[Union[str, List[str]]] = None, + limit: Optional[int] = 1000, + offset: Optional[int] = 0, ) -> Union[Any, Dict[str, Any]]: """ Retrieve a result for a record. @@ -506,13 +508,31 @@ def retrieve( If no result ID specified, results for the entire record will be returned. - :param str record_identifier: name of the sample_level record - :param str result_identifier: name of the result to be retrieved - :return any | Dict[str, any]: a single result or a mapping with all the + :param str | List[str] record_identifier: name of the sample_level record + :param str | List[str] result_identifier: name of the result to be retrieved + :param int limit: limit number of records to this amount + :param int offset: offset records by this amount + :return any | Dict[str, any]: a single result or a mapping with filtered results reported for the record """ + if record_identifier is None and result_identifier is None: + # This will retrieve all records and columns. + return self.backend.retrieve_multiple( + record_identifier, result_identifier, limit, offset + ) + + if type(record_identifier) is list or type(result_identifier) is list: + if len(record_identifier) == 1 and len(result_identifier) == 1: + # If user gives single values, just use retrieve. + return self.backend.retrieve(record_identifier[0], result_identifier[0]) + else: + # If user gives lists, retrieve_multiple + return self.backend.retrieve_multiple( + record_identifier, result_identifier, limit, offset + ) r_id = record_identifier or self.record_identifier + return self.backend.retrieve(r_id, result_identifier) @require_backend diff --git a/tests/test_pipestat.py b/tests/test_pipestat.py index 95037b98..617a0124 100644 --- a/tests/test_pipestat.py +++ b/tests/test_pipestat.py @@ -395,6 +395,68 @@ def test_retrieve_basic( assert isinstance(psm.retrieve(record_identifier=rec_id), Mapping) @pytest.mark.parametrize("backend", ["file", "db"]) + def test_retrieve_multiple( + self, + config_file_path, + results_file_path, + schema_file_path, + backend, + ): + values_sample = [ + {"sample1": {"name_of_something": "string 1"}}, + {"sample1": {"number_of_things": 1}}, + {"sample2": {"name_of_something": "string 2"}}, + {"sample2": {"number_of_things": 20}}, + {"sample3": {"name_of_something": "string 3"}}, + {"sample3": {"number_of_things": 300}}, + ] + + with NamedTemporaryFile() as f, ContextManagerDBTesting(DB_URL): + results_file_path = f.name + args = dict(schema_path=schema_file_path, database_only=False) + backend_data = ( + {"config_file": config_file_path} + if backend == "db" + else {"results_file_path": results_file_path} + ) + args.update(backend_data) + psm = SamplePipestatManager(**args) + + for i in values_sample: + for k, v in i.items(): + psm.report(record_identifier=k, values=v, force_overwrite=True) + + # Test singular list works as expected + r_id = list(values_sample[0].keys())[0] + res_id = list(list(values_sample[0].values())[0].keys())[0] + results = psm.retrieve(record_identifier=[r_id], result_identifier=[res_id]) + assert results == list(list(values_sample[0].values())[0].values())[0] + + # Use list of results + r_ids = ["sample1", "sample2"] + res_id = ["md5sum", "number_of_things"] + # res_id = list(list(values_sample[0].values())[0].keys())[0] + results = psm.retrieve(record_identifier=r_ids, result_identifier=res_id) + assert r_ids[0] == list(results["records"][0].keys())[0] + assert ( + list(list(values_sample[3].values())[0].values())[0] + == list(results["records"][1].values())[0]["number_of_things"] + ) + + # Test combinations of empty list for either record or result identifiers. + results = psm.retrieve(record_identifier=r_ids, result_identifier=[]) + assert len(results["result_identifiers"]) == 9 + assert len(results["records"]) == 2 + + results = psm.retrieve(record_identifier=[], result_identifier=[]) + assert len(results["result_identifiers"]) == 9 + assert len(results["records"]) == 3 + + results = psm.retrieve(record_identifier=[], result_identifier=res_id) + assert "md5sum" in results["result_identifiers"] + assert len(results["records"]) == 3 + + @pytest.mark.parametrize("backend", ["db"]) def test_get_records( self, config_file_path,