|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT license. |
| 3 | + |
| 4 | +import logging |
| 5 | +from typing import Any, Literal, Optional |
| 6 | + |
| 7 | +from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults |
| 8 | +from pyrit.exceptions import ComponentRole, execution_context |
| 9 | +from pyrit.executor.attack.core.attack_config import AttackConverterConfig, AttackScoringConfig |
| 10 | +from pyrit.executor.attack.core.attack_parameters import AttackParameters |
| 11 | +from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack |
| 12 | +from pyrit.executor.attack.single_turn.single_turn_attack_strategy import SingleTurnAttackContext |
| 13 | +from pyrit.models import ( |
| 14 | + AttackResult, |
| 15 | + ConversationReference, |
| 16 | + ConversationType, |
| 17 | + Message, |
| 18 | + build_atomic_attack_identifier, |
| 19 | +) |
| 20 | +from pyrit.prompt_converter.bijection_converter import BijectionConverter |
| 21 | +from pyrit.prompt_normalizer import PromptConverterConfiguration, PromptNormalizer |
| 22 | +from pyrit.prompt_target import PromptTarget |
| 23 | + |
| 24 | +logger = logging.getLogger(__name__) |
| 25 | + |
| 26 | +# BijectionLearningAttack constructs its own encoded messages, so callers |
| 27 | +# cannot inject pre-built next_message or prepended_conversation. |
| 28 | +BijectionLearningParameters = AttackParameters.excluding("prepended_conversation", "next_message") |
| 29 | + |
| 30 | + |
| 31 | +class BijectionLearningAttack(PromptSendingAttack): |
| 32 | + """ |
| 33 | + Implement the Bijection Learning jailbreak [@liu2024bijectionlearning]. |
| 34 | +
|
| 35 | + Each attempt generates a fresh random bijection and threads two paired |
| 36 | + converters through PyRIT's normal converter pipeline: |
| 37 | +
|
| 38 | + * **Request side** – a ``BijectionConverter(direction="encode")`` appended |
| 39 | + after any user-supplied request converters. It wraps the objective in the |
| 40 | + teaching preamble and encodes it before the prompt reaches the target. |
| 41 | + * **Response side** – a matching ``BijectionConverter(direction="decode")`` |
| 42 | + built from that same attempt's mapping, prepended before any user-supplied |
| 43 | + response converters. The normalizer applies it to the raw target response |
| 44 | + so the scorer always receives decoded plaintext. |
| 45 | +
|
| 46 | + Repeating with independent mappings (best-of-n) more than doubles the |
| 47 | + single-attempt attack success rate reported in the paper. |
| 48 | + """ |
| 49 | + |
| 50 | + @apply_defaults |
| 51 | + def __init__( |
| 52 | + self, |
| 53 | + *, |
| 54 | + objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] |
| 55 | + attack_converter_config: Optional[AttackConverterConfig] = None, |
| 56 | + attack_scoring_config: Optional[AttackScoringConfig] = None, |
| 57 | + prompt_normalizer: Optional[PromptNormalizer] = None, |
| 58 | + max_attempts_on_failure: int = 0, |
| 59 | + mapping_type: Literal["letter", "digit"] = "digit", |
| 60 | + fixed_points: int = 13, |
| 61 | + digit_length: int = 2, |
| 62 | + num_teaching_shots: int = 5, |
| 63 | + ) -> None: |
| 64 | + """ |
| 65 | + Args: |
| 66 | + objective_target: The target system to attack. |
| 67 | + attack_converter_config: Optional additional converter configuration. |
| 68 | + User-supplied request converters run *before* bijection encoding; |
| 69 | + user-supplied response converters run *after* bijection decoding. |
| 70 | + attack_scoring_config: Scoring configuration. |
| 71 | + prompt_normalizer: Optional normalizer override. |
| 72 | + max_attempts_on_failure: Additional attempts after the first |
| 73 | + failure (best-of-n sampling). Each attempt uses a fresh random |
| 74 | + bijection mapping. |
| 75 | + mapping_type: ``"letter"`` or ``"digit"`` — forwarded to |
| 76 | + ``BijectionConverter``. |
| 77 | + fixed_points: Letters that map to themselves (0–25). Lower values |
| 78 | + yield more complex encodings. |
| 79 | + digit_length: Numeric code length for ``mapping_type="digit"``. |
| 80 | + num_teaching_shots: Benign teaching pairs prepended to the query. |
| 81 | + """ |
| 82 | + super().__init__( |
| 83 | + objective_target=objective_target, |
| 84 | + attack_converter_config=attack_converter_config, |
| 85 | + attack_scoring_config=attack_scoring_config, |
| 86 | + prompt_normalizer=prompt_normalizer, |
| 87 | + max_attempts_on_failure=max_attempts_on_failure, |
| 88 | + params_type=BijectionLearningParameters, |
| 89 | + ) |
| 90 | + self._mapping_type = mapping_type |
| 91 | + self._fixed_points = fixed_points |
| 92 | + self._digit_length = digit_length |
| 93 | + self._num_teaching_shots = num_teaching_shots |
| 94 | + |
| 95 | + async def _perform_async(self, *, context: SingleTurnAttackContext[Any]) -> AttackResult: |
| 96 | + """ |
| 97 | + Run the bijection learning attack loop. |
| 98 | +
|
| 99 | + Each iteration: |
| 100 | + 1. Creates a fresh ``BijectionConverter(direction="encode")`` — new |
| 101 | + random mapping for this attempt. |
| 102 | + 2. Builds a paired ``BijectionConverter(direction="decode")`` from the |
| 103 | + same mapping. |
| 104 | + 3. Calls the normalizer with the objective as plain text, the encode |
| 105 | + converter appended to request converters, and the decode converter |
| 106 | + prepended to response converters. The normalizer handles all |
| 107 | + transformation; the scorer receives decoded plaintext. |
| 108 | + 4. Scores and breaks on success; otherwise resets the conversation for |
| 109 | + the next attempt. |
| 110 | +
|
| 111 | + Returns: |
| 112 | + AttackResult: The outcome, last response, and score for the attempt. |
| 113 | + """ |
| 114 | + self._logger.info(f"Starting {self.__class__.__name__} with objective: {context.objective}") |
| 115 | + |
| 116 | + response: Optional[Message] = None |
| 117 | + score = None |
| 118 | + |
| 119 | + for attempt in range(self._max_attempts_on_failure + 1): |
| 120 | + self._logger.debug(f"Attempt {attempt + 1}/{self._max_attempts_on_failure + 1}") |
| 121 | + |
| 122 | + # Fresh random bijection for this attempt. |
| 123 | + encode_converter = BijectionConverter( |
| 124 | + direction="encode", |
| 125 | + mapping_type=self._mapping_type, |
| 126 | + fixed_points=self._fixed_points, |
| 127 | + digit_length=self._digit_length, |
| 128 | + num_teaching_shots=self._num_teaching_shots, |
| 129 | + append_description=True, |
| 130 | + ) |
| 131 | + # Paired decoder built from THIS attempt's mapping. |
| 132 | + decode_converter = BijectionConverter( |
| 133 | + direction="decode", |
| 134 | + custom_mapping=encode_converter.mapping, |
| 135 | + ) |
| 136 | + |
| 137 | + # Append the encode converter AFTER user-supplied request converters |
| 138 | + # so bijection encoding is the last transform before the target. |
| 139 | + request_configs = self._request_converters + PromptConverterConfiguration.from_converters( |
| 140 | + converters=[encode_converter] |
| 141 | + ) |
| 142 | + # Prepend the decode converter BEFORE user-supplied response converters |
| 143 | + # so the scorer always receives decoded plaintext. |
| 144 | + response_configs = ( |
| 145 | + PromptConverterConfiguration.from_converters(converters=[decode_converter]) + self._response_converters |
| 146 | + ) |
| 147 | + |
| 148 | + # Send the plain objective; encoding is handled by the request converter. |
| 149 | + message = Message.from_prompt(prompt=context.objective, role="user") |
| 150 | + |
| 151 | + with execution_context( |
| 152 | + component_role=ComponentRole.OBJECTIVE_TARGET, |
| 153 | + attack_strategy_name=self.__class__.__name__, |
| 154 | + attack_identifier=self.get_identifier(), |
| 155 | + component_identifier=self._objective_target.get_identifier(), |
| 156 | + objective_target_conversation_id=context.conversation_id, |
| 157 | + objective=context.params.objective, |
| 158 | + ): |
| 159 | + response = await self._prompt_normalizer.send_prompt_async( |
| 160 | + message=message, |
| 161 | + target=self._objective_target, |
| 162 | + conversation_id=context.conversation_id, |
| 163 | + request_converter_configurations=request_configs, |
| 164 | + response_converter_configurations=response_configs, |
| 165 | + attack_identifier=self.get_identifier(), |
| 166 | + ) |
| 167 | + |
| 168 | + if not response: |
| 169 | + self._logger.warning(f"No response on attempt {attempt + 1} (likely filtered)") |
| 170 | + if attempt < self._max_attempts_on_failure: |
| 171 | + context.related_conversations.add( |
| 172 | + ConversationReference( |
| 173 | + conversation_id=context.conversation_id, |
| 174 | + conversation_type=ConversationType.PRUNED, |
| 175 | + ) |
| 176 | + ) |
| 177 | + await self._setup_async(context=context) |
| 178 | + continue |
| 179 | + |
| 180 | + # The response's converted_value is already decoded by the response |
| 181 | + # converter; pass it directly to the scorer. |
| 182 | + score = await self._evaluate_response_async(response=response, objective=context.objective) |
| 183 | + |
| 184 | + if not self._objective_scorer: |
| 185 | + break |
| 186 | + |
| 187 | + if score and score.get_value(): |
| 188 | + break |
| 189 | + |
| 190 | + if attempt < self._max_attempts_on_failure: |
| 191 | + context.related_conversations.add( |
| 192 | + ConversationReference( |
| 193 | + conversation_id=context.conversation_id, |
| 194 | + conversation_type=ConversationType.PRUNED, |
| 195 | + ) |
| 196 | + ) |
| 197 | + await self._setup_async(context=context) |
| 198 | + |
| 199 | + outcome, outcome_reason = self._determine_attack_outcome(response=response, score=score, context=context) |
| 200 | + |
| 201 | + return AttackResult( |
| 202 | + conversation_id=context.conversation_id, |
| 203 | + objective=context.objective, |
| 204 | + atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=self.get_identifier()), |
| 205 | + last_response=response.get_piece() if response else None, |
| 206 | + last_score=score, |
| 207 | + related_conversations=context.related_conversations, |
| 208 | + outcome=outcome, |
| 209 | + outcome_reason=outcome_reason, |
| 210 | + executed_turns=1, |
| 211 | + labels=context.memory_labels, |
| 212 | + ) |
0 commit comments