diff --git a/object_detection/object_detection/DetectorBase.py b/object_detection/object_detection/DetectorBase.py index 51ffa0e..dd70347 100644 --- a/object_detection/object_detection/DetectorBase.py +++ b/object_detection/object_detection/DetectorBase.py @@ -19,10 +19,12 @@ class DetectorBase(ABC): - def __init__(self, logger) -> None: - self.logger = logger + def __init__(self) -> None: self.predictions = [] + def set_logger(self, logger) -> None: + self.logger = logger + def create_predictions_list(self, class_ids, confidences, boxes): self.predictions = [] for i in range(len(class_ids)): diff --git a/object_detection/object_detection/Detectors/RetinaNet.py b/object_detection/object_detection/Detectors/RetinaNet.py index d873085..ffdb7c2 100755 --- a/object_detection/object_detection/Detectors/RetinaNet.py +++ b/object_detection/object_detection/Detectors/RetinaNet.py @@ -23,11 +23,8 @@ class RetinaNet(DetectorBase): - def __init__(self, logger): - super().__init__(logger) - - # Create a logger instance - self.logger = super().get_logger() + def __init__(self): + super().__init__() def build_model(self, model_dir_path, weight_file_name): model_path = os.path.join(model_dir_path, weight_file_name) diff --git a/object_detection/object_detection/Detectors/YOLOv5.py b/object_detection/object_detection/Detectors/YOLOv5.py index cc5cccb..eda8371 100755 --- a/object_detection/object_detection/Detectors/YOLOv5.py +++ b/object_detection/object_detection/Detectors/YOLOv5.py @@ -21,13 +21,10 @@ class YOLOv5(DetectorBase): - def __init__(self, logger, conf_threshold=0.7): - super().__init__(logger) + def __init__(self, conf_threshold=0.7): + super().__init__() self.conf_threshold = conf_threshold - # Create a logger instance - self.logger = super().get_logger() - def build_model(self, model_dir_path, weight_file_name): try: model_path = os.path.join(model_dir_path, weight_file_name) diff --git a/object_detection/object_detection/Detectors/YOLOv8.py b/object_detection/object_detection/Detectors/YOLOv8.py index d23afda..d67a4c1 100755 --- a/object_detection/object_detection/Detectors/YOLOv8.py +++ b/object_detection/object_detection/Detectors/YOLOv8.py @@ -21,9 +21,9 @@ class YOLOv8(DetectorBase): - def __init__(self, logger): + def __init__(self): - super().__init__(logger) + super().__init__() def build_model(self, model_dir_path, weight_file_name): try: diff --git a/object_detection/object_detection/ObjectDetection.py b/object_detection/object_detection/ObjectDetection.py index 6954ca8..060e725 100644 --- a/object_detection/object_detection/ObjectDetection.py +++ b/object_detection/object_detection/ObjectDetection.py @@ -36,7 +36,7 @@ def __init__(self): self.available_detectors = [] # Create a logger instance - self.logger = super().get_logger() + self.logger = self.get_logger() # Declare parameters with default values self.declare_parameters( @@ -111,8 +111,12 @@ def load_detector(self): detector_mod = importlib.import_module(".Detectors." + self.detector_type, "object_detection") detector_class = getattr(detector_mod, self.detector_type) - self.detector = detector_class(self.logger) + self.detector = detector_class() + # Set the logger for the detector plugins + self.detector.set_logger(self.logger) + + # Load the model and the classes file for the detector plugin self.detector.build_model(self.model_dir_path, self.weight_file_name) self.detector.load_classes(self.model_dir_path)