Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add retrieve_multiple to backends #101

Merged
merged 6 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions pipestat/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,24 @@ def retrieve(
_LOGGER.warning("Not implemented yet for this backend")
pass

def retrieve_multiple(
self,
record_identifier: Optional[List[str]] = None,
result_identifier: Optional[List[str]] = None,
limit: Optional[int] = 1000,
offset: Optional[int] = 0,
) -> Union[Any, Dict[str, Any]]:
"""
:param List[str] record_identifier: list of record identifiers
:param List[str] result_identifier: list of result identifiers to be retrieved
:param int limit: limit number of records to this amount
:param int offset: offset records by this amount
:return Dict[str, any]: a mapping with filtered results reported for the record
"""

_LOGGER.warning("Not implemented yet for this backend")
pass

def remove(
self,
record_identifier: Optional[str] = None,
Expand Down
66 changes: 66 additions & 0 deletions pipestat/backends/dbbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,72 @@ def retrieve(
}
raise RecordNotFoundError(f"Record '{record_identifier}' not found")

def retrieve_multiple(
self,
record_identifier: Optional[List[str]] = None,
result_identifier: Optional[List[str]] = None,
limit: Optional[int] = 1000,
offset: Optional[int] = 0,
) -> Union[Any, Dict[str, Any]]:
"""
:param List[str] record_identifier: list of record identifiers
:param List[str] result_identifier: list of result identifiers to be retrieved
:param int limit: limit number of records to this amount
:param int offset: offset records by this amount
:return Dict[str, any]: a mapping with filtered results reported for the record
"""

record_list = []

if result_identifier == []:
result_identifier = None
if record_identifier == []:
record_identifier = None

ORM = self.get_model(table_name=self.table_name)

if record_identifier is not None:
for r_id in record_identifier:
filter = [("record_identifier", "eq", r_id)]
result = self.select(
columns=result_identifier, filter_conditions=filter, limit=limit, offset=offset
)
retrieved_record = {}
result_dict = dict(result[0])
for k, v in list(result_dict.items()):
if k not in self.parsed_schema.results_data.keys():
result_dict.pop(k)
retrieved_record.update({r_id: result_dict})
record_list.append(retrieved_record)
if record_identifier is None:
if result_identifier is not None:
result_identifier = ["record_identifier"] + result_identifier
record_list = []
records = self.select(
columns=result_identifier, filter_conditions=None, limit=limit, offset=offset
)
for record in records:
retrieved_record = {}
r_id = record.record_identifier
record_dict = dict(record)
for k, v in list(record_dict.items()):
if k not in self.parsed_schema.results_data.keys():
record_dict.pop(k)
retrieved_record.update({r_id: record_dict})
record_list.append(retrieved_record)

records_dict = {
"count": len(record_list),
"limit": limit,
"offset": offset,
"record_identifiers": record_identifier,
"result_identifiers": result_identifier
or list(self.parsed_schema.results_data.keys()) + [CREATED_TIME] + [MODIFIED_TIME],
"records": record_list,
}

return records_dict

def select(
self,
columns: Optional[List[str]] = None,
Expand Down
52 changes: 52 additions & 0 deletions pipestat/backends/filebackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,58 @@ def report(

return results_formatted

def retrieve_multiple(
self,
record_identifier: Optional[List[str]] = None,
result_identifier: Optional[List[str]] = None,
limit: Optional[int] = 1000,
offset: Optional[int] = 0,
) -> Union[Any, Dict[str, Any]]:
"""
:param List[str] record_identifier: list of record identifiers
:param List[str] result_identifier: list of result identifiers to be retrieved
:param int limit: limit number of records to this amount
:param int offset: offset records by this amount
:return Dict[str, any]: a mapping with filtered results reported for the record
"""

record_list = []

if result_identifier == [] or result_identifier is None:
result_identifier = (
list(self.parsed_schema.results_data.keys()) + [CREATED_TIME] + [MODIFIED_TIME]
)
if record_identifier == [] or record_identifier is None:
record_identifier = list(
self._data.data[self.pipeline_name][self.pipeline_type].keys()
)

for k in list(self._data.data[self.pipeline_name][self.pipeline_type].keys())[
offset : offset + limit
]:
if k in record_identifier:
retrieved_record = {}
retrieved_results = {}
for key, value in self._data.data[self.pipeline_name][self.pipeline_type][
k
].items():
if key in result_identifier:
retrieved_results.update({key: value})

if retrieved_results != {}:
retrieved_record.update({k: retrieved_results})
record_list.append(retrieved_record)

records_dict = {
"count": len(record_list),
"limit": limit,
"offset": offset,
"record_identifiers": record_identifier,
"result_identifiers": result_identifier,
"records": record_list,
}
return records_dict

def retrieve(
self,
record_identifier: Optional[str] = None,
Expand Down
30 changes: 25 additions & 5 deletions pipestat/pipestat.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,22 +497,42 @@ def retrieve_distinct(
@require_backend
def retrieve(
self,
record_identifier: Optional[str] = None,
result_identifier: Optional[str] = None,
record_identifier: Optional[Union[str, List[str]]] = None,
result_identifier: Optional[Union[str, List[str]]] = None,
limit: Optional[int] = 1000,
offset: Optional[int] = 0,
) -> Union[Any, Dict[str, Any]]:
"""
Retrieve a result for a record.

If no result ID specified, results for the entire record will
be returned.

:param str record_identifier: name of the sample_level record
:param str result_identifier: name of the result to be retrieved
:return any | Dict[str, any]: a single result or a mapping with all the
:param str | List[str] record_identifier: name of the sample_level record
:param str | List[str] result_identifier: name of the result to be retrieved
:param int limit: limit number of records to this amount
:param int offset: offset records by this amount
:return any | Dict[str, any]: a single result or a mapping with filtered
results reported for the record
"""
if record_identifier is None and result_identifier is None:
# This will retrieve all records and columns.
return self.backend.retrieve_multiple(
record_identifier, result_identifier, limit, offset
)

if type(record_identifier) is list or type(result_identifier) is list:
if len(record_identifier) == 1 and len(result_identifier) == 1:
# If user gives single values, just use retrieve.
return self.backend.retrieve(record_identifier[0], result_identifier[0])
else:
# If user gives lists, retrieve_multiple
return self.backend.retrieve_multiple(
record_identifier, result_identifier, limit, offset
)

r_id = record_identifier or self.record_identifier

return self.backend.retrieve(r_id, result_identifier)

@require_backend
Expand Down
62 changes: 62 additions & 0 deletions tests/test_pipestat.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,68 @@ def test_retrieve_basic(
assert isinstance(psm.retrieve(record_identifier=rec_id), Mapping)

@pytest.mark.parametrize("backend", ["file", "db"])
def test_retrieve_multiple(
self,
config_file_path,
results_file_path,
schema_file_path,
backend,
):
values_sample = [
{"sample1": {"name_of_something": "string 1"}},
{"sample1": {"number_of_things": 1}},
{"sample2": {"name_of_something": "string 2"}},
{"sample2": {"number_of_things": 20}},
{"sample3": {"name_of_something": "string 3"}},
{"sample3": {"number_of_things": 300}},
]

with NamedTemporaryFile() as f, ContextManagerDBTesting(DB_URL):
results_file_path = f.name
args = dict(schema_path=schema_file_path, database_only=False)
backend_data = (
{"config_file": config_file_path}
if backend == "db"
else {"results_file_path": results_file_path}
)
args.update(backend_data)
psm = SamplePipestatManager(**args)

for i in values_sample:
for k, v in i.items():
psm.report(record_identifier=k, values=v, force_overwrite=True)

# Test singular list works as expected
r_id = list(values_sample[0].keys())[0]
res_id = list(list(values_sample[0].values())[0].keys())[0]
results = psm.retrieve(record_identifier=[r_id], result_identifier=[res_id])
assert results == list(list(values_sample[0].values())[0].values())[0]

# Use list of results
r_ids = ["sample1", "sample2"]
res_id = ["md5sum", "number_of_things"]
# res_id = list(list(values_sample[0].values())[0].keys())[0]
results = psm.retrieve(record_identifier=r_ids, result_identifier=res_id)
assert r_ids[0] == list(results["records"][0].keys())[0]
assert (
list(list(values_sample[3].values())[0].values())[0]
== list(results["records"][1].values())[0]["number_of_things"]
)

# Test combinations of empty list for either record or result identifiers.
results = psm.retrieve(record_identifier=r_ids, result_identifier=[])
assert len(results["result_identifiers"]) == 9
assert len(results["records"]) == 2

results = psm.retrieve(record_identifier=[], result_identifier=[])
assert len(results["result_identifiers"]) == 9
assert len(results["records"]) == 3

results = psm.retrieve(record_identifier=[], result_identifier=res_id)
assert "md5sum" in results["result_identifiers"]
assert len(results["records"]) == 3

@pytest.mark.parametrize("backend", ["db"])
def test_get_records(
self,
config_file_path,
Expand Down
Loading