diff --git a/cmd/thv/app/commands.go b/cmd/thv/app/commands.go index 14fe7bc348..fad9b560c0 100644 --- a/cmd/thv/app/commands.go +++ b/cmd/thv/app/commands.go @@ -71,6 +71,7 @@ func NewRootCmd(enableUpdates bool) *cobra.Command { rootCmd.AddCommand(inspectorCommand()) rootCmd.AddCommand(newMCPCommand()) rootCmd.AddCommand(newVMCPCommand()) + rootCmd.AddCommand(newLLMCommand()) rootCmd.AddCommand(groupCmd) rootCmd.AddCommand(skillCmd) rootCmd.AddCommand(statusCmd) @@ -113,6 +114,7 @@ func IsInformationalCommand(args []string) bool { "mcp": true, "skill": true, "vmcp": true, + "llm": true, } return informationalCommands[command] diff --git a/cmd/thv/app/llm.go b/cmd/thv/app/llm.go new file mode 100644 index 0000000000..e071783efe --- /dev/null +++ b/cmd/thv/app/llm.go @@ -0,0 +1,299 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package app + +import ( + "context" + "encoding/json" + "fmt" + "os" + + "github.com/spf13/cobra" + + "github.com/stacklok/toolhive/pkg/auth/secrets" + "github.com/stacklok/toolhive/pkg/config" + "github.com/stacklok/toolhive/pkg/llm" + pkgsecrets "github.com/stacklok/toolhive/pkg/secrets" +) + +func newLLMCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "llm", + Hidden: true, + Short: "Manage LLM gateway authentication", + Long: `Configure and manage authentication for OIDC-protected LLM gateways. + +The llm command bridges AI coding tools to LLM gateways by handling OIDC +authentication transparently. Two modes are planned: + + Proxy mode — a localhost reverse proxy injects fresh tokens for tools + that only accept static API keys (e.g. Cursor). + Token helper — "thv llm token" prints a fresh JWT suitable for use as + apiKeyHelper or auth.command in OIDC-capable tools + (e.g. Claude Code). + +To configure the gateway connection settings, use: + + thv llm config set --gateway-url https://llm.example.com \ + --issuer https://auth.example.com \ + --client-id my-client-id + +Use "thv llm config show" to view the current configuration.`, + } + + cmd.AddCommand(newConfigCommand()) + cmd.AddCommand(newLLMSetupCommand()) + cmd.AddCommand(newLLMTeardownCommand()) + cmd.AddCommand(newLLMProxyCommand()) + cmd.AddCommand(newLLMTokenCommand()) + + return cmd +} + +// ── config subcommand group ─────────────────────────────────────────────────── + +func newConfigCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "config", + Short: "Manage LLM gateway configuration", + Long: "The config command provides subcommands to manage LLM gateway connection settings.", + } + + cmd.AddCommand(newConfigSetCommand()) + cmd.AddCommand(newConfigShowCommand()) + cmd.AddCommand(newConfigResetCommand()) + + return cmd +} + +func newConfigSetCommand() *cobra.Command { + var ( + gatewayURL string + issuer string + clientID string + audience string + proxyPort int + callbackPort int + ) + + cmd := &cobra.Command{ + Use: "set", + Short: "Set LLM gateway connection settings", + Long: `Persist LLM gateway connection settings to config.yaml. + +Example: + thv llm config set \ + --gateway-url https://llm.example.com \ + --issuer https://auth.example.com \ + --client-id my-client-id`, + Args: cobra.NoArgs, + RunE: func(_ *cobra.Command, _ []string) error { + return config.UpdateConfig(func(c *config.Config) error { + if gatewayURL != "" { + c.LLM.GatewayURL = gatewayURL + } + if issuer != "" { + c.LLM.OIDC.Issuer = issuer + } + if clientID != "" { + c.LLM.OIDC.ClientID = clientID + } + if audience != "" { + c.LLM.OIDC.Audience = audience + } + if proxyPort != 0 { + c.LLM.Proxy.ListenPort = proxyPort + } + if callbackPort != 0 { + c.LLM.OIDC.CallbackPort = callbackPort + } + // Always validate format/range for any fields that were set, + // so invalid values (e.g. http:// URL, out-of-range port) are + // rejected immediately rather than silently persisted. + // Full validation (required-field checks) only runs once the + // mandatory trio is present, allowing incremental configuration. + if !c.LLM.IsConfigured() { + return c.LLM.ValidatePartial() + } + return c.LLM.Validate() + }) + }, + } + + cmd.Flags().StringVar(&gatewayURL, "gateway-url", "", "LLM gateway base URL (must use HTTPS)") + cmd.Flags().StringVar(&issuer, "issuer", "", "OIDC issuer URL") + cmd.Flags().StringVar(&clientID, "client-id", "", "OIDC client ID") + cmd.Flags().StringVar(&audience, "audience", "", "OIDC audience (optional)") + cmd.Flags().IntVar(&proxyPort, "proxy-port", 0, "Localhost proxy listen port (default 14000)") + cmd.Flags().IntVar(&callbackPort, "callback-port", 0, "OIDC callback port (default: ephemeral)") + + return cmd +} + +func newConfigShowCommand() *cobra.Command { + var outputFormat string + + cmd := &cobra.Command{ + Use: "show", + Short: "Display current LLM gateway configuration", + Args: cobra.NoArgs, + PreRunE: ValidateFormat(&outputFormat, FormatJSON, FormatText), + RunE: func(_ *cobra.Command, _ []string) error { + provider := config.NewDefaultProvider() + llmCfg := provider.GetConfig().LLM + + if outputFormat == FormatJSON { + enc, err := json.MarshalIndent(llmCfg, "", " ") + if err != nil { + return fmt.Errorf("failed to encode config as JSON: %w", err) + } + fmt.Println(string(enc)) + return nil + } + + if !llmCfg.IsConfigured() { + fmt.Println("LLM gateway is not configured. Run \"thv llm config set\" to configure it.") + return nil + } + + fmt.Printf("Gateway URL: %s\n", llmCfg.GatewayURL) + fmt.Printf("OIDC Issuer: %s\n", llmCfg.OIDC.Issuer) + fmt.Printf("OIDC Client: %s\n", llmCfg.OIDC.ClientID) + if llmCfg.OIDC.Audience != "" { + fmt.Printf("Audience: %s\n", llmCfg.OIDC.Audience) + } + fmt.Printf("Proxy Port: %d\n", llmCfg.EffectiveProxyPort()) + fmt.Printf("Scopes: %v\n", llmCfg.OIDC.EffectiveScopes()) + if len(llmCfg.ConfiguredTools) > 0 { + fmt.Println("Configured tools:") + for _, t := range llmCfg.ConfiguredTools { + fmt.Printf(" - %s (%s) %s\n", t.Tool, t.Mode, t.ConfigPath) + } + } + + return nil + }, + } + + AddFormatFlag(cmd, &outputFormat, FormatJSON, FormatText) + + return cmd +} + +func newConfigResetCommand() *cobra.Command { + return &cobra.Command{ + Use: "reset", + Short: "Clear all LLM gateway configuration and cached tokens", + Long: `Remove all LLM gateway settings from config.yaml and delete cached OIDC +tokens from the secrets provider.`, + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, _ []string) error { + // Delete cached tokens from the secrets provider first. + if err := deleteLLMSecrets(cmd.Context()); err != nil { + // Non-fatal: log and continue so the config is still cleared. + fmt.Fprintf(os.Stderr, "Warning: could not remove cached LLM tokens: %v\n", err) + } + + return config.UpdateConfig(func(c *config.Config) error { + c.LLM = llm.Config{} + return nil + }) + }, + } +} + +// deleteLLMSecrets removes all secrets stored under the LLM scope. +func deleteLLMSecrets(ctx context.Context) error { + provider, err := secrets.GetSystemSecretsProvider() + if err != nil { + return fmt.Errorf("failed to get secrets provider: %w", err) + } + scoped := pkgsecrets.NewScopedProvider(provider, pkgsecrets.ScopeLLM) + descs, err := scoped.ListSecrets(ctx) + if err != nil { + return err + } + if len(descs) == 0 { + return nil + } + names := make([]string, len(descs)) + for i, d := range descs { + names[i] = d.Key + } + return scoped.DeleteSecrets(ctx, names) +} + +// ── setup / teardown stubs ──────────────────────────────────────────────────── + +func newLLMSetupCommand() *cobra.Command { + return &cobra.Command{ + Use: "setup", + Short: "Detect installed AI tools, configure them, and trigger OIDC login (coming soon)", + Args: cobra.NoArgs, + RunE: func(_ *cobra.Command, _ []string) error { + return fmt.Errorf("not implemented: coming in a future release") + }, + } +} + +func newLLMTeardownCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "teardown [tool-name]", + Short: "Remove LLM gateway configuration from all tools and stop the proxy (coming soon)", + Args: cobra.MaximumNArgs(1), + RunE: func(_ *cobra.Command, _ []string) error { + return fmt.Errorf("not implemented: coming in a future release") + }, + } + + cmd.Flags().Bool("purge-tokens", false, "Also delete cached OIDC tokens from the secrets provider") + + return cmd +} + +// ── proxy subcommand group ──────────────────────────────────────────────────── + +func newLLMProxyCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "proxy", + Short: "Manage the LLM gateway localhost proxy", + } + + cmd.AddCommand(newLLMProxyStartCommand()) + + return cmd +} + +func newLLMProxyStartCommand() *cobra.Command { + return &cobra.Command{ + Use: "start", + Short: "Start the LLM proxy in the foreground (coming soon)", + Args: cobra.NoArgs, + RunE: func(_ *cobra.Command, _ []string) error { + return fmt.Errorf("not implemented: coming in a future release") + }, + } +} + +// ── token helper (hidden) ───────────────────────────────────────────────────── + +func newLLMTokenCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "token", + Hidden: true, + Short: "Print a fresh LLM gateway access token to stdout", + Long: `Print a fresh OIDC access token to stdout (all other output on stderr). +Intended for use as apiKeyHelper or auth.command in OIDC-capable AI tools. +Runs non-interactively — will not launch a browser flow.`, + Args: cobra.NoArgs, + RunE: func(_ *cobra.Command, _ []string) error { + // Print the error to stderr so it doesn't corrupt the token value + // expected on stdout by callers such as apiKeyHelper or auth.command. + return fmt.Errorf("thv llm token is not yet implemented — " + + "run \"thv llm setup\" once it is available to configure your tools") + }, + } + + return cmd +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 01222d8e77..de20808ce1 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -19,6 +19,7 @@ import ( "github.com/stacklok/toolhive-core/env" "github.com/stacklok/toolhive/pkg/container/templates" + "github.com/stacklok/toolhive/pkg/llm" "github.com/stacklok/toolhive/pkg/lockfile" "github.com/stacklok/toolhive/pkg/secrets" ) @@ -47,6 +48,7 @@ type Config struct { BuildAuthFiles map[string]string `yaml:"build_auth_files,omitempty"` RuntimeConfigs map[string]*templates.RuntimeConfig `yaml:"runtime_configs,omitempty"` RegistryAuth RegistryAuth `yaml:"registry_auth,omitempty"` + LLM llm.Config `yaml:"llm,omitempty"` } // RegistryAuthTypeOAuth is the auth type for OAuth/OIDC authentication. diff --git a/pkg/llm/config.go b/pkg/llm/config.go new file mode 100644 index 0000000000..1b6bba5db0 --- /dev/null +++ b/pkg/llm/config.go @@ -0,0 +1,143 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package llm + +import ( + "errors" + "fmt" + "strings" + "time" + + "github.com/stacklok/toolhive/pkg/networking" +) + +const ( + // DefaultProxyListenPort is the default port the localhost proxy listens on. + DefaultProxyListenPort = 14000 + + // DefaultOIDCScopes are the default OAuth scopes requested during login. + DefaultOIDCScopes = "openid offline_access" +) + +// Config holds all LLM gateway settings persisted under the llm: key in +// ToolHive's config.yaml. +type Config struct { + GatewayURL string `yaml:"gateway_url,omitempty" json:"gateway_url,omitempty"` + OIDC OIDCConfig `yaml:"oidc,omitempty" json:"oidc,omitempty"` + Proxy ProxyConfig `yaml:"proxy,omitempty" json:"proxy,omitempty"` + ConfiguredTools []ToolConfig `yaml:"configured_tools,omitempty" json:"configured_tools,omitempty"` +} + +// OIDCConfig holds OIDC provider settings and cached token state for the LLM +// gateway. Cached fields follow the same pattern as RegistryOAuthConfig in +// pkg/config/config.go — token values are never stored here, only references +// and expiry metadata. +type OIDCConfig struct { + Issuer string `yaml:"issuer,omitempty" json:"issuer,omitempty"` + ClientID string `yaml:"client_id,omitempty" json:"client_id,omitempty"` + Scopes []string `yaml:"scopes,omitempty" json:"scopes,omitempty"` + Audience string `yaml:"audience,omitempty" json:"audience,omitempty"` + CallbackPort int `yaml:"callback_port,omitempty" json:"callback_port,omitempty"` + + // CachedRefreshTokenRef is the secrets-provider key under which the refresh + // token is stored (never the token value itself). + CachedRefreshTokenRef string `yaml:"cached_refresh_token_ref,omitempty" json:"cached_refresh_token_ref,omitempty"` + // CachedTokenExpiry is the expiry of the most recently cached access token, + // used to surface helpful messages when the token is about to expire. + CachedTokenExpiry time.Time `yaml:"cached_token_expiry,omitempty" json:"cached_token_expiry,omitempty"` +} + +// ProxyConfig holds configuration for the localhost reverse proxy. +type ProxyConfig struct { + ListenPort int `yaml:"listen_port,omitempty" json:"listen_port,omitempty"` +} + +// ToolConfig records a tool that setup has configured, so teardown knows +// exactly what to reverse. +type ToolConfig struct { + // Tool is the canonical tool identifier (e.g. "claude-code", "cursor"). + Tool string `yaml:"tool" json:"tool"` + // Mode is the authentication mode: "direct" or "proxy". + Mode string `yaml:"mode" json:"mode"` + // ConfigPath is the absolute path to the tool's config file that was patched. + ConfigPath string `yaml:"config_path" json:"config_path"` +} + +// IsConfigured reports whether the minimum required fields are present for the +// LLM gateway to be usable: gateway URL, OIDC issuer, and OIDC client ID. +func (c *Config) IsConfigured() bool { + return c.GatewayURL != "" && c.OIDC.Issuer != "" && c.OIDC.ClientID != "" +} + +// ValidatePartial validates any fields that are explicitly set, without +// requiring the mandatory trio (gateway_url, oidc.issuer, oidc.client_id). +// Use this to catch URL format or port range errors during incremental +// configuration, before all required fields have been provided. +func (c *Config) ValidatePartial() error { + var errs []error + + if c.GatewayURL != "" { + if err := networking.ValidateHTTPSURL(c.GatewayURL); err != nil { + errs = append(errs, fmt.Errorf("gateway_url: %w", err)) + } + } + + if c.OIDC.Issuer != "" { + if err := networking.ValidateIssuerURL(c.OIDC.Issuer); err != nil { + errs = append(errs, fmt.Errorf("oidc.issuer: %w", err)) + } + } + + if c.Proxy.ListenPort != 0 && (c.Proxy.ListenPort < 1024 || c.Proxy.ListenPort > 65535) { + errs = append(errs, fmt.Errorf("proxy.listen_port must be between 1024 and 65535, got: %d", c.Proxy.ListenPort)) + } + + // Reuse networking.ValidateCallbackPort for the OIDC callback port — same + // range check (1024–65535), zero means ephemeral (auto-assigned). Pass the + // client ID so the validator applies strict availability checking for + // pre-registered clients (clientID != ""). + if err := networking.ValidateCallbackPort(c.OIDC.CallbackPort, c.OIDC.ClientID); err != nil { + errs = append(errs, fmt.Errorf("oidc.callback_port: %w", err)) + } + + return errors.Join(errs...) +} + +// Validate performs full validation of the LLM config, including HTTPS +// enforcement, port range checks, and OIDC field requirements. +func (c *Config) Validate() error { + var errs []error + + if c.GatewayURL == "" { + errs = append(errs, errors.New("gateway_url is required")) + } + + if c.OIDC.Issuer == "" { + errs = append(errs, errors.New("oidc.issuer is required")) + } + + if c.OIDC.ClientID == "" { + errs = append(errs, errors.New("oidc.client_id is required")) + } + + return errors.Join(append(errs, c.ValidatePartial())...) +} + +// EffectiveProxyPort returns the configured proxy listen port, or +// DefaultProxyListenPort if none is set. +func (c *Config) EffectiveProxyPort() int { + if c.Proxy.ListenPort > 0 { + return c.Proxy.ListenPort + } + return DefaultProxyListenPort +} + +// EffectiveScopes returns the configured OIDC scopes, or the default scopes +// (openid, offline_access) if none are set. +func (c *OIDCConfig) EffectiveScopes() []string { + if len(c.Scopes) > 0 { + return c.Scopes + } + return strings.Fields(DefaultOIDCScopes) +} diff --git a/pkg/llm/config_test.go b/pkg/llm/config_test.go new file mode 100644 index 0000000000..f2d69f9a4d --- /dev/null +++ b/pkg/llm/config_test.go @@ -0,0 +1,339 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package llm + +import ( + "testing" +) + +func TestConfig_IsConfigured(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg Config + want bool + }{ + { + name: "fully configured", + cfg: Config{ + GatewayURL: "https://llm.example.com", + OIDC: OIDCConfig{ + Issuer: "https://auth.example.com", + ClientID: "my-client", + }, + }, + want: true, + }, + { + name: "missing gateway URL", + cfg: Config{ + OIDC: OIDCConfig{ + Issuer: "https://auth.example.com", + ClientID: "my-client", + }, + }, + want: false, + }, + { + name: "missing issuer", + cfg: Config{ + GatewayURL: "https://llm.example.com", + OIDC: OIDCConfig{ + ClientID: "my-client", + }, + }, + want: false, + }, + { + name: "missing client ID", + cfg: Config{ + GatewayURL: "https://llm.example.com", + OIDC: OIDCConfig{ + Issuer: "https://auth.example.com", + }, + }, + want: false, + }, + { + name: "empty config", + cfg: Config{}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := tt.cfg.IsConfigured() + if got != tt.want { + t.Errorf("IsConfigured() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestConfig_Validate(t *testing.T) { + t.Parallel() + + valid := Config{ + GatewayURL: "https://llm.example.com", + OIDC: OIDCConfig{ + Issuer: "https://auth.example.com", + ClientID: "my-client", + }, + } + + tests := []struct { + name string + cfg Config + wantErr bool + }{ + { + name: "valid config", + cfg: valid, + wantErr: false, + }, + { + name: "missing gateway URL", + cfg: Config{ + OIDC: OIDCConfig{ + Issuer: "https://auth.example.com", + ClientID: "my-client", + }, + }, + wantErr: true, + }, + { + name: "HTTP gateway URL rejected", + cfg: Config{ + GatewayURL: "http://llm.example.com", + OIDC: OIDCConfig{ + Issuer: "https://auth.example.com", + ClientID: "my-client", + }, + }, + wantErr: true, + }, + { + name: "missing issuer", + cfg: Config{ + GatewayURL: "https://llm.example.com", + OIDC: OIDCConfig{ + ClientID: "my-client", + }, + }, + wantErr: true, + }, + { + name: "missing client ID", + cfg: Config{ + GatewayURL: "https://llm.example.com", + OIDC: OIDCConfig{ + Issuer: "https://auth.example.com", + }, + }, + wantErr: true, + }, + { + name: "proxy port below range", + cfg: Config{ + GatewayURL: "https://llm.example.com", + OIDC: OIDCConfig{ + Issuer: "https://auth.example.com", + ClientID: "my-client", + }, + Proxy: ProxyConfig{ListenPort: 80}, + }, + wantErr: true, + }, + { + name: "proxy port above range", + cfg: Config{ + GatewayURL: "https://llm.example.com", + OIDC: OIDCConfig{ + Issuer: "https://auth.example.com", + ClientID: "my-client", + }, + Proxy: ProxyConfig{ListenPort: 99999}, + }, + wantErr: true, + }, + { + name: "valid custom proxy port", + cfg: Config{ + GatewayURL: "https://llm.example.com", + OIDC: OIDCConfig{ + Issuer: "https://auth.example.com", + ClientID: "my-client", + }, + Proxy: ProxyConfig{ListenPort: 8080}, + }, + wantErr: false, + }, + { + name: "callback port below range", + cfg: Config{ + GatewayURL: "https://llm.example.com", + OIDC: OIDCConfig{ + Issuer: "https://auth.example.com", + ClientID: "my-client", + CallbackPort: 100, + }, + }, + wantErr: true, + }, + { + name: "valid callback port", + cfg: Config{ + GatewayURL: "https://llm.example.com", + OIDC: OIDCConfig{ + Issuer: "https://auth.example.com", + ClientID: "my-client", + CallbackPort: 9000, + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := tt.cfg.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestConfig_ValidatePartial(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg Config + wantErr bool + }{ + { + name: "empty config is valid", + cfg: Config{}, + wantErr: false, + }, + { + name: "valid gateway URL only", + cfg: Config{GatewayURL: "https://llm.example.com"}, + wantErr: false, + }, + { + name: "HTTP gateway URL rejected", + cfg: Config{GatewayURL: "http://llm.example.com"}, + wantErr: true, + }, + { + name: "valid issuer only", + cfg: Config{OIDC: OIDCConfig{Issuer: "https://auth.example.com"}}, + wantErr: false, + }, + { + name: "invalid issuer rejected", + cfg: Config{OIDC: OIDCConfig{Issuer: "not-a-url"}}, + wantErr: true, + }, + { + name: "proxy port below range rejected", + cfg: Config{Proxy: ProxyConfig{ListenPort: 80}}, + wantErr: true, + }, + { + name: "proxy port above range rejected", + cfg: Config{Proxy: ProxyConfig{ListenPort: 99999}}, + wantErr: true, + }, + { + name: "valid proxy port accepted", + cfg: Config{Proxy: ProxyConfig{ListenPort: 8080}}, + wantErr: false, + }, + { + name: "callback port below range rejected", + cfg: Config{OIDC: OIDCConfig{CallbackPort: 100}}, + wantErr: true, + }, + { + name: "valid callback port accepted", + cfg: Config{OIDC: OIDCConfig{CallbackPort: 9000}}, + wantErr: false, + }, + { + name: "multiple invalid fields all reported", + cfg: Config{ + GatewayURL: "http://llm.example.com", + Proxy: ProxyConfig{ListenPort: 80}, + }, + wantErr: true, + }, + { + name: "required fields absent but valid values accepted", + cfg: Config{ + GatewayURL: "https://llm.example.com", + Proxy: ProxyConfig{ListenPort: 8080}, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := tt.cfg.ValidatePartial() + if (err != nil) != tt.wantErr { + t.Errorf("ValidatePartial() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestConfig_EffectiveProxyPort(t *testing.T) { + t.Parallel() + + t.Run("returns default when not set", func(t *testing.T) { + t.Parallel() + cfg := Config{} + if got := cfg.EffectiveProxyPort(); got != DefaultProxyListenPort { + t.Errorf("EffectiveProxyPort() = %d, want %d", got, DefaultProxyListenPort) + } + }) + + t.Run("returns configured port", func(t *testing.T) { + t.Parallel() + cfg := Config{Proxy: ProxyConfig{ListenPort: 8080}} + if got := cfg.EffectiveProxyPort(); got != 8080 { + t.Errorf("EffectiveProxyPort() = %d, want 8080", got) + } + }) +} + +func TestOIDCConfig_EffectiveScopes(t *testing.T) { + t.Parallel() + + t.Run("returns defaults when not set", func(t *testing.T) { + t.Parallel() + cfg := OIDCConfig{} + scopes := cfg.EffectiveScopes() + if len(scopes) == 0 { + t.Error("EffectiveScopes() returned empty slice for zero-value config") + } + }) + + t.Run("returns configured scopes", func(t *testing.T) { + t.Parallel() + cfg := OIDCConfig{Scopes: []string{"openid", "email"}} + scopes := cfg.EffectiveScopes() + if len(scopes) != 2 || scopes[0] != "openid" || scopes[1] != "email" { + t.Errorf("EffectiveScopes() = %v, want [openid email]", scopes) + } + }) +} diff --git a/pkg/llm/doc.go b/pkg/llm/doc.go new file mode 100644 index 0000000000..6a4f9d2e8e --- /dev/null +++ b/pkg/llm/doc.go @@ -0,0 +1,18 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package llm provides configuration types and public API for the thv llm +// command group, which bridges AI coding tools to OIDC-protected LLM gateways. +// +// Two authentication modes are planned: +// - Proxy mode: a localhost reverse proxy that injects fresh OIDC tokens for +// tools that only accept static API keys (e.g. Cursor). +// - Token helper mode: thv llm token prints a fresh JWT to stdout, suitable +// for use as apiKeyHelper or auth.command in OIDC-capable tools (e.g. Claude Code). +// +// Both modes are under active development; the corresponding CLI commands +// currently return not-implemented errors. +// +// Configuration is persisted in ToolHive's config.yaml under the llm: key via +// the existing UpdateConfig() mechanism. +package llm diff --git a/pkg/networking/utilities.go b/pkg/networking/utilities.go index 67f020d244..112621888e 100644 --- a/pkg/networking/utilities.go +++ b/pkg/networking/utilities.go @@ -95,6 +95,41 @@ func validateEndpointURLWithSkip(endpoint string, skipValidation bool) error { return nil } +// ValidateHTTPSURL checks that rawURL is a valid URL using the https scheme. +// Unlike ValidateEndpointURL, no localhost exception is made — HTTPS is always +// required (suitable for gateway URLs and other production endpoints). +func ValidateHTTPSURL(rawURL string) error { + parsed, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("invalid URL: %w", err) + } + if parsed.Host == "" { + return fmt.Errorf("URL must include a host: %s", rawURL) + } + if parsed.Scheme != HttpsScheme { + return fmt.Errorf("must use HTTPS, got scheme %q", parsed.Scheme) + } + return nil +} + +// ValidateIssuerURL validates that an OIDC issuer URL is well-formed and uses +// HTTPS. HTTP is permitted only for localhost (development). Per OIDC Core +// Section 3.1.2.1 and RFC 8414 Section 2, the issuer MUST use the "https" +// scheme. +func ValidateIssuerURL(rawURL string) error { + u, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("invalid issuer URL %q: %w", rawURL, err) + } + if u.Host == "" { + return fmt.Errorf("issuer URL must include a host: %s", rawURL) + } + if u.Scheme != HttpsScheme && !IsLocalhost(u.Host) { + return fmt.Errorf("issuer URL must use HTTPS (except localhost for development): %s", rawURL) + } + return nil +} + // IsLocalhost checks if a host is localhost (for development) func IsLocalhost(host string) bool { return strings.HasPrefix(host, "localhost:") || diff --git a/pkg/networking/utilities_test.go b/pkg/networking/utilities_test.go index 58d1517d29..377bca2495 100644 --- a/pkg/networking/utilities_test.go +++ b/pkg/networking/utilities_test.go @@ -490,3 +490,131 @@ func TestValidateEndpointURL(t *testing.T) { }) } } + +func TestValidateHTTPSURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + url string + expectError bool + }{ + { + name: "valid HTTPS URL", + url: "https://llm.example.com", + expectError: false, + }, + { + name: "valid HTTPS URL with path", + url: "https://llm.example.com/api/v1", + expectError: false, + }, + { + name: "valid HTTPS URL with port", + url: "https://llm.example.com:8443", + expectError: false, + }, + { + name: "HTTP rejected even for localhost", + url: "http://localhost:8080", + expectError: true, + }, + { + name: "HTTP rejected for remote host", + url: "http://llm.example.com", + expectError: true, + }, + { + name: "missing host", + url: "https://", + expectError: true, + }, + { + name: "unsupported scheme", + url: "ftp://llm.example.com", + expectError: true, + }, + { + name: "invalid URL format", + url: "not-a-url", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := ValidateHTTPSURL(tt.url) + if tt.expectError { + assert.Error(t, err, "Expected error for URL: %s", tt.url) + } else { + assert.NoError(t, err, "Expected no error for URL: %s", tt.url) + } + }) + } +} + +func TestValidateIssuerURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + url string + expectError bool + }{ + { + name: "valid HTTPS issuer", + url: "https://auth.example.com", + expectError: false, + }, + { + name: "valid HTTPS issuer with path", + url: "https://auth.example.com/realms/myrealm", + expectError: false, + }, + { + name: "localhost HTTP allowed for development", + url: "http://localhost:8080", + expectError: false, + }, + { + name: "127.0.0.1 HTTP allowed for development", + url: "http://127.0.0.1:9000", + expectError: false, + }, + { + name: "HTTP rejected for remote host", + url: "http://auth.example.com", + expectError: true, + }, + { + name: "missing host", + url: "https://", + expectError: true, + }, + { + name: "invalid URL format", + url: "not-a-url", + expectError: true, + }, + { + name: "unsupported scheme", + url: "ftp://auth.example.com", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := ValidateIssuerURL(tt.url) + if tt.expectError { + assert.Error(t, err, "Expected error for URL: %s", tt.url) + } else { + assert.NoError(t, err, "Expected no error for URL: %s", tt.url) + } + }) + } +} diff --git a/pkg/registry/auth/issuer_validation.go b/pkg/registry/auth/issuer_validation.go index 35db4038a0..9fe07e7ec2 100644 --- a/pkg/registry/auth/issuer_validation.go +++ b/pkg/registry/auth/issuer_validation.go @@ -3,26 +3,11 @@ package auth -import ( - "fmt" - "net/url" - - "github.com/stacklok/toolhive/pkg/networking" -) +import "github.com/stacklok/toolhive/pkg/networking" // validateIssuerURL validates that the issuer is a well-formed URL using HTTPS. // HTTP is permitted only for localhost (development). Per OIDC Core Section 3.1.2.1 // and RFC 8414 Section 2, the issuer MUST use the "https" scheme. func validateIssuerURL(rawURL string) error { - u, err := url.Parse(rawURL) - if err != nil { - return fmt.Errorf("invalid issuer URL %q: %w", rawURL, err) - } - if u.Host == "" { - return fmt.Errorf("issuer URL must include a host: %s", rawURL) - } - if u.Scheme != "https" && !networking.IsLocalhost(u.Host) { - return fmt.Errorf("issuer URL must use HTTPS (except localhost for development): %s", rawURL) - } - return nil + return networking.ValidateIssuerURL(rawURL) } diff --git a/pkg/secrets/scoped.go b/pkg/secrets/scoped.go index b3dc539cff..57bedc04f9 100644 --- a/pkg/secrets/scoped.go +++ b/pkg/secrets/scoped.go @@ -21,8 +21,8 @@ import ( // ends and the name begins. // // All constants declared in this package (ScopeRegistry, ScopeWorkloads, -// ScopeAuth) satisfy these invariants. Custom scopes introduced in the future -// must be validated against them. +// ScopeAuth, ScopeLLM) satisfy these invariants. Custom scopes introduced in +// the future must be validated against them. type SecretScope string const ( @@ -42,6 +42,9 @@ const ( // ScopeAuth is reserved for enterprise CLI/Desktop login tokens. ScopeAuth SecretScope = "auth" + + // ScopeLLM is the scope for LLM gateway OIDC refresh tokens. + ScopeLLM SecretScope = "llm" ) // ErrReservedKeyName is returned when a user command attempts to manage a diff --git a/pkg/secrets/scoped_test.go b/pkg/secrets/scoped_test.go index 794a049b9a..f7e01ffaef 100644 --- a/pkg/secrets/scoped_test.go +++ b/pkg/secrets/scoped_test.go @@ -802,6 +802,7 @@ func TestSecretScopeInvariants(t *testing.T) { secrets.ScopeRegistry, secrets.ScopeWorkloads, secrets.ScopeAuth, + secrets.ScopeLLM, } for _, scope := range scopes {