Skip to content

Commit

Permalink
Merge pull request #634 from voxel51/release/v0.13.0
Browse files Browse the repository at this point in the history
Release v0.13.0
  • Loading branch information
benjaminpkane authored Sep 16, 2024
2 parents 3f14178 + d72657d commit b90a7cb
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 49 deletions.
159 changes: 111 additions & 48 deletions eta/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,12 +1038,17 @@ class Model(Serializable):
Attributes:
base_name: the base name of the model (no version info)
base_filename: the base filename of the model (if any, no version info)
base_filename: the base filename or directory of the model (if any)
(no version info)
subdir: the model's subdirectory (if any)
manager: the ModelManager instance that describes the remote storage
location of the models_dir (if any)
version: the version of the model (if any)
author (optional): the author of the model
version: (optional) the model version
url (optional): the URL where the model is hosted
source (optional): the source of the model
license (optional): the license under which the model is distributed
description: the description of the model (if any)
source: the source of the model (if any)
size_bytes: the size of the model on disk (if any)
default_deployment_config_dict: a dictionary representation of an
`eta.core.learning.ModelConfig` describing the recommended settings
Expand All @@ -1061,10 +1066,14 @@ def __init__(
self,
base_name,
base_filename=None,
subdir=None,
manager=None,
author=None,
version=None,
description=None,
url=None,
source=None,
license=None,
description=None,
size_bytes=None,
default_deployment_config_dict=None,
requirements=None,
Expand All @@ -1076,10 +1085,14 @@ def __init__(
Args:
base_name: the base name of the model
base_filename (optional): the base filename for the model
subdir: the model's subdirectory (if any)
manager (optional): the ModelManager for the model
author (optional): the author of the model
version: (optional) the model version
url (optional): the URL where the model is hosted
source (optional): the source of the model
license (optional): the license under which the model is distributed
description: (optional) the description of the model
source: (optional) the source of the model
size_bytes: (optional) the size of the model on disk
default_deployment_config_dict: (optional) a dictionary
representation of an `eta.core.learning.ModelConfig` describing
Expand All @@ -1090,10 +1103,14 @@ def __init__(
"""
self.base_name = base_name
self.base_filename = base_filename
self.subdir = subdir
self.manager = manager
self.author = author
self.version = version or None
self.description = description
self.url = url
self.source = source
self.license = license
self.description = description
self.size_bytes = size_bytes
self.default_deployment_config_dict = default_deployment_config_dict
self.requirements = requirements
Expand All @@ -1112,14 +1129,19 @@ def name(self):
@property
def filename(self):
"""The version-aware filename of the model."""
if not self.has_version:
return self.base_filename

if self.base_filename is None:
return None

base, ext = os.path.splitext(self.base_filename)
return base + "-v" + self.version + ext
if self.has_version:
base, ext = os.path.splitext(self.base_filename)
filename = base + "-v" + self.version + ext
else:
filename = self.base_filename

if self.subdir is not None:
filename = os.path.join(self.subdir, filename)

return filename

@property
def has_manager(self):
Expand Down Expand Up @@ -1383,17 +1405,11 @@ def parse_name(name):
Returns:
base_name: the base name of the model
version: the version of the model, or None if no version was found
Raises:
ModelError: if the model name was invalid
"""
chunks = name.split("@")
chunks = name.rsplit("@", 1)
if len(chunks) == 1:
return name, None

if chunks[1] == "" or len(chunks) > 2:
raise ModelError("Invalid model name '%s'" % name)

return chunks[0], chunks[1]

@staticmethod
Expand All @@ -1406,7 +1422,7 @@ def has_version_str(name):
Returns:
True/False
"""
return bool(Model.parse_name(name)[1])
return Model.parse_name(name)[1] is not None

def attributes(self):
"""Returns a list of class attributes to be serialized.
Expand All @@ -1417,9 +1433,12 @@ def attributes(self):
return [
"base_name",
"base_filename",
"author",
"version",
"description",
"url",
"source",
"license",
"description",
"size_bytes",
"manager",
"default_deployment_config_dict",
Expand All @@ -1429,11 +1448,12 @@ def attributes(self):
]

@classmethod
def from_dict(cls, d):
def from_dict(cls, d, subdir=None):
"""Constructs a Model from a JSON dictionary.
Args:
d: a JSON dictionary
subdir (optional): a subdirectory for the model
Returns:
a Model instance
Expand All @@ -1453,10 +1473,14 @@ def from_dict(cls, d):
return cls(
d["base_name"],
base_filename=d.get("base_filename", None),
subdir=subdir,
manager=manager,
author=d.get("author", None),
version=d.get("version", None),
description=d.get("description", None),
url=d.get("url", None),
source=d.get("source", None),
license=d.get("license", None),
description=d.get("description", None),
size_bytes=d.get("size_bytes", None),
default_deployment_config_dict=d.get(
"default_deployment_config_dict", None
Expand All @@ -1472,70 +1496,103 @@ class ModelsManifest(Serializable):

_MODEL_CLS = Model

def __init__(self, models=None):
def __init__(self, models=None, name=None, url=None):
"""Creates a ModelsManifest instance.
Args:
models: a list of Model instances
name (optional): a name for the manifest
url (optional): the source location of the manifest
"""
self.models = models or []
if models is None:
models = []

if name is not None:
subdir = os.path.join(*name.split("/"))
for model in models:
model.subdir = subdir
else:
subdir = None

self.models = models
self.name = name
self.url = url
self._subdir = subdir

def __iter__(self):
return iter(self.models)

def add_model(self, model):
@property
def subdir(self):
return self._subdir

def add_model(self, model, error_level=0):
"""Adds the given model to the manifest.
Args:
model: a Model instance
error_level: the error level to use, defined as:
Raises:
ModelError: if the model conflicts with an existing model in the
manifest
0: raise error if the model cannot be added
1: log warning if the model cannot be added
2: ignore models that cannot be added
"""
if self.has_model_with_name(model.name):
raise ModelError(
error_msg = (
"Manifest already contains model called '%s'" % model.name
)
etau.handle_error(ModelError(error_msg), error_level)
return

if model.filename is not None and self.has_model_with_filename(
model.filename
):
raise ModelError(
if self.has_model_with_filename(model):
error_msg = (
"Manifest already contains model with filename '%s'"
% (model.filename)
% model.filename
)
etau.handle_error(ModelError(error_msg), error_level)
return

if self.has_model_with_name(model.base_name):
raise ModelError(
error_msg = (
"Manifest already contains a versionless model called '%s', "
"so a versioned model is not allowed" % model.base_name
)
"so a versioned model is not allowed"
) % model.base_name
etau.handle_error(ModelError(error_msg), error_level)
return

self.models.append(model)

def remove_model(self, name):
def remove_model(self, name, error_level=0):
"""Removes the model with the given name from the ModelsManifest.
Args:
name: the name of the model
error_level: the error level to use, defined as:
Raises:
ModelError: if the model was not found
0: raise error if the model cannot be added
1: log warning if the model cannot be added
2: ignore models that cannot be added
"""
if not self.has_model_with_name(name):
raise ModelError("Manifest does not contain model '%s'" % name)
error_msg = "Manifest does not contain model '%s'" % name
etau.handle_error(ModelError(error_msg), error_level)
return

self.models = [model for model in self.models if model.name != name]

def merge(self, models_manifest):
def merge(self, models_manifest, error_level=0):
"""Merges the models manifest into this one.
Args:
models_manifest: a ModelsManifest
error_level: the error level to use, defined as:
0: raise error if a model cannot be added
1: log warning if a model cannot be added
2: ignore models that cannot be added
"""
for model in models_manifest:
self.add_model(model)
self.add_model(model, error_level=error_level)

def get_model_with_name(self, name):
"""Gets the model with the given name.
Expand Down Expand Up @@ -1593,17 +1650,20 @@ def has_model_with_name(self, name):
"""
return any(name == model.name for model in self.models)

def has_model_with_filename(self, filename):
"""Determines whether this manifest contains a model with the given
def has_model_with_filename(self, model):
"""Determines whether this manifest contains a model with a conflicting
filename.
Args:
filename: the filename
model: a Model instance
Returns:
True/False
"""
return any(filename == model.filename for model in self.models)
if model.filename is None:
return False

return any(model.filename == m.filename for m in self.models)

@staticmethod
def make_manifest_path(models_dir):
Expand Down Expand Up @@ -1664,7 +1724,10 @@ def from_dict(cls, d):
Returns:
a ModelsManifest
"""
return cls(models=[cls._MODEL_CLS.from_dict(md) for md in d["models"]])
models = [cls._MODEL_CLS.from_dict(md) for md in d.get("models", [])]
name = d.get("name", None)
url = d.get("url", None)
return cls(models=models, name=name, url=url)


class ModelManager(Configurable, Serializable):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from wheel.bdist_wheel import bdist_wheel


VERSION = "0.12.7"
VERSION = "0.13.0"


class BdistWheelCustom(bdist_wheel):
Expand Down

0 comments on commit b90a7cb

Please sign in to comment.