Skip to content

Commit cdcab06

Browse files
Robert FitzpatrickRobert Fitzpatrick
authored andcommitted
refactor: improve message structure and extract multimodal handling
Address Roman's feedback items #2 and #3: - Change _build_adversarial_prompt to return Message instead of Union type - Extract message construction logic into separate helper methods - Add _build_text_message() for simple text prompts - Add _build_multimodal_message() for media responses - Simplify caller code by removing tuple handling logic - Improve logging to work with Message objects These architectural improvements prepare the code to integrate with the modality support detection system from separate PR.
1 parent 5cb97e1 commit cdcab06

1 file changed

Lines changed: 69 additions & 45 deletions

File tree

pyrit/executor/attack/multi_turn/red_teaming.py

Lines changed: 69 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import logging
88
import uuid
99
from pathlib import Path
10-
from typing import Any, Callable, Optional, Union
10+
from typing import Any, Callable, Optional
1111

1212
from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults
1313
from pyrit.common.path import EXECUTOR_RED_TEAM_PATH
@@ -357,41 +357,20 @@ async def _generate_next_prompt_async(self, context: MultiTurnAttackContext[Any]
357357
# Generate prompt using adversarial chat
358358
logger.debug(f"Generating prompt for turn {context.executed_turns + 1}")
359359

360-
# Prepare prompt for the adversarial chat
361-
attack_message = await self._build_adversarial_prompt(context)
362-
363-
# Build the message for the adversarial chat.
364-
# For file/media responses, construct a multimodal message with both
365-
# the textual feedback and the actual media (image/video) so the
366-
# adversarial chat (e.g. GPT-4o) can see what the target generated.
367-
if isinstance(prompt_result, tuple):
368-
feedback_text, media_piece = prompt_result
369-
# Use a shared conversation_id so Message validation passes
370-
shared_conversation_id = str(uuid.uuid4())
371-
pieces = [
372-
MessagePiece(
373-
original_value=feedback_text,
374-
role="user",
375-
conversation_id=shared_conversation_id,
376-
)
377-
]
378-
if media_piece is not None:
379-
pieces.append(
380-
MessagePiece(
381-
original_value=media_piece.converted_value,
382-
role="user",
383-
original_value_data_type=media_piece.converted_value_data_type,
384-
conversation_id=shared_conversation_id,
385-
)
386-
)
387-
prompt_message = Message(message_pieces=pieces)
360+
# Build the message for the adversarial chat
361+
prompt_message = await self._build_adversarial_prompt(context)
362+
363+
# Log the message being sent
364+
if prompt_message.is_multimodal():
365+
text_piece = prompt_message.get_first_piece_by_data_type("text")
366+
media_pieces = [p for p in prompt_message.message_pieces if p.converted_value_data_type != "text"]
367+
feedback_text = text_piece.converted_value if text_piece else "No text content"
368+
media_info = f"{len(media_pieces)} media piece(s)" if media_pieces else "no media"
388369
logger.debug(
389-
f"Sending multimodal prompt to adversarial chat: {feedback_text[:50]}... "
390-
f"+ {media_piece.converted_value_data_type if media_piece else 'no'} media"
370+
f"Sending multimodal prompt to adversarial chat: {feedback_text[:50]}... + {media_info}"
391371
)
392372
else:
393-
prompt_text = prompt_result
394-
prompt_message = Message.from_prompt(prompt=prompt_text, role="user")
373+
prompt_text = prompt_message.get_first_piece().converted_value
395374
logger.debug(f"Sending prompt to adversarial chat: {prompt_text[:50]}...")
396375

397376
with execution_context(
@@ -420,33 +399,35 @@ async def _generate_next_prompt_async(self, context: MultiTurnAttackContext[Any]
420399
async def _build_adversarial_prompt(
421400
self,
422401
context: MultiTurnAttackContext[Any],
423-
) -> Union[str, tuple[str, Optional[MessagePiece]]]:
402+
) -> Message:
424403
"""
425-
Build a prompt for the adversarial chat based on the last response.
404+
Build a prompt message for the adversarial chat based on the last response.
426405
427-
For text responses, returns a plain string. For file/media responses (images, video, etc.),
428-
returns a tuple of (feedback_text, media_piece) so the caller can construct a multimodal
429-
message that includes the actual generated media alongside the textual feedback.
406+
For text responses, creates a simple text message. For file/media responses (images, video, etc.),
407+
creates a multimodal message that includes both the textual feedback and the actual generated
408+
media so the adversarial chat can see what the target produced.
430409
431410
Args:
432411
context (MultiTurnAttackContext): The attack context containing the current state and configuration.
433412
434413
Returns:
435-
Union[str, tuple[str, Optional[MessagePiece]]]: Either a plain text prompt string,
436-
or a tuple of (feedback_text, media_piece) when the target returned media content.
414+
Message: A message ready to be sent to the adversarial chat.
437415
"""
438416
# If no last response, return the seed prompt (rendered with objective if template exists)
439417
if not context.last_response:
440-
return self._adversarial_chat_seed_prompt.render_template_value_silent(objective=context.objective)
418+
prompt_text = self._adversarial_chat_seed_prompt.render_template_value_silent(objective=context.objective)
419+
return Message.from_prompt(prompt=prompt_text, role="user")
441420

442421
# Get the last assistant piece from the response
443422
response_piece = context.last_response.get_piece()
444423

445-
# Text/error responses return str; file responses return tuple[str, Optional[MessagePiece]]
424+
# Build message based on response type (text vs file/media)
446425
if response_piece.converted_value_data_type in ("text", "error"):
447-
return self._handle_adversarial_text_response(context=context)
448-
449-
return self._handle_adversarial_file_response(context=context)
426+
feedback_text = self._handle_adversarial_text_response(context=context)
427+
return self._build_text_message(feedback_text)
428+
else:
429+
feedback_text, media_piece = self._handle_adversarial_file_response(context=context)
430+
return self._build_multimodal_message(feedback_text, media_piece)
450431

451432
def _handle_adversarial_text_response(self, *, context: MultiTurnAttackContext[Any]) -> str:
452433
"""
@@ -538,6 +519,49 @@ def _handle_adversarial_file_response(
538519

539520
return (feedback, response_piece)
540521

522+
def _build_text_message(self, feedback_text: str) -> Message:
523+
"""
524+
Build a simple text message for the adversarial chat.
525+
526+
Args:
527+
feedback_text (str): The text content for the message.
528+
529+
Returns:
530+
Message: A text message ready to be sent to the adversarial chat.
531+
"""
532+
return Message.from_prompt(prompt=feedback_text, role="user")
533+
534+
def _build_multimodal_message(self, feedback_text: str, media_piece: Optional[MessagePiece]) -> Message:
535+
"""
536+
Build a multimodal message for the adversarial chat containing both text and media.
537+
538+
Args:
539+
feedback_text (str): The textual feedback to include.
540+
media_piece (Optional[MessagePiece]): The media piece from the target response, if any.
541+
542+
Returns:
543+
Message: A multimodal message ready to be sent to the adversarial chat.
544+
"""
545+
# Use a shared conversation_id so Message validation passes
546+
shared_conversation_id = str(uuid.uuid4())
547+
pieces = [
548+
MessagePiece(
549+
original_value=feedback_text,
550+
role="user",
551+
conversation_id=shared_conversation_id,
552+
)
553+
]
554+
if media_piece is not None:
555+
pieces.append(
556+
MessagePiece(
557+
original_value=media_piece.converted_value,
558+
role="user",
559+
original_value_data_type=media_piece.converted_value_data_type,
560+
conversation_id=shared_conversation_id,
561+
)
562+
)
563+
return Message(message_pieces=pieces)
564+
541565
async def _send_prompt_to_objective_target_async(
542566
self,
543567
*,

0 commit comments

Comments
 (0)