diff --git a/src/image2structure/upload.py b/src/image2structure/upload.py index cd32b0a..5fd0e08 100644 --- a/src/image2structure/upload.py +++ b/src/image2structure/upload.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Union from tqdm import tqdm from datasets import Dataset from datasets import DatasetDict, Features, Value, Image as HFImage, Sequence @@ -12,6 +12,7 @@ import numpy as np import json import imagehash +import uuid def load_image(image_path: str) -> Image.Image: @@ -38,8 +39,11 @@ def load_archive(archive_path: str) -> str: def transform(row: dict) -> dict: row["image"] = load_image(row["image"]) - metadata_str: str = load_file(row["metadata"]) - metadata: Dict[str, Any] = json.loads(metadata_str) + metadata: Dict[str, Any] = ( + json.loads(load_file(row["metadata"])) + if isinstance(row["metadata"], str) + else row["metadata"] + ) for key in metadata: if key != "assets": row[key] = json.dumps(metadata[key], indent=4) @@ -115,6 +119,10 @@ def classify_difficulty(dataset, data_type: str, wild_data: bool = False): return Dataset.from_pandas(df) +def parse_list_or_str(value: str) -> List[str]: + return value.replace(" ", "").replace("[", "").replace("]", "").split(",") + + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Upload collected data to huggingface") parser.add_argument( @@ -135,6 +143,12 @@ def parse_args() -> argparse.Namespace: default=-1, help="The maximum number of instances to upload", ) + parser.add_argument( + "--subset", + type=parse_list_or_str, + default=None, + help="The subset of the data to upload. By default will upload all the data", + ) return parser.parse_args() @@ -143,25 +157,30 @@ def main(): data_type: str = os.path.basename(args.data_path) print(f"\nUploading {data_type} dataset...") - for category in os.listdir(args.data_path): + categories: List[str] = ( + os.listdir(args.data_path) if args.subset is None else args.subset + ) + print(f"Categories: {categories}") + for category in categories: print(f"\nUploading {category} dataset...") data_path: str = os.path.join(args.data_path, category) # There should be 4 folders in the data_path # - images - # - structures - # - metadata + # - structures (except for in-the-wild) + # - metadata (should be present but we handle the case where it is not) # - assets image_path = os.path.join(data_path, "images") structure_path = os.path.join(data_path, "structures") metadata_path = os.path.join(data_path, "metadata") assets_path = os.path.join(data_path, "assets") text_path = os.path.join(data_path, "text") - for path in [image_path, metadata_path, assets_path]: + for path in [image_path, assets_path]: if not os.path.exists(path): raise FileNotFoundError(f"{path} does not exist") has_structure: bool = os.path.exists(structure_path) has_text: bool = os.path.exists(text_path) + has_metadata: bool = os.path.exists(metadata_path) num_data_points: int = len(os.listdir(image_path)) @@ -188,8 +207,9 @@ def main(): image_set = set() for i in tqdm(range(num_data_points), desc="Loading data"): try: - values = {} - file_name: str = file_names[i].replace(".png", "") + values: Dict[str, Any] = {} + image_name = file_names[i] + file_name: str = image_name.replace(".png", "").replace(".jpg", "") if has_structure: structure_file = os.path.join( @@ -209,14 +229,23 @@ def main(): text: str = load_file(os.path.join(text_path, f"{file_name}.txt")) values["text"] = [text] - image = os.path.join(image_path, f"{file_name}.png") + image = os.path.join(image_path, image_name) hashed_img: str = str(imagehash.average_hash(load_image(image))) if hashed_img in image_set: continue image_set.add(hashed_img) values["image"] = [image] - metadata = os.path.join(metadata_path, f"{file_name}.json") + metadata: Union[str, Dict[str, Any]] + if has_metadata: + metadata = os.path.join(metadata_path, f"{file_name}.json") + else: + metadata = { + "assets": [], + "additional_info": {}, + "category": category, + "uuid": str(uuid.uuid4()), + } values["metadata"] = [metadata] df = pd.concat([df, pd.DataFrame(values)]) @@ -246,7 +275,7 @@ def main(): # Classify the difficulty of the instances valid_dataset = classify_difficulty( - valid_dataset, data_type, category == "wild" + valid_dataset, data_type, "wild" in category.lower() ) # valid_dataset = Dataset.from_pandas(df) # Print first 5 instances