Skip to content

Commit

Permalink
Handle no metadata for 'wild' subset upload
Browse files Browse the repository at this point in the history
  • Loading branch information
JosselinSomervilleRoberts committed Aug 1, 2024
1 parent 392b1e9 commit be615b5
Showing 1 changed file with 41 additions and 12 deletions.
53 changes: 41 additions & 12 deletions src/image2structure/upload.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Union
from tqdm import tqdm
from datasets import Dataset
from datasets import DatasetDict, Features, Value, Image as HFImage, Sequence
Expand All @@ -12,6 +12,7 @@
import numpy as np
import json
import imagehash
import uuid


def load_image(image_path: str) -> Image.Image:
Expand All @@ -38,8 +39,11 @@ def load_archive(archive_path: str) -> str:

def transform(row: dict) -> dict:
row["image"] = load_image(row["image"])
metadata_str: str = load_file(row["metadata"])
metadata: Dict[str, Any] = json.loads(metadata_str)
metadata: Dict[str, Any] = (
json.loads(load_file(row["metadata"]))
if isinstance(row["metadata"], str)
else row["metadata"]
)
for key in metadata:
if key != "assets":
row[key] = json.dumps(metadata[key], indent=4)
Expand Down Expand Up @@ -115,6 +119,10 @@ def classify_difficulty(dataset, data_type: str, wild_data: bool = False):
return Dataset.from_pandas(df)


def parse_list_or_str(value: str) -> List[str]:
return value.replace(" ", "").replace("[", "").replace("]", "").split(",")


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Upload collected data to huggingface")
parser.add_argument(
Expand All @@ -135,6 +143,12 @@ def parse_args() -> argparse.Namespace:
default=-1,
help="The maximum number of instances to upload",
)
parser.add_argument(
"--subset",
type=parse_list_or_str,
default=None,
help="The subset of the data to upload. By default will upload all the data",
)
return parser.parse_args()


Expand All @@ -143,25 +157,30 @@ def main():

data_type: str = os.path.basename(args.data_path)
print(f"\nUploading {data_type} dataset...")
for category in os.listdir(args.data_path):
categories: List[str] = (
os.listdir(args.data_path) if args.subset is None else args.subset
)
print(f"Categories: {categories}")
for category in categories:
print(f"\nUploading {category} dataset...")
data_path: str = os.path.join(args.data_path, category)

# There should be 4 folders in the data_path
# - images
# - structures
# - metadata
# - structures (except for in-the-wild)
# - metadata (should be present but we handle the case where it is not)
# - assets
image_path = os.path.join(data_path, "images")
structure_path = os.path.join(data_path, "structures")
metadata_path = os.path.join(data_path, "metadata")
assets_path = os.path.join(data_path, "assets")
text_path = os.path.join(data_path, "text")
for path in [image_path, metadata_path, assets_path]:
for path in [image_path, assets_path]:
if not os.path.exists(path):
raise FileNotFoundError(f"{path} does not exist")
has_structure: bool = os.path.exists(structure_path)
has_text: bool = os.path.exists(text_path)
has_metadata: bool = os.path.exists(metadata_path)

num_data_points: int = len(os.listdir(image_path))

Expand All @@ -188,8 +207,9 @@ def main():
image_set = set()
for i in tqdm(range(num_data_points), desc="Loading data"):
try:
values = {}
file_name: str = file_names[i].replace(".png", "")
values: Dict[str, Any] = {}
image_name = file_names[i]
file_name: str = image_name.replace(".png", "").replace(".jpg", "")

if has_structure:
structure_file = os.path.join(
Expand All @@ -209,14 +229,23 @@ def main():
text: str = load_file(os.path.join(text_path, f"{file_name}.txt"))
values["text"] = [text]

image = os.path.join(image_path, f"{file_name}.png")
image = os.path.join(image_path, image_name)
hashed_img: str = str(imagehash.average_hash(load_image(image)))
if hashed_img in image_set:
continue
image_set.add(hashed_img)
values["image"] = [image]

metadata = os.path.join(metadata_path, f"{file_name}.json")
metadata: Union[str, Dict[str, Any]]
if has_metadata:
metadata = os.path.join(metadata_path, f"{file_name}.json")
else:
metadata = {
"assets": [],
"additional_info": {},
"category": category,
"uuid": str(uuid.uuid4()),
}
values["metadata"] = [metadata]

df = pd.concat([df, pd.DataFrame(values)])
Expand Down Expand Up @@ -246,7 +275,7 @@ def main():

# Classify the difficulty of the instances
valid_dataset = classify_difficulty(
valid_dataset, data_type, category == "wild"
valid_dataset, data_type, "wild" in category.lower()
)
# valid_dataset = Dataset.from_pandas(df)
# Print first 5 instances
Expand Down

0 comments on commit be615b5

Please sign in to comment.