Skip to content

Commit

Permalink
[HMA] Remove ISignalTypeConfigStore from HMA (#1724)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mackay-Fisher authored Dec 20, 2024
1 parent eefe31a commit 3163d99
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

from OpenMediaMatch.background_tasks.development import get_apscheduler
from OpenMediaMatch.persistence import get_storage
from OpenMediaMatch.storage.interface import ISignalExchangeStore, SignalTypeConfig
from OpenMediaMatch.storage.interface import ISignalExchangeStore
from threatexchange.cli.storage.interfaces import SignalTypeConfig
from OpenMediaMatch.utils.time_utils import duration_to_human_str

logger = logging.getLogger(__name__)
Expand Down
66 changes: 1 addition & 65 deletions hasher-matcher-actioner/src/OpenMediaMatch/storage/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import time

import flask

from threatexchange.cli.storage.interfaces import ISignalTypeConfigStore
from threatexchange.utils import dataclass_json
from threatexchange.content_type.content_base import ContentType
from threatexchange.signal_type.signal_base import SignalType
Expand Down Expand Up @@ -60,70 +60,6 @@ def get_content_type_configs(self) -> t.Mapping[str, ContentTypeConfig]:
"""


@dataclass
class SignalTypeConfig:
"""
Holder for SignalType configuration
"""

# Signal types that are not enabled should not be used in hashing/matching
enabled_ratio: float
signal_type: t.Type[SignalType]

@property
def enabled(self) -> bool:
# TODO do a coin flip here, but also refactor this to do seeding
return self.enabled_ratio >= 0.0


class ISignalTypeConfigStore(metaclass=abc.ABCMeta):
"""Interface for accessing SignalType configuration"""

@abc.abstractmethod
def get_signal_type_configs(self) -> t.Mapping[str, SignalTypeConfig]:
"""Return all installed signal types."""

@abc.abstractmethod
def _create_or_update_signal_type_override(
self, signal_type: str, enabled_ratio: float
) -> None:
"""Create or update database entry for a signal type, setting a new value."""

@t.final
def create_or_update_signal_type_override(
self, signal_type: str, enabled_ratio: float
) -> None:
"""Update enabled ratio of an installed signal type."""
installed_signal_types = self.get_signal_type_configs()
if signal_type not in installed_signal_types:
raise ValueError(f"Unknown signal type {signal_type}")
if not (0.0 <= enabled_ratio <= 1.0):
raise ValueError(
f"Invalid enabled ratio {enabled_ratio}. Must be in the range 0.0-1.0 inclusive."
)
self._create_or_update_signal_type_override(signal_type, enabled_ratio)

@t.final
def get_enabled_signal_types(self) -> t.Mapping[str, t.Type[SignalType]]:
"""Helper shortcut for getting only enabled SignalTypes"""
return {
k: v.signal_type
for k, v in self.get_signal_type_configs().items()
if v.enabled
}

@t.final
def get_enabled_signal_types_for_content_type(
self, content_type: t.Type[ContentType]
) -> t.Mapping[str, t.Type[SignalType]]:
"""Helper shortcut for getting enabled types for a piece of content"""
return {
k: v.signal_type
for k, v in self.get_signal_type_configs().items()
if v.enabled and content_type in v.signal_type.get_content_types()
}


@dataclass
class SignalTypeIndexBuildCheckpoint:
"""
Expand Down
4 changes: 2 additions & 2 deletions hasher-matcher-actioner/src/OpenMediaMatch/storage/mocked.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)

from OpenMediaMatch.storage import interface
from OpenMediaMatch.storage.interface import SignalTypeConfig
from threatexchange.cli.storage.interfaces import SignalTypeConfig


class MockedUnifiedStore(interface.IUnifiedStore):
Expand Down Expand Up @@ -49,7 +49,7 @@ def get_content_type_configs(self) -> t.Mapping[str, interface.ContentTypeConfig
def get_signal_type_configs(self) -> t.Mapping[str, SignalTypeConfig]:
# Needed to bamboozle mypy into working
s_types: t.Sequence[t.Type[SignalType]] = (PdqSignal, VideoMD5Signal)
return {s.get_name(): interface.SignalTypeConfig(1.0, s) for s in s_types}
return {s.get_name(): SignalTypeConfig(1.0, s) for s in s_types}

def _create_or_update_signal_type_override(
self, signal_type: str, enabled_ratio: float
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
)

from OpenMediaMatch.storage import interface
from threatexchange.cli.storage.interfaces import SignalTypeConfig
from OpenMediaMatch.storage.postgres import database, flask_utils


Expand Down Expand Up @@ -140,12 +141,12 @@ def exchange_api_config_update(
sesh.add(config)
sesh.commit()

def get_signal_type_configs(self) -> t.Mapping[str, interface.SignalTypeConfig]:
def get_signal_type_configs(self) -> t.Mapping[str, SignalTypeConfig]:
# If a signal is installed, then it is enabled by default. But it may be disabled by an
# override in the database.
signal_type_overrides = self._query_signal_type_overrides()
return {
name: interface.SignalTypeConfig(
name: SignalTypeConfig(
signal_type_overrides.get(name, 1.0),
st,
)
Expand Down
42 changes: 21 additions & 21 deletions python-threatexchange/threatexchange/cli/storage/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,6 @@
from threatexchange.signal_type.signal_base import SignalType


@dataclass
class ContentTypeConfig:
"""
Holder for ContentType configuration.
"""

# Content types that are not enabled should not be used in hashing/matching
enabled: bool
content_type: t.Type[ContentType]


class IContentTypeConfigStore(metaclass=abc.ABCMeta):
"""Interface for accessing ContentType configuration"""

@abc.abstractmethod
def get_content_type_configs(self) -> t.Mapping[str, ContentTypeConfig]:
"""
Return all installed content types.
"""


@dataclass
class SignalTypeConfig:
"""
Expand Down Expand Up @@ -117,3 +96,24 @@ def get_enabled_signal_types_for_content_type(
for k, v in self.get_signal_type_configs().items()
if v.enabled and content_type in v.signal_type.get_content_types()
}


@dataclass
class ContentTypeConfig:
"""
Holder for ContentType configuration.
"""

# Content types that are not enabled should not be used in hashing/matching
enabled: bool
content_type: t.Type[ContentType]


class IContentTypeConfigStore(metaclass=abc.ABCMeta):
"""Interface for accessing ContentType configuration"""

@abc.abstractmethod
def get_content_type_configs(self) -> t.Mapping[str, ContentTypeConfig]:
"""
Return all installed content types.
"""

0 comments on commit 3163d99

Please sign in to comment.