-
Notifications
You must be signed in to change notification settings - Fork 321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[py-tx] Add a new match command line option (--rotations) #1672
Changes from 18 commits
26fe246
931b81a
dfe04d9
0e44874
bd532ae
13bf6e7
594782b
f8a5891
e570ef9
02ff144
a9db9d9
ad9a165
a1a39d3
6a6cea0
d841f60
3b83a41
dff8c1b
e367f61
509a64a
960fea1
63b460f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,22 +5,30 @@ | |
Match command for parsing simple data sources against the dataset. | ||
""" | ||
|
||
from dataclasses import dataclass, field | ||
import argparse | ||
import logging | ||
import pathlib | ||
import typing as t | ||
|
||
import tempfile | ||
|
||
from threatexchange import common | ||
from threatexchange.cli.fetch_cmd import FetchCommand | ||
from threatexchange.cli.helpers import FlexFilesInputAction | ||
from threatexchange.exchanges.fetch_state import FetchedSignalMetadata | ||
|
||
from threatexchange.signal_type.index import IndexMatch, SignalTypeIndex | ||
from threatexchange.signal_type.index import ( | ||
IndexMatch, | ||
SignalTypeIndex, | ||
IndexMatchUntyped, | ||
SignalSimilarityInfo, | ||
T, | ||
) | ||
from threatexchange.cli.exceptions import CommandError | ||
from threatexchange.signal_type.signal_base import BytesHasher, SignalType | ||
from threatexchange.cli.cli_config import CLISettings | ||
from threatexchange.content_type.content_base import ContentType | ||
from threatexchange.content_type.content_base import ContentType, RotationType | ||
from threatexchange.content_type.photo import PhotoContent | ||
|
||
from threatexchange.signal_type.signal_base import MatchesStr, TextHasher, FileHasher | ||
from threatexchange.cli import command_base | ||
|
@@ -29,6 +37,19 @@ | |
TMatcher = t.Callable[[pathlib.Path], t.List[IndexMatch]] | ||
|
||
|
||
@dataclass | ||
class _IndexMatchWithRotation(t.Generic[T]): | ||
match: IndexMatchUntyped[SignalSimilarityInfo, T] | ||
rotation_type: t.Optional[RotationType] = field(default=None) | ||
|
||
def __str__(self): | ||
# Supposed to be without whitespace, but let's make sure | ||
distance_str = "".join(self.match.similarity_info.pretty_str().split()) | ||
if self.rotation_type is None: | ||
return distance_str | ||
return f"{distance_str} [{self.rotation_type.name}]" | ||
|
||
|
||
class MatchCommand(command_base.Command): | ||
""" | ||
Match content to fetched signals | ||
|
@@ -126,6 +147,12 @@ def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> No | |
action="store_true", | ||
help="show all matches, not just one per collaboration", | ||
) | ||
ap.add_argument( | ||
"--rotations", | ||
"-R", | ||
action="store_true", | ||
help="for photos, generate and match all 8 simple rotations", | ||
) | ||
|
||
def __init__( | ||
self, | ||
|
@@ -136,6 +163,7 @@ def __init__( | |
show_false_positives: bool, | ||
hide_disputed: bool, | ||
all: bool, | ||
rotations: bool = False, | ||
) -> None: | ||
self.content_type = content_type | ||
self.only_signal = only_signal | ||
|
@@ -144,6 +172,7 @@ def __init__( | |
self.hide_disputed = hide_disputed | ||
self.files = files | ||
self.all = all | ||
self.rotations = rotations | ||
|
||
if only_signal and content_type not in only_signal.get_content_types(): | ||
raise CommandError( | ||
|
@@ -152,6 +181,11 @@ def __init__( | |
2, | ||
) | ||
|
||
if self.rotations and not issubclass(content_type, PhotoContent): | ||
raise CommandError( | ||
"--rotations flag is only available for Photo content type", 2 | ||
) | ||
|
||
def execute(self, settings: CLISettings) -> None: | ||
if not settings.index.list(): | ||
if not settings.in_demo_mode: | ||
|
@@ -196,18 +230,23 @@ def execute(self, settings: CLISettings) -> None: | |
for s_type, index in indices: | ||
seen = set() # TODO - maybe take the highest certainty? | ||
if self.as_hashes: | ||
results = _match_hashes(path, s_type, index) | ||
results: t.Sequence[_IndexMatchWithRotation] = _match_hashes( | ||
path, s_type, index | ||
) | ||
else: | ||
results = _match_file(path, s_type, index) | ||
results = _match_file(path, s_type, index, rotations=self.rotations) | ||
|
||
for r in results: | ||
metadatas: t.List[t.Tuple[str, FetchedSignalMetadata]] = r.metadata | ||
metadatas: t.List[t.Tuple[str, FetchedSignalMetadata]] = ( | ||
r.match.metadata | ||
) | ||
distance_str = str(r) | ||
|
||
for collab, fetched_data in metadatas: | ||
if not self.all and collab in seen: | ||
continue | ||
seen.add(collab) | ||
# Supposed to be without whitespace, but let's make sure | ||
distance_str = "".join(r.similarity_info.pretty_str().split()) | ||
|
||
print( | ||
s_type.get_name(), | ||
distance_str, | ||
|
@@ -217,18 +256,54 @@ def execute(self, settings: CLISettings) -> None: | |
|
||
|
||
def _match_file( | ||
path: pathlib.Path, s_type: t.Type[SignalType], index: SignalTypeIndex | ||
) -> t.Sequence[IndexMatch]: | ||
path: pathlib.Path, | ||
s_type: t.Type[SignalType], | ||
index: SignalTypeIndex, | ||
rotations: bool = False, | ||
) -> t.Sequence[_IndexMatchWithRotation]: | ||
if issubclass(s_type, MatchesStr): | ||
return index.query(path.read_text()) | ||
matches = index.query(path.read_text()) | ||
return [_IndexMatchWithRotation(match=match) for match in matches] | ||
|
||
assert issubclass(s_type, FileHasher) | ||
return index.query(s_type.hash_from_file(path)) | ||
|
||
if not rotations or s_type != PhotoContent: | ||
matches = index.query(s_type.hash_from_file(path)) | ||
return [_IndexMatchWithRotation(match=match) for match in matches] | ||
|
||
# Handle rotations for photos | ||
with open(path, "rb") as f: | ||
image_data = f.read() | ||
|
||
rotated_images: t.Dict[RotationType, bytes] = PhotoContent.all_simple_rotations( | ||
image_data | ||
) | ||
all_matches = [] | ||
|
||
for rotation_type, rotated_bytes in rotated_images.items(): | ||
# Create a temporary file to hold the image bytes | ||
with tempfile.NamedTemporaryFile() as temp_file: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. blocking: This is going to write a lot of files over the course of an execution! We are going to call this method for every single match type. Another approach you could do is refactor this so that the rotated images are inserted higher up in the stack, and then rather than taking rotations: bool, you could pass in the path of the rotation, and an optional enum representing which enum it can take. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm thinking of change s_type to be subclass of BytesHasher so that I can use bytes directly with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the last PR, I mentioned the downside - not all photo formats are knowable without the extension. There's likely a workaround, but let's stay with the current course and see if we can fix it in a followup instead. |
||
temp_file.write(rotated_bytes) | ||
temp_file_path = pathlib.Path(temp_file.name) | ||
matches = index.query(s_type.hash_from_file(temp_file_path)) | ||
temp_file_path.unlink() # Clean up the temporary file | ||
haianhng31 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Add rotation information if any matches were found | ||
matches_with_rotations = [] | ||
for match in matches: | ||
matches_with_rotations.append( | ||
_IndexMatchWithRotation(match=match, rotation_type=rotation_type) | ||
) | ||
|
||
all_matches.extend(matches_with_rotations) | ||
|
||
return all_matches | ||
|
||
|
||
def _match_hashes( | ||
path: pathlib.Path, s_type: t.Type[SignalType], index: SignalTypeIndex | ||
) -> t.Sequence[IndexMatch]: | ||
ret: t.List[IndexMatch] = [] | ||
) -> t.Sequence[_IndexMatchWithRotation]: | ||
ret: t.List[_IndexMatchWithRotation] = [] | ||
for hash in path.read_text().splitlines(): | ||
hash = hash.strip() | ||
if not hash: | ||
|
@@ -244,5 +319,6 @@ def _match_hashes( | |
f"{hash_repr} from {path} is not a valid hash for {s_type.get_name()}", | ||
2, | ||
) | ||
ret.extend(index.query(hash)) | ||
matches = index.query(hash) | ||
ret.extend([_IndexMatchWithRotation(match=match) for match in matches]) | ||
return ret |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -1,7 +1,10 @@ | ||||||||||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||||||
|
||||||||||
import pathlib | ||||||||||
import tempfile | ||||||||||
from threatexchange.cli.tests.e2e_test_helper import ThreatExchangeCLIE2eTest | ||||||||||
from threatexchange.cli.tests.e2e_test_helper import ThreatExchangeCLIE2eHelper, ThreatExchangeCLIE2eTest | ||||||||||
from threatexchange.content_type.content_base import RotationType | ||||||||||
from threatexchange.content_type.photo import PhotoContent | ||||||||||
from threatexchange.signal_type.md5 import VideoMD5Signal | ||||||||||
|
||||||||||
|
||||||||||
|
@@ -31,3 +34,33 @@ def test_invalid_hash(self): | |||||||||
("-H", "video", "--", not_hash), | ||||||||||
f"{not_hash!r} from .* is not a valid hash for video_md5", | ||||||||||
) | ||||||||||
|
||||||||||
def test_non_photo_match_with_rotations(self): | ||||||||||
with tempfile.NamedTemporaryFile() as f: | ||||||||||
for content_type in ["url", "text", "video"]: | ||||||||||
self.assert_cli_usage_error( | ||||||||||
("--rotations", content_type, f.name), | ||||||||||
msg_regex="--rotations flag is only available for Photo content type", | ||||||||||
) | ||||||||||
|
||||||||||
def test_photo_hash_with_rotations(self): | ||||||||||
test_file = pathlib.Path("threatexchange/tests/hashing/resources/rgb.jpeg") | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Danger! This photo might not correspond to one of the sample signals. Check out the comments on https://github.com/facebook/ThreatExchange/blob/main/python-threatexchange/threatexchange/signal_type/pdq/signal.py#L90 Additionally, this makes the test not work if run from different directories! Here's how I fixed it:
Suggested change
|
||||||||||
|
||||||||||
hash_cmd = ThreatExchangeCLIE2eHelper() | ||||||||||
hash_cmd.COMMON_CALL_ARGS = ("hash",) | ||||||||||
hash_cmd._state_dir = pathlib.Path() | ||||||||||
|
||||||||||
hash = hash_cmd.cli_call("photo", str(test_file)) | ||||||||||
assert hash == "pdq fb4eed46cb8a6c78819ca06b756c541f7b07ef6d02c82fccd00f862166272cda\n" | ||||||||||
|
||||||||||
# rotated_images = PhotoContent.all_simple_rotations(test_file.read_bytes()) | ||||||||||
|
||||||||||
# img = rotated_images[RotationType.ROTATE90] #try with 1 rotated image first | ||||||||||
|
||||||||||
# with tempfile.NamedTemporaryFile() as tmp_file: | ||||||||||
# img = rotated_images[RotationType.ROTATE90] | ||||||||||
# tmp_file.write(img) | ||||||||||
# self.assert_cli_output( | ||||||||||
# ("--rotations", "photo", "--", tmp_file.name), | ||||||||||
# "video_md5 - (Sample Signals) INVESTIGATION_SEED", | ||||||||||
# ) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My current approach for this test:
Q: @Dcallies
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If there is a match, then the output is the matching collaboration and distance. Consider this ouptput:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here's what I did: First I sanity checked that we expect the produced hashes to match
However, I can see an error here - why isn't the rotation printed? When I looked to answer this question I found a bug in your code, which I found by adding some helpful print statements. I believe you can find that bug too! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for your comment! I found the bug. I shouldn't have checked s_type == PhotoContent since it's SignalType not ContentType |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
""" | ||
from PIL import Image | ||
import io | ||
import typing as t | ||
|
||
from .content_base import ContentType, RotationType | ||
|
||
|
@@ -82,7 +83,7 @@ def flip_minus1(cls, image_data: bytes) -> bytes: | |
return buffer.getvalue() | ||
|
||
@classmethod | ||
def all_simple_rotations(cls, image_data: bytes): | ||
def all_simple_rotations(cls, image_data: bytes) -> t.Dict[RotationType, bytes]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch |
||
""" | ||
Generate the 8 naive rotations of an image. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before processing this list of files, if rotations is true here, generate new files that you iterate through.
Here's one way to do that