diff --git a/core/application/startup.go b/core/application/startup.go index 728c3c97221e..bb7ec82750c8 100644 --- a/core/application/startup.go +++ b/core/application/startup.go @@ -17,6 +17,7 @@ import ( "github.com/mudler/LocalAI/core/services/jobs" "github.com/mudler/LocalAI/core/services/nodes" "github.com/mudler/LocalAI/core/services/storage" + "github.com/mudler/LocalAI/pkg/vram" coreStartup "github.com/mudler/LocalAI/core/startup" "github.com/mudler/LocalAI/internal" @@ -231,6 +232,10 @@ func New(opts ...config.AppOption) (*Application, error) { xlog.Error("error registering external backends", "error", err) } + // Wire gallery generation counter into VRAM caches so they invalidate + // when gallery data refreshes instead of using a fixed TTL. + vram.SetGalleryGenerationFunc(gallery.GalleryGeneration) + if options.ConfigFile != "" { if err := application.ModelConfigLoader().LoadMultipleModelConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil { xlog.Error("error loading config file", "error", err) diff --git a/core/config/backend_capabilities.go b/core/config/backend_capabilities.go new file mode 100644 index 000000000000..02af395545be --- /dev/null +++ b/core/config/backend_capabilities.go @@ -0,0 +1,458 @@ +package config + +import ( + "slices" + "strings" +) + +// Usecase name constants — the canonical string values used in gallery entries, +// model configs (known_usecases), and UsecaseInfoMap keys. +const ( + UsecaseChat = "chat" + UsecaseCompletion = "completion" + UsecaseEdit = "edit" + UsecaseVision = "vision" + UsecaseEmbeddings = "embeddings" + UsecaseTokenize = "tokenize" + UsecaseImage = "image" + UsecaseVideo = "video" + UsecaseTranscript = "transcript" + UsecaseTTS = "tts" + UsecaseSoundGeneration = "sound_generation" + UsecaseRerank = "rerank" + UsecaseDetection = "detection" + UsecaseVAD = "vad" +) + +// GRPCMethod identifies a Backend service RPC from backend.proto. +type GRPCMethod string + +const ( + MethodPredict GRPCMethod = "Predict" + MethodPredictStream GRPCMethod = "PredictStream" + MethodEmbedding GRPCMethod = "Embedding" + MethodGenerateImage GRPCMethod = "GenerateImage" + MethodGenerateVideo GRPCMethod = "GenerateVideo" + MethodAudioTranscription GRPCMethod = "AudioTranscription" + MethodTTS GRPCMethod = "TTS" + MethodTTSStream GRPCMethod = "TTSStream" + MethodSoundGeneration GRPCMethod = "SoundGeneration" + MethodTokenizeString GRPCMethod = "TokenizeString" + MethodDetect GRPCMethod = "Detect" + MethodRerank GRPCMethod = "Rerank" + MethodVAD GRPCMethod = "VAD" +) + +// UsecaseInfo describes a single known_usecase value and how it maps +// to the gRPC backend API. +type UsecaseInfo struct { + // Flag is the ModelConfigUsecase bitmask value. + Flag ModelConfigUsecase + // GRPCMethod is the primary Backend service RPC this usecase maps to. + GRPCMethod GRPCMethod + // IsModifier is true when this usecase doesn't map to its own gRPC RPC + // but modifies how another RPC behaves (e.g., vision uses Predict with images). + IsModifier bool + // DependsOn names the usecase(s) this modifier requires (e.g., "chat"). + DependsOn string + // Description is a human/LLM-readable explanation of what this usecase means. + Description string +} + +// UsecaseInfoMap maps each known_usecase string to its gRPC and semantic info. +var UsecaseInfoMap = map[string]UsecaseInfo{ + UsecaseChat: { + Flag: FLAG_CHAT, + GRPCMethod: MethodPredict, + Description: "Conversational/instruction-following via the Predict RPC with chat templates.", + }, + UsecaseCompletion: { + Flag: FLAG_COMPLETION, + GRPCMethod: MethodPredict, + Description: "Text completion via the Predict RPC with a completion template.", + }, + UsecaseEdit: { + Flag: FLAG_EDIT, + GRPCMethod: MethodPredict, + Description: "Text editing via the Predict RPC with an edit template.", + }, + UsecaseVision: { + Flag: FLAG_VISION, + GRPCMethod: MethodPredict, + IsModifier: true, + DependsOn: UsecaseChat, + Description: "The model accepts images alongside text in the Predict RPC. For llama-cpp this requires an mmproj file.", + }, + UsecaseEmbeddings: { + Flag: FLAG_EMBEDDINGS, + GRPCMethod: MethodEmbedding, + Description: "Vector embedding generation via the Embedding RPC.", + }, + UsecaseTokenize: { + Flag: FLAG_TOKENIZE, + GRPCMethod: MethodTokenizeString, + Description: "Tokenization via the TokenizeString RPC without running inference.", + }, + UsecaseImage: { + Flag: FLAG_IMAGE, + GRPCMethod: MethodGenerateImage, + Description: "Image generation via the GenerateImage RPC (Stable Diffusion, Flux, etc.).", + }, + UsecaseVideo: { + Flag: FLAG_VIDEO, + GRPCMethod: MethodGenerateVideo, + Description: "Video generation via the GenerateVideo RPC.", + }, + UsecaseTranscript: { + Flag: FLAG_TRANSCRIPT, + GRPCMethod: MethodAudioTranscription, + Description: "Speech-to-text via the AudioTranscription RPC.", + }, + UsecaseTTS: { + Flag: FLAG_TTS, + GRPCMethod: MethodTTS, + Description: "Text-to-speech via the TTS RPC.", + }, + UsecaseSoundGeneration: { + Flag: FLAG_SOUND_GENERATION, + GRPCMethod: MethodSoundGeneration, + Description: "Music/sound generation via the SoundGeneration RPC (not speech).", + }, + UsecaseRerank: { + Flag: FLAG_RERANK, + GRPCMethod: MethodRerank, + Description: "Document reranking via the Rerank RPC.", + }, + UsecaseDetection: { + Flag: FLAG_DETECTION, + GRPCMethod: MethodDetect, + Description: "Object detection via the Detect RPC with bounding boxes.", + }, + UsecaseVAD: { + Flag: FLAG_VAD, + GRPCMethod: MethodVAD, + Description: "Voice activity detection via the VAD RPC.", + }, +} + +// BackendCapability describes which gRPC methods and usecases a backend supports. +// Derived from reviewing actual implementations in backend/go/ and backend/python/. +type BackendCapability struct { + // GRPCMethods lists the Backend service RPCs this backend implements. + GRPCMethods []GRPCMethod + // PossibleUsecases lists all usecase strings this backend can support. + PossibleUsecases []string + // DefaultUsecases lists the conservative safe defaults. + DefaultUsecases []string + // AcceptsImages indicates multimodal image input in Predict. + AcceptsImages bool + // AcceptsVideos indicates multimodal video input in Predict. + AcceptsVideos bool + // AcceptsAudios indicates multimodal audio input in Predict. + AcceptsAudios bool + // Description is a human-readable summary of the backend. + Description string +} + +// BackendCapabilities maps each backend name (as used in model configs and gallery +// entries) to its verified capabilities. This is the single source of truth for +// what each backend supports. +// +// Backend names use hyphens (e.g., "llama-cpp") matching the gallery convention. +// Use NormalizeBackendName() for names with dots (e.g., "llama.cpp"). +var BackendCapabilities = map[string]BackendCapability{ + // --- LLM / text generation backends --- + "llama-cpp": { + GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodEmbedding, MethodTokenizeString}, + PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseEdit, UsecaseEmbeddings, UsecaseTokenize, UsecaseVision}, + DefaultUsecases: []string{UsecaseChat}, + AcceptsImages: true, // requires mmproj + Description: "llama.cpp GGUF models — LLM inference with optional vision via mmproj", + }, + "vllm": { + GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodEmbedding}, + PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseEmbeddings, UsecaseVision}, + DefaultUsecases: []string{UsecaseChat}, + AcceptsImages: true, + AcceptsVideos: true, + Description: "vLLM engine — high-throughput LLM serving with optional multimodal", + }, + "vllm-omni": { + GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodGenerateImage, MethodGenerateVideo, MethodTTS}, + PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseImage, UsecaseVideo, UsecaseTTS, UsecaseVision}, + DefaultUsecases: []string{UsecaseChat}, + AcceptsImages: true, + AcceptsVideos: true, + AcceptsAudios: true, + Description: "vLLM omni-modal — supports text, image, video generation and TTS", + }, + "transformers": { + GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodEmbedding, MethodTTS, MethodSoundGeneration}, + PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseEmbeddings, UsecaseTTS, UsecaseSoundGeneration}, + DefaultUsecases: []string{UsecaseChat}, + Description: "HuggingFace transformers — general-purpose Python inference", + }, + "mlx": { + GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodEmbedding}, + PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseEmbeddings}, + DefaultUsecases: []string{UsecaseChat}, + Description: "Apple MLX framework — optimized for Apple Silicon", + }, + "mlx-distributed": { + GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodEmbedding}, + PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseEmbeddings}, + DefaultUsecases: []string{UsecaseChat}, + Description: "MLX distributed inference across multiple Apple Silicon devices", + }, + "mlx-vlm": { + GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodEmbedding}, + PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseEmbeddings, UsecaseVision}, + DefaultUsecases: []string{UsecaseChat, UsecaseVision}, + AcceptsImages: true, + AcceptsAudios: true, + Description: "MLX vision-language models with multimodal input", + }, + "mlx-audio": { + GRPCMethods: []GRPCMethod{MethodPredict, MethodTTS}, + PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseTTS}, + DefaultUsecases: []string{UsecaseChat}, + Description: "MLX audio models — text generation and TTS", + }, + + // --- Image/video generation backends --- + "diffusers": { + GRPCMethods: []GRPCMethod{MethodGenerateImage, MethodGenerateVideo}, + PossibleUsecases: []string{UsecaseImage, UsecaseVideo}, + DefaultUsecases: []string{UsecaseImage}, + Description: "HuggingFace diffusers — Stable Diffusion, Flux, video generation", + }, + "stablediffusion": { + GRPCMethods: []GRPCMethod{MethodGenerateImage}, + PossibleUsecases: []string{UsecaseImage}, + DefaultUsecases: []string{UsecaseImage}, + Description: "Stable Diffusion native backend", + }, + "stablediffusion-ggml": { + GRPCMethods: []GRPCMethod{MethodGenerateImage}, + PossibleUsecases: []string{UsecaseImage}, + DefaultUsecases: []string{UsecaseImage}, + Description: "Stable Diffusion via GGML quantized models", + }, + + // --- Speech-to-text backends --- + "whisper": { + GRPCMethods: []GRPCMethod{MethodAudioTranscription, MethodVAD}, + PossibleUsecases: []string{UsecaseTranscript, UsecaseVAD}, + DefaultUsecases: []string{UsecaseTranscript}, + Description: "OpenAI Whisper — speech recognition and voice activity detection", + }, + "faster-whisper": { + GRPCMethods: []GRPCMethod{MethodAudioTranscription}, + PossibleUsecases: []string{UsecaseTranscript}, + DefaultUsecases: []string{UsecaseTranscript}, + Description: "CTranslate2-accelerated Whisper for faster transcription", + }, + "whisperx": { + GRPCMethods: []GRPCMethod{MethodAudioTranscription}, + PossibleUsecases: []string{UsecaseTranscript}, + DefaultUsecases: []string{UsecaseTranscript}, + Description: "WhisperX — Whisper with word-level timestamps and speaker diarization", + }, + "moonshine": { + GRPCMethods: []GRPCMethod{MethodAudioTranscription}, + PossibleUsecases: []string{UsecaseTranscript}, + DefaultUsecases: []string{UsecaseTranscript}, + Description: "Moonshine speech recognition", + }, + "nemo": { + GRPCMethods: []GRPCMethod{MethodAudioTranscription}, + PossibleUsecases: []string{UsecaseTranscript}, + DefaultUsecases: []string{UsecaseTranscript}, + Description: "NVIDIA NeMo speech recognition", + }, + "qwen-asr": { + GRPCMethods: []GRPCMethod{MethodAudioTranscription}, + PossibleUsecases: []string{UsecaseTranscript}, + DefaultUsecases: []string{UsecaseTranscript}, + Description: "Qwen automatic speech recognition", + }, + "voxtral": { + GRPCMethods: []GRPCMethod{MethodAudioTranscription}, + PossibleUsecases: []string{UsecaseTranscript}, + DefaultUsecases: []string{UsecaseTranscript}, + Description: "Voxtral speech recognition", + }, + "vibevoice": { + GRPCMethods: []GRPCMethod{MethodAudioTranscription, MethodTTS}, + PossibleUsecases: []string{UsecaseTranscript, UsecaseTTS}, + DefaultUsecases: []string{UsecaseTranscript, UsecaseTTS}, + Description: "VibeVoice — bidirectional speech (transcription and synthesis)", + }, + + // --- TTS backends --- + "piper": { + GRPCMethods: []GRPCMethod{MethodTTS}, + PossibleUsecases: []string{UsecaseTTS}, + DefaultUsecases: []string{UsecaseTTS}, + Description: "Piper — fast neural TTS optimized for Raspberry Pi", + }, + "kokoro": { + GRPCMethods: []GRPCMethod{MethodTTS}, + PossibleUsecases: []string{UsecaseTTS}, + DefaultUsecases: []string{UsecaseTTS}, + Description: "Kokoro TTS", + }, + "coqui": { + GRPCMethods: []GRPCMethod{MethodTTS}, + PossibleUsecases: []string{UsecaseTTS}, + DefaultUsecases: []string{UsecaseTTS}, + Description: "Coqui TTS — multi-speaker neural synthesis", + }, + "kitten-tts": { + GRPCMethods: []GRPCMethod{MethodTTS}, + PossibleUsecases: []string{UsecaseTTS}, + DefaultUsecases: []string{UsecaseTTS}, + Description: "Kitten TTS", + }, + "outetts": { + GRPCMethods: []GRPCMethod{MethodTTS}, + PossibleUsecases: []string{UsecaseTTS}, + DefaultUsecases: []string{UsecaseTTS}, + Description: "OuteTTS", + }, + "pocket-tts": { + GRPCMethods: []GRPCMethod{MethodTTS}, + PossibleUsecases: []string{UsecaseTTS}, + DefaultUsecases: []string{UsecaseTTS}, + Description: "Pocket TTS — lightweight text-to-speech", + }, + "qwen-tts": { + GRPCMethods: []GRPCMethod{MethodTTS}, + PossibleUsecases: []string{UsecaseTTS}, + DefaultUsecases: []string{UsecaseTTS}, + Description: "Qwen TTS", + }, + "faster-qwen3-tts": { + GRPCMethods: []GRPCMethod{MethodTTS}, + PossibleUsecases: []string{UsecaseTTS}, + DefaultUsecases: []string{UsecaseTTS}, + Description: "Faster Qwen3 TTS — accelerated Qwen TTS", + }, + "fish-speech": { + GRPCMethods: []GRPCMethod{MethodTTS}, + PossibleUsecases: []string{UsecaseTTS}, + DefaultUsecases: []string{UsecaseTTS}, + Description: "Fish Speech TTS", + }, + "neutts": { + GRPCMethods: []GRPCMethod{MethodTTS}, + PossibleUsecases: []string{UsecaseTTS}, + DefaultUsecases: []string{UsecaseTTS}, + Description: "NeuTTS — neural text-to-speech", + }, + "chatterbox": { + GRPCMethods: []GRPCMethod{MethodTTS}, + PossibleUsecases: []string{UsecaseTTS}, + DefaultUsecases: []string{UsecaseTTS}, + Description: "Chatterbox TTS", + }, + "voxcpm": { + GRPCMethods: []GRPCMethod{MethodTTS, MethodTTSStream}, + PossibleUsecases: []string{UsecaseTTS}, + DefaultUsecases: []string{UsecaseTTS}, + Description: "VoxCPM TTS with streaming support", + }, + + // --- Sound generation backends --- + "ace-step": { + GRPCMethods: []GRPCMethod{MethodTTS, MethodSoundGeneration}, + PossibleUsecases: []string{UsecaseTTS, UsecaseSoundGeneration}, + DefaultUsecases: []string{UsecaseSoundGeneration}, + Description: "ACE-Step — music and sound generation", + }, + "acestep-cpp": { + GRPCMethods: []GRPCMethod{MethodSoundGeneration}, + PossibleUsecases: []string{UsecaseSoundGeneration}, + DefaultUsecases: []string{UsecaseSoundGeneration}, + Description: "ACE-Step C++ — native sound generation", + }, + "transformers-musicgen": { + GRPCMethods: []GRPCMethod{MethodTTS, MethodSoundGeneration}, + PossibleUsecases: []string{UsecaseTTS, UsecaseSoundGeneration}, + DefaultUsecases: []string{UsecaseSoundGeneration}, + Description: "Meta MusicGen via transformers — music generation from text", + }, + + // --- Utility backends --- + "rerankers": { + GRPCMethods: []GRPCMethod{MethodRerank}, + PossibleUsecases: []string{UsecaseRerank}, + DefaultUsecases: []string{UsecaseRerank}, + Description: "Cross-encoder reranking models", + }, + "rfdetr": { + GRPCMethods: []GRPCMethod{MethodDetect}, + PossibleUsecases: []string{UsecaseDetection}, + DefaultUsecases: []string{UsecaseDetection}, + Description: "RF-DETR object detection", + }, + "silero-vad": { + GRPCMethods: []GRPCMethod{MethodVAD}, + PossibleUsecases: []string{UsecaseVAD}, + DefaultUsecases: []string{UsecaseVAD}, + Description: "Silero VAD — voice activity detection", + }, +} + +// NormalizeBackendName converts backend names to the canonical hyphenated form +// used in gallery entries (e.g., "llama.cpp" → "llama-cpp"). +func NormalizeBackendName(backend string) string { + return strings.ReplaceAll(backend, ".", "-") +} + +// GetBackendCapability returns the capability info for a backend, or nil if unknown. +// Handles backend name normalization. +func GetBackendCapability(backend string) *BackendCapability { + if cap, ok := BackendCapabilities[NormalizeBackendName(backend)]; ok { + return &cap + } + return nil +} + +// PossibleUsecasesForBackend returns all usecases a backend can support. +// Returns nil if the backend is unknown. +func PossibleUsecasesForBackend(backend string) []string { + if cap := GetBackendCapability(backend); cap != nil { + return cap.PossibleUsecases + } + return nil +} + +// DefaultUsecasesForBackend returns the conservative default usecases. +// Returns nil if the backend is unknown. +func DefaultUsecasesForBackendCap(backend string) []string { + if cap := GetBackendCapability(backend); cap != nil { + return cap.DefaultUsecases + } + return nil +} + +// IsValidUsecaseForBackend checks whether a usecase is in a backend's possible set. +// Returns true for unknown backends (permissive fallback). +func IsValidUsecaseForBackend(backend, usecase string) bool { + cap := GetBackendCapability(backend) + if cap == nil { + return true // unknown backend — don't restrict + } + return slices.Contains(cap.PossibleUsecases, usecase) +} + +// AllBackendNames returns a sorted list of all known backend names. +func AllBackendNames() []string { + names := make([]string, 0, len(BackendCapabilities)) + for name := range BackendCapabilities { + names = append(names, name) + } + slices.Sort(names) + return names +} diff --git a/core/config/backend_capabilities_test.go b/core/config/backend_capabilities_test.go new file mode 100644 index 000000000000..d3ca74a18241 --- /dev/null +++ b/core/config/backend_capabilities_test.go @@ -0,0 +1,116 @@ +package config + +import ( + "slices" + "strings" + "testing" +) + +func TestBackendCapabilities_AllHaveUsecases(t *testing.T) { + for name, cap := range BackendCapabilities { + if len(cap.PossibleUsecases) == 0 { + t.Errorf("backend %q has no possible usecases", name) + } + if len(cap.DefaultUsecases) == 0 { + t.Errorf("backend %q has no default usecases", name) + } + if len(cap.GRPCMethods) == 0 { + t.Errorf("backend %q has no gRPC methods", name) + } + } +} + +func TestBackendCapabilities_DefaultsSubsetOfPossible(t *testing.T) { + for name, cap := range BackendCapabilities { + for _, d := range cap.DefaultUsecases { + if !slices.Contains(cap.PossibleUsecases, d) { + t.Errorf("backend %q: default %q not in possible %v", name, d, cap.PossibleUsecases) + } + } + } +} + +func TestBackendCapabilities_UsecasesMatchFlags(t *testing.T) { + allFlags := GetAllModelConfigUsecases() + for name, cap := range BackendCapabilities { + for _, u := range cap.PossibleUsecases { + info, ok := UsecaseInfoMap[u] + if !ok { + t.Errorf("backend %q: usecase %q not in UsecaseInfoMap", name, u) + continue + } + flagName := "FLAG_" + strings.ToUpper(u) + if _, ok := allFlags[flagName]; !ok { + // Try without transform — some names differ + found := false + for _, flag := range allFlags { + if flag == info.Flag { + found = true + break + } + } + if !found { + t.Errorf("backend %q: usecase %q flag %d not in GetAllModelConfigUsecases", name, u, info.Flag) + } + } + } + } +} + +func TestUsecaseInfoMap_AllHaveFlags(t *testing.T) { + for name, info := range UsecaseInfoMap { + if info.Flag == FLAG_ANY { + t.Errorf("usecase %q has FLAG_ANY (zero) — should have a real flag", name) + } + if info.GRPCMethod == "" { + t.Errorf("usecase %q has no gRPC method", name) + } + } +} + +func TestGetBackendCapability(t *testing.T) { + cap := GetBackendCapability("llama-cpp") + if cap == nil { + t.Fatal("llama-cpp should be known") + } + if !slices.Contains(cap.PossibleUsecases, "chat") { + t.Error("llama-cpp should support chat") + } +} + +func TestGetBackendCapability_Normalize(t *testing.T) { + cap := GetBackendCapability("llama.cpp") + if cap == nil { + t.Fatal("llama.cpp should normalize to llama-cpp") + } +} + +func TestGetBackendCapability_Unknown(t *testing.T) { + cap := GetBackendCapability("nonexistent") + if cap != nil { + t.Error("unknown backend should return nil") + } +} + +func TestIsValidUsecaseForBackend(t *testing.T) { + if !IsValidUsecaseForBackend("piper", "tts") { + t.Error("piper should support tts") + } + if IsValidUsecaseForBackend("piper", "chat") { + t.Error("piper should not support chat") + } + // Unknown backend is permissive + if !IsValidUsecaseForBackend("unknown", "anything") { + t.Error("unknown backend should allow any usecase") + } +} + +func TestAllBackendNames(t *testing.T) { + names := AllBackendNames() + if len(names) < 30 { + t.Errorf("expected 30+ backends, got %d", len(names)) + } + if !slices.IsSorted(names) { + t.Error("should be sorted") + } +} diff --git a/core/config/model_config.go b/core/config/model_config.go index a4815c766755..ecdaf723a533 100644 --- a/core/config/model_config.go +++ b/core/config/model_config.go @@ -565,11 +565,39 @@ const ( FLAG_VAD ModelConfigUsecase = 0b010000000000 FLAG_VIDEO ModelConfigUsecase = 0b100000000000 FLAG_DETECTION ModelConfigUsecase = 0b1000000000000 + FLAG_VISION ModelConfigUsecase = 0b10000000000000 // Common Subsets FLAG_LLM ModelConfigUsecase = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT ) +// ModalityGroups defines groups of usecases that belong to the same modality. +// Flags within the same group are NOT orthogonal (e.g., chat and completion are +// both text/language). A model is multimodal when its usecases span 2+ groups. +var ModalityGroups = []ModelConfigUsecase{ + FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT, // text/language + FLAG_VISION | FLAG_DETECTION, // visual understanding + FLAG_TRANSCRIPT, // speech input + FLAG_TTS | FLAG_SOUND_GENERATION, // audio output + FLAG_IMAGE | FLAG_VIDEO, // visual generation +} + +// IsMultimodal returns true if the given usecases span two or more orthogonal +// modality groups. For example chat+vision is multimodal, but chat+completion +// is not (both belong to the text/language group). +func IsMultimodal(usecases ModelConfigUsecase) bool { + groupCount := 0 + for _, group := range ModalityGroups { + if usecases&group != 0 { + groupCount++ + if groupCount >= 2 { + return true + } + } + } + return false +} + func GetAllModelConfigUsecases() map[string]ModelConfigUsecase { return map[string]ModelConfigUsecase{ // Note: FLAG_ANY is intentionally excluded from this map @@ -588,6 +616,7 @@ func GetAllModelConfigUsecases() map[string]ModelConfigUsecase { "FLAG_LLM": FLAG_LLM, "FLAG_VIDEO": FLAG_VIDEO, "FLAG_DETECTION": FLAG_DETECTION, + "FLAG_VISION": FLAG_VISION, } } diff --git a/core/gallery/gallery.go b/core/gallery/gallery.go index 0b0791afe75f..b7667b234bc7 100644 --- a/core/gallery/gallery.go +++ b/core/gallery/gallery.go @@ -7,6 +7,8 @@ import ( "path/filepath" "slices" "strings" + "sync" + "sync/atomic" "time" "github.com/lithammer/fuzzysearch/fuzzy" @@ -92,6 +94,34 @@ func (gm GalleryElements[T]) Search(term string) GalleryElements[T] { return filteredModels } +// FilterGalleryModelsByUsecase returns models whose known_usecases include all +// the bits set in usecase. For example, passing FLAG_CHAT matches any model +// with the chat usecase; passing FLAG_CHAT|FLAG_VISION matches only models +// that have both. +func FilterGalleryModelsByUsecase(models GalleryElements[*GalleryModel], usecase config.ModelConfigUsecase) GalleryElements[*GalleryModel] { + var filtered GalleryElements[*GalleryModel] + for _, m := range models { + u := m.GetKnownUsecases() + if u != nil && (*u&usecase) == usecase { + filtered = append(filtered, m) + } + } + return filtered +} + +// FilterGalleryModelsByMultimodal returns models whose known_usecases span two +// or more orthogonal modality groups (e.g. chat+vision, tts+transcript). +func FilterGalleryModelsByMultimodal(models GalleryElements[*GalleryModel]) GalleryElements[*GalleryModel] { + var filtered GalleryElements[*GalleryModel] + for _, m := range models { + u := m.GetKnownUsecases() + if u != nil && config.IsMultimodal(*u) { + filtered = append(filtered, m) + } + } + return filtered +} + func (gm GalleryElements[T]) FilterByTag(tag string) GalleryElements[T] { var filtered GalleryElements[T] for _, m := range gm { @@ -267,6 +297,77 @@ func AvailableGalleryModels(galleries []config.Gallery, systemState *system.Syst return models, nil } +var ( + availableModelsMu sync.RWMutex + availableModelsCache GalleryElements[*GalleryModel] + refreshing atomic.Bool + galleryGeneration atomic.Uint64 +) + +// GalleryGeneration returns a counter that increments each time the gallery +// model list is refreshed from upstream. VRAM estimation caches use this to +// invalidate entries when the gallery data changes. +func GalleryGeneration() uint64 { return galleryGeneration.Load() } + +// AvailableGalleryModelsCached returns gallery models from an in-memory cache. +// Local-only fields (installed status) are refreshed on every call. A background +// goroutine is triggered to re-fetch the full model list (including network +// calls) so subsequent requests pick up changes without blocking the caller. +// The first call with an empty cache blocks until the initial load completes. +func AvailableGalleryModelsCached(galleries []config.Gallery, systemState *system.SystemState) (GalleryElements[*GalleryModel], error) { + availableModelsMu.RLock() + cached := availableModelsCache + availableModelsMu.RUnlock() + + if cached != nil { + // Refresh installed status under write lock to avoid races with + // concurrent readers and the background refresh goroutine. + availableModelsMu.Lock() + for _, m := range cached { + _, err := os.Stat(filepath.Join(systemState.Model.ModelsPath, fmt.Sprintf("%s.yaml", m.GetName()))) + m.SetInstalled(err == nil) + } + availableModelsMu.Unlock() + // Trigger a background refresh if one is not already running. + triggerGalleryRefresh(galleries, systemState) + return cached, nil + } + + // No cache yet — must do a blocking load. + models, err := AvailableGalleryModels(galleries, systemState) + if err != nil { + return nil, err + } + + availableModelsMu.Lock() + availableModelsCache = models + galleryGeneration.Add(1) + availableModelsMu.Unlock() + + return models, nil +} + +// triggerGalleryRefresh starts a background goroutine that refreshes the +// gallery model cache. Only one refresh runs at a time; concurrent calls +// are no-ops. +func triggerGalleryRefresh(galleries []config.Gallery, systemState *system.SystemState) { + if !refreshing.CompareAndSwap(false, true) { + return + } + go func() { + defer refreshing.Store(false) + models, err := AvailableGalleryModels(galleries, systemState) + if err != nil { + xlog.Error("background gallery refresh failed", "error", err) + return + } + availableModelsMu.Lock() + availableModelsCache = models + galleryGeneration.Add(1) + availableModelsMu.Unlock() + }() +} + // List available backends func AvailableBackends(galleries []config.Gallery, systemState *system.SystemState) (GalleryElements[*GalleryBackend], error) { return availableBackendsWithFilter(galleries, systemState, true) diff --git a/core/gallery/importers/diffuser.go b/core/gallery/importers/diffuser.go index c702da3d3025..1060899aa7a5 100644 --- a/core/gallery/importers/diffuser.go +++ b/core/gallery/importers/diffuser.go @@ -93,7 +93,7 @@ func (i *DiffuserImporter) Import(details Details) (gallery.ModelConfig, error) modelConfig := config.ModelConfig{ Name: name, Description: description, - KnownUsecaseStrings: []string{"image"}, + KnownUsecaseStrings: []string{config.UsecaseImage}, Backend: backend, PredictionOptions: schema.PredictionOptions{ BasicModelRequest: schema.BasicModelRequest{ diff --git a/core/gallery/importers/llama-cpp.go b/core/gallery/importers/llama-cpp.go index edd9387913c3..45e91154e347 100644 --- a/core/gallery/importers/llama-cpp.go +++ b/core/gallery/importers/llama-cpp.go @@ -104,7 +104,7 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error) modelConfig := config.ModelConfig{ Name: name, Description: description, - KnownUsecaseStrings: []string{"chat"}, + KnownUsecaseStrings: []string{config.UsecaseChat}, Options: []string{"use_jinja:true"}, Backend: "llama-cpp", TemplateConfig: config.TemplateConfig{ diff --git a/core/gallery/importers/local.go b/core/gallery/importers/local.go index 2a456cc6020d..73020ceceea6 100644 --- a/core/gallery/importers/local.go +++ b/core/gallery/importers/local.go @@ -42,7 +42,7 @@ func ImportLocalPath(dirPath, name string) (*config.ModelConfig, error) { cfg := &config.ModelConfig{ Name: name, Backend: "llama-cpp", - KnownUsecaseStrings: []string{"chat"}, + KnownUsecaseStrings: []string{config.UsecaseChat}, Options: []string{"use_jinja:true"}, } cfg.Model = relPath(ggufFile) @@ -60,7 +60,7 @@ func ImportLocalPath(dirPath, name string) (*config.ModelConfig, error) { cfg := &config.ModelConfig{ Name: name, Backend: "transformers", - KnownUsecaseStrings: []string{"chat"}, + KnownUsecaseStrings: []string{config.UsecaseChat}, } cfg.Model = baseModel cfg.TemplateConfig.UseTokenizerTemplate = true @@ -76,7 +76,7 @@ func ImportLocalPath(dirPath, name string) (*config.ModelConfig, error) { cfg := &config.ModelConfig{ Name: name, Backend: "transformers", - KnownUsecaseStrings: []string{"chat"}, + KnownUsecaseStrings: []string{config.UsecaseChat}, } cfg.Model = baseModel cfg.TemplateConfig.UseTokenizerTemplate = true @@ -91,7 +91,7 @@ func ImportLocalPath(dirPath, name string) (*config.ModelConfig, error) { cfg := &config.ModelConfig{ Name: name, Backend: "transformers", - KnownUsecaseStrings: []string{"chat"}, + KnownUsecaseStrings: []string{config.UsecaseChat}, } cfg.Model = relPath(dirPath) cfg.TemplateConfig.UseTokenizerTemplate = true diff --git a/core/gallery/importers/mlx.go b/core/gallery/importers/mlx.go index 7ab513f6dd19..feac13129015 100644 --- a/core/gallery/importers/mlx.go +++ b/core/gallery/importers/mlx.go @@ -69,7 +69,7 @@ func (i *MLXImporter) Import(details Details) (gallery.ModelConfig, error) { modelConfig := config.ModelConfig{ Name: name, Description: description, - KnownUsecaseStrings: []string{"chat"}, + KnownUsecaseStrings: []string{config.UsecaseChat}, Backend: backend, PredictionOptions: schema.PredictionOptions{ BasicModelRequest: schema.BasicModelRequest{ diff --git a/core/gallery/importers/transformers.go b/core/gallery/importers/transformers.go index 5a4732ca896c..dbed7402dcf1 100644 --- a/core/gallery/importers/transformers.go +++ b/core/gallery/importers/transformers.go @@ -83,7 +83,7 @@ func (i *TransformersImporter) Import(details Details) (gallery.ModelConfig, err modelConfig := config.ModelConfig{ Name: name, Description: description, - KnownUsecaseStrings: []string{"chat"}, + KnownUsecaseStrings: []string{config.UsecaseChat}, Backend: backend, PredictionOptions: schema.PredictionOptions{ BasicModelRequest: schema.BasicModelRequest{ diff --git a/core/gallery/importers/vllm.go b/core/gallery/importers/vllm.go index 88baef1fefa8..20439da52451 100644 --- a/core/gallery/importers/vllm.go +++ b/core/gallery/importers/vllm.go @@ -73,7 +73,7 @@ func (i *VLLMImporter) Import(details Details) (gallery.ModelConfig, error) { modelConfig := config.ModelConfig{ Name: name, Description: description, - KnownUsecaseStrings: []string{"chat"}, + KnownUsecaseStrings: []string{config.UsecaseChat}, Backend: backend, PredictionOptions: schema.PredictionOptions{ BasicModelRequest: schema.BasicModelRequest{ diff --git a/core/gallery/models_types.go b/core/gallery/models_types.go index f70a5b222567..c3de03efcbfa 100644 --- a/core/gallery/models_types.go +++ b/core/gallery/models_types.go @@ -52,3 +52,23 @@ func (m *GalleryModel) GetTags() []string { func (m *GalleryModel) GetDescription() string { return m.Description } + +// GetKnownUsecases extracts known_usecases from the model's Overrides and +// returns the parsed usecase flags. Returns nil when no usecases are declared. +func (m *GalleryModel) GetKnownUsecases() *config.ModelConfigUsecase { + raw, ok := m.Overrides["known_usecases"] + if !ok { + return nil + } + list, ok := raw.([]any) + if !ok { + return nil + } + strs := make([]string, 0, len(list)) + for _, v := range list { + if s, ok := v.(string); ok { + strs = append(strs, s) + } + } + return config.GetUsecasesFromYAML(strs) +} diff --git a/core/http/endpoints/localai/config_meta.go b/core/http/endpoints/localai/config_meta.go index 22d055b999a8..103081c30eb8 100644 --- a/core/http/endpoints/localai/config_meta.go +++ b/core/http/endpoints/localai/config_meta.go @@ -120,13 +120,13 @@ func AutocompleteEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a capability := strings.TrimPrefix(provider, "models:") var filterFn config.ModelConfigFilterFn switch capability { - case "chat": + case config.UsecaseChat: filterFn = config.BuildUsecaseFilterFn(config.FLAG_CHAT) - case "tts": + case config.UsecaseTTS: filterFn = config.BuildUsecaseFilterFn(config.FLAG_TTS) - case "vad": + case config.UsecaseVAD: filterFn = config.BuildUsecaseFilterFn(config.FLAG_VAD) - case "transcript": + case config.UsecaseTranscript: filterFn = config.BuildUsecaseFilterFn(config.FLAG_TRANSCRIPT) default: filterFn = config.NoFilterFn diff --git a/core/http/endpoints/localai/import_model.go b/core/http/endpoints/localai/import_model.go index a1931bae9117..41921c6848d8 100644 --- a/core/http/endpoints/localai/import_model.go +++ b/core/http/endpoints/localai/import_model.go @@ -51,18 +51,17 @@ func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.Appl } estCtx, cancel := context.WithTimeout(c.Request().Context(), 5*time.Second) defer cancel() - result, err := vram.EstimateModel(estCtx, vram.ModelEstimateInput{ - Files: files, - Options: vram.EstimateOptions{ContextLength: 8192}, - }) + result, err := vram.EstimateModelMultiContext(estCtx, vram.ModelEstimateInput{ + Files: files, + }, []uint32{8192}) if err == nil { if result.SizeBytes > 0 { resp.EstimatedSizeBytes = result.SizeBytes resp.EstimatedSizeDisplay = result.SizeDisplay } - if result.VRAMBytes > 0 { - resp.EstimatedVRAMBytes = result.VRAMBytes - resp.EstimatedVRAMDisplay = result.VRAMDisplay + if v := result.VRAMForContext(8192); v > 0 { + resp.EstimatedVRAMBytes = v + resp.EstimatedVRAMDisplay = vram.FormatBytes(v) } } } diff --git a/core/http/endpoints/localai/vram.go b/core/http/endpoints/localai/vram.go index fe7b312bef80..5ac7b0fcaf40 100644 --- a/core/http/endpoints/localai/vram.go +++ b/core/http/endpoints/localai/vram.go @@ -2,9 +2,9 @@ package localai import ( "context" - "fmt" "net/http" "path/filepath" + "slices" "strings" "time" @@ -14,16 +14,10 @@ import ( ) type vramEstimateRequest struct { - Model string `json:"model"` // model name (must be installed) - ContextSize uint32 `json:"context_size,omitempty"` // context length to estimate for (default 8192) - GPULayers int `json:"gpu_layers,omitempty"` // number of layers to offload to GPU (0 = all) - KVQuantBits int `json:"kv_quant_bits,omitempty"` // KV cache quantization bits (0 = fp16) -} - -type vramEstimateResponse struct { - vram.EstimateResult - ContextNote string `json:"context_note,omitempty"` // note when context_size was defaulted - ModelMaxContext uint64 `json:"model_max_context,omitempty"` // model's trained maximum context length + Model string `json:"model"` // model name (must be installed) + ContextSizes []uint32 `json:"context_sizes,omitempty"` // context sizes to estimate (default [8192]) + GPULayers int `json:"gpu_layers,omitempty"` // number of layers to offload to GPU (0 = all) + KVQuantBits int `json:"kv_quant_bits,omitempty"` // KV cache quantization bits (0 = fp16) } // resolveModelURI converts a relative model path to a file:// URI so the @@ -36,8 +30,8 @@ func resolveModelURI(uri, modelsPath string) string { return "file://" + filepath.Join(modelsPath, uri) } -// addWeightFile appends a resolved weight file to files and tracks the first GGUF. -func addWeightFile(uri, modelsPath string, files *[]vram.FileInput, firstGGUF *string, seen map[string]bool) { +// addWeightFile appends a resolved weight file to files. +func addWeightFile(uri, modelsPath string, files *[]vram.FileInput, seen map[string]bool) { if !vram.IsWeightFile(uri) { return } @@ -47,21 +41,17 @@ func addWeightFile(uri, modelsPath string, files *[]vram.FileInput, firstGGUF *s } seen[resolved] = true *files = append(*files, vram.FileInput{URI: resolved, Size: 0}) - if *firstGGUF == "" && vram.IsGGUF(uri) { - *firstGGUF = resolved - } } // VRAMEstimateEndpoint returns a handler that estimates VRAM usage for an -// installed model configuration. For uninstalled models (gallery URLs), use -// the gallery-level estimates in /api/models instead. +// installed model configuration at multiple context sizes. // @Summary Estimate VRAM usage for a model -// @Description Estimates VRAM based on model weight files, context size, and GPU layers +// @Description Estimates VRAM based on model weight files at multiple context sizes // @Tags config // @Accept json // @Produce json // @Param request body vramEstimateRequest true "VRAM estimation parameters" -// @Success 200 {object} vramEstimateResponse "VRAM estimate" +// @Success 200 {object} vram.MultiContextEstimate "VRAM estimate" // @Router /api/models/vram-estimate [post] func VRAMEstimateEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { @@ -82,17 +72,16 @@ func VRAMEstimateEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applic modelsPath := appConfig.SystemState.Model.ModelsPath var files []vram.FileInput - var firstGGUF string seen := make(map[string]bool) for _, f := range modelConfig.DownloadFiles { - addWeightFile(string(f.URI), modelsPath, &files, &firstGGUF, seen) + addWeightFile(string(f.URI), modelsPath, &files, seen) } if modelConfig.Model != "" { - addWeightFile(modelConfig.Model, modelsPath, &files, &firstGGUF, seen) + addWeightFile(modelConfig.Model, modelsPath, &files, seen) } if modelConfig.MMProj != "" { - addWeightFile(modelConfig.MMProj, modelsPath, &files, &firstGGUF, seen) + addWeightFile(modelConfig.MMProj, modelsPath, &files, seen) } if len(files) == 0 { @@ -101,45 +90,36 @@ func VRAMEstimateEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applic }) } - contextDefaulted := false - opts := vram.EstimateOptions{ - ContextLength: req.ContextSize, - GPULayers: req.GPULayers, - KVQuantBits: req.KVQuantBits, - } - if opts.ContextLength == 0 { + contextSizes := req.ContextSizes + if len(contextSizes) == 0 { if modelConfig.ContextSize != nil { - opts.ContextLength = uint32(*modelConfig.ContextSize) + contextSizes = []uint32{uint32(*modelConfig.ContextSize)} } else { - opts.ContextLength = 8192 - contextDefaulted = true + contextSizes = []uint32{8192} + } + } + + // Include model's configured context size alongside requested sizes + if modelConfig.ContextSize != nil { + modelCtx := uint32(*modelConfig.ContextSize) + if !slices.Contains(contextSizes, modelCtx) { + contextSizes = append(contextSizes, modelCtx) } } + opts := vram.EstimateOptions{ + GPULayers: req.GPULayers, + KVQuantBits: req.KVQuantBits, + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - result, err := vram.Estimate(ctx, files, opts, vram.DefaultCachedSizeResolver(), vram.DefaultCachedGGUFReader()) + result, err := vram.EstimateMultiContext(ctx, files, contextSizes, opts, vram.DefaultCachedSizeResolver(), vram.DefaultCachedGGUFReader()) if err != nil { return c.JSON(http.StatusInternalServerError, map[string]any{"error": err.Error()}) } - resp := vramEstimateResponse{EstimateResult: result} - - // When context was defaulted to 8192, read the GGUF metadata to report - // the model's trained maximum context length so callers know the estimate - // may be conservative. - if contextDefaulted && firstGGUF != "" { - ggufMeta, err := vram.DefaultCachedGGUFReader().ReadMetadata(ctx, firstGGUF) - if err == nil && ggufMeta != nil && ggufMeta.MaximumContextLength > 0 { - resp.ModelMaxContext = ggufMeta.MaximumContextLength - resp.ContextNote = fmt.Sprintf( - "Estimate used default context_size=8192. The model's trained maximum context is %d; VRAM usage will be higher at larger context sizes.", - ggufMeta.MaximumContextLength, - ) - } - } - - return c.JSON(http.StatusOK, resp) + return c.JSON(http.StatusOK, result) } } diff --git a/core/http/react-ui/e2e/models-gallery.spec.js b/core/http/react-ui/e2e/models-gallery.spec.js index ed5be1e56f5a..f0936c436299 100644 --- a/core/http/react-ui/e2e/models-gallery.spec.js +++ b/core/http/react-ui/e2e/models-gallery.spec.js @@ -2,13 +2,13 @@ import { test, expect } from '@playwright/test' const MOCK_MODELS_RESPONSE = { models: [ - { name: 'llama-model', description: 'A llama model', backend: 'llama-cpp', installed: false, tags: ['llm'] }, - { name: 'whisper-model', description: 'A whisper model', backend: 'whisper', installed: true, tags: ['stt'] }, + { name: 'llama-model', description: 'A llama model', backend: 'llama-cpp', installed: false, tags: ['chat'] }, + { name: 'whisper-model', description: 'A whisper model', backend: 'whisper', installed: true, tags: ['transcript'] }, { name: 'stablediffusion-model', description: 'An image model', backend: 'stablediffusion', installed: false, tags: ['sd'] }, { name: 'unknown-model', description: 'No backend', backend: '', installed: false, tags: [] }, ], allBackends: ['llama-cpp', 'stablediffusion', 'whisper'], - allTags: ['llm', 'sd', 'stt'], + allTags: ['chat', 'sd', 'transcript'], availableModels: 4, installedModels: 1, totalPages: 1, @@ -78,3 +78,121 @@ test.describe('Models Gallery - Backend Features', () => { await expect(detail.locator('text=llama-cpp')).toBeVisible() }) }) + +const BACKEND_USECASES_MOCK = { + 'llama-cpp': ['chat', 'embeddings', 'vision'], + 'whisper': ['transcript'], + 'stablediffusion': ['image'], +} + +test.describe('Models Gallery - Multi-select Filters', () => { + test.beforeEach(async ({ page }) => { + await page.route('**/api/models*', (route) => { + route.fulfill({ + contentType: 'application/json', + body: JSON.stringify(MOCK_MODELS_RESPONSE), + }) + }) + await page.route('**/api/backends/usecases', (route) => { + route.fulfill({ + contentType: 'application/json', + body: JSON.stringify(BACKEND_USECASES_MOCK), + }) + }) + await page.goto('/app/models') + await expect(page.locator('th', { hasText: 'Backend' })).toBeVisible({ timeout: 10_000 }) + }) + + test('multi-select toggle: click Chat, TTS, then Chat again', async ({ page }) => { + const chatBtn = page.locator('.filter-btn', { hasText: 'Chat' }) + const ttsBtn = page.locator('.filter-btn', { hasText: 'TTS' }) + + await chatBtn.click() + await expect(chatBtn).toHaveClass(/active/) + + await ttsBtn.click() + await expect(chatBtn).toHaveClass(/active/) + await expect(ttsBtn).toHaveClass(/active/) + + // Click Chat again to deselect it + await chatBtn.click() + await expect(chatBtn).not.toHaveClass(/active/) + await expect(ttsBtn).toHaveClass(/active/) + }) + + test('"All" clears selection', async ({ page }) => { + const chatBtn = page.locator('.filter-btn', { hasText: 'Chat' }) + const allBtn = page.locator('.filter-btn', { hasText: 'All' }) + + await chatBtn.click() + await expect(chatBtn).toHaveClass(/active/) + + await allBtn.click() + await expect(allBtn).toHaveClass(/active/) + await expect(chatBtn).not.toHaveClass(/active/) + }) + + test('query param sent correctly with multiple filters', async ({ page }) => { + const chatBtn = page.locator('.filter-btn', { hasText: 'Chat' }) + const ttsBtn = page.locator('.filter-btn', { hasText: 'TTS' }) + + // Click Chat and wait for its request to settle + await chatBtn.click() + await page.waitForResponse(resp => resp.url().includes('/api/models')) + + // Now click TTS and capture the resulting request + const [request] = await Promise.all([ + page.waitForRequest(req => { + if (!req.url().includes('/api/models')) return false + const u = new URL(req.url()) + const tag = u.searchParams.get('tag') + return tag && tag.split(',').length >= 2 + }), + ttsBtn.click(), + ]) + + const url = new URL(request.url()) + const tags = url.searchParams.get('tag').split(',').sort() + expect(tags).toEqual(['chat', 'tts']) + }) + + test('backend greys out unavailable filters', async ({ page }) => { + // Select llama-cpp backend via dropdown + await page.locator('button', { hasText: 'All Backends' }).click() + const dropdown = page.locator('input[placeholder="Search backends..."]').locator('..').locator('..') + await dropdown.locator('text=llama-cpp').click() + + // Wait for filter state to update + const ttsBtn = page.locator('.filter-btn', { hasText: 'TTS' }) + const sttBtn = page.locator('.filter-btn', { hasText: 'STT' }) + const imageBtn = page.locator('.filter-btn', { hasText: 'Image' }) + + // TTS, STT, Image should be disabled for llama-cpp + await expect(ttsBtn).toBeDisabled() + await expect(sttBtn).toBeDisabled() + await expect(imageBtn).toBeDisabled() + + // Chat, Embeddings, Vision should remain enabled + const chatBtn = page.locator('.filter-btn', { hasText: 'Chat' }) + const embBtn = page.locator('.filter-btn', { hasText: 'Embeddings' }) + const visBtn = page.locator('.filter-btn', { hasText: 'Vision' }) + await expect(chatBtn).toBeEnabled() + await expect(embBtn).toBeEnabled() + await expect(visBtn).toBeEnabled() + }) + + test('backend clears incompatible filters', async ({ page }) => { + // Select TTS filter first + const ttsBtn = page.locator('.filter-btn', { hasText: 'TTS' }) + await ttsBtn.click() + await expect(ttsBtn).toHaveClass(/active/) + + // Now select llama-cpp backend (which doesn't support TTS) + await page.locator('button', { hasText: 'All Backends' }).click() + const dropdown = page.locator('input[placeholder="Search backends..."]').locator('..').locator('..') + await dropdown.locator('text=llama-cpp').click() + + // TTS should be auto-removed from selection + await expect(ttsBtn).not.toHaveClass(/active/) + }) +}) diff --git a/core/http/react-ui/src/pages/Backends.jsx b/core/http/react-ui/src/pages/Backends.jsx index 61e6468b9b69..aa63b463ebf4 100644 --- a/core/http/react-ui/src/pages/Backends.jsx +++ b/core/http/react-ui/src/pages/Backends.jsx @@ -139,11 +139,11 @@ export default function Backends() { const FILTERS = [ { key: '', label: 'All', icon: 'fa-layer-group' }, - { key: 'llm', label: 'LLM', icon: 'fa-brain' }, + { key: 'chat', label: 'Chat', icon: 'fa-brain' }, { key: 'image', label: 'Image', icon: 'fa-image' }, { key: 'video', label: 'Video', icon: 'fa-video' }, { key: 'tts', label: 'TTS', icon: 'fa-microphone' }, - { key: 'stt', label: 'STT', icon: 'fa-headphones' }, + { key: 'transcript', label: 'STT', icon: 'fa-headphones' }, { key: 'vision', label: 'Vision', icon: 'fa-eye' }, ] diff --git a/core/http/react-ui/src/pages/Models.jsx b/core/http/react-ui/src/pages/Models.jsx index fd1f3f6c9bd9..2aaab9d5ff6d 100644 --- a/core/http/react-ui/src/pages/Models.jsx +++ b/core/http/react-ui/src/pages/Models.jsx @@ -86,16 +86,19 @@ function GalleryLoader() { } +const CONTEXT_SIZES = [8192, 16384, 32768, 65536, 131072, 262144] +const CONTEXT_LABELS = ['8K', '16K', '32K', '64K', '128K', '256K'] + const FILTERS = [ { key: '', label: 'All', icon: 'fa-layer-group' }, - { key: 'llm', label: 'LLM', icon: 'fa-brain' }, - { key: 'sd', label: 'Image', icon: 'fa-image' }, + { key: 'chat', label: 'Chat', icon: 'fa-brain' }, + { key: 'image', label: 'Image', icon: 'fa-image' }, { key: 'multimodal', label: 'Multimodal', icon: 'fa-shapes' }, { key: 'vision', label: 'Vision', icon: 'fa-eye' }, { key: 'tts', label: 'TTS', icon: 'fa-microphone' }, - { key: 'stt', label: 'STT', icon: 'fa-headphones' }, - { key: 'embedding', label: 'Embedding', icon: 'fa-vector-square' }, - { key: 'reranker', label: 'Rerank', icon: 'fa-sort' }, + { key: 'transcript', label: 'STT', icon: 'fa-headphones' }, + { key: 'embeddings', label: 'Embeddings', icon: 'fa-vector-square' }, + { key: 'rerank', label: 'Rerank', icon: 'fa-sort' }, ] export default function Models() { @@ -108,7 +111,7 @@ export default function Models() { const [page, setPage] = useState(1) const [totalPages, setTotalPages] = useState(1) const [search, setSearch] = useState('') - const [filter, setFilter] = useState('') + const [filters, setFilters] = useState([]) const [sort, setSort] = useState('') const [order, setOrder] = useState('asc') const [installing, setInstalling] = useState(new Map()) @@ -117,6 +120,9 @@ export default function Models() { const [stats, setStats] = useState({ total: 0, installed: 0, repositories: 0 }) const [backendFilter, setBackendFilter] = useState('') const [allBackends, setAllBackends] = useState([]) + const [backendUsecases, setBackendUsecases] = useState({}) + const [estimates, setEstimates] = useState({}) + const [contextSize, setContextSize] = useState(CONTEXT_SIZES[0]) const debounceRef = useRef(null) const [confirmDialog, setConfirmDialog] = useState(null) @@ -127,14 +133,14 @@ export default function Models() { try { setLoading(true) const searchVal = params.search !== undefined ? params.search : search - const filterVal = params.filter !== undefined ? params.filter : filter + const filtersVal = params.filters !== undefined ? params.filters : filters const sortVal = params.sort !== undefined ? params.sort : sort const backendVal = params.backendFilter !== undefined ? params.backendFilter : backendFilter const queryParams = { page: params.page || page, items: 9, } - if (filterVal) queryParams.tag = filterVal + if (filtersVal.length > 0) queryParams.tag = filtersVal.join(',') if (searchVal) queryParams.term = searchVal if (backendVal) queryParams.backend = backendVal if (sortVal) { @@ -154,17 +160,50 @@ export default function Models() { } finally { setLoading(false) } - }, [page, search, filter, sort, order, backendFilter, addToast]) + }, [page, search, filters, sort, order, backendFilter, addToast]) useEffect(() => { fetchModels() - }, [page, filter, sort, order, backendFilter]) + }, [page, filters, sort, order, backendFilter]) + + // Fetch backend→usecase mapping once on mount + useEffect(() => { + modelsApi.backendUsecases().then(setBackendUsecases).catch(() => {}) + }, []) + + // When backend changes, remove selected filters that aren't available + useEffect(() => { + if (backendFilter && backendUsecases[backendFilter]) { + setFilters(prev => { + const possible = backendUsecases[backendFilter] + const filtered = prev.filter(k => k === 'multimodal' || possible.includes(k)) + return filtered.length !== prev.length ? filtered : prev + }) + } + }, [backendFilter, backendUsecases]) // Re-fetch when operations change (install/delete completion) useEffect(() => { if (!loading) fetchModels() }, [operations.length]) + // Fetch VRAM/size estimates asynchronously for visible models. + useEffect(() => { + if (models.length === 0) return + let cancelled = false + models.forEach(model => { + const id = model.name || model.id + if (estimates[id]) return + modelsApi.estimate(id, CONTEXT_SIZES).then(est => { + if (cancelled) return + if (est && (est.sizeBytes || est.estimates)) { + setEstimates(prev => ({ ...prev, [id]: est })) + } + }).catch(() => {}) + }) + return () => { cancelled = true } + }, [models]) + const handleSearch = (value) => { setSearch(value) if (debounceRef.current) clearTimeout(debounceRef.current) @@ -174,6 +213,20 @@ export default function Models() { }, 500) } + const toggleFilter = (key) => { + if (key === '') { setFilters([]); setPage(1); return } + setFilters(prev => + prev.includes(key) ? prev.filter(k => k !== key) : [...prev, key] + ) + setPage(1) + } + + const isFilterAvailable = (key) => { + if (!backendFilter || key === '' || key === 'multimodal') return true + const possible = backendUsecases[backendFilter] + return !possible || possible.includes(key) + } + const handleSort = (col) => { if (sort === col) { setOrder(o => o === 'asc' ? 'desc' : 'asc') @@ -292,16 +345,23 @@ export default function Models() { {/* Filter buttons */}
- {search || filter || backendFilter + {search || filters.length > 0 || backendFilter ? 'No models match your current search or filters.' : 'The model gallery is empty.'}
- {(search || filter || backendFilter) && ( + {(search || filters.length > 0 || backendFilter) && ( @@ -359,9 +438,14 @@ export default function Models() { {models.map((model, idx) => { const name = model.name || model.id + const estData = estimates[name] + const sizeDisplay = estData?.sizeDisplay + const ctxEst = estData?.estimates?.[String(contextSize)] + const vramDisplay = ctxEst?.vramDisplay + const vramBytes = ctxEst?.vramBytes const installing = isInstalling(name) const progress = getOperationProgress(name) - const fit = fitsGpu(model.estimated_vram_bytes) + const fit = fitsGpu(vramBytes) const isExpanded = expandedRow === idx return ( @@ -428,15 +512,15 @@ export default function Models() { {/* Size / VRAM */}