Skip to content

Commit cbc03f4

Browse files
authored
Merge branch 'main' into copilot/enhancement-allow-shared-boards
2 parents edd1258 + ee60097 commit cbc03f4

11 files changed

Lines changed: 482 additions & 11 deletions

File tree

invokeai/app/invocations/compel.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from invokeai.app.invocations.primitives import ConditioningOutput
2020
from invokeai.app.services.shared.invocation_context import InvocationContext
2121
from invokeai.app.util.ti_utils import generate_ti_list
22+
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
2223
from invokeai.backend.model_patcher import ModelPatcher
2324
from invokeai.backend.patches.layer_patcher import LayerPatcher
2425
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
@@ -103,7 +104,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
103104
textual_inversion_manager=ti_manager,
104105
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
105106
truncate_long_prompts=False,
106-
device=text_encoder.device, # Use the device the model is actually on
107+
device=get_effective_device(text_encoder),
107108
split_long_text_mode=SplitLongTextMode.SENTENCES,
108109
)
109110

@@ -212,7 +213,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
212213
truncate_long_prompts=False, # TODO:
213214
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
214215
requires_pooled=get_pooled,
215-
device=text_encoder.device, # Use the device the model is actually on
216+
device=get_effective_device(text_encoder),
216217
split_long_text_mode=SplitLongTextMode.SENTENCES,
217218
)
218219

invokeai/app/invocations/sd3_text_encoder.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from invokeai.app.invocations.model import CLIPField, T5EncoderField
1717
from invokeai.app.invocations.primitives import SD3ConditioningOutput
1818
from invokeai.app.services.shared.invocation_context import InvocationContext
19+
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
1920
from invokeai.backend.model_manager.taxonomy import ModelFormat
2021
from invokeai.backend.patches.layer_patcher import LayerPatcher
2122
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
@@ -103,6 +104,7 @@ def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tens
103104
context.util.signal_progress("Running T5 encoder")
104105
assert isinstance(t5_text_encoder, T5EncoderModel)
105106
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
107+
t5_device = get_effective_device(t5_text_encoder)
106108

107109
text_inputs = t5_tokenizer(
108110
prompt,
@@ -125,7 +127,7 @@ def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tens
125127
f" {max_seq_len} tokens: {removed_text}"
126128
)
127129

128-
prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0]
130+
prompt_embeds = t5_text_encoder(text_input_ids.to(t5_device))[0]
129131

130132
assert isinstance(prompt_embeds, torch.Tensor)
131133
return prompt_embeds
@@ -144,6 +146,7 @@ def _clip_encode(
144146
context.util.signal_progress("Running CLIP encoder")
145147
assert isinstance(clip_text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
146148
assert isinstance(clip_tokenizer, CLIPTokenizer)
149+
clip_device = get_effective_device(clip_text_encoder)
147150

148151
clip_text_encoder_config = clip_text_encoder_info.config
149152
assert clip_text_encoder_config is not None
@@ -187,9 +190,7 @@ def _clip_encode(
187190
"The following part of your input was truncated because CLIP can only handle sequences up to"
188191
f" {tokenizer_max_length} tokens: {removed_text}"
189192
)
190-
prompt_embeds = clip_text_encoder(
191-
input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True
192-
)
193+
prompt_embeds = clip_text_encoder(input_ids=text_input_ids.to(clip_device), output_hidden_states=True)
193194
pooled_prompt_embeds = prompt_embeds[0]
194195
prompt_embeds = prompt_embeds.hidden_states[-2]
195196

invokeai/backend/flux/modules/conditioner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from torch import Tensor, nn
44
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
55

6+
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
7+
68

79
class HFEncoder(nn.Module):
810
def __init__(
@@ -32,7 +34,7 @@ def forward(self, text: list[str]) -> Tensor:
3234
)
3335

3436
# Move inputs to the same device as the model to support cpu_only models
35-
model_device = next(self.hf_module.parameters()).device
37+
model_device = get_effective_device(self.hf_module)
3638

3739
outputs = self.hf_module(
3840
input_ids=batch_encoding["input_ids"].to(model_device),

invokeai/backend/model_manager/load/load_base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,12 @@ def __init__(self, cache_record: CacheRecord, cache: ModelCache):
5858

5959
def __enter__(self) -> AnyModel:
6060
self._cache.lock(self._cache_record, None)
61-
return self.model
61+
try:
62+
self.repair_required_tensors_on_device()
63+
return self.model
64+
except Exception:
65+
self._cache.unlock(self._cache_record)
66+
raise
6267

6368
def __exit__(self, *args: Any, **kwargs: Any) -> None:
6469
self._cache.unlock(self._cache_record)
@@ -74,6 +79,7 @@ def model_on_device(
7479
"""
7580
self._cache.lock(self._cache_record, working_mem_bytes)
7681
try:
82+
self.repair_required_tensors_on_device()
7783
yield (self._cache_record.cached_model.get_cpu_state_dict(), self._cache_record.cached_model.model)
7884
finally:
7985
self._cache.unlock(self._cache_record)

invokeai/frontend/web/public/locales/it.json

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,16 @@
844844
"settingsImportedPartial": "Impostazioni del modello parzialmente importate. Le impostazioni incompatibili sono state ignorate: {{fields}}",
845845
"settingsImportFailed": "Impossibile importare le impostazioni del modello",
846846
"settingsImportIncompatible": "Il file delle impostazioni non contiene impostazioni compatibili per questo tipo di modello",
847-
"settingsImportInvalidFile": "File di impostazioni non valido"
847+
"settingsImportInvalidFile": "File di impostazioni non valido",
848+
"reidentifyModels": "Re-identificare i modelli",
849+
"reidentifyModelsConfirm": "Sei sicuro di voler re-identificare {{count}} modello(i)? Questa operazione eseguirà una nuova scansione dei relativi file dei pesi per determinarne il formato e le impostazioni corrette.",
850+
"reidentifyWarning": "Questa operazione ripristinerà tutte le impostazioni personalizzate che potresti aver applicato a questi modelli.",
851+
"modelsReidentified": "{{count}} modello(i) re-identificato(i) con successo",
852+
"modelsReidentifyFailed": "Impossibile re-identificare i modelli",
853+
"someModelsFailedToReidentify": "Non è stato possibile re-identificare {{count}} modello(i)",
854+
"modelsReidentifiedPartial": "Completato parzialmente",
855+
"someModelsReidentified": "{{succeeded}} re-identificato(i), {{failed}} fallito(i)",
856+
"modelsReidentifyError": "Errore nella re-identificazione dei modelli"
848857
},
849858
"parameters": {
850859
"images": "Immagini",
@@ -2585,7 +2594,9 @@
25852594
"saveToGallery": "Salva nella Galleria",
25862595
"previous": "Precedente",
25872596
"showResultsOn": "Visualizzare i risultati",
2588-
"showResultsOff": "Nascondere i risultati"
2597+
"showResultsOff": "Nascondere i risultati",
2598+
"hideThumbnails": "Nascondi le miniature",
2599+
"showThumbnails": "Mostra miniature"
25892600
},
25902601
"HUD": {
25912602
"bbox": "Riquadro di delimitazione",
@@ -2826,6 +2837,27 @@
28262837
"alignLeft": "Allinea a sinistra",
28272838
"alignCenter": "Allinea al centro",
28282839
"alignRight": "Allinea a destra"
2840+
},
2841+
"workflowIntegration": {
2842+
"title": "Eseguire il flusso di lavoro sula Tela",
2843+
"description": "Seleziona un flusso di lavoro con un nodo Output su tela e un parametro immagine da eseguire sul livello corrente della tela. Puoi regolare i parametri prima dell'esecuzione. Il risultato verrà aggiunto nuovamente alla tela.",
2844+
"execute": "Eseguire il flusso di lavoro",
2845+
"executing": "Esecuzione in corso...",
2846+
"runWorkflow": "Avvia il flusso di lavoro",
2847+
"filteringWorkflows": "Filtraggio dei flussi di lavoro...",
2848+
"loadingWorkflows": "Caricamento dei flussi di lavoro...",
2849+
"noWorkflowsFound": "Nessun flusso di lavoro trovato.",
2850+
"noWorkflowsWithImageField": "Nessun flusso di lavoro compatibile trovato. Un flusso di lavoro richiede un Generatore Modello con un campo di input immagine e un nodo Output su tela.",
2851+
"selectWorkflow": "Seleziona il flusso di lavoro",
2852+
"selectPlaceholder": "Scegli un flusso di lavoro...",
2853+
"unnamedWorkflow": "Flusso di lavoro senza nome",
2854+
"loadingParameters": "Caricamento dei parametri del flusso di lavoro in corso...",
2855+
"noFormBuilderError": "Questo flusso di lavoro non dispone di un generatore di moduli e non può essere utilizzato. Selezionare un flusso di lavoro diverso.",
2856+
"imageFieldSelected": "Questo campo riceverà l'immagine della tela",
2857+
"imageFieldNotSelected": "Fai clic su questo campo per usarlo per l'immagine sulla tela",
2858+
"executionStarted": "L'esecuzione del flusso di lavoro è stata avviata",
2859+
"executionStartedDescription": "Il risultato apparirà nell'area di lavoro una volta completata l'operazione.",
2860+
"executionFailed": "Impossibile eseguire il flusso di lavoro"
28292861
}
28302862
},
28312863
"ui": {
@@ -3072,7 +3104,8 @@
30723104
"rememberMe": "Ricordami per 7 giorni",
30733105
"signIn": "Accedi",
30743106
"signingIn": "Accesso in corso...",
3075-
"loginFailed": "Accesso non riuscito. Controlla le tue credenziali."
3107+
"loginFailed": "Accesso non riuscito. Controlla le tue credenziali.",
3108+
"sessionExpired": "Le tue credenziali sono scadute. Effettua nuovamente l'accesso per continuare."
30763109
},
30773110
"setup": {
30783111
"title": "Benvenuti a InvokeAI",

invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import '@fontsource-variable/inter';
22
import 'overlayscrollbars/overlayscrollbars.css';
33
import '@xyflow/react/dist/base.css';
44
import 'common/components/OverlayScrollbars/overlayscrollbars.css';
5+
import 'app/components/touchDevice.css';
56

67
import { ChakraProvider, DarkMode, extendTheme, theme as baseTheme, TOAST_OPTIONS } from '@invoke-ai/ui-library';
78
import { useStore } from '@nanostores/react';
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
/* Hide tooltips on touch devices where hover gets "stuck" */
2+
@media (hover: none) {
3+
[role='tooltip'] {
4+
visibility: hidden !important;
5+
opacity: 0 !important;
6+
}
7+
}
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
from contextlib import contextmanager, nullcontext
2+
from types import SimpleNamespace
3+
from unittest.mock import MagicMock
4+
5+
import torch
6+
7+
from invokeai.app.invocations.compel import SDXLPromptInvocationBase
8+
9+
10+
class FakeClipTextEncoder(torch.nn.Module):
11+
def __init__(self, effective_device: torch.device):
12+
super().__init__()
13+
self.register_parameter("cpu_param", torch.nn.Parameter(torch.ones(1)))
14+
self.register_buffer("active_buffer", torch.ones(1, device=effective_device))
15+
self.dtype = torch.float32
16+
17+
@property
18+
def device(self) -> torch.device:
19+
return torch.device("cpu")
20+
21+
22+
class FakeTokenizer:
23+
pass
24+
25+
26+
class FakeLoadedModel:
27+
def __init__(self, model, config=None):
28+
self._model = model
29+
self.config = config
30+
31+
@contextmanager
32+
def model_on_device(self):
33+
yield (None, self._model)
34+
35+
def __enter__(self):
36+
return self._model
37+
38+
def __exit__(self, exc_type, exc, tb):
39+
return False
40+
41+
42+
class FakeCompel:
43+
last_init_device: torch.device | None = None
44+
45+
def __init__(self, *args, device: torch.device, **kwargs):
46+
del args, kwargs
47+
FakeCompel.last_init_device = device
48+
self.conditioning_provider = SimpleNamespace(
49+
get_pooled_embeddings=lambda prompts: torch.ones((len(prompts), 4), dtype=torch.float32)
50+
)
51+
52+
@staticmethod
53+
def parse_prompt_string(prompt: str) -> str:
54+
return prompt
55+
56+
def build_conditioning_tensor_for_conjunction(self, conjunction: str):
57+
del conjunction
58+
return torch.ones((1, 4, 4), dtype=torch.float32), {}
59+
60+
61+
@contextmanager
62+
def fake_apply_ti(tokenizer, text_encoder, ti_list):
63+
del text_encoder, ti_list
64+
yield tokenizer, object()
65+
66+
67+
def test_sdxl_run_clip_compel_uses_effective_device_for_partially_loaded_model(monkeypatch):
68+
module_path = "invokeai.app.invocations.compel"
69+
effective_device = torch.device("meta")
70+
text_encoder = FakeClipTextEncoder(effective_device=effective_device)
71+
tokenizer = FakeTokenizer()
72+
text_encoder_info = FakeLoadedModel(text_encoder, config=SimpleNamespace(base="sdxl"))
73+
tokenizer_info = FakeLoadedModel(tokenizer)
74+
75+
mock_context = MagicMock()
76+
mock_context.models.load.side_effect = [text_encoder_info, tokenizer_info]
77+
mock_context.config.get.return_value.log_tokenization = False
78+
mock_context.util.signal_progress = MagicMock()
79+
80+
monkeypatch.setattr(f"{module_path}.CLIPTextModel", FakeClipTextEncoder)
81+
monkeypatch.setattr(f"{module_path}.CLIPTextModelWithProjection", FakeClipTextEncoder)
82+
monkeypatch.setattr(f"{module_path}.CLIPTokenizer", FakeTokenizer)
83+
monkeypatch.setattr(f"{module_path}.Compel", FakeCompel)
84+
monkeypatch.setattr(f"{module_path}.generate_ti_list", lambda prompt, base, context: [])
85+
monkeypatch.setattr(f"{module_path}.LayerPatcher.apply_smart_model_patches", lambda **kwargs: nullcontext())
86+
monkeypatch.setattr(f"{module_path}.ModelPatcher.apply_clip_skip", lambda *args, **kwargs: nullcontext())
87+
monkeypatch.setattr(f"{module_path}.ModelPatcher.apply_ti", fake_apply_ti)
88+
89+
base = SDXLPromptInvocationBase()
90+
cond, pooled = base.run_clip_compel(
91+
context=mock_context,
92+
clip_field=SimpleNamespace(
93+
text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace(), loras=[], skipped_layers=0
94+
),
95+
prompt="test prompt",
96+
get_pooled=False,
97+
lora_prefix="lora_te1_",
98+
zero_on_empty=False,
99+
)
100+
101+
assert FakeCompel.last_init_device == effective_device
102+
assert cond.shape == (1, 4, 4)
103+
assert pooled is None
104+
105+
106+
def test_sdxl_run_clip_compel_uses_cpu_for_fully_cpu_model(monkeypatch):
107+
module_path = "invokeai.app.invocations.compel"
108+
text_encoder = FakeClipTextEncoder(effective_device=torch.device("cpu"))
109+
tokenizer = FakeTokenizer()
110+
text_encoder_info = FakeLoadedModel(text_encoder, config=SimpleNamespace(base="sdxl"))
111+
tokenizer_info = FakeLoadedModel(tokenizer)
112+
113+
mock_context = MagicMock()
114+
mock_context.models.load.side_effect = [text_encoder_info, tokenizer_info]
115+
mock_context.config.get.return_value.log_tokenization = False
116+
mock_context.util.signal_progress = MagicMock()
117+
118+
monkeypatch.setattr(f"{module_path}.CLIPTextModel", FakeClipTextEncoder)
119+
monkeypatch.setattr(f"{module_path}.CLIPTextModelWithProjection", FakeClipTextEncoder)
120+
monkeypatch.setattr(f"{module_path}.CLIPTokenizer", FakeTokenizer)
121+
monkeypatch.setattr(f"{module_path}.Compel", FakeCompel)
122+
monkeypatch.setattr(f"{module_path}.generate_ti_list", lambda prompt, base, context: [])
123+
monkeypatch.setattr(f"{module_path}.LayerPatcher.apply_smart_model_patches", lambda **kwargs: nullcontext())
124+
monkeypatch.setattr(f"{module_path}.ModelPatcher.apply_clip_skip", lambda *args, **kwargs: nullcontext())
125+
monkeypatch.setattr(f"{module_path}.ModelPatcher.apply_ti", fake_apply_ti)
126+
127+
base = SDXLPromptInvocationBase()
128+
base.run_clip_compel(
129+
context=mock_context,
130+
clip_field=SimpleNamespace(
131+
text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace(), loras=[], skipped_layers=0
132+
),
133+
prompt="test prompt",
134+
get_pooled=False,
135+
lora_prefix="lora_te1_",
136+
zero_on_empty=False,
137+
)
138+
139+
assert FakeCompel.last_init_device == torch.device("cpu")

0 commit comments

Comments
 (0)