diff --git a/src/citrine/__version__.py b/src/citrine/__version__.py index f15bcb892..0b288f83b 100644 --- a/src/citrine/__version__.py +++ b/src/citrine/__version__.py @@ -1 +1 @@ -__version__ = '2.11.4' +__version__ = '2.12.0' diff --git a/src/citrine/resources/predictor.py b/src/citrine/resources/predictor.py index 03ba50cb6..556cbbc01 100644 --- a/src/citrine/resources/predictor.py +++ b/src/citrine/resources/predictor.py @@ -121,12 +121,18 @@ def list_archived(self, collection_builder=self._build_collection_elements, per_page=per_page) - def archive(self, uid: Union[UUID, str], *, version: Union[int, str] = MOST_RECENT_VER): + def archive(self, + uid: Union[UUID, str], + *, + version: Union[int, str] = MOST_RECENT_VER) -> Predictor: url = self._construct_path(uid, version, "archive") entity = self.session.put_resource(url, {}, version=self._api_version) return self.build(entity) - def restore(self, uid: Union[UUID, str], *, version: Union[int, str] = MOST_RECENT_VER): + def restore(self, + uid: Union[UUID, str], + *, + version: Union[int, str] = MOST_RECENT_VER) -> Predictor: url = self._construct_path(uid, version, "restore") entity = self.session.put_resource(url, {}, version=self._api_version) return self.build(entity) @@ -147,6 +153,22 @@ def convert_to_graph(self, raise exc return self.build(entity) + def is_stale(self, + uid: Union[UUID, str], + *, + version: Union[int, str] = MOST_RECENT_VER) -> bool: + path = self._construct_path(uid, version, "is-stale") + response = self.session.get_resource(path, version=self._api_version) + return response["is_stale"] + + def retrain_stale(self, + uid: Union[UUID, str], + *, + version: Union[int, str] = MOST_RECENT_VER) -> Predictor: + path = self._construct_path(uid, version, "retrain-stale") + entity = self.session.put_resource(path, {}, version=self._api_version) + return self.build(entity) + class PredictorCollection(AbstractModuleCollection[Predictor]): """Represents the collection of all predictors for a project. @@ -416,3 +438,20 @@ def convert_and_update(self, """ new_pred = self.convert_to_graph(uid, version=version, retrain_if_needed=retrain_if_needed) return self.update(new_pred) if new_pred else None + + def is_stale(self, uid: Union[UUID, str], *, version: Union[int, str]) -> bool: + """Returns True if a predictor is stale, False otherwise. + + A predictor is stale if it's in the READY state, but the platform cannot load the + previously trained object. + """ + return self._versions_collection.is_stale(uid, version=version) + + def retrain_stale(self, uid: Union[UUID, str], *, version: Union[int, str]) -> Predictor: + """Begins retraining a stale predictor. + + This can only be used on a stale predictor, which is when it's in the READY state, but the + platform cannot load the previously trained object. Using it on a non-stale predictor will + result in an error. + """ + return self._versions_collection.retrain_stale(uid, version=version) diff --git a/tests/resources/test_predictor.py b/tests/resources/test_predictor.py index 3f2de5342..1f32643c0 100644 --- a/tests/resources/test_predictor.py +++ b/tests/resources/test_predictor.py @@ -628,3 +628,40 @@ def test_restore_invalid_version(valid_graph_predictor_data, version): with pytest.raises(ValueError): pc.restore_version(uuid.uuid4(), version=version) + + +@pytest.mark.parametrize("is_stale", (True, False)) +def test_is_stale(valid_graph_predictor_data, is_stale): + session = FakeSession() + pc = PredictorCollection(uuid.uuid4(), session) + pred_id = valid_graph_predictor_data["id"] + pred_version = valid_graph_predictor_data["metadata"]["version"] + response = { + "id": pred_id, + "version": pred_version, + "status": "READY", + "is_stale": is_stale + } + session.set_response(response) + + resp = pc.is_stale(pred_id, version=pred_version) + + versions_path = _PredictorVersionCollection._path_template.format(project_id=pc.project_id, uid=pred_id) + assert session.calls == [FakeCall(method='GET', path=f"{versions_path}/{pred_version}/is-stale")] + assert resp == is_stale + +def test_retrain_stale(valid_graph_predictor_data): + session = FakeSession() + pc = PredictorCollection(uuid.uuid4(), session) + pred_id = valid_graph_predictor_data["id"] + pred_version = valid_graph_predictor_data["metadata"]["version"] + + response = deepcopy(valid_graph_predictor_data) + response["metadata"]["status"]["name"] = "VALIDATING" + response["metadata"]["status"]["detail"] = [] + session.set_response(response) + + pc.retrain_stale(pred_id, version=pred_version) + + versions_path = _PredictorVersionCollection._path_template.format(project_id=pc.project_id, uid=pred_id) + assert session.calls == [FakeCall(method='PUT', path=f"{versions_path}/{pred_version}/retrain-stale", json={})]