@@ -34,7 +34,12 @@ def resolve_steering_mode(session: Session) -> SteeringMode:
3434 """Resolve and validate the session steering mode used at generation time."""
3535
3636 mode = session .config .steering_mode
37- if mode == SteeringMode .low_dimensional :
37+ if mode in {
38+ SteeringMode .low_dimensional ,
39+ SteeringMode .content_masked ,
40+ SteeringMode .token_factorized ,
41+ SteeringMode .token_vector_field ,
42+ }:
3843 return mode
3944 raise ValueError (f"Unsupported steering mode: { mode } " )
4045
@@ -259,19 +264,127 @@ def _resolve_model_source(self, session: Session) -> str:
259264 "Run scripts/setup_huggingface.py first or enable STABLE_STEERING_ALLOW_REMOTE_MODEL_DOWNLOAD=true."
260265 )
261266
262- def _steering_offset (self , prompt_embeds , z , anchor_strength : float ):
267+ def _hidden_basis (self , hidden : int , index_id : int , * , device , dtype ):
268+ """Build a deterministic hidden-space basis vector for one steering axis."""
269+
270+ torch = self ._torch
271+ index = torch .linspace (0.0 , 1.0 , hidden , device = device , dtype = dtype )
272+ basis = torch .sin (index * (index_id + 1 ) * torch .pi ) + torch .cos (index * (index_id + 1 ) * 0.5 * torch .pi )
273+ return basis / torch .norm (basis )
274+
275+ def _token_hidden_basis (self , seq_len : int , hidden : int , index_id : int , * , device , dtype ):
276+ """Build a deterministic per-token hidden-vector field for one steering axis."""
277+
278+ torch = self ._torch
279+ token_index = torch .linspace (0.0 , 1.0 , seq_len , device = device , dtype = dtype ).view (seq_len , 1 )
280+ hidden_index = torch .linspace (0.0 , 1.0 , hidden , device = device , dtype = dtype ).view (1 , hidden )
281+ frequency = float (index_id + 1 )
282+ basis = (
283+ torch .sin ((token_index + 0.17 * frequency ) * (hidden_index + 0.11 ) * torch .pi * (1.0 + frequency ))
284+ + 0.7 * torch .cos ((token_index * (0.45 + 0.08 * frequency ) - hidden_index * (0.63 + 0.04 * frequency )) * torch .pi )
285+ + 0.35 * torch .sin ((token_index * hidden_index + 0.13 * frequency ) * 2.0 * torch .pi )
286+ )
287+ return basis / torch .clamp (torch .norm (basis ), min = torch .tensor (1e-6 , device = device , dtype = dtype ))
288+
289+ def _token_inputs (self , pipe , prompt : str , * , seq_len : int , device , dtype ):
290+ """Tokenize the prompt so token-aware steering modes can shape per-token offsets."""
291+
292+ tokenizer = getattr (pipe , "tokenizer" , None )
293+ if tokenizer is None :
294+ return None
295+
296+ tokenized = tokenizer (
297+ prompt ,
298+ padding = "max_length" ,
299+ truncation = True ,
300+ max_length = seq_len ,
301+ return_tensors = "pt" ,
302+ )
303+ input_ids = tokenized .input_ids .to (device = device )
304+ attention_mask = tokenized .attention_mask .to (device = device , dtype = dtype )
305+ return {"input_ids" : input_ids , "attention_mask" : attention_mask }
306+
307+ def _content_mask (self , token_inputs , * , tokenizer , dtype ):
308+ """Build a mask that suppresses padding and special tokens for token-aware steering."""
309+
310+ attention_mask = token_inputs ["attention_mask" ].to (dtype = dtype )
311+ input_ids = token_inputs ["input_ids" ]
312+ content_mask = attention_mask .clone ()
313+
314+ if tokenizer is not None :
315+ for attr in ("bos_token_id" , "eos_token_id" , "pad_token_id" ):
316+ token_id = getattr (tokenizer , attr , None )
317+ if token_id is not None :
318+ content_mask = content_mask * (input_ids != token_id ).to (dtype = dtype )
319+
320+ if float (content_mask .sum ()) <= 0.0 :
321+ return attention_mask
322+ return content_mask
323+
324+ def _steering_offset (self , prompt_embeds , z , anchor_strength : float , * , steering_mode : SteeringMode , token_inputs = None , tokenizer = None ):
263325 """Project the low-dimensional steering vector into embedding space."""
264326
265327 torch = self ._torch
328+ seq_len = prompt_embeds .shape [1 ]
266329 hidden = prompt_embeds .shape [- 1 ]
267330 device = prompt_embeds .device
268331 dtype = prompt_embeds .dtype
269- index = torch .linspace (0.0 , 1.0 , hidden , device = device , dtype = dtype )
270332 offset = torch .zeros_like (prompt_embeds )
271- for i , value in enumerate (z ):
272- basis = torch .sin (index * (i + 1 ) * torch .pi ) + torch .cos (index * (i + 1 ) * 0.5 * torch .pi )
273- basis = basis / torch .norm (basis )
274- offset = offset + (float (value ) * float (anchor_strength )) * basis .view (1 , 1 , hidden )
333+
334+ if steering_mode == SteeringMode .low_dimensional :
335+ for i , value in enumerate (z ):
336+ basis = self ._hidden_basis (hidden , i , device = device , dtype = dtype )
337+ offset = offset + (float (value ) * float (anchor_strength )) * basis .view (1 , 1 , hidden )
338+ return offset
339+
340+ if token_inputs is None :
341+ raise ValueError (f"Token-aware steering mode { steering_mode .value } requires token inputs." )
342+
343+ content_mask = self ._content_mask (token_inputs , tokenizer = tokenizer , dtype = dtype )
344+ token_positions = torch .linspace (0.0 , 1.0 , seq_len , device = device , dtype = dtype )
345+
346+ if steering_mode == SteeringMode .content_masked :
347+ token_profile = 0.35 + 0.65 * torch .sin (token_positions * torch .pi )
348+ token_profile = token_profile .view (1 , seq_len , 1 ) * content_mask .view (1 , seq_len , 1 )
349+ active_tokens = torch .clamp (content_mask .sum (), min = 1.0 )
350+ normalizer = torch .clamp (token_profile .sum (dim = 1 , keepdim = True ), min = 1.0 )
351+ token_profile = token_profile * (active_tokens / normalizer )
352+ for i , value in enumerate (z ):
353+ basis = self ._hidden_basis (hidden , i , device = device , dtype = dtype )
354+ offset = offset + (float (value ) * float (anchor_strength )) * token_profile * basis .view (1 , 1 , hidden )
355+ return offset
356+
357+ if steering_mode == SteeringMode .token_factorized :
358+ mask = content_mask .view (seq_len )
359+ for i , value in enumerate (z ):
360+ hidden_basis = self ._hidden_basis (hidden , i , device = device , dtype = dtype )
361+ token_basis = (
362+ torch .sin (token_positions * (i + 1 ) * torch .pi )
363+ + 0.5 * torch .cos (token_positions * (i + 1 ) * 2.0 * torch .pi )
364+ ) * mask
365+ if float (token_basis .abs ().sum ()) > 0.0 :
366+ token_basis = token_basis - ((token_basis * mask ).sum () / torch .clamp (mask .sum (), min = 1.0 )) * mask
367+ token_norm = torch .norm (token_basis )
368+ if float (token_norm ) > 0.0 :
369+ token_basis = token_basis / token_norm
370+ offset = offset + (float (value ) * float (anchor_strength ) * 0.8 ) * token_basis .view (1 , seq_len , 1 ) * hidden_basis .view (1 , 1 , hidden )
371+ return offset
372+
373+ if steering_mode == SteeringMode .token_vector_field :
374+ mask = content_mask .view (seq_len , 1 )
375+ active_tokens = torch .clamp (mask .sum (), min = 1.0 )
376+ for i , value in enumerate (z ):
377+ token_hidden_basis = self ._token_hidden_basis (seq_len , hidden , i , device = device , dtype = dtype ) * mask
378+ if float (token_hidden_basis .abs ().sum ()) > 0.0 :
379+ token_hidden_basis = token_hidden_basis - token_hidden_basis .sum (dim = 0 , keepdim = True ) / active_tokens
380+ token_hidden_basis = token_hidden_basis * mask
381+ token_hidden_basis = token_hidden_basis / torch .clamp (
382+ torch .norm (token_hidden_basis ),
383+ min = torch .tensor (1e-6 , device = device , dtype = dtype ),
384+ )
385+ offset = offset + (float (value ) * float (anchor_strength ) * 0.7 ) * token_hidden_basis .unsqueeze (0 )
386+ return offset
387+
275388 return offset
276389
277390 def _encode_steered_embeddings (self , session : Session , candidate : Candidate ):
@@ -286,14 +399,21 @@ def _encode_steered_embeddings(self, session: Session, candidate: Candidate):
286399 do_classifier_free_guidance = True ,
287400 negative_prompt = session .negative_prompt or "" ,
288401 )
289- if steering_mode == SteeringMode .low_dimensional :
290- steered_prompt_embeds = prompt_embeds + self ._steering_offset (
291- prompt_embeds ,
292- candidate .z ,
293- session .config .anchor_strength ,
294- )
295- else :
296- raise ValueError (f"Unsupported steering mode: { steering_mode } " )
402+ token_inputs = self ._token_inputs (
403+ pipe ,
404+ session .prompt ,
405+ seq_len = prompt_embeds .shape [1 ],
406+ device = prompt_embeds .device ,
407+ dtype = prompt_embeds .dtype ,
408+ )
409+ steered_prompt_embeds = prompt_embeds + self ._steering_offset (
410+ prompt_embeds ,
411+ candidate .z ,
412+ session .config .anchor_strength ,
413+ steering_mode = steering_mode ,
414+ token_inputs = token_inputs ,
415+ tokenizer = getattr (pipe , "tokenizer" , None ),
416+ )
297417 return steered_prompt_embeds , negative_prompt_embeds
298418
299419 def render_candidate (self , session : Session , candidate : Candidate ) -> Candidate :
0 commit comments