Skip to content

Commit b2d79dc

Browse files
skunkworxdarkJPPhotoPfannkuchensacklstein
authored
feat:(model-manager) add sorting capabilities for models (#9024)
* feat(model-manager): add comprehensive sorting capabilities for models dded the ability to sort models in the Model Manager by various attributes including Name, Base, Type, Format, Size, Date Added, and Date Modified. Supports both ascending and descending order. - Backend: Added `order_by` and `direction` query parameters to the ``/api/v1/models`/` listing endpoint. Implemented case-insensitive sorting in the SQLite model records service. - Frontend: Introduced `<ModelSortControl />` UI, updated Redux slices to manage sort state, removed client-side entity adapter sorting to respect server-side ordering, and added i18n localization keys. - Tests: Added test coverage for SQL-based sorting on size and name. * feat(model-manager): add comprehensive sorting capabilities for models dded the ability to sort models in the Model Manager by various attributes including Name, Base, Type, Format, Size, Date Added, and Date Modified. Supports both ascending and descending order. - Backend: Added `order_by` and `direction` query parameters to the ``/api/v1/models`/` listing endpoint. Implemented case-insensitive sorting in the SQLite model records service. - Frontend: Introduced `<ModelSortControl />` UI, updated Redux slices to manage sort state, removed client-side entity adapter sorting to respect server-side ordering, and added i18n localization keys. - Tests: Added test coverage for SQL-based sorting on size and name. * ruff fix * typegen fix * typegen fix - this time without my custom nodes. * another typegen fix * refactor(ui): consolidate model filter and sort controls into a unified menu - Replaced separate `ModelSortControl` and `ModelTypeFilter` components with a single, unified "Filtering" dropdown menu. - Organised filtering options into categorised submenus in the following order: Direction, Sort By, and Model Type. - Enhanced submenu labels to display the currently active selection inline for quick reference. - Improved visual alignment within menus by using hidden checkmarks on unselected items, ensuring consistent indentation across all options. - Resolved styling and linting issues (unused variables, JSX bind warnings) within the new component. * Lint fix * Addresses PR feedback to use translation strings directly within `ORDER_BY_OPTIONS`. Previously, sort keys and their translated labels were maintained in separate constructs (`ORDER_BY_OPTIONS` array and `ORDER_BY_LABELS` map). This refactor converts `ORDER_BY_OPTIONS` into an array of objects containing both the `key` and its corresponding `i18nKey`, creating a single source of truth. This change: - Simplifies the `SortBySubMenu` component by removing the redundant `ORDER_BY_LABELS` lookup map. - Improves maintainability by ensuring developers only need to update one place when adding or modifying sort options. - Reduces the risk of mismatched keys and labels. --------- Co-authored-by: Jonathan <34005131+JPPhoto@users.noreply.github.com> Co-authored-by: Alexander Eichhorn <alex@eichhorn.dev> Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
1 parent d7d623e commit b2d79dc

14 files changed

Lines changed: 419 additions & 100 deletions

File tree

invokeai/app/api/routers/model_manager.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@
2626
from invokeai.app.services.model_records import (
2727
InvalidModelException,
2828
ModelRecordChanges,
29+
ModelRecordOrderBy,
2930
UnknownModelException,
3031
)
3132
from invokeai.app.services.orphaned_models import OrphanedModelInfo
33+
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
3234
from invokeai.app.util.suppress_output import SuppressOutput
3335
from invokeai.backend.model_manager.configs.external_api import ExternalApiModelConfig
3436
from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory
@@ -159,6 +161,8 @@ async def list_model_records(
159161
model_format: Optional[ModelFormat] = Query(
160162
default=None, description="Exact match on the format of the model (e.g. 'diffusers')"
161163
),
164+
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Name, description="The field to order by"),
165+
direction: SQLiteDirection = Query(default=SQLiteDirection.Ascending, description="The direction to order by"),
162166
) -> ModelsList:
163167
"""Get a list of models."""
164168
record_store = ApiDependencies.invoker.services.model_manager.store
@@ -167,12 +171,23 @@ async def list_model_records(
167171
for base_model in base_models:
168172
found_models.extend(
169173
record_store.search_by_attr(
170-
base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format
174+
base_model=base_model,
175+
model_type=model_type,
176+
model_name=model_name,
177+
model_format=model_format,
178+
order_by=order_by,
179+
direction=direction,
171180
)
172181
)
173182
else:
174183
found_models.extend(
175-
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
184+
record_store.search_by_attr(
185+
model_type=model_type,
186+
model_name=model_name,
187+
model_format=model_format,
188+
order_by=order_by,
189+
direction=direction,
190+
)
176191
)
177192
for index, model in enumerate(found_models):
178193
found_models[index] = prepare_model_config_for_response(model, ApiDependencies)

invokeai/app/services/model_records/model_records_base.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pydantic import BaseModel, Field
1212

1313
from invokeai.app.services.shared.pagination import PaginatedResults
14+
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
1415
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
1516
from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings
1617
from invokeai.backend.model_manager.configs.external_api import (
@@ -60,6 +61,10 @@ class ModelRecordOrderBy(str, Enum):
6061
Base = "base"
6162
Name = "name"
6263
Format = "format"
64+
Size = "size"
65+
DateAdded = "created_at"
66+
DateModified = "updated_at"
67+
Path = "path"
6368

6469

6570
class ModelSummary(BaseModel):
@@ -200,7 +205,11 @@ def get_model_by_hash(self, hash: str) -> AnyModelConfig:
200205

201206
@abstractmethod
202207
def list_models(
203-
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
208+
self,
209+
page: int = 0,
210+
per_page: int = 10,
211+
order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default,
212+
direction: SQLiteDirection = SQLiteDirection.Ascending,
204213
) -> PaginatedResults[ModelSummary]:
205214
"""Return a paginated summary listing of each model in the database."""
206215
pass
@@ -237,6 +246,8 @@ def search_by_attr(
237246
base_model: Optional[BaseModelType] = None,
238247
model_type: Optional[ModelType] = None,
239248
model_format: Optional[ModelFormat] = None,
249+
order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default,
250+
direction: SQLiteDirection = SQLiteDirection.Ascending,
240251
) -> List[AnyModelConfig]:
241252
"""
242253
Return models matching name, base and/or type.

invokeai/app/services/model_records/model_records_sql.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
UnknownModelException,
5858
)
5959
from invokeai.app.services.shared.pagination import PaginatedResults
60+
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
6061
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
6162
from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory
6263
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
@@ -257,6 +258,7 @@ def search_by_attr(
257258
model_type: Optional[ModelType] = None,
258259
model_format: Optional[ModelFormat] = None,
259260
order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default,
261+
direction: SQLiteDirection = SQLiteDirection.Ascending,
260262
) -> List[AnyModelConfig]:
261263
"""
262264
Return models matching name, base and/or type.
@@ -266,18 +268,24 @@ def search_by_attr(
266268
:param model_type: Filter by type of model (optional)
267269
:param model_format: Filter by model format (e.g. "diffusers") (optional)
268270
:param order_by: Result order
271+
:param direction: Result direction
269272
270273
If none of the optional filters are passed, will return all
271274
models in the database.
272275
"""
273276
with self._db.transaction() as cursor:
274277
assert isinstance(order_by, ModelRecordOrderBy)
278+
order_dir = "DESC" if direction == SQLiteDirection.Descending else "ASC"
275279
ordering = {
276-
ModelRecordOrderBy.Default: "type, base, name, format",
280+
ModelRecordOrderBy.Default: f"type {order_dir}, base COLLATE NOCASE {order_dir}, name COLLATE NOCASE {order_dir}, format",
277281
ModelRecordOrderBy.Type: "type",
278-
ModelRecordOrderBy.Base: "base",
279-
ModelRecordOrderBy.Name: "name",
282+
ModelRecordOrderBy.Base: "base COLLATE NOCASE",
283+
ModelRecordOrderBy.Name: "name COLLATE NOCASE",
280284
ModelRecordOrderBy.Format: "format",
285+
ModelRecordOrderBy.Size: "IFNULL(json_extract(config, '$.file_size'), 0)",
286+
ModelRecordOrderBy.DateAdded: "created_at",
287+
ModelRecordOrderBy.DateModified: "updated_at",
288+
ModelRecordOrderBy.Path: "path",
281289
}
282290

283291
where_clause: list[str] = []
@@ -301,7 +309,7 @@ def search_by_attr(
301309
SELECT config
302310
FROM models
303311
{where}
304-
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
312+
ORDER BY {ordering[order_by]} {order_dir} -- using ? to bind doesn't work here for some reason;
305313
""",
306314
tuple(bindings),
307315
)
@@ -357,17 +365,26 @@ def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
357365
return results
358366

359367
def list_models(
360-
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
368+
self,
369+
page: int = 0,
370+
per_page: int = 10,
371+
order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default,
372+
direction: SQLiteDirection = SQLiteDirection.Ascending,
361373
) -> PaginatedResults[ModelSummary]:
362374
"""Return a paginated summary listing of each model in the database."""
363375
with self._db.transaction() as cursor:
364376
assert isinstance(order_by, ModelRecordOrderBy)
377+
order_dir = "DESC" if direction == SQLiteDirection.Descending else "ASC"
365378
ordering = {
366-
ModelRecordOrderBy.Default: "type, base, name, format",
379+
ModelRecordOrderBy.Default: f"type {order_dir}, base COLLATE NOCASE {order_dir}, name COLLATE NOCASE {order_dir}, format",
367380
ModelRecordOrderBy.Type: "type",
368-
ModelRecordOrderBy.Base: "base",
369-
ModelRecordOrderBy.Name: "name",
381+
ModelRecordOrderBy.Base: "base COLLATE NOCASE",
382+
ModelRecordOrderBy.Name: "name COLLATE NOCASE",
370383
ModelRecordOrderBy.Format: "format",
384+
ModelRecordOrderBy.Size: "IFNULL(json_extract(config, '$.file_size'), 0)",
385+
ModelRecordOrderBy.DateAdded: "created_at",
386+
ModelRecordOrderBy.DateModified: "updated_at",
387+
ModelRecordOrderBy.Path: "path",
371388
}
372389

373390
# Lock so that the database isn't updated while we're doing the two queries.
@@ -385,7 +402,7 @@ def list_models(
385402
f"""--sql
386403
SELECT config
387404
FROM models
388-
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
405+
ORDER BY {ordering[order_by]} {order_dir} -- using ? to bind doesn't work here for some reason
389406
LIMIT ?
390407
OFFSET ?;
391408
""",

invokeai/frontend/web/public/locales/en.json

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,6 +1203,15 @@
12031203
"modelType": "Model Type",
12041204
"modelUpdated": "Model Updated",
12051205
"modelUpdateFailed": "Model Update Failed",
1206+
"sortByName": "Name",
1207+
"sortByBase": "Base",
1208+
"sortBySize": "Size",
1209+
"sortByDateAdded": "Date Added",
1210+
"sortByDateModified": "Date Modified",
1211+
"sortByPath": "Path",
1212+
"sortByType": "Type",
1213+
"sortByFormat": "Format",
1214+
"sortDefault": "Default",
12061215
"name": "Name",
12071216
"externalProvider": "External Provider",
12081217
"externalCapabilities": "External Capabilities",

invokeai/frontend/web/src/common/hooks/useSubMenu.tsx

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,18 @@ export const useSubMenu = (): UseSubMenuReturn => {
151151
};
152152
};
153153

154-
export const SubMenuButtonContent = ({ label }: { label: string }) => {
154+
export const SubMenuButtonContent = ({ label, value }: { label: string; value?: string }) => {
155155
return (
156156
<Flex w="full" h="full" flexDir="row" justifyContent="space-between" alignItems="center">
157157
<Text>{label}</Text>
158-
<Icon as={PiCaretRightBold} />
158+
<Flex alignItems="center" gap={2}>
159+
{value !== undefined && (
160+
<Text fontSize="sm" color="base.400">
161+
{value}
162+
</Text>
163+
)}
164+
<Icon as={PiCaretRightBold} />
165+
</Flex>
159166
</Flex>
160167
);
161168
};

invokeai/frontend/web/src/features/modelManagerV2/models.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ export type ModelCategoryData = {
3131
filter: (config: AnyModelConfig) => boolean;
3232
};
3333

34-
export const MODEL_CATEGORIES: Record<ModelCategoryType, ModelCategoryData> = {
34+
const MODEL_CATEGORIES: Record<ModelCategoryType, ModelCategoryData> = {
3535
unknown: {
3636
category: 'unknown',
3737
i18nKey: 'common.unknown',

invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ const zModelManagerState = z.object({
2525
scanPath: z.string().optional(),
2626
shouldInstallInPlace: z.boolean(),
2727
selectedModelKeys: z.array(z.string()),
28+
orderBy: z
29+
.enum(['default', 'name', 'type', 'base', 'size', 'created_at', 'updated_at', 'path', 'format'])
30+
.default('name'),
31+
sortDirection: z.enum(['asc', 'desc']).default('asc'),
2832
});
2933

3034
type ModelManagerState = z.infer<typeof zModelManagerState>;
@@ -38,6 +42,8 @@ const getInitialState = (): ModelManagerState => ({
3842
scanPath: undefined,
3943
shouldInstallInPlace: true,
4044
selectedModelKeys: [],
45+
orderBy: 'name',
46+
sortDirection: 'asc',
4147
});
4248

4349
const slice = createSlice({
@@ -77,6 +83,12 @@ const slice = createSlice({
7783
clearModelSelection: (state) => {
7884
state.selectedModelKeys = [];
7985
},
86+
setOrderBy: (state, action: PayloadAction<ModelManagerState['orderBy']>) => {
87+
state.orderBy = action.payload;
88+
},
89+
setSortDirection: (state, action: PayloadAction<ModelManagerState['sortDirection']>) => {
90+
state.sortDirection = action.payload;
91+
},
8092
},
8193
});
8294

@@ -90,6 +102,8 @@ export const {
90102
modelSelectionChanged,
91103
toggleModelSelection,
92104
clearModelSelection,
105+
setOrderBy,
106+
setSortDirection,
93107
} = slice.actions;
94108

95109
export const modelManagerSliceConfig: SliceConfig<typeof slice> = {
@@ -119,3 +133,5 @@ export const selectSearchTerm = createModelManagerSelector((mm) => mm.searchTerm
119133
export const selectFilteredModelType = createModelManagerSelector((mm) => mm.filteredModelType);
120134
export const selectShouldInstallInPlace = createModelManagerSelector((mm) => mm.shouldInstallInPlace);
121135
export const selectSelectedModelKeys = createModelManagerSelector((mm) => mm.selectedModelKeys);
136+
export const selectOrderBy = createModelManagerSelector((mm) => mm.orderBy);
137+
export const selectSortDirection = createModelManagerSelector((mm) => mm.sortDirection);

0 commit comments

Comments
 (0)