1313# limitations under the License.
1414
1515import math
16- from typing import Optional , List , TYPE_CHECKING , Dict , Union , Tuple
16+ from typing import TYPE_CHECKING , Dict , List , Optional , Tuple , Union
1717
1818import torch
1919
2020from .guider_utils import BaseGuidance , rescale_noise_cfg
2121
22+
2223if TYPE_CHECKING :
2324 from ..modular_pipelines .modular_pipeline import BlockState
2425
@@ -74,10 +75,10 @@ def __init__(
7475 self .momentum_buffer = None
7576
7677 def prepare_inputs (self , data : "BlockState" , input_fields : Optional [Dict [str , Union [str , Tuple [str , str ]]]] = None ) -> List ["BlockState" ]:
77-
78+
7879 if input_fields is None :
7980 input_fields = self ._input_fields
80-
81+
8182 if self ._step == 0 :
8283 if self .adaptive_projected_guidance_momentum is not None :
8384 self .momentum_buffer = MomentumBuffer (self .adaptive_projected_guidance_momentum )
@@ -123,19 +124,19 @@ def num_conditions(self) -> int:
123124 def _is_apg_enabled (self ) -> bool :
124125 if not self ._enabled :
125126 return False
126-
127+
127128 is_within_range = True
128129 if self ._num_inference_steps is not None :
129130 skip_start_step = int (self ._start * self ._num_inference_steps )
130131 skip_stop_step = int (self ._stop * self ._num_inference_steps )
131132 is_within_range = skip_start_step <= self ._step < skip_stop_step
132-
133+
133134 is_close = False
134135 if self .use_original_formulation :
135136 is_close = math .isclose (self .guidance_scale , 0.0 )
136137 else :
137138 is_close = math .isclose (self .guidance_scale , 1.0 )
138-
139+
139140 return is_within_range and not is_close
140141
141142
@@ -160,25 +161,25 @@ def normalized_guidance(
160161):
161162 diff = pred_cond - pred_uncond
162163 dim = [- i for i in range (1 , len (diff .shape ))]
163-
164+
164165 if momentum_buffer is not None :
165166 momentum_buffer .update (diff )
166167 diff = momentum_buffer .running_average
167-
168+
168169 if norm_threshold > 0 :
169170 ones = torch .ones_like (diff )
170171 diff_norm = diff .norm (p = 2 , dim = dim , keepdim = True )
171172 scale_factor = torch .minimum (ones , norm_threshold / diff_norm )
172173 diff = diff * scale_factor
173-
174+
174175 v0 , v1 = diff .double (), pred_cond .double ()
175176 v1 = torch .nn .functional .normalize (v1 , dim = dim )
176177 v0_parallel = (v0 * v1 ).sum (dim = dim , keepdim = True ) * v1
177178 v0_orthogonal = v0 - v0_parallel
178179 diff_parallel , diff_orthogonal = v0_parallel .type_as (diff ), v0_orthogonal .type_as (diff )
179180 normalized_update = diff_orthogonal + eta * diff_parallel
180-
181+
181182 pred = pred_cond if use_original_formulation else pred_uncond
182183 pred = pred + guidance_scale * normalized_update
183-
184+
184185 return pred
0 commit comments