1- """DyPE base configuration and utilities."""
1+ """DyPE base configuration and utilities for FLUX vision_yarn RoPE ."""
22
3- import math
43from dataclasses import dataclass
5- from typing import Literal
64
75import torch
86from torch import Tensor
@@ -14,72 +12,39 @@ class DyPEConfig:
1412
1513 enable_dype : bool = True
1614 base_resolution : int = 1024 # Native training resolution
17- method : Literal ["vision_yarn" , "yarn" , "ntk" , "base" ] = "vision_yarn"
1815 dype_scale : float = 2.0 # Magnitude λs (0.0-8.0)
1916 dype_exponent : float = 2.0 # Decay speed λt (0.0-1000.0)
2017 dype_start_sigma : float = 1.0 # When DyPE decay starts
2118
2219
23- def get_mscale (scale : float , mscale_factor : float = 1.0 ) -> float :
24- """Calculate magnitude scaling factor.
25-
26- Args:
27- scale: The resolution scaling factor
28- mscale_factor: Adjustment factor for the scaling
29-
30- Returns:
31- The magnitude scaling factor
32- """
33- if scale <= 1.0 :
34- return 1.0
35- return mscale_factor * math .log (scale ) + 1.0
36-
37-
38- def get_timestep_mscale (
39- scale : float ,
20+ def get_timestep_kappa (
4021 current_sigma : float ,
4122 dype_scale : float ,
4223 dype_exponent : float ,
4324 dype_start_sigma : float ,
4425) -> float :
45- """Calculate timestep-dependent magnitude scaling .
26+ """Calculate the paper-style DyPE scheduler value κ(t) .
4627
4728 The key insight of DyPE: early steps focus on low frequencies (global structure),
48- late steps on high frequencies (details). This function modulates the scaling
49- based on the current timestep/sigma.
29+ late steps on high frequencies (details). DyPE expresses this as a direct
30+ timestep scheduler over the positional extrapolation strength:
31+
32+ κ(t) = λs * t^λt
5033
5134 Args:
52- scale: Resolution scaling factor
5335 current_sigma: Current noise level (1.0 = full noise, 0.0 = clean)
5436 dype_scale: DyPE magnitude (λs)
5537 dype_exponent: DyPE decay speed (λt)
5638 dype_start_sigma: Sigma threshold to start decay
5739
5840 Returns:
59- Timestep-modulated scaling factor
41+ Timestep scheduler value κ(t)
6042 """
61- if scale <= 1.0 :
62- return 1.0
63-
64- # Normalize sigma to [0, 1] range relative to start_sigma
65- if current_sigma >= dype_start_sigma :
66- t_normalized = 1.0
67- else :
68- t_normalized = current_sigma / dype_start_sigma
69-
70- # Apply exponential decay: stronger extrapolation early, weaker late
71- # decay = exp(-λt * (1 - t)) where t=1 is early (high sigma), t=0 is late
72- decay = math .exp (- dype_exponent * (1.0 - t_normalized ))
73-
74- # Base mscale from resolution
75- base_mscale = get_mscale (scale )
43+ if dype_scale <= 0.0 or dype_start_sigma <= 0.0 :
44+ return 0.0
7645
77- # Interpolate between base_mscale and 1.0 based on decay and dype_scale
78- # When decay=1 (early): use scaled value
79- # When decay=0 (late): use base value
80- scaled_mscale = 1.0 + (base_mscale - 1.0 ) * dype_scale * decay
81-
82- return scaled_mscale
46+ t_normalized = max (0.0 , min (current_sigma / dype_start_sigma , 1.0 ))
47+ return dype_scale * (t_normalized ** dype_exponent )
8348
8449
8550def compute_vision_yarn_freqs (
@@ -117,35 +82,23 @@ def compute_vision_yarn_freqs(
11782 """
11883 assert dim % 2 == 0
11984
120- # Use the larger scale for NTK calculation
12185 scale = max (scale_h , scale_w )
12286
12387 device = pos .device
12488 dtype = torch .float64 if device .type != "mps" else torch .float32
12589
126- # NTK-aware theta scaling: extends position coverage for high-res
127- # Formula: theta_scaled = theta * scale^(dim/(dim-2))
128- # This increases the wavelength of position encodings proportionally
90+ # DyPE applies a direct timestep scheduler to the NTK extrapolation exponent.
91+ # Early steps keep strong extrapolation; late steps relax smoothly back
92+ # toward the training-time RoPE.
12993 if scale > 1.0 :
130- ntk_alpha = scale ** (dim / (dim - 2 ))
131-
132- # Apply timestep-dependent DyPE modulation
133- # mscale controls how strongly we apply the NTK extrapolation
134- # Early steps (high sigma): stronger extrapolation for global structure
135- # Late steps (low sigma): weaker extrapolation for fine details
136- mscale = get_timestep_mscale (
137- scale = scale ,
94+ ntk_exponent = dim / (dim - 2 )
95+ kappa = get_timestep_kappa (
13896 current_sigma = current_sigma ,
13997 dype_scale = dype_config .dype_scale ,
14098 dype_exponent = dype_config .dype_exponent ,
14199 dype_start_sigma = dype_config .dype_start_sigma ,
142100 )
143-
144- # Modulate NTK alpha by mscale
145- # When mscale > 1: interpolate towards stronger extrapolation
146- # When mscale = 1: use base NTK alpha
147- modulated_alpha = 1.0 + (ntk_alpha - 1.0 ) * mscale
148- scaled_theta = theta * modulated_alpha
101+ scaled_theta = theta * (scale ** (ntk_exponent * kappa ))
149102 else :
150103 scaled_theta = theta
151104
@@ -160,101 +113,3 @@ def compute_vision_yarn_freqs(
160113 sin = torch .sin (angles )
161114
162115 return cos .to (pos .dtype ), sin .to (pos .dtype )
163-
164-
165- def compute_yarn_freqs (
166- pos : Tensor ,
167- dim : int ,
168- theta : int ,
169- scale : float ,
170- current_sigma : float ,
171- dype_config : DyPEConfig ,
172- ) -> tuple [Tensor , Tensor ]:
173- """Compute RoPE frequencies using YARN/NTK method.
174-
175- Uses NTK-aware theta scaling for high-resolution support with
176- timestep-dependent DyPE modulation.
177-
178- Args:
179- pos: Position tensor
180- dim: Embedding dimension
181- theta: RoPE base frequency
182- scale: Uniform scaling factor
183- current_sigma: Current noise level (1.0 = full noise, 0.0 = clean)
184- dype_config: DyPE configuration
185-
186- Returns:
187- Tuple of (cos, sin) frequency tensors
188- """
189- assert dim % 2 == 0
190-
191- device = pos .device
192- dtype = torch .float64 if device .type != "mps" else torch .float32
193-
194- # NTK-aware theta scaling with DyPE modulation
195- if scale > 1.0 :
196- ntk_alpha = scale ** (dim / (dim - 2 ))
197-
198- # Apply timestep-dependent DyPE modulation
199- mscale = get_timestep_mscale (
200- scale = scale ,
201- current_sigma = current_sigma ,
202- dype_scale = dype_config .dype_scale ,
203- dype_exponent = dype_config .dype_exponent ,
204- dype_start_sigma = dype_config .dype_start_sigma ,
205- )
206-
207- # Modulate NTK alpha by mscale
208- modulated_alpha = 1.0 + (ntk_alpha - 1.0 ) * mscale
209- scaled_theta = theta * modulated_alpha
210- else :
211- scaled_theta = theta
212-
213- freq_seq = torch .arange (0 , dim , 2 , dtype = dtype , device = device ) / dim
214- freqs = 1.0 / (scaled_theta ** freq_seq )
215-
216- angles = torch .einsum ("...n,d->...nd" , pos .to (dtype ), freqs )
217-
218- cos = torch .cos (angles )
219- sin = torch .sin (angles )
220-
221- return cos .to (pos .dtype ), sin .to (pos .dtype )
222-
223-
224- def compute_ntk_freqs (
225- pos : Tensor ,
226- dim : int ,
227- theta : int ,
228- scale : float ,
229- ) -> tuple [Tensor , Tensor ]:
230- """Compute RoPE frequencies using NTK method.
231-
232- Neural Tangent Kernel approach - continuous frequency scaling without
233- timestep dependency.
234-
235- Args:
236- pos: Position tensor
237- dim: Embedding dimension
238- theta: RoPE base frequency
239- scale: Scaling factor
240-
241- Returns:
242- Tuple of (cos, sin) frequency tensors
243- """
244- assert dim % 2 == 0
245-
246- device = pos .device
247- dtype = torch .float64 if device .type != "mps" else torch .float32
248-
249- # NTK scaling
250- scaled_theta = theta * (scale ** (dim / (dim - 2 )))
251-
252- freq_seq = torch .arange (0 , dim , 2 , dtype = dtype , device = device ) / dim
253- freqs = 1.0 / (scaled_theta ** freq_seq )
254-
255- angles = torch .einsum ("...n,d->...nd" , pos .to (dtype ), freqs )
256-
257- cos = torch .cos (angles )
258- sin = torch .sin (angles )
259-
260- return cos .to (pos .dtype ), sin .to (pos .dtype )
0 commit comments