Skip to content

Commit ef9780e

Browse files
committed
Align DyPE with paper
1 parent dbbf289 commit ef9780e

File tree

9 files changed

+478
-287
lines changed

9 files changed

+478
-287
lines changed

invokeai/app/invocations/flux_denoise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ def _run_diffusion(
477477
)
478478
context.logger.info(
479479
f"DyPE enabled: resolution={self.width}x{self.height}, preset={self.dype_preset}, "
480-
f"method={dype_config.method}, scale={dype_config.dype_scale:.2f}, "
480+
f"scale={dype_config.dype_scale:.2f}, "
481481
f"exponent={dype_config.dype_exponent:.2f}, start_sigma={dype_config.dype_start_sigma:.2f}, "
482482
f"base_resolution={dype_config.base_resolution}"
483483
)

invokeai/backend/flux/denoise.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,18 @@ def denoise(
9696
timestep = scheduler.timesteps[step_index]
9797
# Convert scheduler timestep (0-1000) to normalized (0-1) for the model
9898
t_curr = timestep.item() / scheduler.config.num_train_timesteps
99+
dype_sigma = DyPEExtension.resolve_step_sigma(
100+
fallback_sigma=t_curr,
101+
step_index=step_index,
102+
scheduler_sigmas=getattr(scheduler, "sigmas", None),
103+
)
99104
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
100105

101106
# DyPE: Update step state for timestep-dependent scaling
102107
if dype_extension is not None and dype_embedder is not None:
103108
dype_extension.update_step_state(
104109
embedder=dype_embedder,
105-
timestep=t_curr,
106-
timestep_index=user_step,
107-
total_steps=total_steps,
110+
sigma=dype_sigma,
108111
)
109112

110113
# For Heun scheduler, track if we're in first or second order step
@@ -264,9 +267,7 @@ def denoise(
264267
if dype_extension is not None and dype_embedder is not None:
265268
dype_extension.update_step_state(
266269
embedder=dype_embedder,
267-
timestep=t_curr,
268-
timestep_index=step_index,
269-
total_steps=total_steps,
270+
sigma=t_curr,
270271
)
271272

272273
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)

invokeai/backend/flux/dype/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""Dynamic Position Extrapolation (DyPE) for FLUX models.
22
3-
DyPE enables high-resolution image generation (4K+) with pretrained FLUX models
4-
by dynamically scaling RoPE position embeddings during the denoising process.
3+
DyPE enables high-resolution image generation with pretrained FLUX models by
4+
dynamically modulating RoPE extrapolation during denoising.
55
6-
Based on: https://github.com/wildminder/ComfyUI-DyPE
6+
Based on the official DyPE project: https://github.com/guyyariv/DyPE
77
"""
88

99
from invokeai.backend.flux.dype.base import DyPEConfig

invokeai/backend/flux/dype/base.py

Lines changed: 18 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
"""DyPE base configuration and utilities."""
1+
"""DyPE base configuration and utilities for FLUX vision_yarn RoPE."""
22

3-
import math
43
from dataclasses import dataclass
5-
from typing import Literal
64

75
import torch
86
from torch import Tensor
@@ -14,72 +12,39 @@ class DyPEConfig:
1412

1513
enable_dype: bool = True
1614
base_resolution: int = 1024 # Native training resolution
17-
method: Literal["vision_yarn", "yarn", "ntk", "base"] = "vision_yarn"
1815
dype_scale: float = 2.0 # Magnitude λs (0.0-8.0)
1916
dype_exponent: float = 2.0 # Decay speed λt (0.0-1000.0)
2017
dype_start_sigma: float = 1.0 # When DyPE decay starts
2118

2219

23-
def get_mscale(scale: float, mscale_factor: float = 1.0) -> float:
24-
"""Calculate magnitude scaling factor.
25-
26-
Args:
27-
scale: The resolution scaling factor
28-
mscale_factor: Adjustment factor for the scaling
29-
30-
Returns:
31-
The magnitude scaling factor
32-
"""
33-
if scale <= 1.0:
34-
return 1.0
35-
return mscale_factor * math.log(scale) + 1.0
36-
37-
38-
def get_timestep_mscale(
39-
scale: float,
20+
def get_timestep_kappa(
4021
current_sigma: float,
4122
dype_scale: float,
4223
dype_exponent: float,
4324
dype_start_sigma: float,
4425
) -> float:
45-
"""Calculate timestep-dependent magnitude scaling.
26+
"""Calculate the paper-style DyPE scheduler value κ(t).
4627
4728
The key insight of DyPE: early steps focus on low frequencies (global structure),
48-
late steps on high frequencies (details). This function modulates the scaling
49-
based on the current timestep/sigma.
29+
late steps on high frequencies (details). DyPE expresses this as a direct
30+
timestep scheduler over the positional extrapolation strength:
31+
32+
κ(t) = λs * t^λt
5033
5134
Args:
52-
scale: Resolution scaling factor
5335
current_sigma: Current noise level (1.0 = full noise, 0.0 = clean)
5436
dype_scale: DyPE magnitude (λs)
5537
dype_exponent: DyPE decay speed (λt)
5638
dype_start_sigma: Sigma threshold to start decay
5739
5840
Returns:
59-
Timestep-modulated scaling factor
41+
Timestep scheduler value κ(t)
6042
"""
61-
if scale <= 1.0:
62-
return 1.0
63-
64-
# Normalize sigma to [0, 1] range relative to start_sigma
65-
if current_sigma >= dype_start_sigma:
66-
t_normalized = 1.0
67-
else:
68-
t_normalized = current_sigma / dype_start_sigma
69-
70-
# Apply exponential decay: stronger extrapolation early, weaker late
71-
# decay = exp(-λt * (1 - t)) where t=1 is early (high sigma), t=0 is late
72-
decay = math.exp(-dype_exponent * (1.0 - t_normalized))
73-
74-
# Base mscale from resolution
75-
base_mscale = get_mscale(scale)
43+
if dype_scale <= 0.0 or dype_start_sigma <= 0.0:
44+
return 0.0
7645

77-
# Interpolate between base_mscale and 1.0 based on decay and dype_scale
78-
# When decay=1 (early): use scaled value
79-
# When decay=0 (late): use base value
80-
scaled_mscale = 1.0 + (base_mscale - 1.0) * dype_scale * decay
81-
82-
return scaled_mscale
46+
t_normalized = max(0.0, min(current_sigma / dype_start_sigma, 1.0))
47+
return dype_scale * (t_normalized**dype_exponent)
8348

8449

8550
def compute_vision_yarn_freqs(
@@ -117,35 +82,23 @@ def compute_vision_yarn_freqs(
11782
"""
11883
assert dim % 2 == 0
11984

120-
# Use the larger scale for NTK calculation
12185
scale = max(scale_h, scale_w)
12286

12387
device = pos.device
12488
dtype = torch.float64 if device.type != "mps" else torch.float32
12589

126-
# NTK-aware theta scaling: extends position coverage for high-res
127-
# Formula: theta_scaled = theta * scale^(dim/(dim-2))
128-
# This increases the wavelength of position encodings proportionally
90+
# DyPE applies a direct timestep scheduler to the NTK extrapolation exponent.
91+
# Early steps keep strong extrapolation; late steps relax smoothly back
92+
# toward the training-time RoPE.
12993
if scale > 1.0:
130-
ntk_alpha = scale ** (dim / (dim - 2))
131-
132-
# Apply timestep-dependent DyPE modulation
133-
# mscale controls how strongly we apply the NTK extrapolation
134-
# Early steps (high sigma): stronger extrapolation for global structure
135-
# Late steps (low sigma): weaker extrapolation for fine details
136-
mscale = get_timestep_mscale(
137-
scale=scale,
94+
ntk_exponent = dim / (dim - 2)
95+
kappa = get_timestep_kappa(
13896
current_sigma=current_sigma,
13997
dype_scale=dype_config.dype_scale,
14098
dype_exponent=dype_config.dype_exponent,
14199
dype_start_sigma=dype_config.dype_start_sigma,
142100
)
143-
144-
# Modulate NTK alpha by mscale
145-
# When mscale > 1: interpolate towards stronger extrapolation
146-
# When mscale = 1: use base NTK alpha
147-
modulated_alpha = 1.0 + (ntk_alpha - 1.0) * mscale
148-
scaled_theta = theta * modulated_alpha
101+
scaled_theta = theta * (scale ** (ntk_exponent * kappa))
149102
else:
150103
scaled_theta = theta
151104

@@ -160,101 +113,3 @@ def compute_vision_yarn_freqs(
160113
sin = torch.sin(angles)
161114

162115
return cos.to(pos.dtype), sin.to(pos.dtype)
163-
164-
165-
def compute_yarn_freqs(
166-
pos: Tensor,
167-
dim: int,
168-
theta: int,
169-
scale: float,
170-
current_sigma: float,
171-
dype_config: DyPEConfig,
172-
) -> tuple[Tensor, Tensor]:
173-
"""Compute RoPE frequencies using YARN/NTK method.
174-
175-
Uses NTK-aware theta scaling for high-resolution support with
176-
timestep-dependent DyPE modulation.
177-
178-
Args:
179-
pos: Position tensor
180-
dim: Embedding dimension
181-
theta: RoPE base frequency
182-
scale: Uniform scaling factor
183-
current_sigma: Current noise level (1.0 = full noise, 0.0 = clean)
184-
dype_config: DyPE configuration
185-
186-
Returns:
187-
Tuple of (cos, sin) frequency tensors
188-
"""
189-
assert dim % 2 == 0
190-
191-
device = pos.device
192-
dtype = torch.float64 if device.type != "mps" else torch.float32
193-
194-
# NTK-aware theta scaling with DyPE modulation
195-
if scale > 1.0:
196-
ntk_alpha = scale ** (dim / (dim - 2))
197-
198-
# Apply timestep-dependent DyPE modulation
199-
mscale = get_timestep_mscale(
200-
scale=scale,
201-
current_sigma=current_sigma,
202-
dype_scale=dype_config.dype_scale,
203-
dype_exponent=dype_config.dype_exponent,
204-
dype_start_sigma=dype_config.dype_start_sigma,
205-
)
206-
207-
# Modulate NTK alpha by mscale
208-
modulated_alpha = 1.0 + (ntk_alpha - 1.0) * mscale
209-
scaled_theta = theta * modulated_alpha
210-
else:
211-
scaled_theta = theta
212-
213-
freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
214-
freqs = 1.0 / (scaled_theta**freq_seq)
215-
216-
angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs)
217-
218-
cos = torch.cos(angles)
219-
sin = torch.sin(angles)
220-
221-
return cos.to(pos.dtype), sin.to(pos.dtype)
222-
223-
224-
def compute_ntk_freqs(
225-
pos: Tensor,
226-
dim: int,
227-
theta: int,
228-
scale: float,
229-
) -> tuple[Tensor, Tensor]:
230-
"""Compute RoPE frequencies using NTK method.
231-
232-
Neural Tangent Kernel approach - continuous frequency scaling without
233-
timestep dependency.
234-
235-
Args:
236-
pos: Position tensor
237-
dim: Embedding dimension
238-
theta: RoPE base frequency
239-
scale: Scaling factor
240-
241-
Returns:
242-
Tuple of (cos, sin) frequency tensors
243-
"""
244-
assert dim % 2 == 0
245-
246-
device = pos.device
247-
dtype = torch.float64 if device.type != "mps" else torch.float32
248-
249-
# NTK scaling
250-
scaled_theta = theta * (scale ** (dim / (dim - 2)))
251-
252-
freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
253-
freqs = 1.0 / (scaled_theta**freq_seq)
254-
255-
angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs)
256-
257-
cos = torch.cos(angles)
258-
sin = torch.sin(angles)
259-
260-
return cos.to(pos.dtype), sin.to(pos.dtype)

invokeai/backend/flux/dype/presets.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ class DyPEPresetConfig:
3131
"""Preset configuration values."""
3232

3333
base_resolution: int
34-
method: str
3534
dype_scale: float
3635
dype_exponent: float
3736
dype_start_sigma: float
@@ -41,7 +40,6 @@ class DyPEPresetConfig:
4140
DYPE_PRESETS: dict[DyPEPreset, DyPEPresetConfig] = {
4241
DYPE_PRESET_4K: DyPEPresetConfig(
4342
base_resolution=1024,
44-
method="vision_yarn",
4543
dype_scale=2.0,
4644
dype_exponent=2.0,
4745
dype_start_sigma=1.0,
@@ -84,7 +82,6 @@ def get_dype_config_for_resolution(
8482
return DyPEConfig(
8583
enable_dype=True,
8684
base_resolution=base_resolution,
87-
method="vision_yarn",
8885
dype_scale=dynamic_dype_scale,
8986
dype_exponent=2.0,
9087
dype_start_sigma=1.0,
@@ -111,24 +108,24 @@ def get_dype_config_for_area(
111108
return None
112109

113110
area_ratio = area / base_area
114-
effective_side_ratio = math.sqrt(area_ratio) # 1.0 at base, 2.0 at 2K (if base is 1K)
115-
116-
# Strength: 0 at base area, 8 at sat_area, clamped thereafter.
117-
sat_area = 2027520 # Determined by experimentation where a vertical line appears
118-
sat_side_ratio = math.sqrt(sat_area / base_area)
119-
dynamic_dype_scale = 8.0 * (effective_side_ratio - 1.0) / (sat_side_ratio - 1.0)
111+
effective_side_ratio = math.sqrt(area_ratio)
112+
aspect_ratio = max(width, height) / min(width, height)
113+
aspect_attenuation = 1.0 if aspect_ratio <= 2.0 else 2.0 / aspect_ratio
114+
115+
# Retune area mode to be "auto, but area-aware" instead of dramatically
116+
# stronger than auto. This keeps it closer to the paper-style core DyPE.
117+
dynamic_dype_scale = 2.4 * effective_side_ratio
118+
dynamic_dype_scale *= aspect_attenuation
120119
dynamic_dype_scale = max(0.0, min(dynamic_dype_scale, 8.0))
121120

122-
# Continuous exponent schedule:
123-
# r=1 -> 0.5, r=2 -> 1.0, r=4 -> 2.0 (exact), smoothly varying in between.
124-
x = math.log2(effective_side_ratio)
125-
dype_exponent = 0.25 * (x**2) + 0.25 * x + 0.5
126-
dype_exponent = max(0.5, min(dype_exponent, 2.0))
121+
# Use a narrower, higher exponent range than the old area heuristic so the
122+
# paper-style scheduler decays more conservatively and artifacts are reduced.
123+
exponent_progress = max(0.0, min(effective_side_ratio - 1.0, 1.0))
124+
dype_exponent = 1.25 + 0.75 * exponent_progress
127125

128126
return DyPEConfig(
129127
enable_dype=True,
130128
base_resolution=base_resolution,
131-
method="vision_yarn",
132129
dype_scale=dynamic_dype_scale,
133130
dype_exponent=dype_exponent,
134131
dype_start_sigma=1.0,
@@ -165,7 +162,6 @@ def get_dype_config_from_preset(
165162
return DyPEConfig(
166163
enable_dype=True,
167164
base_resolution=1024,
168-
method="vision_yarn",
169165
dype_scale=custom_scale if custom_scale is not None else dynamic_dype_scale,
170166
dype_exponent=custom_exponent if custom_exponent is not None else 2.0,
171167
dype_start_sigma=1.0,
@@ -196,7 +192,6 @@ def get_dype_config_from_preset(
196192
return DyPEConfig(
197193
enable_dype=True,
198194
base_resolution=preset_config.base_resolution,
199-
method=preset_config.method,
200195
dype_scale=preset_config.dype_scale,
201196
dype_exponent=preset_config.dype_exponent,
202197
dype_start_sigma=preset_config.dype_start_sigma,

0 commit comments

Comments
 (0)