Skip to content

Commit

Permalink
[py-tx] Updated for pr revisions added cobinded unletterboxing and mo…
Browse files Browse the repository at this point in the history
…ved source files for unboxing and pytest
  • Loading branch information
Mackay-Fisher committed Nov 13, 2024
1 parent 926801e commit 01c2bf9
Show file tree
Hide file tree
Showing 9 changed files with 185 additions and 252 deletions.
114 changes: 56 additions & 58 deletions python-threatexchange/threatexchange/cli/hash_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from threatexchange.signal_type.signal_base import FileHasher, SignalType
from threatexchange.cli import command_base
from threatexchange.cli.helpers import FlexFilesInputAction
from threatexchange.signal_type.pdq.signal import PdqSignal


class HashCommand(command_base.Command):
Expand Down Expand Up @@ -54,6 +53,7 @@ def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> No
signal_choices = sorted(
s.get_name() for s in signal_types if issubclass(s, FileHasher)
)

ap.add_argument(
"content_type",
**common.argparse_choices_pre_type_kwargs(
Expand Down Expand Up @@ -81,50 +81,50 @@ def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> No
)

ap.add_argument(
"--preprocess",
choices=["unletterbox"],
help="Apply preprocessing steps to the image before hashing.",
"--photo-preprocess",
choices=["unletterbox", "rotations"],
help=(
"Apply one of the preprocessing steps to the image before hashing. "
"'unletterbox' removes black borders, and 'rotations' generates all 8 "
"simple rotations."
),
)

ap.add_argument(
"--black-threshold",
type=int,
default=40,
help="Set the black threshold for unletterboxing. Default is 40.",
default=10,
help=(
"Set the black threshold for unletterboxing (default: 5)."
"Only applies when 'unletterbox' is selected in --preprocess."
),
)

ap.add_argument(
"--save-output",
type=bool,
default=False,
help="If true, save the processed image as a new file.",
)

ap.add_argument(
"--rotations",
"--R",
action="store_true",
help="for photos, generate all 8 simple rotations",
help="If true, saves the processed image as a new file.",
)

def __init__(
self,
content_type: t.Type[ContentType],
signal_type: t.Optional[t.Type[SignalType]],
files: t.List[pathlib.Path],
rotations: bool = False,
preprocess: t.Optional[str] = None,
black_threshold: int = 40,
photo_preprocess: t.Optional[str] = None,
black_threshold: int = 0,
save_output: bool = False,
) -> None:
self.content_type = content_type
self.signal_type = signal_type
self.preprocess = preprocess
self.photo_preprocess = photo_preprocess
self.black_threshold = black_threshold
self.save_output = save_output
self.files = files

self.rotations = rotations
if self.photo_preprocess and not issubclass(self.content_type, PhotoContent):
raise CommandError(
"--photo-preprocess flag is only available for Photo content type", 2
)

def execute(self, settings: CLISettings) -> None:
hashers = [
Expand All @@ -141,46 +141,44 @@ def execute(self, settings: CLISettings) -> None:

hashers = [self.signal_type] # type: ignore # can't detect intersection types

if not self.rotations:
if self.photo_preprocess:
for file in self.files:
for hasher in hashers:
if isinstance(hasher, PdqSignal) and (
self.content_type.get_name() == "photo"
and self.preprocess == "unletterbox"
):
hash_str = PdqSignal.hash_from_bytes(
PhotoContent.unletterbox(
file, self.save_output, self.black_threshold
)
)
else:
hash_str = hasher.hash_from_file(file)
if hash_str:
print(hasher.get_name(), hash_str)
return

if not issubclass(self.content_type, PhotoContent):
raise CommandError(
"--rotations flag is only available for Photo content type", 2
)

for file in self.files:
with open(file, "rb") as f:
if (
self.content_type.get_name() == "photo"
and self.preprocess == "unletterbox"
):
image_bytes = PhotoContent.unletterbox(
file, self.save_output, self.black_threshold
updated_bytes: t.List[bytes] = []
rotation_type = []
if self.photo_preprocess == "unletterbox":
updated_bytes.append(
PhotoContent.unletterbox(str(file), self.black_threshold)
)
else:
image_bytes = f.read()
rotated_images = PhotoContent.all_simple_rotations(image_bytes)
for rotation_type, rotated_bytes in rotated_images.items():
with tempfile.NamedTemporaryFile() as temp_file: # Create a temporary file to hold the byte data
temp_file.write(rotated_bytes)
elif self.photo_preprocess == "rotations":
with open(file, "rb") as f:
image_bytes = f.read()
rotations = PhotoContent.all_simple_rotations(image_bytes)
rotation_type, updated_bytes = list(rotations.keys()), list(
rotations.values()
)
for idx, bytes_data in enumerate(updated_bytes):
with tempfile.NamedTemporaryFile() as temp_file:
temp_file.write(bytes_data)
temp_file_path = pathlib.Path(temp_file.name)
for hasher in hashers:
hash_str = hasher.hash_from_file(temp_file_path)
if hash_str:
print(rotation_type.name, hasher.get_name(), hash_str)
print(
f"{rotation_type[idx].name if rotation_type else ''} {hasher.get_name()} {hash_str}"
)
if self.save_output:
suffix = (
f"_{rotation_type[idx].name}"
if rotation_type
else "_unletterboxed"
)
output_path = file.with_stem(f"{file.stem}{suffix}")
with open(output_path, "wb") as output_file:
output_file.write(bytes_data)
print(f"Processed image saved to: {output_path}")
else:
for file in self.files:
for hasher in hashers:
hash_str = hasher.hash_from_file(file)
if hash_str:
print(hasher.get_name(), hash_str)
52 changes: 49 additions & 3 deletions python-threatexchange/threatexchange/cli/tests/hash_cmd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def test_rotations_with_non_photo_content(
"""Test that rotation flag raises error with non-photo content"""
for content_type in ["url", "text", "video"]:
hash_cli.assert_cli_usage_error(
("--rotations", content_type, str(tmp_file)),
msg_regex="--rotations flag is only available for Photo content type",
("--photo-preprocess=rotations", content_type, str(tmp_file)),
msg_regex="--photo-preprocess flag is only available for Photo content type",
)


Expand All @@ -93,7 +93,7 @@ def test_rotations_with_photo_content(hash_cli: ThreatExchangeCLIE2eHelper):
test_file = pathlib.Path("threatexchange/tests/hashing/resources/LA.png")

hash_cli.assert_cli_output(
("--rotations", "photo", str(test_file)),
("--photo-preprocess=rotations", "photo", str(test_file)),
[
"ORIGINAL pdq accb6d39648035f8125c8ce6ba65007de7b54c67a2d93ef7b8f33b0611306715",
"ROTATE90 pdq 1f70cbbc77edc5f9524faa1b18f3b76cd0a04a833e20f645d229d0acc8499c56",
Expand All @@ -105,3 +105,49 @@ def test_rotations_with_photo_content(hash_cli: ThreatExchangeCLIE2eHelper):
"FLIPMINUS1 pdq 5bb15db9e8a1f03c174a380a55aeaa2985bde9c60abce301bde48df918b5c15b",
],
)


def test_unletterbox_with_non_photo_content(
hash_cli: ThreatExchangeCLIE2eHelper, tmp_file: pathlib.Path
):
"""Test that unletterbox flag raises error with non-photo content"""
for content_type in ["url", "text", "video"]:
hash_cli.assert_cli_usage_error(
("--photo-preprocess=unletterbox", content_type, str(tmp_file)),
msg_regex="--photo-preprocess flag is only available for Photo content type",
)


def test_unletterbox_with_photo_content(hash_cli: ThreatExchangeCLIE2eHelper):
"""Test that photo unletterboxing is properly processed"""
test_file = pathlib.Path(
"threatexchange/tests/hashing/resources/letterboxed_sample-b.jpg"
)
clean_file = pathlib.Path("threatexchange/tests/hashing/resources/sample-b.jpg")

hash_cli.assert_cli_output(
("photo", str(clean_file)),
[
"pdq f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22",
],
)

"""Test that photo unletterboxing is chnaged based off of allowed threshold"""
hash_cli.assert_cli_output(
("--photo-preprocess=unletterbox", "photo", str(test_file)),
[
"pdq 58f870cce0f4e84d8e378a32028f63f4b36e26f597621e1d33e6b39c4a9c9b22",
],
)

hash_cli.assert_cli_output(
(
"--photo-preprocess=unletterbox",
"--black-threshold=25",
"photo",
str(test_file),
),
[
"pdq f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22",
],
)
123 changes: 13 additions & 110 deletions python-threatexchange/threatexchange/content_type/photo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import os

from .content_base import ContentType, RotationType
from threatexchange.content_type.preprocess import unletterboxing


class PhotoContent(ContentType):
Expand Down Expand Up @@ -105,118 +106,20 @@ def all_simple_rotations(cls, image_data: bytes):
return rotations

@classmethod
def detect_top_border(
cls, grayscale_img: Image.Image, black_threshold: int = 10
) -> int:
def unletterbox(cls, file_path: str, black_threshold: int = 0) -> bytes:
"""
Detect the top black border by counting rows with only black pixels.
Uses a defualt black threshold of 10 so that only rows with pixel brightness
of 10 or lower will be removed.
Returns the first row that is not all blacked out from the top.
"""
width, height = grayscale_img.size
for y in range(height):
row_pixels = list(grayscale_img.crop((0, y, width, y + 1)).getdata())
if all(pixel < black_threshold for pixel in row_pixels):
continue
return y
return height

@classmethod
def detect_bottom_border(
cls, grayscale_img: Image.Image, black_threshold: int = 10
) -> int:
"""
Detect the bottom black border by counting rows with only black pixels from the bottom up.
Uses a defualt black threshold of 10 so that only rows with pixel brightness
of 10 or lower will be removed.
Returns the first row that is not all blacked out from the bottom.
"""
width, height = grayscale_img.size
for y in range(height - 1, -1, -1):
row_pixels = list(grayscale_img.crop((0, y, width, y + 1)).getdata())
if all(pixel < black_threshold for pixel in row_pixels):
continue
return height - y - 1
return height

@classmethod
def detect_left_border(
cls, grayscale_img: Image.Image, black_threshold: int = 10
) -> int:
Remove black letterbox borders from the sides and top of the image based on the specified black_threshold.
Returns the cleaned image as raw bytes.
"""
Detect the left black border by counting columns with only black pixels.
Uses a defualt black threshold of 10 so that only colums with pixel brightness
of 10 or lower will be removed.
with Image.open(file_path) as image:
top = unletterboxing.detect_top_border(image, black_threshold)
bottom = unletterboxing.detect_bottom_border(image, black_threshold)
left = unletterboxing.detect_left_border(image, black_threshold)
right = unletterboxing.detect_right_border(image, black_threshold)

Returns the first column from the left that is not all blacked out in the column.
"""
width, height = grayscale_img.size
for x in range(width):
col_pixels = list(grayscale_img.crop((x, 0, x + 1, height)).getdata())
if all(pixel < black_threshold for pixel in col_pixels):
continue
return x
return width
width, height = image.size
cropped_img = image.crop((left, top, width - right, height - bottom))

@classmethod
def detect_right_border(
cls, grayscale_img: Image.Image, black_threshold: int = 10
) -> int:
"""
Detect the right black border by counting columns with only black pixels from the right.
Uses a defualt black threshold of 10 so that only colums with pixel brightness
of 10 or lower will be removed.
Returns the first column from the right that is not all blacked out in the column.
"""
width, height = grayscale_img.size
for x in range(width - 1, -1, -1):
col_pixels = list(grayscale_img.crop((x, 0, x + 1, height)).getdata())
if all(pixel < black_threshold for pixel in col_pixels):
continue
return width - x - 1
return width

@classmethod
def unletterbox(
cls, file_path: Path, save_output: bool = False, black_threshold: int = 40
) -> bytes:
"""
Remove black letterbox borders from the sides and top of the image.
Converts the image to grescale then remove the columns and rows that
are all completly blacked out.
Then removing the edges to give back a cleaned image bytes.
Return the new hash of the cleaned image with an option to create a new output file as well
"""
# Open the original image
with Image.open(file_path) as img:
grayscale_img = img.convert("L")

top = cls.detect_top_border(grayscale_img, black_threshold)
bottom = cls.detect_bottom_border(grayscale_img, black_threshold)
left = cls.detect_left_border(grayscale_img, black_threshold)
right = cls.detect_right_border(grayscale_img, black_threshold)

width, height = grayscale_img.size
cropped_box = (left, top, width - right, height - bottom)

cropped_img = img.crop(cropped_box)

# Optionally save the unletterboxed image to a new file in the same directory
if save_output:
path = Path(file_path)
output_path = path.parent / f"{path.stem}_unletterboxed{path.suffix}"
cropped_img.save(output_path)
print(f"Unletterboxed image saved to: {output_path}")

# Convert the cropped image to bytes for hashing
with io.BytesIO() as buffer:
cropped_img.save(buffer, format=img.format)
cropped_image_data = buffer.getvalue()
return cropped_image_data
cropped_img.save(buffer, format=image.format)
return buffer.getvalue()
Loading

0 comments on commit 01c2bf9

Please sign in to comment.