diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py
index 65b059ecfce..632b5e26f0e 100644
--- a/invokeai/app/api/routers/model_manager.py
+++ b/invokeai/app/api/routers/model_manager.py
@@ -26,9 +26,11 @@
from invokeai.app.services.model_records import (
InvalidModelException,
ModelRecordChanges,
+ ModelRecordOrderBy,
UnknownModelException,
)
from invokeai.app.services.orphaned_models import OrphanedModelInfo
+from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.util.suppress_output import SuppressOutput
from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory
from invokeai.backend.model_manager.configs.main import (
@@ -130,6 +132,8 @@ async def list_model_records(
model_format: Optional[ModelFormat] = Query(
default=None, description="Exact match on the format of the model (e.g. 'diffusers')"
),
+ order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Name, description="The field to order by"),
+ direction: SQLiteDirection = Query(default=SQLiteDirection.Ascending, description="The direction to order by"),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_manager.store
@@ -138,12 +142,23 @@ async def list_model_records(
for base_model in base_models:
found_models.extend(
record_store.search_by_attr(
- base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format
+ base_model=base_model,
+ model_type=model_type,
+ model_name=model_name,
+ model_format=model_format,
+ order_by=order_by,
+ direction=direction,
)
)
else:
found_models.extend(
- record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
+ record_store.search_by_attr(
+ model_type=model_type,
+ model_name=model_name,
+ model_format=model_format,
+ order_by=order_by,
+ direction=direction,
+ )
)
for model in found_models:
model = add_cover_image_to_model_config(model, ApiDependencies)
diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py
index dcdc0ce5956..c6a9463a998 100644
--- a/invokeai/app/services/model_records/model_records_base.py
+++ b/invokeai/app/services/model_records/model_records_base.py
@@ -11,6 +11,7 @@
from pydantic import BaseModel, Field
from invokeai.app.services.shared.pagination import PaginatedResults
+from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
@@ -56,6 +57,10 @@ class ModelRecordOrderBy(str, Enum):
Base = "base"
Name = "name"
Format = "format"
+ Size = "size"
+ DateAdded = "created_at"
+ DateModified = "updated_at"
+ Path = "path"
class ModelSummary(BaseModel):
@@ -185,7 +190,11 @@ def get_model_by_hash(self, hash: str) -> AnyModelConfig:
@abstractmethod
def list_models(
- self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
+ self,
+ page: int = 0,
+ per_page: int = 10,
+ order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default,
+ direction: SQLiteDirection = SQLiteDirection.Ascending,
) -> PaginatedResults[ModelSummary]:
"""Return a paginated summary listing of each model in the database."""
pass
@@ -222,6 +231,8 @@ def search_by_attr(
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
model_format: Optional[ModelFormat] = None,
+ order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default,
+ direction: SQLiteDirection = SQLiteDirection.Ascending,
) -> List[AnyModelConfig]:
"""
Return models matching name, base and/or type.
diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py
index edcbba2acdc..f104c3855e7 100644
--- a/invokeai/app/services/model_records/model_records_sql.py
+++ b/invokeai/app/services/model_records/model_records_sql.py
@@ -57,6 +57,7 @@
UnknownModelException,
)
from invokeai.app.services.shared.pagination import PaginatedResults
+from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
@@ -257,6 +258,7 @@ def search_by_attr(
model_type: Optional[ModelType] = None,
model_format: Optional[ModelFormat] = None,
order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default,
+ direction: SQLiteDirection = SQLiteDirection.Ascending,
) -> List[AnyModelConfig]:
"""
Return models matching name, base and/or type.
@@ -266,18 +268,24 @@ def search_by_attr(
:param model_type: Filter by type of model (optional)
:param model_format: Filter by model format (e.g. "diffusers") (optional)
:param order_by: Result order
+ :param direction: Result direction
If none of the optional filters are passed, will return all
models in the database.
"""
with self._db.transaction() as cursor:
assert isinstance(order_by, ModelRecordOrderBy)
+ order_dir = "DESC" if direction == SQLiteDirection.Descending else "ASC"
ordering = {
- ModelRecordOrderBy.Default: "type, base, name, format",
+ ModelRecordOrderBy.Default: f"type {order_dir}, base COLLATE NOCASE {order_dir}, name COLLATE NOCASE {order_dir}, format",
ModelRecordOrderBy.Type: "type",
- ModelRecordOrderBy.Base: "base",
- ModelRecordOrderBy.Name: "name",
+ ModelRecordOrderBy.Base: "base COLLATE NOCASE",
+ ModelRecordOrderBy.Name: "name COLLATE NOCASE",
ModelRecordOrderBy.Format: "format",
+ ModelRecordOrderBy.Size: "IFNULL(json_extract(config, '$.file_size'), 0)",
+ ModelRecordOrderBy.DateAdded: "created_at",
+ ModelRecordOrderBy.DateModified: "updated_at",
+ ModelRecordOrderBy.Path: "path",
}
where_clause: list[str] = []
@@ -301,7 +309,7 @@ def search_by_attr(
SELECT config
FROM models
{where}
- ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
+ ORDER BY {ordering[order_by]} {order_dir} -- using ? to bind doesn't work here for some reason;
""",
tuple(bindings),
)
@@ -357,17 +365,26 @@ def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
return results
def list_models(
- self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
+ self,
+ page: int = 0,
+ per_page: int = 10,
+ order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default,
+ direction: SQLiteDirection = SQLiteDirection.Ascending,
) -> PaginatedResults[ModelSummary]:
"""Return a paginated summary listing of each model in the database."""
with self._db.transaction() as cursor:
assert isinstance(order_by, ModelRecordOrderBy)
+ order_dir = "DESC" if direction == SQLiteDirection.Descending else "ASC"
ordering = {
- ModelRecordOrderBy.Default: "type, base, name, format",
+ ModelRecordOrderBy.Default: f"type {order_dir}, base COLLATE NOCASE {order_dir}, name COLLATE NOCASE {order_dir}, format",
ModelRecordOrderBy.Type: "type",
- ModelRecordOrderBy.Base: "base",
- ModelRecordOrderBy.Name: "name",
+ ModelRecordOrderBy.Base: "base COLLATE NOCASE",
+ ModelRecordOrderBy.Name: "name COLLATE NOCASE",
ModelRecordOrderBy.Format: "format",
+ ModelRecordOrderBy.Size: "IFNULL(json_extract(config, '$.file_size'), 0)",
+ ModelRecordOrderBy.DateAdded: "created_at",
+ ModelRecordOrderBy.DateModified: "updated_at",
+ ModelRecordOrderBy.Path: "path",
}
# Lock so that the database isn't updated while we're doing the two queries.
@@ -385,7 +402,7 @@ def list_models(
f"""--sql
SELECT config
FROM models
- ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
+ ORDER BY {ordering[order_by]} {order_dir} -- using ? to bind doesn't work here for some reason
LIMIT ?
OFFSET ?;
""",
diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json
index 9b2aaddad73..fcac3efa0a0 100644
--- a/invokeai/frontend/web/public/locales/en.json
+++ b/invokeai/frontend/web/public/locales/en.json
@@ -1165,6 +1165,15 @@
"modelType": "Model Type",
"modelUpdated": "Model Updated",
"modelUpdateFailed": "Model Update Failed",
+ "sortByName": "Name",
+ "sortByBase": "Base",
+ "sortBySize": "Size",
+ "sortByDateAdded": "Date Added",
+ "sortByDateModified": "Date Modified",
+ "sortByPath": "Path",
+ "sortByType": "Type",
+ "sortByFormat": "Format",
+ "sortDefault": "Default",
"name": "Name",
"modelPickerFallbackNoModelsInstalled": "No models installed.",
"modelPickerFallbackNoModelsInstalled2": "Visit the Model Manager to install models.",
diff --git a/invokeai/frontend/web/src/common/hooks/useSubMenu.tsx b/invokeai/frontend/web/src/common/hooks/useSubMenu.tsx
index f8ea01909a7..4c1bc56e495 100644
--- a/invokeai/frontend/web/src/common/hooks/useSubMenu.tsx
+++ b/invokeai/frontend/web/src/common/hooks/useSubMenu.tsx
@@ -151,11 +151,18 @@ export const useSubMenu = (): UseSubMenuReturn => {
};
};
-export const SubMenuButtonContent = ({ label }: { label: string }) => {
+export const SubMenuButtonContent = ({ label, value }: { label: string; value?: string }) => {
return (
{label}
-
+
+ {value !== undefined && (
+
+ {value}
+
+ )}
+
+
);
};
diff --git a/invokeai/frontend/web/src/features/modelManagerV2/models.ts b/invokeai/frontend/web/src/features/modelManagerV2/models.ts
index 5cf18b337b0..e06ae38fcb0 100644
--- a/invokeai/frontend/web/src/features/modelManagerV2/models.ts
+++ b/invokeai/frontend/web/src/features/modelManagerV2/models.ts
@@ -30,7 +30,7 @@ export type ModelCategoryData = {
filter: (config: AnyModelConfig) => boolean;
};
-export const MODEL_CATEGORIES: Record = {
+const MODEL_CATEGORIES: Record = {
unknown: {
category: 'unknown',
i18nKey: 'common.unknown',
diff --git a/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts b/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts
index 092998d0c31..b6aea4138bb 100644
--- a/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts
+++ b/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts
@@ -22,6 +22,10 @@ const zModelManagerState = z.object({
scanPath: z.string().optional(),
shouldInstallInPlace: z.boolean(),
selectedModelKeys: z.array(z.string()),
+ orderBy: z
+ .enum(['default', 'name', 'type', 'base', 'size', 'created_at', 'updated_at', 'path', 'format'])
+ .default('name'),
+ sortDirection: z.enum(['asc', 'desc']).default('asc'),
});
type ModelManagerState = z.infer;
@@ -35,6 +39,8 @@ const getInitialState = (): ModelManagerState => ({
scanPath: undefined,
shouldInstallInPlace: true,
selectedModelKeys: [],
+ orderBy: 'name',
+ sortDirection: 'asc',
});
const slice = createSlice({
@@ -74,6 +80,12 @@ const slice = createSlice({
clearModelSelection: (state) => {
state.selectedModelKeys = [];
},
+ setOrderBy: (state, action: PayloadAction) => {
+ state.orderBy = action.payload;
+ },
+ setSortDirection: (state, action: PayloadAction) => {
+ state.sortDirection = action.payload;
+ },
},
});
@@ -87,6 +99,8 @@ export const {
modelSelectionChanged,
toggleModelSelection,
clearModelSelection,
+ setOrderBy,
+ setSortDirection,
} = slice.actions;
export const modelManagerSliceConfig: SliceConfig = {
@@ -116,3 +130,5 @@ export const selectSearchTerm = createModelManagerSelector((mm) => mm.searchTerm
export const selectFilteredModelType = createModelManagerSelector((mm) => mm.filteredModelType);
export const selectShouldInstallInPlace = createModelManagerSelector((mm) => mm.shouldInstallInPlace);
export const selectSelectedModelKeys = createModelManagerSelector((mm) => mm.selectedModelKeys);
+export const selectOrderBy = createModelManagerSelector((mm) => mm.orderBy);
+export const selectSortDirection = createModelManagerSelector((mm) => mm.sortDirection);
diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFilterMenu.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFilterMenu.tsx
new file mode 100644
index 00000000000..370ffc05415
--- /dev/null
+++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFilterMenu.tsx
@@ -0,0 +1,238 @@
+import { Button, Flex, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
+import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
+import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
+import type { ModelCategoryData } from 'features/modelManagerV2/models';
+import { MODEL_CATEGORIES_AS_LIST } from 'features/modelManagerV2/models';
+import {
+ selectFilteredModelType,
+ selectOrderBy,
+ selectSortDirection,
+ setFilteredModelType,
+ setOrderBy,
+ setSortDirection,
+} from 'features/modelManagerV2/store/modelManagerV2Slice';
+import { memo, useCallback, useMemo } from 'react';
+import { useTranslation } from 'react-i18next';
+import {
+ PiCheckBold,
+ PiFunnelBold,
+ PiListBold,
+ PiSortAscendingBold,
+ PiSortDescendingBold,
+ PiWarningBold,
+} from 'react-icons/pi';
+
+type OrderBy = 'default' | 'name' | 'type' | 'base' | 'size' | 'created_at' | 'updated_at' | 'path' | 'format';
+
+const ORDER_BY_OPTIONS: OrderBy[] = [
+ 'default',
+ 'name',
+ 'base',
+ 'size',
+ 'created_at',
+ 'updated_at',
+ 'path',
+ 'type',
+ 'format',
+];
+
+const SortByMenuItem = memo(({ option, label }: { option: OrderBy; label: string }) => {
+ const dispatch = useAppDispatch();
+ const orderBy = useAppSelector(selectOrderBy);
+ const onClick = useCallback(() => {
+ dispatch(setOrderBy(option));
+ }, [dispatch, option]);
+
+ return (
+ : }
+ >
+ {label}
+
+ );
+});
+SortByMenuItem.displayName = 'SortByMenuItem';
+
+const SortBySubMenu = memo(() => {
+ const { t } = useTranslation();
+ const subMenu = useSubMenu();
+ const orderBy = useAppSelector(selectOrderBy);
+
+ const ORDER_BY_LABELS = useMemo(
+ () => ({
+ default: t('modelManager.sortDefault'),
+ name: t('modelManager.sortByName'),
+ base: t('modelManager.sortByBase'),
+ size: t('modelManager.sortBySize'),
+ created_at: t('modelManager.sortByDateAdded'),
+ updated_at: t('modelManager.sortByDateModified'),
+ path: t('modelManager.sortByPath'),
+ type: t('modelManager.sortByType'),
+ format: t('modelManager.sortByFormat'),
+ }),
+ [t]
+ );
+
+ return (
+ }>
+
+
+ );
+});
+SortBySubMenu.displayName = 'SortBySubMenu';
+
+const DirectionSubMenu = memo(() => {
+ const { t } = useTranslation();
+ const dispatch = useAppDispatch();
+ const direction = useAppSelector(selectSortDirection);
+ const subMenu = useSubMenu();
+
+ const setDirectionAsc = useCallback(() => {
+ dispatch(setSortDirection('asc'));
+ }, [dispatch]);
+
+ const setDirectionDesc = useCallback(() => {
+ dispatch(setSortDirection('desc'));
+ }, [dispatch]);
+
+ const currentValue = direction === 'asc' ? t('common.ascending', 'Ascending') : t('common.descending', 'Descending');
+
+ return (
+ : }
+ >
+
+
+ );
+});
+DirectionSubMenu.displayName = 'DirectionSubMenu';
+
+const ModelTypeSubMenu = memo(() => {
+ const { t } = useTranslation();
+ const dispatch = useAppDispatch();
+ const filteredModelType = useAppSelector(selectFilteredModelType);
+ const subMenu = useSubMenu();
+
+ const clearModelType = useCallback(() => {
+ dispatch(setFilteredModelType(null));
+ }, [dispatch]);
+
+ const setMissingFilter = useCallback(() => {
+ dispatch(setFilteredModelType('missing'));
+ }, [dispatch]);
+
+ const currentValue = useMemo(() => {
+ if (filteredModelType === null) {
+ return t('modelManager.allModels');
+ }
+ if (filteredModelType === 'missing') {
+ return t('modelManager.missingFiles');
+ }
+ const categoryData = MODEL_CATEGORIES_AS_LIST.find((data) => data.category === filteredModelType);
+ return categoryData ? t(categoryData.i18nKey) : '';
+ }, [filteredModelType, t]);
+
+ return (
+ }>
+
+
+ );
+});
+ModelTypeSubMenu.displayName = 'ModelTypeSubMenu';
+
+const ModelMenuItem = memo(({ data }: { data: ModelCategoryData }) => {
+ const { t } = useTranslation();
+ const dispatch = useAppDispatch();
+ const filteredModelType = useAppSelector(selectFilteredModelType);
+ const onClick = useCallback(() => {
+ dispatch(setFilteredModelType(data.category));
+ }, [data.category, dispatch]);
+ return (
+ : }
+ >
+ {t(data.i18nKey)}
+
+ );
+});
+ModelMenuItem.displayName = 'ModelMenuItem';
+
+export const ModelFilterMenu = memo(() => {
+ const { t } = useTranslation();
+
+ return (
+
+ );
+});
+
+ModelFilterMenu.displayName = 'ModelFilterMenu';
diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx
index ed49fa2870b..033a439bfcc 100644
--- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx
+++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx
@@ -8,8 +8,10 @@ import {
clearModelSelection,
type FilterableModelType,
selectFilteredModelType,
+ selectOrderBy,
selectSearchTerm,
selectSelectedModelKeys,
+ selectSortDirection,
setSelectedModelKey,
} from 'features/modelManagerV2/store/modelManagerV2Slice';
import { memo, useCallback, useMemo, useState } from 'react';
@@ -39,6 +41,8 @@ const ModelList = () => {
const dispatch = useAppDispatch();
const filteredModelType = useAppSelector(selectFilteredModelType);
const searchTerm = useAppSelector(selectSearchTerm);
+ const orderBy = useAppSelector(selectOrderBy);
+ const direction = useAppSelector(selectSortDirection);
const selectedModelKeys = useAppSelector(selectSelectedModelKeys);
const { t } = useTranslation();
const toast = useToast();
@@ -47,7 +51,8 @@ const ModelList = () => {
const [isDeleting, setIsDeleting] = useState(false);
const [isReidentifying, setIsReidentifying] = useState(false);
- const { data: allModelsData, isLoading: isLoadingAll } = useGetModelConfigsQuery();
+ const queryArgs = useMemo(() => ({ order_by: orderBy, direction: direction.toUpperCase() }), [orderBy, direction]);
+ const { data: allModelsData, isLoading: isLoadingAll } = useGetModelConfigsQuery(queryArgs);
const { data: missingModelsData, isLoading: isLoadingMissing } = useGetMissingModelsQuery();
const [bulkDeleteModels] = useBulkDeleteModelsMutation();
const [bulkReidentifyModels] = useBulkReidentifyModelsMutation();
diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListNavigation.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListNavigation.tsx
index 78bed8ab830..bbfb88df5cf 100644
--- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListNavigation.tsx
+++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListNavigation.tsx
@@ -6,8 +6,8 @@ import type { ChangeEventHandler } from 'react';
import { memo, useCallback } from 'react';
import { PiXBold } from 'react-icons/pi';
+import { ModelFilterMenu } from './ModelFilterMenu';
import { ModelListBulkActions } from './ModelListBulkActions';
-import { ModelTypeFilter } from './ModelTypeFilter';
export const ModelListNavigation = memo(() => {
const dispatch = useAppDispatch();
@@ -50,7 +50,7 @@ export const ModelListNavigation = memo(() => {
-
+
diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx
deleted file mode 100644
index 5aa8e628869..00000000000
--- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx
+++ /dev/null
@@ -1,78 +0,0 @@
-import { Button, Flex, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
-import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
-import type { ModelCategoryData } from 'features/modelManagerV2/models';
-import { MODEL_CATEGORIES, MODEL_CATEGORIES_AS_LIST } from 'features/modelManagerV2/models';
-import type { ModelCategoryType } from 'features/modelManagerV2/store/modelManagerV2Slice';
-import { selectFilteredModelType, setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
-import { memo, useCallback } from 'react';
-import { useTranslation } from 'react-i18next';
-import { PiFunnelBold, PiWarningBold } from 'react-icons/pi';
-
-const isModelCategoryType = (type: string): type is ModelCategoryType => {
- return type in MODEL_CATEGORIES;
-};
-
-export const ModelTypeFilter = memo(() => {
- const { t } = useTranslation();
- const dispatch = useAppDispatch();
- const filteredModelType = useAppSelector(selectFilteredModelType);
-
- const clearModelType = useCallback(() => {
- dispatch(setFilteredModelType(null));
- }, [dispatch]);
-
- const setMissingFilter = useCallback(() => {
- dispatch(setFilteredModelType('missing'));
- }, [dispatch]);
-
- const getButtonLabel = () => {
- if (filteredModelType === 'missing') {
- return t('modelManager.missingFiles');
- }
- if (filteredModelType && isModelCategoryType(filteredModelType)) {
- return t(MODEL_CATEGORIES[filteredModelType].i18nKey);
- }
- return t('modelManager.allModels');
- };
-
- return (
-
- );
-});
-
-ModelTypeFilter.displayName = 'ModelTypeFilter';
-
-const ModelMenuItem = memo(({ data }: { data: ModelCategoryData }) => {
- const { t } = useTranslation();
- const dispatch = useAppDispatch();
- const filteredModelType = useAppSelector(selectFilteredModelType);
- const onClick = useCallback(() => {
- dispatch(setFilteredModelType(data.category));
- }, [data.category, dispatch]);
- return (
-
- );
-});
-ModelMenuItem.displayName = 'ModelMenuItem';
diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts
index c3d0decd53c..f279d46d823 100644
--- a/invokeai/frontend/web/src/services/api/endpoints/models.ts
+++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts
@@ -111,9 +111,13 @@ type DeleteOrphanedModelsResponse = {
errors: Record;
};
+type GetModelConfigsArg = {
+ order_by?: string;
+ direction?: string;
+} | void;
+
const modelConfigsAdapter = createEntityAdapter({
selectId: (entity) => entity.key,
- sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const modelConfigsAdapterSelectors = modelConfigsAdapter.getSelectors(undefined, getSelectorsOptions);
@@ -338,8 +342,11 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['ModelInstalls'],
}),
- getModelConfigs: build.query, void>({
- query: () => ({ url: buildModelsUrl() }),
+ getModelConfigs: build.query, GetModelConfigsArg>({
+ query: (arg) => {
+ const queryStr = arg ? `?${queryString.stringify(arg)}` : '';
+ return { url: buildModelsUrl(queryStr) };
+ },
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'ModelConfig', id: LIST_TAG }];
if (result) {
@@ -498,5 +505,5 @@ export const {
useDeleteOrphanedModelsMutation,
} = modelsApi;
-export const selectModelConfigsQuery = modelsApi.endpoints.getModelConfigs.select();
+export const selectModelConfigsQuery = modelsApi.endpoints.getModelConfigs.select(undefined);
export const selectMissingModelsQuery = modelsApi.endpoints.getMissingModels.select();
diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts
index f7ef229bf31..758428c6407 100644
--- a/invokeai/frontend/web/src/services/api/schema.ts
+++ b/invokeai/frontend/web/src/services/api/schema.ts
@@ -22429,6 +22429,12 @@ export type components = {
*/
config_path?: string | null;
};
+ /**
+ * ModelRecordOrderBy
+ * @description The order in which to return model summaries.
+ * @enum {string}
+ */
+ ModelRecordOrderBy: "default" | "type" | "base" | "name" | "format" | "size" | "created_at" | "updated_at" | "path";
/** ModelRelationshipBatchRequest */
ModelRelationshipBatchRequest: {
/**
@@ -30787,6 +30793,10 @@ export interface operations {
model_name?: string | null;
/** @description Exact match on the format of the model (e.g. 'diffusers') */
model_format?: components["schemas"]["ModelFormat"] | null;
+ /** @description The field to order by */
+ order_by?: components["schemas"]["ModelRecordOrderBy"];
+ /** @description The direction to order by */
+ direction?: components["schemas"]["SQLiteDirection"];
};
header?: never;
path?: never;
diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py
index 2b6c54d5b0f..19a1b74e73f 100644
--- a/tests/app/services/model_records/test_model_records_sql.py
+++ b/tests/app/services/model_records/test_model_records_sql.py
@@ -11,11 +11,13 @@
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import (
DuplicateModelException,
+ ModelRecordOrderBy,
ModelRecordServiceBase,
ModelRecordServiceSQL,
UnknownModelException,
)
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
+from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings
from invokeai.backend.model_manager.configs.lora import LoRA_LyCORIS_SDXL_Config
from invokeai.backend.model_manager.configs.main import (
@@ -364,6 +366,73 @@ def test_filter_2(store: ModelRecordServiceBase):
assert len(matches) == 1
+def test_search_by_attr_sorting(store: ModelRecordServiceSQL):
+ config1 = Main_Diffusers_SD1_Config(
+ path="/tmp/config1",
+ name="alpha",
+ base=BaseModelType.StableDiffusion1,
+ type=ModelType.Main,
+ hash="CONFIG1HASH",
+ file_size=1000,
+ source="test/source/",
+ source_type=ModelSourceType.Path,
+ variant=ModelVariantType.Normal,
+ prediction_type=SchedulerPredictionType.Epsilon,
+ repo_variant=ModelRepoVariant.Default,
+ )
+ config2 = Main_Diffusers_SD2_Config(
+ path="/tmp/config2",
+ name="beta",
+ base=BaseModelType.StableDiffusion2,
+ type=ModelType.Main,
+ hash="CONFIG2HASH",
+ file_size=2000,
+ source="test/source/",
+ source_type=ModelSourceType.Path,
+ variant=ModelVariantType.Normal,
+ prediction_type=SchedulerPredictionType.Epsilon,
+ repo_variant=ModelRepoVariant.Default,
+ )
+ config3 = VAE_Diffusers_SD1_Config(
+ path="/tmp/config3",
+ name="gamma",
+ base=BaseModelType.StableDiffusion1,
+ type=ModelType.VAE,
+ hash="CONFIG3HASH",
+ file_size=500,
+ source="test/source/",
+ source_type=ModelSourceType.Path,
+ repo_variant=ModelRepoVariant.Default,
+ )
+ for c in config1, config2, config3:
+ store.add_model(c)
+
+ # Test sorting by Name Ascending
+ matches = store.search_by_attr(order_by=ModelRecordOrderBy.Name, direction=SQLiteDirection.Ascending)
+ assert len(matches) == 3
+ assert matches[0].name == "alpha"
+ assert matches[1].name == "beta"
+ assert matches[2].name == "gamma"
+
+ # Test sorting by Name Descending
+ matches = store.search_by_attr(order_by=ModelRecordOrderBy.Name, direction=SQLiteDirection.Descending)
+ assert matches[0].name == "gamma"
+ assert matches[1].name == "beta"
+ assert matches[2].name == "alpha"
+
+ # Test sorting by Size Ascending
+ matches = store.search_by_attr(order_by=ModelRecordOrderBy.Size, direction=SQLiteDirection.Ascending)
+ assert matches[0].name == "gamma" # 500
+ assert matches[1].name == "alpha" # 1000
+ assert matches[2].name == "beta" # 2000
+
+ # Test sorting by Size Descending
+ matches = store.search_by_attr(order_by=ModelRecordOrderBy.Size, direction=SQLiteDirection.Descending)
+ assert matches[0].name == "beta" # 2000
+ assert matches[1].name == "alpha" # 1000
+ assert matches[2].name == "gamma" # 500
+
+
def test_model_record_changes():
# This test guards against some unexpected behaviours from pydantic's union evaluation. See #6035
changes = ModelRecordChanges.model_validate({"default_settings": {"preprocessor": "value"}})