diff --git a/h1st/model/model.py b/h1st/model/model.py index 9608d72d..1831dd39 100644 --- a/h1st/model/model.py +++ b/h1st/model/model.py @@ -47,9 +47,8 @@ def __init__(self): self.stats = {} self.metrics = {} self.base_model = None - self.model_repo = ModelRepository() - def persist(self, version=None) -> str: + def persist(self, path: str, version: str = None) -> str: """ Persist this model's properties to the ModelRepository. Currently, only `stats`, `metrics`, `model` properties are supported. @@ -59,15 +58,15 @@ def persist(self, version=None) -> str: :param version: model version, leave blank for autogeneration :returns: model version """ - repo = self.model_repo.get_model_repo(self) + repo = ModelRepository(storage=path) return repo.persist(model=self, version=version) - def load(self, version: str = None) -> Any: + def load(self, path: str, version: str = None) -> Any: """ Load parameters from the specified `version` from the ModelRepository. Leave version blank to load latest version. """ - repo = self.model_repo.get_model_repo(self) + repo = ModelRepository(storage=path) repo.load(model=self, version=version) return self diff --git a/h1st/model/repository/model_repository.py b/h1st/model/repository/model_repository.py index 53aecd04..02d3a3ab 100644 --- a/h1st/model/repository/model_repository.py +++ b/h1st/model/repository/model_repository.py @@ -525,14 +525,15 @@ def _get_key(self, model, version): return key - def get_model_repo(self, ref=None): + @classmethod + def get_model_repo(cls, ref=None): """ Retrieve the default model repository for the project :param ref: target model :returns: Model repository instance """ - if not hasattr(self, "MODEL_REPO"): # ModelRepository.MODEL_REPO + if not hasattr(cls, "MODEL_REPO"): # global ModelRepository.MODEL_REPO repo_path = None if ref is not None: # root module @@ -566,9 +567,9 @@ def get_model_repo(self, ref=None): if not repo_path: raise RuntimeError("Please set MODEL_REPO_PATH in config.py") - self.MODEL_REPO = ModelRepository(storage=repo_path) + setattr(cls, "MODEL_REPO", ModelRepository(storage=repo_path)) - return self.MODEL_REPO + return getattr(cls, "MODEL_REPO") def _tar_create(target, source): diff --git a/h1st/model/repository/storage/s3.py b/h1st/model/repository/storage/s3.py index 4d2dc2f9..b8a10af8 100644 --- a/h1st/model/repository/storage/s3.py +++ b/h1st/model/repository/storage/s3.py @@ -57,7 +57,7 @@ def set_obj(self, name: str, value: Any) -> NoReturn: :param value: value in python object """ key = self._to_key(name) - logger.info(f"---Saving obj {key} to S3, value {value}") + logger.info(f"---Saving obj {key} to S3") with self.fs.open(key, 'wb') as f: return cloudpickle.dump(value, f) @@ -70,7 +70,7 @@ def set_bytes(self, name: str, value: bytes) -> NoReturn: :param value: value in bytes """ key = self._to_key(name) - logger.info(f"---Saving bytes {key} to S3, value {value}") + logger.info(f"---Saving bytes {key} to S3") with self.fs.open(key, 'wb') as f: f.write(value) diff --git a/pyproject.toml b/pyproject.toml index 84b55c1a..26c23dd0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "h1st" -version = "2.1.5" +version = "2.1.6" description = "Human-First AI (H1st)" authors = ["Aitomatic, Inc. "] license = "Apache-2.0"