Skip to content

Commit 2c6fc47

Browse files
authored
[feat]: support offload for flux2 (#1034)
1 parent c7cf058 commit 2c6fc47

5 files changed

Lines changed: 75 additions & 8 deletions

File tree

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
"model_cls": "flux2_dev",
3+
"task": "t2i",
4+
"infer_steps": 50,
5+
"sample_guide_scale": 4.0,
6+
"vae_scale_factor": 16,
7+
"feature_caching": "None",
8+
"enable_cfg": false,
9+
"patch_size": 2,
10+
"tokenizer_max_length": 512,
11+
"rope_type": "flashinfer",
12+
"text_encoder_out_layers": [10, 20, 30],
13+
"cpu_offload": true,
14+
"offload_granularity": "block"
15+
}

lightx2v/common/ops/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pathlib import Path
44

55
import torch
6+
from loguru import logger
67
from safetensors import safe_open
78

89
from lightx2v.utils.envs import *
@@ -81,10 +82,15 @@ def create_pin_tensor(tensor, transpose=False, dtype=None):
8182
dtype: Target data type of the pinned tensor (optional, defaults to source tensor's dtype)
8283
8384
Returns:
84-
Pinned memory tensor (on CPU) with optional transposition applied
85+
Pinned memory tensor (on CPU) with optional transposition applied.
86+
Falls back to regular CPU tensor if pinned memory allocation fails.
8587
"""
8688
dtype = dtype or tensor.dtype
87-
pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=dtype)
89+
try:
90+
pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=dtype)
91+
except Exception as e:
92+
logger.warning(f"Failed to allocate pinned memory (shape={tensor.shape}, dtype={dtype}): {e}. Falling back to regular CPU memory.")
93+
pin_tensor = torch.empty(tensor.shape, dtype=dtype)
8894
pin_tensor = pin_tensor.copy_(tensor)
8995
if transpose:
9096
pin_tensor = pin_tensor.t()

lightx2v/models/networks/flux2/weights/transformer_weights.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,12 +238,18 @@ def to_cuda(self, non_blocking=True):
238238
block.to_cuda(non_blocking=non_blocking)
239239
for block in self.single_blocks:
240240
block.to_cuda(non_blocking=non_blocking)
241+
self.double_stream_modulation_img_linear.to_cuda(non_blocking=non_blocking)
242+
self.double_stream_modulation_txt_linear.to_cuda(non_blocking=non_blocking)
243+
self.single_stream_modulation_linear.to_cuda(non_blocking=non_blocking)
241244

242245
def to_cpu(self, non_blocking=True):
243246
for block in self.double_blocks:
244247
block.to_cpu(non_blocking=non_blocking)
245248
for block in self.single_blocks:
246249
block.to_cpu(non_blocking=non_blocking)
250+
self.double_stream_modulation_img_linear.to_cpu(non_blocking=non_blocking)
251+
self.double_stream_modulation_txt_linear.to_cpu(non_blocking=non_blocking)
252+
self.single_stream_modulation_linear.to_cpu(non_blocking=non_blocking)
247253

248254

249255
# Backward-compatible aliases

lightx2v/models/runners/flux2/flux2_runner.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from lightx2v.utils.registry_factory import RUNNER_REGISTER
1414
from lightx2v_platform.base.global_var import AI_DEVICE
1515

16+
torch_device_module = getattr(torch, AI_DEVICE)
17+
1618

1719
def calculate_dimensions(target_area, ratio):
1820
width = math.sqrt(target_area * ratio)
@@ -45,8 +47,11 @@ def load_vae(self):
4547

4648
def init_modules(self):
4749
logger.info(f"Initializing {self.config['model_cls']} modules...")
48-
self.load_model()
49-
self.model.set_scheduler(self.scheduler)
50+
if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
51+
self.load_model()
52+
self.model.set_scheduler(self.scheduler)
53+
elif self.config.get("lazy_load", False):
54+
assert self.config.get("cpu_offload", False)
5055

5156
task = self.config.get("task", "t2i")
5257
if task == "i2i":
@@ -59,8 +64,12 @@ def init_modules(self):
5964
@ProfilingContext4DebugL2("Run Encoders")
6065
def _run_input_encoder_local_t2i(self):
6166
prompt = self.input_info.prompt
67+
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
68+
self.text_encoders = self.load_text_encoder()
6269
text_encoder_output = self.run_text_encoder(prompt, neg_prompt=self.input_info.negative_prompt)
63-
torch.cuda.empty_cache()
70+
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
71+
del self.text_encoders[0]
72+
torch_device_module.empty_cache()
6473
gc.collect()
6574
return {
6675
"text_encoder_output": text_encoder_output,
@@ -70,7 +79,11 @@ def _run_input_encoder_local_t2i(self):
7079
@ProfilingContext4DebugL2("Run Encoders I2I")
7180
def _run_input_encoder_local_i2i(self):
7281
prompt = self.input_info.prompt
82+
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
83+
self.text_encoders = self.load_text_encoder()
7384
text_encoder_output = self.run_text_encoder(prompt, neg_prompt=self.input_info.negative_prompt)
85+
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
86+
del self.text_encoders[0]
7487

7588
image_path = self.input_info.image_path
7689
from PIL import Image
@@ -108,7 +121,7 @@ def _run_input_encoder_local_i2i(self):
108121
if index == 0:
109122
self.input_info.target_shape = (image_height, image_width)
110123

111-
torch.cuda.empty_cache()
124+
torch_device_module.empty_cache()
112125
gc.collect()
113126

114127
return {
@@ -244,6 +257,9 @@ def set_img_shapes(self):
244257

245258
@ProfilingContext4DebugL1("Run VAE Decoder")
246259
def run_vae_decoder(self, latents):
260+
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
261+
self.vae = self.load_vae()
262+
247263
B, _, C = latents.shape
248264

249265
H = int((self.input_info.latent_image_ids[0, :, 1].max() + 1).item())
@@ -252,14 +268,20 @@ def run_vae_decoder(self, latents):
252268
latents = latents.view(B, H, W, C).permute(0, 3, 1, 2)
253269

254270
bn_mean = self.vae.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
255-
bn_std = torch.sqrt(self.vae.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.vae.config.batch_norm_eps)
271+
bn_std = torch.sqrt(self.vae.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.vae.config.batch_norm_eps).to(latents.device, latents.dtype)
256272
latents = latents * bn_std + bn_mean
257273

258274
latents = latents.reshape(B, C // 4, 2, 2, H, W)
259275
latents = latents.permute(0, 1, 4, 2, 5, 3)
260276
latents = latents.reshape(B, C // 4, H * 2, W * 2)
261277

262278
images = self.vae.decode(latents, self.input_info)
279+
280+
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
281+
del self.vae
282+
torch_device_module.empty_cache()
283+
gc.collect()
284+
263285
return images
264286

265287
@ProfilingContext4DebugL1("RUN pipeline")
@@ -279,7 +301,7 @@ def run_pipeline(self, input_info):
279301
image.save(input_info.save_result_path)
280302
logger.info(f"Image saved: {input_info.save_result_path}")
281303

282-
torch.cuda.empty_cache()
304+
torch_device_module.empty_cache()
283305
gc.collect()
284306

285307
if input_info.return_result_tensor:
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#!/bin/bash
2+
lightx2v_path=
3+
model_path="/data/temp/FLUX.2-dev"
4+
export CUDA_VISIBLE_DEVICES=3
5+
6+
source ${lightx2v_path}/scripts/base/base.sh
7+
8+
# Create output directory
9+
mkdir -p ${lightx2v_path}/save_results
10+
11+
python -m lightx2v.infer \
12+
--model_cls flux2_dev \
13+
--task t2i \
14+
--model_path $model_path \
15+
--config_json "${lightx2v_path}/configs/flux2/flux2_dev_offload.json" \
16+
--prompt "Realistic macro photograph of a hermit crab using a soda can as its shell, partially emerging from the can, captured with sharp detail and natural colors, on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean waves in the background. The can has the text 'BFL Diffusers' on it and it has a color gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom." \
17+
--save_result_path "${lightx2v_path}/save_results/flux2_dev_offload.png" \
18+
--seed 42

0 commit comments

Comments
 (0)