Skip to content

Commit

Permalink
forbid extra fields in config
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlov721 committed Feb 22, 2024
1 parent 2797278 commit fa22b14
Showing 1 changed file with 24 additions and 20 deletions.
44 changes: 24 additions & 20 deletions luxonis_train/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@

from luxonis_ml.data import BucketStorage, BucketType
from luxonis_ml.utils import Environ, LuxonisConfig, LuxonisFileSystem, setup_logging
from pydantic import BaseModel, Field, field_serializer, model_validator
from pydantic import BaseModel, ConfigDict, Field, field_serializer, model_validator

from luxonis_train.utils.general import is_acyclic
from luxonis_train.utils.registry import MODELS

logger = logging.getLogger(__name__)


class AttachedModuleConfig(BaseModel):
class CustomBaseModel(BaseModel):
model_config = ConfigDict(extra="forbid")


class AttachedModuleConfig(CustomBaseModel):
name: str
attached_to: str
alias: str | None = None
Expand All @@ -28,20 +32,20 @@ class MetricModuleConfig(AttachedModuleConfig):
is_main_metric: bool = False


class FreezingConfig(BaseModel):
class FreezingConfig(CustomBaseModel):
active: bool = False
unfreeze_after: int | float | None = None


class ModelNodeConfig(BaseModel):
class ModelNodeConfig(CustomBaseModel):
name: str
alias: str | None = None
inputs: list[str] = []
params: dict[str, Any] = {}
freezing: FreezingConfig = FreezingConfig()


class PredefinedModelConfig(BaseModel):
class PredefinedModelConfig(CustomBaseModel):
name: str
params: dict[str, Any] = {}
include_nodes: bool = True
Expand All @@ -50,7 +54,7 @@ class PredefinedModelConfig(BaseModel):
include_visualizers: bool = True


class ModelConfig(BaseModel):
class ModelConfig(CustomBaseModel):
name: str
predefined_model: PredefinedModelConfig | None = None
weights: str | None = None
Expand Down Expand Up @@ -114,7 +118,7 @@ def check_unique_names(self):
return self


class TrackerConfig(BaseModel):
class TrackerConfig(CustomBaseModel):
project_name: str | None = None
project_id: str | None = None
run_name: str | None = None
Expand All @@ -126,7 +130,7 @@ class TrackerConfig(BaseModel):
is_mlflow: bool = False


class DatasetConfig(BaseModel):
class DatasetConfig(CustomBaseModel):
name: str | None = None
id: str | None = None
team_name: str | None = None
Expand All @@ -143,20 +147,20 @@ def get_enum_value(self, v: Enum, _) -> str:
return str(v.value)


class NormalizeAugmentationConfig(BaseModel):
class NormalizeAugmentationConfig(CustomBaseModel):
active: bool = True
params: dict[str, Any] = {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
}


class AugmentationConfig(BaseModel):
class AugmentationConfig(CustomBaseModel):
name: str
params: dict[str, Any] = {}


class PreprocessingConfig(BaseModel):
class PreprocessingConfig(CustomBaseModel):
train_image_size: Annotated[
list[int], Field(default=[256, 256], min_length=2, max_length=2)
] = [256, 256]
Expand All @@ -174,23 +178,23 @@ def check_normalize(self):
return self


class CallbackConfig(BaseModel):
class CallbackConfig(CustomBaseModel):
name: str
active: bool = True
params: dict[str, Any] = {}


class OptimizerConfig(BaseModel):
class OptimizerConfig(CustomBaseModel):
name: str = "Adam"
params: dict[str, Any] = {}


class SchedulerConfig(BaseModel):
class SchedulerConfig(CustomBaseModel):
name: str = "ConstantLR"
params: dict[str, Any] = {}


class TrainerConfig(BaseModel):
class TrainerConfig(CustomBaseModel):
preprocessing: PreprocessingConfig = PreprocessingConfig()

accelerator: Literal["auto", "cpu", "gpu"] = "auto"
Expand Down Expand Up @@ -229,17 +233,17 @@ def check_num_workes_platform(self):
return self


class OnnxExportConfig(BaseModel):
class OnnxExportConfig(CustomBaseModel):
opset_version: int = 12
dynamic_axes: dict[str, Any] | None = None


class BlobconverterExportConfig(BaseModel):
class BlobconverterExportConfig(CustomBaseModel):
active: bool = False
shaves: int = 6


class ExportConfig(BaseModel):
class ExportConfig(CustomBaseModel):
export_save_directory: str = "output_export"
input_shape: list[int] | None = None
export_model_name: str = "model"
Expand All @@ -265,12 +269,12 @@ def pad_values(values: float | list[float] | None):
return self


class StorageConfig(BaseModel):
class StorageConfig(CustomBaseModel):
active: bool = True
storage_type: Literal["local", "remote"] = "local"


class TunerConfig(BaseModel):
class TunerConfig(CustomBaseModel):
study_name: str = "test-study"
use_pruner: bool = True
n_trials: int | None = 15
Expand Down

0 comments on commit fa22b14

Please sign in to comment.