Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[py-tx] Add a new match command line option (--rotations) #1672

Merged
merged 21 commits into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
26fe246
Create RotationType enum
haianhng31 Oct 28, 2024
931b81a
Create rotation helpers for PhotoContent
haianhng31 Oct 28, 2024
dfe04d9
Include docstring and comments
haianhng31 Oct 30, 2024
0e44874
Create --rotation command line
haianhng31 Oct 30, 2024
bd532ae
Fix lint and small change to photo rotation helpers
haianhng31 Oct 30, 2024
13bf6e7
Merge branch 'photo-helpers' of https://github.com/haianhng31/ThreatE…
haianhng31 Oct 30, 2024
594782b
create dataclass IndexMatchWithRotation and use that in match_cmd
haianhng31 Oct 30, 2024
f8a5891
Merge branch 'main' into rotation-cmd
haianhng31 Nov 1, 2024
e570ef9
merge main
haianhng31 Nov 2, 2024
02ff144
edit _IndexMatchWithRotation and updates its usage in match_cmd
haianhng31 Nov 2, 2024
a9db9d9
add --rotations to hash cmd
haianhng31 Nov 2, 2024
ad9a165
Minor fix to hash cmd
haianhng31 Nov 2, 2024
a1a39d3
fix small detail in hash cmd
haianhng31 Nov 2, 2024
6a6cea0
Add test cases for Hash command with rotations
haianhng31 Nov 2, 2024
d841f60
fix minor details & remove old test cases for hash with rotations
haianhng31 Nov 4, 2024
3b83a41
checkout hash_cmd.py from main (hash_cmd --rotations is handled in a …
haianhng31 Nov 4, 2024
dff8c1b
Merge branch 'main' of https://github.com/facebook/ThreatExchange int…
haianhng31 Nov 5, 2024
e367f61
match rotation tests
haianhng31 Nov 5, 2024
509a64a
debug s_type in _match_file
haianhng31 Nov 6, 2024
960fea1
Edit match with rotation test cases
haianhng31 Nov 6, 2024
63b460f
delete unused lines
haianhng31 Nov 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 91 additions & 15 deletions python-threatexchange/threatexchange/cli/match_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,30 @@
Match command for parsing simple data sources against the dataset.
"""

from dataclasses import dataclass, field
import argparse
import logging
import pathlib
import typing as t

import tempfile

from threatexchange import common
from threatexchange.cli.fetch_cmd import FetchCommand
from threatexchange.cli.helpers import FlexFilesInputAction
from threatexchange.exchanges.fetch_state import FetchedSignalMetadata

from threatexchange.signal_type.index import IndexMatch, SignalTypeIndex
from threatexchange.signal_type.index import (
IndexMatch,
SignalTypeIndex,
IndexMatchUntyped,
SignalSimilarityInfo,
T,
)
from threatexchange.cli.exceptions import CommandError
from threatexchange.signal_type.signal_base import BytesHasher, SignalType
from threatexchange.cli.cli_config import CLISettings
from threatexchange.content_type.content_base import ContentType
from threatexchange.content_type.content_base import ContentType, RotationType
from threatexchange.content_type.photo import PhotoContent

from threatexchange.signal_type.signal_base import MatchesStr, TextHasher, FileHasher
from threatexchange.cli import command_base
Expand All @@ -29,6 +37,19 @@
TMatcher = t.Callable[[pathlib.Path], t.List[IndexMatch]]


@dataclass
class _IndexMatchWithRotation(t.Generic[T]):
match: IndexMatchUntyped[SignalSimilarityInfo, T]
rotation_type: t.Optional[RotationType] = field(default=None)

def __str__(self):
# Supposed to be without whitespace, but let's make sure
distance_str = "".join(self.match.similarity_info.pretty_str().split())
if self.rotation_type is None:
return distance_str
return f"{distance_str} [{self.rotation_type.name}]"


class MatchCommand(command_base.Command):
"""
Match content to fetched signals
Expand Down Expand Up @@ -126,6 +147,12 @@ def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> No
action="store_true",
help="show all matches, not just one per collaboration",
)
ap.add_argument(
"--rotations",
"-R",
action="store_true",
help="for photos, generate and match all 8 simple rotations",
)

def __init__(
self,
Expand All @@ -136,6 +163,7 @@ def __init__(
show_false_positives: bool,
hide_disputed: bool,
all: bool,
rotations: bool = False,
) -> None:
self.content_type = content_type
self.only_signal = only_signal
Expand All @@ -144,6 +172,7 @@ def __init__(
self.hide_disputed = hide_disputed
self.files = files
self.all = all
self.rotations = rotations

if only_signal and content_type not in only_signal.get_content_types():
raise CommandError(
Expand All @@ -152,6 +181,11 @@ def __init__(
2,
)

if self.rotations and not issubclass(content_type, PhotoContent):
raise CommandError(
"--rotations flag is only available for Photo content type", 2
)

def execute(self, settings: CLISettings) -> None:
if not settings.index.list():
if not settings.in_demo_mode:
Expand Down Expand Up @@ -196,18 +230,23 @@ def execute(self, settings: CLISettings) -> None:
for s_type, index in indices:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before processing this list of files, if rotations is true here, generate new files that you iterate through.

Here's one way to do that

def handle_rotations() -> Iterator[(Path, t.Optional[RotationType)]:
  for file in self.files:
    if not self.rotations:
      yield file, None
      continue
   for rot, dat in PhotoContent.blah():
      with NamedTemporary(...) as f:
         yield Path(f.name), rot

for path in handle_rotations(self.files):

seen = set() # TODO - maybe take the highest certainty?
if self.as_hashes:
results = _match_hashes(path, s_type, index)
results: t.Sequence[_IndexMatchWithRotation] = _match_hashes(
path, s_type, index
)
else:
results = _match_file(path, s_type, index)
results = _match_file(path, s_type, index, rotations=self.rotations)

for r in results:
metadatas: t.List[t.Tuple[str, FetchedSignalMetadata]] = r.metadata
metadatas: t.List[t.Tuple[str, FetchedSignalMetadata]] = (
r.match.metadata
)
distance_str = str(r)

for collab, fetched_data in metadatas:
if not self.all and collab in seen:
continue
seen.add(collab)
# Supposed to be without whitespace, but let's make sure
distance_str = "".join(r.similarity_info.pretty_str().split())

print(
s_type.get_name(),
distance_str,
Expand All @@ -217,18 +256,54 @@ def execute(self, settings: CLISettings) -> None:


def _match_file(
path: pathlib.Path, s_type: t.Type[SignalType], index: SignalTypeIndex
) -> t.Sequence[IndexMatch]:
path: pathlib.Path,
s_type: t.Type[SignalType],
index: SignalTypeIndex,
rotations: bool = False,
) -> t.Sequence[_IndexMatchWithRotation]:
if issubclass(s_type, MatchesStr):
return index.query(path.read_text())
matches = index.query(path.read_text())
return [_IndexMatchWithRotation(match=match) for match in matches]

assert issubclass(s_type, FileHasher)
return index.query(s_type.hash_from_file(path))

if not rotations or s_type != PhotoContent:
matches = index.query(s_type.hash_from_file(path))
return [_IndexMatchWithRotation(match=match) for match in matches]

# Handle rotations for photos
with open(path, "rb") as f:
image_data = f.read()

rotated_images: t.Dict[RotationType, bytes] = PhotoContent.all_simple_rotations(
image_data
)
all_matches = []

for rotation_type, rotated_bytes in rotated_images.items():
# Create a temporary file to hold the image bytes
with tempfile.NamedTemporaryFile() as temp_file:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blocking: This is going to write a lot of files over the course of an execution! We are going to call this method for every single match type.

Another approach you could do is refactor this so that the rotated images are inserted higher up in the stack, and then rather than taking rotations: bool, you could pass in the path of the rotation, and an optional enum representing which enum it can take.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking of change s_type to be subclass of BytesHasher so that I can use bytes directly with hash_from_bytes. All the tests passed and mypy doesn't complain. What do you think of this approach?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the last PR, I mentioned the downside - not all photo formats are knowable without the extension. There's likely a workaround, but let's stay with the current course and see if we can fix it in a followup instead.

temp_file.write(rotated_bytes)
temp_file_path = pathlib.Path(temp_file.name)
matches = index.query(s_type.hash_from_file(temp_file_path))
temp_file_path.unlink() # Clean up the temporary file
haianhng31 marked this conversation as resolved.
Show resolved Hide resolved

# Add rotation information if any matches were found
matches_with_rotations = []
for match in matches:
matches_with_rotations.append(
_IndexMatchWithRotation(match=match, rotation_type=rotation_type)
)

all_matches.extend(matches_with_rotations)

return all_matches


def _match_hashes(
path: pathlib.Path, s_type: t.Type[SignalType], index: SignalTypeIndex
) -> t.Sequence[IndexMatch]:
ret: t.List[IndexMatch] = []
) -> t.Sequence[_IndexMatchWithRotation]:
ret: t.List[_IndexMatchWithRotation] = []
for hash in path.read_text().splitlines():
hash = hash.strip()
if not hash:
Expand All @@ -244,5 +319,6 @@ def _match_hashes(
f"{hash_repr} from {path} is not a valid hash for {s_type.get_name()}",
2,
)
ret.extend(index.query(hash))
matches = index.query(hash)
ret.extend([_IndexMatchWithRotation(match=match) for match in matches])
return ret
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import pathlib
import tempfile
from threatexchange.cli.tests.e2e_test_helper import ThreatExchangeCLIE2eTest
from threatexchange.cli.tests.e2e_test_helper import ThreatExchangeCLIE2eHelper, ThreatExchangeCLIE2eTest
from threatexchange.content_type.content_base import RotationType
from threatexchange.content_type.photo import PhotoContent
from threatexchange.signal_type.md5 import VideoMD5Signal


Expand Down Expand Up @@ -31,3 +34,33 @@ def test_invalid_hash(self):
("-H", "video", "--", not_hash),
f"{not_hash!r} from .* is not a valid hash for video_md5",
)

def test_non_photo_match_with_rotations(self):
with tempfile.NamedTemporaryFile() as f:
for content_type in ["url", "text", "video"]:
self.assert_cli_usage_error(
("--rotations", content_type, f.name),
msg_regex="--rotations flag is only available for Photo content type",
)

def test_photo_hash_with_rotations(self):
test_file = pathlib.Path("threatexchange/tests/hashing/resources/rgb.jpeg")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Danger! This photo might not correspond to one of the sample signals. Check out the comments on https://github.com/facebook/ThreatExchange/blob/main/python-threatexchange/threatexchange/signal_type/pdq/signal.py#L90

Additionally, this makes the test not work if run from different directories! Here's how I fixed it:

Suggested change
test_file = pathlib.Path("threatexchange/tests/hashing/resources/rgb.jpeg")
test_file = pathlib.Path(
__file__ + "../../../../../../pdq/data/bridge-mods/aaa-orig.jpg"
).resolve()


hash_cmd = ThreatExchangeCLIE2eHelper()
hash_cmd.COMMON_CALL_ARGS = ("hash",)
hash_cmd._state_dir = pathlib.Path()

hash = hash_cmd.cli_call("photo", str(test_file))
assert hash == "pdq fb4eed46cb8a6c78819ca06b756c541f7b07ef6d02c82fccd00f862166272cda\n"

# rotated_images = PhotoContent.all_simple_rotations(test_file.read_bytes())

# img = rotated_images[RotationType.ROTATE90] #try with 1 rotated image first

# with tempfile.NamedTemporaryFile() as tmp_file:
# img = rotated_images[RotationType.ROTATE90]
# tmp_file.write(img)
# self.assert_cli_output(
# ("--rotations", "photo", "--", tmp_file.name),
# "video_md5 - (Sample Signals) INVESTIGATION_SEED",
# )
Copy link
Contributor Author

@haianhng31 haianhng31 Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My current approach for this test:

  1. Hash an image
  2. Fetch
  3. Call match --rotations the rotated version of the original image

Q: @Dcallies

  • I got the hash pdq of the image (line 53) but how do i fetch indexes and call match command on this hash?
  • If there is a match found, what should the output look like? When I try it with the CLI it keeps output
    pdq 16 (Sample Signals) INVESTIGATION_SEED
    or if i'm using the config collab:
    pdq 0 (file.txt) INVESTIGATION_SEED

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is a match, then the output is the matching collaboration and distance.

Consider this ouptput:

pdq 16 (Sample Signals) INVESTIGATION_SEED

signal_type distance collab name confidence
pdq 16 Sample Signals INVESTIGATION_SEED

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's what I did:

First I sanity checked that we expect the produced hashes to match

$ tx hash --rotations photo /workspaces/devcontainer-ThreatExchange/pdq/data/bridge-mods/aaa-orig.jpg 
ORIGINAL pdq f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22
ROTATE90 pdq 30a10efd71cc3d429013d48d0ffffc52e34e0e17ada952a9d29685211ea9e5af
ROTATE180 pdq adad5a64b5a142e75b62a09857da895ae63b847fc23794b766b319361bc93188
ROTATE270 pdq a5f0a457a48995e8c9065c275aaa5498b61ba4bdf8fcf80387c32f8b1bfc4f05
FLIPX pdq f8f80f31e0f417b20e37f5cd028f980fb36ed02a9662c1e233e64c634e9c64dd
FLIPY pdq 0dad2599b1a1bd1a5362576742da32a5e63b7380c2374b4866b366c91bc9ce77
FLIPPLUS1 pdq f0a5e102f1ccc0bd945308720fff038de34ef1e8ada9a956d2967ade5ea91a50
FLIPMINUS1 pdq a5f05aa8a4896a17c906a2d85aaaab07b61b5b42f8fc07fc87c3d0741bfcb0fa

# Check original
$ tx match photo -H f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22
pdq 16 (Sample Signals) INVESTIGATION_SEED
# Note - this should be distance 0, but it's stopping at the first match, which is one of the variations

# Check with rotation
$ tx match --rotations photo /workspaces/devcontainer-ThreatExchange/pdq/data/bridge-mods/aaa-orig.jpg
pdq 16 (Sample Signals) INVESTIGATION_SEED

However, I can see an error here - why isn't the rotation printed? When I looked to answer this question I found a bug in your code, which I found by adding some helpful print statements. I believe you can find that bug too!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your comment! I found the bug. I shouldn't have checked s_type == PhotoContent since it's SignalType not ContentType

Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
This records all the valid signal types for a piece of content.
"""

from enum import Enum, auto
from enum import Enum
Dcallies marked this conversation as resolved.
Show resolved Hide resolved
import typing as t

from threatexchange import common
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
from PIL import Image
import io
import typing as t

from .content_base import ContentType, RotationType

Expand Down Expand Up @@ -82,7 +83,7 @@ def flip_minus1(cls, image_data: bytes) -> bytes:
return buffer.getvalue()

@classmethod
def all_simple_rotations(cls, image_data: bytes):
def all_simple_rotations(cls, image_data: bytes) -> t.Dict[RotationType, bytes]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

"""
Generate the 8 naive rotations of an image.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import pickle
import typing as t


T = t.TypeVar("T")
S_Co = t.TypeVar("S_Co", covariant=True, bound="SignalSimilarityInfo")
CT = t.TypeVar("CT", bound="Comparable")
Expand Down
Loading