diff --git a/cmd/pilotctl/main.go b/cmd/pilotctl/main.go index ca24ee60..cb2547e3 100644 --- a/cmd/pilotctl/main.go +++ b/cmd/pilotctl/main.go @@ -52,8 +52,15 @@ func defaultSocket() string { } func configDir() string { + // PILOT_HOME, when set, overrides the location of the .pilot config + // directory — matching the override the other components honor and + // letting callers (tests, sandboxes, multi-identity setups) relocate + // state without rewriting $HOME. + if h := os.Getenv("PILOT_HOME"); h != "" { + return filepath.Join(h, defaultConfigDir) + } home, _ := os.UserHomeDir() - return home + "/" + defaultConfigDir + return filepath.Join(home, defaultConfigDir) } func configPath() string { return configDir() + "/" + defaultConfigFile } @@ -275,11 +282,23 @@ func getRegistry() string { func loadConfig() map[string]interface{} { f, err := os.Open(configPath()) if err != nil { + if !errors.Is(err, os.ErrNotExist) { + // A missing config is normal (defaults apply); anything else — + // permission denied, an I/O error — is a real problem the user + // should hear about rather than silently get defaults for. + slog.Warn("could not read config; using defaults", + "path", configPath(), "error", err) + } return map[string]interface{}{} } defer f.Close() var cfg map[string]interface{} if err := json.NewDecoder(f).Decode(&cfg); err != nil { + // A corrupt config silently falling back to defaults can mask a + // truncated/half-written file and make admin_token/registry + // settings vanish without explanation. Surface it. + slog.Warn("config file is corrupt; using defaults", + "path", configPath(), "error", err) return map[string]interface{}{} } return cfg @@ -311,14 +330,50 @@ func saveConfig(cfg map[string]interface{}) error { if err := os.MkdirAll(dir, 0700); err != nil { return err } - f, err := os.Create(configPath()) + + // config.json holds secrets (admin_token, webhook, registry, etc.), so + // keep it 0600 and write it atomically: serialize to a temp file in the + // same directory, fsync, then rename over the target. A crash mid-write + // can never leave a truncated or world-readable config behind, and we do + // not inherit loose permissions from any pre-existing file. + data, err := json.MarshalIndent(cfg, "", " ") if err != nil { return err } - defer f.Close() - enc := json.NewEncoder(f) - enc.SetIndent("", " ") - return enc.Encode(cfg) + data = append(data, '\n') + + tmp, err := os.CreateTemp(dir, defaultConfigFile+".tmp-*") + if err != nil { + return err + } + tmpName := tmp.Name() + cleanup := true + defer func() { + if cleanup { + _ = os.Remove(tmpName) + } + }() + + if err := tmp.Chmod(0600); err != nil { + tmp.Close() // #nosec G104 -- best-effort cleanup on the error path; the Chmod error is the one returned + return err + } + if _, err := tmp.Write(data); err != nil { + tmp.Close() // #nosec G104 -- best-effort cleanup on the error path; the Write error is the one returned + return err + } + if err := tmp.Sync(); err != nil { + tmp.Close() // #nosec G104 -- best-effort cleanup on the error path; the Sync error is the one returned + return err + } + if err := tmp.Close(); err != nil { + return err + } + if err := os.Rename(tmpName, configPath()); err != nil { + return err + } + cleanup = false + return nil } // --- Arg parsing helpers --- @@ -340,7 +395,7 @@ func parseFlags(args []string) (map[string]string, []string) { if key != "" { if idx := strings.Index(key, "="); idx >= 0 { flags[key[:idx]] = key[idx+1:] - } else if i+1 < len(args) && !strings.HasPrefix(args[i+1], "-") { + } else if i+1 < len(args) && isFlagValue(args[i+1]) { flags[key] = args[i+1] i++ } else { @@ -353,6 +408,39 @@ func parseFlags(args []string) (map[string]string, []string) { return flags, pos } +// isFlagValue reports whether tok, appearing immediately after a flag key, +// should be consumed as that flag's value rather than parsed as the next +// flag. Values that begin with "-" used to be silently dropped (forcing +// users onto --key=value); this lets "--key -value" work for the common +// cases without making "--flag --next" ambiguous: +// - anything not starting with "-" is a value +// - a bare "-" (stdin sentinel) is a value +// - "-1", "-3.14" and other negative numbers are values +// - "-=foo" or "-3x" (a "-" followed by a digit) is a value, since those +// are not valid flag names +// +// A token shaped like a flag ("--name" or "-name") is treated as the next +// flag, not a value — use --key=value to pass such a literal. +func isFlagValue(tok string) bool { + if !strings.HasPrefix(tok, "-") { + return true + } + if tok == "-" { + return true + } + // "--something" is always a flag. + if strings.HasPrefix(tok, "--") { + return false + } + // Single dash: a value if the remainder is numeric ("-1", "-3.14") or + // begins with a digit ("-3x"); otherwise it looks like a "-name" flag. + rest := tok[1:] + if isNumericFlag(rest) { + return true + } + return rest[0] >= '0' && rest[0] <= '9' +} + // isNumericFlag reports whether s looks like a bare number (e.g. "1", "3.14"), // so that negative number positional args like "-1" are not treated as flags. func isNumericFlag(s string) bool { @@ -2735,20 +2823,31 @@ func cmdDaemonStop() { } pid := readPID() + discovered := false if pid <= 0 { - // Try socket - d, err := driver.Connect(getSocket()) + // PID file is missing (manual start, lost file, crash). If the socket + // is live the daemon is still running — discover its PID from the + // socket owner and stop it instead of telling the user to kill it by + // hand. + socket := getSocket() + d, err := driver.Connect(socket) if err != nil { fatalCode("not_running", "daemon is not running") } - d.Close() - fatalHint("not_running", - fmt.Sprintf("find and kill the process manually: lsof -U | grep %s", getSocket()), - "daemon socket is active but PID file is missing") + d.Close() // #nosec G104 -- probe connection only; we just needed to confirm the socket is live + pid = discoverDaemonPID(socket) + if pid <= 0 { + fatalHint("not_running", + fmt.Sprintf("find and kill the process manually: lsof -U | grep %s", socket), + "daemon socket is active but PID file is missing and the owning process could not be identified") + } + discovered = true } if !processExists(pid) { - os.Remove(pidFilePath()) + if !discovered { + os.Remove(pidFilePath()) // #nosec G104 -- best-effort removal of stale PID file; failure is non-fatal + } fatalCode("not_running", "daemon is not running (cleaned up stale state)") } @@ -2953,6 +3052,34 @@ func processExists(pid int) bool { return proc.Signal(syscall.Signal(0)) == nil } +// discoverDaemonPID finds the PID of the process holding the daemon's Unix +// socket. It is the fallback for `daemon stop` when the PID file is missing +// (manual start, crash that lost the file, etc.) so we can stop the daemon +// instead of punting to "kill it yourself". Returns 0 if the owner can't be +// determined (e.g. lsof unavailable). Only same-UID processes are returned by +// lsof for our own socket, so this can't be used to signal other users' procs. +func discoverDaemonPID(socketPath string) int { + if socketPath == "" { + return 0 + } + // #nosec G204 -- fixed argv (lsof -t -U -a); socketPath is our own config-derived daemon socket passed as a separate arg, no shell, no injection + out, err := exec.Command("lsof", "-t", "-U", "-a", socketPath).Output() + if err != nil { + return 0 + } + self := os.Getpid() + for _, line := range strings.Fields(string(out)) { + pid, perr := strconv.Atoi(line) + if perr != nil || pid <= 0 || pid == self { + continue + } + if processExists(pid) { + return pid + } + } + return 0 +} + // ===================== GATEWAY ===================== // cmdGatewayStart exec's pilot-gateway in the foreground. The flags @@ -3207,14 +3334,19 @@ func cmdSetHostname(args []string) { fatalCode("connection_failed", "set-hostname: %v", err) } - // Persist to config.json so hostname survives daemon restart + // Persist to config.json so hostname survives daemon restart. Report a + // save failure rather than letting persisted config silently diverge + // from the running daemon. cfg := loadConfig() if hostname != "" { cfg["hostname"] = hostname } else { delete(cfg, "hostname") } - saveConfig(cfg) + if err := saveConfig(cfg); err != nil { + fatalCode("internal", + "hostname set on daemon but could not persist to %s: %v", configPath(), err) + } if jsonOutput { outputOK(map[string]interface{}{ @@ -3237,10 +3369,14 @@ func cmdClearHostname() { fatalCode("connection_failed", "clear-hostname: %v", err) } - // Persist to config.json so hostname stays cleared on daemon restart + // Persist to config.json so hostname stays cleared on daemon restart. + // Report a save failure rather than silently diverging from the daemon. cfg := loadConfig() delete(cfg, "hostname") - saveConfig(cfg) + if err := saveConfig(cfg); err != nil { + fatalCode("internal", + "hostname cleared on daemon but could not persist to %s: %v", configPath(), err) + } if jsonOutput { outputOK(map[string]interface{}{ @@ -3408,7 +3544,9 @@ func cmdConnect(args []string) { // --message mode: send one message, read one response, exit if message != "" { - conn, err := d.DialAddr(target, port) + // Honor --timeout on the dial itself so the daemon-side dial is + // cancelled (not left dangling) when the peer never answers. + conn, err := d.DialAddrTimeout(target, port, timeout) if err != nil { hint := classifyDaemonError(err) if hint == "" { @@ -3470,20 +3608,27 @@ func cmdConnect(args []string) { "--message is required (interactive mode not supported)") } - // Read all piped stdin + // Read all piped stdin. bufio.Scanner's default 64 KiB token limit + // would make a single long line fail silently, and scanner.Err() must be + // checked or read errors are swallowed. Raise the buffer and surface Err(). var stdinData []byte scanner := bufio.NewScanner(os.Stdin) + const maxLine = 16 * 1024 * 1024 // 16 MiB per line + scanner.Buffer(make([]byte, 0, 64*1024), maxLine) for scanner.Scan() { if len(stdinData) > 0 { stdinData = append(stdinData, '\n') } stdinData = append(stdinData, scanner.Bytes()...) } + if err := scanner.Err(); err != nil { + fatalCode("invalid_argument", "reading stdin: %v", err) + } if len(stdinData) == 0 { fatalCode("invalid_argument", "no data on stdin — use --message or pipe data") } - conn, err := d.DialAddr(target, port) + conn, err := d.DialAddrTimeout(target, port, timeout) if err != nil { fatalHint("connection_failed", fmt.Sprintf("check that %s is reachable: pilotctl ping %s", target, target), @@ -4166,18 +4311,34 @@ func cmdSendMessage(args []string) { if len(traceEvents) > 0 { result["trace"] = traceEvents } - outputOK(result) if waitDur > 0 { - if !jsonOutput { + // Defer all output until the reply is in (or times out) so that + // --json emits a SINGLE document: a machine parser reading stdout + // must not see the send-result object followed by a second reply + // object. In JSON mode we fold the reply into one envelope; in + // human mode we still print the send result first, then the reply. + if jsonOutput { + stop := startWaitProgress("waiting for reply") + reply, err := waitForInboxReply(agentHint, inboxCutoff, waitDur) + stop() + if err != nil { + fatalCode("timeout", "%v", err) + } + result["reply"] = reply + outputOK(result) + } else { + outputOK(result) fmt.Fprintf(os.Stderr, "waiting for reply from %s (up to %s)...\n", pos[0], waitDur) + stop := startWaitProgress("waiting for reply") + reply, err := waitForInboxReply(agentHint, inboxCutoff, waitDur) + stop() + if err != nil { + fatalCode("timeout", "%v", err) + } + output(reply) } - stop := startWaitProgress("waiting for reply") - reply, err := waitForInboxReply(agentHint, inboxCutoff, waitDur) - stop() - if err != nil { - fatalCode("timeout", "%v", err) - } - output(reply) + } else { + outputOK(result) } } else if reuseConn { // --reuse-conn: one dial shared across all N sends. Seq 0 pays dial @@ -5260,29 +5421,15 @@ func cmdPing(args []string) { // reuseConn mode: one shared connection across all iterations. // nil = needs dial. Reconnects only on error to avoid the ~1.5×RTT // TCP-handshake cost on packets 2+. Disabled by default (ablation flag). - type dialResult struct { - conn *driver.Conn - err error - } dialOnce := func() (*driver.Conn, time.Duration, error) { - ch := make(chan dialResult, 1) - go func() { - c, e := d.DialAddr(target, protocol.PortEcho) - ch <- dialResult{c, e} - }() + // DialAddrTimeout enforces the per-attempt budget at the daemon-IPC + // layer and cancels the dial on expiry, so a timed-out ping against a + // ghost peer doesn't leak a goroutine or leave a dangling daemon-side + // dial (the previous goroutine+select drained but could not cancel the + // underlying dial). t0 := time.Now() - select { - case dr := <-ch: - return dr.conn, time.Since(t0), dr.err - case <-time.After(perAttempt): - // Drain the goroutine asynchronously. - go func() { - if dr := <-ch; dr.conn != nil { - dr.conn.Close() - } - }() - return nil, time.Since(t0), fmt.Errorf("dial timeout after %s", perAttempt) - } + c, e := d.DialAddrTimeout(target, protocol.PortEcho, perAttempt) + return c, time.Since(t0), e } var sharedConn *driver.Conn @@ -5510,26 +5657,14 @@ func cmdTraceroute(args []string) { fmt.Printf("TRACEROUTE %s\n", target) } + // Tunnel negotiation against a slow or unreachable peer blocks silently + // for up to --timeout; show elapsed progress on a TTY. DialAddrTimeout + // cancels the underlying daemon dial on expiry so a timed-out traceroute + // does not leave a dangling connection behind in the daemon. start := time.Now() - connDone := make(chan *driver.Conn) - var dialErr error - go func() { - conn, err := d.DialAddr(target, protocol.PortEcho) - dialErr = err - connDone <- conn - }() - - // Tunnel negotiation against a slow or unreachable peer blocks here - // silently for up to --timeout; show elapsed progress on a TTY. stopProgress := startWaitProgress(fmt.Sprintf("tracing %s", target)) - var conn *driver.Conn - select { - case conn = <-connDone: - stopProgress() - case <-time.After(timeout): - stopProgress() - fatalCode("timeout", "dial timeout") - } + conn, dialErr := d.DialAddrTimeout(target, protocol.PortEcho, timeout) + stopProgress() setupTime := time.Since(start) if dialErr != nil { @@ -5553,7 +5688,20 @@ func cmdTraceroute(args []string) { for i := 0; i < 3; i++ { pingStart := time.Now() payload := fmt.Sprintf("trace-%d", i) - conn.Write([]byte(payload)) + if _, werr := conn.Write([]byte(payload)); werr != nil { + // The connection dropped mid-trace: record the write failure and + // don't block in Read waiting for an echo that will never come. + rtt := time.Since(pingStart) + sample := map[string]interface{}{ + "rtt_ms": float64(rtt.Microseconds()) / 1000.0, + "error": werr.Error(), + } + rttSamples = append(rttSamples, sample) + if !jsonOutput { + fmt.Printf(" rtt=%v write error: %v\n", rtt, werr) + } + break + } buf := make([]byte, 1024) n, err := conn.Read(buf) @@ -5597,6 +5745,30 @@ func cmdBench(args []string) { timeout := flagDuration(flags, "timeout", 120*time.Second) + // Validate the size before touching the daemon so a nonsense size fails + // fast (and can't tie up the transport). Negatives/zero/NaN/Inf would + // underflow the int conversion or busy-loop the send; an absurd size + // would run the echo path indefinitely. Cap at a generous 4 GiB. + totalSize := 1024 * 1024 + if len(pos) > 1 { + sizeMB, perr := strconv.ParseFloat(pos[1], 64) + if perr != nil { + fatalCode("invalid_argument", "invalid size: %v", perr) + } + const maxSizeMB = 4096.0 + if !(sizeMB > 0) { + fatalCode("invalid_argument", "size must be a positive number of MB, got %q", pos[1]) + } + if sizeMB > maxSizeMB { + fatalCode("invalid_argument", "size %g MB exceeds maximum of %g MB", sizeMB, maxSizeMB) + } + totalSize = int(sizeMB * 1024 * 1024) + if totalSize <= 0 { + fatalCode("invalid_argument", "size %q is too small to send", pos[1]) + } + } + const chunkSize = 4096 + d := connectDriver() defer d.Close() @@ -5606,21 +5778,11 @@ func cmdBench(args []string) { } maybeAutoHandshake(d, target, flagBool(flags, "no-auto-handshake")) - totalSize := 1024 * 1024 - if len(pos) > 1 { - sizeMB, err := strconv.ParseFloat(pos[1], 64) - if err != nil { - fatalCode("invalid_argument", "invalid size: %v", err) - } - totalSize = int(sizeMB * 1024 * 1024) - } - const chunkSize = 4096 - if !jsonOutput { fmt.Printf("BENCH %s — sending %s via echo port\n", target, formatBytes(uint64(totalSize))) } - conn, err := d.DialAddr(target, protocol.PortEcho) + conn, err := d.DialAddrTimeout(target, protocol.PortEcho, timeout) if err != nil { fatalHint("connection_failed", fmt.Sprintf("check that %s is reachable: pilotctl ping %s", target, target), diff --git a/cmd/pilotctl/zz_config_safety_test.go b/cmd/pilotctl/zz_config_safety_test.go new file mode 100644 index 00000000..23d6ac38 --- /dev/null +++ b/cmd/pilotctl/zz_config_safety_test.go @@ -0,0 +1,172 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package main + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +// TestSaveConfigPerms verifies the secret-bearing config.json is written +// 0600 and that a pre-existing loose-permission file does not retain its +// loose mode after a save (atomic rename installs a fresh 0600 file). +func TestSaveConfigPerms(t *testing.T) { + withTempHomeFull(t) + + // Pre-create a world-readable config to prove perms are not preserved. + dir := configDir() + if err := os.MkdirAll(dir, 0o700); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(configPath(), []byte(`{"old":"loose"}`), 0o644); err != nil { + t.Fatalf("seed: %v", err) + } + + if err := saveConfig(map[string]interface{}{"admin_token": "s3cr3t"}); err != nil { + t.Fatalf("saveConfig: %v", err) + } + + info, err := os.Stat(configPath()) + if err != nil { + t.Fatalf("stat: %v", err) + } + if perm := info.Mode().Perm(); perm != 0o600 { + t.Errorf("config perms = %o, want 0600", perm) + } + + got := loadConfig() + if got["admin_token"] != "s3cr3t" { + t.Errorf("admin_token = %v", got["admin_token"]) + } + if _, ok := got["old"]; ok { + t.Errorf("stale key survived overwrite: %v", got) + } +} + +// TestSaveConfigAtomicNoTemp verifies a successful save leaves no temp file +// behind in the config dir — only config.json. +func TestSaveConfigAtomicNoTemp(t *testing.T) { + withTempHomeFull(t) + + if err := saveConfig(map[string]interface{}{"hostname": "atomic"}); err != nil { + t.Fatalf("saveConfig: %v", err) + } + + entries, err := os.ReadDir(configDir()) + if err != nil { + t.Fatalf("readdir: %v", err) + } + for _, e := range entries { + if strings.Contains(e.Name(), ".tmp-") { + t.Errorf("leftover temp file: %s", e.Name()) + } + } +} + +// TestConfigDirHonorsPilotHome verifies PILOT_HOME relocates the .pilot dir +// and takes precedence over HOME. +func TestConfigDirHonorsPilotHome(t *testing.T) { + home := withTempHomeFull(t) + ph := t.TempDir() + t.Setenv("PILOT_HOME", ph) + + if got, want := configDir(), filepath.Join(ph, ".pilot"); got != want { + t.Errorf("configDir = %s, want %s (PILOT_HOME should win over HOME=%s)", got, want, home) + } + + // And config written/read round-trips under PILOT_HOME. + if err := saveConfig(map[string]interface{}{"hostname": "ph-agent"}); err != nil { + t.Fatalf("saveConfig: %v", err) + } + if _, err := os.Stat(filepath.Join(ph, ".pilot", defaultConfigFile)); err != nil { + t.Errorf("config not written under PILOT_HOME: %v", err) + } + if loadConfig()["hostname"] != "ph-agent" { + t.Error("config not read back under PILOT_HOME") + } +} + +// TestCmdBenchRejectsBadSize covers the size-validation guard: negative, +// zero, and absurdly-large sizes are rejected before any daemon dial, so +// these run without a daemon and must exit non-zero with invalid_argument. +func TestCmdBenchRejectsBadSize(t *testing.T) { + t.Parallel() + bad := []string{"-5", "0", "999999", "nan"} + for _, sz := range bad { + sz := sz + t.Run(sz, func(t *testing.T) { + t.Parallel() + _, stderr, code := runCLI(t, []string{ + "bench", "0:0000.0000.002A", sz, + }, map[string]string{"PILOT_SOCKET": "/tmp/nope-bench-" + sz + ".sock"}) + if code == 0 { + t.Errorf("size %q: expected non-zero exit", sz) + } + if !strings.Contains(stderr, "size") && !strings.Contains(stderr, "invalid") { + t.Errorf("size %q: expected size/invalid hint, got: %s", sz, stderr) + } + }) + } +} + +// TestCmdSendMessageJSONWaitSingleDoc verifies that send-message --json --wait +// emits exactly ONE JSON document (with the reply folded in), so machine +// parsers reading stdout don't choke on a second concatenated document. +func TestCmdSendMessageJSONWaitSingleDoc(t *testing.T) { + sd := newStreamDaemon(t) + home := sd.useDaemonNoRegistry(t) + + // Seed an inbox reply whose "from" matches the resolved target address. + inbox := filepath.Join(home, ".pilot", "inbox") + if err := os.MkdirAll(inbox, 0o700); err != nil { + t.Fatalf("mkdir inbox: %v", err) + } + reply := map[string]interface{}{ + "from": "0:0000.0000.002A", + "data": "pong", + } + body, _ := json.Marshal(reply) + replyPath := filepath.Join(inbox, "TEXT-reply.json") + if err := os.WriteFile(replyPath, body, 0o600); err != nil { + t.Fatalf("write reply: %v", err) + } + // Ensure mtime is after the cutoff send-message computes (now-1s). + future := time.Now().Add(2 * time.Second) + if err := os.Chtimes(replyPath, future, future); err != nil { + t.Fatalf("chtimes: %v", err) + } + + out := captureStdout(t, func() { + withJSON(func() { + cmdSendMessage([]string{"0:0000.0000.002A", "--data", "ping", "--wait", "5s"}) + }) + }) + + // Exactly one JSON document: a second Decode must hit EOF. + dec := json.NewDecoder(strings.NewReader(out)) + var env map[string]interface{} + if err := dec.Decode(&env); err != nil { + t.Fatalf("first decode: %v\n%s", err, out) + } + var extra interface{} + if err := dec.Decode(&extra); err == nil { + t.Fatalf("expected single JSON document, found a second: %v\n%s", extra, out) + } + + // The reply must be folded into the single envelope. + data, ok := env["data"].(map[string]interface{}) + if !ok { + t.Fatalf("missing data object: %s", out) + } + replyObj, ok := data["reply"].(map[string]interface{}) + if !ok { + t.Fatalf("missing folded reply: %s", out) + } + if replyObj["data"] != "pong" { + t.Errorf("reply.data = %v, want pong", replyObj["data"]) + } +} diff --git a/cmd/pilotctl/zz_lifecycle_test.go b/cmd/pilotctl/zz_lifecycle_test.go index 2594e6cd..6f04b181 100644 --- a/cmd/pilotctl/zz_lifecycle_test.go +++ b/cmd/pilotctl/zz_lifecycle_test.go @@ -19,6 +19,7 @@ func withTempHomeFull(t *testing.T) string { tmp := t.TempDir() t.Setenv("HOME", tmp) // Clear env overrides so the test exercises the loadConfig fallback path. + t.Setenv("PILOT_HOME", "") t.Setenv("PILOT_SOCKET", "") t.Setenv("PILOT_REGISTRY", "") t.Setenv("PILOT_ADMIN_TOKEN", "") diff --git a/cmd/pilotctl/zz_parsers_test.go b/cmd/pilotctl/zz_parsers_test.go index 07948c31..43f86dca 100644 --- a/cmd/pilotctl/zz_parsers_test.go +++ b/cmd/pilotctl/zz_parsers_test.go @@ -87,13 +87,55 @@ func TestParseFlagsBasic(t *testing.T) { wantPos: nil, }, { - name: "flag value that starts with hyphen does not consume", + name: "negative-number value is consumed by preceding flag", args: []string{"--count", "-1"}, wantFlag: map[string]string{ - // next arg starts with '-' so we treat --count as bool - "count": "true", + // "-1" is a value, not a flag, so --count consumes it. + "count": "-1", }, - wantPos: []string{"-1"}, + wantPos: nil, + }, + { + name: "decimal-negative value is consumed", + args: []string{"--offset", "-3.14"}, + wantFlag: map[string]string{ + "offset": "-3.14", + }, + wantPos: nil, + }, + { + name: "bare-dash value (stdin) is consumed", + args: []string{"--file", "-"}, + wantFlag: map[string]string{ + "file": "-", + }, + wantPos: nil, + }, + { + name: "dash-digit value is consumed", + args: []string{"--rate", "-3x"}, + wantFlag: map[string]string{ + "rate": "-3x", + }, + wantPos: nil, + }, + { + name: "next long flag is not consumed as a value", + args: []string{"--data", "--trace"}, + wantFlag: map[string]string{ + "data": "true", + "trace": "true", + }, + wantPos: nil, + }, + { + name: "next single-dash flag is not consumed as a value", + args: []string{"--data", "-email", "x@y.com"}, + wantFlag: map[string]string{ + "data": "true", + "email": "x@y.com", + }, + wantPos: nil, }, } for _, tc := range cases { diff --git a/cmd/pilotctl/zz_updates_test.go b/cmd/pilotctl/zz_updates_test.go index d1e31cdd..bc8ef13e 100644 --- a/cmd/pilotctl/zz_updates_test.go +++ b/cmd/pilotctl/zz_updates_test.go @@ -151,6 +151,7 @@ func withTempHome(t *testing.T) string { t.Helper() dir := t.TempDir() t.Setenv("HOME", dir) + t.Setenv("PILOT_HOME", "") return dir }