diff --git a/pyrit/auxiliary_attacks/gcg/__init__.py b/pyrit/auxiliary_attacks/gcg/__init__.py index a10d862fe3..160b2f313a 100644 --- a/pyrit/auxiliary_attacks/gcg/__init__.py +++ b/pyrit/auxiliary_attacks/gcg/__init__.py @@ -47,18 +47,28 @@ # mechanism so all GCG public symbols share one re-export pathway. _LAZY_IMPORTS = { "CandidateFilter": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "CandidateFilter"), + "CrossEntropyLoss": ("pyrit.auxiliary_attacks.gcg.default_implementations", "CrossEntropyLoss"), "GCG": ("pyrit.auxiliary_attacks.gcg.generator", "GCGGenerator"), "GCGContext": ("pyrit.auxiliary_attacks.gcg.generator", "GCGContext"), "GCGGenerator": ("pyrit.auxiliary_attacks.gcg.generator", "GCGGenerator"), "GCGResult": ("pyrit.auxiliary_attacks.gcg.generator", "GCGResult"), + "LengthPreservingFilter": ("pyrit.auxiliary_attacks.gcg.default_implementations", "LengthPreservingFilter"), + "LiteralStringInit": ("pyrit.auxiliary_attacks.gcg.default_implementations", "LiteralStringInit"), "LossFunction": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "LossFunction"), "SamplingStrategy": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "SamplingStrategy"), + "StandardGCGSampling": ("pyrit.auxiliary_attacks.gcg.default_implementations", "StandardGCGSampling"), "SuffixInitializer": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "SuffixInitializer"), "load_goals_and_targets": ("pyrit.auxiliary_attacks.gcg.data", "load_goals_and_targets"), } if TYPE_CHECKING: from pyrit.auxiliary_attacks.gcg.data import load_goals_and_targets + from pyrit.auxiliary_attacks.gcg.default_implementations import ( + CrossEntropyLoss, + LengthPreservingFilter, + LiteralStringInit, + StandardGCGSampling, + ) from pyrit.auxiliary_attacks.gcg.extension_protocols import ( CandidateFilter, LossFunction, @@ -91,6 +101,7 @@ def __dir__() -> list[str]: __all__ = [ "CandidateFilter", + "CrossEntropyLoss", "GCG", "GCGAlgorithmConfig", "GCGConfig", @@ -101,8 +112,11 @@ def __dir__() -> list[str]: "GCGOutputConfig", "GCGResult", "GCGStrategyConfig", + "LengthPreservingFilter", + "LiteralStringInit", "LossFunction", "SamplingStrategy", + "StandardGCGSampling", "SuffixInitializer", "load_goals_and_targets", ] diff --git a/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py b/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py index 4df1ae9205..ea0677527e 100644 --- a/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py +++ b/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py @@ -17,6 +17,12 @@ get_embedding_matrix, get_embeddings, ) +from pyrit.auxiliary_attacks.gcg.default_implementations import ( + CrossEntropyLoss, + LengthPreservingFilter, + StandardGCGSampling, +) +from pyrit.auxiliary_attacks.gcg.extension_protocols import CandidateFilter, LossFunction, SamplingStrategy logger = logging.getLogger(__name__) @@ -125,6 +131,99 @@ def sample_control( class GCGMultiPromptAttack(MultiPromptAttack): """GCG-specific multi-prompt attack that implements the GCG optimization step.""" + def __init__( + self, + goals: list[str], + targets: list[str], + workers: list[Any], + control_init: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !", + test_prefixes: list[str] | None = None, + logfile: str | None = None, + managers: dict[str, Any] | None = None, + test_goals: list[str] | None = None, + test_targets: list[str] | None = None, + test_workers: list[Any] | None = None, + *, + sampling: SamplingStrategy | None = None, + loss: LossFunction | None = None, + candidate_filter: CandidateFilter | None = None, + ) -> None: + super().__init__( + goals, + targets, + workers, + control_init, + test_prefixes, + logfile, + managers, + test_goals, + test_targets, + test_workers, + ) + self._sampling = sampling + self._loss = loss + self._candidate_filter = candidate_filter + + def _resolve_sampling(self) -> SamplingStrategy: + sampling = getattr(self, "_sampling", None) + if sampling is not None: + return sampling + return StandardGCGSampling() + + def _resolve_loss(self, *, target_weight: float, control_weight: float) -> LossFunction: + loss = getattr(self, "_loss", None) + if loss is not None: + return loss + return CrossEntropyLoss(target_weight=target_weight, control_weight=control_weight) + + def _resolve_candidate_filter(self, *, filter_cand: bool) -> CandidateFilter: + candidate_filter = getattr(self, "_candidate_filter", None) + if candidate_filter is not None: + return candidate_filter + return LengthPreservingFilter(filter=filter_cand) + + def _sample_control_candidates( + self, + *, + worker_index: int, + gradient: torch.Tensor, + batch_size: int, + topk: int, + temp: float, + allow_non_ascii: bool, + ) -> torch.Tensor: + sampler = self._resolve_sampling() + prompt_manager = self.prompts[worker_index] + return sampler.sample_candidates( + gradient=gradient, + control_tokens=prompt_manager.control_toks, + batch_size=batch_size, + top_k=topk, + temperature=temp, + allow_non_ascii=allow_non_ascii, + non_ascii_tokens=prompt_manager.disallowed_toks, + ) + + def _filter_control_candidates( + self, + *, + worker_index: int, + control_cand: torch.Tensor, + filter_cand: bool, + ) -> list[str]: + candidate_filter = self._resolve_candidate_filter(filter_cand=filter_cand) + return candidate_filter.filter_candidates( + candidate_tokens=control_cand, + tokenizer=self.workers[worker_index].tokenizer, + current_control=self.control_str, + ) + + def _get_control_length(self, *, control: str) -> int | None: + try: + return len(self.workers[0].tokenizer(control).input_ids[1:]) + except (AttributeError, TypeError, ValueError): + return None + def step( self, *, @@ -158,6 +257,7 @@ def step( """ main_device = self.models[0].device control_cands = [] + loss_function = self._resolve_loss(target_weight=target_weight, control_weight=control_weight) for j, worker in enumerate(self.workers): worker(self.prompts[j], "grad", worker.model) @@ -171,10 +271,19 @@ def step( grad = torch.zeros_like(new_grad) if grad.shape != new_grad.shape: with torch.no_grad(): - control_cand = self.prompts[j - 1].sample_control(grad, batch_size, topk, temp, allow_non_ascii) + control_cand = self._sample_control_candidates( + worker_index=j - 1, + gradient=grad, + batch_size=batch_size, + topk=topk, + temp=temp, + allow_non_ascii=allow_non_ascii, + ) control_cands.append( - self.get_filtered_cands( - j - 1, control_cand, filter_cand=filter_cand, curr_control=self.control_str + self._filter_control_candidates( + worker_index=j - 1, + control_cand=control_cand, + filter_cand=filter_cand, ) ) grad = new_grad @@ -182,9 +291,20 @@ def step( grad += new_grad with torch.no_grad(): - control_cand = self.prompts[j].sample_control(grad, batch_size, topk, temp, allow_non_ascii) + control_cand = self._sample_control_candidates( + worker_index=j, + gradient=grad, + batch_size=batch_size, + topk=topk, + temp=temp, + allow_non_ascii=allow_non_ascii, + ) control_cands.append( - self.get_filtered_cands(j, control_cand, filter_cand=filter_cand, curr_control=self.control_str) + self._filter_control_candidates( + worker_index=j, + control_cand=control_cand, + filter_cand=filter_cand, + ) ) del grad, control_cand gc.collect() @@ -205,14 +325,14 @@ def step( worker(self.prompts[k][i], "logits", worker.model, cand, return_ids=True) logits, ids = zip(*[worker.results.get() for worker in self.workers]) loss[j * batch_size : (j + 1) * batch_size] += sum( - target_weight * self.prompts[k][i].target_loss(logit, id).mean(dim=-1).to(main_device) + loss_function.compute_loss( + logits=logit, + token_ids=id, + target_slice=self.prompts[k][i]._target_slice, + control_slice=self.prompts[k][i]._control_slice, + ).to(main_device) for k, (logit, id) in enumerate(zip(logits, ids)) ) - if control_weight != 0: - loss[j * batch_size : (j + 1) * batch_size] += sum( - control_weight * self.prompts[k][i].control_loss(logit, id).mean(dim=-1).to(main_device) - for k, (logit, id) in enumerate(zip(logits, ids)) - ) del logits, ids gc.collect() @@ -229,7 +349,9 @@ def step( del control_cands, loss gc.collect() - logger.info(f"Current length: {len(self.workers[0].tokenizer(next_control).input_ids[1:])}") + current_length = self._get_control_length(control=next_control) + if current_length is not None: + logger.info(f"Current length: {current_length}") logger.info(next_control) return next_control, cand_loss.item() / len(self.prompts[0]) / len(self.workers) diff --git a/pyrit/auxiliary_attacks/gcg/config.py b/pyrit/auxiliary_attacks/gcg/config.py index 097a9087af..c2debada6e 100644 --- a/pyrit/auxiliary_attacks/gcg/config.py +++ b/pyrit/auxiliary_attacks/gcg/config.py @@ -25,6 +25,13 @@ if TYPE_CHECKING: from pathlib import Path + from pyrit.auxiliary_attacks.gcg.extension_protocols import ( + CandidateFilter, + LossFunction, + SamplingStrategy, + SuffixInitializer, + ) + _DEFAULT_CONTROL_INIT: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !" @@ -147,6 +154,18 @@ class GCGAlgorithmConfig: random_seed (int): Seed for ``torch``/``numpy``/``random``. Defaults to 42. control_init (str): Initial suffix string the optimization starts from. Defaults to twenty space-separated ``!`` tokens. + sampling (SamplingStrategy | None): Optional strategy object that + samples candidate suffix token sequences from the aggregated + gradient. ``None`` uses the built-in default implementation. + loss (LossFunction | None): Optional loss object used to score each + candidate suffix. ``None`` uses the built-in weighted + cross-entropy default that preserves legacy behavior. + candidate_filter (CandidateFilter | None): Optional candidate-filter + object that decodes/prunes sampled candidate token sequences. + ``None`` uses the built-in length-preserving filter. + suffix_init (SuffixInitializer | None): Optional initializer object + that produces the initial suffix string at attack construction + time. ``None`` uses ``control_init`` verbatim. """ n_steps: int = 500 @@ -161,6 +180,10 @@ class GCGAlgorithmConfig: filter_cand: bool = True random_seed: int = 42 control_init: str = _DEFAULT_CONTROL_INIT + sampling: SamplingStrategy | None = None + loss: LossFunction | None = None + candidate_filter: CandidateFilter | None = None + suffix_init: SuffixInitializer | None = None def __post_init__(self) -> None: if self.n_steps <= 0: @@ -183,6 +206,27 @@ def __post_init__(self) -> None: ) if not self.control_init: raise ValueError("GCGAlgorithmConfig.control_init must be a non-empty string.") + self._validate_extensions() + + def _validate_extensions(self) -> None: + from pyrit.auxiliary_attacks.gcg.extension_protocols import ( + CandidateFilter, + LossFunction, + SamplingStrategy, + SuffixInitializer, + ) + + checks = ( + ("sampling", self.sampling, SamplingStrategy), + ("loss", self.loss, LossFunction), + ("candidate_filter", self.candidate_filter, CandidateFilter), + ("suffix_init", self.suffix_init, SuffixInitializer), + ) + for field_name, value, protocol in checks: + if value is not None and not isinstance(value, protocol): + raise ValueError( + f"GCGAlgorithmConfig.{field_name} must satisfy {protocol.__name__}, got {type(value)!r}." + ) @dataclass diff --git a/pyrit/auxiliary_attacks/gcg/default_implementations.py b/pyrit/auxiliary_attacks/gcg/default_implementations.py new file mode 100644 index 0000000000..3967c128c7 --- /dev/null +++ b/pyrit/auxiliary_attacks/gcg/default_implementations.py @@ -0,0 +1,331 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Default concrete implementations of the four GCG extension protocols. + +Each class in this module reproduces the byte-identical behavior of the +legacy GCG attack code path it replaces: + +- ``StandardGCGSampling`` reproduces ``GCGPromptManager.sample_control``. +- ``CrossEntropyLoss`` reproduces ``AttackPrompt.target_loss`` and + ``AttackPrompt.control_loss`` combined via the weighted sum applied + inside ``GCGMultiPromptAttack.step``. +- ``LengthPreservingFilter`` reproduces ``MultiPromptAttack.get_filtered_cands``. +- ``LiteralStringInit`` reproduces the literal-string ``control_init`` + parameter threaded through the attack constructors. + +The defaults are *not* wired into ``GCGMultiPromptAttack`` here. They are +shipped ahead of wiring so the strategy objects can already be constructed +and inspected, and so the wiring change is a pure orchestration edit. +""" + +from __future__ import annotations + +import logging +from typing import Any + +import numpy as np +import torch +import torch.nn as nn + +logger = logging.getLogger(__name__) + + +class StandardGCGSampling: + """Top-k by ``-gradient``, uniform pick within top-k at one random position per row. + + The standard GCG sampling rule: for each of ``batch_size`` candidate + rows, pick one of the ``control_length`` positions, then replace the + token at that position with a uniformly-sampled token id from the top-k + smallest-gradient (most-promising) candidates at that position. The + ``temperature`` argument is part of the protocol but is unused by this + sampler, which always samples uniformly within the top-k. + + Reproduces ``GCGPromptManager.sample_control`` from + ``pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py`` byte-for-byte. + """ + + def sample_candidates( + self, + *, + gradient: torch.Tensor, + control_tokens: torch.Tensor, + batch_size: int, + top_k: int, + temperature: float, + allow_non_ascii: bool, + non_ascii_tokens: torch.Tensor, + ) -> torch.Tensor: + """Sample ``batch_size`` candidate suffix token sequences. + + Args: + gradient (torch.Tensor): Aggregated gradient over the control + tokens with shape ``(control_length, vocab_size)``. Mutated + in-place when ``allow_non_ascii`` is False (the disallowed + token positions are set to ``+inf``), matching legacy + behavior. + control_tokens (torch.Tensor): Current suffix token sequence + with shape ``(control_length,)``. + batch_size (int): Number of candidate suffix rows to return. + top_k (int): Number of top gradient positions per control slot + drawn from. + temperature (float): Sampling temperature. Unused by this + implementation; kept to match the protocol signature. + allow_non_ascii (bool): When False, mask the ``non_ascii_tokens`` + positions of ``gradient`` to ``+inf`` so they fall out of + the top-k. + non_ascii_tokens (torch.Tensor): Token ids to exclude when + ``allow_non_ascii`` is False. + + Returns: + torch.Tensor: Candidate suffix token sequences with shape + ``(batch_size, control_length)`` on the same device as + ``gradient``. + """ + if not allow_non_ascii: + gradient[:, non_ascii_tokens.to(gradient.device)] = np.inf + top_indices = (-gradient).topk(top_k, dim=1).indices + control_tokens = control_tokens.to(gradient.device) + original_control_tokens = control_tokens.repeat(batch_size, 1) + new_token_pos = torch.arange( + 0, + len(control_tokens), + len(control_tokens) / batch_size, + device=gradient.device, + ).type(torch.int64) + new_token_val = torch.gather( + top_indices[new_token_pos], + 1, + torch.randint(0, top_k, (batch_size, 1), device=gradient.device), + ) + return original_control_tokens.scatter_(1, new_token_pos.unsqueeze(-1), new_token_val) + + +class CrossEntropyLoss: + """Weighted token-level cross-entropy on the target and control slices. + + Per candidate: ``target_weight * CE(target_slice) + control_weight * + CE(control_slice)``, where each cross-entropy term is reduced over its + slice with ``.mean(dim=-1)`` to give one scalar per candidate. The + ``.mean(dim=-1)`` reduction matches where the legacy orchestrator + applies it: ``GCGMultiPromptAttack.step`` calls + ``target_loss(...).mean(dim=-1)`` outside the per-prompt loss method, + so the ``LossFunction`` protocol places the per-candidate scalar + reduction inside the implementation. + + When ``control_weight == 0`` the control term is skipped entirely, + matching the legacy ``if control_weight != 0:`` guard inside ``step``. + The same skip is applied when ``target_weight == 0`` for symmetry. + + Reproduces ``AttackPrompt.target_loss`` + ``AttackPrompt.control_loss`` + from ``pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py``, + combined per ``GCGMultiPromptAttack.step`` in + ``pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py``. + """ + + def __init__(self, *, target_weight: float = 1.0, control_weight: float = 0.0) -> None: + """Initialize the cross-entropy loss with target / control weights. + + Args: + target_weight (float): Weight on the target-slice cross-entropy. + Defaults to 1.0. + control_weight (float): Weight on the control-slice + cross-entropy. Defaults to 0.0 (target-only signal). + + Raises: + ValueError: If either weight is negative, or if both are zero. + """ + if target_weight < 0 or control_weight < 0: + raise ValueError( + "CrossEntropyLoss target_weight and control_weight must be >= 0, " + f"got target_weight={target_weight}, control_weight={control_weight}." + ) + if target_weight == 0 and control_weight == 0: + raise ValueError( + "CrossEntropyLoss requires at least one of target_weight or control_weight to be > 0; " + "with both at 0 the loss is identically zero and provides no signal." + ) + self._target_weight = target_weight + self._control_weight = control_weight + + def compute_loss( + self, + *, + logits: torch.Tensor, + token_ids: torch.Tensor, + target_slice: slice, + control_slice: slice, + ) -> torch.Tensor: + """Compute the per-candidate weighted cross-entropy loss. + + Args: + logits (torch.Tensor): Model logits for the candidate batch + with shape ``(batch_size, seq_len, vocab_size)``. + token_ids (torch.Tensor): Input token ids the model was run on + with shape ``(batch_size, seq_len)``. + target_slice (slice): Slice into the sequence dimension that + identifies the target tokens. + control_slice (slice): Slice into the sequence dimension that + identifies the control (suffix) tokens. + + Returns: + torch.Tensor: Per-candidate scalar loss with shape + ``(batch_size,)``. + """ + criterion = nn.CrossEntropyLoss(reduction="none") + total: torch.Tensor | None = None + + if self._target_weight > 0: + target_loss_slice = slice(target_slice.start - 1, target_slice.stop - 1) + target_term = criterion( + logits[:, target_loss_slice, :].transpose(1, 2), + token_ids[:, target_slice], + ).mean(dim=-1) + total = self._target_weight * target_term + + if self._control_weight > 0: + control_loss_slice = slice(control_slice.start - 1, control_slice.stop - 1) + control_term = criterion( + logits[:, control_loss_slice, :].transpose(1, 2), + token_ids[:, control_slice], + ).mean(dim=-1) + weighted_control = self._control_weight * control_term + total = weighted_control if total is None else total + weighted_control + + # Constructor guarantees at least one weight is > 0, so ``total`` is + # always assigned. The check is kept for the type checker. + if total is None: + raise RuntimeError( + "CrossEntropyLoss.compute_loss produced no terms; " + "this indicates a corrupted instance with both weights at 0." + ) + return total + + +class LengthPreservingFilter: + """Decodes each candidate token row and drops any whose decoded string + either (a) equals ``current_control`` or (b) re-tokenizes to a different + token count, padding dropped rows by repeating the last accepted + candidate. + + The ``filter`` constructor parameter selects between filtering (legacy + ``filter_cand=True`` branch) and passthrough decode-only mode (legacy + ``filter_cand=False`` branch). + + Also performs the legacy out-of-vocab clamping: tokens above + ``tokenizer.vocab_size`` are replaced in-place by the id of ``"!"``, + matching the safety pass at the top of ``get_filtered_cands``. + + Reproduces ``MultiPromptAttack.get_filtered_cands`` from + ``pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py``. + """ + + def __init__(self, *, filter: bool = True) -> None: + """Initialize the filter. + + Args: + filter (bool): When True, drop candidates that equal + ``current_control`` or re-tokenize to a different length, + padding the result with the last accepted candidate. When + False, decode every row and return them all unchanged. + Defaults to True. + """ + self._filter = filter + + def filter_candidates( + self, + *, + candidate_tokens: torch.Tensor, + tokenizer: Any, + current_control: str, + ) -> list[str]: + """Decode and filter a batch of candidate suffix token tensors. + + Args: + candidate_tokens (torch.Tensor): Sampled candidate suffixes + with shape ``(batch_size, control_length)``. Mutated + in-place by the out-of-vocab clamp, matching legacy + behavior. + tokenizer (Any): HuggingFace-style tokenizer. ``tokenizer.decode`` + renders each row to text; ``tokenizer(text, + add_special_tokens=False).input_ids`` is used to detect + re-tokenization drift; ``tokenizer("!").input_ids[0]`` + provides the replacement id for out-of-vocab clamping. + current_control (str): Current suffix string. When ``filter`` + is True, candidates that decode to this string are dropped. + + Returns: + list[str]: Decoded candidate suffix strings of length exactly + ``candidate_tokens.shape[0]``. + """ + logger.info("Masking out of range token_id.") + vocab_size = tokenizer.vocab_size + candidate_tokens[candidate_tokens > vocab_size] = tokenizer("!").input_ids[0] + + candidates: list[str] = [] + for i in range(candidate_tokens.shape[0]): + decoded_str = tokenizer.decode( + candidate_tokens[i], skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + if self._filter: + if decoded_str != current_control and len( + tokenizer(decoded_str, add_special_tokens=False).input_ids + ) == len(candidate_tokens[i]): + candidates.append(decoded_str) + else: + candidates.append(decoded_str) + + if self._filter: + candidates = candidates + [candidates[-1]] * (len(candidate_tokens) - len(candidates)) + return candidates + + +class LiteralStringInit: + """Returns the configured literal suffix verbatim; ignores the tokenizer. + + Encapsulates the current ``control_init`` plumbing — a literal string + threaded through ``AttackPrompt.__init__``, ``PromptManager.__init__``, + ``MultiPromptAttack.__init__``, and the per-strategy ``*Attack`` + constructors — so that custom initializers that do need the tokenizer + (for example, a random vocabulary sampler) can be swapped in without + changing those constructor signatures. + + Reproduces the literal-string ``control_init`` parameter assignment + (``self.control = control_init``) inside ``AttackPrompt.__init__`` in + ``pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py``. + """ + + def __init__(self, *, suffix: str) -> None: + """Initialize the literal-string suffix initializer. + + Args: + suffix (str): The literal suffix string to return on every + call to ``make_initial_suffix``. Must be non-empty. + + Raises: + ValueError: If ``suffix`` is the empty string. + """ + if not suffix: + raise ValueError("LiteralStringInit.suffix must be a non-empty string.") + self._suffix = suffix + + def make_initial_suffix(self, *, tokenizer: Any) -> str: + """Return the configured suffix string. + + Args: + tokenizer (Any): Ignored. Present to match the protocol + signature so custom initializers that need vocabulary + access can be substituted without changing call sites. + + Returns: + str: The literal suffix string supplied at construction. + """ + return self._suffix + + +__all__ = [ + "CrossEntropyLoss", + "LengthPreservingFilter", + "LiteralStringInit", + "StandardGCGSampling", +] diff --git a/pyrit/auxiliary_attacks/gcg/extension_protocols.py b/pyrit/auxiliary_attacks/gcg/extension_protocols.py index f9f1a3013e..973fb22a2b 100644 --- a/pyrit/auxiliary_attacks/gcg/extension_protocols.py +++ b/pyrit/auxiliary_attacks/gcg/extension_protocols.py @@ -16,12 +16,11 @@ - ``SuffixInitializer`` — how the initial suffix string fed into the optimization loop is constructed. -The module is **typing surface only**. It ships no concrete implementations, -no defaults, and no wiring into ``GCGAlgorithmConfig`` or -``GCGMultiPromptAttack``. The default behaviors that match the current attack -code will land as concrete classes in a follow-up PR; the optional -``GCGAlgorithmConfig`` fields that select between defaults and custom -implementations will land in the PR after that. +The module is **typing surface only**. Concrete defaults live in +``default_implementations.py``, and orchestration wiring lives in +``GCGAlgorithmConfig`` + ``GCGMultiPromptAttack``. Keeping this module purely +protocol definitions preserves a stable extension API that can be imported +without pulling in heavy runtime dependencies. Tensor-typed signatures are kept lazy via ``from __future__ import annotations`` plus a ``TYPE_CHECKING`` import for ``torch`` so that diff --git a/pyrit/auxiliary_attacks/gcg/generator.py b/pyrit/auxiliary_attacks/gcg/generator.py index 4c812594e9..12ef46040c 100644 --- a/pyrit/auxiliary_attacks/gcg/generator.py +++ b/pyrit/auxiliary_attacks/gcg/generator.py @@ -38,6 +38,7 @@ import logging import time from dataclasses import dataclass, field +from functools import partial from typing import Any, overload import numpy as np @@ -212,6 +213,18 @@ def _build_identifier(self) -> ComponentIdentifier: "topk": self._algorithm.topk, "target_weight": self._algorithm.target_weight, "control_weight": self._algorithm.control_weight, + "sampling_impl": ( + type(self._algorithm.sampling).__name__ if self._algorithm.sampling is not None else "default" + ), + "loss_impl": type(self._algorithm.loss).__name__ if self._algorithm.loss is not None else "default", + "candidate_filter_impl": ( + type(self._algorithm.candidate_filter).__name__ + if self._algorithm.candidate_filter is not None + else "default" + ), + "suffix_init_impl": ( + type(self._algorithm.suffix_init).__name__ if self._algorithm.suffix_init is not None else "default" + ), "transfer": self._strategy.transfer, "progressive_goals": self._strategy.progressive_goals, "progressive_models": self._strategy.progressive_models, @@ -257,7 +270,12 @@ async def _perform_async(self, *, context: GCGContext) -> GCGResult: managers = { "AP": attack_lib.GCGAttackPrompt, "PM": attack_lib.GCGPromptManager, - "MPA": attack_lib.GCGMultiPromptAttack, + "MPA": partial( + attack_lib.GCGMultiPromptAttack, + sampling=self._algorithm.sampling, + loss=self._algorithm.loss, + candidate_filter=self._algorithm.candidate_filter, + ), } context.attack = self._create_attack( params=params, @@ -400,6 +418,7 @@ def _create_attack( logfile_path: str, ) -> Any: """Build the right attack object based on the strategy flags.""" + control_init = self._resolve_control_init(workers=workers) if self._strategy.transfer: return ProgressiveMultiPromptAttack( train_goals, @@ -407,7 +426,7 @@ def _create_attack( workers, progressive_models=self._strategy.progressive_models, progressive_goals=self._strategy.progressive_goals, - control_init=self._algorithm.control_init, + control_init=control_init, logfile=logfile_path, managers=managers, test_goals=test_goals, @@ -421,7 +440,7 @@ def _create_attack( train_goals, train_targets, workers, - control_init=self._algorithm.control_init, + control_init=control_init, logfile=logfile_path, managers=managers, test_goals=test_goals, @@ -432,6 +451,18 @@ def _create_attack( mpa_n_steps=self._algorithm.n_steps, ) + def _resolve_control_init(self, *, workers: list[Any]) -> str: + """Resolve the initial suffix string for a run. + + Uses the configured ``suffix_init`` extension when provided; otherwise + falls back to the legacy literal ``control_init`` value. + """ + if self._algorithm.suffix_init is None: + return self._algorithm.control_init + if not workers: + raise ValueError("Cannot resolve suffix_init without at least one worker tokenizer.") + return self._algorithm.suffix_init.make_initial_suffix(tokenizer=workers[0].tokenizer) + @staticmethod def _read_result(*, logfile_path: str, memory_labels: dict[str, str]) -> GCGResult: """Pull final-step values out of the JSON log written during the run.""" diff --git a/tests/unit/auxiliary_attacks/gcg/test_config.py b/tests/unit/auxiliary_attacks/gcg/test_config.py index da0a7f6a9a..922b0ffedd 100644 --- a/tests/unit/auxiliary_attacks/gcg/test_config.py +++ b/tests/unit/auxiliary_attacks/gcg/test_config.py @@ -9,7 +9,7 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import pytest @@ -28,6 +28,49 @@ _LLAMA_2 = "meta-llama/Llama-2-7b-chat-hf" +class _SamplingStub: + def sample_candidates( + self, + *, + gradient: Any, + control_tokens: Any, + batch_size: int, + top_k: int, + temperature: float, + allow_non_ascii: bool, + non_ascii_tokens: Any, + ) -> Any: + return control_tokens + + +class _LossStub: + def compute_loss( + self, + *, + logits: Any, + token_ids: Any, + target_slice: slice, + control_slice: slice, + ) -> Any: + return logits + + +class _FilterStub: + def filter_candidates( + self, + *, + candidate_tokens: Any, + tokenizer: Any, + current_control: str, + ) -> list[str]: + return [current_control] + + +class _SuffixInitStub: + def make_initial_suffix(self, *, tokenizer: Any) -> str: + return "stub suffix" + + def _minimal_config() -> GCGConfig: return GCGConfig(models=[GCGModelConfig(name=_LLAMA_2)]) @@ -42,6 +85,10 @@ def test_minimal_config_constructs_with_defaults() -> None: assert config.test_models == [] assert config.algorithm.n_steps == 500 assert config.algorithm.batch_size == 512 + assert config.algorithm.sampling is None + assert config.algorithm.loss is None + assert config.algorithm.candidate_filter is None + assert config.algorithm.suffix_init is None assert config.strategy.transfer is False assert config.output.verbose is True assert config.hf_token is None @@ -100,6 +147,33 @@ def test_algorithm_empty_control_init_raises() -> None: GCGAlgorithmConfig(control_init="") +@pytest.mark.parametrize( + "field_name,value", + [ + ("sampling", object()), + ("loss", object()), + ("candidate_filter", object()), + ("suffix_init", object()), + ], +) +def test_algorithm_extension_type_validation(field_name: str, value: object) -> None: + with pytest.raises(ValueError, match=rf"GCGAlgorithmConfig\.{field_name} must satisfy"): + GCGAlgorithmConfig(**{field_name: value}) + + +def test_algorithm_accepts_protocol_implementations() -> None: + config = GCGAlgorithmConfig( + sampling=_SamplingStub(), + loss=_LossStub(), + candidate_filter=_FilterStub(), + suffix_init=_SuffixInitStub(), + ) + assert config.sampling is not None + assert config.loss is not None + assert config.candidate_filter is not None + assert config.suffix_init is not None + + @pytest.mark.parametrize("field_name", ["n_train_data", "n_test_data"]) def test_data_negative_count_raises(field_name: str) -> None: with pytest.raises(ValueError, match=f"GCGDataConfig.{field_name} must be >= 0"): diff --git a/tests/unit/auxiliary_attacks/gcg/test_default_implementations.py b/tests/unit/auxiliary_attacks/gcg/test_default_implementations.py new file mode 100644 index 0000000000..8b89745052 --- /dev/null +++ b/tests/unit/auxiliary_attacks/gcg/test_default_implementations.py @@ -0,0 +1,454 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Tests for ``pyrit.auxiliary_attacks.gcg.default_implementations``. + +These tests verify byte-identical parity between the four default +implementations and the legacy GCG attack code paths they reproduce: + +- ``StandardGCGSampling`` vs ``GCGPromptManager.sample_control`` +- ``CrossEntropyLoss`` vs the weighted sum of ``AttackPrompt.target_loss`` + and ``AttackPrompt.control_loss`` applied inside + ``GCGMultiPromptAttack.step`` +- ``LengthPreservingFilter`` vs ``MultiPromptAttack.get_filtered_cands`` +- ``LiteralStringInit`` vs the literal-string ``control_init`` assignment + inside ``AttackPrompt.__init__`` + +Mocking patterns follow the conventions established in +``tests/unit/auxiliary_attacks/gcg/test_gcg_core.py`` (``object.__new__`` +to skip the real ``__init__``, ``MagicMock`` tokenizers). +""" + +from unittest.mock import MagicMock + +import pytest + +torch = pytest.importorskip("torch", reason="GCG default implementations require torch") + +attack_manager_mod = pytest.importorskip( + "pyrit.auxiliary_attacks.gcg.attack.base.attack_manager", + reason="GCG optional dependencies (torch, mlflow, etc.) not installed", +) +gcg_attack_mod = pytest.importorskip( + "pyrit.auxiliary_attacks.gcg.attack.gcg.gcg_attack", + reason="GCG optional dependencies not installed", +) + +import pyrit.auxiliary_attacks.gcg as gcg_pkg # noqa: E402 +from pyrit.auxiliary_attacks.gcg import ( # noqa: E402 + CrossEntropyLoss, + LengthPreservingFilter, + LiteralStringInit, + StandardGCGSampling, +) +from pyrit.auxiliary_attacks.gcg import default_implementations as defaults_module # noqa: E402 +from pyrit.auxiliary_attacks.gcg.config import GCGAlgorithmConfig # noqa: E402 + +AttackPrompt = attack_manager_mod.AttackPrompt +MultiPromptAttack = attack_manager_mod.MultiPromptAttack +GCGPromptManager = gcg_attack_mod.GCGPromptManager + + +DEFAULT_NAMES = ( + "CrossEntropyLoss", + "LengthPreservingFilter", + "LiteralStringInit", + "StandardGCGSampling", +) + + +class TestPackageReExports: + """Verify the four default classes are re-exported from the package root.""" + + @pytest.mark.parametrize("name", DEFAULT_NAMES) + def test_default_is_reexported_with_identity(self, name: str) -> None: + package_attr = getattr(gcg_pkg, name) + module_attr = getattr(defaults_module, name) + assert package_attr is module_attr, ( + f"{name} re-exported from pyrit.auxiliary_attacks.gcg must be the same " + f"object as pyrit.auxiliary_attacks.gcg.default_implementations.{name}" + ) + + @pytest.mark.parametrize("name", DEFAULT_NAMES) + def test_default_in_package_dunder_all(self, name: str) -> None: + assert name in gcg_pkg.__all__ + + +class TestStandardGCGSampling: + """Parity: ``StandardGCGSampling`` vs ``GCGPromptManager.sample_control``.""" + + def _make_legacy_prompt_manager( + self, + *, + control_tokens: torch.Tensor, + non_ascii_tokens: torch.Tensor, + ) -> GCGPromptManager: + # Mirrors the construction pattern used by TestSampleControl in + # test_gcg_core.py: skip __init__ and seed just the attributes that + # sample_control reads. + prompt_manager = object.__new__(GCGPromptManager) + prompt_manager._nonascii_toks = non_ascii_tokens + prompt_manager._prompts = [MagicMock()] + prompt_manager._prompts[0].control_toks = control_tokens.clone() + return prompt_manager + + def test_sample_candidates_matches_legacy_with_ascii_only(self) -> None: + """Legacy reference: ``GCGPromptManager.sample_control(grad, batch_size, + topk=top_k, temp=1.0, allow_non_ascii=False)`` in + ``pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py``. + """ + n_control_tokens = 5 + vocab_size = 50 + batch_size = 4 + top_k = 8 + + torch.manual_seed(2026) + gradient_template = torch.randn(n_control_tokens, vocab_size) + control_tokens = torch.randint(0, vocab_size, (n_control_tokens,)) + non_ascii_tokens = torch.tensor([2, 7, 13]) + + # Legacy path + prompt_manager = self._make_legacy_prompt_manager( + control_tokens=control_tokens, non_ascii_tokens=non_ascii_tokens + ) + torch.manual_seed(12345) + legacy_out = prompt_manager.sample_control( + gradient_template.clone(), + batch_size, + topk=top_k, + temp=1.0, + allow_non_ascii=False, + ) + + # Default path + default = StandardGCGSampling() + torch.manual_seed(12345) + default_out = default.sample_candidates( + gradient=gradient_template.clone(), + control_tokens=control_tokens.clone(), + batch_size=batch_size, + top_k=top_k, + temperature=1.0, + allow_non_ascii=False, + non_ascii_tokens=non_ascii_tokens, + ) + + assert torch.equal(default_out, legacy_out) + + def test_sample_candidates_matches_legacy_with_non_ascii_allowed(self) -> None: + """Legacy reference: same as above but with ``allow_non_ascii=True`` + (the no-mask branch where the gradient is not mutated). + """ + n_control_tokens = 6 + vocab_size = 40 + batch_size = 5 + top_k = 10 + + torch.manual_seed(2027) + gradient_template = torch.randn(n_control_tokens, vocab_size) + control_tokens = torch.randint(0, vocab_size, (n_control_tokens,)) + non_ascii_tokens = torch.tensor([1, 4]) + + prompt_manager = self._make_legacy_prompt_manager( + control_tokens=control_tokens, non_ascii_tokens=non_ascii_tokens + ) + torch.manual_seed(54321) + legacy_out = prompt_manager.sample_control( + gradient_template.clone(), + batch_size, + topk=top_k, + temp=1.0, + allow_non_ascii=True, + ) + + default = StandardGCGSampling() + torch.manual_seed(54321) + default_out = default.sample_candidates( + gradient=gradient_template.clone(), + control_tokens=control_tokens.clone(), + batch_size=batch_size, + top_k=top_k, + temperature=1.0, + allow_non_ascii=True, + non_ascii_tokens=non_ascii_tokens, + ) + + assert torch.equal(default_out, legacy_out) + + +class TestCrossEntropyLoss: + """Parity: ``CrossEntropyLoss`` vs ``AttackPrompt.target_loss`` + + ``AttackPrompt.control_loss``. + """ + + def _make_legacy_prompt( + self, + *, + target_slice: slice, + control_slice: slice, + ) -> AttackPrompt: + # Mirrors TestTargetAndControlLoss in test_gcg_core.py: skip + # __init__ and seed only the slice attributes that the loss methods + # consult. + prompt = object.__new__(AttackPrompt) + prompt._target_slice = target_slice + prompt._control_slice = control_slice + return prompt + + def test_compute_loss_matches_legacy_weighted_sum(self) -> None: + """Legacy reference: + ``target_weight * AttackPrompt.target_loss(logits, ids).mean(dim=-1)`` + ``+ control_weight * AttackPrompt.control_loss(logits, ids).mean(dim=-1)``, + per ``GCGMultiPromptAttack.step`` in + ``pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py``. + """ + batch_size = 4 + seq_len = 10 + vocab_size = 30 + target_slice = slice(5, 8) + control_slice = slice(2, 5) + target_weight = 1.0 + control_weight = 0.1 + + torch.manual_seed(99) + logits = torch.randn(batch_size, seq_len, vocab_size) + token_ids = torch.randint(0, vocab_size, (batch_size, seq_len)) + + prompt = self._make_legacy_prompt(target_slice=target_slice, control_slice=control_slice) + legacy_target = prompt.target_loss(logits, token_ids).mean(dim=-1) + legacy_control = prompt.control_loss(logits, token_ids).mean(dim=-1) + legacy_total = target_weight * legacy_target + control_weight * legacy_control + + default = CrossEntropyLoss(target_weight=target_weight, control_weight=control_weight) + default_total = default.compute_loss( + logits=logits, + token_ids=token_ids, + target_slice=target_slice, + control_slice=control_slice, + ) + + assert torch.equal(default_total, legacy_total) + + def test_compute_loss_target_only_matches_legacy_target_loss(self) -> None: + """With ``control_weight=0`` the legacy ``step`` skips the control + term (``if control_weight != 0:`` guard at line 211). The default + must produce the same per-candidate value as + ``target_weight * target_loss(...).mean(dim=-1)`` alone. + """ + target_slice = slice(4, 7) + control_slice = slice(1, 4) + + torch.manual_seed(7) + logits = torch.randn(3, 9, 25) + token_ids = torch.randint(0, 25, (3, 9)) + + prompt = self._make_legacy_prompt(target_slice=target_slice, control_slice=control_slice) + legacy_total = 1.0 * prompt.target_loss(logits, token_ids).mean(dim=-1) + + default = CrossEntropyLoss(target_weight=1.0, control_weight=0.0) + default_total = default.compute_loss( + logits=logits, + token_ids=token_ids, + target_slice=target_slice, + control_slice=control_slice, + ) + + assert torch.equal(default_total, legacy_total) + + def test_compute_loss_control_only_matches_legacy_control_loss(self) -> None: + """With ``target_weight=0`` the default must produce the same value + as ``control_weight * control_loss(...).mean(dim=-1)`` alone. + """ + target_slice = slice(4, 7) + control_slice = slice(1, 4) + + torch.manual_seed(13) + logits = torch.randn(3, 9, 25) + token_ids = torch.randint(0, 25, (3, 9)) + + prompt = self._make_legacy_prompt(target_slice=target_slice, control_slice=control_slice) + legacy_total = 0.5 * prompt.control_loss(logits, token_ids).mean(dim=-1) + + default = CrossEntropyLoss(target_weight=0.0, control_weight=0.5) + default_total = default.compute_loss( + logits=logits, + token_ids=token_ids, + target_slice=target_slice, + control_slice=control_slice, + ) + + assert torch.equal(default_total, legacy_total) + + def test_init_rejects_both_weights_zero(self) -> None: + with pytest.raises(ValueError, match="at least one"): + CrossEntropyLoss(target_weight=0.0, control_weight=0.0) + + def test_init_rejects_negative_target_weight(self) -> None: + with pytest.raises(ValueError, match=">= 0"): + CrossEntropyLoss(target_weight=-0.5, control_weight=1.0) + + def test_init_rejects_negative_control_weight(self) -> None: + with pytest.raises(ValueError, match=">= 0"): + CrossEntropyLoss(target_weight=1.0, control_weight=-0.5) + + def test_compute_loss_returns_batch_sized_tensor(self) -> None: + batch_size = 4 + logits = torch.randn(batch_size, 10, 20) + token_ids = torch.randint(0, 20, (batch_size, 10)) + + default = CrossEntropyLoss(target_weight=1.0, control_weight=0.1) + out = default.compute_loss( + logits=logits, + token_ids=token_ids, + target_slice=slice(5, 8), + control_slice=slice(2, 5), + ) + + assert out.shape == (batch_size,) + + +def _make_filter_tokenizer() -> MagicMock: + """Build a fresh, deterministic, stateless mock tokenizer for filter tests. + + Behavior: + - ``decode(tensor)`` -> ``"x" * int(tensor[0].item())`` — string length + is keyed off the first token id, so each row maps to a distinct + predictable string. + - ``tokenizer(text, ...).input_ids`` has length ``len(text)`` — so the + retokenized length check is fully predictable from the decoded + string. + - ``tokenizer("!").input_ids[0] == 0`` — provides the clamp + replacement id. + - ``vocab_size == 100``. + """ + tokenizer = MagicMock() + tokenizer.vocab_size = 100 + + def decode_fn(ids, **_kwargs): + return "x" * int(ids[0].item()) + + tokenizer.decode.side_effect = decode_fn + + def call_tokenizer(text, **_kwargs): + result = MagicMock() + if text == "!": + result.input_ids = [0] + else: + result.input_ids = list(range(len(text))) + return result + + tokenizer.side_effect = call_tokenizer + return tokenizer + + +class TestLengthPreservingFilter: + """Parity: ``LengthPreservingFilter`` vs + ``MultiPromptAttack.get_filtered_cands``. + """ + + def _make_legacy_attack(self, *, tokenizer: MagicMock) -> MultiPromptAttack: + # Mirrors TestGetFilteredCands in test_gcg_core.py: skip __init__ + # and only attach the workers list that get_filtered_cands reads. + attack = object.__new__(MultiPromptAttack) + worker = MagicMock() + worker.tokenizer = tokenizer + attack.workers = [worker] + return attack + + def test_filter_candidates_matches_legacy_filtered(self) -> None: + """Legacy reference: + ``MultiPromptAttack.get_filtered_cands(0, control_cand, + filter_cand=True, curr_control=...)`` in + ``pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py``. + + With the helper tokenizer: + - Row 0 ``[3, 0, 1]`` -> decode ``"xxx"`` (len 3); retok len 3 == + control_length 3 -> KEEP. + - Row 1 ``[5, 0, 0]`` -> decode ``"xxxxx"`` (len 5); retok len 5 + != 3 -> DROP. + - Row 2 ``[2, 0, 1]`` -> decode ``"xx"`` (len 2); retok len 2 != + 3 -> DROP. + Pad-with-last gives ``["xxx", "xxx", "xxx"]``. + """ + candidate_template = torch.tensor([[3, 0, 1], [5, 0, 0], [2, 0, 1]]) + + legacy_attack = self._make_legacy_attack(tokenizer=_make_filter_tokenizer()) + legacy_out = legacy_attack.get_filtered_cands( + 0, candidate_template.clone(), filter_cand=True, curr_control="never_matches" + ) + + default = LengthPreservingFilter(filter=True) + default_out = default.filter_candidates( + candidate_tokens=candidate_template.clone(), + tokenizer=_make_filter_tokenizer(), + current_control="never_matches", + ) + + assert default_out == legacy_out + assert legacy_out == ["xxx", "xxx", "xxx"] + + def test_filter_candidates_matches_legacy_unfiltered(self) -> None: + """Legacy reference: ``get_filtered_cands(0, control_cand, + filter_cand=False)``. Every row is decoded and returned unchanged. + """ + candidate_template = torch.tensor([[3, 0, 1], [5, 0, 0], [2, 0, 1]]) + + legacy_attack = self._make_legacy_attack(tokenizer=_make_filter_tokenizer()) + legacy_out = legacy_attack.get_filtered_cands(0, candidate_template.clone(), filter_cand=False) + + default = LengthPreservingFilter(filter=False) + default_out = default.filter_candidates( + candidate_tokens=candidate_template.clone(), + tokenizer=_make_filter_tokenizer(), + current_control="ignored_when_filter_false", + ) + + assert default_out == legacy_out + assert legacy_out == ["xxx", "xxxxx", "xx"] + + def test_filter_candidates_clamps_out_of_vocab_tokens(self) -> None: + """Both code paths apply the legacy vocab-clamp in-place: tokens + above ``vocab_size`` are replaced by the id of ``"!"`` before any + decoding happens. + """ + candidate_template = torch.tensor([[150, 0, 1], [3, 0, 1]]) # 150 > vocab_size=100 + + legacy_input = candidate_template.clone() + legacy_attack = self._make_legacy_attack(tokenizer=_make_filter_tokenizer()) + legacy_attack.get_filtered_cands(0, legacy_input, filter_cand=False) + + default_input = candidate_template.clone() + default = LengthPreservingFilter(filter=False) + default.filter_candidates( + candidate_tokens=default_input, + tokenizer=_make_filter_tokenizer(), + current_control="", + ) + + assert torch.equal(default_input, legacy_input) + assert default_input[0, 0].item() == 0 + + +class TestLiteralStringInit: + """Parity: ``LiteralStringInit`` vs the literal-string ``control_init`` + assignment inside ``AttackPrompt.__init__`` (``self.control = + control_init``). + """ + + def test_make_initial_suffix_returns_default_control_init(self) -> None: + """Legacy reference: ``GCGAlgorithmConfig.control_init`` (default + ``_DEFAULT_CONTROL_INIT``) is assigned to ``self.control`` in + ``AttackPrompt.__init__``. + """ + default_suffix = GCGAlgorithmConfig().control_init + initializer = LiteralStringInit(suffix=default_suffix) + assert initializer.make_initial_suffix(tokenizer=MagicMock()) == default_suffix + + def test_make_initial_suffix_ignores_tokenizer(self) -> None: + suffix = "custom suffix string" + initializer = LiteralStringInit(suffix=suffix) + assert initializer.make_initial_suffix(tokenizer=None) == suffix + + def test_init_rejects_empty_suffix(self) -> None: + with pytest.raises(ValueError, match="non-empty"): + LiteralStringInit(suffix="") diff --git a/tests/unit/auxiliary_attacks/gcg/test_gcg_core.py b/tests/unit/auxiliary_attacks/gcg/test_gcg_core.py index c3858bf357..a90563ed85 100644 --- a/tests/unit/auxiliary_attacks/gcg/test_gcg_core.py +++ b/tests/unit/auxiliary_attacks/gcg/test_gcg_core.py @@ -1,7 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from unittest.mock import MagicMock +from typing import Any +from unittest.mock import MagicMock, patch import pytest @@ -25,6 +26,7 @@ "pyrit.auxiliary_attacks.gcg.attack.gcg.gcg_attack", reason="GCG optional dependencies not installed", ) +GCGMultiPromptAttack = gcg_attack_mod.GCGMultiPromptAttack GCGPromptManager = gcg_attack_mod.GCGPromptManager token_gradients = gcg_attack_mod.token_gradients @@ -501,3 +503,466 @@ def test_raises_when_tokenizer_has_no_chat_template(self) -> None: with patch.object(attack_manager_mod.AutoTokenizer, "from_pretrained", return_value=bare_tokenizer): with pytest.raises(ValueError, match="no chat_template configured"): get_workers(params) + + +class _Queue: + def __init__(self, items: list[Any]) -> None: + self._items = list(items) + + def get(self) -> Any: + return self._items.pop(0) + + +class _WorkerStub: + def __init__( + self, + *, + gradient: torch.Tensor, + logits: torch.Tensor, + token_ids: torch.Tensor, + tokenizer: MagicMock, + ) -> None: + self.model = MagicMock() + self.model.device = "cpu" + self.tokenizer = tokenizer + self.results = _Queue([gradient, (logits, token_ids)]) + self.calls: list[tuple] = [] + + def __call__(self, *args: Any, **kwargs: Any) -> None: + self.calls.append((args, kwargs)) + + +class _PromptManagerStub: + def __init__( + self, + *, + prompt: AttackPrompt, + control_tokens: torch.Tensor, + disallowed_tokens: torch.Tensor, + control_str: str, + ) -> None: + self._prompts = [prompt] + self._control_tokens = control_tokens + self._disallowed_tokens = disallowed_tokens + self.control_str = control_str + + def __len__(self) -> int: + return len(self._prompts) + + def __getitem__(self, i: int) -> AttackPrompt: + return self._prompts[i] + + @property + def control_toks(self) -> torch.Tensor: + return self._control_tokens + + @property + def disallowed_toks(self) -> torch.Tensor: + return self._disallowed_tokens + + +class _SpySampling: + def __init__(self, *, sampled_tokens: torch.Tensor) -> None: + self.sampled_tokens = sampled_tokens + self.calls: list[dict] = [] + + def sample_candidates( + self, + *, + gradient: torch.Tensor, + control_tokens: torch.Tensor, + batch_size: int, + top_k: int, + temperature: float, + allow_non_ascii: bool, + non_ascii_tokens: torch.Tensor, + ) -> torch.Tensor: + self.calls.append( + { + "gradient": gradient.clone(), + "control_tokens": control_tokens.clone(), + "batch_size": batch_size, + "top_k": top_k, + "temperature": temperature, + "allow_non_ascii": allow_non_ascii, + "non_ascii_tokens": non_ascii_tokens.clone(), + } + ) + return self.sampled_tokens.clone() + + +class _SpyLoss: + def __init__(self, *, losses: torch.Tensor) -> None: + self.losses = losses + self.calls: list[dict] = [] + + def compute_loss( + self, + *, + logits: torch.Tensor, + token_ids: torch.Tensor, + target_slice: slice, + control_slice: slice, + ) -> torch.Tensor: + self.calls.append( + { + "logits": logits.clone(), + "token_ids": token_ids.clone(), + "target_slice": target_slice, + "control_slice": control_slice, + } + ) + return self.losses.to(logits.device) + + +class _SpyFilter: + def __init__(self, *, candidates: list[str]) -> None: + self.candidates = list(candidates) + self.calls: list[dict] = [] + + def filter_candidates( + self, + *, + candidate_tokens: torch.Tensor, + tokenizer: MagicMock, + current_control: str, + ) -> list[str]: + self.calls.append( + { + "candidate_tokens": candidate_tokens.clone(), + "tokenizer": tokenizer, + "current_control": current_control, + } + ) + return list(self.candidates) + + +class TestGCGMultiPromptAttackStepWiring: + @staticmethod + def _make_tokenizer() -> MagicMock: + tokenizer = MagicMock() + tokenizer.vocab_size = 100 + + def decode_fn(ids, **_kwargs): + values = ids.tolist() if hasattr(ids, "tolist") else list(ids) + return " ".join(str(int(v)) for v in values) + + def call_fn(text, **_kwargs): + output = MagicMock() + if text == "!": + output.input_ids = [0] + else: + output.input_ids = [int(piece) for piece in text.split()] if text else [] + return output + + tokenizer.decode.side_effect = decode_fn + tokenizer.side_effect = call_fn + return tokenizer + + @staticmethod + def _make_prompt(*, target_slice: slice, control_slice: slice) -> AttackPrompt: + prompt = object.__new__(AttackPrompt) + prompt._target_slice = target_slice + prompt._control_slice = control_slice + return prompt + + @staticmethod + def _make_attack( + *, + worker: _WorkerStub, + prompt_manager: _PromptManagerStub, + sampling: object | None = None, + loss: object | None = None, + candidate_filter: object | None = None, + ) -> GCGMultiPromptAttack: + attack = object.__new__(GCGMultiPromptAttack) + attack.workers = [worker] + attack.models = [worker.model] + attack.prompts = [prompt_manager] + attack._sampling = sampling + attack._loss = loss + attack._candidate_filter = candidate_filter + return attack + + def test_step_default_path_matches_legacy_behavior(self) -> None: + gradient = torch.tensor( + [ + [0.3, -0.4, 0.8, -0.2, 0.1, 0.5], + [-0.3, 0.2, -0.8, 0.4, 0.1, 0.7], + [0.2, 0.6, -0.1, -0.5, 0.4, -0.2], + ], + dtype=torch.float32, + ) + logits = torch.randn(1, 8, 10) + token_ids = torch.randint(0, 10, (1, 8)) + control_tokens = torch.tensor([1, 2, 3], dtype=torch.long) + disallowed_tokens = torch.tensor([], dtype=torch.long) + target_slice = slice(4, 6) + control_slice = slice(1, 4) + current_control = "99 99 99" + tokenizer = self._make_tokenizer() + + worker = _WorkerStub(gradient=gradient.clone(), logits=logits, token_ids=token_ids, tokenizer=tokenizer) + prompt = self._make_prompt(target_slice=target_slice, control_slice=control_slice) + prompt_manager = _PromptManagerStub( + prompt=prompt, + control_tokens=control_tokens, + disallowed_tokens=disallowed_tokens, + control_str=current_control, + ) + attack = self._make_attack(worker=worker, prompt_manager=prompt_manager) + + target_weight = 1.3 + control_weight = 0.2 + torch.manual_seed(2026) + actual_control, actual_loss = attack.step( + batch_size=1, + topk=3, + temp=1.0, + allow_non_ascii=True, + target_weight=target_weight, + control_weight=control_weight, + verbose=True, + filter_cand=True, + ) + + legacy_prompt_manager = object.__new__(GCGPromptManager) + legacy_prompt_for_sampling = MagicMock() + legacy_prompt_for_sampling.control_toks = control_tokens.clone() + legacy_prompt_manager._prompts = [legacy_prompt_for_sampling] + legacy_prompt_manager._nonascii_toks = disallowed_tokens + + legacy_attack = object.__new__(MultiPromptAttack) + legacy_worker = MagicMock() + legacy_worker.tokenizer = tokenizer + legacy_attack.workers = [legacy_worker] + + legacy_prompt_for_loss = self._make_prompt(target_slice=target_slice, control_slice=control_slice) + normalized_gradient = gradient / gradient.norm(dim=-1, keepdim=True) + torch.manual_seed(2026) + legacy_control_cand = legacy_prompt_manager.sample_control( + normalized_gradient.clone(), + 1, + topk=3, + temp=1.0, + allow_non_ascii=True, + ) + legacy_controls = legacy_attack.get_filtered_cands( + 0, + legacy_control_cand, + filter_cand=True, + curr_control=current_control, + ) + legacy_loss = target_weight * legacy_prompt_for_loss.target_loss(logits, token_ids).mean( + dim=-1 + ) + control_weight * legacy_prompt_for_loss.control_loss(logits, token_ids).mean(dim=-1) + + assert actual_control == legacy_controls[0] + assert actual_loss == pytest.approx(legacy_loss[0].item()) + + def test_step_uses_custom_protocol_implementations_when_supplied(self) -> None: + gradient = torch.randn(3, 6) + logits = torch.randn(2, 8, 10) + token_ids = torch.randint(0, 10, (2, 8)) + control_tokens = torch.tensor([1, 2, 3], dtype=torch.long) + disallowed_tokens = torch.tensor([5], dtype=torch.long) + tokenizer = self._make_tokenizer() + + worker = _WorkerStub(gradient=gradient.clone(), logits=logits, token_ids=token_ids, tokenizer=tokenizer) + prompt = self._make_prompt(target_slice=slice(4, 6), control_slice=slice(1, 4)) + prompt_manager = _PromptManagerStub( + prompt=prompt, + control_tokens=control_tokens, + disallowed_tokens=disallowed_tokens, + control_str="current control", + ) + + sampled_tokens = torch.tensor([[8, 8, 8], [9, 9, 9]], dtype=torch.long) + sampling = _SpySampling(sampled_tokens=sampled_tokens) + candidate_filter = _SpyFilter(candidates=["candidate-A", "candidate-B"]) + custom_losses = torch.tensor([3.0, 0.5], dtype=torch.float32) + loss = _SpyLoss(losses=custom_losses) + attack = self._make_attack( + worker=worker, + prompt_manager=prompt_manager, + sampling=sampling, + loss=loss, + candidate_filter=candidate_filter, + ) + + selected_control, normalized_loss = attack.step( + batch_size=2, + topk=4, + temp=0.8, + allow_non_ascii=False, + target_weight=0.0, + control_weight=1.0, + verbose=True, + filter_cand=True, + ) + + assert selected_control == "candidate-B" + assert normalized_loss == pytest.approx(0.5) + assert len(sampling.calls) == 1 + assert len(candidate_filter.calls) == 1 + assert len(loss.calls) == 1 + assert sampling.calls[0]["batch_size"] == 2 + assert sampling.calls[0]["top_k"] == 4 + assert sampling.calls[0]["allow_non_ascii"] is False + assert candidate_filter.calls[0]["current_control"] == "current control" + + def test_gcg_multi_prompt_attack_init_with_custom_protocols(self) -> None: + """Test GCGMultiPromptAttack.__init__ stores custom sampling/loss/filter.""" + sampling = _SpySampling(sampled_tokens=torch.tensor([[1, 2, 3]])) + loss = _SpyLoss(losses=torch.tensor([1.0])) + candidate_filter = _SpyFilter(candidates=["filtered"]) + workers = [MagicMock()] + + with patch.object(MultiPromptAttack, "__init__", return_value=None) as mock_base_init: + attack = GCGMultiPromptAttack( + goals=["goal"], + targets=["target"], + workers=workers, + control_init="seed control", + sampling=sampling, + loss=loss, + candidate_filter=candidate_filter, + ) + + assert mock_base_init.call_count == 1 + assert mock_base_init.call_args.args[:4] == (["goal"], ["target"], workers, "seed control") + + assert attack._sampling is sampling + assert attack._loss is loss + assert attack._candidate_filter is candidate_filter + + def test_step_aggregates_workers_when_grad_shapes_mismatch(self) -> None: + """Test step handles a worker gradient shape mismatch by sampling per group.""" + tokenizer = self._make_tokenizer() + prompt = self._make_prompt(target_slice=slice(0, 1), control_slice=slice(0, 1)) + prompt_manager1 = _PromptManagerStub( + prompt=prompt, + control_tokens=torch.tensor([1], dtype=torch.long), + disallowed_tokens=torch.tensor([], dtype=torch.long), + control_str="seed", + ) + prompt_manager2 = _PromptManagerStub( + prompt=prompt, + control_tokens=torch.tensor([1], dtype=torch.long), + disallowed_tokens=torch.tensor([], dtype=torch.long), + control_str="seed", + ) + + grad1 = torch.tensor([[0.1, 0.2, 0.3]], dtype=torch.float32) + grad2 = torch.tensor([[0.4, 0.5, 0.6, 0.7]], dtype=torch.float32) + logits = torch.randn(1, 8, 10) + token_ids = torch.randint(0, 10, (1, 8)) + worker1 = _WorkerStub(gradient=grad1, logits=logits, token_ids=token_ids, tokenizer=tokenizer) + worker2 = _WorkerStub(gradient=grad2, logits=logits, token_ids=token_ids, tokenizer=tokenizer) + worker1.results = _Queue([grad1, (logits, token_ids), (logits, token_ids)]) + worker2.results = _Queue([grad2, (logits, token_ids), (logits, token_ids)]) + + attack = object.__new__(GCGMultiPromptAttack) + attack.workers = [worker1, worker2] + attack.models = [worker1.model] + attack.prompts = [prompt_manager1, prompt_manager2] + attack.control_str = "seed" + + class _ConstantLoss: + @staticmethod + def compute_loss( + *, + logits: torch.Tensor, + token_ids: torch.Tensor, + target_slice: slice, + control_slice: slice, + ) -> torch.Tensor: + return torch.tensor([0.5], dtype=torch.float32) + + with ( + patch.object( + attack, + "_sample_control_candidates", + return_value=torch.tensor([[1, 2, 3]], dtype=torch.long), + ) as mock_sample, + patch.object(attack, "_filter_control_candidates", return_value=["candidate"]), + patch.object(attack, "_resolve_loss", return_value=_ConstantLoss()), + patch.object(attack, "_get_control_length", return_value=None), + ): + control, normalized_loss = attack.step( + batch_size=1, + topk=2, + temp=1.0, + allow_non_ascii=True, + target_weight=1.0, + control_weight=0.1, + verbose=True, + filter_cand=True, + ) + + assert control == "candidate" + assert normalized_loss == pytest.approx(0.5) + assert mock_sample.call_count == 2 + assert mock_sample.call_args_list[0].kwargs["worker_index"] == 0 + assert mock_sample.call_args_list[1].kwargs["worker_index"] == 1 + + def test_resolve_methods_return_defaults_when_none(self) -> None: + """Test _resolve_* methods return defaults when custom protocols are None.""" + worker = _WorkerStub( + gradient=torch.tensor([[0.1]]), + logits=torch.randn(1, 8, 10), + token_ids=torch.randint(0, 10, (1, 8)), + tokenizer=self._make_tokenizer(), + ) + prompt_manager = _PromptManagerStub( + prompt=self._make_prompt(target_slice=slice(0, 1), control_slice=slice(0, 1)), + control_tokens=torch.tensor([1]), + disallowed_tokens=torch.tensor([]), + control_str="test", + ) + + attack = self._make_attack(worker=worker, prompt_manager=prompt_manager) + + # Test _resolve_sampling returns default + sampler = attack._resolve_sampling() + assert sampler is not None + + # Test _resolve_loss returns default + loss_func = attack._resolve_loss(target_weight=1.0, control_weight=0.1) + assert loss_func is not None + + # Test _resolve_candidate_filter returns default + filter_func = attack._resolve_candidate_filter(filter_cand=True) + assert filter_func is not None + + def test_get_control_length_success(self) -> None: + """Test _get_control_length returns token count after dropping the first token.""" + tokenizer = self._make_tokenizer() + worker = _WorkerStub( + gradient=torch.tensor([[0.1]]), + logits=torch.randn(1, 8, 10), + token_ids=torch.randint(0, 10, (1, 8)), + tokenizer=tokenizer, + ) + attack = object.__new__(GCGMultiPromptAttack) + attack.workers = [worker] + + length = attack._get_control_length(control="1 2 3") + assert length == 2 + + def test_get_control_length_handles_error(self) -> None: + """Test _get_control_length returns None on tokenizer error.""" + tokenizer = MagicMock() + tokenizer.side_effect = ValueError("Tokenizer error") + + worker = _WorkerStub( + gradient=torch.tensor([[0.1]]), + logits=torch.randn(1, 8, 10), + token_ids=torch.randint(0, 10, (1, 8)), + tokenizer=tokenizer, + ) + attack = object.__new__(GCGMultiPromptAttack) + attack.workers = [worker] + + length = attack._get_control_length(control="test") + assert length is None diff --git a/tests/unit/auxiliary_attacks/gcg/test_generator.py b/tests/unit/auxiliary_attacks/gcg/test_generator.py index f410aa5079..956dcc4953 100644 --- a/tests/unit/auxiliary_attacks/gcg/test_generator.py +++ b/tests/unit/auxiliary_attacks/gcg/test_generator.py @@ -6,8 +6,9 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING -from unittest.mock import MagicMock, patch +from functools import partial +from typing import TYPE_CHECKING, Any +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -238,6 +239,137 @@ def test_augmentation_modifies_at_least_some_targets(self) -> None: assert num_changed > 0 +class TestExtensionWiring: + def test_create_attack_uses_suffix_initializer_when_configured(self) -> None: + class _SuffixInitStub: + def __init__(self) -> None: + self.calls: list[object] = [] + + def make_initial_suffix(self, *, tokenizer: object) -> str: + self.calls.append(tokenizer) + return "initialized suffix" + + suffix_init = _SuffixInitStub() + gen = GCGGenerator( + models=[GCGModelConfig(name=_LLAMA_2)], + algorithm=GCGAlgorithmConfig(suffix_init=suffix_init), + ) + worker = MagicMock() + worker.tokenizer = MagicMock() + + with patch.object(generator_mod, "IndividualPromptAttack") as mock_individual: + gen._create_attack( + params=MagicMock(), + managers={"MPA": MagicMock()}, + train_goals=["g"], + train_targets=["t"], + test_goals=[], + test_targets=[], + workers=[worker], + test_workers=[], + logfile_path="out.json", + ) + + assert suffix_init.calls == [worker.tokenizer] + assert mock_individual.call_args.kwargs["control_init"] == "initialized suffix" + + def test_resolve_control_init_returns_default_when_suffix_init_not_configured(self) -> None: + gen = GCGGenerator( + models=[GCGModelConfig(name=_LLAMA_2)], + algorithm=GCGAlgorithmConfig(control_init="seed control"), + ) + + assert gen._resolve_control_init(workers=[]) == "seed control" + + def test_resolve_control_init_raises_when_suffix_init_requires_workers(self) -> None: + """Test _resolve_control_init raises ValueError when suffix_init configured but no workers.""" + + class _SuffixInitStub: + def make_initial_suffix(self, *, tokenizer: object) -> str: + return "initialized suffix" + + suffix_init = _SuffixInitStub() + gen = GCGGenerator( + models=[GCGModelConfig(name=_LLAMA_2)], + algorithm=GCGAlgorithmConfig(suffix_init=suffix_init), + ) + + with pytest.raises(ValueError, match="Cannot resolve suffix_init without at least one worker"): + gen._resolve_control_init(workers=[]) + + async def test_perform_async_binds_algorithm_extensions_into_mpa_factory(self, tmp_path: Path) -> None: + class _SamplingStub: + def sample_candidates( + self, + *, + gradient: Any, + control_tokens: Any, + batch_size: int, + top_k: int, + temperature: float, + allow_non_ascii: bool, + non_ascii_tokens: Any, + ) -> Any: + return control_tokens + + class _LossStub: + def compute_loss( + self, + *, + logits: Any, + token_ids: Any, + target_slice: slice, + control_slice: slice, + ) -> Any: + return logits + + class _FilterStub: + def filter_candidates( + self, + *, + candidate_tokens: Any, + tokenizer: Any, + current_control: str, + ) -> list[str]: + return [current_control] + + sampling = _SamplingStub() + loss = _LossStub() + candidate_filter = _FilterStub() + gen = GCGGenerator( + models=[GCGModelConfig(name=_LLAMA_2)], + algorithm=GCGAlgorithmConfig( + sampling=sampling, + loss=loss, + candidate_filter=candidate_filter, + ), + output=GCGOutputConfig(result_prefix=str(tmp_path / "gcg")), + ) + context = GCGContext( + goals=["g"], + targets=["t"], + workers=[MagicMock()], + test_workers=[], + ) + fake_attack = MagicMock() + + with ( + patch.object(gen, "_create_attack", return_value=fake_attack) as mock_create_attack, + patch.object(gen, "_build_logfile_path", return_value=str(tmp_path / "result.json")), + patch.object(gen, "_read_result", return_value=GCGResult(final_suffix="x")), + patch("pyrit.auxiliary_attacks.gcg.generator.asyncio.to_thread", new=AsyncMock(return_value=None)), + ): + await gen._perform_async(context=context) + + managers = mock_create_attack.call_args.kwargs["managers"] + mpa_factory = managers["MPA"] + assert isinstance(mpa_factory, partial) + assert mpa_factory.func is generator_mod.attack_lib.GCGMultiPromptAttack + assert mpa_factory.keywords["sampling"] is sampling + assert mpa_factory.keywords["loss"] is loss + assert mpa_factory.keywords["candidate_filter"] is candidate_filter + + class TestReadResult: def test_reads_final_suffix_and_loss(self, tmp_path: Path) -> None: log_path = tmp_path / "result.json"