Skip to content

Commit 0f8dce0

Browse files
kappacommitYour NameJPPhotoPfannkuchensack
authored
fix anima model auto-selection (#9035)
Co-authored-by: Your Name <you@example.com> Co-authored-by: Jonathan <34005131+JPPhoto@users.noreply.github.com> Co-authored-by: Alexander Eichhorn <alex@eichhorn.dev>
1 parent 5436ced commit 0f8dce0

3 files changed

Lines changed: 366 additions & 5 deletions

File tree

Lines changed: 360 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,360 @@
1+
import { zModelIdentifierField } from 'features/nodes/types/common';
2+
import { beforeEach, describe, expect, it, vi } from 'vitest';
3+
4+
// Mock model configs returned by selectors - these simulate what RTK Query provides
5+
const mockAnimaQwen3Encoder = {
6+
key: 'qwen3-06b-key',
7+
hash: 'qwen3-06b-hash',
8+
name: 'Qwen3 0.6B Encoder',
9+
base: 'any' as const,
10+
type: 'qwen3_encoder' as const,
11+
variant: 'qwen3_06b' as const,
12+
format: 'qwen3_encoder' as const,
13+
};
14+
15+
const mockAnimaVAE = {
16+
key: 'anima-vae-key',
17+
hash: 'anima-vae-hash',
18+
name: 'Anima VAE',
19+
base: 'anima' as const,
20+
type: 'vae' as const,
21+
format: 'diffusers' as const,
22+
};
23+
24+
const mockT5Encoder = {
25+
key: 't5-xxl-key',
26+
hash: 't5-xxl-hash',
27+
name: 'T5-XXL Encoder',
28+
base: 'any' as const,
29+
type: 't5_encoder' as const,
30+
format: 't5_encoder' as const,
31+
};
32+
33+
const mockAnimaMainModel = {
34+
key: 'anima-main-key',
35+
hash: 'anima-main-hash',
36+
name: 'Anima Generate',
37+
base: 'anima' as const,
38+
type: 'main' as const,
39+
};
40+
41+
const mockFluxMainModel = {
42+
key: 'flux-main-key',
43+
hash: 'flux-main-hash',
44+
name: 'FLUX.1 Dev',
45+
base: 'flux' as const,
46+
type: 'main' as const,
47+
};
48+
49+
// Track dispatched actions
50+
const dispatched: Array<{ type: string; payload: unknown }> = [];
51+
const mockDispatch = vi.fn((action: { type: string; payload: unknown }) => {
52+
dispatched.push(action);
53+
});
54+
55+
// Mock logger
56+
vi.mock('app/logging/logger', () => ({
57+
logger: () => ({
58+
debug: vi.fn(),
59+
error: vi.fn(),
60+
warn: vi.fn(),
61+
info: vi.fn(),
62+
}),
63+
}));
64+
65+
// Mock toast
66+
vi.mock('features/toast/toast', () => ({
67+
toast: vi.fn(),
68+
}));
69+
70+
// Mock i18next
71+
vi.mock('i18next', () => ({
72+
t: (key: string) => key,
73+
}));
74+
75+
// Mock model selectors from RTK Query hooks
76+
77+
const mockSelectAnimaQwen3EncoderModels = vi.fn((_state: unknown) => [mockAnimaQwen3Encoder]);
78+
79+
const mockSelectAnimaVAEModels = vi.fn((_state: unknown) => [mockAnimaVAE]);
80+
81+
const mockSelectT5EncoderModels = vi.fn((_state: unknown) => [mockT5Encoder]);
82+
83+
vi.mock('services/api/hooks/modelsByType', () => ({
84+
selectAnimaQwen3EncoderModels: (state: unknown) => mockSelectAnimaQwen3EncoderModels(state),
85+
selectAnimaVAEModels: (state: unknown) => mockSelectAnimaVAEModels(state),
86+
selectT5EncoderModels: (state: unknown) => mockSelectT5EncoderModels(state),
87+
selectQwen3EncoderModels: vi.fn(() => []),
88+
selectZImageDiffusersModels: vi.fn(() => []),
89+
selectFluxVAEModels: vi.fn(() => []),
90+
selectGlobalRefImageModels: vi.fn(() => []),
91+
selectRegionalRefImageModels: vi.fn(() => []),
92+
}));
93+
94+
// Mock model configs adapter
95+
vi.mock('services/api/endpoints/models', () => ({
96+
modelConfigsAdapterSelectors: { selectById: vi.fn() },
97+
selectModelConfigsQuery: vi.fn(() => ({ data: undefined })),
98+
}));
99+
100+
vi.mock('services/api/types', () => ({
101+
isFluxKontextModelConfig: vi.fn(() => false),
102+
isFluxReduxModelConfig: vi.fn(() => false),
103+
}));
104+
105+
// Mock canvas selectors
106+
vi.mock('features/controlLayers/store/canvasStagingAreaSlice', () => ({
107+
buildSelectIsStaging: vi.fn(() => vi.fn(() => false)),
108+
selectCanvasSessionId: vi.fn(() => null),
109+
}));
110+
111+
vi.mock('features/controlLayers/store/selectors', () => ({
112+
selectAllEntitiesOfType: vi.fn(() => []),
113+
selectBboxModelBase: vi.fn(() => 'anima'),
114+
selectCanvasSlice: vi.fn(() => ({})),
115+
}));
116+
117+
vi.mock('features/controlLayers/store/refImagesSlice', () => ({
118+
refImageConfigChanged: vi.fn(),
119+
refImageModelChanged: vi.fn(),
120+
selectReferenceImageEntities: vi.fn(() => []),
121+
}));
122+
123+
vi.mock('features/controlLayers/store/types', async () => {
124+
const actual = await vi.importActual('features/controlLayers/store/types');
125+
return {
126+
...(actual as Record<string, unknown>),
127+
getEntityIdentifier: vi.fn(),
128+
isFlux2ReferenceImageConfig: vi.fn(() => false),
129+
};
130+
});
131+
132+
vi.mock('features/controlLayers/store/util', () => ({
133+
initialFlux2ReferenceImage: {},
134+
initialFluxKontextReferenceImage: {},
135+
initialFLUXRedux: {},
136+
initialIPAdapter: {},
137+
}));
138+
139+
vi.mock('features/modelManagerV2/models', () => ({
140+
SUPPORTS_REF_IMAGES_BASE_MODELS: ['sd-1', 'sdxl', 'flux', 'flux2'],
141+
}));
142+
143+
vi.mock('features/controlLayers/store/canvasSlice', () => ({
144+
bboxSyncedToOptimalDimension: vi.fn(() => ({ type: 'bboxSyncedToOptimalDimension' })),
145+
rgRefImageModelChanged: vi.fn(),
146+
}));
147+
148+
vi.mock('features/controlLayers/store/lorasSlice', () => ({
149+
loraIsEnabledChanged: vi.fn((payload: unknown) => ({ type: 'loraIsEnabledChanged', payload })),
150+
}));
151+
152+
// Capture the listener effect so we can call it directly
153+
let capturedEffect: ((action: unknown, api: unknown) => void) | null = null;
154+
155+
// Import actual action creators for assertion matching
156+
const paramsSliceActual = (await vi.importActual('features/controlLayers/store/paramsSlice')) as {
157+
animaQwen3EncoderModelSelected: { type: string };
158+
animaT5EncoderModelSelected: { type: string };
159+
animaVaeModelSelected: { type: string };
160+
};
161+
const { animaQwen3EncoderModelSelected, animaT5EncoderModelSelected, animaVaeModelSelected } = paramsSliceActual;
162+
163+
// Import after mocks are set up
164+
const { addModelSelectedListener } = await import('./modelSelected');
165+
const { modelSelected } = await import('features/parameters/store/actions');
166+
const { zParameterModel } = await import('features/parameters/types/parameterSchemas');
167+
168+
// Capture the effect
169+
addModelSelectedListener(((config: { effect: typeof capturedEffect }) => {
170+
capturedEffect = config.effect;
171+
}) as never);
172+
173+
function buildMockState(overrides: Record<string, unknown> = {}) {
174+
return {
175+
params: {
176+
model: null,
177+
vae: null,
178+
zImageVaeModel: null,
179+
zImageQwen3EncoderModel: null,
180+
zImageQwen3SourceModel: null,
181+
animaVaeModel: null,
182+
animaQwen3EncoderModel: null,
183+
animaT5EncoderModel: null,
184+
animaScheduler: 'euler',
185+
kleinVaeModel: null,
186+
kleinQwen3EncoderModel: null,
187+
zImageScheduler: 'euler',
188+
...overrides,
189+
},
190+
loras: { loras: [] },
191+
canvas: {},
192+
};
193+
}
194+
195+
describe('modelSelected listener - Anima defaulting', () => {
196+
beforeEach(() => {
197+
dispatched.length = 0;
198+
mockDispatch.mockClear();
199+
mockSelectAnimaQwen3EncoderModels.mockReturnValue([mockAnimaQwen3Encoder]);
200+
mockSelectAnimaVAEModels.mockReturnValue([mockAnimaVAE]);
201+
mockSelectT5EncoderModels.mockReturnValue([mockT5Encoder]);
202+
});
203+
204+
it('should dispatch encoder models with full ModelIdentifierField payloads when switching to Anima', () => {
205+
const state = buildMockState({ model: mockFluxMainModel });
206+
const action = modelSelected(zParameterModel.parse(mockAnimaMainModel));
207+
208+
capturedEffect!(action, {
209+
getState: () => state,
210+
dispatch: mockDispatch,
211+
});
212+
213+
// Find the dispatched actions for Anima encoders
214+
const qwen3Dispatch = dispatched.find((a) => a.type === animaQwen3EncoderModelSelected.type);
215+
const t5Dispatch = dispatched.find((a) => a.type === animaT5EncoderModelSelected.type);
216+
const vaeDispatch = dispatched.find((a) => a.type === animaVaeModelSelected.type);
217+
218+
// All three should have been dispatched
219+
expect(qwen3Dispatch).toBeDefined();
220+
expect(t5Dispatch).toBeDefined();
221+
expect(vaeDispatch).toBeDefined();
222+
223+
// The payloads must pass zModelIdentifierField validation (the actual schema used by reducers)
224+
expect(zModelIdentifierField.safeParse(qwen3Dispatch!.payload).success).toBe(true);
225+
expect(zModelIdentifierField.safeParse(t5Dispatch!.payload).success).toBe(true);
226+
expect(zModelIdentifierField.safeParse(vaeDispatch!.payload).success).toBe(true);
227+
});
228+
229+
it('should include hash and type in Qwen3 encoder payload', () => {
230+
const state = buildMockState({ model: mockFluxMainModel });
231+
const action = modelSelected(zParameterModel.parse(mockAnimaMainModel));
232+
233+
capturedEffect!(action, {
234+
getState: () => state,
235+
dispatch: mockDispatch,
236+
});
237+
238+
const qwen3Dispatch = dispatched.find((a) => a.type === animaQwen3EncoderModelSelected.type);
239+
expect(qwen3Dispatch!.payload).toMatchObject({
240+
key: mockAnimaQwen3Encoder.key,
241+
hash: mockAnimaQwen3Encoder.hash,
242+
name: mockAnimaQwen3Encoder.name,
243+
base: mockAnimaQwen3Encoder.base,
244+
type: mockAnimaQwen3Encoder.type,
245+
});
246+
});
247+
248+
it('should include hash and type in T5 encoder payload', () => {
249+
const state = buildMockState({ model: mockFluxMainModel });
250+
const action = modelSelected(zParameterModel.parse(mockAnimaMainModel));
251+
252+
capturedEffect!(action, {
253+
getState: () => state,
254+
dispatch: mockDispatch,
255+
});
256+
257+
const t5Dispatch = dispatched.find((a) => a.type === animaT5EncoderModelSelected.type);
258+
expect(t5Dispatch!.payload).toMatchObject({
259+
key: mockT5Encoder.key,
260+
hash: mockT5Encoder.hash,
261+
name: mockT5Encoder.name,
262+
base: mockT5Encoder.base,
263+
type: mockT5Encoder.type,
264+
});
265+
});
266+
267+
it('should not dispatch encoder defaults when Anima models are already set', () => {
268+
const existingQwen3 = { key: 'existing', hash: 'h', name: 'Existing', base: 'any', type: 'qwen3_encoder' };
269+
const existingT5 = { key: 'existing-t5', hash: 'h', name: 'Existing T5', base: 'any', type: 't5_encoder' };
270+
const existingVae = { key: 'existing-vae', hash: 'h', name: 'Existing VAE', base: 'anima', type: 'vae' };
271+
272+
const state = buildMockState({
273+
model: mockFluxMainModel,
274+
animaQwen3EncoderModel: existingQwen3,
275+
animaT5EncoderModel: existingT5,
276+
animaVaeModel: existingVae,
277+
});
278+
279+
const action = modelSelected(zParameterModel.parse(mockAnimaMainModel));
280+
281+
capturedEffect!(action, {
282+
getState: () => state,
283+
dispatch: mockDispatch,
284+
});
285+
286+
// Should NOT dispatch any encoder model selections since they're already set
287+
const qwen3Dispatch = dispatched.find((a) => a.type === animaQwen3EncoderModelSelected.type);
288+
const t5Dispatch = dispatched.find((a) => a.type === animaT5EncoderModelSelected.type);
289+
const vaeDispatch = dispatched.find((a) => a.type === animaVaeModelSelected.type);
290+
291+
expect(qwen3Dispatch).toBeUndefined();
292+
expect(t5Dispatch).toBeUndefined();
293+
expect(vaeDispatch).toBeUndefined();
294+
});
295+
296+
it('should not dispatch encoder defaults when no encoder models are available', () => {
297+
mockSelectAnimaQwen3EncoderModels.mockReturnValue([]);
298+
mockSelectAnimaVAEModels.mockReturnValue([]);
299+
300+
const state = buildMockState({ model: mockFluxMainModel });
301+
const action = modelSelected(zParameterModel.parse(mockAnimaMainModel));
302+
303+
capturedEffect!(action, {
304+
getState: () => state,
305+
dispatch: mockDispatch,
306+
});
307+
308+
const qwen3Dispatch = dispatched.find((a) => a.type === animaQwen3EncoderModelSelected.type);
309+
const t5Dispatch = dispatched.find((a) => a.type === animaT5EncoderModelSelected.type);
310+
const vaeDispatch = dispatched.find((a) => a.type === animaVaeModelSelected.type);
311+
312+
expect(qwen3Dispatch).toBeUndefined();
313+
expect(t5Dispatch).toBeUndefined();
314+
expect(vaeDispatch).toBeUndefined();
315+
});
316+
317+
it('should clear Anima models when switching away from Anima', () => {
318+
const existingQwen3 = { key: 'existing', hash: 'h', name: 'Existing', base: 'any', type: 'qwen3_encoder' };
319+
const existingT5 = { key: 'existing-t5', hash: 'h', name: 'Existing T5', base: 'any', type: 't5_encoder' };
320+
const existingVae = { key: 'existing-vae', hash: 'h', name: 'Existing VAE', base: 'anima', type: 'vae' };
321+
322+
const state = buildMockState({
323+
model: mockAnimaMainModel,
324+
animaQwen3EncoderModel: existingQwen3,
325+
animaT5EncoderModel: existingT5,
326+
animaVaeModel: existingVae,
327+
});
328+
329+
const action = modelSelected(zParameterModel.parse(mockFluxMainModel));
330+
331+
capturedEffect!(action, {
332+
getState: () => state,
333+
dispatch: mockDispatch,
334+
});
335+
336+
// Should dispatch null for all three
337+
const qwen3Dispatch = dispatched.find((a) => a.type === animaQwen3EncoderModelSelected.type);
338+
const t5Dispatch = dispatched.find((a) => a.type === animaT5EncoderModelSelected.type);
339+
const vaeDispatch = dispatched.find((a) => a.type === animaVaeModelSelected.type);
340+
341+
expect(qwen3Dispatch).toBeDefined();
342+
expect(qwen3Dispatch!.payload).toBeNull();
343+
expect(t5Dispatch).toBeDefined();
344+
expect(t5Dispatch!.payload).toBeNull();
345+
expect(vaeDispatch).toBeDefined();
346+
expect(vaeDispatch!.payload).toBeNull();
347+
});
348+
});
349+
350+
describe('zModelIdentifierField schema validation', () => {
351+
it('should reject payloads missing hash and type', () => {
352+
const incomplete = { key: 'some-key', name: 'Some Model', base: 'any' };
353+
expect(zModelIdentifierField.safeParse(incomplete).success).toBe(false);
354+
});
355+
356+
it('should accept payloads with all required fields', () => {
357+
const complete = { key: 'some-key', hash: 'some-hash', name: 'Some Model', base: 'any', type: 'qwen3_encoder' };
358+
expect(zModelIdentifierField.safeParse(complete).success).toBe(true);
359+
});
360+
});

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,10 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
200200
dispatch(
201201
animaQwen3EncoderModelSelected({
202202
key: qwen3Encoder.key,
203+
hash: qwen3Encoder.hash,
203204
name: qwen3Encoder.name,
204205
base: qwen3Encoder.base,
206+
type: qwen3Encoder.type,
205207
})
206208
);
207209
}
@@ -221,8 +223,10 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
221223
dispatch(
222224
animaT5EncoderModelSelected({
223225
key: t5Encoder.key,
226+
hash: t5Encoder.hash,
224227
name: t5Encoder.name,
225228
base: t5Encoder.base,
229+
type: t5Encoder.type,
226230
})
227231
);
228232
}

0 commit comments

Comments
 (0)