Skip to content

Commit

Permalink
[hma] Add hash compare endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
dougneal committed Oct 3, 2024
1 parent b75cd06 commit ff9381d
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 0 deletions.
47 changes: 47 additions & 0 deletions hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,53 @@ def index_status():
return status_by_name


@bp.route("/compare", methods=["POST"])
def compare():
"""
Compare pairs of hashes and get the match distance between them.
Example input:
{
"pdq": ["facd8b...", "facd8b..."],
"not_pdq": ["bdec19...","bdec19..."]
}
Example output
{
"pdq": [
true,
{
"distance": 9
}
],
"not_pdq": 20
true,
{
"distance": 341
}
}
}
"""
request_data = request.get_json()
if type(request_data) != dict:
abort(400, "Request input was not a dict")
storage = get_storage()
results = {}
for signal_type_str in request_data.keys():
hashes_to_compare = request_data.get(signal_type_str)
if type(hashes_to_compare) != list:
abort(400, f"Comparison hashes for {signal_type_str} was not a list")
if hashes_to_compare.__len__() != 2:
abort(400, f"Comparison hash list lenght must be exactly 2")
signal_type = _validate_and_transform_signal_type(signal_type_str, storage)
try:
left = signal_type.validate_signal_str(hashes_to_compare[0])
right = signal_type.validate_signal_str(hashes_to_compare[1])
comparison = signal_type.compare_hash(left, right)
results[signal_type_str] = comparison
except Exception as e:
abort(400, f"Invalid {signal_type_str} hash: {e}")
return results


def initiate_index_cache(app: Flask, scheduler: APScheduler | None) -> None:
assert not hasattr(app, "signal_type_index_cache"), "Aready initialized?"
storage = get_storage()
Expand Down
28 changes: 28 additions & 0 deletions hasher-matcher-actioner/src/OpenMediaMatch/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,31 @@ def test_exchange_api_set_auth(app: Flask, client: FlaskClient):
"supports_authentification": True,
"has_set_authentification": False,
}


def test_compare_hashes(app: Flask, client: FlaskClient):
specimen1 = "facd8bcb2a49bcebdec1985298d5fe84bcd006c187c598c720c3c087b3fdb318"
specimen2 = "facd8bcb2a49bcebdec1985228d5ae84bcd006c187c598c720c2b087b3fdb318"
# Happy path
resp = client.post("/m/compare", json={"pdq": [specimen1, specimen2]})
assert resp.json == {"pdq": [True, {"distance": 9}]}

# Malformed input
bad_inputs = [
# Not a dict
["banana"],
# Dict, but values are not lists
{"pdq": "banana"},
# List of comparison hashes is empty
{"pdq": []},
# Hashes are invalid
{"pdq": ["banana", "banana"]},
# Too many hashes (must be exactly 2)
{"pdq": [specimen1, specimen2, specimen1]},
]
for bad_input in bad_inputs:
resp = client.post(
"/m/compare",
json=bad_input,
)
assert resp.status_code == 400

0 comments on commit ff9381d

Please sign in to comment.