Skip to content

Commit ee60097

Browse files
authored
Broaden text encoder partial-load recovery (invoke-ai#9034)
1 parent d4c0e63 commit ee60097

8 files changed

Lines changed: 438 additions & 8 deletions

File tree

invokeai/app/invocations/compel.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from invokeai.app.invocations.primitives import ConditioningOutput
2020
from invokeai.app.services.shared.invocation_context import InvocationContext
2121
from invokeai.app.util.ti_utils import generate_ti_list
22+
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
2223
from invokeai.backend.model_patcher import ModelPatcher
2324
from invokeai.backend.patches.layer_patcher import LayerPatcher
2425
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
@@ -103,7 +104,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
103104
textual_inversion_manager=ti_manager,
104105
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
105106
truncate_long_prompts=False,
106-
device=text_encoder.device, # Use the device the model is actually on
107+
device=get_effective_device(text_encoder),
107108
split_long_text_mode=SplitLongTextMode.SENTENCES,
108109
)
109110

@@ -212,7 +213,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
212213
truncate_long_prompts=False, # TODO:
213214
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
214215
requires_pooled=get_pooled,
215-
device=text_encoder.device, # Use the device the model is actually on
216+
device=get_effective_device(text_encoder),
216217
split_long_text_mode=SplitLongTextMode.SENTENCES,
217218
)
218219

invokeai/app/invocations/sd3_text_encoder.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from invokeai.app.invocations.model import CLIPField, T5EncoderField
1717
from invokeai.app.invocations.primitives import SD3ConditioningOutput
1818
from invokeai.app.services.shared.invocation_context import InvocationContext
19+
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
1920
from invokeai.backend.model_manager.taxonomy import ModelFormat
2021
from invokeai.backend.patches.layer_patcher import LayerPatcher
2122
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
@@ -103,6 +104,7 @@ def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tens
103104
context.util.signal_progress("Running T5 encoder")
104105
assert isinstance(t5_text_encoder, T5EncoderModel)
105106
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
107+
t5_device = get_effective_device(t5_text_encoder)
106108

107109
text_inputs = t5_tokenizer(
108110
prompt,
@@ -125,7 +127,7 @@ def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tens
125127
f" {max_seq_len} tokens: {removed_text}"
126128
)
127129

128-
prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0]
130+
prompt_embeds = t5_text_encoder(text_input_ids.to(t5_device))[0]
129131

130132
assert isinstance(prompt_embeds, torch.Tensor)
131133
return prompt_embeds
@@ -144,6 +146,7 @@ def _clip_encode(
144146
context.util.signal_progress("Running CLIP encoder")
145147
assert isinstance(clip_text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
146148
assert isinstance(clip_tokenizer, CLIPTokenizer)
149+
clip_device = get_effective_device(clip_text_encoder)
147150

148151
clip_text_encoder_config = clip_text_encoder_info.config
149152
assert clip_text_encoder_config is not None
@@ -187,9 +190,7 @@ def _clip_encode(
187190
"The following part of your input was truncated because CLIP can only handle sequences up to"
188191
f" {tokenizer_max_length} tokens: {removed_text}"
189192
)
190-
prompt_embeds = clip_text_encoder(
191-
input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True
192-
)
193+
prompt_embeds = clip_text_encoder(input_ids=text_input_ids.to(clip_device), output_hidden_states=True)
193194
pooled_prompt_embeds = prompt_embeds[0]
194195
prompt_embeds = prompt_embeds.hidden_states[-2]
195196

invokeai/backend/flux/modules/conditioner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from torch import Tensor, nn
44
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
55

6+
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
7+
68

79
class HFEncoder(nn.Module):
810
def __init__(
@@ -32,7 +34,7 @@ def forward(self, text: list[str]) -> Tensor:
3234
)
3335

3436
# Move inputs to the same device as the model to support cpu_only models
35-
model_device = next(self.hf_module.parameters()).device
37+
model_device = get_effective_device(self.hf_module)
3638

3739
outputs = self.hf_module(
3840
input_ids=batch_encoding["input_ids"].to(model_device),

invokeai/backend/model_manager/load/load_base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,12 @@ def __init__(self, cache_record: CacheRecord, cache: ModelCache):
5858

5959
def __enter__(self) -> AnyModel:
6060
self._cache.lock(self._cache_record, None)
61-
return self.model
61+
try:
62+
self.repair_required_tensors_on_device()
63+
return self.model
64+
except Exception:
65+
self._cache.unlock(self._cache_record)
66+
raise
6267

6368
def __exit__(self, *args: Any, **kwargs: Any) -> None:
6469
self._cache.unlock(self._cache_record)
@@ -74,6 +79,7 @@ def model_on_device(
7479
"""
7580
self._cache.lock(self._cache_record, working_mem_bytes)
7681
try:
82+
self.repair_required_tensors_on_device()
7783
yield (self._cache_record.cached_model.get_cpu_state_dict(), self._cache_record.cached_model.model)
7884
finally:
7985
self._cache.unlock(self._cache_record)
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
from contextlib import contextmanager, nullcontext
2+
from types import SimpleNamespace
3+
from unittest.mock import MagicMock
4+
5+
import torch
6+
7+
from invokeai.app.invocations.compel import SDXLPromptInvocationBase
8+
9+
10+
class FakeClipTextEncoder(torch.nn.Module):
11+
def __init__(self, effective_device: torch.device):
12+
super().__init__()
13+
self.register_parameter("cpu_param", torch.nn.Parameter(torch.ones(1)))
14+
self.register_buffer("active_buffer", torch.ones(1, device=effective_device))
15+
self.dtype = torch.float32
16+
17+
@property
18+
def device(self) -> torch.device:
19+
return torch.device("cpu")
20+
21+
22+
class FakeTokenizer:
23+
pass
24+
25+
26+
class FakeLoadedModel:
27+
def __init__(self, model, config=None):
28+
self._model = model
29+
self.config = config
30+
31+
@contextmanager
32+
def model_on_device(self):
33+
yield (None, self._model)
34+
35+
def __enter__(self):
36+
return self._model
37+
38+
def __exit__(self, exc_type, exc, tb):
39+
return False
40+
41+
42+
class FakeCompel:
43+
last_init_device: torch.device | None = None
44+
45+
def __init__(self, *args, device: torch.device, **kwargs):
46+
del args, kwargs
47+
FakeCompel.last_init_device = device
48+
self.conditioning_provider = SimpleNamespace(
49+
get_pooled_embeddings=lambda prompts: torch.ones((len(prompts), 4), dtype=torch.float32)
50+
)
51+
52+
@staticmethod
53+
def parse_prompt_string(prompt: str) -> str:
54+
return prompt
55+
56+
def build_conditioning_tensor_for_conjunction(self, conjunction: str):
57+
del conjunction
58+
return torch.ones((1, 4, 4), dtype=torch.float32), {}
59+
60+
61+
@contextmanager
62+
def fake_apply_ti(tokenizer, text_encoder, ti_list):
63+
del text_encoder, ti_list
64+
yield tokenizer, object()
65+
66+
67+
def test_sdxl_run_clip_compel_uses_effective_device_for_partially_loaded_model(monkeypatch):
68+
module_path = "invokeai.app.invocations.compel"
69+
effective_device = torch.device("meta")
70+
text_encoder = FakeClipTextEncoder(effective_device=effective_device)
71+
tokenizer = FakeTokenizer()
72+
text_encoder_info = FakeLoadedModel(text_encoder, config=SimpleNamespace(base="sdxl"))
73+
tokenizer_info = FakeLoadedModel(tokenizer)
74+
75+
mock_context = MagicMock()
76+
mock_context.models.load.side_effect = [text_encoder_info, tokenizer_info]
77+
mock_context.config.get.return_value.log_tokenization = False
78+
mock_context.util.signal_progress = MagicMock()
79+
80+
monkeypatch.setattr(f"{module_path}.CLIPTextModel", FakeClipTextEncoder)
81+
monkeypatch.setattr(f"{module_path}.CLIPTextModelWithProjection", FakeClipTextEncoder)
82+
monkeypatch.setattr(f"{module_path}.CLIPTokenizer", FakeTokenizer)
83+
monkeypatch.setattr(f"{module_path}.Compel", FakeCompel)
84+
monkeypatch.setattr(f"{module_path}.generate_ti_list", lambda prompt, base, context: [])
85+
monkeypatch.setattr(f"{module_path}.LayerPatcher.apply_smart_model_patches", lambda **kwargs: nullcontext())
86+
monkeypatch.setattr(f"{module_path}.ModelPatcher.apply_clip_skip", lambda *args, **kwargs: nullcontext())
87+
monkeypatch.setattr(f"{module_path}.ModelPatcher.apply_ti", fake_apply_ti)
88+
89+
base = SDXLPromptInvocationBase()
90+
cond, pooled = base.run_clip_compel(
91+
context=mock_context,
92+
clip_field=SimpleNamespace(
93+
text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace(), loras=[], skipped_layers=0
94+
),
95+
prompt="test prompt",
96+
get_pooled=False,
97+
lora_prefix="lora_te1_",
98+
zero_on_empty=False,
99+
)
100+
101+
assert FakeCompel.last_init_device == effective_device
102+
assert cond.shape == (1, 4, 4)
103+
assert pooled is None
104+
105+
106+
def test_sdxl_run_clip_compel_uses_cpu_for_fully_cpu_model(monkeypatch):
107+
module_path = "invokeai.app.invocations.compel"
108+
text_encoder = FakeClipTextEncoder(effective_device=torch.device("cpu"))
109+
tokenizer = FakeTokenizer()
110+
text_encoder_info = FakeLoadedModel(text_encoder, config=SimpleNamespace(base="sdxl"))
111+
tokenizer_info = FakeLoadedModel(tokenizer)
112+
113+
mock_context = MagicMock()
114+
mock_context.models.load.side_effect = [text_encoder_info, tokenizer_info]
115+
mock_context.config.get.return_value.log_tokenization = False
116+
mock_context.util.signal_progress = MagicMock()
117+
118+
monkeypatch.setattr(f"{module_path}.CLIPTextModel", FakeClipTextEncoder)
119+
monkeypatch.setattr(f"{module_path}.CLIPTextModelWithProjection", FakeClipTextEncoder)
120+
monkeypatch.setattr(f"{module_path}.CLIPTokenizer", FakeTokenizer)
121+
monkeypatch.setattr(f"{module_path}.Compel", FakeCompel)
122+
monkeypatch.setattr(f"{module_path}.generate_ti_list", lambda prompt, base, context: [])
123+
monkeypatch.setattr(f"{module_path}.LayerPatcher.apply_smart_model_patches", lambda **kwargs: nullcontext())
124+
monkeypatch.setattr(f"{module_path}.ModelPatcher.apply_clip_skip", lambda *args, **kwargs: nullcontext())
125+
monkeypatch.setattr(f"{module_path}.ModelPatcher.apply_ti", fake_apply_ti)
126+
127+
base = SDXLPromptInvocationBase()
128+
base.run_clip_compel(
129+
context=mock_context,
130+
clip_field=SimpleNamespace(
131+
text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace(), loras=[], skipped_layers=0
132+
),
133+
prompt="test prompt",
134+
get_pooled=False,
135+
lora_prefix="lora_te1_",
136+
zero_on_empty=False,
137+
)
138+
139+
assert FakeCompel.last_init_device == torch.device("cpu")

0 commit comments

Comments
 (0)