diff --git a/open-media-match/pyproject.toml b/open-media-match/pyproject.toml index 1c597990c..f2feb483d 100644 --- a/open-media-match/pyproject.toml +++ b/open-media-match/pyproject.toml @@ -20,13 +20,17 @@ dependencies = [ "flask_sqlalchemy", "flask_migrate", "psycopg2", + "threatexchange", ] [project.optional-dependencies] all = [ "mypy", "black", - "pytest" + "pytest", + "types-Flask-Migrate", + "types-requests", + ] test = [ "pytest" ] diff --git a/open-media-match/src/OpenMediaMatch/__init__.py b/open-media-match/src/OpenMediaMatch/__init__.py index 090f1d815..702288c34 100644 --- a/open-media-match/src/OpenMediaMatch/__init__.py +++ b/open-media-match/src/OpenMediaMatch/__init__.py @@ -1,6 +1,8 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import os +import sys + import flask import flask_migrate import flask_sqlalchemy @@ -16,7 +18,14 @@ def create_app(): Create and configure the Flask app """ app = flask.Flask(__name__) - app.config.from_envvar("OMM_CONFIG") + if "OMM_CONFIG" in os.environ: + app.config.from_envvar("OMM_CONFIG") + elif sys.argv[0].endswith("/flask"): # Default for flask CLI + # The devcontainer settings. If you are using the CLI outside + # the devcontainer and getting an error, just override the env + app.config.from_pyfile("/workspace/.devcontainer/omm_config.py") + else: + raise RuntimeError("No flask config given - try populating OMM_CONFIG env") app.config.update( SQLALCHEMY_DATABASE_URI=app.config.get("DATABASE_URI"), SQLALCHEMY_TRACK_MODIFICATIONS=False, diff --git a/open-media-match/src/OpenMediaMatch/app_resources.py b/open-media-match/src/OpenMediaMatch/app_resources.py new file mode 100644 index 000000000..21fe0f273 --- /dev/null +++ b/open-media-match/src/OpenMediaMatch/app_resources.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +""" +Accessors for various "global" resources, usually cached by request lifetime + +I can't tell if these should just be in app.py, so I'm sticking it here for now, +since one advantage of putting these in functions is we can type the output. +""" + +from flask import g + +from OpenMediaMatch.storage.interface import IUnifiedStore +from OpenMediaMatch.storage.default import DefaultOMMStore + + +def get_storage() -> IUnifiedStore: + """ + Get the storage object, which is just a wrapper around the real storage. + """ + if "storage" not in g: + # dougneal, you'll need to eventually add constructor arguments + # for this to pass in the postgres/database object. We're just + # hiding flask bits from pytx bits + g.storage = DefaultOMMStore() + return g.storage diff --git a/open-media-match/src/OpenMediaMatch/blueprints/hashing.py b/open-media-match/src/OpenMediaMatch/blueprints/hashing.py index 5097c7a76..1106cb0bd 100644 --- a/open-media-match/src/OpenMediaMatch/blueprints/hashing.py +++ b/open-media-match/src/OpenMediaMatch/blueprints/hashing.py @@ -1,5 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +""" +Endpoints for hashing content +""" + +from pathlib import Path +import tempfile +import typing as t + from flask import Blueprint -from flask import abort, request +from flask import abort, request, current_app +import requests + +from threatexchange.content_type.content_base import ContentType +from threatexchange.signal_type.signal_base import FileHasher, SignalType + +from OpenMediaMatch import app_resources bp = Blueprint("hashing", __name__) @@ -10,18 +26,65 @@ def hash_media(): Fetch content and return its hash. TODO: implement """ + + content_type = _parse_request_content_type() + signal_types = _parse_request_signal_type(content_type) + media_url = request.args.get("url", None) if media_url is None: - # path is required, otherwise we don't know what we're hashing. - # TODO: a more helpful message - abort(400) - - hash_types = request.args.get("types", None) - if hash_types is not None: - # TODO: parse this into a list of hash types - pass - - # TODO - # - download the media - # - decode it - # - hash it + abort(400, "url is required") + + download_resp = requests.get(media_url, allow_redirects=True, timeout=30 * 1000) + download_resp.raise_for_status() + + ret = {} + + # For images, we may need to copy the file suffix (.png, jpeg, etc) for it to work + with tempfile.NamedTemporaryFile("wb") as tmp: + current_app.logger.debug("Writing to %s", tmp.name) + tmp.write(download_resp.content) + path = Path(tmp.name) + for st in signal_types.values(): + # At this point, every BytesHasher is a FileHasher, but we could + # explicitly pull those out to avoiding storing any copies of + # data locally, even temporarily + if issubclass(st, FileHasher): + ret[st.get_name()] = st.hash_from_file(path) + return ret + + +def _parse_request_content_type() -> ContentType: + storage = app_resources.get_storage() + arg = request.args.get("content_type", "") + content_type_config = storage.get_content_type_configs().get(arg) + if content_type_config is None: + abort(400, f"no such content_type: '{arg}'") + + if not content_type_config.enabled: + abort(400, f"content_type {arg} is disabled") + + return content_type_config.content_type + + +def _parse_request_signal_type(content_type: ContentType) -> t.Mapping[str, SignalType]: + storage = app_resources.get_storage() + signal_types = storage.get_enabled_signal_types_for_content_type(content_type) + if not signal_types: + abort(500, "No signal types configured!") + signal_type_args = request.args.get("types", None) + if signal_type_args is None: + return signal_types + + ret = {} + for st_name in signal_type_args.split(","): + st_name = st_name.strip() + if not st_name: + continue + if st_name not in signal_types: + abort(400, f"signal type '{st_name}' doesn't exist or is disabled") + ret[st_name] = signal_types[st_name] + + if not ret: + abort(400, "empty signal type selection") + + return ret diff --git a/open-media-match/src/OpenMediaMatch/migrations/env.py b/open-media-match/src/OpenMediaMatch/migrations/env.py index a4418335c..31e250503 100644 --- a/open-media-match/src/OpenMediaMatch/migrations/env.py +++ b/open-media-match/src/OpenMediaMatch/migrations/env.py @@ -1,17 +1,17 @@ import logging from logging.config import fileConfig -from flask import current_app - from alembic import context +from flask import current_app # this is the Alembic Config object, which provides # access to the values within the .ini file in use. +# It's also impossible to type! config = context.config # Interpret the config file for Python logging. # This line sets up loggers basically. -fileConfig(config.config_file_name) +fileConfig(config.config_file_name) # type: ignore logger = logging.getLogger("alembic.env") diff --git a/open-media-match/src/OpenMediaMatch/models.py b/open-media-match/src/OpenMediaMatch/models.py index a759e1786..08b154417 100644 --- a/open-media-match/src/OpenMediaMatch/models.py +++ b/open-media-match/src/OpenMediaMatch/models.py @@ -1,16 +1,16 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -from . import database as db +from OpenMediaMatch import database as db -class Bank(db.Model): +class Bank(db.Model): # type: ignore[name-defined] # mypy not smart enough __tablename__ = "banks" id = db.Column(db.Integer, primary_key=True, autoincrement=True) name = db.Column(db.String(255), nullable=False) enabled = db.Column(db.Boolean, nullable=False) -class Hash(db.Model): +class Hash(db.Model): # type: ignore[name-defined] # mypy not smart enough __tablename__ = "hashes" id = db.Column(db.Integer, primary_key=True, autoincrement=True) enabled = db.Column(db.Boolean, nullable=False) diff --git a/open-media-match/src/OpenMediaMatch/storage/interface.py b/open-media-match/src/OpenMediaMatch/storage/interface.py index a1b5615d4..a278ccd49 100644 --- a/open-media-match/src/OpenMediaMatch/storage/interface.py +++ b/open-media-match/src/OpenMediaMatch/storage/interface.py @@ -33,7 +33,7 @@ class ContentTypeConfig: # Content types that are not enabled should not be used in hashing/matching enabled: bool - signal_type: ContentType + content_type: ContentType class IContentTypeConfigStore(metaclass=abc.ABCMeta): @@ -62,9 +62,27 @@ class ISignalTypeConfigStore(metaclass=abc.ABCMeta): @abc.abstractmethod def get_signal_type_configs(self) -> t.Mapping[str, SignalTypeConfig]: - """ - Return all installed signal types. - """ + """Return all installed signal types.""" + + @t.final + def get_enabled_signal_types(self) -> t.Mapping[str, SignalType]: + """Helper shortcut for getting only enabled SignalTypes""" + return { + k: v.signal_type + for k, v in self.get_signal_type_configs().items() + if v.enabled + } + + @t.final + def get_enabled_signal_types_for_content_type( + self, content_type: ContentType + ) -> t.Mapping[str, SignalType]: + """Helper shortcut for getting enabled types for a piece of content""" + return { + k: v.signal_type + for k, v in self.get_signal_type_configs().items() + if v.enabled and content_type in v.signal_type.get_content_types() + } class ISignalExchangeConfigStore(metaclass=abc.ABCMeta):