From ff9381d0c0770e11b7d2e8d8f37e1459a7d05b54 Mon Sep 17 00:00:00 2001 From: Doug Neal Date: Thu, 3 Oct 2024 14:58:45 +0100 Subject: [PATCH] [hma] Add hash compare endpoint --- .../src/OpenMediaMatch/blueprints/matching.py | 47 +++++++++++++++++++ .../src/OpenMediaMatch/tests/test_api.py | 28 +++++++++++ 2 files changed, 75 insertions(+) diff --git a/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py b/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py index 42fd7efc4..801377196 100644 --- a/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py +++ b/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py @@ -306,6 +306,53 @@ def index_status(): return status_by_name +@bp.route("/compare", methods=["POST"]) +def compare(): + """ + Compare pairs of hashes and get the match distance between them. + Example input: + { + "pdq": ["facd8b...", "facd8b..."], + "not_pdq": ["bdec19...","bdec19..."] + } + Example output + { + "pdq": [ + true, + { + "distance": 9 + } + ], + "not_pdq": 20 + true, + { + "distance": 341 + } + } + } + """ + request_data = request.get_json() + if type(request_data) != dict: + abort(400, "Request input was not a dict") + storage = get_storage() + results = {} + for signal_type_str in request_data.keys(): + hashes_to_compare = request_data.get(signal_type_str) + if type(hashes_to_compare) != list: + abort(400, f"Comparison hashes for {signal_type_str} was not a list") + if hashes_to_compare.__len__() != 2: + abort(400, f"Comparison hash list lenght must be exactly 2") + signal_type = _validate_and_transform_signal_type(signal_type_str, storage) + try: + left = signal_type.validate_signal_str(hashes_to_compare[0]) + right = signal_type.validate_signal_str(hashes_to_compare[1]) + comparison = signal_type.compare_hash(left, right) + results[signal_type_str] = comparison + except Exception as e: + abort(400, f"Invalid {signal_type_str} hash: {e}") + return results + + def initiate_index_cache(app: Flask, scheduler: APScheduler | None) -> None: assert not hasattr(app, "signal_type_index_cache"), "Aready initialized?" storage = get_storage() diff --git a/hasher-matcher-actioner/src/OpenMediaMatch/tests/test_api.py b/hasher-matcher-actioner/src/OpenMediaMatch/tests/test_api.py index c09b1c649..d0b1b378a 100644 --- a/hasher-matcher-actioner/src/OpenMediaMatch/tests/test_api.py +++ b/hasher-matcher-actioner/src/OpenMediaMatch/tests/test_api.py @@ -281,3 +281,31 @@ def test_exchange_api_set_auth(app: Flask, client: FlaskClient): "supports_authentification": True, "has_set_authentification": False, } + + +def test_compare_hashes(app: Flask, client: FlaskClient): + specimen1 = "facd8bcb2a49bcebdec1985298d5fe84bcd006c187c598c720c3c087b3fdb318" + specimen2 = "facd8bcb2a49bcebdec1985228d5ae84bcd006c187c598c720c2b087b3fdb318" + # Happy path + resp = client.post("/m/compare", json={"pdq": [specimen1, specimen2]}) + assert resp.json == {"pdq": [True, {"distance": 9}]} + + # Malformed input + bad_inputs = [ + # Not a dict + ["banana"], + # Dict, but values are not lists + {"pdq": "banana"}, + # List of comparison hashes is empty + {"pdq": []}, + # Hashes are invalid + {"pdq": ["banana", "banana"]}, + # Too many hashes (must be exactly 2) + {"pdq": [specimen1, specimen2, specimen1]}, + ] + for bad_input in bad_inputs: + resp = client.post( + "/m/compare", + json=bad_input, + ) + assert resp.status_code == 400