diff --git a/cmd/aima/context.go b/cmd/aima/context.go index 89265d4..f6d49db 100644 --- a/cmd/aima/context.go +++ b/cmd/aima/context.go @@ -1,6 +1,8 @@ package main import ( + "time" + "github.com/jguan/aima/internal/agent" "github.com/jguan/aima/internal/k3s" "github.com/jguan/aima/internal/knowledge" @@ -15,17 +17,18 @@ import ( // It collects the local variables that buildToolDeps() closures previously // captured, enabling those closures to be split into separate files. type appContext struct { - cat *knowledge.Catalog - db *state.DB - kStore *knowledge.Store - rt runtime.Runtime // default runtime (K3S > Docker > Native) - nativeRt runtime.Runtime - dockerRt runtime.Runtime - k3sRt runtime.Runtime - proxy *proxy.Server - k3s *k3s.Client - dataDir string - digests map[string]string // factory catalog digests - support *support.Service - eventBus *agent.EventBus // shared EventBus for Explorer events + cat *knowledge.Catalog + db *state.DB + kStore *knowledge.Store + rt runtime.Runtime // default runtime (K3S > Docker > Native) + nativeRt runtime.Runtime + dockerRt runtime.Runtime + k3sRt runtime.Runtime + proxy *proxy.Server + k3s *k3s.Client + dataDir string + catalogLoadedAt time.Time + digests map[string]string // factory catalog digests + support *support.Service + eventBus *agent.EventBus // shared EventBus for Explorer events } diff --git a/cmd/aima/diagnostics.go b/cmd/aima/diagnostics.go index 60c9b7c..32369d2 100644 --- a/cmd/aima/diagnostics.go +++ b/cmd/aima/diagnostics.go @@ -92,7 +92,8 @@ func buildDiagnosticsBundle(ctx context.Context, ac *appContext, deps *mcp.ToolD "goarch": goruntime.GOARCH, "data_dir": redactHomePath(dataDir), }, - "sections": map[string]any{}, + "service_context": serviceContextStatus(ac), + "sections": map[string]any{}, } sections := bundle["sections"].(map[string]any) diff --git a/cmd/aima/diagnostics_test.go b/cmd/aima/diagnostics_test.go index 9612998..61cd8ee 100644 --- a/cmd/aima/diagnostics_test.go +++ b/cmd/aima/diagnostics_test.go @@ -124,3 +124,30 @@ func TestBuildDiagnosticsBundleRecordsSectionErrors(t *testing.T) { t.Fatalf("hardware error = %v, want context canceled", hardware["error"]) } } + +func TestServiceContextReportsStaleOverlayHint(t *testing.T) { + tmp := t.TempDir() + overlayDir := tmp + string(os.PathSeparator) + "catalog" + string(os.PathSeparator) + "user" + string(os.PathSeparator) + "models" + if err := os.MkdirAll(overlayDir, 0o755); err != nil { + t.Fatalf("mkdir overlay: %v", err) + } + overlayFile := overlayDir + string(os.PathSeparator) + "demo.patch.yaml" + if err := os.WriteFile(overlayFile, []byte("kind: model_asset_patch\nmetadata:\n name: demo\n"), 0o644); err != nil { + t.Fatalf("write overlay: %v", err) + } + newer := time.Now().Add(5 * time.Minute) + if err := os.Chtimes(overlayFile, newer, newer); err != nil { + t.Fatalf("chtimes overlay: %v", err) + } + + status := serviceContextStatus(&appContext{ + dataDir: tmp, + catalogLoadedAt: time.Now(), + }) + if status["overlay_newer_than_catalog"] != true { + t.Fatalf("overlay_newer_than_catalog = %v, want true; status=%v", status["overlay_newer_than_catalog"], status) + } + if status["reload_hint"] == "" { + t.Fatalf("reload_hint missing: %v", status) + } +} diff --git a/cmd/aima/main.go b/cmd/aima/main.go index 92a11b6..1f4f7ba 100644 --- a/cmd/aima/main.go +++ b/cmd/aima/main.go @@ -154,18 +154,19 @@ func run() error { mcpServer := mcp.NewServer() supportSvc := support.NewService(db, support.WithLogger(slog.Default())) ac := &appContext{ - cat: cat, - db: db, - kStore: knowledgeStore, - rt: rt, - nativeRt: nativeRt, - dockerRt: dockerRt, - k3sRt: k3sRt, - proxy: proxyServer, - k3s: k3sClient, - dataDir: dataDir, - digests: factoryDigests, - support: supportSvc, + cat: cat, + db: db, + kStore: knowledgeStore, + rt: rt, + nativeRt: nativeRt, + dockerRt: dockerRt, + k3sRt: k3sRt, + proxy: proxyServer, + k3s: k3sClient, + dataDir: dataDir, + catalogLoadedAt: time.Now().UTC(), + digests: factoryDigests, + support: supportSvc, } deps := buildToolDeps(ac) diff --git a/cmd/aima/resolve.go b/cmd/aima/resolve.go index 33a38ab..ed322b3 100644 --- a/cmd/aima/resolve.go +++ b/cmd/aima/resolve.go @@ -29,9 +29,11 @@ var autoDetectWarned sync.Map // resolvedDeployment holds the shared result of resolve + CheckFit, // used by both DeployApply and DeployDryRun. type resolvedDeployment struct { - ModelName string - Resolved *knowledge.ResolvedConfig - Fit *knowledge.FitReport + ModelName string + Resolved *knowledge.ResolvedConfig + ResolvedConfig map[string]any + ResolvedProvenance map[string]string + Fit *knowledge.FitReport } // queryGoldenOverrides returns config overrides from the best golden configuration @@ -117,6 +119,8 @@ func resolveDeployment(ctx context.Context, cat *knowledge.Catalog, db *state.DB return nil, err } + resolvedConfig := cloneAnyMap(resolved.Config) + resolvedProvenance := cloneStringMap(resolved.Provenance) fit := knowledge.CheckFit(resolved, hwInfo) for k, v := range fit.Adjustments { resolved.Config[k] = v @@ -124,12 +128,36 @@ func resolveDeployment(ctx context.Context, cat *knowledge.Catalog, db *state.DB } return &resolvedDeployment{ - ModelName: canonicalName, - Resolved: resolved, - Fit: fit, + ModelName: canonicalName, + Resolved: resolved, + ResolvedConfig: resolvedConfig, + ResolvedProvenance: resolvedProvenance, + Fit: fit, }, nil } +func cloneAnyMap(in map[string]any) map[string]any { + if in == nil { + return nil + } + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func cloneStringMap(in map[string]string) map[string]string { + if in == nil { + return nil + } + out := make(map[string]string, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + // normalizeAutoPortOverrides removes "auto" sentinels from port-like override keys // before resolution. This preserves the engine YAML default port so Go-side host // port allocation can still choose a free host port later in deploy.apply. diff --git a/cmd/aima/resolve_test.go b/cmd/aima/resolve_test.go index 2895970..4c3f399 100644 --- a/cmd/aima/resolve_test.go +++ b/cmd/aima/resolve_test.go @@ -363,6 +363,70 @@ func TestNormalizeAutoPortOverrides(t *testing.T) { } } +func TestResolveDeploymentKeepsResolvedAndEffectiveConfigSeparate(t *testing.T) { + ctx := context.Background() + db, err := state.Open(ctx, ":memory:") + if err != nil { + t.Fatalf("Open: %v", err) + } + defer db.Close() + + cat := &knowledge.Catalog{ + EngineAssets: []knowledge.EngineAsset{{ + Metadata: knowledge.EngineMetadata{ + Name: "vllm-test", + Type: "vllm", + Version: "1.0", + SupportedFormats: []string{"safetensors"}, + }, + Hardware: knowledge.EngineHardware{GPUArch: "*"}, + Startup: knowledge.EngineStartup{ + Command: []string{"vllm", "serve", "{{.ModelPath}}"}, + DefaultArgs: map[string]any{"gpu_memory_utilization": 0.85, "port": 8000}, + }, + Runtime: knowledge.EngineRuntime{Default: "container"}, + }}, + ModelAssets: []knowledge.ModelAsset{{ + Metadata: knowledge.ModelMetadata{Name: "demo-model", Type: "llm"}, + Storage: knowledge.ModelStorage{DefaultPathPattern: "/models/demo"}, + Variants: []knowledge.ModelVariant{{ + Name: "demo-model-vllm", + Engine: "vllm", + Format: "safetensors", + Hardware: knowledge.ModelVariantHardware{GPUArch: "Blackwell"}, + }}, + }}, + } + + rd, err := resolveDeployment(ctx, cat, db, nil, knowledge.HardwareInfo{ + GPUArch: "Blackwell", + GPUVRAMMiB: 122880, + GPUMemFreeMiB: 64000, + GPUMemUsedMiB: 58880, + UnifiedMemory: false, + Platform: "linux/arm64", + }, "demo-model", "vllm", "", nil, t.TempDir()) + if err != nil { + t.Fatalf("resolveDeployment: %v", err) + } + + if got := rd.ResolvedConfig["gpu_memory_utilization"]; got != 0.85 { + t.Fatalf("resolved_config gpu_memory_utilization = %v, want 0.85", got) + } + if got := rd.Resolved.Config["gpu_memory_utilization"]; got != 0.51 { + t.Fatalf("effective config gpu_memory_utilization = %v, want 0.51", got) + } + if got := rd.Fit.Adjustments["gpu_memory_utilization"]; got != 0.51 { + t.Fatalf("fit adjustment gpu_memory_utilization = %v, want 0.51", got) + } + if got := rd.ResolvedProvenance["gpu_memory_utilization"]; got != "L0" { + t.Fatalf("resolved provenance = %q, want L0", got) + } + if got := rd.Resolved.Provenance["gpu_memory_utilization"]; got != "L0-auto" { + t.Fatalf("effective provenance = %q, want L0-auto", got) + } +} + func TestResolveCatalogWithLocalEngineOverlayUsesInstalledContainerAsset(t *testing.T) { ctx := context.Background() db, err := state.Open(ctx, ":memory:") diff --git a/cmd/aima/service_context.go b/cmd/aima/service_context.go new file mode 100644 index 0000000..4482987 --- /dev/null +++ b/cmd/aima/service_context.go @@ -0,0 +1,65 @@ +package main + +import ( + "os" + "path/filepath" + "strings" + "time" +) + +func serviceContextStatus(ac *appContext) map[string]any { + dataDir := "" + var loadedAt time.Time + if ac != nil { + dataDir = strings.TrimSpace(ac.dataDir) + loadedAt = ac.catalogLoadedAt + } + overlayDir := "" + if dataDir != "" { + overlayDir = filepath.Join(dataDir, "catalog") + } + latestOverlay := latestModTime(overlayDir) + + status := map[string]any{ + "data_dir": dataDir, + "overlay_dir": overlayDir, + } + if !loadedAt.IsZero() { + status["catalog_loaded_at"] = loadedAt.Format(time.RFC3339) + } + if !latestOverlay.IsZero() { + status["overlay_latest_mtime"] = latestOverlay.Format(time.RFC3339) + if !loadedAt.IsZero() && latestOverlay.After(loadedAt) { + status["overlay_newer_than_catalog"] = true + status["reload_hint"] = "catalog overlays changed after this AIMA process loaded; restart aima-serve or reload catalog before trusting UI dry-run results" + } + } + if user := os.Getenv("USER"); user != "" { + status["user"] = user + } + if home, err := os.UserHomeDir(); err == nil && home != "" { + status["home_dir"] = home + } + return status +} + +func latestModTime(root string) time.Time { + if strings.TrimSpace(root) == "" { + return time.Time{} + } + var latest time.Time + _ = filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error { + if err != nil || d == nil { + return nil + } + info, statErr := d.Info() + if statErr != nil { + return nil + } + if info.ModTime().After(latest) { + latest = info.ModTime() + } + return nil + }) + return latest +} diff --git a/cmd/aima/tooldeps_deploy.go b/cmd/aima/tooldeps_deploy.go index 9df4e86..f9cc4c6 100644 --- a/cmd/aima/tooldeps_deploy.go +++ b/cmd/aima/tooldeps_deploy.go @@ -91,6 +91,7 @@ func buildDeployDeps(ac *appContext, deps *mcp.ToolDeps, PortSpecs: append([]knowledge.StartupPort(nil), resolved.PortSpecs...), InitCommands: resolved.InitCommands, ModelPath: modelPath, + ModelType: catalogModelType(cat, modelName), Config: resolved.Config, RuntimeClassName: resolved.RuntimeClassName, CPUArch: resolved.CPUArch, @@ -346,14 +347,19 @@ func buildDeployDeps(ac *appContext, deps *mcp.ToolDeps, } result := map[string]any{ - "model": rd.ModelName, - "engine": resolved.Engine, - "engine_image": resolved.EngineImage, - "slot": resolved.Slot, - "runtime": runtimeName, - "config": resolved.Config, - "ports": knowledge.ResolvePortBindingsFromSpecs(resolved.PortSpecs, resolved.Config), - "provenance": resolved.Provenance, + "model": rd.ModelName, + "engine": resolved.Engine, + "engine_image": resolved.EngineImage, + "slot": resolved.Slot, + "runtime": runtimeName, + "config": resolved.Config, + "resolved_config": rd.ResolvedConfig, + "effective_config": resolved.Config, + "fit_adjustments": rd.Fit.Adjustments, + "ports": knowledge.ResolvePortBindingsFromSpecs(resolved.PortSpecs, resolved.Config), + "provenance": resolved.Provenance, + "resolved_provenance": rd.ResolvedProvenance, + "effective_provenance": resolved.Provenance, "fit_report": map[string]any{ "fit": rd.Fit.Fit, "reason": rd.Fit.Reason, diff --git a/cmd/aima/tooldeps_knowledge.go b/cmd/aima/tooldeps_knowledge.go index a6746e1..bbe43d3 100644 --- a/cmd/aima/tooldeps_knowledge.go +++ b/cmd/aima/tooldeps_knowledge.go @@ -591,10 +591,11 @@ func buildKnowledgeDeps(ac *appContext, deps *mcp.ToolDeps) { } } status := map[string]any{ - "factory_assets": catalogSize(factoryCat), - "overlay_assets": catalogSize(overlayCat), - "shadowed": shadowed, - "parse_warnings": parseWarnings, + "factory_assets": catalogSize(factoryCat), + "overlay_assets": catalogSize(overlayCat), + "shadowed": shadowed, + "parse_warnings": parseWarnings, + "service_context": serviceContextStatus(ac), } return json.Marshal(status) } diff --git a/cmd/aima/tooldeps_system.go b/cmd/aima/tooldeps_system.go index c0bfe47..25ad773 100644 --- a/cmd/aima/tooldeps_system.go +++ b/cmd/aima/tooldeps_system.go @@ -132,6 +132,9 @@ func buildSystemDeps(ac *appContext, deps *mcp.ToolDeps) { if b, e := json.Marshal(ac.support.Status(ctx)); e == nil { status["support"] = b } + if b, e := json.Marshal(serviceContextStatus(ac)); e == nil { + status["service_context"] = b + } if deps.OpenClawStatus != nil { if b, e := deps.OpenClawStatus(ctx); e == nil { status["openclaw"] = b diff --git a/internal/knowledge/configflags.go b/internal/knowledge/configflags.go index 3833b98..7ef9ba9 100644 --- a/internal/knowledge/configflags.go +++ b/internal/knowledge/configflags.go @@ -36,18 +36,81 @@ func FormatConfigFlag(key string, value any) []string { } } +// ConfigFlagContext describes the runtime command surface used to decide +// whether a resolved config key is a real CLI argument or only a resolver hint. +type ConfigFlagContext struct { + Command []string + ModelPath string + Engine string + ModelType string +} + // ShouldIncludeConfigFlag reports whether a resolved config key should be emitted // as a runtime CLI flag for the given startup command and local model path. // Some keys, such as quantization, are selection hints for a model artifact rather // than portable runtime flags across every engine. func ShouldIncludeConfigFlag(command []string, modelPath, key string, value any) bool { + return ShouldIncludeConfigFlagFor(ConfigFlagContext{Command: command, ModelPath: modelPath}, key, value) +} + +// ShouldIncludeConfigFlagFor is the engine-aware form of ShouldIncludeConfigFlag. +// It keeps legacy behavior for unknown engines, but prevents LLM-only knobs from +// leaking into image/audio/OCR service wrappers that do not expose those flags. +func ShouldIncludeConfigFlagFor(ctx ConfigFlagContext, key string, value any) bool { switch strings.ToLower(strings.TrimSpace(key)) { case "": return false case "quantization": - return shouldIncludeQuantizationFlag(command, modelPath, value) + return shouldIncludeQuantizationFlag(ctx.Command, ctx.ModelPath, value) + default: + if isLLMOnlyConfigKey(key) && !commandAcceptsLLMConfig(ctx) { + return false + } + return true + } +} + +func isLLMOnlyConfigKey(key string) bool { + switch strings.ToLower(strings.TrimSpace(key)) { + case "max_model_len", "max_seq_len", "max_seq_length", "max_context_len", "max_context_tokens", + "context_length", "ctx_size", "n_ctx", "gpu_memory_utilization", "mem_fraction_static", + "tensor_parallel_size", "pipeline_parallel_size", "kv_cache_dtype", "dtype", "trust_remote_code", + "enforce_eager", "disable_log_stats", "served_model_name", "speculative_config", + "mm_encoder_attn_backend": + return true default: + return false + } +} + +func commandAcceptsLLMConfig(ctx ConfigFlagContext) bool { + if isLLMModelType(ctx.ModelType) { + return true + } + if strings.TrimSpace(ctx.Engine) == "" && strings.TrimSpace(ctx.ModelType) == "" && len(ctx.Command) == 0 { + return true + } + for _, value := range []string{ctx.Engine, strings.Join(ctx.Command, " ")} { + lower := strings.ToLower(value) + switch { + case strings.Contains(lower, "vllm"), + strings.Contains(lower, "sglang"), + strings.Contains(lower, "llama"), + strings.Contains(lower, "ollama"), + strings.Contains(lower, "transformers serve"), + strings.Contains(lower, "qwen-asr-serve"): + return true + } + } + return false +} + +func isLLMModelType(modelType string) bool { + switch strings.ToLower(strings.TrimSpace(modelType)) { + case "llm", "vlm", "embedding": return true + default: + return false } } diff --git a/internal/knowledge/configflags_test.go b/internal/knowledge/configflags_test.go index a0db97b..7c0c3ec 100644 --- a/internal/knowledge/configflags_test.go +++ b/internal/knowledge/configflags_test.go @@ -127,4 +127,34 @@ func TestShouldIncludeConfigFlag(t *testing.T) { t.Fatal("quantization flag should be kept when config.json declares quantization_config") } }) + + t.Run("skip LLM-only flags for non-LLM service wrappers", func(t *testing.T) { + for _, key := range []string{"max_model_len", "gpu_memory_utilization", "mem_fraction_static"} { + if ShouldIncludeConfigFlagFor( + ConfigFlagContext{ + Command: []string{"python3", "server.py"}, + Engine: "z-image-diffusers", + ModelType: "image_gen", + }, + key, + 8192, + ) { + t.Fatalf("%s should be omitted for image service wrappers", key) + } + } + }) + + t.Run("keep LLM-only flags for vLLM-like service wrappers", func(t *testing.T) { + if !ShouldIncludeConfigFlagFor( + ConfigFlagContext{ + Command: []string{"qwen-asr-serve", "{{.ModelPath}}"}, + Engine: "vllm-nightly-audio", + ModelType: "asr", + }, + "max_model_len", + 8192, + ) { + t.Fatal("max_model_len should be kept for vLLM audio wrappers") + } + }) } diff --git a/internal/knowledge/podgen.go b/internal/knowledge/podgen.go index ecd490d..f5f00c5 100644 --- a/internal/knowledge/podgen.go +++ b/internal/knowledge/podgen.go @@ -283,7 +283,12 @@ func GeneratePod(resolved *ResolvedConfig) ([]byte, error) { if _, reserved := portKeys[k]; reserved { continue } - if !ShouldIncludeConfigFlag(resolved.Command, resolved.ModelPath, k, resolved.Config[k]) { + if !ShouldIncludeConfigFlagFor(ConfigFlagContext{ + Command: resolved.Command, + ModelPath: resolved.ModelPath, + Engine: resolved.Engine, + ModelType: resolved.ModelType, + }, k, resolved.Config[k]) { continue } flagName := "--" + strings.ReplaceAll(k, "_", "-") diff --git a/internal/knowledge/podgen_test.go b/internal/knowledge/podgen_test.go index c05908b..a7ab802 100644 --- a/internal/knowledge/podgen_test.go +++ b/internal/knowledge/podgen_test.go @@ -395,3 +395,35 @@ func TestGeneratePodEnvMerge(t *testing.T) { t.Errorf("SHARED_VAR = %q, want engine-wins (engine overrides hw)", envMap["SHARED_VAR"]) } } + +func TestGeneratePodFiltersLLMOnlyArgsForImageService(t *testing.T) { + resolved := &ResolvedConfig{ + Engine: "z-image-diffusers", + EngineImage: "qujing-z-image:latest", + ModelPath: "/data/models/stable-diffusion-v1-5", + ModelName: "stable-diffusion-v1-5", + ModelType: "image_gen", + Slot: "default", + Config: map[string]any{ + "port": 8188, + "max_model_len": 8192, + "gpu_memory_utilization": 0.5, + }, + Command: []string{"python3", "server.py"}, + PortSpecs: []StartupPort{ + {Name: "http", Flag: "--port", ConfigKey: "port", Primary: true}, + }, + } + + podYAML, err := GeneratePod(resolved) + if err != nil { + t.Fatalf("GeneratePod: %v", err) + } + s := string(podYAML) + if strings.Contains(s, "--max-model-len") || strings.Contains(s, "--gpu-memory-utilization") { + t.Fatalf("LLM-only args should not be emitted for image services:\n%s", s) + } + if !strings.Contains(s, "--port") { + t.Fatalf("expected service port flag to remain:\n%s", s) + } +} diff --git a/internal/knowledge/resolver.go b/internal/knowledge/resolver.go index b30d3ec..ad262cc 100644 --- a/internal/knowledge/resolver.go +++ b/internal/knowledge/resolver.go @@ -51,6 +51,7 @@ type ResolvedConfig struct { EngineImage string ModelPath string ModelName string + ModelType string ModelFormat string Slot string Config map[string]any @@ -221,6 +222,7 @@ func (c *Catalog) Resolve(hw HardwareInfo, modelName, engineType string, userOve Engine: engineType, EngineAssetName: engineAssetName, ModelName: model.Metadata.Name, + ModelType: model.Metadata.Type, ModelFormat: variant.Format, Slot: slot.Name, Config: config, diff --git a/internal/runtime/docker.go b/internal/runtime/docker.go index cc37972..3fa0698 100644 --- a/internal/runtime/docker.go +++ b/internal/runtime/docker.go @@ -208,7 +208,12 @@ func (r *DockerRuntime) buildRunArgs(name string, req *DeployRequest) []string { command = knowledge.AppendPortBindings(command, portBindings) // Append config values as CLI flags, with template substitution - for _, f := range configToFlags(req.Config, req.Command, req.ModelPath, knowledge.PortConfigKeys(req.PortSpecs)) { + for _, f := range configToFlagsFor(req.Config, knowledge.ConfigFlagContext{ + Command: req.Command, + ModelPath: req.ModelPath, + Engine: req.Engine, + ModelType: req.ModelType, + }, knowledge.PortConfigKeys(req.PortSpecs)) { f = strings.ReplaceAll(f, "{{.ModelName}}", req.Name) f = strings.ReplaceAll(f, "{{.ModelPath}}", containerModelPath) command = append(command, f) diff --git a/internal/runtime/k3s.go b/internal/runtime/k3s.go index 34f5ffb..ba0cf28 100644 --- a/internal/runtime/k3s.go +++ b/internal/runtime/k3s.go @@ -149,6 +149,7 @@ func toResolvedConfig(req *DeployRequest) *knowledge.ResolvedConfig { EngineImage: req.Image, ModelPath: req.ModelPath, ModelName: req.Name, + ModelType: req.ModelType, Slot: slot, Config: config, Command: req.Command, diff --git a/internal/runtime/native.go b/internal/runtime/native.go index 3956abd..45db917 100644 --- a/internal/runtime/native.go +++ b/internal/runtime/native.go @@ -154,7 +154,12 @@ func (r *NativeRuntime) Deploy(ctx context.Context, req *DeployRequest) error { command = knowledge.AppendPortBindings(command, portBindings) // Append other config values as CLI flags, with template substitution - for _, f := range configToFlags(req.Config, req.Command, req.ModelPath, knowledge.PortConfigKeys(req.PortSpecs)) { + for _, f := range configToFlagsFor(req.Config, knowledge.ConfigFlagContext{ + Command: req.Command, + ModelPath: req.ModelPath, + Engine: req.Engine, + ModelType: req.ModelType, + }, knowledge.PortConfigKeys(req.PortSpecs)) { f = strings.ReplaceAll(f, "{{.ModelName}}", req.Name) f = strings.ReplaceAll(f, "{{.ModelPath}}", req.ModelPath) command = append(command, f) diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index 43b643c..b75f3f5 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -34,6 +34,7 @@ type DeployRequest struct { PortSpecs []knowledge.StartupPort InitCommands []string // pre-commands to run before main server (K3S, Docker) ModelPath string // host path to model files + ModelType string // catalog model type (llm, asr, tts, image_gen, ...) Port int // legacy fallback; prefer Config + PortSpecs Config map[string]any Partition *PartitionRequest // resource limits (K3S+HAMi); native ignores @@ -147,6 +148,10 @@ func isPortFlag(s string) bool { // ShouldIncludeConfigFlag. Serialization rules (bool/map/scalar) live in // FormatConfigFlag so K3S podgen and Docker/Native runtime stay consistent. func configToFlags(config map[string]any, command []string, modelPath string, reservedKeys map[string]struct{}) []string { + return configToFlagsFor(config, knowledge.ConfigFlagContext{Command: command, ModelPath: modelPath}, reservedKeys) +} + +func configToFlagsFor(config map[string]any, flagCtx knowledge.ConfigFlagContext, reservedKeys map[string]struct{}) []string { if len(config) == 0 { return nil } @@ -155,7 +160,7 @@ func configToFlags(config map[string]any, command []string, modelPath string, re if _, reserved := reservedKeys[k]; reserved { continue } - if !knowledge.ShouldIncludeConfigFlag(command, modelPath, k, config[k]) { + if !knowledge.ShouldIncludeConfigFlagFor(flagCtx, k, config[k]) { continue } keys = append(keys, k) diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index 8b75589..81fa599 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -4,6 +4,8 @@ import ( "encoding/json" "strings" "testing" + + "github.com/jguan/aima/internal/knowledge" ) func TestConfigToFlagsSkipsSelectionOnlyQuantization(t *testing.T) { @@ -81,3 +83,23 @@ func TestConfigToFlagsMapEmitsJSON(t *testing.T) { t.Fatalf("expected method=mtp, got %v", parsed["method"]) } } + +func TestConfigToFlagsFiltersLLMOnlyArgsForImageService(t *testing.T) { + flags := configToFlagsFor( + map[string]any{ + "port": 8188, + "max_model_len": 8192, + "gpu_memory_utilization": 0.5, + }, + knowledge.ConfigFlagContext{ + Command: []string{"python3", "server.py"}, + Engine: "z-image-diffusers", + ModelType: "image_gen", + }, + map[string]struct{}{"port": {}}, + ) + got := strings.Join(flags, " ") + if strings.Contains(got, "--max-model-len") || strings.Contains(got, "--gpu-memory-utilization") { + t.Fatalf("LLM-only flags should be omitted for image services, got %q", got) + } +}