Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyrit/backend/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Pydantic models for API requests and responses.
"""

from pyrit.backend.models._media import DEFAULT_MEDIA_EXTENSIONS
from pyrit.backend.models.attacks import (
AddMessageRequest,
AddMessageResponse,
Expand Down Expand Up @@ -66,6 +67,8 @@
)

__all__ = [
# Media
"DEFAULT_MEDIA_EXTENSIONS",
# Attacks
"AddMessageRequest",
"AddMessageResponse",
Expand Down
13 changes: 10 additions & 3 deletions pyrit/backend/models/_media.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,15 @@
"binary_path": "file",
}

# Fallback extension per prefix when the value carries no usable suffix.
_DEFAULT_EXTENSIONS = {"image": ".png", "audio": ".wav", "video": ".mp4", "file": ".bin"}
# Default file extension per media data type, used when a value carries no usable
# suffix and no MIME metadata is available. Centralized here so the backend response
# models and the attack/converter services share a single source of truth.
DEFAULT_MEDIA_EXTENSIONS: dict[str, str] = {
"image_path": ".png",
"audio_path": ".wav",
"video_path": ".mp4",
"binary_path": ".bin",
}


def infer_mime_type(*, value: str | None, data_type: PromptDataType) -> str | None:
Expand Down Expand Up @@ -79,6 +86,6 @@ def build_filename(*, data_type: str, sha256: str | None, value: str | None) ->
ext = Path(source).suffix

if not ext:
ext = _DEFAULT_EXTENSIONS.get(prefix, ".bin")
ext = DEFAULT_MEDIA_EXTENSIONS.get(data_type, ".bin")

return f"{prefix}_{short_hash}{ext}"
21 changes: 13 additions & 8 deletions pyrit/backend/services/attack_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
request_piece_to_pyrit_message_piece,
request_to_pyrit_message,
)
from pyrit.backend.models import DEFAULT_MEDIA_EXTENSIONS
from pyrit.backend.models.attacks import (
AddMessageRequest,
AddMessageResponse,
Expand Down Expand Up @@ -935,22 +936,26 @@ async def _persist_base64_pieces_async(request: AddMessageRequest) -> None:
except (OSError, ValueError):
pass

# Derive file extension from the MIME type sent by the frontend
ext = None
if piece.mime_type:
ext = mimetypes.guess_extension(piece.mime_type, strict=False)
if not ext:
ext = ".bin"

# Strip data URI prefix if present (e.g. "data:image/png;base64,...")
# The backend itself returns data URIs from pyrit_messages_to_dto_async,
# so the client may echo them back.
value = piece.original_value
data_uri_mime_type = None
if value.startswith("data:"):
# Format: data:<mime>;base64,<payload>
_, _, payload = value.partition(",")
header, _, payload = value.partition(",")
data_uri_mime_type = header.split(":", 1)[1].split(";", 1)[0] if ":" in header else None
value = payload

# Derive file extension from MIME metadata, then fall back to data_type.
ext = None
if piece.mime_type:
ext = mimetypes.guess_extension(piece.mime_type, strict=False)
if not ext and data_uri_mime_type:
ext = mimetypes.guess_extension(data_uri_mime_type, strict=False)
if not ext:
ext = DEFAULT_MEDIA_EXTENSIONS.get(piece.data_type, ".bin")

serializer = data_serializer_factory(
category="prompt-memory-entries",
data_type=cast("PromptDataType", piece.data_type),
Expand Down
14 changes: 4 additions & 10 deletions pyrit/backend/services/converter_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
import uuid
from functools import lru_cache
from pathlib import Path
from typing import Any, ClassVar, Literal, Union, get_args, get_origin
from typing import Any, Literal, Union, get_args, get_origin
from urllib.parse import parse_qs, urlparse

from pyrit.backend.mappers.converter_mappers import converter_object_to_instance
from pyrit.backend.models import DEFAULT_MEDIA_EXTENSIONS
from pyrit.backend.models.converters import (
ConverterCatalogEntry,
ConverterCatalogResponse,
Expand Down Expand Up @@ -79,13 +80,6 @@ class ConverterService:
API metadata is derived from the converter objects.
"""

_DATA_TYPE_EXTENSION: ClassVar[dict[str, str]] = {
"image_path": ".png",
"audio_path": ".wav",
"video_path": ".mp4",
"binary_path": ".bin",
}

def __init__(self) -> None:
"""Initialize the converter service."""
self._registry = ConverterRegistry.get_registry_singleton()
Expand Down Expand Up @@ -254,7 +248,7 @@ async def preview_conversion_async(self, *, request: ConverterPreviewRequest) ->
elif original_value.startswith("data:"):
_, _, value = original_value.partition(",")

ext = self._DATA_TYPE_EXTENSION.get(str(data_type), ".bin")
ext = DEFAULT_MEDIA_EXTENSIONS.get(str(data_type), ".bin")

serializer = data_serializer_factory(
category="prompt-memory-entries",
Expand All @@ -268,7 +262,7 @@ async def preview_conversion_async(self, *, request: ConverterPreviewRequest) ->
pass
else:
# Treat as raw base64
ext = self._DATA_TYPE_EXTENSION.get(str(data_type), ".bin")
ext = DEFAULT_MEDIA_EXTENSIONS.get(str(data_type), ".bin")

serializer = data_serializer_factory(
category="prompt-memory-entries",
Expand Down
69 changes: 66 additions & 3 deletions tests/unit/backend/test_attack_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def make_attack_result(
has_target: bool = True,
name: str = "Test Attack",
outcome: AttackOutcome = AttackOutcome.UNDETERMINED,
created_at: datetime = None,
updated_at: datetime = None,
created_at: datetime | None = None,
updated_at: datetime | None = None,
) -> AttackResult:
"""Create a mock AttackResult for testing."""
now = datetime.now(timezone.utc)
Expand Down Expand Up @@ -125,7 +125,7 @@ def make_mock_piece(
sequence: int = 0,
original_value: str = "test",
converted_value: str = "test",
timestamp: datetime = None,
timestamp: datetime | None = None,
):
"""Create a mock message piece."""
piece = MagicMock()
Expand Down Expand Up @@ -1639,6 +1639,69 @@ async def test_data_uri_prefix_is_stripped_before_saving(self, attack_service) -
mock_serializer.save_b64_image_async.assert_awaited_once_with(data="aW1hZ2VkYXRh")
assert request.pieces[0].original_value == "/saved/image.png"

async def test_data_uri_mime_type_supplies_extension_when_mime_type_missing(self, attack_service) -> None:
"""Data URI media type should prevent image uploads from falling back to blocked .bin files."""
request = AddMessageRequest(
role="user",
pieces=[
MessagePieceRequest(
data_type="image_path",
original_value="data:image/png;base64,aW1hZ2VkYXRh",
),
],
send=False,
target_conversation_id="test-id",
)

mock_serializer = MagicMock()
mock_serializer.save_b64_image_async = AsyncMock()
mock_serializer.value = "/saved/image.png"

with patch(
"pyrit.backend.services.attack_service.data_serializer_factory",
return_value=mock_serializer,
) as factory_mock:
await AttackService._persist_base64_pieces_async(request)

factory_mock.assert_called_once_with(
category="prompt-memory-entries",
data_type="image_path",
extension=".png",
)
mock_serializer.save_b64_image_async.assert_awaited_once_with(data="aW1hZ2VkYXRh")
assert request.pieces[0].original_value == "/saved/image.png"

async def test_path_data_type_supplies_extension_when_mime_type_missing(self, attack_service) -> None:
"""Raw image base64 without MIME metadata should still use a media-serving extension."""
request = AddMessageRequest(
role="user",
pieces=[
MessagePieceRequest(
data_type="image_path",
original_value="aW1hZ2VkYXRh",
),
],
send=False,
target_conversation_id="test-id",
)

mock_serializer = MagicMock()
mock_serializer.save_b64_image_async = AsyncMock()
mock_serializer.value = "/saved/image.png"

with patch(
"pyrit.backend.services.attack_service.data_serializer_factory",
return_value=mock_serializer,
) as factory_mock:
await AttackService._persist_base64_pieces_async(request)

factory_mock.assert_called_once_with(
category="prompt-memory-entries",
data_type="image_path",
extension=".png",
)
assert request.pieces[0].original_value == "/saved/image.png"

async def test_http_url_is_kept_as_is(self, attack_service) -> None:
"""HTTPS blob URLs should not be re-persisted."""
request = AddMessageRequest(
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/backend/test_converter_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ async def test_preview_conversion_chains_multiple_converters(self) -> None:
mock_converter2.convert_async.assert_called_with(prompt="step1_output", input_type="text")

async def test_preview_conversion_persists_data_uri_for_image_path(self) -> None:
"""Data URIs on *_path types are decoded via the _DATA_TYPE_EXTENSION map and persisted."""
"""Data URIs on *_path types are decoded via the DEFAULT_MEDIA_EXTENSIONS map and persisted."""
service = ConverterService()

mock_converter = MagicMock(spec=prompt_converter.PromptConverter)
Expand All @@ -403,7 +403,7 @@ async def test_preview_conversion_persists_data_uri_for_image_path(self) -> None
await service.preview_conversion_async(request=request)

mock_factory.assert_called_once()
# ext is the image_path mapping from _DATA_TYPE_EXTENSION
# ext is the image_path mapping from DEFAULT_MEDIA_EXTENSIONS
assert mock_factory.call_args.kwargs["extension"] == ".png"
assert mock_factory.call_args.kwargs["data_type"] == "image_path"
mock_serializer.save_b64_image_async.assert_awaited_once_with(data="iVBORw0KGgo=")
Expand Down Expand Up @@ -437,7 +437,7 @@ async def test_preview_conversion_persists_raw_base64_for_audio_path(self) -> No
await service.preview_conversion_async(request=request)

mock_factory.assert_called_once()
# ext is the audio_path mapping from _DATA_TYPE_EXTENSION
# ext is the audio_path mapping from DEFAULT_MEDIA_EXTENSIONS
assert mock_factory.call_args.kwargs["extension"] == ".wav"
assert mock_factory.call_args.kwargs["data_type"] == "audio_path"
mock_serializer.save_b64_image_async.assert_awaited_once_with(data=raw_b64)
Expand Down
Loading