diff --git a/luxonis_train/utils/config.py b/luxonis_train/utils/config.py index 48661f7d..591376f8 100644 --- a/luxonis_train/utils/config.py +++ b/luxonis_train/utils/config.py @@ -5,7 +5,7 @@ from luxonis_ml.data import BucketStorage, BucketType from luxonis_ml.utils import Environ, LuxonisConfig, LuxonisFileSystem, setup_logging -from pydantic import BaseModel, Field, field_serializer, model_validator +from pydantic import BaseModel, ConfigDict, Field, field_serializer, model_validator from luxonis_train.utils.general import is_acyclic from luxonis_train.utils.registry import MODELS @@ -13,7 +13,11 @@ logger = logging.getLogger(__name__) -class AttachedModuleConfig(BaseModel): +class CustomBaseModel(BaseModel): + model_config = ConfigDict(extra="forbid") + + +class AttachedModuleConfig(CustomBaseModel): name: str attached_to: str alias: str | None = None @@ -28,12 +32,12 @@ class MetricModuleConfig(AttachedModuleConfig): is_main_metric: bool = False -class FreezingConfig(BaseModel): +class FreezingConfig(CustomBaseModel): active: bool = False unfreeze_after: int | float | None = None -class ModelNodeConfig(BaseModel): +class ModelNodeConfig(CustomBaseModel): name: str alias: str | None = None inputs: list[str] = [] @@ -41,7 +45,7 @@ class ModelNodeConfig(BaseModel): freezing: FreezingConfig = FreezingConfig() -class PredefinedModelConfig(BaseModel): +class PredefinedModelConfig(CustomBaseModel): name: str params: dict[str, Any] = {} include_nodes: bool = True @@ -50,7 +54,7 @@ class PredefinedModelConfig(BaseModel): include_visualizers: bool = True -class ModelConfig(BaseModel): +class ModelConfig(CustomBaseModel): name: str predefined_model: PredefinedModelConfig | None = None weights: str | None = None @@ -114,7 +118,7 @@ def check_unique_names(self): return self -class TrackerConfig(BaseModel): +class TrackerConfig(CustomBaseModel): project_name: str | None = None project_id: str | None = None run_name: str | None = None @@ -126,7 +130,7 @@ class TrackerConfig(BaseModel): is_mlflow: bool = False -class DatasetConfig(BaseModel): +class DatasetConfig(CustomBaseModel): name: str | None = None id: str | None = None team_name: str | None = None @@ -143,7 +147,7 @@ def get_enum_value(self, v: Enum, _) -> str: return str(v.value) -class NormalizeAugmentationConfig(BaseModel): +class NormalizeAugmentationConfig(CustomBaseModel): active: bool = True params: dict[str, Any] = { "mean": [0.485, 0.456, 0.406], @@ -151,12 +155,12 @@ class NormalizeAugmentationConfig(BaseModel): } -class AugmentationConfig(BaseModel): +class AugmentationConfig(CustomBaseModel): name: str params: dict[str, Any] = {} -class PreprocessingConfig(BaseModel): +class PreprocessingConfig(CustomBaseModel): train_image_size: Annotated[ list[int], Field(default=[256, 256], min_length=2, max_length=2) ] = [256, 256] @@ -174,23 +178,23 @@ def check_normalize(self): return self -class CallbackConfig(BaseModel): +class CallbackConfig(CustomBaseModel): name: str active: bool = True params: dict[str, Any] = {} -class OptimizerConfig(BaseModel): +class OptimizerConfig(CustomBaseModel): name: str = "Adam" params: dict[str, Any] = {} -class SchedulerConfig(BaseModel): +class SchedulerConfig(CustomBaseModel): name: str = "ConstantLR" params: dict[str, Any] = {} -class TrainerConfig(BaseModel): +class TrainerConfig(CustomBaseModel): preprocessing: PreprocessingConfig = PreprocessingConfig() accelerator: Literal["auto", "cpu", "gpu"] = "auto" @@ -229,17 +233,17 @@ def check_num_workes_platform(self): return self -class OnnxExportConfig(BaseModel): +class OnnxExportConfig(CustomBaseModel): opset_version: int = 12 dynamic_axes: dict[str, Any] | None = None -class BlobconverterExportConfig(BaseModel): +class BlobconverterExportConfig(CustomBaseModel): active: bool = False shaves: int = 6 -class ExportConfig(BaseModel): +class ExportConfig(CustomBaseModel): export_save_directory: str = "output_export" input_shape: list[int] | None = None export_model_name: str = "model" @@ -265,12 +269,12 @@ def pad_values(values: float | list[float] | None): return self -class StorageConfig(BaseModel): +class StorageConfig(CustomBaseModel): active: bool = True storage_type: Literal["local", "remote"] = "local" -class TunerConfig(BaseModel): +class TunerConfig(CustomBaseModel): study_name: str = "test-study" use_pruner: bool = True n_trials: int | None = 15