diff --git a/src/labelformat/errors.py b/src/labelformat/errors.py new file mode 100644 index 0000000..57dcf3e --- /dev/null +++ b/src/labelformat/errors.py @@ -0,0 +1,4 @@ +class LabelWithoutImageError(Exception): + """Raised when a label is found without a corresponding image.""" + + pass diff --git a/src/labelformat/formats/lightly.py b/src/labelformat/formats/lightly.py index 8ff9be3..f721e65 100644 --- a/src/labelformat/formats/lightly.py +++ b/src/labelformat/formats/lightly.py @@ -6,6 +6,7 @@ from labelformat import utils from labelformat.cli.registry import Task, cli_register +from labelformat.errors import LabelWithoutImageError from labelformat.model.bounding_box import BoundingBox, BoundingBoxFormat from labelformat.model.category import Category from labelformat.model.image import Image @@ -15,7 +16,6 @@ ObjectDetectionOutput, SingleObjectDetection, ) -from labelformat.types import JsonDict @cli_register(format="lightly", task=Task.OBJECT_DETECTION) @@ -34,14 +34,21 @@ def add_cli_arguments(parser: ArgumentParser) -> None: default="../images", help="Relative path to images folder from label folder", ) + parser.add_argument( + "--skip-labels-without-image", + action="store_true", + help="Skip labels without corresponding image", + ) def __init__( self, input_folder: Path, images_rel_path: str = "../images", + skip_labels_without_image: bool = False, ) -> None: self._input_folder = input_folder self._images_rel_path = images_rel_path + self._skip_labels_without_image = skip_labels_without_image self._categories = self._get_categories() def get_categories(self) -> Iterable[Category]: @@ -62,6 +69,12 @@ def get_labels(self) -> Iterable[ImageObjectDetection]: if json_path.name == "schema.json": continue data = json.loads(json_path.read_text()) + if data["file_name"] not in filename_to_image: + if self._skip_labels_without_image: + continue + raise LabelWithoutImageError( + f"Label '{json_path.name}' does not have a corresponding image." + ) image = filename_to_image[data["file_name"]] objects = [] for prediction in data["predictions"]: diff --git a/tests/unit/formats/test_lightly.py b/tests/unit/formats/test_lightly.py new file mode 100644 index 0000000..385f270 --- /dev/null +++ b/tests/unit/formats/test_lightly.py @@ -0,0 +1,172 @@ +import json +from pathlib import Path + +import pytest +from pytest_mock import MockerFixture + +from labelformat.errors import LabelWithoutImageError +from labelformat.formats.lightly import ( + LightlyObjectDetectionInput, + LightlyObjectDetectionOutput, +) +from labelformat.model.bounding_box import BoundingBox +from labelformat.model.category import Category +from labelformat.model.image import Image +from labelformat.model.object_detection import ( + ImageObjectDetection, + SingleObjectDetection, +) + +from ...simple_object_detection_label_input import SimpleObjectDetectionInput + + +def _create_label_file(tmp_path: Path) -> Path: + """Create a dummy label file in the given directory.""" + annotation = json.dumps( + { + "file_name": "image.jpg", + "predictions": [ + { + "category_id": 1, + "bbox": [10.0, 20.0, 20.0, 20.0], + }, + { + "category_id": 0, + "bbox": [50.0, 60.0, 20.0, 20.0], + }, + ], + } + ) + label_path = tmp_path / "labels" / "image.json" + label_path.parent.mkdir(parents=True, exist_ok=True) + label_path.write_text(annotation) + return label_path + + +def _create_schema_file(tmp_path: Path) -> Path: + """Create a dummy schema file in the given directory.""" + schema = json.dumps( + { + "task_type": "object-detection", + "categories": [ + {"name": "cat", "id": 0}, + {"name": "dog", "id": 1}, + {"name": "cow", "id": 2}, + ], + } + ) + schema_path = tmp_path / "labels" / "schema.json" + schema_path.write_text(schema) + return schema_path + + +class TestLightlyObjectDetectionInput: + def test_get_labels(self, tmp_path: Path, mocker: MockerFixture) -> None: + # Prepare inputs. + _create_label_file(tmp_path=tmp_path) + _create_schema_file(tmp_path=tmp_path) + + # Mock the image file. + (tmp_path / "images").mkdir() + (tmp_path / "images/image.jpg").touch() + mocker.patch("PIL.Image.open", autospec=True).return_value.size = (100, 200) + + # Convert. + label_input = LightlyObjectDetectionInput( + input_folder=tmp_path / "labels", + images_rel_path="../images", + ) + labels = list(label_input.get_labels()) + assert labels == [ + ImageObjectDetection( + image=Image(id=0, filename="image.jpg", width=100, height=200), + objects=[ + SingleObjectDetection( + category=Category(id=1, name="dog"), + box=BoundingBox( + xmin=10.0, + ymin=20.0, + xmax=30.0, + ymax=40.0, + ), + ), + SingleObjectDetection( + category=Category(id=0, name="cat"), + box=BoundingBox( + xmin=50.0, + ymin=60.0, + xmax=70.0, + ymax=80.0, + ), + ), + ], + ) + ] + + def test_get_labels__raises_label_without_image(self, tmp_path: Path) -> None: + # Prepare inputs. + _create_label_file(tmp_path=tmp_path) + _create_schema_file(tmp_path=tmp_path) + + # Try to convert. + label_input = LightlyObjectDetectionInput( + input_folder=tmp_path / "labels", + images_rel_path="../images", + ) + with pytest.raises( + LabelWithoutImageError, + match="Label 'image.json' does not have a corresponding image.", + ): + list(label_input.get_labels()) + + def test_get_labels__skip_label_without_image( + self, tmp_path: Path, mocker: MockerFixture + ) -> None: + # Prepare inputs. + _create_label_file(tmp_path=tmp_path) + _create_schema_file(tmp_path=tmp_path) + + # Convert. + label_input = LightlyObjectDetectionInput( + input_folder=tmp_path / "labels", + images_rel_path="../images", + skip_labels_without_image=True, + ) + assert list(label_input.get_labels()) == [] + + +class TestLightlyObjectDetectionOutput: + def test_save(self, tmp_path: Path) -> None: + output_folder = tmp_path / "labels" + LightlyObjectDetectionOutput(output_folder=output_folder).save( + label_input=SimpleObjectDetectionInput() + ) + assert output_folder.exists() + assert output_folder.is_dir() + + filepaths = sorted(list(output_folder.glob("**/*"))) + assert filepaths == [ + tmp_path / "labels" / "image.json", + tmp_path / "labels" / "schema.json", + ] + + contents = (tmp_path / "labels" / "image.json").read_text() + expected = json.dumps( + { + "file_name": "image.jpg", + "predictions": [ + { + "category_id": 1, + "bbox": [10.0, 20.0, 20.0, 20.0], + "score": 0.0, # default + }, + { + "category_id": 0, + "bbox": [50.0, 60.0, 20.0, 20.0], + "score": 0.0, # default + }, + ], + }, + indent=2, + ) + assert contents == expected