Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/aima/tooldeps_deploy.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ func buildDeployDeps(ac *appContext, deps *mcp.ToolDeps,
PortSpecs: append([]knowledge.StartupPort(nil), resolved.PortSpecs...),
InitCommands: resolved.InitCommands,
ModelPath: modelPath,
ModelType: catalogModelType(cat, modelName),
Config: resolved.Config,
RuntimeClassName: resolved.RuntimeClassName,
CPUArch: resolved.CPUArch,
Expand Down
65 changes: 64 additions & 1 deletion internal/knowledge/configflags.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,81 @@ func FormatConfigFlag(key string, value any) []string {
}
}

// ConfigFlagContext describes the runtime command surface used to decide
// whether a resolved config key is a real CLI argument or only a resolver hint.
type ConfigFlagContext struct {
Command []string
ModelPath string
Engine string
ModelType string
}

// ShouldIncludeConfigFlag reports whether a resolved config key should be emitted
// as a runtime CLI flag for the given startup command and local model path.
// Some keys, such as quantization, are selection hints for a model artifact rather
// than portable runtime flags across every engine.
func ShouldIncludeConfigFlag(command []string, modelPath, key string, value any) bool {
return ShouldIncludeConfigFlagFor(ConfigFlagContext{Command: command, ModelPath: modelPath}, key, value)
}

// ShouldIncludeConfigFlagFor is the engine-aware form of ShouldIncludeConfigFlag.
// It keeps legacy behavior for unknown engines, but prevents LLM-only knobs from
// leaking into image/audio/OCR service wrappers that do not expose those flags.
func ShouldIncludeConfigFlagFor(ctx ConfigFlagContext, key string, value any) bool {
switch strings.ToLower(strings.TrimSpace(key)) {
case "":
return false
case "quantization":
return shouldIncludeQuantizationFlag(command, modelPath, value)
return shouldIncludeQuantizationFlag(ctx.Command, ctx.ModelPath, value)
default:
if isLLMOnlyConfigKey(key) && !commandAcceptsLLMConfig(ctx) {
return false
}
return true
}
}

func isLLMOnlyConfigKey(key string) bool {
switch strings.ToLower(strings.TrimSpace(key)) {
case "max_model_len", "max_seq_len", "max_seq_length", "max_context_len", "max_context_tokens",
"context_length", "ctx_size", "n_ctx", "gpu_memory_utilization", "mem_fraction_static",
"tensor_parallel_size", "pipeline_parallel_size", "kv_cache_dtype", "dtype", "trust_remote_code",
"enforce_eager", "disable_log_stats", "served_model_name", "speculative_config",
"mm_encoder_attn_backend":
return true
default:
return false
}
}

func commandAcceptsLLMConfig(ctx ConfigFlagContext) bool {
if isLLMModelType(ctx.ModelType) {
return true
}
if strings.TrimSpace(ctx.Engine) == "" && strings.TrimSpace(ctx.ModelType) == "" && len(ctx.Command) == 0 {
return true
}
for _, value := range []string{ctx.Engine, strings.Join(ctx.Command, " ")} {
lower := strings.ToLower(value)
switch {
case strings.Contains(lower, "vllm"),
strings.Contains(lower, "sglang"),
strings.Contains(lower, "llama"),
strings.Contains(lower, "ollama"),
strings.Contains(lower, "transformers serve"),
strings.Contains(lower, "qwen-asr-serve"):
return true
}
}
return false
}

func isLLMModelType(modelType string) bool {
switch strings.ToLower(strings.TrimSpace(modelType)) {
case "llm", "vlm", "embedding":
return true
default:
return false
}
}

Expand Down
30 changes: 30 additions & 0 deletions internal/knowledge/configflags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,34 @@ func TestShouldIncludeConfigFlag(t *testing.T) {
t.Fatal("quantization flag should be kept when config.json declares quantization_config")
}
})

t.Run("skip LLM-only flags for non-LLM service wrappers", func(t *testing.T) {
for _, key := range []string{"max_model_len", "gpu_memory_utilization", "mem_fraction_static"} {
if ShouldIncludeConfigFlagFor(
ConfigFlagContext{
Command: []string{"python3", "server.py"},
Engine: "z-image-diffusers",
ModelType: "image_gen",
},
key,
8192,
) {
t.Fatalf("%s should be omitted for image service wrappers", key)
}
}
})

t.Run("keep LLM-only flags for vLLM-like service wrappers", func(t *testing.T) {
if !ShouldIncludeConfigFlagFor(
ConfigFlagContext{
Command: []string{"qwen-asr-serve", "{{.ModelPath}}"},
Engine: "vllm-nightly-audio",
ModelType: "asr",
},
"max_model_len",
8192,
) {
t.Fatal("max_model_len should be kept for vLLM audio wrappers")
}
})
}
7 changes: 6 additions & 1 deletion internal/knowledge/podgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,12 @@ func GeneratePod(resolved *ResolvedConfig) ([]byte, error) {
if _, reserved := portKeys[k]; reserved {
continue
}
if !ShouldIncludeConfigFlag(resolved.Command, resolved.ModelPath, k, resolved.Config[k]) {
if !ShouldIncludeConfigFlagFor(ConfigFlagContext{
Command: resolved.Command,
ModelPath: resolved.ModelPath,
Engine: resolved.Engine,
ModelType: resolved.ModelType,
}, k, resolved.Config[k]) {
continue
}
flagName := "--" + strings.ReplaceAll(k, "_", "-")
Expand Down
32 changes: 32 additions & 0 deletions internal/knowledge/podgen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,3 +395,35 @@ func TestGeneratePodEnvMerge(t *testing.T) {
t.Errorf("SHARED_VAR = %q, want engine-wins (engine overrides hw)", envMap["SHARED_VAR"])
}
}

func TestGeneratePodFiltersLLMOnlyArgsForImageService(t *testing.T) {
resolved := &ResolvedConfig{
Engine: "z-image-diffusers",
EngineImage: "qujing-z-image:latest",
ModelPath: "/data/models/stable-diffusion-v1-5",
ModelName: "stable-diffusion-v1-5",
ModelType: "image_gen",
Slot: "default",
Config: map[string]any{
"port": 8188,
"max_model_len": 8192,
"gpu_memory_utilization": 0.5,
},
Command: []string{"python3", "server.py"},
PortSpecs: []StartupPort{
{Name: "http", Flag: "--port", ConfigKey: "port", Primary: true},
},
}

podYAML, err := GeneratePod(resolved)
if err != nil {
t.Fatalf("GeneratePod: %v", err)
}
s := string(podYAML)
if strings.Contains(s, "--max-model-len") || strings.Contains(s, "--gpu-memory-utilization") {
t.Fatalf("LLM-only args should not be emitted for image services:\n%s", s)
}
if !strings.Contains(s, "--port") {
t.Fatalf("expected service port flag to remain:\n%s", s)
}
}
2 changes: 2 additions & 0 deletions internal/knowledge/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ type ResolvedConfig struct {
EngineImage string
ModelPath string
ModelName string
ModelType string
ModelFormat string
Slot string
Config map[string]any
Expand Down Expand Up @@ -221,6 +222,7 @@ func (c *Catalog) Resolve(hw HardwareInfo, modelName, engineType string, userOve
Engine: engineType,
EngineAssetName: engineAssetName,
ModelName: model.Metadata.Name,
ModelType: model.Metadata.Type,
ModelFormat: variant.Format,
Slot: slot.Name,
Config: config,
Expand Down
7 changes: 6 additions & 1 deletion internal/runtime/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,12 @@ func (r *DockerRuntime) buildRunArgs(name string, req *DeployRequest) []string {
command = knowledge.AppendPortBindings(command, portBindings)

// Append config values as CLI flags, with template substitution
for _, f := range configToFlags(req.Config, req.Command, req.ModelPath, knowledge.PortConfigKeys(req.PortSpecs)) {
for _, f := range configToFlagsFor(req.Config, knowledge.ConfigFlagContext{
Command: req.Command,
ModelPath: req.ModelPath,
Engine: req.Engine,
ModelType: req.ModelType,
}, knowledge.PortConfigKeys(req.PortSpecs)) {
f = strings.ReplaceAll(f, "{{.ModelName}}", req.Name)
f = strings.ReplaceAll(f, "{{.ModelPath}}", containerModelPath)
command = append(command, f)
Expand Down
1 change: 1 addition & 0 deletions internal/runtime/k3s.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ func toResolvedConfig(req *DeployRequest) *knowledge.ResolvedConfig {
EngineImage: req.Image,
ModelPath: req.ModelPath,
ModelName: req.Name,
ModelType: req.ModelType,
Slot: slot,
Config: config,
Command: req.Command,
Expand Down
7 changes: 6 additions & 1 deletion internal/runtime/native.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,12 @@ func (r *NativeRuntime) Deploy(ctx context.Context, req *DeployRequest) error {
command = knowledge.AppendPortBindings(command, portBindings)

// Append other config values as CLI flags, with template substitution
for _, f := range configToFlags(req.Config, req.Command, req.ModelPath, knowledge.PortConfigKeys(req.PortSpecs)) {
for _, f := range configToFlagsFor(req.Config, knowledge.ConfigFlagContext{
Command: req.Command,
ModelPath: req.ModelPath,
Engine: req.Engine,
ModelType: req.ModelType,
}, knowledge.PortConfigKeys(req.PortSpecs)) {
f = strings.ReplaceAll(f, "{{.ModelName}}", req.Name)
f = strings.ReplaceAll(f, "{{.ModelPath}}", req.ModelPath)
command = append(command, f)
Expand Down
7 changes: 6 additions & 1 deletion internal/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type DeployRequest struct {
PortSpecs []knowledge.StartupPort
InitCommands []string // pre-commands to run before main server (K3S, Docker)
ModelPath string // host path to model files
ModelType string // catalog model type (llm, asr, tts, image_gen, ...)
Port int // legacy fallback; prefer Config + PortSpecs
Config map[string]any
Partition *PartitionRequest // resource limits (K3S+HAMi); native ignores
Expand Down Expand Up @@ -147,6 +148,10 @@ func isPortFlag(s string) bool {
// ShouldIncludeConfigFlag. Serialization rules (bool/map/scalar) live in
// FormatConfigFlag so K3S podgen and Docker/Native runtime stay consistent.
func configToFlags(config map[string]any, command []string, modelPath string, reservedKeys map[string]struct{}) []string {
return configToFlagsFor(config, knowledge.ConfigFlagContext{Command: command, ModelPath: modelPath}, reservedKeys)
}

func configToFlagsFor(config map[string]any, flagCtx knowledge.ConfigFlagContext, reservedKeys map[string]struct{}) []string {
if len(config) == 0 {
return nil
}
Expand All @@ -155,7 +160,7 @@ func configToFlags(config map[string]any, command []string, modelPath string, re
if _, reserved := reservedKeys[k]; reserved {
continue
}
if !knowledge.ShouldIncludeConfigFlag(command, modelPath, k, config[k]) {
if !knowledge.ShouldIncludeConfigFlagFor(flagCtx, k, config[k]) {
continue
}
keys = append(keys, k)
Expand Down
22 changes: 22 additions & 0 deletions internal/runtime/runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"encoding/json"
"strings"
"testing"

"github.com/jguan/aima/internal/knowledge"
)

func TestConfigToFlagsSkipsSelectionOnlyQuantization(t *testing.T) {
Expand Down Expand Up @@ -81,3 +83,23 @@ func TestConfigToFlagsMapEmitsJSON(t *testing.T) {
t.Fatalf("expected method=mtp, got %v", parsed["method"])
}
}

func TestConfigToFlagsFiltersLLMOnlyArgsForImageService(t *testing.T) {
flags := configToFlagsFor(
map[string]any{
"port": 8188,
"max_model_len": 8192,
"gpu_memory_utilization": 0.5,
},
knowledge.ConfigFlagContext{
Command: []string{"python3", "server.py"},
Engine: "z-image-diffusers",
ModelType: "image_gen",
},
map[string]struct{}{"port": {}},
)
got := strings.Join(flags, " ")
if strings.Contains(got, "--max-model-len") || strings.Contains(got, "--gpu-memory-utilization") {
t.Fatalf("LLM-only flags should be omitted for image services, got %q", got)
}
}
Loading