diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index 65b059ecfce..632b5e26f0e 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -26,9 +26,11 @@ from invokeai.app.services.model_records import ( InvalidModelException, ModelRecordChanges, + ModelRecordOrderBy, UnknownModelException, ) from invokeai.app.services.orphaned_models import OrphanedModelInfo +from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.app.util.suppress_output import SuppressOutput from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory from invokeai.backend.model_manager.configs.main import ( @@ -130,6 +132,8 @@ async def list_model_records( model_format: Optional[ModelFormat] = Query( default=None, description="Exact match on the format of the model (e.g. 'diffusers')" ), + order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Name, description="The field to order by"), + direction: SQLiteDirection = Query(default=SQLiteDirection.Ascending, description="The direction to order by"), ) -> ModelsList: """Get a list of models.""" record_store = ApiDependencies.invoker.services.model_manager.store @@ -138,12 +142,23 @@ async def list_model_records( for base_model in base_models: found_models.extend( record_store.search_by_attr( - base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format + base_model=base_model, + model_type=model_type, + model_name=model_name, + model_format=model_format, + order_by=order_by, + direction=direction, ) ) else: found_models.extend( - record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format) + record_store.search_by_attr( + model_type=model_type, + model_name=model_name, + model_format=model_format, + order_by=order_by, + direction=direction, + ) ) for model in found_models: model = add_cover_image_to_model_config(model, ApiDependencies) diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index dcdc0ce5956..c6a9463a998 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -11,6 +11,7 @@ from pydantic import BaseModel, Field from invokeai.app.services.shared.pagination import PaginatedResults +from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.app.util.model_exclude_null import BaseModelExcludeNull from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings from invokeai.backend.model_manager.configs.factory import AnyModelConfig @@ -56,6 +57,10 @@ class ModelRecordOrderBy(str, Enum): Base = "base" Name = "name" Format = "format" + Size = "size" + DateAdded = "created_at" + DateModified = "updated_at" + Path = "path" class ModelSummary(BaseModel): @@ -185,7 +190,11 @@ def get_model_by_hash(self, hash: str) -> AnyModelConfig: @abstractmethod def list_models( - self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default + self, + page: int = 0, + per_page: int = 10, + order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default, + direction: SQLiteDirection = SQLiteDirection.Ascending, ) -> PaginatedResults[ModelSummary]: """Return a paginated summary listing of each model in the database.""" pass @@ -222,6 +231,8 @@ def search_by_attr( base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None, model_format: Optional[ModelFormat] = None, + order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default, + direction: SQLiteDirection = SQLiteDirection.Ascending, ) -> List[AnyModelConfig]: """ Return models matching name, base and/or type. diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index edcbba2acdc..f104c3855e7 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -57,6 +57,7 @@ UnknownModelException, ) from invokeai.app.services.shared.pagination import PaginatedResults +from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType @@ -257,6 +258,7 @@ def search_by_attr( model_type: Optional[ModelType] = None, model_format: Optional[ModelFormat] = None, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default, + direction: SQLiteDirection = SQLiteDirection.Ascending, ) -> List[AnyModelConfig]: """ Return models matching name, base and/or type. @@ -266,18 +268,24 @@ def search_by_attr( :param model_type: Filter by type of model (optional) :param model_format: Filter by model format (e.g. "diffusers") (optional) :param order_by: Result order + :param direction: Result direction If none of the optional filters are passed, will return all models in the database. """ with self._db.transaction() as cursor: assert isinstance(order_by, ModelRecordOrderBy) + order_dir = "DESC" if direction == SQLiteDirection.Descending else "ASC" ordering = { - ModelRecordOrderBy.Default: "type, base, name, format", + ModelRecordOrderBy.Default: f"type {order_dir}, base COLLATE NOCASE {order_dir}, name COLLATE NOCASE {order_dir}, format", ModelRecordOrderBy.Type: "type", - ModelRecordOrderBy.Base: "base", - ModelRecordOrderBy.Name: "name", + ModelRecordOrderBy.Base: "base COLLATE NOCASE", + ModelRecordOrderBy.Name: "name COLLATE NOCASE", ModelRecordOrderBy.Format: "format", + ModelRecordOrderBy.Size: "IFNULL(json_extract(config, '$.file_size'), 0)", + ModelRecordOrderBy.DateAdded: "created_at", + ModelRecordOrderBy.DateModified: "updated_at", + ModelRecordOrderBy.Path: "path", } where_clause: list[str] = [] @@ -301,7 +309,7 @@ def search_by_attr( SELECT config FROM models {where} - ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason; + ORDER BY {ordering[order_by]} {order_dir} -- using ? to bind doesn't work here for some reason; """, tuple(bindings), ) @@ -357,17 +365,26 @@ def search_by_hash(self, hash: str) -> List[AnyModelConfig]: return results def list_models( - self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default + self, + page: int = 0, + per_page: int = 10, + order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default, + direction: SQLiteDirection = SQLiteDirection.Ascending, ) -> PaginatedResults[ModelSummary]: """Return a paginated summary listing of each model in the database.""" with self._db.transaction() as cursor: assert isinstance(order_by, ModelRecordOrderBy) + order_dir = "DESC" if direction == SQLiteDirection.Descending else "ASC" ordering = { - ModelRecordOrderBy.Default: "type, base, name, format", + ModelRecordOrderBy.Default: f"type {order_dir}, base COLLATE NOCASE {order_dir}, name COLLATE NOCASE {order_dir}, format", ModelRecordOrderBy.Type: "type", - ModelRecordOrderBy.Base: "base", - ModelRecordOrderBy.Name: "name", + ModelRecordOrderBy.Base: "base COLLATE NOCASE", + ModelRecordOrderBy.Name: "name COLLATE NOCASE", ModelRecordOrderBy.Format: "format", + ModelRecordOrderBy.Size: "IFNULL(json_extract(config, '$.file_size'), 0)", + ModelRecordOrderBy.DateAdded: "created_at", + ModelRecordOrderBy.DateModified: "updated_at", + ModelRecordOrderBy.Path: "path", } # Lock so that the database isn't updated while we're doing the two queries. @@ -385,7 +402,7 @@ def list_models( f"""--sql SELECT config FROM models - ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason + ORDER BY {ordering[order_by]} {order_dir} -- using ? to bind doesn't work here for some reason LIMIT ? OFFSET ?; """, diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 9b2aaddad73..fcac3efa0a0 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -1165,6 +1165,15 @@ "modelType": "Model Type", "modelUpdated": "Model Updated", "modelUpdateFailed": "Model Update Failed", + "sortByName": "Name", + "sortByBase": "Base", + "sortBySize": "Size", + "sortByDateAdded": "Date Added", + "sortByDateModified": "Date Modified", + "sortByPath": "Path", + "sortByType": "Type", + "sortByFormat": "Format", + "sortDefault": "Default", "name": "Name", "modelPickerFallbackNoModelsInstalled": "No models installed.", "modelPickerFallbackNoModelsInstalled2": "Visit the Model Manager to install models.", diff --git a/invokeai/frontend/web/src/common/hooks/useSubMenu.tsx b/invokeai/frontend/web/src/common/hooks/useSubMenu.tsx index f8ea01909a7..4c1bc56e495 100644 --- a/invokeai/frontend/web/src/common/hooks/useSubMenu.tsx +++ b/invokeai/frontend/web/src/common/hooks/useSubMenu.tsx @@ -151,11 +151,18 @@ export const useSubMenu = (): UseSubMenuReturn => { }; }; -export const SubMenuButtonContent = ({ label }: { label: string }) => { +export const SubMenuButtonContent = ({ label, value }: { label: string; value?: string }) => { return ( {label} - + + {value !== undefined && ( + + {value} + + )} + + ); }; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/models.ts b/invokeai/frontend/web/src/features/modelManagerV2/models.ts index 5cf18b337b0..e06ae38fcb0 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/models.ts +++ b/invokeai/frontend/web/src/features/modelManagerV2/models.ts @@ -30,7 +30,7 @@ export type ModelCategoryData = { filter: (config: AnyModelConfig) => boolean; }; -export const MODEL_CATEGORIES: Record = { +const MODEL_CATEGORIES: Record = { unknown: { category: 'unknown', i18nKey: 'common.unknown', diff --git a/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts b/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts index 092998d0c31..b6aea4138bb 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts +++ b/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts @@ -22,6 +22,10 @@ const zModelManagerState = z.object({ scanPath: z.string().optional(), shouldInstallInPlace: z.boolean(), selectedModelKeys: z.array(z.string()), + orderBy: z + .enum(['default', 'name', 'type', 'base', 'size', 'created_at', 'updated_at', 'path', 'format']) + .default('name'), + sortDirection: z.enum(['asc', 'desc']).default('asc'), }); type ModelManagerState = z.infer; @@ -35,6 +39,8 @@ const getInitialState = (): ModelManagerState => ({ scanPath: undefined, shouldInstallInPlace: true, selectedModelKeys: [], + orderBy: 'name', + sortDirection: 'asc', }); const slice = createSlice({ @@ -74,6 +80,12 @@ const slice = createSlice({ clearModelSelection: (state) => { state.selectedModelKeys = []; }, + setOrderBy: (state, action: PayloadAction) => { + state.orderBy = action.payload; + }, + setSortDirection: (state, action: PayloadAction) => { + state.sortDirection = action.payload; + }, }, }); @@ -87,6 +99,8 @@ export const { modelSelectionChanged, toggleModelSelection, clearModelSelection, + setOrderBy, + setSortDirection, } = slice.actions; export const modelManagerSliceConfig: SliceConfig = { @@ -116,3 +130,5 @@ export const selectSearchTerm = createModelManagerSelector((mm) => mm.searchTerm export const selectFilteredModelType = createModelManagerSelector((mm) => mm.filteredModelType); export const selectShouldInstallInPlace = createModelManagerSelector((mm) => mm.shouldInstallInPlace); export const selectSelectedModelKeys = createModelManagerSelector((mm) => mm.selectedModelKeys); +export const selectOrderBy = createModelManagerSelector((mm) => mm.orderBy); +export const selectSortDirection = createModelManagerSelector((mm) => mm.sortDirection); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFilterMenu.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFilterMenu.tsx new file mode 100644 index 00000000000..370ffc05415 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFilterMenu.tsx @@ -0,0 +1,238 @@ +import { Button, Flex, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu'; +import type { ModelCategoryData } from 'features/modelManagerV2/models'; +import { MODEL_CATEGORIES_AS_LIST } from 'features/modelManagerV2/models'; +import { + selectFilteredModelType, + selectOrderBy, + selectSortDirection, + setFilteredModelType, + setOrderBy, + setSortDirection, +} from 'features/modelManagerV2/store/modelManagerV2Slice'; +import { memo, useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { + PiCheckBold, + PiFunnelBold, + PiListBold, + PiSortAscendingBold, + PiSortDescendingBold, + PiWarningBold, +} from 'react-icons/pi'; + +type OrderBy = 'default' | 'name' | 'type' | 'base' | 'size' | 'created_at' | 'updated_at' | 'path' | 'format'; + +const ORDER_BY_OPTIONS: OrderBy[] = [ + 'default', + 'name', + 'base', + 'size', + 'created_at', + 'updated_at', + 'path', + 'type', + 'format', +]; + +const SortByMenuItem = memo(({ option, label }: { option: OrderBy; label: string }) => { + const dispatch = useAppDispatch(); + const orderBy = useAppSelector(selectOrderBy); + const onClick = useCallback(() => { + dispatch(setOrderBy(option)); + }, [dispatch, option]); + + return ( + : } + > + {label} + + ); +}); +SortByMenuItem.displayName = 'SortByMenuItem'; + +const SortBySubMenu = memo(() => { + const { t } = useTranslation(); + const subMenu = useSubMenu(); + const orderBy = useAppSelector(selectOrderBy); + + const ORDER_BY_LABELS = useMemo( + () => ({ + default: t('modelManager.sortDefault'), + name: t('modelManager.sortByName'), + base: t('modelManager.sortByBase'), + size: t('modelManager.sortBySize'), + created_at: t('modelManager.sortByDateAdded'), + updated_at: t('modelManager.sortByDateModified'), + path: t('modelManager.sortByPath'), + type: t('modelManager.sortByType'), + format: t('modelManager.sortByFormat'), + }), + [t] + ); + + return ( + }> + + + + + + {ORDER_BY_OPTIONS.map((option) => ( + + ))} + + + + ); +}); +SortBySubMenu.displayName = 'SortBySubMenu'; + +const DirectionSubMenu = memo(() => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const direction = useAppSelector(selectSortDirection); + const subMenu = useSubMenu(); + + const setDirectionAsc = useCallback(() => { + dispatch(setSortDirection('asc')); + }, [dispatch]); + + const setDirectionDesc = useCallback(() => { + dispatch(setSortDirection('desc')); + }, [dispatch]); + + const currentValue = direction === 'asc' ? t('common.ascending', 'Ascending') : t('common.descending', 'Descending'); + + return ( + : } + > + + + + + + : } + > + {t('common.ascending', 'Ascending')} + + : } + > + {t('common.descending', 'Descending')} + + + + + ); +}); +DirectionSubMenu.displayName = 'DirectionSubMenu'; + +const ModelTypeSubMenu = memo(() => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const filteredModelType = useAppSelector(selectFilteredModelType); + const subMenu = useSubMenu(); + + const clearModelType = useCallback(() => { + dispatch(setFilteredModelType(null)); + }, [dispatch]); + + const setMissingFilter = useCallback(() => { + dispatch(setFilteredModelType('missing')); + }, [dispatch]); + + const currentValue = useMemo(() => { + if (filteredModelType === null) { + return t('modelManager.allModels'); + } + if (filteredModelType === 'missing') { + return t('modelManager.missingFiles'); + } + const categoryData = MODEL_CATEGORIES_AS_LIST.find((data) => data.category === filteredModelType); + return categoryData ? t(categoryData.i18nKey) : ''; + }, [filteredModelType, t]); + + return ( + }> + + + + + + : } + > + {t('modelManager.allModels')} + + : } + > + + {filteredModelType !== 'missing' && } + {t('modelManager.missingFiles')} + + + {MODEL_CATEGORIES_AS_LIST.map((data) => ( + + ))} + + + + ); +}); +ModelTypeSubMenu.displayName = 'ModelTypeSubMenu'; + +const ModelMenuItem = memo(({ data }: { data: ModelCategoryData }) => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const filteredModelType = useAppSelector(selectFilteredModelType); + const onClick = useCallback(() => { + dispatch(setFilteredModelType(data.category)); + }, [data.category, dispatch]); + return ( + : } + > + {t(data.i18nKey)} + + ); +}); +ModelMenuItem.displayName = 'ModelMenuItem'; + +export const ModelFilterMenu = memo(() => { + const { t } = useTranslation(); + + return ( + + }> + {t('common.filtering', 'Filtering')} + + + + + + + + ); +}); + +ModelFilterMenu.displayName = 'ModelFilterMenu'; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx index ed49fa2870b..033a439bfcc 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx @@ -8,8 +8,10 @@ import { clearModelSelection, type FilterableModelType, selectFilteredModelType, + selectOrderBy, selectSearchTerm, selectSelectedModelKeys, + selectSortDirection, setSelectedModelKey, } from 'features/modelManagerV2/store/modelManagerV2Slice'; import { memo, useCallback, useMemo, useState } from 'react'; @@ -39,6 +41,8 @@ const ModelList = () => { const dispatch = useAppDispatch(); const filteredModelType = useAppSelector(selectFilteredModelType); const searchTerm = useAppSelector(selectSearchTerm); + const orderBy = useAppSelector(selectOrderBy); + const direction = useAppSelector(selectSortDirection); const selectedModelKeys = useAppSelector(selectSelectedModelKeys); const { t } = useTranslation(); const toast = useToast(); @@ -47,7 +51,8 @@ const ModelList = () => { const [isDeleting, setIsDeleting] = useState(false); const [isReidentifying, setIsReidentifying] = useState(false); - const { data: allModelsData, isLoading: isLoadingAll } = useGetModelConfigsQuery(); + const queryArgs = useMemo(() => ({ order_by: orderBy, direction: direction.toUpperCase() }), [orderBy, direction]); + const { data: allModelsData, isLoading: isLoadingAll } = useGetModelConfigsQuery(queryArgs); const { data: missingModelsData, isLoading: isLoadingMissing } = useGetMissingModelsQuery(); const [bulkDeleteModels] = useBulkDeleteModelsMutation(); const [bulkReidentifyModels] = useBulkReidentifyModelsMutation(); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListNavigation.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListNavigation.tsx index 78bed8ab830..bbfb88df5cf 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListNavigation.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListNavigation.tsx @@ -6,8 +6,8 @@ import type { ChangeEventHandler } from 'react'; import { memo, useCallback } from 'react'; import { PiXBold } from 'react-icons/pi'; +import { ModelFilterMenu } from './ModelFilterMenu'; import { ModelListBulkActions } from './ModelListBulkActions'; -import { ModelTypeFilter } from './ModelTypeFilter'; export const ModelListNavigation = memo(() => { const dispatch = useAppDispatch(); @@ -50,7 +50,7 @@ export const ModelListNavigation = memo(() => { - + diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx deleted file mode 100644 index 5aa8e628869..00000000000 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx +++ /dev/null @@ -1,78 +0,0 @@ -import { Button, Flex, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import type { ModelCategoryData } from 'features/modelManagerV2/models'; -import { MODEL_CATEGORIES, MODEL_CATEGORIES_AS_LIST } from 'features/modelManagerV2/models'; -import type { ModelCategoryType } from 'features/modelManagerV2/store/modelManagerV2Slice'; -import { selectFilteredModelType, setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice'; -import { memo, useCallback } from 'react'; -import { useTranslation } from 'react-i18next'; -import { PiFunnelBold, PiWarningBold } from 'react-icons/pi'; - -const isModelCategoryType = (type: string): type is ModelCategoryType => { - return type in MODEL_CATEGORIES; -}; - -export const ModelTypeFilter = memo(() => { - const { t } = useTranslation(); - const dispatch = useAppDispatch(); - const filteredModelType = useAppSelector(selectFilteredModelType); - - const clearModelType = useCallback(() => { - dispatch(setFilteredModelType(null)); - }, [dispatch]); - - const setMissingFilter = useCallback(() => { - dispatch(setFilteredModelType('missing')); - }, [dispatch]); - - const getButtonLabel = () => { - if (filteredModelType === 'missing') { - return t('modelManager.missingFiles'); - } - if (filteredModelType && isModelCategoryType(filteredModelType)) { - return t(MODEL_CATEGORIES[filteredModelType].i18nKey); - } - return t('modelManager.allModels'); - }; - - return ( - - }> - {getButtonLabel()} - - - {t('modelManager.allModels')} - - - - {t('modelManager.missingFiles')} - - - {MODEL_CATEGORIES_AS_LIST.map((data) => ( - - ))} - - - ); -}); - -ModelTypeFilter.displayName = 'ModelTypeFilter'; - -const ModelMenuItem = memo(({ data }: { data: ModelCategoryData }) => { - const { t } = useTranslation(); - const dispatch = useAppDispatch(); - const filteredModelType = useAppSelector(selectFilteredModelType); - const onClick = useCallback(() => { - dispatch(setFilteredModelType(data.category)); - }, [data.category, dispatch]); - return ( - - {t(data.i18nKey)} - - ); -}); -ModelMenuItem.displayName = 'ModelMenuItem'; diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index c3d0decd53c..f279d46d823 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -111,9 +111,13 @@ type DeleteOrphanedModelsResponse = { errors: Record; }; +type GetModelConfigsArg = { + order_by?: string; + direction?: string; +} | void; + const modelConfigsAdapter = createEntityAdapter({ selectId: (entity) => entity.key, - sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const modelConfigsAdapterSelectors = modelConfigsAdapter.getSelectors(undefined, getSelectorsOptions); @@ -338,8 +342,11 @@ export const modelsApi = api.injectEndpoints({ }, invalidatesTags: ['ModelInstalls'], }), - getModelConfigs: build.query, void>({ - query: () => ({ url: buildModelsUrl() }), + getModelConfigs: build.query, GetModelConfigsArg>({ + query: (arg) => { + const queryStr = arg ? `?${queryString.stringify(arg)}` : ''; + return { url: buildModelsUrl(queryStr) }; + }, providesTags: (result) => { const tags: ApiTagDescription[] = [{ type: 'ModelConfig', id: LIST_TAG }]; if (result) { @@ -498,5 +505,5 @@ export const { useDeleteOrphanedModelsMutation, } = modelsApi; -export const selectModelConfigsQuery = modelsApi.endpoints.getModelConfigs.select(); +export const selectModelConfigsQuery = modelsApi.endpoints.getModelConfigs.select(undefined); export const selectMissingModelsQuery = modelsApi.endpoints.getMissingModels.select(); diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index f7ef229bf31..758428c6407 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -22429,6 +22429,12 @@ export type components = { */ config_path?: string | null; }; + /** + * ModelRecordOrderBy + * @description The order in which to return model summaries. + * @enum {string} + */ + ModelRecordOrderBy: "default" | "type" | "base" | "name" | "format" | "size" | "created_at" | "updated_at" | "path"; /** ModelRelationshipBatchRequest */ ModelRelationshipBatchRequest: { /** @@ -30787,6 +30793,10 @@ export interface operations { model_name?: string | null; /** @description Exact match on the format of the model (e.g. 'diffusers') */ model_format?: components["schemas"]["ModelFormat"] | null; + /** @description The field to order by */ + order_by?: components["schemas"]["ModelRecordOrderBy"]; + /** @description The direction to order by */ + direction?: components["schemas"]["SQLiteDirection"]; }; header?: never; path?: never; diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index 2b6c54d5b0f..19a1b74e73f 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -11,11 +11,13 @@ from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.model_records import ( DuplicateModelException, + ModelRecordOrderBy, ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException, ) from invokeai.app.services.model_records.model_records_base import ModelRecordChanges +from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings from invokeai.backend.model_manager.configs.lora import LoRA_LyCORIS_SDXL_Config from invokeai.backend.model_manager.configs.main import ( @@ -364,6 +366,73 @@ def test_filter_2(store: ModelRecordServiceBase): assert len(matches) == 1 +def test_search_by_attr_sorting(store: ModelRecordServiceSQL): + config1 = Main_Diffusers_SD1_Config( + path="/tmp/config1", + name="alpha", + base=BaseModelType.StableDiffusion1, + type=ModelType.Main, + hash="CONFIG1HASH", + file_size=1000, + source="test/source/", + source_type=ModelSourceType.Path, + variant=ModelVariantType.Normal, + prediction_type=SchedulerPredictionType.Epsilon, + repo_variant=ModelRepoVariant.Default, + ) + config2 = Main_Diffusers_SD2_Config( + path="/tmp/config2", + name="beta", + base=BaseModelType.StableDiffusion2, + type=ModelType.Main, + hash="CONFIG2HASH", + file_size=2000, + source="test/source/", + source_type=ModelSourceType.Path, + variant=ModelVariantType.Normal, + prediction_type=SchedulerPredictionType.Epsilon, + repo_variant=ModelRepoVariant.Default, + ) + config3 = VAE_Diffusers_SD1_Config( + path="/tmp/config3", + name="gamma", + base=BaseModelType.StableDiffusion1, + type=ModelType.VAE, + hash="CONFIG3HASH", + file_size=500, + source="test/source/", + source_type=ModelSourceType.Path, + repo_variant=ModelRepoVariant.Default, + ) + for c in config1, config2, config3: + store.add_model(c) + + # Test sorting by Name Ascending + matches = store.search_by_attr(order_by=ModelRecordOrderBy.Name, direction=SQLiteDirection.Ascending) + assert len(matches) == 3 + assert matches[0].name == "alpha" + assert matches[1].name == "beta" + assert matches[2].name == "gamma" + + # Test sorting by Name Descending + matches = store.search_by_attr(order_by=ModelRecordOrderBy.Name, direction=SQLiteDirection.Descending) + assert matches[0].name == "gamma" + assert matches[1].name == "beta" + assert matches[2].name == "alpha" + + # Test sorting by Size Ascending + matches = store.search_by_attr(order_by=ModelRecordOrderBy.Size, direction=SQLiteDirection.Ascending) + assert matches[0].name == "gamma" # 500 + assert matches[1].name == "alpha" # 1000 + assert matches[2].name == "beta" # 2000 + + # Test sorting by Size Descending + matches = store.search_by_attr(order_by=ModelRecordOrderBy.Size, direction=SQLiteDirection.Descending) + assert matches[0].name == "beta" # 2000 + assert matches[1].name == "alpha" # 1000 + assert matches[2].name == "gamma" # 500 + + def test_model_record_changes(): # This test guards against some unexpected behaviours from pydantic's union evaluation. See #6035 changes = ModelRecordChanges.model_validate({"default_settings": {"preprocessor": "value"}})