Skip to content

Commit 3186982

Browse files
FEAT add TargetConfiguration & pieces (microsoft#1573)
1 parent e5f18ab commit 3186982

10 files changed

Lines changed: 1117 additions & 4 deletions

pyrit/message_normalizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pyrit.message_normalizer.chat_message_normalizer import ChatMessageNormalizer
99
from pyrit.message_normalizer.conversation_context_normalizer import ConversationContextNormalizer
1010
from pyrit.message_normalizer.generic_system_squash import GenericSystemSquashNormalizer
11+
from pyrit.message_normalizer.history_squash_normalizer import HistorySquashNormalizer
1112
from pyrit.message_normalizer.message_normalizer import (
1213
MessageListNormalizer,
1314
MessageStringNormalizer,
@@ -18,6 +19,7 @@
1819
"MessageListNormalizer",
1920
"MessageStringNormalizer",
2021
"GenericSystemSquashNormalizer",
22+
"HistorySquashNormalizer",
2123
"TokenizerTemplateNormalizer",
2224
"ConversationContextNormalizer",
2325
"ChatMessageNormalizer",
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
from pyrit.message_normalizer.message_normalizer import MessageListNormalizer
5+
from pyrit.models import Message
6+
7+
8+
class HistorySquashNormalizer(MessageListNormalizer[Message]):
9+
"""
10+
Squashes a multi-turn conversation into a single user message.
11+
12+
Previous turns are formatted as labeled context and prepended to the
13+
latest message. Used by the normalization pipeline to adapt prompts
14+
for targets that do not support multi-turn conversations.
15+
"""
16+
17+
async def normalize_async(self, messages: list[Message]) -> list[Message]:
18+
"""
19+
Combine all messages into a single user message.
20+
21+
When there is only one message it is returned unchanged. Otherwise
22+
all prior turns are formatted as ``Role: content`` lines under a
23+
``[Conversation History]`` header and the last message's content
24+
appears under a ``[Current Message]`` header.
25+
26+
Args:
27+
messages: The conversation messages to squash.
28+
29+
Returns:
30+
list[Message]: A single-element list containing the squashed message.
31+
32+
Raises:
33+
ValueError: If the messages list is empty.
34+
"""
35+
if not messages:
36+
raise ValueError("Messages list cannot be empty")
37+
38+
if len(messages) == 1:
39+
return list(messages)
40+
41+
history_lines = self._format_history(messages=messages[:-1])
42+
current_parts = [piece.converted_value for piece in messages[-1].message_pieces]
43+
44+
combined = (
45+
"[Conversation History]\n" + "\n".join(history_lines) + "\n\n[Current Message]\n" + "\n".join(current_parts)
46+
)
47+
48+
return [Message.from_prompt(prompt=combined, role="user")]
49+
50+
def _format_history(self, *, messages: list[Message]) -> list[str]:
51+
"""
52+
Format prior messages as ``Role: content`` lines.
53+
54+
Args:
55+
messages: The history messages to format.
56+
57+
Returns:
58+
list[str]: One line per message piece.
59+
"""
60+
lines: list[str] = []
61+
for msg in messages:
62+
lines.extend(f"{piece.api_role.capitalize()}: {piece.converted_value}" for piece in msg.message_pieces)
63+
return lines

pyrit/prompt_target/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,16 @@
1010

1111
from pyrit.prompt_target.azure_blob_storage_target import AzureBlobStorageTarget
1212
from pyrit.prompt_target.azure_ml_chat_target import AzureMLChatTarget
13+
from pyrit.prompt_target.common.conversation_normalization_pipeline import ConversationNormalizationPipeline
1314
from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget
1415
from pyrit.prompt_target.common.prompt_target import PromptTarget
15-
from pyrit.prompt_target.common.target_capabilities import TargetCapabilities
16+
from pyrit.prompt_target.common.target_capabilities import (
17+
CapabilityHandlingPolicy,
18+
CapabilityName,
19+
TargetCapabilities,
20+
UnsupportedCapabilityBehavior,
21+
)
22+
from pyrit.prompt_target.common.target_configuration import TargetConfiguration
1623
from pyrit.prompt_target.common.utils import limit_requests_per_minute
1724
from pyrit.prompt_target.gandalf_target import GandalfLevel, GandalfTarget
1825
from pyrit.prompt_target.http_target.http_target import HTTPTarget
@@ -41,7 +48,10 @@
4148
__all__ = [
4249
"AzureBlobStorageTarget",
4350
"AzureMLChatTarget",
51+
"CapabilityName",
52+
"CapabilityHandlingPolicy",
4453
"CopilotType",
54+
"ConversationNormalizationPipeline",
4555
"GandalfLevel",
4656
"GandalfTarget",
4757
"get_http_target_json_response_callback_function",
@@ -66,6 +76,8 @@
6676
"PromptTarget",
6777
"RealtimeTarget",
6878
"TargetCapabilities",
79+
"TargetConfiguration",
80+
"UnsupportedCapabilityBehavior",
6981
"TextTarget",
7082
"WebSocketCopilotTarget",
7183
]
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
import logging
5+
6+
from pyrit.message_normalizer import (
7+
GenericSystemSquashNormalizer,
8+
HistorySquashNormalizer,
9+
MessageListNormalizer,
10+
)
11+
from pyrit.models import Message
12+
from pyrit.prompt_target.common.target_capabilities import (
13+
CapabilityHandlingPolicy,
14+
CapabilityName,
15+
TargetCapabilities,
16+
UnsupportedCapabilityBehavior,
17+
)
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
# ---------------------------------------------------------------------------
23+
# Single registry: add new normalizable capabilities here and nowhere else.
24+
# Order in the list determines pipeline execution order.
25+
# ---------------------------------------------------------------------------
26+
_NORMALIZER_REGISTRY: list[tuple[CapabilityName, MessageListNormalizer[Message]]] = [
27+
(CapabilityName.SYSTEM_PROMPT, GenericSystemSquashNormalizer()),
28+
(CapabilityName.MULTI_TURN, HistorySquashNormalizer()),
29+
]
30+
31+
# Derived constant — no manual maintenance required.
32+
NORMALIZABLE_CAPABILITIES: frozenset[CapabilityName] = frozenset(cap for cap, _ in _NORMALIZER_REGISTRY)
33+
34+
35+
class ConversationNormalizationPipeline:
36+
"""
37+
Ordered sequence of message normalizers that adapt conversations when
38+
the target lacks certain capabilities.
39+
40+
The pipeline is constructed via ``from_capabilities``, which resolves
41+
capabilities and policy into a concrete, ordered tuple of normalizers.
42+
``normalize_async`` then simply executes that tuple in order.
43+
44+
To add a new normalizable capability, add a single entry to
45+
``_NORMALIZER_REGISTRY``. ``NORMALIZABLE_CAPABILITIES``,
46+
pipeline ordering, and default normalizers are all derived from it.
47+
"""
48+
49+
def __init__(self, normalizers: tuple[MessageListNormalizer[Message], ...] = ()) -> None:
50+
"""
51+
Initialize the normalization pipeline with an ordered sequence of normalizers.
52+
53+
Args:
54+
normalizers (tuple[MessageListNormalizer[Message], ...]):
55+
Ordered normalizers to apply during ``normalize_async``.
56+
Defaults to an empty tuple (pass-through).
57+
"""
58+
self._normalizers = normalizers
59+
60+
@classmethod
61+
def from_capabilities(
62+
cls,
63+
*,
64+
capabilities: TargetCapabilities,
65+
policy: CapabilityHandlingPolicy,
66+
normalizer_overrides: dict[CapabilityName, MessageListNormalizer[Message]] | None = None,
67+
) -> "ConversationNormalizationPipeline":
68+
"""
69+
Resolve capabilities and policy into a concrete pipeline of normalizers.
70+
71+
For each capability in ``_NORMALIZER_REGISTRY`` (in order):
72+
73+
* If the target already supports the capability, no normalizer is added.
74+
* If the capability is missing and the policy is ``ADAPT``, the
75+
corresponding normalizer (from overrides or defaults) is added.
76+
* If the capability is missing and the policy is ``RAISE``, a
77+
``ValueError`` is raised immediately.
78+
79+
Args:
80+
capabilities (TargetCapabilities): The target's declared capabilities.
81+
policy (CapabilityHandlingPolicy): How to handle each missing capability.
82+
normalizer_overrides (dict[CapabilityName, MessageListNormalizer[Message]] | None):
83+
Optional overrides for specific capability normalizers.
84+
Falls back to the defaults from ``_NORMALIZER_REGISTRY``.
85+
86+
Returns:
87+
ConversationNormalizationPipeline: A pipeline with the resolved
88+
ordered tuple of normalizers.
89+
90+
Raises:
91+
ValueError: If a required capability is missing and the policy is RAISE.
92+
"""
93+
overrides = normalizer_overrides or {}
94+
normalizers: list[MessageListNormalizer[Message]] = []
95+
96+
for capability, default_normalizer in _NORMALIZER_REGISTRY:
97+
if capabilities.includes(capability=capability):
98+
continue
99+
100+
behavior = policy.get_behavior(capability=capability)
101+
102+
if behavior == UnsupportedCapabilityBehavior.RAISE:
103+
raise ValueError(f"Target does not support '{capability.value}' and the handling policy is RAISE.")
104+
105+
normalizer = overrides.get(capability, default_normalizer)
106+
107+
normalizers.append(normalizer)
108+
109+
return cls(normalizers=tuple(normalizers))
110+
111+
async def normalize_async(self, *, messages: list[Message]) -> list[Message]:
112+
"""
113+
Run the pre-resolved normalizer sequence over the messages.
114+
115+
Args:
116+
messages (list[Message]): The full conversation to normalize.
117+
118+
Returns:
119+
list[Message]: The (possibly adapted) message list.
120+
"""
121+
result = list(messages)
122+
for normalizer in self._normalizers:
123+
result = await normalizer.normalize_async(result)
124+
return result
125+
126+
@property
127+
def normalizers(self) -> tuple[MessageListNormalizer[Message], ...]:
128+
"""
129+
The ordered normalizers in this pipeline.
130+
131+
Returns:
132+
tuple[MessageListNormalizer[Message], ...]: The normalizer sequence.
133+
"""
134+
return self._normalizers

pyrit/prompt_target/common/target_capabilities.py

Lines changed: 111 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,109 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT license.
33

4-
from dataclasses import dataclass
5-
from typing import Optional, cast
4+
from collections.abc import Mapping
5+
from dataclasses import dataclass, field
6+
from enum import Enum
7+
from types import MappingProxyType
8+
from typing import NoReturn, Optional, cast
69

710
from pyrit.models import PromptDataType
811

912

13+
class CapabilityName(str, Enum):
14+
"""
15+
Canonical identifiers for target capabilities.
16+
17+
This keeps capability identity in one place so policy, requirements, and
18+
normalization code do not duplicate string field names.
19+
"""
20+
21+
MULTI_TURN = "supports_multi_turn"
22+
MULTI_MESSAGE_PIECES = "supports_multi_message_pieces"
23+
JSON_SCHEMA = "supports_json_schema"
24+
JSON_OUTPUT = "supports_json_output"
25+
EDITABLE_HISTORY = "supports_editable_history"
26+
SYSTEM_PROMPT = "supports_system_prompt"
27+
28+
29+
class UnsupportedCapabilityBehavior(str, Enum):
30+
"""
31+
Defines what happens when a caller requires a capability the target does not support.
32+
33+
ADAPT: apply a normalization step to work around the unsupported capability.
34+
RAISE: fail immediately with an error.
35+
"""
36+
37+
ADAPT = "adapt"
38+
RAISE = "raise"
39+
40+
41+
@dataclass(frozen=True)
42+
class CapabilityHandlingPolicy:
43+
"""
44+
Per-capability policy consulted only when a capability is unsupported.
45+
46+
Design invariants
47+
-----------------
48+
* The policy is never consulted if the capability is already supported.
49+
* Non-adaptable capabilities (e.g. ``supports_editable_history``) are not
50+
represented here; requesting them on a target that lacks them always
51+
raises immediately.
52+
"""
53+
54+
behaviors: Mapping[CapabilityName, UnsupportedCapabilityBehavior] = field(
55+
default_factory=lambda: {
56+
CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE,
57+
CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE,
58+
}
59+
)
60+
61+
def get_behavior(self, *, capability: CapabilityName) -> UnsupportedCapabilityBehavior:
62+
"""
63+
Return the configured handling behavior for a capability.
64+
65+
Args:
66+
capability: The capability to look up.
67+
68+
Returns:
69+
UnsupportedCapabilityBehavior: The configured behavior.
70+
71+
Raises:
72+
KeyError: If no behavior exists for the capability. This occurs for
73+
non-adaptable capabilities (e.g., supports_editable_history).
74+
"""
75+
try:
76+
return self.behaviors[capability]
77+
except KeyError:
78+
supported = ", ".join(sorted(cap.value for cap in self.behaviors))
79+
raise KeyError(
80+
f"No policy for capability '{capability.value}'. Supported capabilities: {supported}."
81+
) from None
82+
83+
def __getattr__(self, name: str) -> NoReturn:
84+
"""
85+
Guard against accessing policies for non-adaptable or unknown capabilities.
86+
87+
Raises:
88+
AttributeError: If the capability is not part of this policy.
89+
"""
90+
for capability in CapabilityName:
91+
if capability.value == name:
92+
supported_names = ", ".join(sorted(cap.value for cap in self.behaviors))
93+
raise AttributeError(
94+
f"'{type(self).__name__}' has no policy for '{name}'. "
95+
f"Only the following capabilities have handling policies: "
96+
f"{supported_names}."
97+
)
98+
99+
raise AttributeError(name)
100+
101+
def __post_init__(self) -> None:
102+
"""Create a defensive read-only copy of the behaviors mapping."""
103+
# object.__setattr__ is required because the dataclass is frozen.
104+
object.__setattr__(self, "behaviors", MappingProxyType(dict(self.behaviors)))
105+
106+
10107
@dataclass(frozen=True)
11108
class TargetCapabilities:
12109
"""
@@ -47,6 +144,18 @@ class attribute. Users can override individual capabilities per instance
47144
# The output modalities supported by the target (e.g., "text", "image").
48145
output_modalities: frozenset[frozenset[PromptDataType]] = frozenset({frozenset(["text"])})
49146

147+
def includes(self, *, capability: CapabilityName) -> bool:
148+
"""
149+
Return whether this target supports the given capability.
150+
151+
Args:
152+
capability: The capability to check.
153+
154+
Returns:
155+
bool: True if supported, otherwise False.
156+
"""
157+
return bool(getattr(self, capability.value))
158+
50159
@staticmethod
51160
def get_known_capabilities(underlying_model: str) -> "Optional[TargetCapabilities]":
52161
"""

0 commit comments

Comments
 (0)