Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve sheet music filtering with a trained classifier #28

Merged
merged 9 commits into from
Feb 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 80 additions & 5 deletions experimental/pull_music_sheets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,21 @@
from pdf2image.exceptions import PDFPageCountError
from PIL import Image
from tqdm import tqdm
from torchvision import transforms, models
import torch

from image2structure.util.credentials_utils import get_credentials
from image2structure.util.hierarchical_logger import htrack_block, hlog
from image2structure.util.image_utils import pdf_to_image, is_mostly_white


# Increase the maximum number of pixels allowed
Image.MAX_IMAGE_PIXELS = 700000000


"""
Pull music sheets from IMSL (International Music Score Library Project) to generate the MusicSheets2LilyPond dataset.
The sheet music classifier was trained on sheet music up to 2012 and achieved an accuracy of 99.2% on the test set.

Usage:
python experimental/pull_music_sheets.py <start_year> <end_year> -n <num_examples> -o <output_dir> -c <credentials_path>
Expand All @@ -28,6 +35,60 @@
"""


class SheetMusicClassifier:
"""
A simple classifier to determine if an image is a sheet music or not.
"""

def __init__(self, path_to_model: str = "experimental/sheet_music_classifier.pt"):
with htrack_block(f"Loading the sheet music classifier from {path_to_model}"):
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Assume you have a model defined (or modified) as `model`
model = models.resnet18(pretrained=False) # Example: ResNet-18
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(
num_ftrs, 2
) # Adjusting for binary classification

# Load the trained model weights
model.load_state_dict(torch.load(path_to_model, map_location=self._device))

# Set the model to evaluation mode
model.eval()

self._model = model.to(self._device)
self._transform = transforms.Compose(
[
transforms.Resize(1024),
transforms.CenterCrop(512),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)

def is_sheet_music(self, image_path: str) -> bool:
with torch.no_grad():
# Open the image file
image = Image.open(image_path)

# Convert the image to RGB if it's not already (important for RGBA or grayscale images)
if image.mode != "RGB":
image = image.convert("RGB")

# Apply transformations and move to the appropriate device
image_tensor = self._transform(image).unsqueeze(0).to(self._device)

# Perform inference
outputs = self._model(image_tensor)
_, predicted = torch.max(outputs, 1)

# Return True if predicted class is 1 (sheet music), else False
return predicted.item() == 1


def fetch_music_sheets(
num_examples: int,
year_range: Tuple[int, int],
Expand Down Expand Up @@ -62,6 +123,9 @@ def fetch_music_sheets(
c = client.ImslpClient(username=username, password=password)
hlog("Login to IMSLP was successful. Created ImslpClient.\n")

# Initialize the sheet music classifier
model = SheetMusicClassifier()

# Create the output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

Expand All @@ -70,10 +134,12 @@ def fetch_music_sheets(
imslp_url: str = "https://imslp.org/wiki/"

with htrack_block(
f"Attempting to generate {num_examples} examples from IMSLP between {year_range[0]} and {year_range[1]}..."
f"Attempting to generate {num_examples} examples from IMSLP between {year_range[0]} and {year_range[1]}..."
):
# `search_works` without any arguments returns all works
with htrack_block("Searching for all works. Please be patient as this may take a few minutes."):
with htrack_block(
"Searching for all works. Please be patient as this may take a few minutes."
):
results = c.search_works()
hlog(f"Found {len(results)} works.")

Expand Down Expand Up @@ -134,18 +200,27 @@ def fetch_music_sheets(
else:
page_number = 1

generated: bool = generate_sheet_image(file_path, image_path, page_number)
generated: bool = generate_sheet_image(
file_path, image_path, page_number
)

# Remove the PDF file
os.remove(file_path)
if generated:
if not model.is_sheet_music(image_path):
hlog(
f"Removing {image_path} as it was identified as not a sheet music."
)
os.remove(image_path)
continue

generated_count += 1
hlog(f"Generated {generated_count} of {num_examples} examples.")
break

# Add a delay to avoid subscription prompt
hlog("Sleeping for 16 seconds...")
time.sleep(16)
hlog("Sleeping for 5 seconds...")
time.sleep(5)

if generated_count >= num_examples:
hlog(f"Generated {num_examples} examples. Exiting...")
Expand Down
Binary file added experimental/sheet_music_classifier.pt
Binary file not shown.
2 changes: 1 addition & 1 deletion src/image2structure/util/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


# Increase the maximum number of pixels allowed
Image.MAX_IMAGE_PIXELS = 230000000
Image.MAX_IMAGE_PIXELS = 700000000


def is_mostly_white(
Expand Down