Skip to content

Commit

Permalink
Fix/persist issue (#196)
Browse files Browse the repository at this point in the history
* fix: add path param to persist and load methods

* build: update to v2.1.6
  • Loading branch information
phamhoangtuan authored Jun 29, 2023
1 parent 9f05a87 commit 9d7ccec
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 12 deletions.
9 changes: 4 additions & 5 deletions h1st/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,8 @@ def __init__(self):
self.stats = {}
self.metrics = {}
self.base_model = None
self.model_repo = ModelRepository()

def persist(self, version=None) -> str:
def persist(self, path: str, version: str = None) -> str:
"""
Persist this model's properties to the ModelRepository. Currently, only `stats`, `metrics`, `model` properties are supported.
Expand All @@ -59,15 +58,15 @@ def persist(self, version=None) -> str:
:param version: model version, leave blank for autogeneration
:returns: model version
"""
repo = self.model_repo.get_model_repo(self)
repo = ModelRepository(storage=path)
return repo.persist(model=self, version=version)

def load(self, version: str = None) -> Any:
def load(self, path: str, version: str = None) -> Any:
"""
Load parameters from the specified `version` from the ModelRepository.
Leave version blank to load latest version.
"""
repo = self.model_repo.get_model_repo(self)
repo = ModelRepository(storage=path)
repo.load(model=self, version=version)

return self
Expand Down
9 changes: 5 additions & 4 deletions h1st/model/repository/model_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,14 +525,15 @@ def _get_key(self, model, version):

return key

def get_model_repo(self, ref=None):
@classmethod
def get_model_repo(cls, ref=None):
"""
Retrieve the default model repository for the project
:param ref: target model
:returns: Model repository instance
"""
if not hasattr(self, "MODEL_REPO"): # ModelRepository.MODEL_REPO
if not hasattr(cls, "MODEL_REPO"): # global ModelRepository.MODEL_REPO
repo_path = None
if ref is not None:
# root module
Expand Down Expand Up @@ -566,9 +567,9 @@ def get_model_repo(self, ref=None):
if not repo_path:
raise RuntimeError("Please set MODEL_REPO_PATH in config.py")

self.MODEL_REPO = ModelRepository(storage=repo_path)
setattr(cls, "MODEL_REPO", ModelRepository(storage=repo_path))

return self.MODEL_REPO
return getattr(cls, "MODEL_REPO")


def _tar_create(target, source):
Expand Down
4 changes: 2 additions & 2 deletions h1st/model/repository/storage/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def set_obj(self, name: str, value: Any) -> NoReturn:
:param value: value in python object
"""
key = self._to_key(name)
logger.info(f"---Saving obj {key} to S3, value {value}")
logger.info(f"---Saving obj {key} to S3")

with self.fs.open(key, 'wb') as f:
return cloudpickle.dump(value, f)
Expand All @@ -70,7 +70,7 @@ def set_bytes(self, name: str, value: bytes) -> NoReturn:
:param value: value in bytes
"""
key = self._to_key(name)
logger.info(f"---Saving bytes {key} to S3, value {value}")
logger.info(f"---Saving bytes {key} to S3")

with self.fs.open(key, 'wb') as f:
f.write(value)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "h1st"
version = "2.1.5"
version = "2.1.6"
description = "Human-First AI (H1st)"
authors = ["Aitomatic, Inc. <[email protected]>"]
license = "Apache-2.0"
Expand Down

0 comments on commit 9d7ccec

Please sign in to comment.