|
| 1 | +import { zModelIdentifierField } from 'features/nodes/types/common'; |
| 2 | +import { beforeEach, describe, expect, it, vi } from 'vitest'; |
| 3 | + |
| 4 | +// Mock model configs returned by selectors - these simulate what RTK Query provides |
| 5 | +const mockAnimaQwen3Encoder = { |
| 6 | + key: 'qwen3-06b-key', |
| 7 | + hash: 'qwen3-06b-hash', |
| 8 | + name: 'Qwen3 0.6B Encoder', |
| 9 | + base: 'any' as const, |
| 10 | + type: 'qwen3_encoder' as const, |
| 11 | + variant: 'qwen3_06b' as const, |
| 12 | + format: 'qwen3_encoder' as const, |
| 13 | +}; |
| 14 | + |
| 15 | +const mockAnimaVAE = { |
| 16 | + key: 'anima-vae-key', |
| 17 | + hash: 'anima-vae-hash', |
| 18 | + name: 'Anima VAE', |
| 19 | + base: 'anima' as const, |
| 20 | + type: 'vae' as const, |
| 21 | + format: 'diffusers' as const, |
| 22 | +}; |
| 23 | + |
| 24 | +const mockT5Encoder = { |
| 25 | + key: 't5-xxl-key', |
| 26 | + hash: 't5-xxl-hash', |
| 27 | + name: 'T5-XXL Encoder', |
| 28 | + base: 'any' as const, |
| 29 | + type: 't5_encoder' as const, |
| 30 | + format: 't5_encoder' as const, |
| 31 | +}; |
| 32 | + |
| 33 | +const mockAnimaMainModel = { |
| 34 | + key: 'anima-main-key', |
| 35 | + hash: 'anima-main-hash', |
| 36 | + name: 'Anima Generate', |
| 37 | + base: 'anima' as const, |
| 38 | + type: 'main' as const, |
| 39 | +}; |
| 40 | + |
| 41 | +const mockFluxMainModel = { |
| 42 | + key: 'flux-main-key', |
| 43 | + hash: 'flux-main-hash', |
| 44 | + name: 'FLUX.1 Dev', |
| 45 | + base: 'flux' as const, |
| 46 | + type: 'main' as const, |
| 47 | +}; |
| 48 | + |
| 49 | +// Track dispatched actions |
| 50 | +const dispatched: Array<{ type: string; payload: unknown }> = []; |
| 51 | +const mockDispatch = vi.fn((action: { type: string; payload: unknown }) => { |
| 52 | + dispatched.push(action); |
| 53 | +}); |
| 54 | + |
| 55 | +// Mock logger |
| 56 | +vi.mock('app/logging/logger', () => ({ |
| 57 | + logger: () => ({ |
| 58 | + debug: vi.fn(), |
| 59 | + error: vi.fn(), |
| 60 | + warn: vi.fn(), |
| 61 | + info: vi.fn(), |
| 62 | + }), |
| 63 | +})); |
| 64 | + |
| 65 | +// Mock toast |
| 66 | +vi.mock('features/toast/toast', () => ({ |
| 67 | + toast: vi.fn(), |
| 68 | +})); |
| 69 | + |
| 70 | +// Mock i18next |
| 71 | +vi.mock('i18next', () => ({ |
| 72 | + t: (key: string) => key, |
| 73 | +})); |
| 74 | + |
| 75 | +// Mock model selectors from RTK Query hooks |
| 76 | + |
| 77 | +const mockSelectAnimaQwen3EncoderModels = vi.fn((_state: unknown) => [mockAnimaQwen3Encoder]); |
| 78 | + |
| 79 | +const mockSelectAnimaVAEModels = vi.fn((_state: unknown) => [mockAnimaVAE]); |
| 80 | + |
| 81 | +const mockSelectT5EncoderModels = vi.fn((_state: unknown) => [mockT5Encoder]); |
| 82 | + |
| 83 | +vi.mock('services/api/hooks/modelsByType', () => ({ |
| 84 | + selectAnimaQwen3EncoderModels: (state: unknown) => mockSelectAnimaQwen3EncoderModels(state), |
| 85 | + selectAnimaVAEModels: (state: unknown) => mockSelectAnimaVAEModels(state), |
| 86 | + selectT5EncoderModels: (state: unknown) => mockSelectT5EncoderModels(state), |
| 87 | + selectQwen3EncoderModels: vi.fn(() => []), |
| 88 | + selectZImageDiffusersModels: vi.fn(() => []), |
| 89 | + selectFluxVAEModels: vi.fn(() => []), |
| 90 | + selectGlobalRefImageModels: vi.fn(() => []), |
| 91 | + selectRegionalRefImageModels: vi.fn(() => []), |
| 92 | +})); |
| 93 | + |
| 94 | +// Mock model configs adapter |
| 95 | +vi.mock('services/api/endpoints/models', () => ({ |
| 96 | + modelConfigsAdapterSelectors: { selectById: vi.fn() }, |
| 97 | + selectModelConfigsQuery: vi.fn(() => ({ data: undefined })), |
| 98 | +})); |
| 99 | + |
| 100 | +vi.mock('services/api/types', () => ({ |
| 101 | + isFluxKontextModelConfig: vi.fn(() => false), |
| 102 | + isFluxReduxModelConfig: vi.fn(() => false), |
| 103 | +})); |
| 104 | + |
| 105 | +// Mock canvas selectors |
| 106 | +vi.mock('features/controlLayers/store/canvasStagingAreaSlice', () => ({ |
| 107 | + buildSelectIsStaging: vi.fn(() => vi.fn(() => false)), |
| 108 | + selectCanvasSessionId: vi.fn(() => null), |
| 109 | +})); |
| 110 | + |
| 111 | +vi.mock('features/controlLayers/store/selectors', () => ({ |
| 112 | + selectAllEntitiesOfType: vi.fn(() => []), |
| 113 | + selectBboxModelBase: vi.fn(() => 'anima'), |
| 114 | + selectCanvasSlice: vi.fn(() => ({})), |
| 115 | +})); |
| 116 | + |
| 117 | +vi.mock('features/controlLayers/store/refImagesSlice', () => ({ |
| 118 | + refImageConfigChanged: vi.fn(), |
| 119 | + refImageModelChanged: vi.fn(), |
| 120 | + selectReferenceImageEntities: vi.fn(() => []), |
| 121 | +})); |
| 122 | + |
| 123 | +vi.mock('features/controlLayers/store/types', async () => { |
| 124 | + const actual = await vi.importActual('features/controlLayers/store/types'); |
| 125 | + return { |
| 126 | + ...(actual as Record<string, unknown>), |
| 127 | + getEntityIdentifier: vi.fn(), |
| 128 | + isFlux2ReferenceImageConfig: vi.fn(() => false), |
| 129 | + }; |
| 130 | +}); |
| 131 | + |
| 132 | +vi.mock('features/controlLayers/store/util', () => ({ |
| 133 | + initialFlux2ReferenceImage: {}, |
| 134 | + initialFluxKontextReferenceImage: {}, |
| 135 | + initialFLUXRedux: {}, |
| 136 | + initialIPAdapter: {}, |
| 137 | +})); |
| 138 | + |
| 139 | +vi.mock('features/modelManagerV2/models', () => ({ |
| 140 | + SUPPORTS_REF_IMAGES_BASE_MODELS: ['sd-1', 'sdxl', 'flux', 'flux2'], |
| 141 | +})); |
| 142 | + |
| 143 | +vi.mock('features/controlLayers/store/canvasSlice', () => ({ |
| 144 | + bboxSyncedToOptimalDimension: vi.fn(() => ({ type: 'bboxSyncedToOptimalDimension' })), |
| 145 | + rgRefImageModelChanged: vi.fn(), |
| 146 | +})); |
| 147 | + |
| 148 | +vi.mock('features/controlLayers/store/lorasSlice', () => ({ |
| 149 | + loraIsEnabledChanged: vi.fn((payload: unknown) => ({ type: 'loraIsEnabledChanged', payload })), |
| 150 | +})); |
| 151 | + |
| 152 | +// Capture the listener effect so we can call it directly |
| 153 | +let capturedEffect: ((action: unknown, api: unknown) => void) | null = null; |
| 154 | + |
| 155 | +// Import actual action creators for assertion matching |
| 156 | +const paramsSliceActual = (await vi.importActual('features/controlLayers/store/paramsSlice')) as { |
| 157 | + animaQwen3EncoderModelSelected: { type: string }; |
| 158 | + animaT5EncoderModelSelected: { type: string }; |
| 159 | + animaVaeModelSelected: { type: string }; |
| 160 | +}; |
| 161 | +const { animaQwen3EncoderModelSelected, animaT5EncoderModelSelected, animaVaeModelSelected } = paramsSliceActual; |
| 162 | + |
| 163 | +// Import after mocks are set up |
| 164 | +const { addModelSelectedListener } = await import('./modelSelected'); |
| 165 | +const { modelSelected } = await import('features/parameters/store/actions'); |
| 166 | +const { zParameterModel } = await import('features/parameters/types/parameterSchemas'); |
| 167 | + |
| 168 | +// Capture the effect |
| 169 | +addModelSelectedListener(((config: { effect: typeof capturedEffect }) => { |
| 170 | + capturedEffect = config.effect; |
| 171 | +}) as never); |
| 172 | + |
| 173 | +function buildMockState(overrides: Record<string, unknown> = {}) { |
| 174 | + return { |
| 175 | + params: { |
| 176 | + model: null, |
| 177 | + vae: null, |
| 178 | + zImageVaeModel: null, |
| 179 | + zImageQwen3EncoderModel: null, |
| 180 | + zImageQwen3SourceModel: null, |
| 181 | + animaVaeModel: null, |
| 182 | + animaQwen3EncoderModel: null, |
| 183 | + animaT5EncoderModel: null, |
| 184 | + animaScheduler: 'euler', |
| 185 | + kleinVaeModel: null, |
| 186 | + kleinQwen3EncoderModel: null, |
| 187 | + zImageScheduler: 'euler', |
| 188 | + ...overrides, |
| 189 | + }, |
| 190 | + loras: { loras: [] }, |
| 191 | + canvas: {}, |
| 192 | + }; |
| 193 | +} |
| 194 | + |
| 195 | +describe('modelSelected listener - Anima defaulting', () => { |
| 196 | + beforeEach(() => { |
| 197 | + dispatched.length = 0; |
| 198 | + mockDispatch.mockClear(); |
| 199 | + mockSelectAnimaQwen3EncoderModels.mockReturnValue([mockAnimaQwen3Encoder]); |
| 200 | + mockSelectAnimaVAEModels.mockReturnValue([mockAnimaVAE]); |
| 201 | + mockSelectT5EncoderModels.mockReturnValue([mockT5Encoder]); |
| 202 | + }); |
| 203 | + |
| 204 | + it('should dispatch encoder models with full ModelIdentifierField payloads when switching to Anima', () => { |
| 205 | + const state = buildMockState({ model: mockFluxMainModel }); |
| 206 | + const action = modelSelected(zParameterModel.parse(mockAnimaMainModel)); |
| 207 | + |
| 208 | + capturedEffect!(action, { |
| 209 | + getState: () => state, |
| 210 | + dispatch: mockDispatch, |
| 211 | + }); |
| 212 | + |
| 213 | + // Find the dispatched actions for Anima encoders |
| 214 | + const qwen3Dispatch = dispatched.find((a) => a.type === animaQwen3EncoderModelSelected.type); |
| 215 | + const t5Dispatch = dispatched.find((a) => a.type === animaT5EncoderModelSelected.type); |
| 216 | + const vaeDispatch = dispatched.find((a) => a.type === animaVaeModelSelected.type); |
| 217 | + |
| 218 | + // All three should have been dispatched |
| 219 | + expect(qwen3Dispatch).toBeDefined(); |
| 220 | + expect(t5Dispatch).toBeDefined(); |
| 221 | + expect(vaeDispatch).toBeDefined(); |
| 222 | + |
| 223 | + // The payloads must pass zModelIdentifierField validation (the actual schema used by reducers) |
| 224 | + expect(zModelIdentifierField.safeParse(qwen3Dispatch!.payload).success).toBe(true); |
| 225 | + expect(zModelIdentifierField.safeParse(t5Dispatch!.payload).success).toBe(true); |
| 226 | + expect(zModelIdentifierField.safeParse(vaeDispatch!.payload).success).toBe(true); |
| 227 | + }); |
| 228 | + |
| 229 | + it('should include hash and type in Qwen3 encoder payload', () => { |
| 230 | + const state = buildMockState({ model: mockFluxMainModel }); |
| 231 | + const action = modelSelected(zParameterModel.parse(mockAnimaMainModel)); |
| 232 | + |
| 233 | + capturedEffect!(action, { |
| 234 | + getState: () => state, |
| 235 | + dispatch: mockDispatch, |
| 236 | + }); |
| 237 | + |
| 238 | + const qwen3Dispatch = dispatched.find((a) => a.type === animaQwen3EncoderModelSelected.type); |
| 239 | + expect(qwen3Dispatch!.payload).toMatchObject({ |
| 240 | + key: mockAnimaQwen3Encoder.key, |
| 241 | + hash: mockAnimaQwen3Encoder.hash, |
| 242 | + name: mockAnimaQwen3Encoder.name, |
| 243 | + base: mockAnimaQwen3Encoder.base, |
| 244 | + type: mockAnimaQwen3Encoder.type, |
| 245 | + }); |
| 246 | + }); |
| 247 | + |
| 248 | + it('should include hash and type in T5 encoder payload', () => { |
| 249 | + const state = buildMockState({ model: mockFluxMainModel }); |
| 250 | + const action = modelSelected(zParameterModel.parse(mockAnimaMainModel)); |
| 251 | + |
| 252 | + capturedEffect!(action, { |
| 253 | + getState: () => state, |
| 254 | + dispatch: mockDispatch, |
| 255 | + }); |
| 256 | + |
| 257 | + const t5Dispatch = dispatched.find((a) => a.type === animaT5EncoderModelSelected.type); |
| 258 | + expect(t5Dispatch!.payload).toMatchObject({ |
| 259 | + key: mockT5Encoder.key, |
| 260 | + hash: mockT5Encoder.hash, |
| 261 | + name: mockT5Encoder.name, |
| 262 | + base: mockT5Encoder.base, |
| 263 | + type: mockT5Encoder.type, |
| 264 | + }); |
| 265 | + }); |
| 266 | + |
| 267 | + it('should not dispatch encoder defaults when Anima models are already set', () => { |
| 268 | + const existingQwen3 = { key: 'existing', hash: 'h', name: 'Existing', base: 'any', type: 'qwen3_encoder' }; |
| 269 | + const existingT5 = { key: 'existing-t5', hash: 'h', name: 'Existing T5', base: 'any', type: 't5_encoder' }; |
| 270 | + const existingVae = { key: 'existing-vae', hash: 'h', name: 'Existing VAE', base: 'anima', type: 'vae' }; |
| 271 | + |
| 272 | + const state = buildMockState({ |
| 273 | + model: mockFluxMainModel, |
| 274 | + animaQwen3EncoderModel: existingQwen3, |
| 275 | + animaT5EncoderModel: existingT5, |
| 276 | + animaVaeModel: existingVae, |
| 277 | + }); |
| 278 | + |
| 279 | + const action = modelSelected(zParameterModel.parse(mockAnimaMainModel)); |
| 280 | + |
| 281 | + capturedEffect!(action, { |
| 282 | + getState: () => state, |
| 283 | + dispatch: mockDispatch, |
| 284 | + }); |
| 285 | + |
| 286 | + // Should NOT dispatch any encoder model selections since they're already set |
| 287 | + const qwen3Dispatch = dispatched.find((a) => a.type === animaQwen3EncoderModelSelected.type); |
| 288 | + const t5Dispatch = dispatched.find((a) => a.type === animaT5EncoderModelSelected.type); |
| 289 | + const vaeDispatch = dispatched.find((a) => a.type === animaVaeModelSelected.type); |
| 290 | + |
| 291 | + expect(qwen3Dispatch).toBeUndefined(); |
| 292 | + expect(t5Dispatch).toBeUndefined(); |
| 293 | + expect(vaeDispatch).toBeUndefined(); |
| 294 | + }); |
| 295 | + |
| 296 | + it('should not dispatch encoder defaults when no encoder models are available', () => { |
| 297 | + mockSelectAnimaQwen3EncoderModels.mockReturnValue([]); |
| 298 | + mockSelectAnimaVAEModels.mockReturnValue([]); |
| 299 | + |
| 300 | + const state = buildMockState({ model: mockFluxMainModel }); |
| 301 | + const action = modelSelected(zParameterModel.parse(mockAnimaMainModel)); |
| 302 | + |
| 303 | + capturedEffect!(action, { |
| 304 | + getState: () => state, |
| 305 | + dispatch: mockDispatch, |
| 306 | + }); |
| 307 | + |
| 308 | + const qwen3Dispatch = dispatched.find((a) => a.type === animaQwen3EncoderModelSelected.type); |
| 309 | + const t5Dispatch = dispatched.find((a) => a.type === animaT5EncoderModelSelected.type); |
| 310 | + const vaeDispatch = dispatched.find((a) => a.type === animaVaeModelSelected.type); |
| 311 | + |
| 312 | + expect(qwen3Dispatch).toBeUndefined(); |
| 313 | + expect(t5Dispatch).toBeUndefined(); |
| 314 | + expect(vaeDispatch).toBeUndefined(); |
| 315 | + }); |
| 316 | + |
| 317 | + it('should clear Anima models when switching away from Anima', () => { |
| 318 | + const existingQwen3 = { key: 'existing', hash: 'h', name: 'Existing', base: 'any', type: 'qwen3_encoder' }; |
| 319 | + const existingT5 = { key: 'existing-t5', hash: 'h', name: 'Existing T5', base: 'any', type: 't5_encoder' }; |
| 320 | + const existingVae = { key: 'existing-vae', hash: 'h', name: 'Existing VAE', base: 'anima', type: 'vae' }; |
| 321 | + |
| 322 | + const state = buildMockState({ |
| 323 | + model: mockAnimaMainModel, |
| 324 | + animaQwen3EncoderModel: existingQwen3, |
| 325 | + animaT5EncoderModel: existingT5, |
| 326 | + animaVaeModel: existingVae, |
| 327 | + }); |
| 328 | + |
| 329 | + const action = modelSelected(zParameterModel.parse(mockFluxMainModel)); |
| 330 | + |
| 331 | + capturedEffect!(action, { |
| 332 | + getState: () => state, |
| 333 | + dispatch: mockDispatch, |
| 334 | + }); |
| 335 | + |
| 336 | + // Should dispatch null for all three |
| 337 | + const qwen3Dispatch = dispatched.find((a) => a.type === animaQwen3EncoderModelSelected.type); |
| 338 | + const t5Dispatch = dispatched.find((a) => a.type === animaT5EncoderModelSelected.type); |
| 339 | + const vaeDispatch = dispatched.find((a) => a.type === animaVaeModelSelected.type); |
| 340 | + |
| 341 | + expect(qwen3Dispatch).toBeDefined(); |
| 342 | + expect(qwen3Dispatch!.payload).toBeNull(); |
| 343 | + expect(t5Dispatch).toBeDefined(); |
| 344 | + expect(t5Dispatch!.payload).toBeNull(); |
| 345 | + expect(vaeDispatch).toBeDefined(); |
| 346 | + expect(vaeDispatch!.payload).toBeNull(); |
| 347 | + }); |
| 348 | +}); |
| 349 | + |
| 350 | +describe('zModelIdentifierField schema validation', () => { |
| 351 | + it('should reject payloads missing hash and type', () => { |
| 352 | + const incomplete = { key: 'some-key', name: 'Some Model', base: 'any' }; |
| 353 | + expect(zModelIdentifierField.safeParse(incomplete).success).toBe(false); |
| 354 | + }); |
| 355 | + |
| 356 | + it('should accept payloads with all required fields', () => { |
| 357 | + const complete = { key: 'some-key', hash: 'some-hash', name: 'Some Model', base: 'any', type: 'qwen3_encoder' }; |
| 358 | + expect(zModelIdentifierField.safeParse(complete).success).toBe(true); |
| 359 | + }); |
| 360 | +}); |
0 commit comments