Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions invokeai/app/api/routers/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
from invokeai.app.services.model_records import (
InvalidModelException,
ModelRecordChanges,
ModelRecordOrderBy,
UnknownModelException,
)
from invokeai.app.services.orphaned_models import OrphanedModelInfo
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.util.suppress_output import SuppressOutput
from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory
from invokeai.backend.model_manager.configs.main import (
Expand Down Expand Up @@ -130,6 +132,8 @@ async def list_model_records(
model_format: Optional[ModelFormat] = Query(
default=None, description="Exact match on the format of the model (e.g. 'diffusers')"
),
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Name, description="The field to order by"),
direction: SQLiteDirection = Query(default=SQLiteDirection.Ascending, description="The direction to order by"),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_manager.store
Expand All @@ -138,12 +142,23 @@ async def list_model_records(
for base_model in base_models:
found_models.extend(
record_store.search_by_attr(
base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format
base_model=base_model,
model_type=model_type,
model_name=model_name,
model_format=model_format,
order_by=order_by,
direction=direction,
)
)
else:
found_models.extend(
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
record_store.search_by_attr(
model_type=model_type,
model_name=model_name,
model_format=model_format,
order_by=order_by,
direction=direction,
)
)
for model in found_models:
model = add_cover_image_to_model_config(model, ApiDependencies)
Expand Down
13 changes: 12 additions & 1 deletion invokeai/app/services/model_records/model_records_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pydantic import BaseModel, Field

from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
Expand Down Expand Up @@ -56,6 +57,10 @@ class ModelRecordOrderBy(str, Enum):
Base = "base"
Name = "name"
Format = "format"
Size = "size"
DateAdded = "created_at"
DateModified = "updated_at"
Path = "path"


class ModelSummary(BaseModel):
Expand Down Expand Up @@ -185,7 +190,11 @@ def get_model_by_hash(self, hash: str) -> AnyModelConfig:

@abstractmethod
def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
self,
page: int = 0,
per_page: int = 10,
order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default,
direction: SQLiteDirection = SQLiteDirection.Ascending,
) -> PaginatedResults[ModelSummary]:
"""Return a paginated summary listing of each model in the database."""
pass
Expand Down Expand Up @@ -222,6 +231,8 @@ def search_by_attr(
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
model_format: Optional[ModelFormat] = None,
order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default,
direction: SQLiteDirection = SQLiteDirection.Ascending,
) -> List[AnyModelConfig]:
"""
Return models matching name, base and/or type.
Expand Down
35 changes: 26 additions & 9 deletions invokeai/app/services/model_records/model_records_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
UnknownModelException,
)
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
Expand Down Expand Up @@ -257,6 +258,7 @@ def search_by_attr(
model_type: Optional[ModelType] = None,
model_format: Optional[ModelFormat] = None,
order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default,
direction: SQLiteDirection = SQLiteDirection.Ascending,
) -> List[AnyModelConfig]:
"""
Return models matching name, base and/or type.
Expand All @@ -266,18 +268,24 @@ def search_by_attr(
:param model_type: Filter by type of model (optional)
:param model_format: Filter by model format (e.g. "diffusers") (optional)
:param order_by: Result order
:param direction: Result direction

If none of the optional filters are passed, will return all
models in the database.
"""
with self._db.transaction() as cursor:
assert isinstance(order_by, ModelRecordOrderBy)
order_dir = "DESC" if direction == SQLiteDirection.Descending else "ASC"
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Default: f"type {order_dir}, base COLLATE NOCASE {order_dir}, name COLLATE NOCASE {order_dir}, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Base: "base COLLATE NOCASE",
ModelRecordOrderBy.Name: "name COLLATE NOCASE",
ModelRecordOrderBy.Format: "format",
ModelRecordOrderBy.Size: "IFNULL(json_extract(config, '$.file_size'), 0)",
ModelRecordOrderBy.DateAdded: "created_at",
ModelRecordOrderBy.DateModified: "updated_at",
ModelRecordOrderBy.Path: "path",
}

where_clause: list[str] = []
Expand All @@ -301,7 +309,7 @@ def search_by_attr(
SELECT config
FROM models
{where}
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
ORDER BY {ordering[order_by]} {order_dir} -- using ? to bind doesn't work here for some reason;
""",
tuple(bindings),
)
Expand Down Expand Up @@ -357,17 +365,26 @@ def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
return results

def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
self,
page: int = 0,
per_page: int = 10,
order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default,
direction: SQLiteDirection = SQLiteDirection.Ascending,
) -> PaginatedResults[ModelSummary]:
"""Return a paginated summary listing of each model in the database."""
with self._db.transaction() as cursor:
assert isinstance(order_by, ModelRecordOrderBy)
order_dir = "DESC" if direction == SQLiteDirection.Descending else "ASC"
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Default: f"type {order_dir}, base COLLATE NOCASE {order_dir}, name COLLATE NOCASE {order_dir}, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Base: "base COLLATE NOCASE",
ModelRecordOrderBy.Name: "name COLLATE NOCASE",
ModelRecordOrderBy.Format: "format",
ModelRecordOrderBy.Size: "IFNULL(json_extract(config, '$.file_size'), 0)",
ModelRecordOrderBy.DateAdded: "created_at",
ModelRecordOrderBy.DateModified: "updated_at",
ModelRecordOrderBy.Path: "path",
}

# Lock so that the database isn't updated while we're doing the two queries.
Expand All @@ -385,7 +402,7 @@ def list_models(
f"""--sql
SELECT config
FROM models
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
ORDER BY {ordering[order_by]} {order_dir} -- using ? to bind doesn't work here for some reason
LIMIT ?
OFFSET ?;
""",
Expand Down
9 changes: 9 additions & 0 deletions invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -1165,6 +1165,15 @@
"modelType": "Model Type",
"modelUpdated": "Model Updated",
"modelUpdateFailed": "Model Update Failed",
"sortByName": "Name",
"sortByBase": "Base",
"sortBySize": "Size",
"sortByDateAdded": "Date Added",
"sortByDateModified": "Date Modified",
"sortByPath": "Path",
"sortByType": "Type",
"sortByFormat": "Format",
"sortDefault": "Default",
"name": "Name",
"modelPickerFallbackNoModelsInstalled": "No models installed.",
"modelPickerFallbackNoModelsInstalled2": "Visit the <LinkComponent>Model Manager</LinkComponent> to install models.",
Expand Down
11 changes: 9 additions & 2 deletions invokeai/frontend/web/src/common/hooks/useSubMenu.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,18 @@ export const useSubMenu = (): UseSubMenuReturn => {
};
};

export const SubMenuButtonContent = ({ label }: { label: string }) => {
export const SubMenuButtonContent = ({ label, value }: { label: string; value?: string }) => {
return (
<Flex w="full" h="full" flexDir="row" justifyContent="space-between" alignItems="center">
<Text>{label}</Text>
<Icon as={PiCaretRightBold} />
<Flex alignItems="center" gap={2}>
{value !== undefined && (
<Text fontSize="sm" color="base.400">
{value}
</Text>
)}
<Icon as={PiCaretRightBold} />
</Flex>
</Flex>
);
};
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ export type ModelCategoryData = {
filter: (config: AnyModelConfig) => boolean;
};

export const MODEL_CATEGORIES: Record<ModelCategoryType, ModelCategoryData> = {
const MODEL_CATEGORIES: Record<ModelCategoryType, ModelCategoryData> = {
unknown: {
category: 'unknown',
i18nKey: 'common.unknown',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ const zModelManagerState = z.object({
scanPath: z.string().optional(),
shouldInstallInPlace: z.boolean(),
selectedModelKeys: z.array(z.string()),
orderBy: z
.enum(['default', 'name', 'type', 'base', 'size', 'created_at', 'updated_at', 'path', 'format'])
.default('name'),
sortDirection: z.enum(['asc', 'desc']).default('asc'),
});

type ModelManagerState = z.infer<typeof zModelManagerState>;
Expand All @@ -35,6 +39,8 @@ const getInitialState = (): ModelManagerState => ({
scanPath: undefined,
shouldInstallInPlace: true,
selectedModelKeys: [],
orderBy: 'name',
sortDirection: 'asc',
});

const slice = createSlice({
Expand Down Expand Up @@ -74,6 +80,12 @@ const slice = createSlice({
clearModelSelection: (state) => {
state.selectedModelKeys = [];
},
setOrderBy: (state, action: PayloadAction<ModelManagerState['orderBy']>) => {
state.orderBy = action.payload;
},
setSortDirection: (state, action: PayloadAction<ModelManagerState['sortDirection']>) => {
state.sortDirection = action.payload;
},
},
});

Expand All @@ -87,6 +99,8 @@ export const {
modelSelectionChanged,
toggleModelSelection,
clearModelSelection,
setOrderBy,
setSortDirection,
} = slice.actions;

export const modelManagerSliceConfig: SliceConfig<typeof slice> = {
Expand Down Expand Up @@ -116,3 +130,5 @@ export const selectSearchTerm = createModelManagerSelector((mm) => mm.searchTerm
export const selectFilteredModelType = createModelManagerSelector((mm) => mm.filteredModelType);
export const selectShouldInstallInPlace = createModelManagerSelector((mm) => mm.shouldInstallInPlace);
export const selectSelectedModelKeys = createModelManagerSelector((mm) => mm.selectedModelKeys);
export const selectOrderBy = createModelManagerSelector((mm) => mm.orderBy);
export const selectSortDirection = createModelManagerSelector((mm) => mm.sortDirection);
Loading
Loading