Skip to content

Commit

Permalink
[py-tx] Implement "File" Content Type for Image or Video subclassing (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Mackay-Fisher authored Dec 23, 2024
1 parent e37e1a5 commit 148d8cc
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 2 deletions.
7 changes: 7 additions & 0 deletions python-threatexchange/threatexchange/cli/hash_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from threatexchange.cli.exceptions import CommandError
from threatexchange.content_type.content_base import ContentType
from threatexchange.content_type.photo import PhotoContent
from threatexchange.content_type.file import FileContent
from threatexchange.content_type.content_base import RotationType

from threatexchange.signal_type.signal_base import FileHasher, SignalType
Expand Down Expand Up @@ -127,6 +128,12 @@ def __init__(
raise CommandError(
"--photo-preprocess flag is only available for Photo content type", 2
)
if issubclass(self.content_type, FileContent):
try:
# Use the first file to determine content type
self.content_type = FileContent.map_to_content_type(self.files[0])
except ValueError as e:
raise CommandError(f"{e}", returncode=2)

def execute(self, settings: CLISettings) -> None:
hashers = [
Expand Down
10 changes: 8 additions & 2 deletions python-threatexchange/threatexchange/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
TATSignalExchangeAPI,
)

from threatexchange.content_type import photo, video, text, url
from threatexchange.content_type import photo, video, text, url, file
from threatexchange.exchanges.signal_exchange_api import SignalExchangeAPI
from threatexchange.exchanges.auth import (
SignalExchangeAPIInvalidAuthException,
Expand Down Expand Up @@ -244,7 +244,13 @@ def _get_settings(
extensions = _get_extended_functionality(config)

signals = interface_validation.SignalTypeMapping(
[photo.PhotoContent, video.VideoContent, url.URLContent, text.TextContent]
[
photo.PhotoContent,
video.VideoContent,
url.URLContent,
text.TextContent,
file.FileContent,
]
+ extensions.content_types,
list(_DEFAULT_SIGNAL_TYPES) + extensions.signal_types,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_config_signal(cli: ThreatExchangeCLIE2eHelper) -> None:

def test_config_content(cli: ThreatExchangeCLIE2eHelper) -> None:
expected = [
"file threatexchange.content_type.file.FileContent",
"photo threatexchange.content_type.photo.PhotoContent",
"text threatexchange.content_type.text.TextContent",
"url threatexchange.content_type.url.URLContent",
Expand Down
120 changes: 120 additions & 0 deletions python-threatexchange/threatexchange/cli/tests/hash_cmd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
import pathlib
import tempfile
import pytest
from PIL import Image, ImageSequence
from threatexchange.cli.tests.e2e_test_helper import (
ThreatExchangeCLIE2eHelper,
te_cli,
)
from threatexchange.content_type.file import FileContent
from threatexchange.signal_type.pdq.pdq_hasher import pdq_from_bytes


@pytest.fixture
Expand Down Expand Up @@ -155,3 +158,120 @@ def test_unletterbox_with_photo_content(hash_cli: ThreatExchangeCLIE2eHelper):
"pdq f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22",
],
)


def test_file_content(hash_cli: ThreatExchangeCLIE2eHelper):
"""
Test that FileContent correctly maps to PhotoContent or VideoContent
and raises errors for unsupported file types.
"""
resources_dir = (
pathlib.Path(__file__).parent.parent.parent / "tests/hashing/resources"
)
# Paths for existing test images
photo_jpg = resources_dir / "sample-b.jpg"
photo_png = resources_dir / "LA.png" # Replace with correct PNG file
photo_jpeg_rgb = resources_dir / "rgb.jpeg"

# JPEG Test Case
hash_cli.assert_cli_output(
("file", str(photo_jpg)),
[
"pdq f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22",
],
)

# PNG Test Case
hash_cli.assert_cli_output(
("file", str(photo_png)),
[
"pdq accb6d39648035f8125c8ce6ba65007de7b54c67a2d93ef7b8f33b0611306715",
],
)

# JPEG with RGB Profile Test Case
hash_cli.assert_cli_output(
("file", str(photo_jpeg_rgb)),
[
"pdq fb4eed46cb8a6c78819ca06b756c541f7b07ef6d02c82fccd00f862166272cda",
],
)

# Create and test a temporary empty MP4 file (Video)
with tempfile.NamedTemporaryFile(suffix=".mp4") as tmp_video_file:
hash_cli.assert_cli_output(
("file", tmp_video_file.name),
[
"video_md5 d41d8cd98f00b204e9800998ecf8427e",
],
)

# Create and test a temporary empty AVI file (Video)
with tempfile.NamedTemporaryFile(suffix=".avi") as tmp_avi_file:
hash_cli.assert_cli_output(
("file", tmp_avi_file.name),
[
"video_md5 d41d8cd98f00b204e9800998ecf8427e",
],
)

# Create and test a temporary empty MOV file (Video)
with tempfile.NamedTemporaryFile(suffix=".mov") as tmp_mov_file:
hash_cli.assert_cli_output(
("file", tmp_mov_file.name),
[
"video_md5 d41d8cd98f00b204e9800998ecf8427e",
],
)

# Create and test a temporary static GIF file (1x1 pixel)
with tempfile.NamedTemporaryFile(suffix=".gif") as tmp_static_gif:
# Create a 100x100 multi-colored image
static_img = Image.new("RGB", (100, 100))
pixels = static_img.load()

# Fill the image with colors to improve quality
if pixels:
for i in range(100):
for j in range(100):
pixels[i, j] = ((i * 5) % 256, (j * 5) % 256, ((i + j) * 5) % 256)

# Save the image as a static GIF
static_img.save(tmp_static_gif.name, format="GIF")
hash_cli.assert_cli_output(
("file", tmp_static_gif.name),
[
"pdq 77ffdd3a9405fbb0805027270fa7d7065e7cf8da0c0d795881002667e44f266f",
],
)

# Create and test a temporary animated GIF file (2 frames)
with tempfile.NamedTemporaryFile(suffix=".gif") as tmp_animated_gif:
animated_frames = [
Image.new("RGB", (1, 1), color=(255, 0, 0)), # Red frame
Image.new("RGB", (1, 1), color=(0, 255, 0)), # Green frame
]
animated_frames[0].save(
tmp_animated_gif.name,
format="GIF",
save_all=True,
append_images=animated_frames[1:],
duration=200, # Frame duration
loop=0,
)
hash_cli.assert_cli_output(
("file", tmp_animated_gif.name),
[
"video_md5 ec82a2d0d4d99a623ec2a939accc7de5",
],
)

# Create and test a temporary unsupported .txt file
with tempfile.NamedTemporaryFile(suffix=".txt") as tmp_unsupported_file:
tmp_unsupported_file.write(b"This is a test file.") # Write dummy text
tmp_unsupported_file.flush()
# Assert that the CLI raises a CommandError for unsupported file type
hash_cli.assert_cli_usage_error(
("file", tmp_unsupported_file.name),
msg_regex="Unsupported file type: .txt",
)
58 changes: 58 additions & 0 deletions python-threatexchange/threatexchange/content_type/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc. and affiliates.

"""
Wrapper around the file content type.
"""
import logging
from pathlib import Path
from .photo import PhotoContent
from .video import VideoContent
from .content_base import ContentType
from PIL import Image
import typing as t

# Initialize the logger for this module
logger = logging.getLogger(__name__)


class FileContent(ContentType):
"""
Content type for general files, capable of routing to the appropriate
specific content type (e.g., PhotoContent or VideoContent) based on file extension.
"""

@classmethod
def map_to_content_type(cls, file_path: Path) -> t.Type[ContentType]:
"""
Map the file to a specific content type based on its extension by taking in file path.
Returns the ContentType subclass or raises error if the file type is unsupported.
"""
extension = file_path.suffix.lower()
logger.info(f"Processing file: {file_path}")
logger.info(f"Detected file extension: {extension}")
content_type: t.Type[ContentType]

if extension in {".jpg", ".jpeg", ".png"}:
content_type = PhotoContent
elif extension in {".mp4", ".avi", ".mov"}:
content_type = VideoContent
elif extension == ".gif":
try:
with Image.open(file_path) as img:
# Check if the GIF is animated
is_animated = getattr(img, "is_animated", False)
if is_animated:
logger.info("File is an animated GIF.")
content_type = VideoContent
else:
logger.info("File is a static GIF.")
content_type = PhotoContent
except Exception as e:
raise ValueError(f"Error processing GIF: {e}")
else:
raise ValueError(f"Unsupported file type: {extension}")

logger.info(f"Content type set to: {content_type.__name__}")
return content_type

0 comments on commit 148d8cc

Please sign in to comment.