Skip to content

Commit 5588095

Browse files
committed
initial commit
0 parents  commit 5588095

12 files changed

Lines changed: 1739 additions & 0 deletions

File tree

.github/workflows/typecheck.yml

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
name: Type Check
2+
3+
on:
4+
push:
5+
branches: [ main, master ]
6+
pull_request:
7+
branches: [ main, master ]
8+
workflow_dispatch:
9+
10+
jobs:
11+
typecheck:
12+
runs-on: ubuntu-latest
13+
14+
steps:
15+
- name: Checkout code
16+
uses: actions/checkout@v4
17+
18+
- name: Install uv
19+
uses: astral-sh/setup-uv@v3
20+
with:
21+
enable-cache: true
22+
cache-dependency-glob: "uv.lock"
23+
24+
- name: Set up Python
25+
run: uv python install 3.12
26+
27+
- name: Install dependencies
28+
run: |
29+
uv sync --all-extras --dev
30+
31+
- name: Run pyright
32+
run: |
33+
uv run pyright
34+
35+
- name: Upload type checking results
36+
if: failure()
37+
uses: actions/upload-artifact@v4
38+
with:
39+
name: typecheck-results
40+
path: |
41+
**/pyrightconfig.json
42+
**/.pyright/

README.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Jailbreaking Frontier Models
2+
3+
This is code accompanying the the blogpost, ["Automatically Jailbreaking Frontier Language Models with Investigator Agents"](jailbreaking-frontier-models).
4+
5+
Note: We do not include the RL training code, as it is tightly coupled with our internal research tooling. However, we implement all datasets and reward functions, so that they can be used to train jailbreaking agents and reproduce our experiments.
6+
7+
## Example reproduction
8+
9+
### Install dependencies
10+
11+
```bash
12+
uv sync
13+
source .venv/bin/activate
14+
```
15+
16+
### Run a test script demonstrating the reward function computation
17+
18+
```bash
19+
# assumes that you have a OpenAI-compatible endpoint (e.g. vLLM or SGLang server) running at localhost:8080
20+
21+
22+
```
23+
24+
## Dataset
25+
26+
The dataset is located at `data/jailbreak_dataset.jsonl`. It is a `jsonl` file with the following format:
27+
28+
```json
29+
{
30+
"behavior_id": "<unique id for the behavior>",
31+
"behavior": "<user prompt for the behavior to elicit>",
32+
"category": "<one of drug/chemical/biological/radiological/nuclear/explosive>",
33+
"optimizer_target": "<start of an affirmative response, like \"Sure...\" -- this is not used by our investigator agent, but might be useful for other jailbreaking techniques>",
34+
}
35+
```

datasets/transluce_cbrn.jsonl

Lines changed: 48 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
from __future__ import annotations
2+
3+
import typing
4+
from typing import Awaitable, Callable, ParamSpec, TypeVar
5+
6+
import anyio
7+
import anyio.abc
8+
9+
Args = typing.TypeVarTuple("Args")
10+
T = typing.TypeVar("T")
11+
U = typing.TypeVar("U")
12+
13+
14+
class FutureFinishedWithExceptionError(Exception):
15+
"""An error raised when a future finished with an exception."""
16+
17+
18+
class FutureNotSetError(Exception):
19+
"""An error raised when a future is not set."""
20+
21+
22+
class SimpleFuture[T]:
23+
"""A simple anyio-based future.
24+
25+
Futures can hold a single result, which is provided by an async task. Following anyio (rather
26+
than asyncio) conventions, every task is associated with a task group. Futures allow us to
27+
more easily pass around "boxes" that will eventually contain results.
28+
29+
Note that if you just want to wait for a set of tasks to complete, you can just use an anyio
30+
task group, and store the result in a mutable container like a list or dict; you don't need
31+
futures the same way you do in asyncio. This class is most useful when either:
32+
- you only have a single result of each type and don't want to deal with creating a list with
33+
one element, or want the more convenient typechecked interface of SimpleFuture,
34+
- you want to pass a box between two tasks, so that one task can set the result and another
35+
can wait for it to be set, without requiring a more heavyweight mechanism like a queue.
36+
37+
A basic low-level pattern you can use to run a task and get its result later:
38+
39+
```python
40+
async with anyio.create_task_group() as tg:
41+
future = SimpleFuture()
42+
43+
async def go(*args):
44+
future.set_result(await some_task_fn(*args))
45+
46+
tg.start_soon(go, *args)
47+
48+
# do something else asynchronously while the task runs
49+
50+
# once you exit the task group, the future will have a result
51+
value = future.get()
52+
```
53+
54+
This pattern is useful enough that there is a wrapper so that you can instead do:
55+
56+
```python
57+
async with anyio.create_task_group() as tg:
58+
future = future_from_start_soon(tg, some_task_fn, *args)
59+
# do something else asynchronously while the task runs
60+
61+
# once you exit the task group, the future will have a result
62+
value = future.get()
63+
```
64+
65+
You can also use `await future.wait_for_result()` to wait for the future to complete and get
66+
the result.
67+
"""
68+
69+
def __init__(self):
70+
self.event = anyio.Event()
71+
self.result = None
72+
self.exception = None
73+
74+
def set_result(self, result: T):
75+
"""Set the result of the future."""
76+
77+
assert not self.event.is_set(), "Result already set"
78+
self.result = result
79+
self.event.set()
80+
81+
def set_from_task(
82+
self,
83+
task_fn: Callable[[*Args], typing.Awaitable[T]],
84+
set_exception: bool = True,
85+
catch_exception: bool = False,
86+
) -> Callable[[*Args], typing.Awaitable[None]]:
87+
"""Wrap a callable to set the result of the future from its result.
88+
89+
This function can be used to wrap a task that returns a value into a task that sets a
90+
future. It is most useful in combination with `TaskGroup.start_soon`, e.g.
91+
92+
```python
93+
future = SimpleFuture()
94+
async with anyio.create_task_group() as tg:
95+
# runs `await some_task_fn(*args)` and sets the future to the result
96+
tg.start_soon(future.set_from_task(some_task_fn), *args)
97+
98+
# or, same pattern but with keyword arguments
99+
tg.start_soon(future.set_from_task(functools.partial(some_task_fn, **kwargs)))
100+
```
101+
102+
Args:
103+
task_fn: A callable that returns an awaitable.
104+
set_exception: Whether to set the future to the exception if the task raises an
105+
exception. If False, the exception will not be set.
106+
catch_exception: Whether to catch exceptions and set the future to the exception. If
107+
False (the default), exceptions will be propagated up to the task group. If True,
108+
exceptions will be caught and will not cancel the task group. `set_exception` must
109+
be True if `catch_exception` is True.
110+
111+
Returns:
112+
A callable that can be used to start a task that sets the future to the result of
113+
`task_fn`.
114+
"""
115+
if catch_exception and not set_exception:
116+
raise ValueError("set_exception must be True if catch_exception is True")
117+
118+
async def wrapped_task_fn(*args: *Args):
119+
try:
120+
result = await task_fn(*args)
121+
except Exception as e:
122+
if set_exception:
123+
self.set_exception(e)
124+
if not catch_exception:
125+
raise
126+
else:
127+
raise
128+
else:
129+
self.set_result(result)
130+
131+
return wrapped_task_fn
132+
133+
def set_exception(self, exception: Exception):
134+
"""Set the result of the future to an exception."""
135+
assert not self.event.is_set(), "Result already set"
136+
self.exception = exception
137+
self.event.set()
138+
139+
def has_result(self) -> bool:
140+
"""Check if the future has a result."""
141+
return self.event.is_set()
142+
143+
async def wait_for_result(self) -> T:
144+
"""Wait for the future to complete and return the result.
145+
146+
If the future has an exception, it will be raised.
147+
"""
148+
await self.event.wait()
149+
if self.exception:
150+
raise FutureFinishedWithExceptionError(
151+
f"Future finished with an exception (of type {type(self.exception).__name__})!"
152+
) from self.exception
153+
assert self.result is not None
154+
return self.result
155+
156+
def get(self) -> T:
157+
"""Synchronously get the result of the future.
158+
159+
If the future has an exception, it will be raised.
160+
"""
161+
if not self.has_result():
162+
raise FutureNotSetError("Result not set")
163+
if self.exception:
164+
raise self.exception
165+
assert self.result is not None
166+
return self.result
167+
168+
169+
def future_from_start_soon[
170+
T, *Args
171+
](
172+
task_group: anyio.abc.TaskGroup,
173+
task_fn: Callable[[*Args], Awaitable[T]],
174+
*args: *Args,
175+
catch_exception: bool = False,
176+
) -> SimpleFuture[T]:
177+
"""Create a future from the result of starting a coroutine in a task group.
178+
179+
This is a convenience function for creating a future and starting a task in a task group.
180+
It can be used like this:
181+
182+
```python
183+
async with anyio.create_task_group() as tg:
184+
future = future_from_start_soon(tg, some_task_fn, *args)
185+
```
186+
187+
Args:
188+
task_group: The task group to start the task in. This is the task group that owns the
189+
task and will be cancelled if the task raises an exception (unless
190+
`catch_exception` is True).
191+
task_fn: The coroutine to start.
192+
*args: The arguments to pass to the coroutine.
193+
catch_exception: Whether to catch exceptions and set the future to the exception. If
194+
False (the default), exceptions will be propagated up to the task group. If True,
195+
exceptions will be caught and stored in the future instead, and the task group
196+
will not be cancelled.
197+
198+
Returns:
199+
A future that will eventually contain the result of the coroutine, or an exception if
200+
the coroutine raises an exception and `catch_exception` is True.
201+
"""
202+
result = SimpleFuture[T]()
203+
task_group.start_soon(
204+
result.set_from_task(task_fn, set_exception=True, catch_exception=catch_exception),
205+
*args,
206+
)
207+
return result
208+
209+
210+
P = ParamSpec("P") # full parameter list of the wrapped function
211+
T = TypeVar("T") # its return type
212+
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from openai import AsyncOpenAI
2+
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
3+
from jailbreaking_frontier_models.logprobs import LogProbs
4+
from transformers import PreTrainedTokenizerBase
5+
from tenacity import AsyncRetrying, before_sleep_log, stop_after_attempt, wait_random_exponential
6+
import logging
7+
8+
logger = logging.getLogger(__name__)
9+
10+
async def get_token_logprobs(
11+
client: AsyncOpenAI,
12+
tokenizer: PreTrainedTokenizerBase,
13+
model: str,
14+
input_token_ids: list[int] | None = None,
15+
input_messages: list[ChatCompletionMessageParam] | None = None,
16+
output_token_ids: list[int] | None = None,
17+
output_text: str | None = None,
18+
) -> LogProbs:
19+
"""Get token-level log probabilities for a response.
20+
21+
Args:
22+
client: OpenAI client instance
23+
model: The model to use
24+
input_messages: The messages to use
25+
limiter_fn: The limiter function to use
26+
27+
Returns:
28+
LogProbs object containing prompt and response tokens with their logprobs
29+
30+
WARNING: For vllm, logprobs are not affected by temperature (they assume temperature=1.0). Also, you can't use top_p with logprobs.
31+
"""
32+
33+
assert (input_token_ids is not None or input_messages is not None) and not (
34+
input_token_ids is not None and input_messages is not None
35+
), "Must provide either input_token_ids or input_messages, but not both"
36+
assert (output_token_ids is not None or output_text is not None) and not (
37+
output_token_ids is not None and output_text is not None
38+
), "Must provide either output_token_ids or output_text, but not both"
39+
40+
if input_messages is not None:
41+
conversation_tokens = tokenizer.apply_chat_template(input_messages, add_generation_prompt=True) # type: ignore
42+
elif input_token_ids is not None:
43+
conversation_tokens = input_token_ids
44+
else:
45+
raise ValueError("Must provide either input_messages or input_token_ids")
46+
47+
assert output_text is not None, "output_text must be provided"
48+
49+
full_text = tokenizer.decode(conversation_tokens) + output_text
50+
51+
async for attempt in AsyncRetrying(
52+
wait=wait_random_exponential(multiplier=1, min=1, max=60),
53+
stop=stop_after_attempt(50),
54+
reraise=True,
55+
before_sleep=before_sleep_log(logger, logging.WARNING),
56+
):
57+
with attempt:
58+
output = await client.completions.create(
59+
model=model,
60+
prompt=full_text,
61+
max_tokens=1,
62+
logprobs=1,
63+
echo=True,
64+
)
65+
66+
# Cut off the last tokens, since we sample max_tokens=1 (required for sglang)
67+
token_strs = output.choices[0].logprobs.tokens[:-1] # type: ignore
68+
token_logprobs = output.choices[0].logprobs.token_logprobs[:-1] # type: ignore
69+
70+
if token_strs is None or token_logprobs is None:
71+
raise ValueError("Failed to get logprobs from model")
72+
73+
suffix_idx = len(conversation_tokens)
74+
75+
prompt_token_strs = token_strs[:suffix_idx]
76+
prompt_token_logprobs = list(
77+
token_logprobs[:suffix_idx]
78+
)
79+
response_token_strs = token_strs[suffix_idx:]
80+
response_token_logprobs = list(
81+
token_logprobs[suffix_idx:]
82+
)
83+
84+
return LogProbs(
85+
prompt_token_strs=prompt_token_strs,
86+
prompt_token_logprobs=prompt_token_logprobs, # type: ignore
87+
response_token_strs=response_token_strs,
88+
response_token_logprobs=response_token_logprobs, # type: ignore
89+
)
90+
91+
raise RuntimeError("Failed to get logprobs from model")

jailbreaking_frontier_models/example_reward_fn_computation.py

Whitespace-only changes.

0 commit comments

Comments
 (0)