Skip to content

Commit f58a218

Browse files
romanlutzCopilot
andauthored
FEAT: Define GCG extension protocols (typing surface only) (#1861)
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 13a1da6 commit f58a218

6 files changed

Lines changed: 516 additions & 17 deletions

File tree

pyrit/auxiliary_attacks/gcg/__init__.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
"""Public API for the Greedy Coordinate Gradient (GCG) auxiliary attack.
55
6-
The primary entry point is :class:`GCG` (alias for :class:`GCGGenerator`), a
7-
:class:`pyrit.executor.promptgen.core.PromptGeneratorStrategy` that produces
6+
The primary entry point is ``GCG`` (alias for ``GCGGenerator``), a
7+
``pyrit.executor.promptgen.core.PromptGeneratorStrategy`` that produces
88
adversarial suffixes via the GCG algorithm.
99
1010
Example:
@@ -41,16 +41,30 @@
4141
# only have the base `dev` extra (no torch). Touching any of these names from
4242
# the package root triggers the underlying module import on first access; if
4343
# torch is missing the user gets a clear ModuleNotFoundError pointing at torch.
44+
#
45+
# The extension Protocols live in ``extension_protocols`` (typing-only — that
46+
# module imports cleanly without torch) but are routed through the same lazy
47+
# mechanism so all GCG public symbols share one re-export pathway.
4448
_LAZY_IMPORTS = {
49+
"CandidateFilter": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "CandidateFilter"),
4550
"GCG": ("pyrit.auxiliary_attacks.gcg.generator", "GCGGenerator"),
4651
"GCGContext": ("pyrit.auxiliary_attacks.gcg.generator", "GCGContext"),
4752
"GCGGenerator": ("pyrit.auxiliary_attacks.gcg.generator", "GCGGenerator"),
4853
"GCGResult": ("pyrit.auxiliary_attacks.gcg.generator", "GCGResult"),
54+
"LossFunction": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "LossFunction"),
55+
"SamplingStrategy": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "SamplingStrategy"),
56+
"SuffixInitializer": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "SuffixInitializer"),
4957
"load_goals_and_targets": ("pyrit.auxiliary_attacks.gcg.data", "load_goals_and_targets"),
5058
}
5159

5260
if TYPE_CHECKING:
5361
from pyrit.auxiliary_attacks.gcg.data import load_goals_and_targets
62+
from pyrit.auxiliary_attacks.gcg.extension_protocols import (
63+
CandidateFilter,
64+
LossFunction,
65+
SamplingStrategy,
66+
SuffixInitializer,
67+
)
5468
from pyrit.auxiliary_attacks.gcg.generator import (
5569
GCGContext,
5670
GCGGenerator,
@@ -76,6 +90,7 @@ def __dir__() -> list[str]:
7690

7791

7892
__all__ = [
93+
"CandidateFilter",
7994
"GCG",
8095
"GCGAlgorithmConfig",
8196
"GCGConfig",
@@ -86,5 +101,8 @@ def __dir__() -> list[str]:
86101
"GCGOutputConfig",
87102
"GCGResult",
88103
"GCGStrategyConfig",
104+
"LossFunction",
105+
"SamplingStrategy",
106+
"SuffixInitializer",
89107
"load_goals_and_targets",
90108
]

pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ def run(
654654
n_steps: int = 100,
655655
batch_size: int = 1024,
656656
topk: int = 256,
657-
temp: int = 1,
657+
temp: float = 1.0,
658658
allow_non_ascii: bool = True,
659659
target_weight: Optional[float] = None,
660660
control_weight: Optional[float] = None,

pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def sample_control(
9292
grad: torch.Tensor,
9393
batch_size: int,
9494
topk: int = 256,
95-
temp: int = 1,
95+
temp: float = 1.0,
9696
allow_non_ascii: bool = True,
9797
) -> torch.Tensor:
9898
"""
@@ -102,7 +102,7 @@ def sample_control(
102102
grad (torch.Tensor): Gradient tensor for control tokens.
103103
batch_size (int): Number of candidate controls to generate.
104104
topk (int): Number of top gradient positions to sample from. Defaults to 256.
105-
temp (int): Temperature for sampling. Currently unused but kept for API compatibility. Defaults to 1.
105+
temp (float): Temperature for sampling. Currently unused but kept for API compatibility. Defaults to 1.0.
106106
allow_non_ascii (bool): Whether to allow non-ASCII tokens. Defaults to True.
107107
108108
Returns:
@@ -130,7 +130,7 @@ def step(
130130
*,
131131
batch_size: int = 1024,
132132
topk: int = 256,
133-
temp: int = 1,
133+
temp: float = 1.0,
134134
allow_non_ascii: bool = True,
135135
target_weight: float = 1,
136136
control_weight: float = 0.1,
@@ -146,7 +146,7 @@ def step(
146146
Args:
147147
batch_size (int): Number of candidate controls per batch. Defaults to 1024.
148148
topk (int): Number of top gradient positions to sample from. Defaults to 256.
149-
temp (int): Temperature for sampling. Currently unused but kept for API compatibility. Defaults to 1.
149+
temp (float): Temperature for sampling. Currently unused but kept for API compatibility. Defaults to 1.0.
150150
allow_non_ascii (bool): Whether to allow non-ASCII tokens. Defaults to True.
151151
target_weight (float): Weight for target loss. Defaults to 1.
152152
control_weight (float): Weight for control loss. Defaults to 0.1.

pyrit/auxiliary_attacks/gcg/config.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class GCGDataConfig:
7171
Used as a typed bundle for AML transport (a job ships its data config as
7272
a separate JSON file alongside the strategy ``GCGConfig``). Library
7373
callers loading goals/targets from a CSV can construct one and pass it to
74-
:func:`pyrit.auxiliary_attacks.gcg.data.load_goals_and_targets`.
74+
``pyrit.auxiliary_attacks.gcg.data.load_goals_and_targets``.
7575
7676
Attributes:
7777
train_data (str): URL or filesystem path to the training-data CSV. Empty
@@ -100,7 +100,7 @@ def to_json(self) -> str:
100100

101101
@classmethod
102102
def from_json(cls, payload: str) -> GCGDataConfig:
103-
"""Deserialize a config previously produced by :meth:`to_json`."""
103+
"""Deserialize a config previously produced by ``to_json``."""
104104
try:
105105
data = json.loads(payload)
106106
except json.JSONDecodeError as e:
@@ -131,8 +131,8 @@ class GCGAlgorithmConfig:
131131
Defaults to 512.
132132
topk (int): Top-k gradient positions considered for substitution.
133133
Defaults to 256.
134-
temp (int): Sampling temperature placeholder; the current sampling
135-
implementation samples uniformly from the top-k. Defaults to 1.
134+
temp (float): Sampling temperature placeholder; the current sampling
135+
implementation samples uniformly from the top-k. Defaults to 1.0.
136136
target_weight (float): Weight on the target-string cross-entropy loss.
137137
Defaults to 1.0.
138138
control_weight (float): Weight on the control-string cross-entropy loss.
@@ -153,7 +153,7 @@ class GCGAlgorithmConfig:
153153
test_steps: int = 50
154154
batch_size: int = 512
155155
topk: int = 256
156-
temp: int = 1
156+
temp: float = 1.0
157157
target_weight: float = 1.0
158158
control_weight: float = 0.0
159159
learning_rate: float = 0.01
@@ -240,10 +240,10 @@ class GCGOutputConfig:
240240
class GCGConfig:
241241
"""Top-level strategy configuration for one GCG attack run.
242242
243-
Bundles everything :class:`pyrit.auxiliary_attacks.gcg.GCGGenerator`'s
243+
Bundles everything ``pyrit.auxiliary_attacks.gcg.GCGGenerator``'s
244244
constructor needs. Per-execution data (goals, targets) is **not** here —
245245
those flow through ``GCGGenerator.execute_async``, and for AML transport
246-
they ride alongside this object as a separate :class:`GCGDataConfig` JSON.
246+
they ride alongside this object as a separate ``GCGDataConfig`` JSON.
247247
248248
Attributes:
249249
models (list[GCGModelConfig]): Training models the attack optimizes
@@ -287,11 +287,11 @@ def to_json(self) -> str:
287287

288288
@classmethod
289289
def from_json(cls, payload: str) -> GCGConfig:
290-
"""Deserialize a config previously produced by :meth:`to_json`.
290+
"""Deserialize a config previously produced by ``to_json``.
291291
292292
Args:
293293
payload (str): JSON document matching the shape produced by
294-
:meth:`to_json`.
294+
``to_json``.
295295
296296
Returns:
297297
GCGConfig: A new ``GCGConfig`` reconstructed from ``payload``.
@@ -308,7 +308,7 @@ def from_json(cls, payload: str) -> GCGConfig:
308308

309309
@classmethod
310310
def from_json_file(cls, path: str | Path) -> GCGConfig:
311-
"""Load a config from a JSON file produced by :meth:`to_json_file`.
311+
"""Load a config from a JSON file produced by ``to_json_file``.
312312
313313
Args:
314314
path (str | Path): Filesystem path to a JSON config file.

0 commit comments

Comments
 (0)