Skip to content

Commit 729edfa

Browse files
romanlutzCopilot
andauthored
BUG: Stop leaking absolute media paths and SAS tokens in Attack History 'Last Message' (#1865)
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 23b862d commit 729edfa

14 files changed

Lines changed: 443 additions & 37 deletions

pyrit/backend/mappers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
Centralizes all translation logic so domain models can evolve independently of the API contract.
99
"""
1010

11+
from pyrit.backend.mappers._preview import format_last_message_preview
1112
from pyrit.backend.mappers.attack_mappers import (
1213
attack_result_to_summary,
1314
pyrit_messages_to_dto_async,
@@ -25,6 +26,7 @@
2526
__all__ = [
2627
"attack_result_to_summary",
2728
"converter_object_to_instance",
29+
"format_last_message_preview",
2830
"pyrit_messages_to_dto_async",
2931
"pyrit_scores_to_dto",
3032
"request_piece_to_pyrit_message_piece",

pyrit/backend/mappers/_preview.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
Presentation-layer formatter for ``ConversationStats.last_message_preview``.
6+
7+
Lives in the backend mapper package because the formatting it produces
8+
(``[Image: <basename>]`` etc.) is purely a display concern for the GUI API
9+
responses — the memory layer stays data-agnostic and just stores the raw
10+
value + data type.
11+
12+
The motivating bug: ``converted_value`` for media-path data types
13+
(``image_path`` / ``audio_path`` / ``video_path`` / ``binary_path``) is a
14+
filesystem path or blob URL. Rendering it raw in the Attack History preview
15+
leaks the absolute on-disk location of memory artifacts
16+
(e.g. ``C:\\Users\\<name>\\git\\PyRIT\\dbdata\\...\\1780.mp3``).
17+
"""
18+
19+
from pathlib import PureWindowsPath
20+
from urllib.parse import urlparse
21+
22+
from pyrit.models import MEDIA_PATH_DATA_TYPES, ConversationStats
23+
24+
# Friendly label per media-path data type. Kept here next to the formatter
25+
# so adding a new media type only requires updating one place.
26+
_MEDIA_LABEL: dict[str, str] = {
27+
"image_path": "Image",
28+
"audio_path": "Audio",
29+
"video_path": "Video",
30+
"binary_path": "File",
31+
}
32+
33+
34+
def _derive_basename(value: str) -> str | None:
35+
"""
36+
Return a display-safe basename for *value*.
37+
38+
Args:
39+
value: A filesystem path, URL, or other reference.
40+
41+
Returns:
42+
The basename (filename portion) of *value*, or ``None`` if one can't
43+
be derived (e.g. data URI, empty value).
44+
"""
45+
if not value or value.startswith("data:"):
46+
return None
47+
if value.startswith(("http://", "https://")):
48+
# Strip query string (e.g. SAS tokens) before taking the basename.
49+
parsed = urlparse(value)
50+
name = PureWindowsPath(parsed.path).name
51+
return name or None
52+
# Local path — PureWindowsPath treats both ``/`` and ``\`` as separators,
53+
# so Windows-style paths stored from a Windows host are split correctly
54+
# even when this code runs on a POSIX host (CI, Linux deployments).
55+
return PureWindowsPath(value).name or None
56+
57+
58+
def format_last_message_preview(
59+
*,
60+
value: str | None,
61+
data_type: str | None,
62+
max_len: int = ConversationStats.PREVIEW_MAX_LEN,
63+
) -> str | None:
64+
"""
65+
Build a display string for ``ConversationStats.last_message_preview``.
66+
67+
Media-path data types are rendered as ``[Image: <basename>]`` (and
68+
variants) so the absolute filesystem path of memory artifacts is never
69+
exposed through API responses or UI previews. Text-like data types pass
70+
through with truncation and an ellipsis suffix when they exceed
71+
*max_len*.
72+
73+
Args:
74+
value: Raw ``converted_value`` for the last piece (or ``None``).
75+
data_type: ``converted_value_data_type`` for that piece. ``None``
76+
falls back to the text path.
77+
max_len: Maximum length for text previews before truncation.
78+
79+
Returns:
80+
The formatted preview string, or ``None`` when there is nothing
81+
meaningful to show.
82+
"""
83+
if data_type in MEDIA_PATH_DATA_TYPES:
84+
# MEDIA_PATH_DATA_TYPES guarantees ``data_type`` is a key in
85+
# ``_MEDIA_LABEL`` — both are derived from the same source list.
86+
label = _MEDIA_LABEL[data_type]
87+
basename = _derive_basename(value or "")
88+
return f"[{label}: {basename}]" if basename else f"[{label}]"
89+
90+
if not value:
91+
return None
92+
93+
if len(value) > max_len:
94+
return value[:max_len] + "..."
95+
return value

pyrit/backend/mappers/attack_mappers.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from azure.storage.blob import ContainerSasPermissions, generate_container_sas
2525
from azure.storage.blob.aio import BlobServiceClient
2626

27+
from pyrit.backend.mappers._preview import format_last_message_preview
2728
from pyrit.backend.models.attacks import (
2829
AddMessageRequest,
2930
AttackSummary,
@@ -35,7 +36,7 @@
3536
TargetInfo,
3637
)
3738
from pyrit.common.deprecation import print_deprecation_message
38-
from pyrit.models import AttackResult, ChatMessageRole, PromptDataType
39+
from pyrit.models import MEDIA_PATH_DATA_TYPES, AttackResult, ChatMessageRole, PromptDataType
3940
from pyrit.models import Message as PyritMessage
4041
from pyrit.models import MessagePiece as PyritMessagePiece
4142
from pyrit.models import Score as PyritScore
@@ -50,9 +51,6 @@
5051
# Domain → DTO (for API responses)
5152
# ============================================================================
5253

53-
# Media data types whose values are file paths (local or Azure Blob URLs)
54-
_MEDIA_PATH_TYPES = frozenset({"image_path", "audio_path", "video_path", "binary_path"})
55-
5654
# ---------------------------------------------------------------------------
5755
# Azure Blob SAS token cache
5856
# ---------------------------------------------------------------------------
@@ -172,7 +170,7 @@ def _resolve_media_url(*, value: Optional[str], data_type: str) -> Optional[str]
172170
The value unchanged for non-media types, a ``/api/media?path=...``
173171
URL for local file paths, or the original value for blob URLs / data URIs.
174172
"""
175-
if not value or data_type not in _MEDIA_PATH_TYPES:
173+
if not value or data_type not in MEDIA_PATH_DATA_TYPES:
176174
return value
177175
# Already a URL or data URI — pass through
178176
if value.startswith(("http://", "https://", "data:")):
@@ -227,7 +225,10 @@ def attack_result_to_summary(
227225
AttackSummary DTO ready for the API response.
228226
"""
229227
message_count = stats.message_count
230-
last_preview = stats.last_message_preview
228+
last_preview = format_last_message_preview(
229+
value=stats.last_message_preview,
230+
data_type=stats.last_message_data_type,
231+
)
231232

232233
# Merge attack-result labels with conversation-level labels.
233234
# Conversation labels take precedence on key collision.
@@ -297,7 +298,9 @@ def pyrit_scores_to_dto(scores: list[PyritScore]) -> list[Score]:
297298
return [
298299
Score(
299300
score_id=str(score.id),
300-
scorer_type=score.scorer_class_identifier.class_name,
301+
scorer_type=(
302+
score.scorer_class_identifier.class_name or "Unknown" if score.scorer_class_identifier else "Unknown"
303+
),
301304
score_type=score.score_type,
302305
score_value=score.score_value,
303306
score_category=score.score_category,

pyrit/backend/services/attack_service.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@
2424
from typing import Any, Literal, cast
2525
from urllib.parse import parse_qs, urlparse
2626

27-
from pyrit.backend.mappers.attack_mappers import (
27+
from pyrit.backend.mappers import (
2828
attack_result_to_summary,
29+
format_last_message_preview,
2930
pyrit_messages_to_dto_async,
3031
request_piece_to_pyrit_message_piece,
3132
request_to_pyrit_message,
@@ -177,11 +178,13 @@ async def list_attacks_async(
177178

178179
total_count = (main_stats.message_count if main_stats else 0) + sum(s.message_count for s in pruned_stats)
179180
preview = main_stats.last_message_preview if main_stats else None
181+
preview_data_type = main_stats.last_message_data_type if main_stats else None
180182
conv_labels = (main_stats.labels if main_stats else None) or {}
181183

182184
merged = ConversationStats(
183185
message_count=total_count,
184186
last_message_preview=preview,
187+
last_message_data_type=preview_data_type,
185188
labels=conv_labels,
186189
)
187190

@@ -419,7 +422,10 @@ async def get_conversations_async(self, *, attack_result_id: str) -> AttackConve
419422
ConversationSummary(
420423
conversation_id=conv_id,
421424
message_count=stats.message_count if stats else 0,
422-
last_message_preview=stats.last_message_preview if stats else None,
425+
last_message_preview=format_last_message_preview(
426+
value=stats.last_message_preview if stats else None,
427+
data_type=stats.last_message_data_type if stats else None,
428+
),
423429
created_at=created_at,
424430
)
425431
)

pyrit/memory/azure_sql_memory.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -615,18 +615,23 @@ def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str
615615
placeholders = ", ".join(f":cid{i}" for i in range(len(conversation_ids)))
616616
params = {f"cid{i}": cid for i, cid in enumerate(conversation_ids)}
617617

618-
max_len = ConversationStats.PREVIEW_MAX_LEN
619618
sql = text(
620619
f"""
621620
SELECT
622621
pme.conversation_id,
623622
COUNT(DISTINCT pme.sequence) AS msg_count,
624623
(
625-
SELECT TOP 1 LEFT(p2.converted_value, {max_len + 3})
624+
SELECT TOP 1 LEFT(p2.converted_value, {ConversationStats.PREVIEW_FETCH_MAX_LEN})
626625
FROM "PromptMemoryEntries" p2
627626
WHERE p2.conversation_id = pme.conversation_id
628627
ORDER BY p2.sequence DESC, p2.id DESC
629628
) AS last_preview,
629+
(
630+
SELECT TOP 1 p2b.converted_value_data_type
631+
FROM "PromptMemoryEntries" p2b
632+
WHERE p2b.conversation_id = pme.conversation_id
633+
ORDER BY p2b.sequence DESC, p2b.id DESC
634+
) AS last_data_type,
630635
(
631636
SELECT TOP 1 p3.labels
632637
FROM "PromptMemoryEntries" p3
@@ -648,11 +653,7 @@ def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str
648653

649654
result: dict[str, ConversationStats] = {}
650655
for row in rows:
651-
conv_id, msg_count, last_preview, raw_labels, raw_created_at = row
652-
653-
preview = None
654-
if last_preview:
655-
preview = last_preview[:max_len] + "..." if len(last_preview) > max_len else last_preview
656+
conv_id, msg_count, last_preview, last_data_type, raw_labels, raw_created_at = row
656657

657658
labels: dict[str, str] = {}
658659
if raw_labels and raw_labels not in ("null", "{}"):
@@ -668,7 +669,8 @@ def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str
668669

669670
result[conv_id] = ConversationStats(
670671
message_count=msg_count,
671-
last_message_preview=preview,
672+
last_message_preview=last_preview,
673+
last_message_data_type=last_data_type,
672674
labels=labels,
673675
created_at=created_at,
674676
)

pyrit/memory/sqlite_memory.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -739,19 +739,25 @@ def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str
739739
placeholders = ", ".join(f":cid{i}" for i in range(len(conversation_ids)))
740740
params = {f"cid{i}": cid for i, cid in enumerate(conversation_ids)}
741741

742-
max_len = ConversationStats.PREVIEW_MAX_LEN
743742
sql = text(
744743
f"""
745744
SELECT
746745
pme.conversation_id,
747746
COUNT(DISTINCT pme.sequence) AS msg_count,
748747
(
749-
SELECT SUBSTR(p2.converted_value, 1, {max_len + 3})
748+
SELECT SUBSTR(p2.converted_value, 1, {ConversationStats.PREVIEW_FETCH_MAX_LEN})
750749
FROM "PromptMemoryEntries" p2
751750
WHERE p2.conversation_id = pme.conversation_id
752751
ORDER BY p2.sequence DESC, p2.id DESC
753752
LIMIT 1
754753
) AS last_preview,
754+
(
755+
SELECT p2b.converted_value_data_type
756+
FROM "PromptMemoryEntries" p2b
757+
WHERE p2b.conversation_id = pme.conversation_id
758+
ORDER BY p2b.sequence DESC, p2b.id DESC
759+
LIMIT 1
760+
) AS last_data_type,
755761
(
756762
SELECT p3.labels
757763
FROM "PromptMemoryEntries" p3
@@ -774,11 +780,7 @@ def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str
774780

775781
result: dict[str, ConversationStats] = {}
776782
for row in rows:
777-
conv_id, msg_count, last_preview, raw_labels, raw_created_at = row
778-
779-
preview = None
780-
if last_preview:
781-
preview = last_preview[:max_len] + "..." if len(last_preview) > max_len else last_preview
783+
conv_id, msg_count, last_preview, last_data_type, raw_labels, raw_created_at = row
782784

783785
labels: dict[str, str] = {}
784786
if raw_labels and raw_labels not in ("null", "{}"):
@@ -794,7 +796,8 @@ def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str
794796

795797
result[conv_id] = ConversationStats(
796798
message_count=msg_count,
797-
last_message_preview=preview,
799+
last_message_preview=last_preview,
800+
last_message_data_type=last_data_type,
798801
labels=labels,
799802
created_at=created_at,
800803
)

pyrit/models/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,14 @@
6262
snake_case_to_class_name,
6363
validate_registry_name,
6464
)
65-
from pyrit.models.literals import ChatMessageRole, Modality, PromptDataType, PromptResponseError, SeedType
65+
from pyrit.models.literals import (
66+
MEDIA_PATH_DATA_TYPES,
67+
ChatMessageRole,
68+
Modality,
69+
PromptDataType,
70+
PromptResponseError,
71+
SeedType,
72+
)
6673
from pyrit.models.messages import (
6774
Message,
6875
MessagePiece,
@@ -141,6 +148,7 @@
141148
"IdentifierFilter",
142149
"IdentifierType",
143150
"ImagePathDataTypeSerializer",
151+
"MEDIA_PATH_DATA_TYPES",
144152
"Message",
145153
"MessagePiece",
146154
"Modality",

pyrit/models/conversation_stats.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from pydantic import BaseModel, ConfigDict, Field
88

9+
from pyrit.models.literals import PromptDataType
10+
911

1012
class ConversationStats(BaseModel):
1113
"""
@@ -17,8 +19,17 @@ class ConversationStats(BaseModel):
1719
model_config = ConfigDict(frozen=True)
1820

1921
PREVIEW_MAX_LEN: ClassVar[int] = 100
22+
PREVIEW_FETCH_MAX_LEN: ClassVar[int] = 1024
23+
"""
24+
Upper bound (in characters) for the raw ``last_message_preview`` value
25+
fetched from storage. Larger than ``PREVIEW_MAX_LEN`` so that downstream
26+
presentation code (see ``pyrit.backend.mappers._preview``) has enough
27+
characters to extract a basename from a long media path or signed blob
28+
URL before applying display-level truncation.
29+
"""
2030

2131
message_count: int = 0
2232
last_message_preview: Optional[str] = None
33+
last_message_data_type: Optional[PromptDataType] = None
2334
labels: dict[str, str] = Field(default_factory=dict)
2435
created_at: Optional[datetime] = None

pyrit/models/literals.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818
"function_call_output",
1919
]
2020

21+
# Subset of ``PromptDataType`` values whose stored ``value`` is a path or URL
22+
# pointing at media content (rather than the content itself). Useful for
23+
# treating these specially — e.g. avoiding raw filesystem-path leaks in API
24+
# previews, or signing blob storage URLs before exposing them to the frontend.
25+
MEDIA_PATH_DATA_TYPES: frozenset[PromptDataType] = frozenset({"image_path", "audio_path", "video_path", "binary_path"})
26+
2127
"""
2228
The type of the error in the prompt response
2329
blocked: blocked by an external filter e.g. Azure Filters

0 commit comments

Comments
 (0)