Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[py-tx] Add a new match command line option (--rotations) #1672

Merged
merged 21 commits into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
26fe246
Create RotationType enum
haianhng31 Oct 28, 2024
931b81a
Create rotation helpers for PhotoContent
haianhng31 Oct 28, 2024
dfe04d9
Include docstring and comments
haianhng31 Oct 30, 2024
0e44874
Create --rotation command line
haianhng31 Oct 30, 2024
bd532ae
Fix lint and small change to photo rotation helpers
haianhng31 Oct 30, 2024
13bf6e7
Merge branch 'photo-helpers' of https://github.com/haianhng31/ThreatE…
haianhng31 Oct 30, 2024
594782b
create dataclass IndexMatchWithRotation and use that in match_cmd
haianhng31 Oct 30, 2024
f8a5891
Merge branch 'main' into rotation-cmd
haianhng31 Nov 1, 2024
e570ef9
merge main
haianhng31 Nov 2, 2024
02ff144
edit _IndexMatchWithRotation and updates its usage in match_cmd
haianhng31 Nov 2, 2024
a9db9d9
add --rotations to hash cmd
haianhng31 Nov 2, 2024
ad9a165
Minor fix to hash cmd
haianhng31 Nov 2, 2024
a1a39d3
fix small detail in hash cmd
haianhng31 Nov 2, 2024
6a6cea0
Add test cases for Hash command with rotations
haianhng31 Nov 2, 2024
d841f60
fix minor details & remove old test cases for hash with rotations
haianhng31 Nov 4, 2024
3b83a41
checkout hash_cmd.py from main (hash_cmd --rotations is handled in a …
haianhng31 Nov 4, 2024
dff8c1b
Merge branch 'main' of https://github.com/facebook/ThreatExchange int…
haianhng31 Nov 5, 2024
e367f61
match rotation tests
haianhng31 Nov 5, 2024
509a64a
debug s_type in _match_file
haianhng31 Nov 6, 2024
960fea1
Edit match with rotation test cases
haianhng31 Nov 6, 2024
63b460f
delete unused lines
haianhng31 Nov 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 84 additions & 13 deletions python-threatexchange/threatexchange/cli/match_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,23 @@
import logging
import pathlib
import typing as t

import tempfile

from threatexchange import common
from threatexchange.cli.fetch_cmd import FetchCommand
from threatexchange.cli.helpers import FlexFilesInputAction
from threatexchange.exchanges.fetch_state import FetchedSignalMetadata

from threatexchange.signal_type.index import IndexMatch, SignalTypeIndex
from threatexchange.signal_type.index import (
IndexMatch,
SignalTypeIndex,
IndexMatchWithRotation,
haianhng31 marked this conversation as resolved.
Show resolved Hide resolved
)
from threatexchange.cli.exceptions import CommandError
from threatexchange.signal_type.signal_base import BytesHasher, SignalType
from threatexchange.cli.cli_config import CLISettings
from threatexchange.content_type.content_base import ContentType
from threatexchange.content_type.content_base import ContentType, RotationType
from threatexchange.content_type.photo import PhotoContent

from threatexchange.signal_type.signal_base import MatchesStr, TextHasher, FileHasher
from threatexchange.cli import command_base
Expand Down Expand Up @@ -126,6 +131,12 @@ def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> No
action="store_true",
help="show all matches, not just one per collaboration",
)
ap.add_argument(
"--rotations",
"-R",
action="store_true",
help="for photos, generate and match all 8 simple rotations",
)

def __init__(
self,
Expand All @@ -136,6 +147,7 @@ def __init__(
show_false_positives: bool,
hide_disputed: bool,
all: bool,
rotations: bool = False,
) -> None:
self.content_type = content_type
self.only_signal = only_signal
Expand All @@ -144,6 +156,7 @@ def __init__(
self.hide_disputed = hide_disputed
self.files = files
self.all = all
self.rotations = rotations

if only_signal and content_type not in only_signal.get_content_types():
raise CommandError(
Expand All @@ -152,6 +165,11 @@ def __init__(
2,
)

if self.rotations and not issubclass(content_type, PhotoContent):
raise CommandError(
"--rotations flag is only available for Photo content type", 2
)

def execute(self, settings: CLISettings) -> None:
if not settings.index.list():
if not settings.in_demo_mode:
Expand Down Expand Up @@ -196,33 +214,86 @@ def execute(self, settings: CLISettings) -> None:
for s_type, index in indices:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before processing this list of files, if rotations is true here, generate new files that you iterate through.

Here's one way to do that

def handle_rotations() -> Iterator[(Path, t.Optional[RotationType)]:
  for file in self.files:
    if not self.rotations:
      yield file, None
      continue
   for rot, dat in PhotoContent.blah():
      with NamedTemporary(...) as f:
         yield Path(f.name), rot

for path in handle_rotations(self.files):

seen = set() # TODO - maybe take the highest certainty?
if self.as_hashes:
results = _match_hashes(path, s_type, index)
results_from_hashes = _match_hashes(path, s_type, index)
results: t.Sequence[IndexMatchWithRotation] = [
IndexMatchWithRotation(match=match)
for match in results_from_hashes
]
Copy link
Contributor Author

@haianhng31 haianhng31 Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created a new dataclass IndexMatchWithRotation which will be the return type of _match_file.

Q: @Dcallies Should I also make it the return type of _match_hashes? Just so the return types of these 2 are similar

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make the python typing pass you may need to be consistent between them (see what mypy says)

else:
results = _match_file(path, s_type, index)
results = _match_file(path, s_type, index, rotations=self.rotations)

for r in results:
metadatas: t.List[t.Tuple[str, FetchedSignalMetadata]] = r.metadata
if isinstance(r, IndexMatchWithRotation):
metadatas: t.List[t.Tuple[str, FetchedSignalMetadata]] = (
r.match.metadata
)
rotation_info = f" [{r.rotation_type.name}]"
# Supposed to be without whitespace, but let's make sure
distance_str = "".join(
r.match.similarity_info.pretty_str().split()
)

elif isinstance(r, IndexMatch):
metadatas = r.metadata
rotation_info = ""
distance_str = "".join(r.similarity_info.pretty_str().split())

for collab, fetched_data in metadatas:
if not self.all and collab in seen:
continue
seen.add(collab)
# Supposed to be without whitespace, but let's make sure
distance_str = "".join(r.similarity_info.pretty_str().split())

print(
s_type.get_name(),
distance_str,
distance_str + rotation_info,
f"({collab})",
fetched_data,
)


def _match_file(
path: pathlib.Path, s_type: t.Type[SignalType], index: SignalTypeIndex
) -> t.Sequence[IndexMatch]:
path: pathlib.Path,
s_type: t.Type[SignalType],
index: SignalTypeIndex,
rotations: bool = False,
) -> t.Sequence[IndexMatchWithRotation]:
if issubclass(s_type, MatchesStr):
return index.query(path.read_text())
matches = index.query(path.read_text())
return [IndexMatchWithRotation(match=match) for match in matches]

assert issubclass(s_type, FileHasher)
return index.query(s_type.hash_from_file(path))

if not rotations or s_type != PhotoContent:
matches = index.query(s_type.hash_from_file(path))
return [IndexMatchWithRotation(match=match) for match in matches]

# Handle rotations for photos
with open(path, "rb") as f:
image_data = f.read()

rotated_images: t.Dict[RotationType, bytes] = PhotoContent.all_simple_rotations(
image_data
)
all_matches = []

for rotation_type, rotated_bytes in rotated_images.items():
# Create a temporary file to hold the image bytes
with tempfile.NamedTemporaryFile() as temp_file:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blocking: This is going to write a lot of files over the course of an execution! We are going to call this method for every single match type.

Another approach you could do is refactor this so that the rotated images are inserted higher up in the stack, and then rather than taking rotations: bool, you could pass in the path of the rotation, and an optional enum representing which enum it can take.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking of change s_type to be subclass of BytesHasher so that I can use bytes directly with hash_from_bytes. All the tests passed and mypy doesn't complain. What do you think of this approach?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the last PR, I mentioned the downside - not all photo formats are knowable without the extension. There's likely a workaround, but let's stay with the current course and see if we can fix it in a followup instead.

temp_file.write(rotated_bytes)
temp_file_path = pathlib.Path(temp_file.name)
matches = index.query(s_type.hash_from_file(temp_file_path))
temp_file_path.unlink() # Clean up the temporary file
haianhng31 marked this conversation as resolved.
Show resolved Hide resolved

# Add rotation information if any matches were found
matches_with_rotations = []
for match in matches:
matches_with_rotations.append(
IndexMatchWithRotation(match=match, rotation_type=rotation_type)
)

all_matches.extend(matches_with_rotations)

return all_matches


def _match_hashes(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
This records all the valid signal types for a piece of content.
"""

from enum import Enum, auto

from enum import Enum
Dcallies marked this conversation as resolved.
Show resolved Hide resolved
import typing as t

from threatexchange import common
Expand Down Expand Up @@ -49,4 +50,4 @@ class RotationType(Enum):
FLIPX = "flipx" # Flip the object horizontally along the X-axis
FLIPY = "flipy" # Flip the object horizontally along the Y-axis
FLIPPLUS1 = "flipplus1" # Diagonal flip along the line y = x
FLIPMINUS1 = "flipminus1" # Diagonal flip along the line y = -x
FLIPMINUS1 = "flipminus1" # Diagonal flip along the line y = -x
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
from PIL import Image
import io
import typing as t

from .content_base import ContentType, RotationType

Expand Down Expand Up @@ -82,7 +83,7 @@ def flip_minus1(cls, image_data: bytes) -> bytes:
return buffer.getvalue()

@classmethod
def all_simple_rotations(cls, image_data: bytes):
def all_simple_rotations(cls, image_data: bytes) -> t.Dict[RotationType, bytes]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

"""
Generate the 8 naive rotations of an image.

Expand Down
7 changes: 7 additions & 0 deletions python-threatexchange/threatexchange/signal_type/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import pickle
import typing as t

from threatexchange.content_type.content_base import RotationType
haianhng31 marked this conversation as resolved.
Show resolved Hide resolved

T = t.TypeVar("T")
S_Co = t.TypeVar("S_Co", covariant=True, bound="SignalSimilarityInfo")
Expand Down Expand Up @@ -139,6 +140,12 @@ def __eq__(self, other: t.Any) -> bool:
IndexMatch = IndexMatchUntyped[SignalSimilarityInfo, T]


@dataclass
class IndexMatchWithRotation(t.Generic[T]):
haianhng31 marked this conversation as resolved.
Show resolved Hide resolved
match: IndexMatchUntyped[SignalSimilarityInfo, T]
rotation_type: RotationType = RotationType.ORIGINAL
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you move it into the match command, consider making this optional, then adding a __str__ command which handles the output of the match command. If rotation_type is None, then no rotation is printed. If the rotation is there, then it can format the string you put to string.



Self = t.TypeVar("Self", bound="SignalTypeIndex")


Expand Down
Loading