From a15af73e0aa03b3f9022d7613224db242fc87086 Mon Sep 17 00:00:00 2001 From: biefan <70761325+biefan@users.noreply.github.com> Date: Sun, 28 Jun 2026 08:19:02 +0000 Subject: [PATCH 1/2] Fix media upload extension fallback --- pyrit/backend/services/attack_service.py | 27 ++++++--- tests/unit/backend/test_attack_service.py | 69 ++++++++++++++++++++++- 2 files changed, 85 insertions(+), 11 deletions(-) diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index e2b09149d5..9c2f21560d 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -73,6 +73,13 @@ class AttackService: Uses PyRIT memory (database) as the source of truth via AttackResult. """ + _DATA_TYPE_EXTENSION = { + "image_path": ".png", + "audio_path": ".wav", + "video_path": ".mp4", + "binary_path": ".bin", + } + def __init__(self) -> None: """Initialize the attack service.""" self._memory = CentralMemory.get_memory_instance() @@ -935,22 +942,26 @@ async def _persist_base64_pieces_async(request: AddMessageRequest) -> None: except (OSError, ValueError): pass - # Derive file extension from the MIME type sent by the frontend - ext = None - if piece.mime_type: - ext = mimetypes.guess_extension(piece.mime_type, strict=False) - if not ext: - ext = ".bin" - # Strip data URI prefix if present (e.g. "data:image/png;base64,...") # The backend itself returns data URIs from pyrit_messages_to_dto_async, # so the client may echo them back. value = piece.original_value + data_uri_mime_type = None if value.startswith("data:"): # Format: data:;base64, - _, _, payload = value.partition(",") + header, _, payload = value.partition(",") + data_uri_mime_type = header.split(":", 1)[1].split(";", 1)[0] if ":" in header else None value = payload + # Derive file extension from MIME metadata, then fall back to data_type. + ext = None + if piece.mime_type: + ext = mimetypes.guess_extension(piece.mime_type, strict=False) + if not ext and data_uri_mime_type: + ext = mimetypes.guess_extension(data_uri_mime_type, strict=False) + if not ext: + ext = AttackService._DATA_TYPE_EXTENSION.get(piece.data_type, ".bin") + serializer = data_serializer_factory( category="prompt-memory-entries", data_type=cast("PromptDataType", piece.data_type), diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index cb30c4fcef..54dd30c464 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -68,8 +68,8 @@ def make_attack_result( has_target: bool = True, name: str = "Test Attack", outcome: AttackOutcome = AttackOutcome.UNDETERMINED, - created_at: datetime = None, - updated_at: datetime = None, + created_at: datetime | None = None, + updated_at: datetime | None = None, ) -> AttackResult: """Create a mock AttackResult for testing.""" now = datetime.now(timezone.utc) @@ -125,7 +125,7 @@ def make_mock_piece( sequence: int = 0, original_value: str = "test", converted_value: str = "test", - timestamp: datetime = None, + timestamp: datetime | None = None, ): """Create a mock message piece.""" piece = MagicMock() @@ -1639,6 +1639,69 @@ async def test_data_uri_prefix_is_stripped_before_saving(self, attack_service) - mock_serializer.save_b64_image_async.assert_awaited_once_with(data="aW1hZ2VkYXRh") assert request.pieces[0].original_value == "/saved/image.png" + async def test_data_uri_mime_type_supplies_extension_when_mime_type_missing(self, attack_service) -> None: + """Data URI media type should prevent image uploads from falling back to blocked .bin files.""" + request = AddMessageRequest( + role="user", + pieces=[ + MessagePieceRequest( + data_type="image_path", + original_value="data:image/png;base64,aW1hZ2VkYXRh", + ), + ], + send=False, + target_conversation_id="test-id", + ) + + mock_serializer = MagicMock() + mock_serializer.save_b64_image_async = AsyncMock() + mock_serializer.value = "/saved/image.png" + + with patch( + "pyrit.backend.services.attack_service.data_serializer_factory", + return_value=mock_serializer, + ) as factory_mock: + await AttackService._persist_base64_pieces_async(request) + + factory_mock.assert_called_once_with( + category="prompt-memory-entries", + data_type="image_path", + extension=".png", + ) + mock_serializer.save_b64_image_async.assert_awaited_once_with(data="aW1hZ2VkYXRh") + assert request.pieces[0].original_value == "/saved/image.png" + + async def test_path_data_type_supplies_extension_when_mime_type_missing(self, attack_service) -> None: + """Raw image base64 without MIME metadata should still use a media-serving extension.""" + request = AddMessageRequest( + role="user", + pieces=[ + MessagePieceRequest( + data_type="image_path", + original_value="aW1hZ2VkYXRh", + ), + ], + send=False, + target_conversation_id="test-id", + ) + + mock_serializer = MagicMock() + mock_serializer.save_b64_image_async = AsyncMock() + mock_serializer.value = "/saved/image.png" + + with patch( + "pyrit.backend.services.attack_service.data_serializer_factory", + return_value=mock_serializer, + ) as factory_mock: + await AttackService._persist_base64_pieces_async(request) + + factory_mock.assert_called_once_with( + category="prompt-memory-entries", + data_type="image_path", + extension=".png", + ) + assert request.pieces[0].original_value == "/saved/image.png" + async def test_http_url_is_kept_as_is(self, attack_service) -> None: """HTTPS blob URLs should not be re-persisted.""" request = AddMessageRequest( From 23f22e5797c7e9c04fbefa1b121d54544f901f08 Mon Sep 17 00:00:00 2001 From: Copilot <223556219+Copilot@users.noreply.github.com> Date: Mon, 29 Jun 2026 07:14:00 -0700 Subject: [PATCH 2/2] Consolidate media extension fallback into shared DEFAULT_MEDIA_EXTENSIONS Replace the duplicated media-type to extension maps in AttackService and ConverterService (and the prefix-keyed _DEFAULT_EXTENSIONS in _media.py) with a single shared DEFAULT_MEDIA_EXTENSIONS constant exported from pyrit.backend.models. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/backend/models/__init__.py | 3 +++ pyrit/backend/models/_media.py | 13 ++++++++++--- pyrit/backend/services/attack_service.py | 10 ++-------- pyrit/backend/services/converter_service.py | 14 ++++---------- tests/unit/backend/test_converter_service.py | 6 +++--- 5 files changed, 22 insertions(+), 24 deletions(-) diff --git a/pyrit/backend/models/__init__.py b/pyrit/backend/models/__init__.py index 3ab8571505..026e5d87f0 100644 --- a/pyrit/backend/models/__init__.py +++ b/pyrit/backend/models/__init__.py @@ -7,6 +7,7 @@ Pydantic models for API requests and responses. """ +from pyrit.backend.models._media import DEFAULT_MEDIA_EXTENSIONS from pyrit.backend.models.attacks import ( AddMessageRequest, AddMessageResponse, @@ -66,6 +67,8 @@ ) __all__ = [ + # Media + "DEFAULT_MEDIA_EXTENSIONS", # Attacks "AddMessageRequest", "AddMessageResponse", diff --git a/pyrit/backend/models/_media.py b/pyrit/backend/models/_media.py index e06efacf15..e91b2a56b1 100644 --- a/pyrit/backend/models/_media.py +++ b/pyrit/backend/models/_media.py @@ -29,8 +29,15 @@ "binary_path": "file", } -# Fallback extension per prefix when the value carries no usable suffix. -_DEFAULT_EXTENSIONS = {"image": ".png", "audio": ".wav", "video": ".mp4", "file": ".bin"} +# Default file extension per media data type, used when a value carries no usable +# suffix and no MIME metadata is available. Centralized here so the backend response +# models and the attack/converter services share a single source of truth. +DEFAULT_MEDIA_EXTENSIONS: dict[str, str] = { + "image_path": ".png", + "audio_path": ".wav", + "video_path": ".mp4", + "binary_path": ".bin", +} def infer_mime_type(*, value: str | None, data_type: PromptDataType) -> str | None: @@ -79,6 +86,6 @@ def build_filename(*, data_type: str, sha256: str | None, value: str | None) -> ext = Path(source).suffix if not ext: - ext = _DEFAULT_EXTENSIONS.get(prefix, ".bin") + ext = DEFAULT_MEDIA_EXTENSIONS.get(data_type, ".bin") return f"{prefix}_{short_hash}{ext}" diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index 9c2f21560d..bd7e1b841a 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -31,6 +31,7 @@ request_piece_to_pyrit_message_piece, request_to_pyrit_message, ) +from pyrit.backend.models import DEFAULT_MEDIA_EXTENSIONS from pyrit.backend.models.attacks import ( AddMessageRequest, AddMessageResponse, @@ -73,13 +74,6 @@ class AttackService: Uses PyRIT memory (database) as the source of truth via AttackResult. """ - _DATA_TYPE_EXTENSION = { - "image_path": ".png", - "audio_path": ".wav", - "video_path": ".mp4", - "binary_path": ".bin", - } - def __init__(self) -> None: """Initialize the attack service.""" self._memory = CentralMemory.get_memory_instance() @@ -960,7 +954,7 @@ async def _persist_base64_pieces_async(request: AddMessageRequest) -> None: if not ext and data_uri_mime_type: ext = mimetypes.guess_extension(data_uri_mime_type, strict=False) if not ext: - ext = AttackService._DATA_TYPE_EXTENSION.get(piece.data_type, ".bin") + ext = DEFAULT_MEDIA_EXTENSIONS.get(piece.data_type, ".bin") serializer = data_serializer_factory( category="prompt-memory-entries", diff --git a/pyrit/backend/services/converter_service.py b/pyrit/backend/services/converter_service.py index 981ba271da..54cc92a56e 100644 --- a/pyrit/backend/services/converter_service.py +++ b/pyrit/backend/services/converter_service.py @@ -19,10 +19,11 @@ import uuid from functools import lru_cache from pathlib import Path -from typing import Any, ClassVar, Literal, Union, get_args, get_origin +from typing import Any, Literal, Union, get_args, get_origin from urllib.parse import parse_qs, urlparse from pyrit.backend.mappers.converter_mappers import converter_object_to_instance +from pyrit.backend.models import DEFAULT_MEDIA_EXTENSIONS from pyrit.backend.models.converters import ( ConverterCatalogEntry, ConverterCatalogResponse, @@ -79,13 +80,6 @@ class ConverterService: API metadata is derived from the converter objects. """ - _DATA_TYPE_EXTENSION: ClassVar[dict[str, str]] = { - "image_path": ".png", - "audio_path": ".wav", - "video_path": ".mp4", - "binary_path": ".bin", - } - def __init__(self) -> None: """Initialize the converter service.""" self._registry = ConverterRegistry.get_registry_singleton() @@ -254,7 +248,7 @@ async def preview_conversion_async(self, *, request: ConverterPreviewRequest) -> elif original_value.startswith("data:"): _, _, value = original_value.partition(",") - ext = self._DATA_TYPE_EXTENSION.get(str(data_type), ".bin") + ext = DEFAULT_MEDIA_EXTENSIONS.get(str(data_type), ".bin") serializer = data_serializer_factory( category="prompt-memory-entries", @@ -268,7 +262,7 @@ async def preview_conversion_async(self, *, request: ConverterPreviewRequest) -> pass else: # Treat as raw base64 - ext = self._DATA_TYPE_EXTENSION.get(str(data_type), ".bin") + ext = DEFAULT_MEDIA_EXTENSIONS.get(str(data_type), ".bin") serializer = data_serializer_factory( category="prompt-memory-entries", diff --git a/tests/unit/backend/test_converter_service.py b/tests/unit/backend/test_converter_service.py index 4cc21e52fd..5be877b234 100644 --- a/tests/unit/backend/test_converter_service.py +++ b/tests/unit/backend/test_converter_service.py @@ -376,7 +376,7 @@ async def test_preview_conversion_chains_multiple_converters(self) -> None: mock_converter2.convert_async.assert_called_with(prompt="step1_output", input_type="text") async def test_preview_conversion_persists_data_uri_for_image_path(self) -> None: - """Data URIs on *_path types are decoded via the _DATA_TYPE_EXTENSION map and persisted.""" + """Data URIs on *_path types are decoded via the DEFAULT_MEDIA_EXTENSIONS map and persisted.""" service = ConverterService() mock_converter = MagicMock(spec=prompt_converter.PromptConverter) @@ -403,7 +403,7 @@ async def test_preview_conversion_persists_data_uri_for_image_path(self) -> None await service.preview_conversion_async(request=request) mock_factory.assert_called_once() - # ext is the image_path mapping from _DATA_TYPE_EXTENSION + # ext is the image_path mapping from DEFAULT_MEDIA_EXTENSIONS assert mock_factory.call_args.kwargs["extension"] == ".png" assert mock_factory.call_args.kwargs["data_type"] == "image_path" mock_serializer.save_b64_image_async.assert_awaited_once_with(data="iVBORw0KGgo=") @@ -437,7 +437,7 @@ async def test_preview_conversion_persists_raw_base64_for_audio_path(self) -> No await service.preview_conversion_async(request=request) mock_factory.assert_called_once() - # ext is the audio_path mapping from _DATA_TYPE_EXTENSION + # ext is the audio_path mapping from DEFAULT_MEDIA_EXTENSIONS assert mock_factory.call_args.kwargs["extension"] == ".wav" assert mock_factory.call_args.kwargs["data_type"] == "audio_path" mock_serializer.save_b64_image_async.assert_awaited_once_with(data=raw_b64)