diff --git a/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts index 2b7c0f7d177..de55d476ec5 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts @@ -35,6 +35,14 @@ type PayloadActionWithId = T extends void } & T >; +/** Fingerprint used to match the same reference image entry after recall when ids are regenerated. */ +const getRefImageRecallMatchKey = (entity: RefImageState): string => { + const { config } = entity; + const imageName = config.image?.original.image.image_name ?? config.image?.crop?.image.image_name ?? ''; + const modelKey = 'model' in config && config.model ? config.model.key : ''; + return `${config.type}\0${modelKey}\0${imageName}`; +}; + const slice = createSlice({ name: 'refImages', initialState: getInitialRefImagesState(), @@ -54,13 +62,36 @@ const slice = createSlice({ }, refImagesRecalled: (state, action: PayloadAction<{ entities: RefImageState[]; replace: boolean }>) => { const { entities, replace } = action.payload; - if (replace) { - state.entities = entities; + if (!replace) { + state.entities.push(...entities); + return; + } + const wasPanelOpen = state.isPanelOpen; + const previousSelectedId = state.selectedEntityId; + let previousEntity: RefImageState | null = null; + if (previousSelectedId !== null) { + previousEntity = state.entities.find((e) => e.id === previousSelectedId) ?? null; + } + state.entities = entities; + if (entities.length === 0) { + state.selectedEntityId = null; state.isPanelOpen = false; + return; + } + if (!wasPanelOpen) { state.selectedEntityId = null; - } else { - state.entities.push(...entities); + return; + } + const firstEntity = entities[0]; + assert(firstEntity); + if (previousSelectedId !== null && entities.some((e) => e.id === previousSelectedId)) { + state.selectedEntityId = previousSelectedId; + return; } + const previousKey = previousEntity ? getRefImageRecallMatchKey(previousEntity) : null; + const matched = + previousKey !== null ? entities.find((e) => getRefImageRecallMatchKey(e) === previousKey) : undefined; + state.selectedEntityId = matched?.id ?? firstEntity.id; }, refImageImageChanged: (state, action: PayloadActionWithId<{ croppableImage: CroppableImageWithDims | null }>) => { const { id, croppableImage } = action.payload;