Skip to content

Commit ac00831

Browse files
committed
Extend paper experiments and submission artifacts
1 parent 0f12508 commit ac00831

709 files changed

Lines changed: 218029 additions & 1186 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

app/bootstrap/experiment_models.py

Lines changed: 86 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99
_CLIP_CACHE: dict[tuple[str, str], tuple[object, object]] = {}
1010
_DINO_CACHE: dict[tuple[str, str], tuple[object, object]] = {}
11+
_SIGLIP_CACHE: dict[tuple[str, str], tuple[object, object]] = {}
12+
_BLIP_CAPTION_CACHE: dict[tuple[str, str], tuple[object, object]] = {}
13+
_LPIPS_CACHE: dict[tuple[str, str], object] = {}
1114

1215

1316
def huggingface_cache_dir() -> Path:
@@ -16,6 +19,13 @@ def huggingface_cache_dir() -> Path:
1619
return path
1720

1821

22+
def _local_cache_error(model_family: str, model_id: str) -> RuntimeError:
23+
return RuntimeError(
24+
f"{model_family} model '{model_id}' is not available in the local cache. "
25+
"Run scripts/preload_experiment_models.py first."
26+
)
27+
28+
1929
def get_clip_components(model_id: str, device: str, *, local_only: bool = True):
2030
key = (model_id, device)
2131
cached = _CLIP_CACHE.get(key)
@@ -38,10 +48,7 @@ def get_clip_components(model_id: str, device: str, *, local_only: bool = True):
3848
)
3949
except (OSError, LocalEntryNotFoundError) as exc:
4050
if local_only:
41-
raise RuntimeError(
42-
f"CLIP model '{model_id}' is not available in the local cache. "
43-
"Run scripts/preload_experiment_models.py first."
44-
) from exc
51+
raise _local_cache_error("CLIP", model_id) from exc
4552
raise
4653
model.eval()
4754
_CLIP_CACHE[key] = (model, processor)
@@ -70,11 +77,82 @@ def get_dino_components(model_id: str, device: str, *, local_only: bool = True):
7077
).to(device)
7178
except (OSError, LocalEntryNotFoundError) as exc:
7279
if local_only:
73-
raise RuntimeError(
74-
f"DINO model '{model_id}' is not available in the local cache. "
75-
"Run scripts/preload_experiment_models.py first."
76-
) from exc
80+
raise _local_cache_error("DINO", model_id) from exc
7781
raise
7882
model.eval()
7983
_DINO_CACHE[key] = (processor, model)
8084
return processor, model
85+
86+
87+
def get_siglip_components(model_id: str, device: str, *, local_only: bool = True):
88+
key = (model_id, device)
89+
cached = _SIGLIP_CACHE.get(key)
90+
if cached is not None:
91+
return cached
92+
93+
from transformers import SiglipImageProcessor, SiglipModel
94+
95+
cache_dir = huggingface_cache_dir()
96+
try:
97+
processor = SiglipImageProcessor.from_pretrained(
98+
model_id,
99+
cache_dir=str(cache_dir),
100+
local_files_only=local_only,
101+
use_fast=False,
102+
)
103+
model = SiglipModel.from_pretrained(
104+
model_id,
105+
cache_dir=str(cache_dir),
106+
local_files_only=local_only,
107+
).to(device)
108+
except (OSError, LocalEntryNotFoundError) as exc:
109+
if local_only:
110+
raise _local_cache_error("SigLIP", model_id) from exc
111+
raise
112+
model.eval()
113+
_SIGLIP_CACHE[key] = (processor, model)
114+
return processor, model
115+
116+
117+
def get_blip_caption_components(model_id: str, device: str, *, local_only: bool = True):
118+
key = (model_id, device)
119+
cached = _BLIP_CAPTION_CACHE.get(key)
120+
if cached is not None:
121+
return cached
122+
123+
from transformers import BlipForConditionalGeneration, BlipProcessor
124+
125+
cache_dir = huggingface_cache_dir()
126+
try:
127+
processor = BlipProcessor.from_pretrained(
128+
model_id,
129+
cache_dir=str(cache_dir),
130+
local_files_only=local_only,
131+
use_fast=False,
132+
)
133+
model = BlipForConditionalGeneration.from_pretrained(
134+
model_id,
135+
cache_dir=str(cache_dir),
136+
local_files_only=local_only,
137+
).to(device)
138+
except (OSError, LocalEntryNotFoundError) as exc:
139+
if local_only:
140+
raise _local_cache_error("BLIP caption", model_id) from exc
141+
raise
142+
model.eval()
143+
_BLIP_CAPTION_CACHE[key] = (processor, model)
144+
return processor, model
145+
146+
147+
def get_lpips_metric(net_name: str, device: str):
148+
key = (net_name, device)
149+
cached = _LPIPS_CACHE.get(key)
150+
if cached is not None:
151+
return cached
152+
153+
import lpips
154+
155+
model = lpips.LPIPS(net=net_name).to(device)
156+
model.eval()
157+
_LPIPS_CACHE[key] = model
158+
return model

app/core/config_yaml.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
# Edit any of these values before creating a new session.
1515
# This YAML is reloaded fresh for each setup page visit or reset action.
1616
#
17-
# sampler: random_local | exploit_orthogonal | uncertainty_guided | axis_sweep | incumbent_mix | diversity_shell | line_search | plateau_escape | annealed_shell | spherical_cover | two_scale_cover | quality_diversity_mix
18-
# updater: winner_average | winner_copy | linear_preference | score_weighted_preference | contrastive_preference | softmax_preference | borda_preference | bradley_terry_preference | challenger_mixture_preference | plackett_luce_preference
17+
# sampler: random_local | exploit_orthogonal | uncertainty_guided | axis_sweep | incumbent_mix | diversity_shell | line_search | plateau_escape | annealed_shell | spherical_cover | two_scale_cover | quality_diversity_mix | restart_bridge_mix
18+
# updater: winner_average | winner_copy | linear_preference | score_weighted_preference | contrastive_preference | softmax_preference | borda_preference | bradley_terry_preference | challenger_mixture_preference | plackett_luce_preference | advantage_softmax_preference
1919
# feedback_mode: scalar_rating | pairwise | top_k | winner_only | approve_reject
2020
# seed_policy: fixed-per-round | fixed-per-candidate | fixed-per-candidate-role
21-
# steering_mode: currently low_dimensional
21+
# steering_mode: low_dimensional | content_masked | token_factorized | token_vector_field
2222
# steering_dimension: low-dimensional steering vector size, for example 3 or 5
2323
# candidate_count: visible candidates per round
2424
# image_size: WIDTHxHEIGHT, for example 512x512

app/core/schema.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class SamplerType(str, Enum):
7272
spherical_cover = "spherical_cover"
7373
two_scale_cover = "two_scale_cover"
7474
quality_diversity_mix = "quality_diversity_mix"
75+
restart_bridge_mix = "restart_bridge_mix"
7576

7677

7778
class UpdaterType(str, Enum):
@@ -85,10 +86,14 @@ class UpdaterType(str, Enum):
8586
bradley_terry_preference = "bradley_terry_preference"
8687
challenger_mixture_preference = "challenger_mixture_preference"
8788
plackett_luce_preference = "plackett_luce_preference"
89+
advantage_softmax_preference = "advantage_softmax_preference"
8890

8991

9092
class SteeringMode(str, Enum):
9193
low_dimensional = "low_dimensional"
94+
content_masked = "content_masked"
95+
token_factorized = "token_factorized"
96+
token_vector_field = "token_vector_field"
9297

9398

9499
class StrategyConfig(BaseModel):

app/engine/generation.py

Lines changed: 135 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,12 @@ def resolve_steering_mode(session: Session) -> SteeringMode:
3434
"""Resolve and validate the session steering mode used at generation time."""
3535

3636
mode = session.config.steering_mode
37-
if mode == SteeringMode.low_dimensional:
37+
if mode in {
38+
SteeringMode.low_dimensional,
39+
SteeringMode.content_masked,
40+
SteeringMode.token_factorized,
41+
SteeringMode.token_vector_field,
42+
}:
3843
return mode
3944
raise ValueError(f"Unsupported steering mode: {mode}")
4045

@@ -259,19 +264,127 @@ def _resolve_model_source(self, session: Session) -> str:
259264
"Run scripts/setup_huggingface.py first or enable STABLE_STEERING_ALLOW_REMOTE_MODEL_DOWNLOAD=true."
260265
)
261266

262-
def _steering_offset(self, prompt_embeds, z, anchor_strength: float):
267+
def _hidden_basis(self, hidden: int, index_id: int, *, device, dtype):
268+
"""Build a deterministic hidden-space basis vector for one steering axis."""
269+
270+
torch = self._torch
271+
index = torch.linspace(0.0, 1.0, hidden, device=device, dtype=dtype)
272+
basis = torch.sin(index * (index_id + 1) * torch.pi) + torch.cos(index * (index_id + 1) * 0.5 * torch.pi)
273+
return basis / torch.norm(basis)
274+
275+
def _token_hidden_basis(self, seq_len: int, hidden: int, index_id: int, *, device, dtype):
276+
"""Build a deterministic per-token hidden-vector field for one steering axis."""
277+
278+
torch = self._torch
279+
token_index = torch.linspace(0.0, 1.0, seq_len, device=device, dtype=dtype).view(seq_len, 1)
280+
hidden_index = torch.linspace(0.0, 1.0, hidden, device=device, dtype=dtype).view(1, hidden)
281+
frequency = float(index_id + 1)
282+
basis = (
283+
torch.sin((token_index + 0.17 * frequency) * (hidden_index + 0.11) * torch.pi * (1.0 + frequency))
284+
+ 0.7 * torch.cos((token_index * (0.45 + 0.08 * frequency) - hidden_index * (0.63 + 0.04 * frequency)) * torch.pi)
285+
+ 0.35 * torch.sin((token_index * hidden_index + 0.13 * frequency) * 2.0 * torch.pi)
286+
)
287+
return basis / torch.clamp(torch.norm(basis), min=torch.tensor(1e-6, device=device, dtype=dtype))
288+
289+
def _token_inputs(self, pipe, prompt: str, *, seq_len: int, device, dtype):
290+
"""Tokenize the prompt so token-aware steering modes can shape per-token offsets."""
291+
292+
tokenizer = getattr(pipe, "tokenizer", None)
293+
if tokenizer is None:
294+
return None
295+
296+
tokenized = tokenizer(
297+
prompt,
298+
padding="max_length",
299+
truncation=True,
300+
max_length=seq_len,
301+
return_tensors="pt",
302+
)
303+
input_ids = tokenized.input_ids.to(device=device)
304+
attention_mask = tokenized.attention_mask.to(device=device, dtype=dtype)
305+
return {"input_ids": input_ids, "attention_mask": attention_mask}
306+
307+
def _content_mask(self, token_inputs, *, tokenizer, dtype):
308+
"""Build a mask that suppresses padding and special tokens for token-aware steering."""
309+
310+
attention_mask = token_inputs["attention_mask"].to(dtype=dtype)
311+
input_ids = token_inputs["input_ids"]
312+
content_mask = attention_mask.clone()
313+
314+
if tokenizer is not None:
315+
for attr in ("bos_token_id", "eos_token_id", "pad_token_id"):
316+
token_id = getattr(tokenizer, attr, None)
317+
if token_id is not None:
318+
content_mask = content_mask * (input_ids != token_id).to(dtype=dtype)
319+
320+
if float(content_mask.sum()) <= 0.0:
321+
return attention_mask
322+
return content_mask
323+
324+
def _steering_offset(self, prompt_embeds, z, anchor_strength: float, *, steering_mode: SteeringMode, token_inputs=None, tokenizer=None):
263325
"""Project the low-dimensional steering vector into embedding space."""
264326

265327
torch = self._torch
328+
seq_len = prompt_embeds.shape[1]
266329
hidden = prompt_embeds.shape[-1]
267330
device = prompt_embeds.device
268331
dtype = prompt_embeds.dtype
269-
index = torch.linspace(0.0, 1.0, hidden, device=device, dtype=dtype)
270332
offset = torch.zeros_like(prompt_embeds)
271-
for i, value in enumerate(z):
272-
basis = torch.sin(index * (i + 1) * torch.pi) + torch.cos(index * (i + 1) * 0.5 * torch.pi)
273-
basis = basis / torch.norm(basis)
274-
offset = offset + (float(value) * float(anchor_strength)) * basis.view(1, 1, hidden)
333+
334+
if steering_mode == SteeringMode.low_dimensional:
335+
for i, value in enumerate(z):
336+
basis = self._hidden_basis(hidden, i, device=device, dtype=dtype)
337+
offset = offset + (float(value) * float(anchor_strength)) * basis.view(1, 1, hidden)
338+
return offset
339+
340+
if token_inputs is None:
341+
raise ValueError(f"Token-aware steering mode {steering_mode.value} requires token inputs.")
342+
343+
content_mask = self._content_mask(token_inputs, tokenizer=tokenizer, dtype=dtype)
344+
token_positions = torch.linspace(0.0, 1.0, seq_len, device=device, dtype=dtype)
345+
346+
if steering_mode == SteeringMode.content_masked:
347+
token_profile = 0.35 + 0.65 * torch.sin(token_positions * torch.pi)
348+
token_profile = token_profile.view(1, seq_len, 1) * content_mask.view(1, seq_len, 1)
349+
active_tokens = torch.clamp(content_mask.sum(), min=1.0)
350+
normalizer = torch.clamp(token_profile.sum(dim=1, keepdim=True), min=1.0)
351+
token_profile = token_profile * (active_tokens / normalizer)
352+
for i, value in enumerate(z):
353+
basis = self._hidden_basis(hidden, i, device=device, dtype=dtype)
354+
offset = offset + (float(value) * float(anchor_strength)) * token_profile * basis.view(1, 1, hidden)
355+
return offset
356+
357+
if steering_mode == SteeringMode.token_factorized:
358+
mask = content_mask.view(seq_len)
359+
for i, value in enumerate(z):
360+
hidden_basis = self._hidden_basis(hidden, i, device=device, dtype=dtype)
361+
token_basis = (
362+
torch.sin(token_positions * (i + 1) * torch.pi)
363+
+ 0.5 * torch.cos(token_positions * (i + 1) * 2.0 * torch.pi)
364+
) * mask
365+
if float(token_basis.abs().sum()) > 0.0:
366+
token_basis = token_basis - ((token_basis * mask).sum() / torch.clamp(mask.sum(), min=1.0)) * mask
367+
token_norm = torch.norm(token_basis)
368+
if float(token_norm) > 0.0:
369+
token_basis = token_basis / token_norm
370+
offset = offset + (float(value) * float(anchor_strength) * 0.8) * token_basis.view(1, seq_len, 1) * hidden_basis.view(1, 1, hidden)
371+
return offset
372+
373+
if steering_mode == SteeringMode.token_vector_field:
374+
mask = content_mask.view(seq_len, 1)
375+
active_tokens = torch.clamp(mask.sum(), min=1.0)
376+
for i, value in enumerate(z):
377+
token_hidden_basis = self._token_hidden_basis(seq_len, hidden, i, device=device, dtype=dtype) * mask
378+
if float(token_hidden_basis.abs().sum()) > 0.0:
379+
token_hidden_basis = token_hidden_basis - token_hidden_basis.sum(dim=0, keepdim=True) / active_tokens
380+
token_hidden_basis = token_hidden_basis * mask
381+
token_hidden_basis = token_hidden_basis / torch.clamp(
382+
torch.norm(token_hidden_basis),
383+
min=torch.tensor(1e-6, device=device, dtype=dtype),
384+
)
385+
offset = offset + (float(value) * float(anchor_strength) * 0.7) * token_hidden_basis.unsqueeze(0)
386+
return offset
387+
275388
return offset
276389

277390
def _encode_steered_embeddings(self, session: Session, candidate: Candidate):
@@ -286,14 +399,21 @@ def _encode_steered_embeddings(self, session: Session, candidate: Candidate):
286399
do_classifier_free_guidance=True,
287400
negative_prompt=session.negative_prompt or "",
288401
)
289-
if steering_mode == SteeringMode.low_dimensional:
290-
steered_prompt_embeds = prompt_embeds + self._steering_offset(
291-
prompt_embeds,
292-
candidate.z,
293-
session.config.anchor_strength,
294-
)
295-
else:
296-
raise ValueError(f"Unsupported steering mode: {steering_mode}")
402+
token_inputs = self._token_inputs(
403+
pipe,
404+
session.prompt,
405+
seq_len=prompt_embeds.shape[1],
406+
device=prompt_embeds.device,
407+
dtype=prompt_embeds.dtype,
408+
)
409+
steered_prompt_embeds = prompt_embeds + self._steering_offset(
410+
prompt_embeds,
411+
candidate.z,
412+
session.config.anchor_strength,
413+
steering_mode=steering_mode,
414+
token_inputs=token_inputs,
415+
tokenizer=getattr(pipe, "tokenizer", None),
416+
)
297417
return steered_prompt_embeds, negative_prompt_embeds
298418

299419
def render_candidate(self, session: Session, candidate: Candidate) -> Candidate:

app/engine/orchestrator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@
3535
from app.samplers.plateau_escape import PlateauEscapeSampler
3636
from app.samplers.quality_diversity_mix import QualityDiversityMixSampler
3737
from app.samplers.random_local import RandomLocalSampler
38+
from app.samplers.restart_bridge_mix import RestartBridgeMixSampler
3839
from app.samplers.spherical_cover import SphericalCoverSampler
3940
from app.samplers.two_scale_cover import TwoScaleCoverSampler
4041
from app.samplers.uncertainty import UncertaintyGuidedSampler
42+
from app.updaters.advantage_softmax_pref import AdvantageSoftmaxPreferenceUpdater
4143
from app.storage.repository import JsonRepository
4244
from app.updaters.contrastive_pref import ContrastivePreferenceUpdater
4345
from app.updaters.borda_pref import BordaPreferenceUpdater
@@ -77,6 +79,7 @@ def __init__(
7779
"spherical_cover": SphericalCoverSampler(),
7880
"two_scale_cover": TwoScaleCoverSampler(),
7981
"quality_diversity_mix": QualityDiversityMixSampler(),
82+
"restart_bridge_mix": RestartBridgeMixSampler(),
8083
}
8184
self.updaters = {
8285
"winner_copy": WinnerCopyUpdater(),
@@ -89,6 +92,7 @@ def __init__(
8992
"bradley_terry_preference": BradleyTerryPreferenceUpdater(),
9093
"challenger_mixture_preference": ChallengerMixturePreferenceUpdater(),
9194
"plackett_luce_preference": PlackettLucePreferenceUpdater(),
95+
"advantage_softmax_preference": AdvantageSoftmaxPreferenceUpdater(),
9296
}
9397

9498
@staticmethod

0 commit comments

Comments
 (0)