diff --git a/cmd/cli/commands/context.go b/cmd/cli/commands/context.go new file mode 100644 index 000000000..4a1d1489c --- /dev/null +++ b/cmd/cli/commands/context.go @@ -0,0 +1,359 @@ +package commands + +import ( + "bytes" + "fmt" + "net/url" + "os" + "path/filepath" + "sort" + "time" + + "github.com/docker/cli/cli/command" + "github.com/docker/model-runner/cmd/cli/commands/formatter" + "github.com/docker/model-runner/cmd/cli/pkg/modelctx" + "github.com/spf13/cobra" +) + +// newContextCmd returns the "docker model context" parent command. Its +// subcommands manage named Model Runner contexts stored on disk, so they do +// not require a running backend and override PersistentPreRunE accordingly. +func newContextCmd(cli *command.DockerCli) *cobra.Command { + c := &cobra.Command{ + Use: "context", + Short: "Manage Docker Model Runner contexts", + // Context management commands need only CLI initialisation, not a + // running backend. Override PersistentPreRunE to skip DetectContext. + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + return initDockerCLI(cmd, args, cli, globalOptions) + }, + } + + c.AddCommand( + newContextCreateCmd(), + newContextUseCmd(), + newContextLsCmd(), + newContextRmCmd(), + newContextInspectCmd(), + ) + return c +} + +// contextStore opens the context store using the Docker config directory +// derived from the current CLI configuration. +func contextStore() (*modelctx.Store, error) { + dir, err := dockerConfigDir() + if err != nil { + return nil, fmt.Errorf("unable to determine Docker config directory: %w", err) + } + return modelctx.New(dir) +} + +// dockerConfigDir returns the Docker configuration directory. It honours the +// DOCKER_CONFIG environment variable and falls back to ~/.docker. +func dockerConfigDir() (string, error) { + if dockerCLI != nil && dockerCLI.ConfigFile() != nil { + return filepath.Dir(dockerCLI.ConfigFile().Filename), nil + } + // Fallback used during testing or when CLI is not yet initialised. + if d := os.Getenv("DOCKER_CONFIG"); d != "" { + return d, nil + } + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("unable to determine home directory: %w", err) + } + return filepath.Join(home, ".docker"), nil +} + +// newContextCreateCmd returns the "context create" command. +func newContextCreateCmd() *cobra.Command { + var ( + host string + tls bool + tlsSkipVerify bool + tlsCACert string + description string + ) + + c := &cobra.Command{ + Use: "create NAME", + Short: "Create a named Model Runner context", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + name := args[0] + + // Validate and normalise the host URL. + if host == "" { + return fmt.Errorf("--host is required") + } + + u, err := url.ParseRequestURI(host) + if err != nil { + return fmt.Errorf("invalid --host URL: %w", err) + } + if u.Scheme == "" || u.Host == "" { + return fmt.Errorf("invalid --host URL: must include scheme and host, e.g. http://192.168.1.100:12434") + } + if u.Scheme != "http" && u.Scheme != "https" { + return fmt.Errorf("invalid --host URL: unsupported scheme %q (must be http or https)", u.Scheme) + } + + // Normalise the host string. + host = u.String() + + // Validate the CA cert path if provided. + tlsCACertAbs := "" + if tlsCACert != "" { + abs, err := filepath.Abs(tlsCACert) + if err != nil { + return fmt.Errorf("invalid --tls-ca-cert path: %w", err) + } + if _, err := os.ReadFile(abs); err != nil { + return fmt.Errorf( + "--tls-ca-cert: cannot read %q: %w", abs, err, + ) + } + tlsCACertAbs = abs + } + + store, err := contextStore() + if err != nil { + return fmt.Errorf("unable to open context store: %w", err) + } + + cfg := modelctx.ContextConfig{ + Host: host, + TLS: modelctx.TLSConfig{ + Enabled: tls, + SkipVerify: tlsSkipVerify, + CACert: tlsCACertAbs, + }, + Description: description, + CreatedAt: time.Now().UTC(), + } + if err := store.Create(name, cfg); err != nil { + return err + } + + fmt.Fprintf(cmd.OutOrStdout(), "Context %q created.\n", name) + return nil + }, + } + + c.Flags().StringVar(&host, "host", "", + "Model Runner API base URL (e.g. http://192.168.1.100:12434)") + c.Flags().BoolVar(&tls, "tls", false, + "Enable TLS for connections to this context") + c.Flags().BoolVar(&tlsSkipVerify, "tls-skip-verify", false, + "Skip TLS server certificate verification") + c.Flags().StringVar(&tlsCACert, "tls-ca-cert", "", + "Path to a custom CA certificate PEM file for TLS verification") + c.Flags().StringVar(&description, "description", "", + "Optional human-readable description for this context") + return c +} + +// newContextUseCmd returns the "context use" command. +func newContextUseCmd() *cobra.Command { + return &cobra.Command{ + Use: "use NAME", + Short: "Set the active Model Runner context", + Long: `Set the active Model Runner context. Pass "default" to revert to +automatic backend detection.`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + name := args[0] + + store, err := contextStore() + if err != nil { + return fmt.Errorf("unable to open context store: %w", err) + } + + if err := store.SetActive(name); err != nil { + return err + } + + fmt.Fprintf( + cmd.OutOrStdout(), + "Current context is now %q.\n", name, + ) + return nil + }, + } +} + +// contextListRow holds the data for one row in the "context ls" table. +type contextListRow struct { + name string + host string + description string + active bool +} + +// newContextLsCmd returns the "context ls" command. +func newContextLsCmd() *cobra.Command { + return &cobra.Command{ + Use: "ls", + Aliases: []string{"list"}, + Short: "List Model Runner contexts", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + store, err := contextStore() + if err != nil { + return fmt.Errorf("unable to open context store: %w", err) + } + + contexts, err := store.List() + if err != nil { + return fmt.Errorf("unable to list contexts: %w", err) + } + + activeName, err := store.Active() + if err != nil { + return fmt.Errorf("unable to determine active context: %w", err) + } + + // Warn if MODEL_RUNNER_HOST overrides the active context. + if envHost := os.Getenv("MODEL_RUNNER_HOST"); envHost != "" { + fmt.Fprintf( + cmd.ErrOrStderr(), + "Warning: MODEL_RUNNER_HOST=%q overrides the active context.\n", + envHost, + ) + } + + // Build rows: synthetic "default" first, then named contexts sorted. + rows := []contextListRow{ + { + name: modelctx.DefaultContextName, + host: "(auto-detect)", + description: "Auto-detected Docker context", + active: activeName == modelctx.DefaultContextName, + }, + } + + names := make([]string, 0, len(contexts)) + for n := range contexts { + names = append(names, n) + } + sort.Strings(names) + + for _, n := range names { + cfg := contexts[n] + rows = append(rows, contextListRow{ + name: n, + host: cfg.Host, + description: cfg.Description, + active: activeName == n, + }) + } + + var buf bytes.Buffer + table := newTable(&buf) + table.Header([]string{"NAME", "HOST", "DESCRIPTION", "CURRENT"}) + for _, row := range rows { + current := "" + if row.active { + current = "*" + } + table.Append([]string{ + row.name, + row.host, + row.description, + current, + }) + } + table.Render() + + fmt.Fprint(cmd.OutOrStdout(), buf.String()) + return nil + }, + } +} + +// newContextRmCmd returns the "context rm" command. +func newContextRmCmd() *cobra.Command { + return &cobra.Command{ + Use: "rm NAME [NAME...]", + Aliases: []string{"remove"}, + Short: "Remove one or more Model Runner contexts", + Args: cobra.MinimumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + store, err := contextStore() + if err != nil { + return fmt.Errorf("unable to open context store: %w", err) + } + + // Attempt removal of all named contexts; collect errors. + var errs []error + for _, name := range args { + if err := store.Remove(name); err != nil { + errs = append(errs, fmt.Errorf("%s: %w", name, err)) + continue + } + fmt.Fprintf(cmd.OutOrStdout(), "Context %q removed.\n", name) + } + + if len(errs) > 0 { + for _, e := range errs { + fmt.Fprintln(cmd.ErrOrStderr(), "Error:", e) + } + return fmt.Errorf("one or more contexts could not be removed") + } + return nil + }, + } +} + +// namedContextInspect is the JSON-serialisable representation of a named +// context returned by "context inspect". +type namedContextInspect struct { + Name string `json:"name"` + modelctx.ContextConfig +} + +// newContextInspectCmd returns the "context inspect" command. +func newContextInspectCmd() *cobra.Command { + return &cobra.Command{ + Use: "inspect NAME [NAME...]", + Short: "Display detailed information about one or more contexts", + Args: cobra.MinimumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + store, err := contextStore() + if err != nil { + return fmt.Errorf("unable to open context store: %w", err) + } + + results := make([]namedContextInspect, 0, len(args)) + for _, name := range args { + if name == modelctx.DefaultContextName { + // Return a synthetic entry for "default". + results = append(results, namedContextInspect{ + Name: modelctx.DefaultContextName, + ContextConfig: modelctx.ContextConfig{ + Host: "(auto-detect)", + Description: "Auto-detected Docker context", + }, + }) + continue + } + cfg, err := store.Get(name) + if err != nil { + return err + } + results = append(results, namedContextInspect{ + Name: name, + ContextConfig: cfg, + }) + } + + output, err := formatter.ToStandardJSON(results) + if err != nil { + return err + } + fmt.Fprint(cmd.OutOrStdout(), output) + return nil + }, + } +} diff --git a/cmd/cli/commands/context_test.go b/cmd/cli/commands/context_test.go new file mode 100644 index 000000000..7522c5822 --- /dev/null +++ b/cmd/cli/commands/context_test.go @@ -0,0 +1,397 @@ +package commands + +import ( + "bytes" + "encoding/json" + "strings" + "testing" + + "github.com/docker/model-runner/cmd/cli/pkg/modelctx" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// setupContextTest creates an isolated context store for a single test. It +// sets DOCKER_CONFIG to a temporary directory and clears dockerCLI so that +// dockerConfigDir() falls back to the env var rather than the real CLI +// config, keeping tests hermetic. +func setupContextTest(t *testing.T) *modelctx.Store { + t.Helper() + dir := t.TempDir() + t.Setenv("DOCKER_CONFIG", dir) + dockerCLI = nil // force dockerConfigDir() to use DOCKER_CONFIG + + store, err := modelctx.New(dir) + require.NoError(t, err) + return store +} + +// TestContextCreate verifies that "context create" writes the context and +// prints a confirmation message. +func TestContextCreate(t *testing.T) { + setupContextTest(t) + + cmd := newContextCreateCmd() + out := new(bytes.Buffer) + cmd.SetOut(out) + cmd.SetErr(new(bytes.Buffer)) + cmd.SetArgs([]string{"myremote", "--host", "http://192.168.1.100:12434", "--description", "lab"}) + + require.NoError(t, cmd.Execute()) + assert.Contains(t, out.String(), `"myremote"`) + + // Verify the context was actually stored. + store, err := contextStore() + require.NoError(t, err) + cfg, err := store.Get("myremote") + require.NoError(t, err) + assert.Equal(t, "http://192.168.1.100:12434", cfg.Host) + assert.Equal(t, "lab", cfg.Description) +} + +// TestContextCreate_missingHost verifies that --host is required. +func TestContextCreate_missingHost(t *testing.T) { + setupContextTest(t) + + cmd := newContextCreateCmd() + cmd.SetOut(new(bytes.Buffer)) + cmd.SetErr(new(bytes.Buffer)) + cmd.SetArgs([]string{"myremote"}) + + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "--host") +} + +// TestContextCreate_invalidName verifies that a name starting with a dash is +// rejected. +func TestContextCreate_invalidName(t *testing.T) { + setupContextTest(t) + + cmd := newContextCreateCmd() + cmd.SetOut(new(bytes.Buffer)) + cmd.SetErr(new(bytes.Buffer)) + cmd.SetArgs([]string{"-badname", "--host", "http://localhost:12434"}) + + // Cobra itself will reject args beginning with "-" as flag-like, so we + // test with a name that passes Cobra but fails our validation. + cmd2 := newContextCreateCmd() + cmd2.SetOut(new(bytes.Buffer)) + cmd2.SetErr(new(bytes.Buffer)) + cmd2.SetArgs([]string{"has space", "--host", "http://localhost:12434"}) + err := cmd2.Execute() + require.Error(t, err) +} + +// TestContextCreate_invalidHostURL verifies that hosts without a proper +// scheme or host component are rejected early. +func TestContextCreate_invalidHostURL(t *testing.T) { + setupContextTest(t) + + tests := []struct { + name string + host string + want string + }{ + {"no scheme", "192.168.1.100:12434", "invalid --host URL"}, + {"bare word", "localhost", "invalid --host URL"}, + {"ftp scheme", "ftp://example.com:12434", "unsupported scheme"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := newContextCreateCmd() + cmd.SetOut(new(bytes.Buffer)) + cmd.SetErr(new(bytes.Buffer)) + cmd.SetArgs([]string{"test", "--host", tt.host}) + + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), tt.want) + }) + } +} + +// TestContextCreate_reservedName verifies that "default" is rejected. +func TestContextCreate_reservedName(t *testing.T) { + setupContextTest(t) + + cmd := newContextCreateCmd() + cmd.SetOut(new(bytes.Buffer)) + cmd.SetErr(new(bytes.Buffer)) + cmd.SetArgs([]string{"default", "--host", "http://localhost:12434"}) + + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "reserved") +} + +// TestContextCreate_duplicate verifies that creating a context that already +// exists returns an error. +func TestContextCreate_duplicate(t *testing.T) { + setupContextTest(t) + + for range 2 { + cmd := newContextCreateCmd() + cmd.SetOut(new(bytes.Buffer)) + cmd.SetErr(new(bytes.Buffer)) + cmd.SetArgs([]string{"myremote", "--host", "http://localhost:12434"}) + _ = cmd.Execute() + } + + cmd := newContextCreateCmd() + cmd.SetOut(new(bytes.Buffer)) + cmd.SetErr(new(bytes.Buffer)) + cmd.SetArgs([]string{"myremote", "--host", "http://localhost:12434"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "already exists") +} + +// TestContextLs_empty verifies that "context ls" always shows the "default" +// row even when no named contexts exist. +func TestContextLs_empty(t *testing.T) { + setupContextTest(t) + + cmd := newContextLsCmd() + out := new(bytes.Buffer) + cmd.SetOut(out) + cmd.SetErr(new(bytes.Buffer)) + + require.NoError(t, cmd.Execute()) + assert.Contains(t, out.String(), "default") + assert.Contains(t, out.String(), "*") // default is active +} + +// TestContextLs_withContexts verifies that named contexts appear in the list +// and that the active one is marked with "*". +func TestContextLs_withContexts(t *testing.T) { + store := setupContextTest(t) + require.NoError(t, store.Create("remote", modelctx.ContextConfig{ + Host: "http://remote:12434", + Description: "remote box", + })) + require.NoError(t, store.SetActive("remote")) + + cmd := newContextLsCmd() + out := new(bytes.Buffer) + cmd.SetOut(out) + cmd.SetErr(new(bytes.Buffer)) + + require.NoError(t, cmd.Execute()) + output := out.String() + assert.Contains(t, output, "remote") + assert.Contains(t, output, "http://remote:12434") + + // "remote" should be marked active; "default" should not have "*". + lines := strings.Split(strings.TrimSpace(output), "\n") + for _, line := range lines { + if strings.Contains(line, "remote") && !strings.Contains(line, "default") { + assert.Contains(t, line, "*", "active context should show *") + } + if strings.Contains(line, "default") { + assert.NotContains(t, line, "*", "default should not show * when inactive") + } + } +} + +// TestContextLs_envVarWarning verifies that a MODEL_RUNNER_HOST env var +// triggers a warning on stderr. +func TestContextLs_envVarWarning(t *testing.T) { + setupContextTest(t) + t.Setenv("MODEL_RUNNER_HOST", "http://override:9999") + + cmd := newContextLsCmd() + out := new(bytes.Buffer) + errOut := new(bytes.Buffer) + cmd.SetOut(out) + cmd.SetErr(errOut) + + require.NoError(t, cmd.Execute()) + assert.Contains(t, errOut.String(), "MODEL_RUNNER_HOST") + assert.Contains(t, errOut.String(), "override:9999") +} + +// TestContextUse verifies that "context use" switches the active context. +func TestContextUse(t *testing.T) { + store := setupContextTest(t) + require.NoError(t, store.Create("myctx", modelctx.ContextConfig{ + Host: "http://localhost:12434", + })) + + cmd := newContextUseCmd() + out := new(bytes.Buffer) + cmd.SetOut(out) + cmd.SetErr(new(bytes.Buffer)) + cmd.SetArgs([]string{"myctx"}) + + require.NoError(t, cmd.Execute()) + assert.Contains(t, out.String(), "myctx") + + active, err := store.Active() + require.NoError(t, err) + assert.Equal(t, "myctx", active) +} + +// TestContextUse_default verifies that "context use default" resets to +// auto-detection. +func TestContextUse_default(t *testing.T) { + store := setupContextTest(t) + require.NoError(t, store.Create("myctx", modelctx.ContextConfig{ + Host: "http://localhost:12434", + })) + require.NoError(t, store.SetActive("myctx")) + + cmd := newContextUseCmd() + cmd.SetOut(new(bytes.Buffer)) + cmd.SetErr(new(bytes.Buffer)) + cmd.SetArgs([]string{"default"}) + + require.NoError(t, cmd.Execute()) + + active, err := store.Active() + require.NoError(t, err) + assert.Equal(t, modelctx.DefaultContextName, active) +} + +// TestContextUse_notFound verifies that "context use" returns an error for +// an unknown context name. +func TestContextUse_notFound(t *testing.T) { + setupContextTest(t) + + cmd := newContextUseCmd() + cmd.SetOut(new(bytes.Buffer)) + cmd.SetErr(new(bytes.Buffer)) + cmd.SetArgs([]string{"nosuchctx"}) + + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} + +// TestContextRm verifies that "context rm" removes a context. +func TestContextRm(t *testing.T) { + store := setupContextTest(t) + require.NoError(t, store.Create("myctx", modelctx.ContextConfig{ + Host: "http://localhost:12434", + })) + + cmd := newContextRmCmd() + out := new(bytes.Buffer) + cmd.SetOut(out) + cmd.SetErr(new(bytes.Buffer)) + cmd.SetArgs([]string{"myctx"}) + + require.NoError(t, cmd.Execute()) + assert.Contains(t, out.String(), "myctx") + + _, err := store.Get("myctx") + require.Error(t, err) // should be gone +} + +// TestContextRm_default verifies that "context rm default" returns an error. +func TestContextRm_default(t *testing.T) { + setupContextTest(t) + + cmd := newContextRmCmd() + cmd.SetOut(new(bytes.Buffer)) + cmd.SetErr(new(bytes.Buffer)) + cmd.SetArgs([]string{"default"}) + + err := cmd.Execute() + require.Error(t, err) +} + +// TestContextRm_active verifies that the active context cannot be removed. +func TestContextRm_active(t *testing.T) { + store := setupContextTest(t) + require.NoError(t, store.Create("myctx", modelctx.ContextConfig{ + Host: "http://localhost:12434", + })) + require.NoError(t, store.SetActive("myctx")) + + cmd := newContextRmCmd() + cmd.SetOut(new(bytes.Buffer)) + cmd.SetErr(new(bytes.Buffer)) + cmd.SetArgs([]string{"myctx"}) + + err := cmd.Execute() + require.Error(t, err) + + // Context must still exist. + _, getErr := store.Get("myctx") + require.NoError(t, getErr) +} + +// TestContextRm_notFound verifies that removing an unknown context returns an +// error. +func TestContextRm_notFound(t *testing.T) { + setupContextTest(t) + + cmd := newContextRmCmd() + cmd.SetOut(new(bytes.Buffer)) + cmd.SetErr(new(bytes.Buffer)) + cmd.SetArgs([]string{"nosuchctx"}) + + err := cmd.Execute() + require.Error(t, err) +} + +// TestContextInspect verifies that "context inspect" outputs valid JSON +// containing the stored host. +func TestContextInspect(t *testing.T) { + store := setupContextTest(t) + require.NoError(t, store.Create("myctx", modelctx.ContextConfig{ + Host: "http://192.168.1.100:12434", + Description: "lab box", + })) + + cmd := newContextInspectCmd() + out := new(bytes.Buffer) + cmd.SetOut(out) + cmd.SetErr(new(bytes.Buffer)) + cmd.SetArgs([]string{"myctx"}) + + require.NoError(t, cmd.Execute()) + + var results []map[string]any + require.NoError(t, json.Unmarshal(out.Bytes(), &results)) + require.Len(t, results, 1) + assert.Equal(t, "myctx", results[0]["name"]) + assert.Equal(t, "http://192.168.1.100:12434", results[0]["host"]) + assert.Equal(t, "lab box", results[0]["description"]) +} + +// TestContextInspect_default verifies that "context inspect default" returns +// a synthetic JSON entry. +func TestContextInspect_default(t *testing.T) { + setupContextTest(t) + + cmd := newContextInspectCmd() + out := new(bytes.Buffer) + cmd.SetOut(out) + cmd.SetErr(new(bytes.Buffer)) + cmd.SetArgs([]string{"default"}) + + require.NoError(t, cmd.Execute()) + + var results []map[string]any + require.NoError(t, json.Unmarshal(out.Bytes(), &results)) + require.Len(t, results, 1) + assert.Equal(t, "default", results[0]["name"]) +} + +// TestContextInspect_notFound verifies that inspecting an unknown context +// returns an error. +func TestContextInspect_notFound(t *testing.T) { + setupContextTest(t) + + cmd := newContextInspectCmd() + cmd.SetOut(new(bytes.Buffer)) + cmd.SetErr(new(bytes.Buffer)) + cmd.SetArgs([]string{"nosuchctx"}) + + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} diff --git a/cmd/cli/commands/root.go b/cmd/cli/commands/root.go index 69f0f655a..200361216 100644 --- a/cmd/cli/commands/root.go +++ b/cmd/cli/commands/root.go @@ -13,6 +13,38 @@ import ( // dockerCLI is the Docker CLI environment associated with the command. var dockerCLI *command.DockerCli +// globalOptions holds the Docker client options used in standalone mode. It is +// set during NewRootCmd and referenced by initDockerCLI. +var globalOptions *flags.ClientOptions + +// initDockerCLI performs Docker CLI / plugin initialisation. It is called by +// both the root PersistentPreRunE and the context-command PersistentPreRunE. +// After this call dockerCLI is set and the Docker config is available. +func initDockerCLI(cmd *cobra.Command, args []string, cli *command.DockerCli, opts *flags.ClientOptions) error { + if plugin.RunningStandalone() { + opts.SetDefaultOptions(cmd.Root().Flags()) + if err := cli.Initialize(opts); err != nil { + return fmt.Errorf("unable to configure CLI: %w", err) + } + } else if err := plugin.PersistentPreRunE(cmd, args); err != nil { + return err + } + dockerCLI = cli + return nil +} + +// initModelRunner detects the active Model Runner backend and initialises the +// shared desktopClient. It must be called after initDockerCLI. +func initModelRunner(cmd *cobra.Command, cli *command.DockerCli) error { + var err error + modelRunner, err = desktop.DetectContext(cmd.Context(), cli, asPrinter(cmd)) + if err != nil { + return fmt.Errorf("unable to detect model runner context: %w", err) + } + desktopClient = desktop.New(modelRunner) + return nil +} + // getDockerCLI is an accessor for dockerCLI that can be passed to other // packages. func getDockerCLI() *command.DockerCli { @@ -34,38 +66,15 @@ func getDesktopClient() *desktop.Client { } func NewRootCmd(cli *command.DockerCli) *cobra.Command { - // If we're running in standalone mode, then we're responsible for - // initializing the CLI. In this case, we'll need to initialize the client - // options as well, which we'll add as global flags on the root command. We - // perform that initialization below so that we can register flags with the - // root command. - var globalOptions *flags.ClientOptions - // Set up the root command. - var rootCmd *cobra.Command - rootCmd = &cobra.Command{ + rootCmd := &cobra.Command{ Use: "model", Short: "Docker Model Runner", PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - // Finalize initialization of the CLI. - if plugin.RunningStandalone() { - globalOptions.SetDefaultOptions(rootCmd.Flags()) - if err := cli.Initialize(globalOptions); err != nil { - return fmt.Errorf("unable to configure CLI: %w", err) - } - } else if err := plugin.PersistentPreRunE(cmd, args); err != nil { + if err := initDockerCLI(cmd, args, cli, globalOptions); err != nil { return err } - dockerCLI = cli - - // Detect the model runner context and create a client for it. - var err error - modelRunner, err = desktop.DetectContext(cmd.Context(), dockerCLI, asPrinter(cmd)) - if err != nil { - return fmt.Errorf("unable to detect model runner context: %w", err) - } - desktopClient = desktop.New(modelRunner) - return nil + return initModelRunner(cmd, cli) }, // If running standalone, then we'll register global Docker flags as // top-level flags on the root command, so we'll have to set @@ -87,6 +96,7 @@ func NewRootCmd(cli *command.DockerCli) *cobra.Command { // Runner management commands - these manage the runner itself and don't need automatic runner initialization. rootCmd.AddCommand( newVersionCmd(), + newContextCmd(cli), newInstallRunner(), newUninstallRunner(), newStartRunner(), diff --git a/cmd/cli/desktop/context.go b/cmd/cli/desktop/context.go index a3e6ed46a..7e6546b2c 100644 --- a/cmd/cli/desktop/context.go +++ b/cmd/cli/desktop/context.go @@ -7,6 +7,7 @@ import ( "net/http" "net/url" "os" + "path/filepath" "runtime" "strconv" "strings" @@ -16,6 +17,7 @@ import ( "github.com/docker/cli/cli/command" "github.com/docker/cli/cli/connhelper" "github.com/docker/cli/cli/context/docker" + "github.com/docker/model-runner/cmd/cli/pkg/modelctx" "github.com/docker/model-runner/cmd/cli/pkg/standalone" "github.com/docker/model-runner/cmd/cli/pkg/types" "github.com/docker/model-runner/pkg/inference" @@ -231,6 +233,16 @@ func wakeUpCloudIfIdle(ctx context.Context, cli *command.DockerCli) error { return nil } +// namedContextStore returns a modelctx.Store rooted in the Docker config +// directory. Errors are non-fatal — callers fall back to auto-detection. +func namedContextStore(cli *command.DockerCli) (*modelctx.Store, error) { + if cli == nil || cli.ConfigFile() == nil { + return nil, fmt.Errorf("CLI not initialised") + } + configDir := filepath.Dir(cli.ConfigFile().Filename) + return modelctx.New(configDir) +} + // DetectContext determines the current Docker Model Runner context. func DetectContext(ctx context.Context, cli *command.DockerCli, printer standalone.StatusPrinter) (*ModelRunnerContext, error) { // Check for an explicit endpoint setting. @@ -240,10 +252,46 @@ func DetectContext(ctx context.Context, cli *command.DockerCli, printer standalo // testing purposes. treatDesktopAsMoby := os.Getenv("_MODEL_RUNNER_TREAT_DESKTOP_AS_MOBY") == "1" - // Check if TLS should be used - useTLS := os.Getenv("MODEL_RUNNER_TLS") == "true" - tlsSkipVerify := os.Getenv("MODEL_RUNNER_TLS_SKIP_VERIFY") == "true" - tlsCACert := os.Getenv("MODEL_RUNNER_TLS_CA_CERT") + // Read TLS env vars with LookupEnv so that unset and explicitly-set values + // can be distinguished. This lets named-context TLS settings be overridden + // field-by-field via environment variables. + tlsVal, tlsSet := os.LookupEnv("MODEL_RUNNER_TLS") + tlsSkipVerifyVal, tlsSkipVerifySet := os.LookupEnv("MODEL_RUNNER_TLS_SKIP_VERIFY") + tlsCACertVal, tlsCACertSet := os.LookupEnv("MODEL_RUNNER_TLS_CA_CERT") + useTLS := tlsSet && tlsVal == "true" + tlsSkipVerify := tlsSkipVerifySet && tlsSkipVerifyVal == "true" + tlsCACert := tlsCACertVal + + // If MODEL_RUNNER_HOST is not set, check whether a named context is active + // and use its host and TLS settings as the base configuration. Explicitly + // set env vars always win and overlay the stored values. + if modelRunnerHost == "" { + store, err := namedContextStore(cli) + if err != nil { + printer.Printf("Warning: unable to open context store: %v\n", err) + } else { + activeName, err := store.Active() + if err != nil { + printer.Printf("Warning: unable to determine active context: %v\n", err) + } else if activeName != modelctx.DefaultContextName { + cfg, err := store.Get(activeName) + if err != nil { + printer.Printf("Warning: unable to read context %q: %v\n", activeName, err) + } else { + modelRunnerHost = cfg.Host + if !tlsSet { + useTLS = cfg.TLS.Enabled + } + if !tlsSkipVerifySet { + tlsSkipVerify = cfg.TLS.SkipVerify + } + if !tlsCACertSet && cfg.TLS.CACert != "" { + tlsCACert = cfg.TLS.CACert + } + } + } + } + } // Detect the associated engine type. kind := types.ModelRunnerEngineKindMoby diff --git a/cmd/cli/docs/reference/docker_model.yaml b/cmd/cli/docs/reference/docker_model.yaml index 88068e051..955667c97 100644 --- a/cmd/cli/docs/reference/docker_model.yaml +++ b/cmd/cli/docs/reference/docker_model.yaml @@ -7,6 +7,7 @@ pname: docker plink: docker.yaml cname: - docker model bench + - docker model context - docker model df - docker model inspect - docker model install-runner @@ -35,6 +36,7 @@ cname: - docker model version clink: - docker_model_bench.yaml + - docker_model_context.yaml - docker_model_df.yaml - docker_model_inspect.yaml - docker_model_install-runner.yaml diff --git a/cmd/cli/docs/reference/docker_model_context.yaml b/cmd/cli/docs/reference/docker_model_context.yaml new file mode 100644 index 000000000..05a532ba7 --- /dev/null +++ b/cmd/cli/docs/reference/docker_model_context.yaml @@ -0,0 +1,24 @@ +command: docker model context +short: Manage Docker Model Runner contexts +long: Manage Docker Model Runner contexts +pname: docker model +plink: docker_model.yaml +cname: + - docker model context create + - docker model context inspect + - docker model context ls + - docker model context rm + - docker model context use +clink: + - docker_model_context_create.yaml + - docker_model_context_inspect.yaml + - docker_model_context_ls.yaml + - docker_model_context_rm.yaml + - docker_model_context_use.yaml +deprecated: false +hidden: false +experimental: false +experimentalcli: false +kubernetes: false +swarm: false + diff --git a/cmd/cli/docs/reference/docker_model_context_create.yaml b/cmd/cli/docs/reference/docker_model_context_create.yaml new file mode 100644 index 000000000..47e5bb49d --- /dev/null +++ b/cmd/cli/docs/reference/docker_model_context_create.yaml @@ -0,0 +1,61 @@ +command: docker model context create +short: Create a named Model Runner context +long: Create a named Model Runner context +usage: docker model context create NAME +pname: docker model context +plink: docker_model_context.yaml +options: + - option: description + value_type: string + description: Optional human-readable description for this context + deprecated: false + hidden: false + experimental: false + experimentalcli: false + kubernetes: false + swarm: false + - option: host + value_type: string + description: Model Runner API base URL (e.g. http://192.168.1.100:12434) + deprecated: false + hidden: false + experimental: false + experimentalcli: false + kubernetes: false + swarm: false + - option: tls + value_type: bool + default_value: "false" + description: Enable TLS for connections to this context + deprecated: false + hidden: false + experimental: false + experimentalcli: false + kubernetes: false + swarm: false + - option: tls-ca-cert + value_type: string + description: Path to a custom CA certificate PEM file for TLS verification + deprecated: false + hidden: false + experimental: false + experimentalcli: false + kubernetes: false + swarm: false + - option: tls-skip-verify + value_type: bool + default_value: "false" + description: Skip TLS server certificate verification + deprecated: false + hidden: false + experimental: false + experimentalcli: false + kubernetes: false + swarm: false +deprecated: false +hidden: false +experimental: false +experimentalcli: false +kubernetes: false +swarm: false + diff --git a/cmd/cli/docs/reference/docker_model_context_inspect.yaml b/cmd/cli/docs/reference/docker_model_context_inspect.yaml new file mode 100644 index 000000000..82c897dd3 --- /dev/null +++ b/cmd/cli/docs/reference/docker_model_context_inspect.yaml @@ -0,0 +1,13 @@ +command: docker model context inspect +short: Display detailed information about one or more contexts +long: Display detailed information about one or more contexts +usage: docker model context inspect NAME [NAME...] +pname: docker model context +plink: docker_model_context.yaml +deprecated: false +hidden: false +experimental: false +experimentalcli: false +kubernetes: false +swarm: false + diff --git a/cmd/cli/docs/reference/docker_model_context_ls.yaml b/cmd/cli/docs/reference/docker_model_context_ls.yaml new file mode 100644 index 000000000..cf97594f1 --- /dev/null +++ b/cmd/cli/docs/reference/docker_model_context_ls.yaml @@ -0,0 +1,14 @@ +command: docker model context ls +aliases: docker model context ls, docker model context list +short: List Model Runner contexts +long: List Model Runner contexts +usage: docker model context ls +pname: docker model context +plink: docker_model_context.yaml +deprecated: false +hidden: false +experimental: false +experimentalcli: false +kubernetes: false +swarm: false + diff --git a/cmd/cli/docs/reference/docker_model_context_rm.yaml b/cmd/cli/docs/reference/docker_model_context_rm.yaml new file mode 100644 index 000000000..2efa303bb --- /dev/null +++ b/cmd/cli/docs/reference/docker_model_context_rm.yaml @@ -0,0 +1,14 @@ +command: docker model context rm +aliases: docker model context rm, docker model context remove +short: Remove one or more Model Runner contexts +long: Remove one or more Model Runner contexts +usage: docker model context rm NAME [NAME...] +pname: docker model context +plink: docker_model_context.yaml +deprecated: false +hidden: false +experimental: false +experimentalcli: false +kubernetes: false +swarm: false + diff --git a/cmd/cli/docs/reference/docker_model_context_use.yaml b/cmd/cli/docs/reference/docker_model_context_use.yaml new file mode 100644 index 000000000..720281626 --- /dev/null +++ b/cmd/cli/docs/reference/docker_model_context_use.yaml @@ -0,0 +1,15 @@ +command: docker model context use +short: Set the active Model Runner context +long: |- + Set the active Model Runner context. Pass "default" to revert to + automatic backend detection. +usage: docker model context use NAME +pname: docker model context +plink: docker_model_context.yaml +deprecated: false +hidden: false +experimental: false +experimentalcli: false +kubernetes: false +swarm: false + diff --git a/cmd/cli/docs/reference/model.md b/cmd/cli/docs/reference/model.md index c34bfbdc4..f041e5e10 100644 --- a/cmd/cli/docs/reference/model.md +++ b/cmd/cli/docs/reference/model.md @@ -8,6 +8,7 @@ Docker Model Runner | Name | Description | |:------------------------------------------------|:-----------------------------------------------------------------------| | [`bench`](model_bench.md) | Benchmark a model's performance at different concurrency levels | +| [`context`](model_context.md) | Manage Docker Model Runner contexts | | [`df`](model_df.md) | Show Docker Model Runner disk usage | | [`inspect`](model_inspect.md) | Display detailed information on one model | | [`install-runner`](model_install-runner.md) | Install Docker Model Runner (Docker Engine only) | diff --git a/cmd/cli/docs/reference/model_context.md b/cmd/cli/docs/reference/model_context.md new file mode 100644 index 000000000..d5c05658c --- /dev/null +++ b/cmd/cli/docs/reference/model_context.md @@ -0,0 +1,19 @@ +# docker model context + + +Manage Docker Model Runner contexts + +### Subcommands + +| Name | Description | +|:--------------------------------------|:--------------------------------------------------------| +| [`create`](model_context_create.md) | Create a named Model Runner context | +| [`inspect`](model_context_inspect.md) | Display detailed information about one or more contexts | +| [`ls`](model_context_ls.md) | List Model Runner contexts | +| [`rm`](model_context_rm.md) | Remove one or more Model Runner contexts | +| [`use`](model_context_use.md) | Set the active Model Runner context | + + + + + diff --git a/cmd/cli/docs/reference/model_context_create.md b/cmd/cli/docs/reference/model_context_create.md new file mode 100644 index 000000000..cee0c6338 --- /dev/null +++ b/cmd/cli/docs/reference/model_context_create.md @@ -0,0 +1,18 @@ +# docker model context create + + +Create a named Model Runner context + +### Options + +| Name | Type | Default | Description | +|:--------------------|:---------|:--------|:--------------------------------------------------------------| +| `--description` | `string` | | Optional human-readable description for this context | +| `--host` | `string` | | Model Runner API base URL (e.g. http://192.168.1.100:12434) | +| `--tls` | `bool` | | Enable TLS for connections to this context | +| `--tls-ca-cert` | `string` | | Path to a custom CA certificate PEM file for TLS verification | +| `--tls-skip-verify` | `bool` | | Skip TLS server certificate verification | + + + + diff --git a/cmd/cli/docs/reference/model_context_inspect.md b/cmd/cli/docs/reference/model_context_inspect.md new file mode 100644 index 000000000..75357d4a1 --- /dev/null +++ b/cmd/cli/docs/reference/model_context_inspect.md @@ -0,0 +1,8 @@ +# docker model context inspect + + +Display detailed information about one or more contexts + + + + diff --git a/cmd/cli/docs/reference/model_context_ls.md b/cmd/cli/docs/reference/model_context_ls.md new file mode 100644 index 000000000..5d9e980d8 --- /dev/null +++ b/cmd/cli/docs/reference/model_context_ls.md @@ -0,0 +1,12 @@ +# docker model context ls + + +List Model Runner contexts + +### Aliases + +`docker model context ls`, `docker model context list` + + + + diff --git a/cmd/cli/docs/reference/model_context_rm.md b/cmd/cli/docs/reference/model_context_rm.md new file mode 100644 index 000000000..65a408590 --- /dev/null +++ b/cmd/cli/docs/reference/model_context_rm.md @@ -0,0 +1,12 @@ +# docker model context rm + + +Remove one or more Model Runner contexts + +### Aliases + +`docker model context rm`, `docker model context remove` + + + + diff --git a/cmd/cli/docs/reference/model_context_use.md b/cmd/cli/docs/reference/model_context_use.md new file mode 100644 index 000000000..f6544f63f --- /dev/null +++ b/cmd/cli/docs/reference/model_context_use.md @@ -0,0 +1,9 @@ +# docker model context use + + +Set the active Model Runner context. Pass "default" to revert to +automatic backend detection. + + + + diff --git a/cmd/cli/pkg/modelctx/lock_unix.go b/cmd/cli/pkg/modelctx/lock_unix.go new file mode 100644 index 000000000..8065fb30b --- /dev/null +++ b/cmd/cli/pkg/modelctx/lock_unix.go @@ -0,0 +1,20 @@ +//go:build !windows + +package modelctx + +import ( + "os" + + "golang.org/x/sys/unix" +) + +// lockFile acquires an exclusive advisory lock on the given file using flock(2). +// The lock is automatically released when the file is closed. +func lockFile(f *os.File) error { + return unix.Flock(int(f.Fd()), unix.LOCK_EX) +} + +// unlockFile releases the advisory lock on the given file. +func unlockFile(f *os.File) error { + return unix.Flock(int(f.Fd()), unix.LOCK_UN) +} diff --git a/cmd/cli/pkg/modelctx/lock_windows.go b/cmd/cli/pkg/modelctx/lock_windows.go new file mode 100644 index 000000000..00896eca7 --- /dev/null +++ b/cmd/cli/pkg/modelctx/lock_windows.go @@ -0,0 +1,37 @@ +//go:build windows + +package modelctx + +import ( + "os" + + "golang.org/x/sys/windows" +) + +// lockFile acquires an exclusive lock on the given file using LockFileEx. +// The lock is automatically released when the file is closed. +func lockFile(f *os.File) error { + // LOCKFILE_EXCLUSIVE_LOCK requests an exclusive lock. + // The zero Overlapped struct locks starting at offset 0. + ol := new(windows.Overlapped) + return windows.LockFileEx( + windows.Handle(f.Fd()), + windows.LOCKFILE_EXCLUSIVE_LOCK, + 0, // reserved + 1, // nNumberOfBytesToLockLow + 0, // nNumberOfBytesToLockHigh + ol, + ) +} + +// unlockFile releases the lock on the given file. +func unlockFile(f *os.File) error { + ol := new(windows.Overlapped) + return windows.UnlockFileEx( + windows.Handle(f.Fd()), + 0, // reserved + 1, // nNumberOfBytesToUnlockLow + 0, // nNumberOfBytesToUnlockHigh + ol, + ) +} diff --git a/cmd/cli/pkg/modelctx/store.go b/cmd/cli/pkg/modelctx/store.go new file mode 100644 index 000000000..80fca15bb --- /dev/null +++ b/cmd/cli/pkg/modelctx/store.go @@ -0,0 +1,283 @@ +// Package modelctx provides persistent storage for named Docker Model Runner +// contexts, allowing users to switch between different Model Runner backends +// without setting environment variables each time. +package modelctx + +import ( + "crypto/rand" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "regexp" + "time" +) + +// DefaultContextName is the reserved name for the auto-detected context. +// It is never written to disk; a missing or empty "current" value implies it. +const DefaultContextName = "default" + +// contextFileVersion is the version of the on-disk context file format. +const contextFileVersion = 1 + +// validContextName matches names that follow Docker's context naming rules. +var validContextName = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_.-]*$`) + +// TLSConfig holds optional TLS settings for a named context. +type TLSConfig struct { + // Enabled indicates whether TLS is used for this context. + Enabled bool `json:"enabled"` + // SkipVerify disables TLS server certificate verification. + SkipVerify bool `json:"skipVerify,omitempty"` + // CACert is the absolute path to a custom CA certificate PEM file. + CACert string `json:"caCert,omitempty"` +} + +// ContextConfig is the configuration for a named Model Runner context. +type ContextConfig struct { + // Host is the Model Runner API base URL (e.g. "http://192.168.1.100:12434"). + Host string `json:"host"` + // TLS holds optional TLS settings. + TLS TLSConfig `json:"tls,omitempty"` + // Description is an optional human-readable note. + Description string `json:"description,omitempty"` + // CreatedAt records when the context was created. + CreatedAt time.Time `json:"createdAt"` +} + +// contextFile is the versioned on-disk representation of the context store. +type contextFile struct { + // Version is the schema version; currently always 1. + Version int `json:"version"` + // Current is the active context name; empty means DefaultContextName. + Current string `json:"current,omitempty"` + // Contexts is a map from context name to its configuration. + Contexts map[string]ContextConfig `json:"contexts"` +} + +// Store manages named Model Runner contexts stored in a single JSON file. +type Store struct { + // path is the absolute path to the contexts.json file. + path string +} + +// New returns a Store that persists contexts in +// /model/contexts.json. It creates the parent directory if +// it does not exist. +func New(dockerConfigDir string) (*Store, error) { + dir := filepath.Join(dockerConfigDir, "model") + if err := os.MkdirAll(dir, 0o755); err != nil { + return nil, fmt.Errorf("unable to create model context directory: %w", err) + } + return &Store{path: filepath.Join(dir, "contexts.json")}, nil +} + +// ValidateName returns an error if name is reserved or does not match the +// allowed pattern. +func ValidateName(name string) error { + if name == DefaultContextName { + return fmt.Errorf("context name %q is reserved", name) + } + if !validContextName.MatchString(name) { + return fmt.Errorf( + "invalid context name %q: must match %s", + name, validContextName, + ) + } + return nil +} + +// List returns all named contexts. The synthetic "default" context is not +// included. +func (s *Store) List() (map[string]ContextConfig, error) { + cf, err := s.read() + if err != nil { + return nil, err + } + return cf.Contexts, nil +} + +// Get returns the configuration for the named context. +func (s *Store) Get(name string) (ContextConfig, error) { + cf, err := s.read() + if err != nil { + return ContextConfig{}, err + } + cfg, ok := cf.Contexts[name] + if !ok { + return ContextConfig{}, fmt.Errorf("context %q not found", name) + } + return cfg, nil +} + +// Create writes a new named context. It returns an error if the name is +// reserved, fails validation, or already exists. +func (s *Store) Create(name string, cfg ContextConfig) error { + if err := ValidateName(name); err != nil { + return err + } + return s.update(func(cf *contextFile) error { + if _, exists := cf.Contexts[name]; exists { + return fmt.Errorf("context %q already exists", name) + } + cf.Contexts[name] = cfg + return nil + }) +} + +// Remove deletes the named context. It returns an error if name is +// DefaultContextName or if the context is currently active. +func (s *Store) Remove(name string) error { + if name == DefaultContextName { + return fmt.Errorf("context name %q is reserved and cannot be removed", name) + } + return s.update(func(cf *contextFile) error { + if _, exists := cf.Contexts[name]; !exists { + return fmt.Errorf("context %q not found", name) + } + if cf.Current == name { + return fmt.Errorf( + "context %q is currently active; switch to another context first", + name, + ) + } + delete(cf.Contexts, name) + return nil + }) +} + +// Active returns the name of the currently active context, or +// DefaultContextName if none has been set. +func (s *Store) Active() (string, error) { + cf, err := s.read() + if err != nil { + return "", err + } + if cf.Current == "" { + return DefaultContextName, nil + } + return cf.Current, nil +} + +// SetActive makes the named context active. Pass DefaultContextName to revert +// to auto-detection. The named context must already exist unless name is +// DefaultContextName. +func (s *Store) SetActive(name string) error { + if name != DefaultContextName { + if err := ValidateName(name); err != nil { + return err + } + } + return s.update(func(cf *contextFile) error { + if name != DefaultContextName { + if _, exists := cf.Contexts[name]; !exists { + return fmt.Errorf("context %q not found", name) + } + } + // Store DefaultContextName as an empty string so the JSON omits the + // field, keeping the file clean. + if name == DefaultContextName { + cf.Current = "" + } else { + cf.Current = name + } + return nil + }) +} + +// read loads the context file from disk. A missing file is treated as an +// empty store rather than an error. +func (s *Store) read() (contextFile, error) { + data, err := os.ReadFile(s.path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return contextFile{ + Version: contextFileVersion, + Contexts: make(map[string]ContextConfig), + }, nil + } + return contextFile{}, fmt.Errorf("unable to read context file: %w", err) + } + var cf contextFile + if err := json.Unmarshal(data, &cf); err != nil { + return contextFile{}, fmt.Errorf("unable to parse context file: %w", err) + } + if cf.Contexts == nil { + cf.Contexts = make(map[string]ContextConfig) + } + return cf, nil +} + +// update applies a mutation function under a file lock and writes the result +// atomically. This serialises concurrent writers while allowing readers to +// always see a complete file. +func (s *Store) update(mutate func(*contextFile) error) error { + lockPath := filepath.Join(filepath.Dir(s.path), "contexts.lock") + + // Open (or create) the lock file, then acquire an exclusive advisory + // lock via flock(2) (Unix) or LockFileEx (Windows). The OS-level lock + // prevents concurrent processes from entering this critical section at + // the same time. The lock is released on close. + lf, err := os.OpenFile(lockPath, os.O_CREATE|os.O_RDWR, 0o600) + if err != nil { + return fmt.Errorf("unable to open context lock file: %w", err) + } + defer func() { + _ = unlockFile(lf) + _ = lf.Close() + }() + + if err := lockFile(lf); err != nil { + return fmt.Errorf("unable to acquire context lock: %w", err) + } + + // Re-read under lock to pick up any changes made since the caller last read. + cf, err := s.read() + if err != nil { + return err + } + + // Apply the mutation. + if err := mutate(&cf); err != nil { + return err + } + + // Serialise the updated state. + data, err := json.MarshalIndent(cf, "", " ") + if err != nil { + return fmt.Errorf("unable to serialise context file: %w", err) + } + data = append(data, '\n') + + // Write to a uniquely named temp file then rename atomically. + var rndBuf [8]byte + if _, err := rand.Read(rndBuf[:]); err != nil { + return fmt.Errorf("unable to generate random bytes for temp file: %w", err) + } + tmpPath := fmt.Sprintf( + "%s.tmp.%d.%x", + s.path, os.Getpid(), rndBuf, + ) + f, err := os.OpenFile(tmpPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600) + if err != nil { + return fmt.Errorf("unable to write context file: %w", err) + } + if _, err := f.Write(data); err != nil { + _ = f.Close() + _ = os.Remove(tmpPath) + return fmt.Errorf("unable to write context file: %w", err) + } + if err := f.Sync(); err != nil { + _ = f.Close() + _ = os.Remove(tmpPath) + return fmt.Errorf("unable to sync context file: %w", err) + } + _ = f.Close() + + if err := os.Rename(tmpPath, s.path); err != nil { + _ = os.Remove(tmpPath) + return fmt.Errorf("unable to commit context file: %w", err) + } + return nil +} diff --git a/cmd/cli/pkg/modelctx/store_test.go b/cmd/cli/pkg/modelctx/store_test.go new file mode 100644 index 000000000..d419b51cd --- /dev/null +++ b/cmd/cli/pkg/modelctx/store_test.go @@ -0,0 +1,256 @@ +package modelctx + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// sampleConfig returns a ContextConfig suitable for use in tests. +func sampleConfig(host string) ContextConfig { + return ContextConfig{ + Host: host, + Description: "test context", + CreatedAt: time.Now().UTC().Truncate(time.Second), + } +} + +// newTestStore creates a Store backed by a temporary directory. +func newTestStore(t *testing.T) *Store { + t.Helper() + store, err := New(t.TempDir()) + require.NoError(t, err) + return store +} + +// TestNew verifies that New creates the storage directory and that opening an +// existing store on the same path succeeds. +func TestNew(t *testing.T) { + dir := t.TempDir() + store, err := New(dir) + require.NoError(t, err) + require.NotNil(t, store) + + // Re-opening the same path should also succeed. + store2, err := New(dir) + require.NoError(t, err) + require.NotNil(t, store2) +} + +// TestCreate verifies that a newly created context can be retrieved via Get. +func TestCreate(t *testing.T) { + store := newTestStore(t) + cfg := sampleConfig("http://localhost:12434") + require.NoError(t, store.Create("myctx", cfg)) + + got, err := store.Get("myctx") + require.NoError(t, err) + assert.Equal(t, cfg.Host, got.Host) + assert.Equal(t, cfg.Description, got.Description) +} + +// TestCreate_reservedName verifies that "default" cannot be used as a context +// name. +func TestCreate_reservedName(t *testing.T) { + store := newTestStore(t) + err := store.Create(DefaultContextName, sampleConfig("http://localhost:12434")) + require.Error(t, err) + assert.Contains(t, err.Error(), "reserved") +} + +// TestCreate_invalidNames verifies that names violating the naming rules are +// rejected. +func TestCreate_invalidNames(t *testing.T) { + store := newTestStore(t) + cfg := sampleConfig("http://localhost:12434") + for _, name := range []string{"", "-leading-dash", "has space", "has/slash"} { + err := store.Create(name, cfg) + require.Errorf(t, err, "expected error for name %q", name) + } +} + +// TestCreate_duplicate verifies that creating a context with an already-used +// name returns an error. +func TestCreate_duplicate(t *testing.T) { + store := newTestStore(t) + cfg := sampleConfig("http://localhost:12434") + require.NoError(t, store.Create("myctx", cfg)) + + err := store.Create("myctx", sampleConfig("http://other:12434")) + require.Error(t, err) + assert.Contains(t, err.Error(), "already exists") +} + +// TestGet_notFound verifies that Get returns an error for unknown names. +func TestGet_notFound(t *testing.T) { + store := newTestStore(t) + _, err := store.Get("nosuchctx") + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} + +// TestList verifies that List returns an empty map when no contexts exist and +// returns all created contexts afterwards. +func TestList(t *testing.T) { + store := newTestStore(t) + + // Empty store. + contexts, err := store.List() + require.NoError(t, err) + assert.Empty(t, contexts) + + // Add two contexts. + require.NoError(t, store.Create("alpha", sampleConfig("http://alpha:12434"))) + require.NoError(t, store.Create("beta", sampleConfig("http://beta:12434"))) + + contexts, err = store.List() + require.NoError(t, err) + assert.Len(t, contexts, 2) + assert.Contains(t, contexts, "alpha") + assert.Contains(t, contexts, "beta") +} + +// TestRemove verifies that a context is gone after removal. +func TestRemove(t *testing.T) { + store := newTestStore(t) + require.NoError(t, store.Create("myctx", sampleConfig("http://localhost:12434"))) + + require.NoError(t, store.Remove("myctx")) + + _, err := store.Get("myctx") + require.Error(t, err) +} + +// TestRemove_default verifies that the reserved "default" name cannot be +// removed. +func TestRemove_default(t *testing.T) { + store := newTestStore(t) + err := store.Remove(DefaultContextName) + require.Error(t, err) + assert.Contains(t, err.Error(), "reserved") +} + +// TestRemove_notFound verifies that removing an unknown context returns an +// error. +func TestRemove_notFound(t *testing.T) { + store := newTestStore(t) + err := store.Remove("nosuchctx") + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} + +// TestRemove_activeContext verifies that the currently active context cannot +// be removed and that it remains in the store after the attempt. +func TestRemove_activeContext(t *testing.T) { + store := newTestStore(t) + require.NoError(t, store.Create("myctx", sampleConfig("http://localhost:12434"))) + require.NoError(t, store.SetActive("myctx")) + + err := store.Remove("myctx") + require.Error(t, err) + assert.Contains(t, err.Error(), "currently active") + + // Context must still be present. + _, err = store.Get("myctx") + require.NoError(t, err) +} + +// TestActive_default verifies that Active returns DefaultContextName when no +// context file exists. +func TestActive_default(t *testing.T) { + store := newTestStore(t) + active, err := store.Active() + require.NoError(t, err) + assert.Equal(t, DefaultContextName, active) +} + +// TestSetActive verifies that SetActive changes the value returned by Active. +func TestSetActive(t *testing.T) { + store := newTestStore(t) + require.NoError(t, store.Create("myctx", sampleConfig("http://localhost:12434"))) + + require.NoError(t, store.SetActive("myctx")) + + active, err := store.Active() + require.NoError(t, err) + assert.Equal(t, "myctx", active) +} + +// TestSetActive_backToDefault verifies that SetActive("default") resets the +// active context to the auto-detect sentinel. +func TestSetActive_backToDefault(t *testing.T) { + store := newTestStore(t) + require.NoError(t, store.Create("myctx", sampleConfig("http://localhost:12434"))) + require.NoError(t, store.SetActive("myctx")) + + require.NoError(t, store.SetActive(DefaultContextName)) + + active, err := store.Active() + require.NoError(t, err) + assert.Equal(t, DefaultContextName, active) +} + +// TestSetActive_notFound verifies that SetActive returns an error when the +// named context does not exist. +func TestSetActive_notFound(t *testing.T) { + store := newTestStore(t) + err := store.SetActive("nosuchctx") + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} + +// TestSetActive_invalidName verifies that SetActive rejects invalid names. +func TestSetActive_invalidName(t *testing.T) { + store := newTestStore(t) + err := store.SetActive("has space") + require.Error(t, err) +} + +// TestPersistence verifies that context data written by one Store instance is +// readable by a new instance opened on the same directory. +func TestPersistence(t *testing.T) { + dir := t.TempDir() + cfg := sampleConfig("http://remote:12434") + + // Write with the first instance. + s1, err := New(dir) + require.NoError(t, err) + require.NoError(t, s1.Create("remote", cfg)) + require.NoError(t, s1.SetActive("remote")) + + // Read back with a new instance. + s2, err := New(dir) + require.NoError(t, err) + + active, err := s2.Active() + require.NoError(t, err) + assert.Equal(t, "remote", active) + + got, err := s2.Get("remote") + require.NoError(t, err) + assert.Equal(t, cfg.Host, got.Host) + assert.Equal(t, cfg.Description, got.Description) +} + +// TestTLSConfig verifies that TLS settings are stored and retrieved correctly. +func TestTLSConfig(t *testing.T) { + store := newTestStore(t) + cfg := ContextConfig{ + Host: "https://secure:12444", + TLS: TLSConfig{ + Enabled: true, + SkipVerify: false, + CACert: "/etc/ssl/certs/ca.pem", + }, + CreatedAt: time.Now().UTC(), + } + require.NoError(t, store.Create("secure", cfg)) + + got, err := store.Get("secure") + require.NoError(t, err) + assert.True(t, got.TLS.Enabled) + assert.False(t, got.TLS.SkipVerify) + assert.Equal(t, "/etc/ssl/certs/ca.pem", got.TLS.CACert) +}