Skip to content

Commit

Permalink
Merge pull request #843 from CitrineInformatics/feature/pla-12391-ret…
Browse files Browse the repository at this point in the history
…rain-stale

[PLA-12391] Support retraining stale predictors.
  • Loading branch information
anoto-moniz authored Apr 12, 2023
2 parents e350a3b + 60d7e95 commit 6bc78b4
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/citrine/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.11.4'
__version__ = '2.12.0'
43 changes: 41 additions & 2 deletions src/citrine/resources/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,18 @@ def list_archived(self,
collection_builder=self._build_collection_elements,
per_page=per_page)

def archive(self, uid: Union[UUID, str], *, version: Union[int, str] = MOST_RECENT_VER):
def archive(self,
uid: Union[UUID, str],
*,
version: Union[int, str] = MOST_RECENT_VER) -> Predictor:
url = self._construct_path(uid, version, "archive")
entity = self.session.put_resource(url, {}, version=self._api_version)
return self.build(entity)

def restore(self, uid: Union[UUID, str], *, version: Union[int, str] = MOST_RECENT_VER):
def restore(self,
uid: Union[UUID, str],
*,
version: Union[int, str] = MOST_RECENT_VER) -> Predictor:
url = self._construct_path(uid, version, "restore")
entity = self.session.put_resource(url, {}, version=self._api_version)
return self.build(entity)
Expand All @@ -147,6 +153,22 @@ def convert_to_graph(self,
raise exc
return self.build(entity)

def is_stale(self,
uid: Union[UUID, str],
*,
version: Union[int, str] = MOST_RECENT_VER) -> bool:
path = self._construct_path(uid, version, "is-stale")
response = self.session.get_resource(path, version=self._api_version)
return response["is_stale"]

def retrain_stale(self,
uid: Union[UUID, str],
*,
version: Union[int, str] = MOST_RECENT_VER) -> Predictor:
path = self._construct_path(uid, version, "retrain-stale")
entity = self.session.put_resource(path, {}, version=self._api_version)
return self.build(entity)


class PredictorCollection(AbstractModuleCollection[Predictor]):
"""Represents the collection of all predictors for a project.
Expand Down Expand Up @@ -416,3 +438,20 @@ def convert_and_update(self,
"""
new_pred = self.convert_to_graph(uid, version=version, retrain_if_needed=retrain_if_needed)
return self.update(new_pred) if new_pred else None

def is_stale(self, uid: Union[UUID, str], *, version: Union[int, str]) -> bool:
"""Returns True if a predictor is stale, False otherwise.
A predictor is stale if it's in the READY state, but the platform cannot load the
previously trained object.
"""
return self._versions_collection.is_stale(uid, version=version)

def retrain_stale(self, uid: Union[UUID, str], *, version: Union[int, str]) -> Predictor:
"""Begins retraining a stale predictor.
This can only be used on a stale predictor, which is when it's in the READY state, but the
platform cannot load the previously trained object. Using it on a non-stale predictor will
result in an error.
"""
return self._versions_collection.retrain_stale(uid, version=version)
37 changes: 37 additions & 0 deletions tests/resources/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,3 +628,40 @@ def test_restore_invalid_version(valid_graph_predictor_data, version):

with pytest.raises(ValueError):
pc.restore_version(uuid.uuid4(), version=version)


@pytest.mark.parametrize("is_stale", (True, False))
def test_is_stale(valid_graph_predictor_data, is_stale):
session = FakeSession()
pc = PredictorCollection(uuid.uuid4(), session)
pred_id = valid_graph_predictor_data["id"]
pred_version = valid_graph_predictor_data["metadata"]["version"]
response = {
"id": pred_id,
"version": pred_version,
"status": "READY",
"is_stale": is_stale
}
session.set_response(response)

resp = pc.is_stale(pred_id, version=pred_version)

versions_path = _PredictorVersionCollection._path_template.format(project_id=pc.project_id, uid=pred_id)
assert session.calls == [FakeCall(method='GET', path=f"{versions_path}/{pred_version}/is-stale")]
assert resp == is_stale

def test_retrain_stale(valid_graph_predictor_data):
session = FakeSession()
pc = PredictorCollection(uuid.uuid4(), session)
pred_id = valid_graph_predictor_data["id"]
pred_version = valid_graph_predictor_data["metadata"]["version"]

response = deepcopy(valid_graph_predictor_data)
response["metadata"]["status"]["name"] = "VALIDATING"
response["metadata"]["status"]["detail"] = []
session.set_response(response)

pc.retrain_stale(pred_id, version=pred_version)

versions_path = _PredictorVersionCollection._path_template.format(project_id=pc.project_id, uid=pred_id)
assert session.calls == [FakeCall(method='PUT', path=f"{versions_path}/{pred_version}/retrain-stale", json={})]

0 comments on commit 6bc78b4

Please sign in to comment.