Skip to content

Commit 89bf3ce

Browse files
authored
Merge pull request #841 from modelscope/qwen-image-lora-hotload
support qwen-image lora hotload
2 parents 3ebe118 + 6ab426e commit 89bf3ce

File tree

1 file changed

+57
-4
lines changed

1 file changed

+57
-4
lines changed

diffsynth/pipelines/qwen_image.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,63 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
7777
self.model_fn = model_fn_qwen_image
7878

7979

80-
def load_lora(self, module, path, alpha=1):
81-
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
82-
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
83-
loader.load(module, lora, alpha=alpha)
80+
def load_lora(
81+
self,
82+
module: torch.nn.Module,
83+
lora_config: Union[ModelConfig, str] = None,
84+
alpha=1,
85+
hotload=False,
86+
state_dict=None,
87+
):
88+
if state_dict is None:
89+
if isinstance(lora_config, str):
90+
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
91+
else:
92+
lora_config.download_if_necessary()
93+
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
94+
else:
95+
lora = state_dict
96+
if hotload:
97+
for name, module in module.named_modules():
98+
if isinstance(module, AutoWrappedLinear):
99+
lora_a_name = f'{name}.lora_A.default.weight'
100+
lora_b_name = f'{name}.lora_B.default.weight'
101+
if lora_a_name in lora and lora_b_name in lora:
102+
module.lora_A_weights.append(lora[lora_a_name] * alpha)
103+
module.lora_B_weights.append(lora[lora_b_name])
104+
else:
105+
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
106+
loader.load(module, lora, alpha=alpha)
107+
108+
109+
def clear_lora(self):
110+
for name, module in self.named_modules():
111+
if isinstance(module, AutoWrappedLinear):
112+
if hasattr(module, "lora_A_weights"):
113+
module.lora_A_weights.clear()
114+
if hasattr(module, "lora_B_weights"):
115+
module.lora_B_weights.clear()
116+
117+
118+
def enable_lora_magic(self):
119+
if self.dit is not None:
120+
if not (hasattr(self.dit, "vram_management_enabled") and self.dit.vram_management_enabled):
121+
dtype = next(iter(self.dit.parameters())).dtype
122+
enable_vram_management(
123+
self.dit,
124+
module_map = {
125+
torch.nn.Linear: AutoWrappedLinear,
126+
},
127+
module_config = dict(
128+
offload_dtype=dtype,
129+
offload_device=self.device,
130+
onload_dtype=dtype,
131+
onload_device=self.device,
132+
computation_dtype=self.torch_dtype,
133+
computation_device=self.device,
134+
),
135+
vram_limit=None,
136+
)
84137

85138

86139
def training_loss(self, **inputs):

0 commit comments

Comments
 (0)