From 80c3cf573eaf9d74a290309ee706402622dcc323 Mon Sep 17 00:00:00 2001 From: Tuhalf <37873266+tuhalf@users.noreply.github.com> Date: Thu, 9 Apr 2026 22:32:15 +0200 Subject: [PATCH 01/15] Add daemon-managed CLI runtime and manual test lab --- cmd/diode/app.go | 207 ++++- cmd/diode/cli_status.go | 23 + cmd/diode/config_server.go | 5 + cmd/diode/daemon.go | 1175 +++++++++++++++++++++++++ cmd/diode/daemon_manage.go | 573 ++++++++++++ cmd/diode/daemon_test.go | 305 +++++++ cmd/diode/daemon_transport_unix.go | 163 ++++ cmd/diode/daemon_transport_windows.go | 134 +++ cmd/diode/diode.go | 4 + cmd/diode/fetch.go | 26 +- cmd/diode/files.go | 13 +- cmd/diode/gateway.go | 4 + cmd/diode/join.go | 34 + cmd/diode/mode_helpers.go | 20 + cmd/diode/output_helpers.go | 37 + cmd/diode/publish.go | 69 +- cmd/diode/publish_render.go | 58 ++ cmd/diode/pushpull.go | 16 +- cmd/diode/socksd.go | 4 + cmd/diode/ssh.go | 64 +- cmd/diode/token.go | 3 +- cmd/diode/update.go | 38 +- config/flag.go | 19 +- docs/manual-cli-test-plan.md | 675 ++++++++++++++ scripts/manual/setup_cli_lab.sh | 258 ++++++ 25 files changed, 3773 insertions(+), 154 deletions(-) create mode 100644 cmd/diode/cli_status.go create mode 100644 cmd/diode/daemon.go create mode 100644 cmd/diode/daemon_manage.go create mode 100644 cmd/diode/daemon_test.go create mode 100644 cmd/diode/daemon_transport_unix.go create mode 100644 cmd/diode/daemon_transport_windows.go create mode 100644 cmd/diode/mode_helpers.go create mode 100644 cmd/diode/output_helpers.go create mode 100644 cmd/diode/publish_render.go create mode 100644 docs/manual-cli-test-plan.md create mode 100755 scripts/manual/setup_cli_lab.sh diff --git a/cmd/diode/app.go b/cmd/diode/app.go index 5ee31b42..897ace7e 100644 --- a/cmd/diode/app.go +++ b/cmd/diode/app.go @@ -4,7 +4,9 @@ package main import ( + "flag" "fmt" + "io" "math/rand" "net" "net/http" @@ -45,45 +47,56 @@ var ( } ) -func init() { - cfg := &config.Config{} - diodeCmd.Flag.StringVar(&cfg.DBPath, "dbpath", util.DefaultDBPath(), "file path to db file") - diodeCmd.Flag.IntVar(&cfg.RetryTimes, "retrytimes", 3, "retry times to connect the remote rpc server") - diodeCmd.Flag.DurationVar(&cfg.EdgeE2ETimeout, "e2etimeout", 15*time.Second, "timeout seconds for edge e2e handshake") - // should put to httpd or other command - diodeCmd.Flag.BoolVar(&cfg.EnableUpdate, "update", true, "enable update when start diode") - diodeCmd.Flag.BoolVar(&cfg.EnableMetrics, "metrics", false, "enable metrics stats") - diodeCmd.Flag.BoolVar(&cfg.EnableTray, "tray", false, "show a system tray icon") - diodeCmd.Flag.BoolVar(&cfg.BlockquickDowngrade, "bqdowngrade", false, "reset blockquick window after repeated validation failures") - diodeCmd.Flag.BoolVar(&cfg.Debug, "debug", false, "turn on debug mode") - diodeCmd.Flag.BoolVar(&cfg.EnableAPIServer, "api", false, "turn on the config api") - diodeCmd.Flag.StringVar(&cfg.APIServerAddr, "apiaddr", "localhost:1081", "define config api server address") - diodeCmd.Flag.IntVar(&cfg.RlimitNofile, "rlimit_nofile", 0, "specify the file descriptor numbers that can be opened by this process") - diodeCmd.Flag.StringVar(&cfg.LogFilePath, "logfilepath", "", "absolute path to the log file") - diodeCmd.Flag.BoolVar(&cfg.LogDateTime, "logdatetime", false, "show the date time in log") - diodeCmd.Flag.StringVar(&cfg.ConfigFilePath, "configpath", "", "yaml file path to config file") - diodeCmd.Flag.StringVar(&cfg.CPUProfile, "cpuprofile", "", "file path for cpu profiling") - // diodeCmd.Flag.IntVar(&cfg.CPUProfileRate, "cpuprofilerate", 100, "the CPU profiling rate to hz samples per second") - diodeCmd.Flag.StringVar(&cfg.MEMProfile, "memprofile", "", "file path for memory profiling") - diodeCmd.Flag.IntVar(&cfg.PProfPort, "pprofport", 0, "localhost port for pprof for memory debugging") - diodeCmd.Flag.StringVar(&cfg.BlockProfile, "blockprofile", "", "file path for block profiling") - diodeCmd.Flag.IntVar(&cfg.BlockProfileRate, "blockprofilerate", 1, "the fraction of goroutine blocking events that are reported in the blocking profile") - diodeCmd.Flag.StringVar(&cfg.MutexProfile, "mutexprofile", "", "file path for mutex profiling") - diodeCmd.Flag.IntVar(&cfg.MutexProfileRate, "mutexprofilerate", 1, "the fraction of mutex contention events that are reported in the mutex profile") +func registerRootFlags(fs *flag.FlagSet, cfg *config.Config) { + fs.StringVar(&cfg.DBPath, "dbpath", util.DefaultDBPath(), "file path to db file") + fs.IntVar(&cfg.RetryTimes, "retrytimes", 3, "retry times to connect the remote rpc server") + fs.DurationVar(&cfg.EdgeE2ETimeout, "e2etimeout", 15*time.Second, "timeout seconds for edge e2e handshake") + fs.BoolVar(&cfg.EnableUpdate, "update", true, "enable update when start diode") + fs.BoolVar(&cfg.EnableMetrics, "metrics", false, "enable metrics stats") + fs.BoolVar(&cfg.EnableTray, "tray", false, "show a system tray icon") + fs.BoolVar(&cfg.DisableDaemon, "no-daemon", false, "run this command in standalone mode instead of using the diode daemon") + fs.BoolVar(&cfg.BlockquickDowngrade, "bqdowngrade", false, "reset blockquick window after repeated validation failures") + fs.BoolVar(&cfg.Debug, "debug", false, "turn on debug mode") + fs.BoolVar(&cfg.EnableAPIServer, "api", false, "turn on the config api") + fs.StringVar(&cfg.APIServerAddr, "apiaddr", "localhost:1081", "define config api server address") + fs.IntVar(&cfg.RlimitNofile, "rlimit_nofile", 0, "specify the file descriptor numbers that can be opened by this process") + fs.StringVar(&cfg.LogFilePath, "logfilepath", "", "absolute path to the log file") + fs.BoolVar(&cfg.LogDateTime, "logdatetime", false, "show the date time in log") + fs.StringVar(&cfg.ConfigFilePath, "configpath", "", "yaml file path to config file") + fs.StringVar(&cfg.CPUProfile, "cpuprofile", "", "file path for cpu profiling") + fs.StringVar(&cfg.MEMProfile, "memprofile", "", "file path for memory profiling") + fs.IntVar(&cfg.PProfPort, "pprofport", 0, "localhost port for pprof for memory debugging") + fs.StringVar(&cfg.BlockProfile, "blockprofile", "", "file path for block profiling") + fs.IntVar(&cfg.BlockProfileRate, "blockprofilerate", 1, "the fraction of goroutine blocking events that are reported in the blocking profile") + fs.StringVar(&cfg.MutexProfile, "mutexprofile", "", "file path for mutex profiling") + fs.IntVar(&cfg.MutexProfileRate, "mutexprofilerate", 1, "the fraction of mutex contention events that are reported in the mutex profile") var fleetFake string - diodeCmd.Flag.StringVar(&fleetFake, "fleet", "", "@deprecated. Use: 'diode config set fleet=0x1234' instead") - - diodeCmd.Flag.DurationVar(&cfg.RemoteRPCTimeout, "timeout", 5*time.Second, "timeout seconds to connect to the remote rpc server") - diodeCmd.Flag.DurationVar(&cfg.RetryWait, "retrywait", 1*time.Second, "wait seconds before next retry") - diodeCmd.Flag.Var(&cfg.RemoteRPCAddrs, "diodeaddrs", "addresses of Diode node server (default: asia.prenet.diode.io:41046, europe.prenet.diode.io:41046, usa.prenet.diode.io:41046)") - diodeCmd.Flag.Var(&cfg.SBlockdomains, "blockdomains", "domains (bns names) that are not allowed") - diodeCmd.Flag.Var(&cfg.SBlocklists, "blocklists", "addresses are not allowed to connect to published resource (used when allowlists is empty)") - diodeCmd.Flag.Var(&cfg.SAllowlists, "allowlists", "addresses are allowed to connect to published resource (used when blocklists is empty)") - diodeCmd.Flag.Var(&cfg.SBinds, "bind", "bind a remote port to a local port. -bind :::(udp|tcp|tls)") - diodeCmd.Flag.DurationVar(&cfg.ResolveCacheTime, "resolvecachetime", 10*time.Minute, "time for member and bns resolvers cache. (default: 10 minutes)") - diodeCmd.Flag.DurationVar(&cfg.ResolveCacheTime, "bnscachetime", 10*time.Minute, "(Deprecated. Please use resolvecachetime) time for bns address resolve cache. (default: 10 minutes)") - diodeCmd.Flag.IntVar(&cfg.MaxPortsPerDevice, "maxports", 0, "maximum concurrent ports per device (0 = unlimited)") + fs.StringVar(&fleetFake, "fleet", "", "@deprecated. Use: 'diode config set fleet=0x1234' instead") + + fs.DurationVar(&cfg.RemoteRPCTimeout, "timeout", 5*time.Second, "timeout seconds to connect to the remote rpc server") + fs.DurationVar(&cfg.RetryWait, "retrywait", 1*time.Second, "wait seconds before next retry") + fs.Var(&cfg.RemoteRPCAddrs, "diodeaddrs", "addresses of Diode node server (default: asia.prenet.diode.io:41046, europe.prenet.diode.io:41046, usa.prenet.diode.io:41046)") + fs.Var(&cfg.SBlockdomains, "blockdomains", "domains (bns names) that are not allowed") + fs.Var(&cfg.SBlocklists, "blocklists", "addresses are not allowed to connect to published resource (used when allowlists is empty)") + fs.Var(&cfg.SAllowlists, "allowlists", "addresses are allowed to connect to published resource (used when blocklists is empty)") + fs.Var(&cfg.SBinds, "bind", "bind a remote port to a local port. -bind :::(udp|tcp|tls)") + fs.DurationVar(&cfg.ResolveCacheTime, "resolvecachetime", 10*time.Minute, "time for member and bns resolvers cache. (default: 10 minutes)") + fs.DurationVar(&cfg.ResolveCacheTime, "bnscachetime", 10*time.Minute, "(Deprecated. Please use resolvecachetime) time for bns address resolve cache. (default: 10 minutes)") + fs.IntVar(&cfg.MaxPortsPerDevice, "maxports", 0, "maximum concurrent ports per device (0 = unlimited)") +} + +func newRootConfig() *config.Config { + cfg := &config.Config{} + fs := flag.NewFlagSet("diode-root-defaults", flag.ContinueOnError) + fs.SetOutput(io.Discard) + registerRootFlags(fs, cfg) + return cfg +} + +func init() { + cfg := newRootConfig() + registerRootFlags(&diodeCmd.Flag, cfg) config.AppConfig = cfg // Add diode commands diodeCmd.AddSubCommand(bnsCmd) @@ -212,6 +225,13 @@ type Diode struct { deferals []func() closeCh chan struct{} cmd *command.Command + startMu sync.Mutex + started bool + modeMu sync.Mutex + activeMode string + modeDeferals []func() + modeStopCh chan struct{} + modeDoneCh chan struct{} } // NewDiode return diode application @@ -255,7 +275,7 @@ func (dio *Diode) Init() error { shouldUpdateDiode = diff.Hours() >= 24 } if shouldUpdateDiode { - doUpdate() + _, _ = doUpdate(updateRestartStandalone) } } @@ -373,22 +393,49 @@ func (dio *Diode) Defer(deferal func()) { dio.deferals = append(dio.deferals, deferal) } +// ModeDefer registers cleanup tied to the active daemon mode. +func (dio *Diode) ModeDefer(deferal func()) { + dio.modeDeferals = append(dio.modeDeferals, deferal) +} + +func (dio *Diode) SetCommand(cmd *command.Command) { + dio.cmd = cmd +} + +func (dio *Diode) resolveCommand() (*command.Command, error) { + if dio.cmd != nil { + return dio.cmd, nil + } + dio.cmd = diodeCmd.SubCommand() + if dio.cmd == nil { + return nil, fmt.Errorf("could not determine command to start") + } + return dio.cmd, nil +} + // Start the diode application func (dio *Diode) Start() error { cfg := dio.config - dio.cmd = diodeCmd.SubCommand() - if dio.cmd == nil { - return fmt.Errorf("could not determine command to start") + cmd, err := dio.resolveCommand() + if err != nil { + return err + } + + dio.startMu.Lock() + firstStart := !dio.started + if firstStart { + cfg.PrintLabel("Client address", cfg.ClientAddr.HexString()) + cfg.PrintLabel("Fleet address", cfg.FleetAddr.HexString()) + dio.clientManager.Start() + dio.started = true } - cfg.PrintLabel("Client address", cfg.ClientAddr.HexString()) - cfg.PrintLabel("Fleet address", cfg.FleetAddr.HexString()) - dio.clientManager.Start() + dio.startMu.Unlock() - if dio.cmd.Type == command.EmptyConnectionCommand { + if cmd.Type == command.EmptyConnectionCommand { return nil } - isOneOffCommand := dio.cmd.Type == command.OneOffCommand + isOneOffCommand := cmd.Type == command.OneOffCommand //onlyNeedOne := dio.cmd.SingleConnection || isOneOffCommand if len(dio.config.RemoteRPCAddrs) < 1 { @@ -469,15 +516,78 @@ func (dio *Diode) Wait() { // go func() { // listen to signal sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGINT) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) sig := <-sigChan switch sig { - case syscall.SIGINT: + case syscall.SIGINT, syscall.SIGTERM: dio.Close() } // }() } +func (dio *Diode) BeginMode(mode string) { + dio.modeMu.Lock() + defer dio.modeMu.Unlock() + dio.activeMode = mode + if dio.modeStopCh == nil { + dio.modeStopCh = make(chan struct{}) + } +} + +func (dio *Diode) ModeStopChan() <-chan struct{} { + dio.modeMu.Lock() + defer dio.modeMu.Unlock() + if dio.modeStopCh == nil { + dio.modeStopCh = make(chan struct{}) + } + return dio.modeStopCh +} + +func (dio *Diode) SetModeDone(done chan struct{}) { + dio.modeMu.Lock() + dio.modeDoneCh = done + dio.modeMu.Unlock() +} + +func (dio *Diode) StopMode() { + dio.modeMu.Lock() + stopCh := dio.modeStopCh + doneCh := dio.modeDoneCh + modeDeferals := dio.modeDeferals + dio.modeStopCh = nil + dio.modeDoneCh = nil + dio.modeDeferals = nil + dio.activeMode = "" + dio.modeMu.Unlock() + + if stopCh != nil { + close(stopCh) + } + if doneCh != nil { + <-doneCh + } + for _, fun := range modeDeferals { + fun() + } + if dio.socksServer != nil { + dio.socksServer.Close() + dio.socksServer = nil + } + if dio.proxyServer != nil { + dio.proxyServer.Close() + dio.proxyServer = nil + } + if dio.configAPIServer != nil { + dio.configAPIServer.Close() + dio.configAPIServer = nil + } + socksServerStarted = false + lastAppliedBindSignature = "" + if dio.clientManager != nil { + dio.clientManager.GetPool().SetPublishedPorts(map[int]*config.Port{}) + } +} + // Closed returns the whether diode application has been closed func (dio *Diode) isClosed(closedCh <-chan struct{}) bool { select { @@ -497,6 +607,7 @@ func (dio *Diode) Closed() bool { func (dio *Diode) Close() { dio.cd.Do(func() { cfg := config.AppConfig + dio.StopMode() for _, fun := range dio.deferals { fun() } diff --git a/cmd/diode/cli_status.go b/cmd/diode/cli_status.go new file mode 100644 index 00000000..0c5d0a16 --- /dev/null +++ b/cmd/diode/cli_status.go @@ -0,0 +1,23 @@ +package main + +import "fmt" + +type exitStatusError struct { + code int + msg string +} + +func (e *exitStatusError) Error() string { + return e.msg +} + +func (e *exitStatusError) Status() int { + return e.code +} + +func newExitStatusError(code int, format string, args ...interface{}) error { + return &exitStatusError{ + code: code, + msg: fmt.Sprintf(format, args...), + } +} diff --git a/cmd/diode/config_server.go b/cmd/diode/config_server.go index 236fec40..851687a8 100644 --- a/cmd/diode/config_server.go +++ b/cmd/diode/config_server.go @@ -567,6 +567,11 @@ func (configAPIServer *ConfigAPIServer) apiHandleFunc() func(w http.ResponseWrit return } configAPIServer.successResponse(w, "ok") + if daemonState != nil { + daemonState.baseConfig = sanitizedDaemonBaseConfig(configAPIServer.appConfig) + configAPIServer.appConfig.Logger.Info("Updated config in daemon mode without process restart") + return + } go func() { // restart diode go client // TODO: gracefully restart go client diff --git a/cmd/diode/daemon.go b/cmd/diode/daemon.go new file mode 100644 index 00000000..f2db3b1d --- /dev/null +++ b/cmd/diode/daemon.go @@ -0,0 +1,1175 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "flag" + "fmt" + "io" + "net" + "os" + "os/signal" + "reflect" + "strconv" + "strings" + "sync" + "time" + + "github.com/diodechain/diode_client/command" + "github.com/diodechain/diode_client/config" + "github.com/diodechain/diode_client/rpc" + "github.com/diodechain/diode_client/util" +) + +const ( + daemonCommandName = "__daemon__" + envDaemonReadyFD = "DIODE_DAEMON_READY_FD" + envDaemonStartupSpec = "DIODE_DAEMON_STARTUP_SPEC" + envDaemonRestoreArgs = "DIODE_DAEMON_RESTORE_ARGS" + daemonProtocolVersion = 1 + daemonRequestRunTask = "run_task" + daemonRequestApplyMode = "apply_mode" + daemonRequestLease = "lease_local_proxy" + daemonRequestRelease = "release_local_proxy" + daemonRequestUpdate = "update" + daemonRequestManage = "manage" +) + +var ( + daemonCmd = &command.Command{ + Name: daemonCommandName, + Run: daemonHandler, + Type: command.DaemonCommand, + Hidden: true, + SkipParentHooks: true, + } + daemonExecMu sync.Mutex + activeDaemonReqKind string + activeDaemonReqMu sync.Mutex + daemonState *runtimeDaemon + daemonStartupFlagNames = map[string]bool{ + "-dbpath": true, + "-retrytimes": true, + "-e2etimeout": true, + "-update": true, + "-metrics": true, + "-tray": true, + "-bqdowngrade": true, + "-debug": true, + "-api": true, + "-apiaddr": true, + "-rlimit_nofile": true, + "-logfilepath": true, + "-logdatetime": true, + "-configpath": true, + "-cpuprofile": true, + "-memprofile": true, + "-pprofport": true, + "-blockprofile": true, + "-blockprofilerate": true, + "-mutexprofile": true, + "-mutexprofilerate": true, + "-timeout": true, + "-retrywait": true, + "-diodeaddrs": true, + "-blockdomains": true, + "-blocklists": true, + "-allowlists": true, + "-resolvecachetime": true, + "-bnscachetime": true, + "-maxports": true, + "-no-daemon": true, + } + localBypassCommands = map[string]bool{"": true, "version": true, "mcp": true, "ssh-proxy": true, daemonCommandName: true} + daemonApplyModeCmds = map[string]bool{"publish": true, "gateway": true, "socksd": true, "join": true, "files": true} + daemonRunnableCmds = map[string]bool{"query": true, "time": true, "fetch": true, "token": true, "bns": true, "config": true, "reset": true, "push": true, "pull": true, "publish": true, "gateway": true, "socksd": true, "join": true, "files": true, "ssh": true, "update": true} +) + +type daemonStartupSpec struct { + DBPath string `json:"dbpath"` + RetryTimes int `json:"retrytimes"` + EdgeE2ETimeout time.Duration `json:"e2etimeout"` + EnableUpdate bool `json:"update"` + EnableMetrics bool `json:"metrics"` + EnableTray bool `json:"tray"` + BlockquickDowngrade bool `json:"bqdowngrade"` + Debug bool `json:"debug"` + EnableAPIServer bool `json:"api"` + APIServerAddr string `json:"apiaddr"` + RlimitNofile int `json:"rlimit_nofile"` + LogFilePath string `json:"logfilepath"` + LogDateTime bool `json:"logdatetime"` + ConfigFilePath string `json:"configpath"` + CPUProfile string `json:"cpuprofile"` + MEMProfile string `json:"memprofile"` + PProfPort int `json:"pprofport"` + BlockProfile string `json:"blockprofile"` + BlockProfileRate int `json:"blockprofilerate"` + MutexProfile string `json:"mutexprofile"` + MutexProfileRate int `json:"mutexprofilerate"` + RemoteRPCTimeout time.Duration `json:"timeout"` + RetryWait time.Duration `json:"retrywait"` + RemoteRPCAddrs config.StringValues `json:"diodeaddrs"` + SBlockdomains config.StringValues `json:"blockdomains"` + SBlocklists config.StringValues `json:"blocklists"` + SAllowlists config.StringValues `json:"allowlists"` + ResolveCacheTime time.Duration `json:"resolvecachetime"` + MaxPortsPerDevice int `json:"maxports"` +} + +type daemonMetadata struct { + PID int `json:"pid"` + SocketPath string `json:"socket_path"` + StartupSpec daemonStartupSpec `json:"startup_spec"` +} + +type daemonRequest struct { + Version int `json:"version"` + Kind string `json:"kind"` + Command string `json:"command"` + Args []string `json:"args,omitempty"` + LeaseID string `json:"lease_id,omitempty"` +} + +type daemonResponse struct { + Version int `json:"version"` + Stdout string `json:"stdout,omitempty"` + Stderr string `json:"stderr,omitempty"` + ExitCode int `json:"exit_code"` + Error string `json:"error,omitempty"` + ProxyAddr string `json:"proxy_addr,omitempty"` + LeaseID string `json:"lease_id,omitempty"` + RestartPath string `json:"-"` + Shutdown bool `json:"-"` +} + +type runtimeDaemon struct { + socketPath string + metaPath string + listener net.Listener + startup daemonStartupSpec + baseConfig config.Config + leasesMu sync.Mutex + leases map[string]*rpc.Server + stateMu sync.Mutex + activeMode string + activeArgs []string + ports map[int]*config.Port + binds []config.Bind + socksAddr string + socksOn bool + apiAddr string + apiOn bool +} + +func init() { + diodeCmd.AddSubCommand(daemonCmd) +} + +func daemonHandler() error { + cfg := config.AppConfig + cfg.DisableDaemon = true + startupSpec := daemonStartupSpecFromConfig(cfg) + if raw := os.Getenv(envDaemonStartupSpec); raw != "" { + if err := json.Unmarshal([]byte(raw), &startupSpec); err != nil { + return err + } + applyDaemonStartupSpec(cfg, startupSpec) + } + if err := prepareDiode(); err != nil { + return err + } + defer cleanDiode() + + socketPath, metaPath, err := daemonPaths() + if err != nil { + return err + } + _ = os.Remove(socketPath) + ln, err := daemonListen(socketPath) + if err != nil { + return err + } + + daemonState = &runtimeDaemon{ + socketPath: socketPath, + metaPath: metaPath, + listener: ln, + startup: startupSpec, + baseConfig: sanitizedDaemonBaseConfig(cfg), + leases: map[string]*rpc.Server{}, + } + app.Defer(func() { + daemonState.closeLeases() + _ = ln.Close() + cleanupDaemonTransport(socketPath) + _ = os.Remove(metaPath) + }) + + if err := writeDaemonMetadata(metaPath, daemonMetadata{ + PID: os.Getpid(), + SocketPath: socketPath, + StartupSpec: daemonState.startup, + }); err != nil { + return err + } + restoreArgs, err := daemonRestoreArgsFromEnv() + if err != nil { + return err + } + if len(restoreArgs) > 0 { + if err := runDaemonCommandAsKind(daemonRequestApplyMode, restoreArgs); err != nil { + logDaemonInternalError("Couldn't restore daemon mode after restart", err) + } else { + daemonState.updateModeSnapshot(restoreArgs[0], restoreArgs, config.AppConfig) + } + } + if err := signalDaemonReady(); err != nil { + return err + } + go serveDaemon(ln) + sigCtx, stop := signal.NotifyContext(context.Background(), daemonSignals()...) + defer stop() + select { + case <-sigCtx.Done(): + case <-app.closeCh: + } + app.Close() + return nil +} + +func maybeHandleDaemonCLI(args []string) (bool, int) { + inv, err := parseRootInvocation(args) + if err != nil { + return false, 0 + } + if inv.command == "daemon" { + return handleDaemonManagerCLI(inv.commandArgs) + } + if inv.disableDaemon || inv.help || localBypassCommands[inv.command] || !daemonRunnableCmds[inv.command] { + return false, 0 + } + + req := daemonRequest{ + Version: daemonProtocolVersion, + Command: inv.command, + Args: inv.execArgs, + } + switch inv.command { + case "ssh": + req.Kind = daemonRequestLease + case "update": + req.Kind = daemonRequestUpdate + default: + if daemonApplyModeCmds[inv.command] { + req.Kind = daemonRequestApplyMode + } else { + req.Kind = daemonRequestRunTask + } + } + + resp, handled, reason, err := dispatchViaDaemon(inv.startupSpec, req) + if !handled { + if reason != "" { + stderrln(reason) + } + return false, 0 + } + if err != nil { + stderrln(err.Error()) + return true, 1 + } + if resp.Stdout != "" { + io.WriteString(stdoutWriter(), resp.Stdout) + } + if resp.Stderr != "" { + io.WriteString(stderrWriter(), resp.Stderr) + } + if req.Kind == daemonRequestLease { + return true, runSSHViaDaemonLease(inv.commandArgs, resp) + } + if req.Kind == daemonRequestApplyMode && resp.ExitCode == 0 { + stdoutf("Daemon mode active: %s\n", inv.command) + stdoutln("Use `diode daemon status` to inspect or manage the running daemon.") + } + return true, resp.ExitCode +} + +type rootInvocation struct { + command string + commandArgs []string + execArgs []string + help bool + disableDaemon bool + startupSpec daemonStartupSpec +} + +func parseRootInvocation(args []string) (rootInvocation, error) { + cfg := newRootConfig() + fs := flag.NewFlagSet("diode-root-parse", flag.ContinueOnError) + fs.SetOutput(io.Discard) + registerRootFlags(fs, cfg) + err := fs.Parse(args) + if err == flag.ErrHelp { + return rootInvocation{help: true}, nil + } + if err != nil { + return rootInvocation{}, err + } + rest := fs.Args() + commandName := "" + execArgs := append([]string{}, args...) + if len(rest) > 0 { + commandName = rest[0] + } + if commandName == "" && containsHelpArg(args) { + return rootInvocation{help: true, disableDaemon: cfg.DisableDaemon, startupSpec: daemonStartupSpecFromConfig(cfg)}, nil + } + if commandName == "" { + commandName = "publish" + rest = append([]string{commandName}, rest...) + execArgs = append(execArgs, commandName) + } + if len(rest) > 1 && containsHelpArg(rest[1:]) { + return rootInvocation{help: true, disableDaemon: cfg.DisableDaemon, startupSpec: daemonStartupSpecFromConfig(cfg)}, nil + } + return rootInvocation{ + command: commandName, + commandArgs: rest, + execArgs: execArgs, + help: false, + disableDaemon: cfg.DisableDaemon, + startupSpec: daemonStartupSpecFromConfig(cfg), + }, nil +} + +func containsHelpArg(args []string) bool { + for _, arg := range args { + switch arg { + case "--help", "-help", "-h": + return true + } + } + return false +} + +func dispatchViaDaemon(spec daemonStartupSpec, req daemonRequest) (daemonResponse, bool, string, error) { + meta, metaErr := readDaemonMetadata() + if metaErr == nil && !reflect.DeepEqual(meta.StartupSpec, spec) { + if _, err := dialDaemon(meta.SocketPath); err == nil { + return daemonResponse{}, false, "running daemon is incompatible with this invocation; using standalone mode. Run `diode daemon restart` to reload the daemon with the current binary and flags.", nil + } + cleanupDaemonArtifacts(meta.SocketPath, metaPathFromSocket(meta.SocketPath)) + metaErr = os.ErrNotExist + } + + socketPath := "" + if metaErr == nil { + socketPath = meta.SocketPath + } + conn, err := dialDaemon(socketPath) + if err != nil { + cleanupDaemonArtifacts(socketPath, metaPathFromSocket(socketPath)) + if err := spawnDaemon(spec); err != nil { + return daemonResponse{}, true, "", err + } + meta, err = readDaemonMetadata() + if err != nil { + return daemonResponse{}, true, "", err + } + conn, err = dialDaemon(meta.SocketPath) + if err != nil { + return daemonResponse{}, true, "", err + } + } + defer conn.Close() + if err := json.NewEncoder(conn).Encode(req); err != nil { + return daemonResponse{}, true, "", err + } + var resp daemonResponse + if err := json.NewDecoder(conn).Decode(&resp); err != nil { + return daemonResponse{}, true, "", err + } + return resp, true, "", nil +} + +func serveDaemon(ln net.Listener) { + for { + conn, err := ln.Accept() + if err != nil { + if app.Closed() { + return + } + time.Sleep(50 * time.Millisecond) + continue + } + go handleDaemonConn(conn) + } +} + +func handleDaemonConn(conn net.Conn) { + defer conn.Close() + var req daemonRequest + if err := json.NewDecoder(conn).Decode(&req); err != nil { + _ = json.NewEncoder(conn).Encode(daemonResponse{Version: daemonProtocolVersion, ExitCode: 1, Error: err.Error()}) + return + } + resp := executeDaemonRequest(req) + if err := json.NewEncoder(conn).Encode(resp); err != nil { + logDaemonInternalError("Couldn't encode daemon response", err) + return + } + if resp.Shutdown { + go app.Close() + } + if resp.RestartPath != "" { + if err := daemonRestartSelf(resp.RestartPath, daemonState.startup); err != nil { + logDaemonInternalError("Couldn't restart daemon after update", err) + } + } +} + +func executeDaemonRequest(req daemonRequest) daemonResponse { + daemonExecMu.Lock() + defer daemonExecMu.Unlock() + + resp := daemonResponse{Version: daemonProtocolVersion} + if daemonState == nil { + resp.ExitCode = 1 + resp.Error = "daemon state is not initialized" + return resp + } + + switch req.Kind { + case daemonRequestLease: + addr, leaseID, err := daemonLeaseLocalProxy() + if err != nil { + resp.ExitCode = 1 + resp.Error = err.Error() + return resp + } + resp.ProxyAddr = addr + resp.LeaseID = leaseID + return resp + case daemonRequestRelease: + if err := daemonReleaseLocalProxy(req.LeaseID); err != nil { + resp.ExitCode = 1 + resp.Error = err.Error() + } + return resp + case daemonRequestUpdate: + return executeDaemonBufferedRequest(req.Kind, func() (string, error) { + return runDaemonUpdate(req.Args) + }) + case daemonRequestManage: + manageResp := daemonResponse{Version: daemonProtocolVersion} + buffered := executeDaemonBufferedRequest(req.Kind, func() (string, error) { + return "", runDaemonManage(req.Args, &manageResp) + }) + manageResp.Stdout = buffered.Stdout + manageResp.Stderr = buffered.Stderr + manageResp.ExitCode = buffered.ExitCode + manageResp.Error = buffered.Error + return manageResp + } + if req.Kind == daemonRequestApplyMode && req.Command == "publish" { + req.Args = mergeImplicitPublishArgs(req.Args) + } + resp = executeDaemonBufferedRequest(req.Kind, func() (string, error) { + return "", runDaemonCommandArgs(req.Args) + }) + if req.Kind == daemonRequestApplyMode && resp.ExitCode == 0 { + daemonState.updateModeSnapshot(req.Command, req.Args, config.AppConfig) + } + return resp +} + +func executeDaemonBufferedRequest(kind string, fn func() (string, error)) daemonResponse { + resp := daemonResponse{Version: daemonProtocolVersion} + var stdout bytes.Buffer + var stderr bytes.Buffer + *config.AppConfig = cloneDaemonConfig(&daemonState.baseConfig) + resetTransientConfig(config.AppConfig) + resetRequestGlobals() + config.AppConfig.StdoutWriter = &stdout + config.AppConfig.StderrWriter = &stderr + + activeDaemonReqMu.Lock() + activeDaemonReqKind = kind + activeDaemonReqMu.Unlock() + restartPath, err := fn() + activeDaemonReqMu.Lock() + activeDaemonReqKind = "" + activeDaemonReqMu.Unlock() + + resp.Stdout = stdout.String() + resp.Stderr = stderr.String() + resp.RestartPath = restartPath + if err != nil { + resp.ExitCode = exitCodeFromError(err) + resp.Error = err.Error() + if resp.ExitCode == 0 { + resp.ExitCode = 1 + } + } else { + resp.ExitCode = 0 + } + daemonState.baseConfig = sanitizedDaemonBaseConfig(config.AppConfig) + return resp +} + +func runDaemonCommandArgs(args []string) error { + if len(args) == 0 { + return newExitStatusError(2, "missing command") + } + if err := diodeCmd.Flag.Parse(args); err != nil { + return err + } + if err := refreshRequestDerivedConfig(config.AppConfig); err != nil { + return err + } + subCmd := diodeCmd.SubCommand() + if subCmd == nil { + return newExitStatusError(2, "unknown command: %s", args[0]) + } + app.SetCommand(subCmd) + rootArgs := diodeCmd.Flag.Args() + if len(rootArgs) > 1 { + if !subCmd.PassThroughArgs { + if err := subCmd.Flag.Parse(rootArgs[1:]); err != nil { + return err + } + } + } else if !subCmd.PassThroughArgs { + _ = subCmd.Flag.Parse([]string{}) + } + return subCmd.Run() +} + +func refreshRequestDerivedConfig(cfg *config.Config) error { + if cfg == nil { + return nil + } + cfg.SBinds = dedupeStringValues(cfg.SBinds) + cfg.Binds = make([]config.Bind, 0, len(cfg.SBinds)) + for _, str := range cfg.SBinds { + bind, err := parseBind(str) + if err != nil { + return err + } + cfg.Binds = append(cfg.Binds, *bind) + } + if len(cfg.SAllowlists) == 0 { + cfg.Allowlists = nil + return nil + } + cfg.Allowlists = make(map[util.Address]bool, len(cfg.SAllowlists)) + for _, raw := range cfg.SAllowlists { + addr, err := util.DecodeAddress(raw) + if err != nil { + return err + } + cfg.Allowlists[addr] = true + } + return nil +} + +func dedupeStringValues(values config.StringValues) config.StringValues { + if len(values) < 2 { + return values + } + out := make(config.StringValues, 0, len(values)) + for _, value := range values { + if !util.StringsContain(out, value) { + out = append(out, value) + } + } + return out +} + +func mergeImplicitPublishArgs(args []string) []string { + if daemonState == nil { + return args + } + pre, post, ok := splitPublishExecArgs(args) + if !ok || len(post) > 0 { + return args + } + mode, existingArgs := daemonModeArgs() + if mode != "publish" || len(existingArgs) == 0 { + return args + } + existingPre, existingPost, ok := splitPublishExecArgs(existingArgs) + if !ok || len(existingPost) == 0 { + return args + } + merged := mergeImplicitPublishPreArgs(existingPre, pre) + merged = append(merged, "publish") + merged = append(merged, existingPost...) + return merged +} + +func splitPublishExecArgs(args []string) (pre []string, post []string, ok bool) { + for i, arg := range args { + if arg == "publish" { + return append([]string{}, args[:i]...), append([]string{}, args[i+1:]...), true + } + } + return nil, nil, false +} + +func sanitizeModeArgs(mode string, args []string) []string { + if len(args) == 0 { + return nil + } + cmdIdx := -1 + for i, arg := range args { + if arg == mode { + cmdIdx = i + break + } + } + if cmdIdx < 0 { + return append([]string{}, args...) + } + preItems := parseRootExecItems(args[:cmdIdx]) + sanitized := make([]string, 0, len(args)) + for _, item := range preItems { + if daemonStartupFlagNames[item.flagName] { + continue + } + sanitized = append(sanitized, item.args...) + } + sanitized = append(sanitized, args[cmdIdx:]...) + return sanitized +} + +func mergeImplicitPublishPreArgs(existingPre, currentPre []string) []string { + items := append(parseRootExecItems(existingPre), parseRootExecItems(currentPre)...) + if len(items) == 0 { + return nil + } + bindValues := make(config.StringValues, 0) + out := make([]rootExecItem, 0, len(items)) + for _, item := range items { + if item.flagName == "-bind" { + if item.value == "" || util.StringsContain(bindValues, item.value) { + continue + } + bindValues = append(bindValues, item.value) + } + out = append(out, item) + } + merged := make([]string, 0, len(existingPre)+len(currentPre)) + for _, item := range out { + merged = append(merged, item.args...) + } + return merged +} + +type rootExecItem struct { + flagName string + value string + args []string +} + +func parseRootExecItems(args []string) []rootExecItem { + items := make([]rootExecItem, 0, len(args)) + for i := 0; i < len(args); i++ { + arg := args[i] + if !strings.HasPrefix(arg, "-") || arg == "-" { + items = append(items, rootExecItem{args: []string{arg}}) + continue + } + flagName := arg + value := "" + itemArgs := []string{arg} + if idx := strings.Index(arg, "="); idx >= 0 { + flagName = arg[:idx] + value = arg[idx+1:] + } else if i+1 < len(args) && !strings.HasPrefix(args[i+1], "-") { + value = args[i+1] + itemArgs = append(itemArgs, args[i+1]) + i++ + } + items = append(items, rootExecItem{ + flagName: flagName, + value: value, + args: itemArgs, + }) + } + return items +} + +func isDaemonApplyRequest() bool { + activeDaemonReqMu.Lock() + defer activeDaemonReqMu.Unlock() + return activeDaemonReqKind == daemonRequestApplyMode +} + +func cloneDaemonConfig(cfg *config.Config) config.Config { + cp := *cfg + cp.RemoteRPCAddrs = append(config.StringValues{}, cfg.RemoteRPCAddrs...) + cp.SBlockdomains = append(config.StringValues{}, cfg.SBlockdomains...) + cp.SBlocklists = append(config.StringValues{}, cfg.SBlocklists...) + cp.SAllowlists = append(config.StringValues{}, cfg.SAllowlists...) + cp.SBinds = append(config.StringValues{}, cfg.SBinds...) + cp.PublicPublishedPorts = append(config.StringValues{}, cfg.PublicPublishedPorts...) + cp.ProtectedPublishedPorts = append(config.StringValues{}, cfg.ProtectedPublishedPorts...) + cp.PrivatePublishedPorts = append(config.StringValues{}, cfg.PrivatePublishedPorts...) + cp.SSHPublishedServices = append(config.StringValues{}, cfg.SSHPublishedServices...) + cp.ConfigDelete = append(config.StringValues{}, cfg.ConfigDelete...) + cp.ConfigSet = append(config.StringValues{}, cfg.ConfigSet...) + return cp +} + +func sanitizedDaemonBaseConfig(cfg *config.Config) config.Config { + cp := cloneDaemonConfig(cfg) + resetTransientConfig(&cp) + return cp +} + +func resetTransientConfig(cfg *config.Config) { + cfg.StdoutWriter = nil + cfg.StderrWriter = nil + cfg.DisableDaemon = false + cfg.QueryAddress = "" + cfg.ConfigUnsafe = false + cfg.ConfigList = false + cfg.ConfigDelete = nil + cfg.ConfigSet = nil + cfg.PublicPublishedPorts = nil + cfg.ProtectedPublishedPorts = nil + cfg.PrivatePublishedPorts = nil + cfg.SSHPublishedServices = nil + cfg.PublishedPorts = nil + cfg.SBinds = nil + cfg.Binds = nil + cfg.EnableProxyServer = false + cfg.EnableSProxyServer = false + cfg.EnableSocksServer = false + cfg.SocksServerHost = "127.0.0.1" + cfg.SocksServerPort = 1080 + cfg.SocksFallback = "localhost" + cfg.ProxyServerHost = "127.0.0.1" + cfg.ProxyServerPort = 80 + cfg.SProxyServerHost = "127.0.0.1" + cfg.SProxyServerPort = 443 + cfg.SProxyServerPorts = "" + cfg.SProxyServerCertPath = "./priv/fullchain.pem" + cfg.SProxyServerPrivPath = "./priv/privkey.pem" + cfg.AllowRedirectToSProxy = false + cfg.BNSForce = false + cfg.BNSRegister = "" + cfg.BNSUnregister = "" + cfg.BNSTransfer = "" + cfg.BNSLookup = "" + cfg.BNSAccount = "" + cfg.Experimental = false +} + +func resetRequestGlobals() { + enableStaticServer = false + scfg.RootDirectory = "" + scfg.Host = "127.0.0.1" + scfg.Port = 8080 + scfg.Indexed = false + filesFileroot = "" + edgeACME = false + edgeACMEEmail = "" + edgeACMEAddtlCerts = "" + if fetchCfg != nil { + *fetchCfg = fetchConfig{Method: "GET"} + } + if tokenCfg != nil { + *tokenCfg = tokenConfig{Gas: "21000"} + } + dryRun = false + network = "mainnet" + contractAddress = "" + oasisClient = nil + wantWireGuard = false + wgSuffix = "" +} + +func exitCodeFromError(err error) int { + type statusError interface{ Status() int } + type codeError interface{ Code() int } + if err == nil { + return 0 + } + if se, ok := err.(statusError); ok { + return se.Status() + } + if ce, ok := err.(codeError); ok { + return ce.Code() + } + return 1 +} + +func daemonStartupSpecFromConfig(cfg *config.Config) daemonStartupSpec { + return daemonStartupSpec{ + DBPath: cfg.DBPath, + RetryTimes: cfg.RetryTimes, + EdgeE2ETimeout: cfg.EdgeE2ETimeout, + EnableUpdate: cfg.EnableUpdate, + EnableMetrics: cfg.EnableMetrics, + EnableTray: cfg.EnableTray, + BlockquickDowngrade: cfg.BlockquickDowngrade, + Debug: cfg.Debug, + EnableAPIServer: cfg.EnableAPIServer, + APIServerAddr: cfg.APIServerAddr, + RlimitNofile: cfg.RlimitNofile, + LogFilePath: cfg.LogFilePath, + LogDateTime: cfg.LogDateTime, + ConfigFilePath: cfg.ConfigFilePath, + CPUProfile: cfg.CPUProfile, + MEMProfile: cfg.MEMProfile, + PProfPort: cfg.PProfPort, + BlockProfile: cfg.BlockProfile, + BlockProfileRate: cfg.BlockProfileRate, + MutexProfile: cfg.MutexProfile, + MutexProfileRate: cfg.MutexProfileRate, + RemoteRPCTimeout: cfg.RemoteRPCTimeout, + RetryWait: cfg.RetryWait, + RemoteRPCAddrs: append(config.StringValues{}, cfg.RemoteRPCAddrs...), + SBlockdomains: append(config.StringValues{}, cfg.SBlockdomains...), + SBlocklists: append(config.StringValues{}, cfg.SBlocklists...), + SAllowlists: append(config.StringValues{}, cfg.SAllowlists...), + ResolveCacheTime: cfg.ResolveCacheTime, + MaxPortsPerDevice: cfg.MaxPortsPerDevice, + } +} + +func applyDaemonStartupSpec(cfg *config.Config, spec daemonStartupSpec) { + cfg.DBPath = spec.DBPath + cfg.RetryTimes = spec.RetryTimes + cfg.EdgeE2ETimeout = spec.EdgeE2ETimeout + cfg.EnableUpdate = spec.EnableUpdate + cfg.EnableMetrics = spec.EnableMetrics + cfg.EnableTray = spec.EnableTray + cfg.BlockquickDowngrade = spec.BlockquickDowngrade + cfg.Debug = spec.Debug + cfg.EnableAPIServer = spec.EnableAPIServer + cfg.APIServerAddr = spec.APIServerAddr + cfg.RlimitNofile = spec.RlimitNofile + cfg.LogFilePath = spec.LogFilePath + cfg.LogDateTime = spec.LogDateTime + cfg.ConfigFilePath = spec.ConfigFilePath + cfg.CPUProfile = spec.CPUProfile + cfg.MEMProfile = spec.MEMProfile + cfg.PProfPort = spec.PProfPort + cfg.BlockProfile = spec.BlockProfile + cfg.BlockProfileRate = spec.BlockProfileRate + cfg.MutexProfile = spec.MutexProfile + cfg.MutexProfileRate = spec.MutexProfileRate + cfg.RemoteRPCTimeout = spec.RemoteRPCTimeout + cfg.RetryWait = spec.RetryWait + cfg.RemoteRPCAddrs = append(config.StringValues{}, spec.RemoteRPCAddrs...) + cfg.SBlockdomains = append(config.StringValues{}, spec.SBlockdomains...) + cfg.SBlocklists = append(config.StringValues{}, spec.SBlocklists...) + cfg.SAllowlists = append(config.StringValues{}, spec.SAllowlists...) + cfg.ResolveCacheTime = spec.ResolveCacheTime + cfg.MaxPortsPerDevice = spec.MaxPortsPerDevice +} + +func readDaemonMetadata() (daemonMetadata, error) { + socketPath, metaPath, err := daemonPaths() + if err != nil { + return daemonMetadata{}, err + } + _ = socketPath + buf, err := os.ReadFile(metaPath) + if err != nil { + return daemonMetadata{}, err + } + var meta daemonMetadata + if err := json.Unmarshal(buf, &meta); err != nil { + return daemonMetadata{}, err + } + return meta, nil +} + +func writeDaemonMetadata(path string, meta daemonMetadata) error { + buf, err := json.Marshal(meta) + if err != nil { + return err + } + return os.WriteFile(path, buf, 0600) +} + +func cleanupDaemonArtifacts(socketPath, metaPath string) { + if socketPath != "" { + cleanupDaemonTransport(socketPath) + } + if metaPath != "" { + _ = os.Remove(metaPath) + } +} + +func signalDaemonReady() error { + fdStr := strings.TrimSpace(os.Getenv(envDaemonReadyFD)) + if fdStr == "" { + return nil + } + fd, err := strconv.Atoi(fdStr) + if err != nil { + return err + } + f := os.NewFile(uintptr(fd), "daemon-ready") + if f == nil { + return fmt.Errorf("invalid daemon ready file descriptor") + } + defer f.Close() + _, err = f.Write([]byte{1}) + return err +} + +func daemonRestartEnv(spec daemonStartupSpec) ([]string, error) { + specBytes, err := json.Marshal(spec) + if err != nil { + return nil, err + } + env := make([]string, 0, len(os.Environ())+1) + for _, item := range os.Environ() { + if strings.HasPrefix(item, envDaemonReadyFD+"=") || + strings.HasPrefix(item, envDaemonStartupSpec+"=") || + strings.HasPrefix(item, envDaemonRestoreArgs+"=") { + continue + } + env = append(env, item) + } + env = append(env, fmt.Sprintf("%s=%s", envDaemonStartupSpec, string(specBytes))) + if restoreArgs := daemonRestoreArgsForRestart(); len(restoreArgs) > 0 { + restoreBytes, err := json.Marshal(restoreArgs) + if err != nil { + return nil, err + } + env = append(env, fmt.Sprintf("%s=%s", envDaemonRestoreArgs, string(restoreBytes))) + } + return env, nil +} + +func daemonRestoreArgsFromEnv() ([]string, error) { + raw := strings.TrimSpace(os.Getenv(envDaemonRestoreArgs)) + if raw == "" { + return nil, nil + } + var args []string + if err := json.Unmarshal([]byte(raw), &args); err != nil { + return nil, err + } + return args, nil +} + +func logDaemonInternalError(msg string, err error) { + cfg := config.AppConfig + if cfg != nil && cfg.Logger != nil { + cfg.Logger.Error("%s: %v", msg, err) + return + } + fmt.Fprintf(os.Stderr, "%s: %v\n", msg, err) +} + +func (rd *runtimeDaemon) updateModeSnapshot(mode string, args []string, cfg *config.Config) { + if rd == nil || cfg == nil { + return + } + rd.stateMu.Lock() + defer rd.stateMu.Unlock() + rd.activeMode = mode + rd.activeArgs = sanitizeModeArgs(mode, args) + rd.ports = clonePortMap(cfg.PublishedPorts) + rd.binds = append([]config.Bind{}, cfg.Binds...) + rd.socksOn = cfg.EnableSocksServer + rd.socksAddr = cfg.SocksServerAddr() + rd.apiOn = cfg.EnableAPIServer + rd.apiAddr = cfg.APIServerAddr +} + +func (rd *runtimeDaemon) clearModeSnapshot() { + if rd == nil { + return + } + rd.stateMu.Lock() + defer rd.stateMu.Unlock() + rd.activeMode = "" + rd.activeArgs = nil + rd.ports = map[int]*config.Port{} + rd.binds = nil + rd.socksOn = false + rd.socksAddr = "" + rd.apiOn = false + rd.apiAddr = "" +} + +func daemonRestoreArgsForRestart() []string { + if daemonState == nil { + return nil + } + daemonState.stateMu.Lock() + defer daemonState.stateMu.Unlock() + if daemonState.activeMode == "" || len(daemonState.activeArgs) == 0 { + return nil + } + return append([]string{}, daemonState.activeArgs...) +} + +func (rd *runtimeDaemon) snapshotStatus() daemonRuntimeStatus { + status := daemonRuntimeStatus{} + if rd == nil { + return status + } + rd.stateMu.Lock() + defer rd.stateMu.Unlock() + status.ActiveMode = rd.activeMode + status.ActiveArgs = append([]string{}, rd.activeArgs...) + status.PublishedPorts = clonePortMap(rd.ports) + status.Binds = append([]config.Bind{}, rd.binds...) + status.SocksEnabled = rd.socksOn + status.SocksAddr = rd.socksAddr + status.APIEnabled = rd.apiOn + status.APIAddr = rd.apiAddr + return status +} + +func clonePortMap(in map[int]*config.Port) map[int]*config.Port { + if len(in) == 0 { + return map[int]*config.Port{} + } + out := make(map[int]*config.Port, len(in)) + for k, v := range in { + if v == nil { + out[k] = nil + continue + } + cp := *v + cp.Allowlist = cloneAddressSet(v.Allowlist) + cp.BnsAllowlist = cloneStringBoolSet(v.BnsAllowlist) + cp.DriveAllowList = cloneAddressSet(v.DriveAllowList) + cp.DriveMemberAllowList = cloneAddressSet(v.DriveMemberAllowList) + out[k] = &cp + } + return out +} + +func cloneAddressSet(in map[util.Address]bool) map[util.Address]bool { + if len(in) == 0 { + return map[util.Address]bool{} + } + out := make(map[util.Address]bool, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func cloneStringBoolSet(in map[string]bool) map[string]bool { + if len(in) == 0 { + return map[string]bool{} + } + out := make(map[string]bool, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func daemonLeaseLocalProxy() (string, string, error) { + if err := app.Start(); err != nil { + return "", "", err + } + cfg := config.AppConfig + socksCfg := rpc.Config{ + Addr: net.JoinHostPort("127.0.0.1", "0"), + FleetAddr: cfg.FleetAddr, + Blocklists: cfg.Blocklists(), + Allowlists: cfg.Allowlists, + EnableProxy: false, + ProxyServerAddr: cfg.ProxyServerAddr(), + Fallback: cfg.SocksFallback, + } + socksServer, err := rpc.NewSocksServer(socksCfg, app.clientManager) + if err != nil { + return "", "", err + } + if err := socksServer.Start(); err != nil { + return "", "", err + } + addr := socksServer.Addr() + if addr == nil { + socksServer.Close() + return "", "", fmt.Errorf("proxy lease did not expose an address") + } + tcpAddr, ok := addr.(*net.TCPAddr) + if !ok { + socksServer.Close() + return "", "", fmt.Errorf("unexpected proxy lease address type: %T", addr) + } + host := tcpAddr.IP.String() + if host == "" || host == "::" || host == "0.0.0.0" { + host = "127.0.0.1" + } + leaseID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), tcpAddr.Port) + daemonState.leasesMu.Lock() + daemonState.leases[leaseID] = socksServer + daemonState.leasesMu.Unlock() + return net.JoinHostPort(host, strconv.Itoa(tcpAddr.Port)), leaseID, nil +} + +func daemonReleaseLocalProxy(leaseID string) error { + if daemonState == nil { + return nil + } + daemonState.leasesMu.Lock() + socksServer := daemonState.leases[leaseID] + delete(daemonState.leases, leaseID) + daemonState.leasesMu.Unlock() + if socksServer != nil { + socksServer.Close() + } + return nil +} + +func (rt *runtimeDaemon) closeLeases() { + rt.leasesMu.Lock() + defer rt.leasesMu.Unlock() + for leaseID, server := range rt.leases { + delete(rt.leases, leaseID) + if server != nil { + server.Close() + } + } +} + +func releaseDaemonLease(leaseID string) error { + if leaseID == "" { + return nil + } + meta, err := readDaemonMetadata() + if err != nil { + return err + } + conn, err := dialDaemon(meta.SocketPath) + if err != nil { + return err + } + defer conn.Close() + req := daemonRequest{ + Version: daemonProtocolVersion, + Kind: daemonRequestRelease, + LeaseID: leaseID, + } + if err := json.NewEncoder(conn).Encode(req); err != nil { + return err + } + var resp daemonResponse + if err := json.NewDecoder(conn).Decode(&resp); err != nil { + return err + } + if resp.Error != "" { + return fmt.Errorf("%s", resp.Error) + } + return nil +} diff --git a/cmd/diode/daemon_manage.go b/cmd/diode/daemon_manage.go new file mode 100644 index 00000000..db369077 --- /dev/null +++ b/cmd/diode/daemon_manage.go @@ -0,0 +1,573 @@ +package main + +import ( + "encoding/json" + "io" + "os" + "sort" + "strconv" + "strings" + "time" + + "github.com/diodechain/diode_client/command" + "github.com/diodechain/diode_client/config" +) + +type daemonRuntimeStatus struct { + ActiveMode string + ActiveArgs []string + PublishedPorts map[int]*config.Port + Binds []config.Bind + SocksEnabled bool + SocksAddr string + APIEnabled bool + APIAddr string +} + +var ( + daemonManageCmd = &command.Command{ + Name: "daemon", + HelpText: ` Inspect and manage the running diode daemon.`, + ExampleText: " diode daemon status\n diode daemon stop\n diode daemon ports remove 80 443\n diode daemon ports clear", + Run: daemonManageHandler, + Type: command.EmptyConnectionCommand, + PassThroughArgs: true, + } +) + +func init() { + diodeCmd.AddSubCommand(daemonManageCmd) +} + +func daemonManageHandler() error { + return nil +} + +func handleDaemonManagerCLI(args []string) (bool, int) { + if len(args) == 0 || args[0] != "daemon" { + return false, 0 + } + subArgs := args[1:] + if len(subArgs) == 0 { + subArgs = []string{"status"} + } + switch subArgs[0] { + case "status": + return true, runDaemonManagerStatus() + case "stop": + return true, runDaemonManagerAction([]string{"daemon", "stop"}) + case "restart": + return true, runDaemonManagerRestart() + case "ports": + return true, runDaemonManagerPorts(subArgs[1:]) + default: + stderrln("usage: diode daemon [status|stop|restart|ports]") + stderrln(" diode daemon ports [remove|clear]") + return true, 2 + } +} + +func runDaemonManagerStatus() int { + resp, running, err := dispatchToRunningDaemon(daemonRequest{ + Version: daemonProtocolVersion, + Kind: daemonRequestManage, + Command: "daemon", + Args: []string{"daemon", "status"}, + }) + if err != nil { + stderrln(err.Error()) + return 1 + } + if !running { + stdoutln("Daemon status: not running") + return 0 + } + if resp.Stdout != "" { + _, _ = io.WriteString(stdoutWriter(), resp.Stdout) + } + if resp.Stderr != "" { + _, _ = io.WriteString(stderrWriter(), resp.Stderr) + } + return resp.ExitCode +} + +func runDaemonManagerPorts(args []string) int { + if len(args) == 0 { + stderrln("usage: diode daemon ports [remove|clear]") + return 2 + } + switch args[0] { + case "remove", "rm": + if len(args) < 2 { + stderrln("usage: diode daemon ports remove [...]") + return 2 + } + reqArgs := []string{"daemon", "ports", "remove"} + reqArgs = append(reqArgs, args[1:]...) + return runDaemonManagerAction(reqArgs) + case "clear": + return runDaemonManagerAction([]string{"daemon", "ports", "clear"}) + default: + stderrln("usage: diode daemon ports [remove|clear]") + return 2 + } +} + +func runDaemonManagerAction(args []string) int { + resp, running, err := dispatchToRunningDaemon(daemonRequest{ + Version: daemonProtocolVersion, + Kind: daemonRequestManage, + Command: "daemon", + Args: args, + }) + if err != nil { + stderrln(err.Error()) + return 1 + } + if !running { + stdoutln("Daemon status: not running") + return 1 + } + if resp.Stdout != "" { + _, _ = io.WriteString(stdoutWriter(), resp.Stdout) + } + if resp.Stderr != "" { + _, _ = io.WriteString(stderrWriter(), resp.Stderr) + } + if len(args) >= 2 && args[0] == "daemon" && args[1] == "stop" && resp.ExitCode == 0 { + deadline := time.Now().Add(10 * time.Second) + for time.Now().Before(deadline) { + time.Sleep(100 * time.Millisecond) + _, running, err := dispatchToRunningDaemon(daemonRequest{ + Version: daemonProtocolVersion, + Kind: daemonRequestManage, + Command: "daemon", + Args: []string{"daemon", "status"}, + }) + if err == nil && !running { + return 0 + } + } + stderrln("daemon stop timed out waiting for the daemon to exit") + return 1 + } + return resp.ExitCode +} + +func runDaemonManagerRestart() int { + resp, running, err := dispatchToRunningDaemon(daemonRequest{ + Version: daemonProtocolVersion, + Kind: daemonRequestManage, + Command: "daemon", + Args: []string{"daemon", "restart"}, + }) + if err != nil { + stderrln(err.Error()) + return 1 + } + if !running { + stdoutln("Daemon status: not running") + return 1 + } + if resp.Stdout != "" { + _, _ = io.WriteString(stdoutWriter(), resp.Stdout) + } + if resp.Stderr != "" { + _, _ = io.WriteString(stderrWriter(), resp.Stderr) + } + if resp.ExitCode != 0 { + return resp.ExitCode + } + deadline := time.Now().Add(10 * time.Second) + for time.Now().Before(deadline) { + time.Sleep(200 * time.Millisecond) + statusResp, ok, err := dispatchToRunningDaemon(daemonRequest{ + Version: daemonProtocolVersion, + Kind: daemonRequestManage, + Command: "daemon", + Args: []string{"daemon", "status"}, + }) + if err == nil && ok && statusResp.ExitCode == 0 { + stdoutln("Daemon restarted.") + return 0 + } + } + stderrln("daemon restart timed out waiting for the daemon to come back") + return 1 +} + +func dispatchToRunningDaemon(req daemonRequest) (daemonResponse, bool, error) { + var lastMeta daemonMetadata + for attempt := 0; attempt < 10; attempt++ { + meta, err := readDaemonMetadata() + if err != nil { + if os.IsNotExist(err) { + if attempt == 9 { + return daemonResponse{}, false, nil + } + time.Sleep(100 * time.Millisecond) + continue + } + return daemonResponse{}, false, err + } + lastMeta = meta + conn, err := dialDaemon(meta.SocketPath) + if err != nil { + if attempt == 9 { + cleanupDaemonArtifacts(meta.SocketPath, metaPathFromSocket(meta.SocketPath)) + return daemonResponse{}, false, nil + } + time.Sleep(100 * time.Millisecond) + continue + } + defer conn.Close() + if err := json.NewEncoder(conn).Encode(req); err != nil { + return daemonResponse{}, true, err + } + var resp daemonResponse + if err := json.NewDecoder(conn).Decode(&resp); err != nil { + return daemonResponse{}, true, err + } + return resp, true, nil + } + if lastMeta.SocketPath != "" { + cleanupDaemonArtifacts(lastMeta.SocketPath, metaPathFromSocket(lastMeta.SocketPath)) + } + return daemonResponse{}, false, nil +} + +func runDaemonManage(args []string, resp *daemonResponse) error { + if len(args) == 0 || args[0] != "daemon" { + return newExitStatusError(2, "missing daemon action") + } + action := "status" + if len(args) > 1 { + action = args[1] + } + switch action { + case "status": + renderDaemonStatus() + return nil + case "stop": + stdoutln("Stopping diode daemon.") + app.StopMode() + daemonState.clearModeSnapshot() + if resp != nil { + resp.Shutdown = true + } + return nil + case "restart": + exePath, err := os.Executable() + if err != nil { + exePath = os.Args[0] + } + stdoutln("Restarting diode daemon.") + if resp != nil { + resp.RestartPath = exePath + } + return nil + case "ports": + if len(args) < 3 { + return newExitStatusError(2, "usage: diode daemon ports [remove|clear]") + } + switch args[2] { + case "remove", "rm": + if len(args) < 4 { + return newExitStatusError(2, "usage: diode daemon ports remove [...]") + } + ports := make([]int, 0, len(args)-3) + for _, raw := range args[3:] { + port, err := strconv.Atoi(raw) + if err != nil || port < 1 || port > 65535 { + return newExitStatusError(2, "invalid port: %s", raw) + } + ports = append(ports, port) + } + return daemonRemoveManagedPorts(ports) + case "clear": + return daemonClearManagedPorts() + default: + return newExitStatusError(2, "usage: diode daemon ports [remove|clear]") + } + default: + return newExitStatusError(2, "unknown daemon action: %s", action) + } +} + +func renderDaemonStatus() { + status := daemonState.snapshotStatus() + cfg := config.AppConfig + + cfg.PrintLabel("Daemon status", "running") + cfg.PrintLabel("PID", strconv.Itoa(os.Getpid())) + cfg.PrintLabel("Socket", daemonState.socketPath) + mode := status.ActiveMode + if mode == "" { + mode = "none" + } + cfg.PrintLabel("Active mode", mode) + cfg.PrintLabel("Client address", cfg.ClientAddr.HexString()) + cfg.PrintLabel("Fleet address", cfg.FleetAddr.HexString()) + if cfg.ClientName != "" { + cfg.PrintLabel("Client name", cfg.ClientName+".diode") + } + if status.SocksEnabled { + cfg.PrintLabel("SOCKS proxy", status.SocksAddr) + } else { + cfg.PrintLabel("SOCKS proxy", "disabled") + } + if status.APIEnabled { + cfg.PrintLabel("Config API", status.APIAddr) + } else { + cfg.PrintLabel("Config API", "disabled") + } + if len(status.ActiveArgs) > 0 { + cfg.PrintLabel("Mode args", strings.Join(status.ActiveArgs, " ")) + } + if len(status.PublishedPorts) > 0 { + renderPublishedPortMap(cfg, status.PublishedPorts) + } else { + cfg.PrintLabel("Published ports", "none") + } + if len(status.Binds) > 0 { + renderBindMap(cfg, status.Binds) + } else { + cfg.PrintLabel("Binds", "none") + } +} + +func portAllowlistStrings(port *config.Port) []string { + if port == nil { + return nil + } + addrs := make([]string, 0, len(port.Allowlist)+len(port.BnsAllowlist)+len(port.DriveAllowList)+len(port.DriveMemberAllowList)) + for addr := range port.Allowlist { + addrs = append(addrs, addr.HexString()) + } + for bnsName := range port.BnsAllowlist { + addrs = append(addrs, bnsName) + } + for drive := range port.DriveAllowList { + addrs = append(addrs, drive.HexString()) + } + for driveMember := range port.DriveMemberAllowList { + addrs = append(addrs, driveMember.HexString()) + } + sort.Strings(addrs) + return addrs +} + +func daemonRemoveManagedPorts(ports []int) error { + mode, args := daemonModeArgs() + switch mode { + case "": + return newExitStatusError(1, "daemon has no active mode") + case "publish": + return daemonReapplyPublishWithoutPorts(args, ports) + case "files": + current := daemonState.snapshotStatus() + if len(current.PublishedPorts) != 1 { + return newExitStatusError(1, "files mode is in an unexpected state") + } + for _, port := range ports { + if _, ok := current.PublishedPorts[port]; ok { + app.StopMode() + daemonState.clearModeSnapshot() + stdoutf("Removed published port: %d\n", port) + stdoutln("Files mode stopped because no published ports remain.") + return nil + } + } + return newExitStatusError(1, "requested port is not active in files mode") + default: + return newExitStatusError(1, "port removal is only supported for publish/files modes; current mode is %s", mode) + } +} + +func daemonClearManagedPorts() error { + mode, _ := daemonModeArgs() + switch mode { + case "": + stdoutln("Daemon has no active mode.") + return nil + case "publish", "files": + app.StopMode() + daemonState.clearModeSnapshot() + stdoutln("Removed all published ports and stopped the active publish mode.") + return nil + default: + return newExitStatusError(1, "clearing published ports is only supported for publish/files modes; current mode is %s", mode) + } +} + +func daemonModeArgs() (string, []string) { + daemonState.stateMu.Lock() + defer daemonState.stateMu.Unlock() + return daemonState.activeMode, append([]string{}, daemonState.activeArgs...) +} + +func daemonReapplyPublishWithoutPorts(args []string, ports []int) error { + portSet := make(map[int]bool, len(ports)) + for _, port := range ports { + portSet[port] = true + } + newArgs, removed, err := filterPublishCommandArgs(args, portSet) + if err != nil { + return err + } + if len(removed) == 0 { + return newExitStatusError(1, "none of the requested ports are currently configured in publish mode") + } + sort.Ints(removed) + if countPublishManagedFlags(newArgs) == 0 && len(config.AppConfig.Binds) == 0 { + app.StopMode() + daemonState.clearModeSnapshot() + stdoutf("Removed published ports: %s\n", joinPorts(removed)) + stdoutln("No published ports remain; publish mode stopped.") + return nil + } + if err := runDaemonCommandAsKind(daemonRequestApplyMode, newArgs); err != nil { + return err + } + daemonState.updateModeSnapshot("publish", newArgs, config.AppConfig) + stdoutf("Removed published ports: %s\n", joinPorts(removed)) + stdoutln("Publish mode was reapplied successfully.") + return nil +} + +func filterPublishCommandArgs(args []string, removePorts map[int]bool) ([]string, []int, error) { + if len(args) == 0 || args[0] != "publish" { + return nil, nil, newExitStatusError(1, "daemon is not tracking a publish command") + } + filtered := []string{args[0]} + removed := make(map[int]bool) + for i := 1; i < len(args); i++ { + arg := args[i] + flagName, inlineValue, matched := parseManagedPublishFlag(arg) + if !matched { + filtered = append(filtered, arg) + continue + } + value := inlineValue + if value == "" { + if i+1 >= len(args) { + return nil, nil, newExitStatusError(2, "flag %s is missing a value", flagName) + } + i++ + value = args[i] + } + externPort, err := managedFlagExternPort(flagName, value) + if err != nil { + return nil, nil, err + } + if removePorts[externPort] { + removed[externPort] = true + continue + } + if inlineValue == "" { + filtered = append(filtered, flagName, value) + } else { + filtered = append(filtered, flagName+"="+value) + } + } + out := make([]int, 0, len(removed)) + for port := range removed { + out = append(out, port) + } + return filtered, out, nil +} + +func parseManagedPublishFlag(arg string) (string, string, bool) { + for _, name := range []string{"-public", "-protected", "-private", "-sshd", "-files"} { + if arg == name { + return name, "", true + } + prefix := name + "=" + if strings.HasPrefix(arg, prefix) { + return name, strings.TrimPrefix(arg, prefix), true + } + } + return "", "", false +} + +func managedFlagExternPort(flagName, value string) (int, error) { + switch flagName { + case "-public", "-protected", "-private": + return extractExternPortFromPortSpec(value) + case "-files": + portSpec, _, err := expandFilesSpec(value) + if err != nil { + return 0, err + } + return extractExternPortFromPortSpec(portSpec) + case "-sshd": + head := sshServicePattern.FindStringSubmatch(strings.TrimSpace(strings.Split(value, ",")[0])) + if len(head) != 4 { + return 0, newExitStatusError(2, "invalid ssh publish spec: %s", value) + } + port, err := strconv.Atoi(head[2]) + if err != nil { + return 0, newExitStatusError(2, "invalid ssh publish spec: %s", value) + } + return port, nil + default: + return 0, newExitStatusError(2, "unsupported publish flag: %s", flagName) + } +} + +func extractExternPortFromPortSpec(value string) (int, error) { + head := strings.TrimSpace(strings.Split(value, ",")[0]) + match := portPattern.FindStringSubmatch(head) + if len(match) != 8 { + return 0, newExitStatusError(2, "invalid publish spec: %s", value) + } + srcPort, err := strconv.Atoi(match[3]) + if err != nil { + return 0, err + } + if match[5] == "" { + return srcPort, nil + } + toPort, err := strconv.Atoi(match[5]) + if err != nil { + return 0, err + } + return toPort, nil +} + +func countPublishManagedFlags(args []string) int { + count := 0 + for i := 1; i < len(args); i++ { + flagName, inlineValue, matched := parseManagedPublishFlag(args[i]) + if !matched { + continue + } + count++ + if inlineValue == "" && i+1 < len(args) { + i++ + } + _ = flagName + } + return count +} + +func joinPorts(ports []int) string { + items := make([]string, 0, len(ports)) + for _, port := range ports { + items = append(items, strconv.Itoa(port)) + } + return strings.Join(items, ", ") +} + +func runDaemonCommandAsKind(kind string, args []string) error { + activeDaemonReqMu.Lock() + prev := activeDaemonReqKind + activeDaemonReqKind = kind + activeDaemonReqMu.Unlock() + defer func() { + activeDaemonReqMu.Lock() + activeDaemonReqKind = prev + activeDaemonReqMu.Unlock() + }() + return runDaemonCommandArgs(args) +} diff --git a/cmd/diode/daemon_test.go b/cmd/diode/daemon_test.go new file mode 100644 index 00000000..2377ad04 --- /dev/null +++ b/cmd/diode/daemon_test.go @@ -0,0 +1,305 @@ +package main + +import ( + "strings" + "testing" + "time" + + "github.com/diodechain/diode_client/config" +) + +func TestParseRootInvocationExtractsCommandAndStartupFlags(t *testing.T) { + inv, err := parseRootInvocation([]string{ + "-debug=true", + "-dbpath", "/tmp/diode.db", + "query", + "-address", "0xabc", + }) + if err != nil { + t.Fatalf("parseRootInvocation() error = %v", err) + } + if inv.command != "query" { + t.Fatalf("command = %q, want query", inv.command) + } + if len(inv.commandArgs) != 3 || inv.commandArgs[0] != "query" { + t.Fatalf("commandArgs = %#v, want query subcommand args", inv.commandArgs) + } + if !inv.startupSpec.Debug { + t.Fatalf("startupSpec.Debug = false, want true") + } + if inv.startupSpec.DBPath != "/tmp/diode.db" { + t.Fatalf("startupSpec.DBPath = %q, want /tmp/diode.db", inv.startupSpec.DBPath) + } +} + +func TestParseRootInvocationDetectsHelpAndNoDaemon(t *testing.T) { + inv, err := parseRootInvocation([]string{"-no-daemon", "publish", "--help"}) + if err != nil { + t.Fatalf("parseRootInvocation() error = %v", err) + } + if !inv.help { + t.Fatalf("help = false, want true") + } + if !inv.disableDaemon { + t.Fatalf("disableDaemon = false, want true") + } +} + +func TestParseRootInvocationDefaultsToPublishForRootFlagsOnly(t *testing.T) { + inv, err := parseRootInvocation([]string{"-bind", "8080:0x8911295322a1b94539e258e46f18e33acf21b48a:80"}) + if err != nil { + t.Fatalf("parseRootInvocation() error = %v", err) + } + if inv.command != "publish" { + t.Fatalf("command = %q, want publish", inv.command) + } + if len(inv.commandArgs) != 1 || inv.commandArgs[0] != "publish" { + t.Fatalf("commandArgs = %#v, want implicit publish only", inv.commandArgs) + } + if len(inv.execArgs) != 3 || inv.execArgs[0] != "-bind" || inv.execArgs[2] != "publish" { + t.Fatalf("execArgs = %#v, want root flags plus implicit publish", inv.execArgs) + } +} + +func TestSanitizedDaemonBaseConfigResetsTransientState(t *testing.T) { + cfg := newRootConfig() + cfg.ConfigList = true + cfg.QueryAddress = "0x1" + cfg.EnableProxyServer = true + cfg.SocksServerPort = 9999 + cfg.PublicPublishedPorts = config.StringValues{"80:80"} + cfg.BNSLookup = "example" + cfg.StdoutWriter = testingLoggerWriter{t} + cfg.PublishedPorts = map[int]*config.Port{80: {}} + + sanitized := sanitizedDaemonBaseConfig(cfg) + if sanitized.ConfigList { + t.Fatalf("ConfigList = true, want false") + } + if sanitized.QueryAddress != "" { + t.Fatalf("QueryAddress = %q, want empty", sanitized.QueryAddress) + } + if sanitized.EnableProxyServer { + t.Fatalf("EnableProxyServer = true, want false") + } + if sanitized.SocksServerPort != 1080 { + t.Fatalf("SocksServerPort = %d, want 1080", sanitized.SocksServerPort) + } + if len(sanitized.PublicPublishedPorts) != 0 { + t.Fatalf("PublicPublishedPorts = %#v, want empty", sanitized.PublicPublishedPorts) + } + if sanitized.BNSLookup != "" { + t.Fatalf("BNSLookup = %q, want empty", sanitized.BNSLookup) + } + if sanitized.StdoutWriter != nil || sanitized.StderrWriter != nil { + t.Fatalf("stdout/stderr writers should be cleared") + } + if sanitized.PublishedPorts != nil { + t.Fatalf("PublishedPorts = %#v, want nil", sanitized.PublishedPorts) + } +} + +func TestDaemonStartupSpecFromConfigCopiesRootScopedValues(t *testing.T) { + cfg := newRootConfig() + cfg.RemoteRPCTimeout = 7 * time.Second + cfg.RetryWait = 2 * time.Second + cfg.ResolveCacheTime = 5 * time.Minute + cfg.SBlocklists = config.StringValues{"0x1"} + + spec := daemonStartupSpecFromConfig(cfg) + if spec.RemoteRPCTimeout != 7*time.Second { + t.Fatalf("RemoteRPCTimeout = %v, want 7s", spec.RemoteRPCTimeout) + } + if spec.RetryWait != 2*time.Second { + t.Fatalf("RetryWait = %v, want 2s", spec.RetryWait) + } + if spec.ResolveCacheTime != 5*time.Minute { + t.Fatalf("ResolveCacheTime = %v, want 5m", spec.ResolveCacheTime) + } + if len(spec.SBlocklists) != 1 || spec.SBlocklists[0] != "0x1" { + t.Fatalf("SBlocklists = %#v, want [0x1]", spec.SBlocklists) + } +} + +func TestDaemonRestartEnvReplacesDaemonSpecificVars(t *testing.T) { + t.Setenv(envDaemonReadyFD, "3") + t.Setenv(envDaemonStartupSpec, `{"debug":false}`) + daemonState = nil + + env, err := daemonRestartEnv(daemonStartupSpec{Debug: true}) + if err != nil { + t.Fatalf("daemonRestartEnv() error = %v", err) + } + + startupVars := 0 + for _, item := range env { + if strings.HasPrefix(item, envDaemonReadyFD+"=") { + t.Fatalf("daemon restart env still contains %s: %q", envDaemonReadyFD, item) + } + if strings.HasPrefix(item, envDaemonStartupSpec+"=") { + startupVars++ + if !strings.Contains(item, `"debug":true`) { + t.Fatalf("startup spec env = %q, want debug=true", item) + } + } + } + if startupVars != 1 { + t.Fatalf("startup spec env vars = %d, want 1", startupVars) + } +} + +func TestDaemonRestartEnvPreservesActiveModeArgs(t *testing.T) { + prev := daemonState + defer func() { daemonState = prev }() + daemonState = &runtimeDaemon{ + activeMode: "publish", + activeArgs: []string{"publish", "-public", "80"}, + } + + env, err := daemonRestartEnv(daemonStartupSpec{Debug: true}) + if err != nil { + t.Fatalf("daemonRestartEnv() error = %v", err) + } + + restoreVars := 0 + for _, item := range env { + if strings.HasPrefix(item, envDaemonRestoreArgs+"=") { + restoreVars++ + if !strings.Contains(item, `"publish"`) || !strings.Contains(item, `"-public"`) || !strings.Contains(item, `"80"`) { + t.Fatalf("restore args env = %q", item) + } + } + } + if restoreVars != 1 { + t.Fatalf("restore args env vars = %d, want 1", restoreVars) + } +} + +func TestFilterPublishCommandArgsRemovesRequestedPorts(t *testing.T) { + args := []string{ + "publish", + "-public", "80:80", + "-private=127.0.0.1:22:2222,0x1234567890123456789012345678901234567890", + "-sshd", "private:2022:ubuntu,0x1234567890123456789012345678901234567890", + "-socksd=true", + } + filtered, removed, err := filterPublishCommandArgs(args, map[int]bool{80: true, 2022: true}) + if err != nil { + t.Fatalf("filterPublishCommandArgs() error = %v", err) + } + if got := strings.Join(filtered, " "); got != "publish -private=127.0.0.1:22:2222,0x1234567890123456789012345678901234567890 -socksd=true" { + t.Fatalf("filtered = %q", got) + } + if len(removed) != 2 { + t.Fatalf("removed = %#v, want 2 items", removed) + } + gotRemoved := map[int]bool{removed[0]: true, removed[1]: true} + if !gotRemoved[80] || !gotRemoved[2022] { + t.Fatalf("removed = %#v, want ports 80 and 2022", removed) + } +} + +func TestManagedFlagExternPortSupportsFilesSpec(t *testing.T) { + port, err := managedFlagExternPort("-files", "8080,example.diode") + if err != nil { + t.Fatalf("managedFlagExternPort(-files) error = %v", err) + } + if port != 8080 { + t.Fatalf("port = %d, want 8080", port) + } +} + +func TestRefreshRequestDerivedConfigDedupesBinds(t *testing.T) { + cfg := newRootConfig() + cfg.SBinds = config.StringValues{ + "8080:0x8911295322a1b94539e258e46f18e33acf21b48a:80", + "8080:0x8911295322a1b94539e258e46f18e33acf21b48a:80", + } + if err := refreshRequestDerivedConfig(cfg); err != nil { + t.Fatalf("refreshRequestDerivedConfig() error = %v", err) + } + if len(cfg.SBinds) != 1 { + t.Fatalf("SBinds = %#v, want one unique bind", cfg.SBinds) + } + if len(cfg.Binds) != 1 { + t.Fatalf("Binds = %#v, want one parsed bind", cfg.Binds) + } +} + +func TestMergeImplicitPublishArgsPreservesExistingPublishFlagsAndDedupesBinds(t *testing.T) { + prev := daemonState + defer func() { daemonState = prev }() + daemonState = &runtimeDaemon{ + activeMode: "publish", + activeArgs: []string{ + "-bind", "8080:0x8911295322a1b94539e258e46f18e33acf21b48a:80", + "publish", "-public", "80", + }, + } + + got := mergeImplicitPublishArgs([]string{ + "-bind", "8080:0x8911295322a1b94539e258e46f18e33acf21b48a:80", + "publish", + }) + + want := []string{ + "-bind", "8080:0x8911295322a1b94539e258e46f18e33acf21b48a:80", + "publish", "-public", "80", + } + if strings.Join(got, "\x00") != strings.Join(want, "\x00") { + t.Fatalf("mergeImplicitPublishArgs() = %#v, want %#v", got, want) + } +} + +func TestMergeImplicitPublishArgsAppendsNewBindToExistingPublishState(t *testing.T) { + prev := daemonState + defer func() { daemonState = prev }() + daemonState = &runtimeDaemon{ + activeMode: "publish", + activeArgs: []string{ + "-bind", "8080:0x8911295322a1b94539e258e46f18e33acf21b48a:80", + "publish", "-public", "80", + }, + } + + got := mergeImplicitPublishArgs([]string{ + "-bind", "8081:0x8911295322a1b94539e258e46f18e33acf21b48a:80", + "publish", + }) + + want := []string{ + "-bind", "8080:0x8911295322a1b94539e258e46f18e33acf21b48a:80", + "-bind", "8081:0x8911295322a1b94539e258e46f18e33acf21b48a:80", + "publish", "-public", "80", + } + if strings.Join(got, "\x00") != strings.Join(want, "\x00") { + t.Fatalf("mergeImplicitPublishArgs() = %#v, want %#v", got, want) + } +} + +func TestSanitizeModeArgsRemovesStartupFlagsButKeepsModeFlags(t *testing.T) { + got := sanitizeModeArgs("publish", []string{ + "-update=false", + "-dbpath", "/tmp/test.db", + "-bind", "8080:0x8911295322a1b94539e258e46f18e33acf21b48a:80", + "publish", + "-public", "80", + }) + want := []string{ + "-bind", "8080:0x8911295322a1b94539e258e46f18e33acf21b48a:80", + "publish", + "-public", "80", + } + if strings.Join(got, "\x00") != strings.Join(want, "\x00") { + t.Fatalf("sanitizeModeArgs() = %#v, want %#v", got, want) + } +} + +type testingLoggerWriter struct { + t *testing.T +} + +func (w testingLoggerWriter) Write(p []byte) (int, error) { + w.t.Helper() + return len(p), nil +} diff --git a/cmd/diode/daemon_transport_unix.go b/cmd/diode/daemon_transport_unix.go new file mode 100644 index 00000000..fe6d1991 --- /dev/null +++ b/cmd/diode/daemon_transport_unix.go @@ -0,0 +1,163 @@ +//go:build !windows + +package main + +import ( + "encoding/json" + "fmt" + "net" + "os" + "os/exec" + "path/filepath" + "syscall" + "time" +) + +func daemonPaths() (string, string, error) { + base, err := os.UserConfigDir() + if err != nil { + return "", "", err + } + dir := filepath.Join(base, "diode") + if err := os.MkdirAll(dir, 0700); err != nil { + return "", "", err + } + socketPath := filepath.Join(dir, "daemon.sock") + return socketPath, metaPathFromSocket(socketPath), nil +} + +func metaPathFromSocket(socketPath string) string { + if socketPath == "" { + socketPath, _, _ = daemonPaths() + } + return socketPath + ".json" +} + +func daemonListen(socketPath string) (net.Listener, error) { + ln, err := net.Listen("unix", socketPath) + if err != nil { + return nil, err + } + if err := os.Chmod(socketPath, 0600); err != nil { + _ = ln.Close() + _ = os.Remove(socketPath) + return nil, err + } + return ln, nil +} + +func dialDaemon(socketPath string) (net.Conn, error) { + if socketPath == "" { + path, _, err := daemonPaths() + if err != nil { + return nil, err + } + socketPath = path + } + return net.DialTimeout("unix", socketPath, 500*time.Millisecond) +} + +func cleanupDaemonTransport(socketPath string) { + _ = os.Remove(socketPath) +} + +func daemonSignals() []os.Signal { + return []os.Signal{os.Interrupt, syscall.SIGTERM} +} + +func spawnDaemon(spec daemonStartupSpec) error { + specBytes, err := json.Marshal(spec) + if err != nil { + return err + } + r, w, err := os.Pipe() + if err != nil { + return err + } + defer r.Close() + devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0600) + if err != nil { + _ = w.Close() + return err + } + defer devNull.Close() + + cmd := exec.Command(os.Args[0], daemonCommandName) + cmd.Stdin = devNull + cmd.Stdout = devNull + cmd.Stderr = devNull + cmd.ExtraFiles = []*os.File{w} + cmd.Env = append(os.Environ(), + fmt.Sprintf("%s=3", envDaemonReadyFD), + fmt.Sprintf("%s=%s", envDaemonStartupSpec, string(specBytes)), + ) + cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true} + if err := cmd.Start(); err != nil { + _ = w.Close() + return err + } + _ = w.Close() + + done := make(chan error, 1) + go func() { + buf := make([]byte, 1) + _, err := r.Read(buf) + done <- err + }() + select { + case err := <-done: + return err + case <-time.After(10 * time.Second): + _ = cmd.Process.Kill() + return fmt.Errorf("timed out waiting for daemon startup") + } +} + +func daemonRestartSelf(cmd string, startup daemonStartupSpec) error { + env, err := daemonRestartEnv(startup) + if err != nil { + return err + } + r, w, err := os.Pipe() + if err != nil { + return err + } + defer r.Close() + devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0600) + if err != nil { + _ = w.Close() + return err + } + defer devNull.Close() + + child := exec.Command(cmd, daemonCommandName) + child.Stdin = devNull + child.Stdout = devNull + child.Stderr = devNull + child.ExtraFiles = []*os.File{w} + child.Env = append(env, fmt.Sprintf("%s=3", envDaemonReadyFD)) + child.SysProcAttr = &syscall.SysProcAttr{Setsid: true} + if err := child.Start(); err != nil { + _ = w.Close() + return err + } + _ = w.Close() + + done := make(chan error, 1) + go func() { + buf := make([]byte, 1) + _, err := r.Read(buf) + done <- err + }() + select { + case err := <-done: + if err != nil { + return err + } + case <-time.After(10 * time.Second): + _ = child.Process.Kill() + return fmt.Errorf("timed out waiting for restarted daemon startup") + } + os.Exit(0) + return nil +} diff --git a/cmd/diode/daemon_transport_windows.go b/cmd/diode/daemon_transport_windows.go new file mode 100644 index 00000000..ca3759cd --- /dev/null +++ b/cmd/diode/daemon_transport_windows.go @@ -0,0 +1,134 @@ +//go:build windows + +package main + +import ( + "crypto/sha1" + "encoding/hex" + "encoding/json" + "fmt" + "net" + "os" + "os/exec" + "path/filepath" + "syscall" + "time" + + "github.com/Microsoft/go-winio" +) + +func daemonPaths() (string, string, error) { + base, err := os.UserConfigDir() + if err != nil { + return "", "", err + } + dir := filepath.Join(base, "diode") + if err := os.MkdirAll(dir, 0700); err != nil { + return "", "", err + } + sum := sha1.Sum([]byte(dir)) + socketPath := `\\.\pipe\diode-client-` + hex.EncodeToString(sum[:8]) + return socketPath, metaPathFromSocket(socketPath), nil +} + +func metaPathFromSocket(socketPath string) string { + base, err := os.UserConfigDir() + if err != nil { + return "daemon.json" + } + return filepath.Join(base, "diode", "daemon.json") +} + +func daemonListen(socketPath string) (net.Listener, error) { + return winio.ListenPipe(socketPath, &winio.PipeConfig{ + SecurityDescriptor: "D:P(A;;GA;;;SY)(A;;GA;;;BA)(A;;GA;;;OW)", + }) +} + +func dialDaemon(socketPath string) (net.Conn, error) { + if socketPath == "" { + path, _, err := daemonPaths() + if err != nil { + return nil, err + } + socketPath = path + } + timeout := 500 * time.Millisecond + return winio.DialPipe(socketPath, &timeout) +} + +func cleanupDaemonTransport(socketPath string) {} + +func daemonSignals() []os.Signal { + return []os.Signal{os.Interrupt} +} + +func spawnDaemon(spec daemonStartupSpec) error { + specBytes, err := json.Marshal(spec) + if err != nil { + return err + } + devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0600) + if err != nil { + return err + } + defer devNull.Close() + + cmd := exec.Command(os.Args[0], daemonCommandName) + cmd.Stdin = devNull + cmd.Stdout = devNull + cmd.Stderr = devNull + cmd.Env = append(os.Environ(), fmt.Sprintf("%s=%s", envDaemonStartupSpec, string(specBytes))) + cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} + if err := cmd.Start(); err != nil { + return err + } + + deadline := time.Now().Add(10 * time.Second) + for time.Now().Before(deadline) { + meta, err := readDaemonMetadata() + if err == nil && meta.PID == cmd.Process.Pid { + if _, err := dialDaemon(meta.SocketPath); err == nil { + return nil + } + } + time.Sleep(100 * time.Millisecond) + } + _ = cmd.Process.Kill() + return fmt.Errorf("timed out waiting for daemon startup") +} + +func daemonRestartSelf(cmd string, startup daemonStartupSpec) error { + env, err := daemonRestartEnv(startup) + if err != nil { + return err + } + devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0600) + if err != nil { + return err + } + defer devNull.Close() + + child := exec.Command(cmd, daemonCommandName) + child.Stdin = devNull + child.Stdout = devNull + child.Stderr = devNull + child.Env = env + child.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} + if err := child.Start(); err != nil { + return err + } + deadline := time.Now().Add(10 * time.Second) + for time.Now().Before(deadline) { + meta, err := readDaemonMetadata() + if err == nil && meta.PID == child.Process.Pid { + if _, err := dialDaemon(meta.SocketPath); err == nil { + os.Exit(0) + return nil + } + } + time.Sleep(100 * time.Millisecond) + } + _ = child.Process.Kill() + return fmt.Errorf("timed out waiting for restarted daemon startup") +} diff --git a/cmd/diode/diode.go b/cmd/diode/diode.go index 8ea56cd4..ad9e0039 100644 --- a/cmd/diode/diode.go +++ b/cmd/diode/diode.go @@ -27,6 +27,10 @@ func main() { os.Exit(0) } + if handled, exitCode := maybeHandleDaemonCLI(os.Args[1:]); handled { + os.Exit(exitCode) + } + cfg := config.AppConfig err := diodeCmd.Execute() if err != nil { diff --git a/cmd/diode/fetch.go b/cmd/diode/fetch.go index 59edeea5..847a5fdd 100644 --- a/cmd/diode/fetch.go +++ b/cmd/diode/fetch.go @@ -66,12 +66,12 @@ func (fp *fetchProgress) Read(p []byte) (int, error) { if fp.read == 0 { if fp.contentLength > 0 { fp.pointSize = float64(fp.contentLength) / 60 - fmt.Printf("Downloading %d bytes into '%s'.\n", fp.contentLength, fp.name) - fmt.Println("[------------------------------------------------------------]") - fmt.Printf("[") + stdoutf("Downloading %d bytes into '%s'.\n", fp.contentLength, fp.name) + stdoutln("[------------------------------------------------------------]") + stdoutf("[") } else { fp.pointSize = 0 - fmt.Printf("Downloading into '%s'.\n", fp.name) + stdoutf("Downloading into '%s'.\n", fp.name) } } n, err := fp.Reader.Read(p) @@ -79,13 +79,13 @@ func (fp *fetchProgress) Read(p []byte) (int, error) { if fp.pointSize > 0 { for int64(float64(fp.read)/fp.pointSize) > fp.points { - fmt.Printf("#") + stdoutf("#") fp.points++ } } if err == io.EOF && fp.contentLength > 0 && fp.read == fp.contentLength { - fmt.Printf("] Done!\n") + stdoutf("] Done!\n") } return n, err @@ -187,16 +187,28 @@ func fetchHandler() (err error) { defer resp.Body.Close() var f *os.File + var out io.Writer if len(fetchCfg.Output) > 0 { f, err = os.OpenFile(fetchCfg.Output, os.O_CREATE|os.O_WRONLY, 0600) if err != nil { return } + out = f } else if fetchCfg.Verbose { - f = os.Stdout + out = stdoutWriter() src = resp.Body } + if out != nil { + _, err = io.Copy(out, src) + if f != nil { + f.Close() + } + if err != nil { + return + } + return + } if f != nil { io.Copy(f, src) f.Close() diff --git a/cmd/diode/files.go b/cmd/diode/files.go index 9cd6dfcc..7a900701 100644 --- a/cmd/diode/files.go +++ b/cmd/diode/files.go @@ -8,7 +8,6 @@ import ( "fmt" "net" "net/http" - "os" "strconv" "strings" "time" @@ -87,9 +86,9 @@ func filesHandler() error { cfg := config.AppConfig spec := strings.TrimSpace(filesCmd.Flag.Arg(0)) if spec == "" { - fmt.Fprintln(os.Stderr, "usage: diode files [-fileroot ] ") - fmt.Fprintln(os.Stderr, " files-spec: for public, or ,... for private") - os.Exit(2) + stderrln("usage: diode files [-fileroot ] ") + stderrln(" files-spec: for public, or ,... for private") + return newExitStatusError(2, "missing files publish spec") } portStr, mode, err := expandFilesSpec(spec) @@ -109,12 +108,13 @@ func filesHandler() error { if err := app.Start(); err != nil { return err } + beginRuntimeMode("files") cleanup, err := startFileListener(p, filesFileroot) if err != nil { return err } - app.Defer(cleanup) + registerRuntimeCleanup(cleanup) portMap := map[int]*config.Port{p.To: p} cfg.PublishedPorts = portMap @@ -157,6 +157,9 @@ func filesHandler() error { cfg.PrintLabel(fmt.Sprintf("Port %5d", bind.LocalPort), fmt.Sprintf("%5s %11s:%d", config.ProtocolName(bind.Protocol), bind.To, bind.ToPort)) } } + if isDaemonApplyRequest() { + return nil + } app.Wait() return nil diff --git a/cmd/diode/gateway.go b/cmd/diode/gateway.go index 0f8ab4d2..4504957d 100644 --- a/cmd/diode/gateway.go +++ b/cmd/diode/gateway.go @@ -52,6 +52,7 @@ func gatewayHandler() (err error) { if err != nil { return } + beginRuntimeMode("gateway") cfg := config.AppConfig cfg.EnableProxyServer = true if cfg.EnableAPIServer { @@ -108,6 +109,9 @@ func gatewayHandler() (err error) { if err := proxyServer.Start(); err != nil { cfg.Logger.Error(err.Error()) } + if isDaemonApplyRequest() { + return nil + } app.Wait() return } diff --git a/cmd/diode/join.go b/cmd/diode/join.go index 1d3da19a..e8736e04 100644 --- a/cmd/diode/join.go +++ b/cmd/diode/join.go @@ -3082,6 +3082,28 @@ func runContractController(cfg *config.Config) error { } } +func runContractControllerUntil(cfg *config.Config, stopCh <-chan struct{}) error { + for { + select { + case <-stopCh: + cfg.Logger.Info("Join controller stopped") + return nil + default: + } + + if err := contractSync(cfg); err != nil { + cfg.Logger.Warn("Perimeter contract sync failed: %v", err) + } + + select { + case <-stopCh: + cfg.Logger.Info("Join controller stopped") + return nil + case <-time.After(30 * time.Second): + } + } +} + // updatePublishedPorts updates the published ports based on contract configuration func updatePublishedPorts(client *rpc.Client, props map[string]string) error { cfg := config.AppConfig @@ -3243,6 +3265,9 @@ func joinHandler() (err error) { if err != nil { return } + if isDaemonApplyRequest() { + beginRuntimeMode("join") + } // Initial contract sync to apply perimeter before starting services if syncErr := runContractControllerOnce(cfg); syncErr != nil { @@ -3253,6 +3278,15 @@ func joinHandler() (err error) { if dryRun { return nil } + if isDaemonApplyRequest() { + done := make(chan struct{}) + app.SetModeDone(done) + go func() { + defer close(done) + _ = runContractControllerUntil(cfg, app.ModeStopChan()) + }() + return nil + } return runContractController(cfg) } diff --git a/cmd/diode/mode_helpers.go b/cmd/diode/mode_helpers.go new file mode 100644 index 00000000..ebbd69d6 --- /dev/null +++ b/cmd/diode/mode_helpers.go @@ -0,0 +1,20 @@ +package main + +func beginRuntimeMode(name string) { + if !isDaemonApplyRequest() { + return + } + app.StopMode() + app.BeginMode(name) +} + +func registerRuntimeCleanup(cleanup func()) { + if cleanup == nil { + return + } + if isDaemonApplyRequest() { + app.ModeDefer(cleanup) + return + } + app.Defer(cleanup) +} diff --git a/cmd/diode/output_helpers.go b/cmd/diode/output_helpers.go new file mode 100644 index 00000000..42c7150c --- /dev/null +++ b/cmd/diode/output_helpers.go @@ -0,0 +1,37 @@ +package main + +import ( + "fmt" + "io" + "os" + + "github.com/diodechain/diode_client/config" +) + +func stdoutWriter() io.Writer { + cfg := config.AppConfig + if cfg != nil && cfg.StdoutWriter != nil { + return cfg.StdoutWriter + } + return os.Stdout +} + +func stderrWriter() io.Writer { + cfg := config.AppConfig + if cfg != nil && cfg.StderrWriter != nil { + return cfg.StderrWriter + } + return os.Stderr +} + +func stdoutf(format string, args ...interface{}) { + fmt.Fprintf(stdoutWriter(), format, args...) +} + +func stdoutln(args ...interface{}) { + fmt.Fprintln(stdoutWriter(), args...) +} + +func stderrln(args ...interface{}) { + fmt.Fprintln(stderrWriter(), args...) +} diff --git a/cmd/diode/publish.go b/cmd/diode/publish.go index b4d4019c..ab57b71d 100644 --- a/cmd/diode/publish.go +++ b/cmd/diode/publish.go @@ -6,7 +6,6 @@ package main import ( "fmt" "net" - "os" "regexp" "strconv" "strings" @@ -32,10 +31,10 @@ var ( Type: command.DaemonCommand, SingleConnection: true, } - enableStaticServer = false - staticServer staticserver.StaticHTTPServer - scfg staticserver.Config - publishFileSpecs config.StringValues + enableStaticServer = false + staticServer staticserver.StaticHTTPServer + scfg staticserver.Config + publishFileSpecs config.StringValues publishFileFileroot string ) @@ -381,7 +380,7 @@ func publishHandler() (err error) { return } }() - app.Defer(func() { + registerRuntimeCleanup(func() { // Since we didn't use ListenAndServe, call // ln.Close() instead of staticServer.Close() ln.Close() @@ -397,50 +396,19 @@ func publishHandler() (err error) { } if len(cfg.PublishedPorts) == 0 && len(cfg.Binds) == 0 { - fmt.Println() - fmt.Println("ERROR: Can't run publish without any arguments!") - fmt.Println(" HINT: Try 'diode publish -public 8080:80' to publish a local port") - fmt.Println(" HINT: Check our docs to learn more about publishing ports: https://diode.io/docs/getting-started.html") - fmt.Println(" HINT: Or run 'diode --help' to see all commands") - os.Exit(2) + stdoutln() + stdoutln("ERROR: Can't run publish without any arguments!") + stdoutln(" HINT: Try 'diode publish -public 8080:80' to publish a local port") + stdoutln(" HINT: Check our docs to learn more about publishing ports: https://diode.io/docs/getting-started.html") + stdoutln(" HINT: Or run 'diode --help' to see all commands") + return newExitStatusError(2, "publish requires at least one published port or bind") } + beginRuntimeMode("publish") + if len(cfg.PublishedPorts) > 0 { - cfg.PrintInfo("") - name := cfg.ClientAddr.HexString() - if cfg.ClientName != "" { - name = cfg.ClientName - } app.clientManager.GetPool().SetPublishedPorts(cfg.PublishedPorts) - for _, port := range cfg.PublishedPorts { - if port.Mode == config.PublicPublishedMode { - if port.To == httpPort { - cfg.PrintLabel("HTTP Gateway Enabled", fmt.Sprintf("http://%s.diode.link/", name)) - } - if (8000 <= port.To && port.To <= 8100) || (8400 <= port.To && port.To <= 8500) { - cfg.PrintLabel("HTTP Gateway Enabled", fmt.Sprintf("https://%s.diode.link:%d/", name, port.To)) - } - } - } - cfg.PrintLabel("Port ", " ") - for _, port := range cfg.PublishedPorts { - - addrs := make([]string, 0, len(port.Allowlist)+len(port.BnsAllowlist)) - for addr := range port.Allowlist { - addrs = append(addrs, addr.HexString()) - } - for bnsName := range port.BnsAllowlist { - addrs = append(addrs, bnsName) - } - for drive := range port.DriveAllowList { - addrs = append(addrs, drive.HexString()) - } - for driveMember := range port.DriveMemberAllowList { - addrs = append(addrs, driveMember.HexString()) - } - host := publishedPortDisplayHost(port) - cfg.PrintLabel(fmt.Sprintf("Port %12s", host), fmt.Sprintf("%8d %10s %s %s", port.To, config.ModeName(port.Mode), config.ProtocolName(port.Protocol), strings.Join(addrs, ","))) - } + renderPublishedPortMap(cfg, cfg.PublishedPorts) } if cfg.EnableAPIServer { @@ -471,11 +439,10 @@ func publishHandler() (err error) { if len(cfg.Binds) > 0 { socksServer.SetBinds(cfg.Binds) cfg.Binds = socksServer.GetBinds() // resolve "auto" ports for logs and API - cfg.PrintInfo("") - cfg.PrintLabel("Bind ", " ") - for _, bind := range cfg.Binds { - cfg.PrintLabel(fmt.Sprintf("Port %5d", bind.LocalPort), fmt.Sprintf("%5s %11s:%d", config.ProtocolName(bind.Protocol), bind.To, bind.ToPort)) - } + renderBindMap(cfg, cfg.Binds) + } + if isDaemonApplyRequest() { + return nil } for { app.Wait() diff --git a/cmd/diode/publish_render.go b/cmd/diode/publish_render.go new file mode 100644 index 00000000..fbb85560 --- /dev/null +++ b/cmd/diode/publish_render.go @@ -0,0 +1,58 @@ +package main + +import ( + "fmt" + "sort" + "strings" + + "github.com/diodechain/diode_client/config" +) + +func renderPublishedPortMap(cfg *config.Config, ports map[int]*config.Port) { + if len(ports) == 0 { + return + } + cfg.PrintInfo("") + name := cfg.ClientAddr.HexString() + if cfg.ClientName != "" { + name = cfg.ClientName + } + + keys := make([]int, 0, len(ports)) + for key := range ports { + keys = append(keys, key) + } + sort.Ints(keys) + + for _, key := range keys { + port := ports[key] + if port.Mode == config.PublicPublishedMode { + if port.To == httpPort { + cfg.PrintLabel("HTTP Gateway Enabled", fmt.Sprintf("http://%s.diode.link/", name)) + } + if (8000 <= port.To && port.To <= 8100) || (8400 <= port.To && port.To <= 8500) { + cfg.PrintLabel("HTTP Gateway Enabled", fmt.Sprintf("https://%s.diode.link:%d/", name, port.To)) + } + } + } + + cfg.PrintLabel("Port ", " ") + for _, key := range keys { + port := ports[key] + cfg.PrintLabel( + fmt.Sprintf("Port %12s", publishedPortDisplayHost(port)), + fmt.Sprintf("%8d %10s %s %s", port.To, config.ModeName(port.Mode), config.ProtocolName(port.Protocol), strings.Join(portAllowlistStrings(port), ",")), + ) + } +} + +func renderBindMap(cfg *config.Config, binds []config.Bind) { + if len(binds) == 0 { + return + } + cfg.PrintInfo("") + cfg.PrintLabel("Bind ", " ") + for _, bind := range binds { + cfg.PrintLabel(fmt.Sprintf("Port %5d", bind.LocalPort), fmt.Sprintf("%5s %11s:%d", config.ProtocolName(bind.Protocol), bind.To, bind.ToPort)) + } +} diff --git a/cmd/diode/pushpull.go b/cmd/diode/pushpull.go index 628bf121..3ec0121d 100644 --- a/cmd/diode/pushpull.go +++ b/cmd/diode/pushpull.go @@ -135,9 +135,9 @@ func parsePeerPort(s string) (host string, port int, err error) { func pushHandler() error { args := pushCmd.Flag.Args() if len(args) < 2 { - fmt.Fprintln(os.Stderr, "usage: diode push :") - fmt.Fprintln(os.Stderr, " diode push ::") - os.Exit(2) + stderrln("usage: diode push :") + stderrln(" diode push ::") + return newExitStatusError(2, "missing push arguments") } if len(args) > 2 { return fmt.Errorf("too many arguments (quote :: as one token)") @@ -208,15 +208,15 @@ func pushHandler() error { } return fmt.Errorf("push failed: %s", msg) } - config.AppConfig.Logger.Info("push ok: %s -> %s", localPath, urlStr) + config.AppConfig.PrintInfo(fmt.Sprintf("push ok: %s -> %s", localPath, urlStr)) return nil } func pullHandler() error { args := pullCmd.Flag.Args() if len(args) < 1 { - fmt.Fprintln(os.Stderr, "usage: diode pull :: []") - os.Exit(2) + stderrln("usage: diode pull :: []") + return newExitStatusError(2, "missing pull arguments") } if len(args) > 2 { return fmt.Errorf("too many arguments (quote paths that contain spaces)") @@ -288,7 +288,7 @@ func pullHandler() error { if cerr != nil { return cerr } - config.AppConfig.Logger.Info("pull ok: %s -> ./%s", urlStr, base) + config.AppConfig.PrintInfo(fmt.Sprintf("pull ok: %s -> ./%s", urlStr, base)) return nil } @@ -311,6 +311,6 @@ func pullHandler() error { if cerr != nil { return cerr } - config.AppConfig.Logger.Info("pull ok: %s -> %s", urlStr, dest) + config.AppConfig.PrintInfo(fmt.Sprintf("pull ok: %s -> %s", urlStr, dest)) return nil } diff --git a/cmd/diode/socksd.go b/cmd/diode/socksd.go index 1e5b0240..0ab4150e 100644 --- a/cmd/diode/socksd.go +++ b/cmd/diode/socksd.go @@ -37,6 +37,7 @@ func socksdHandler() (err error) { if err != nil { return } + beginRuntimeMode("socksd") cfg := config.AppConfig cfg.EnableSocksServer = true cfg.EnableProxyServer = true @@ -75,6 +76,9 @@ func socksdHandler() (err error) { return } app.SetSocksServer(socksServer) + if isDaemonApplyRequest() { + return nil + } app.Wait() return } diff --git a/cmd/diode/ssh.go b/cmd/diode/ssh.go index 249be73f..567e5fea 100644 --- a/cmd/diode/ssh.go +++ b/cmd/diode/ssh.go @@ -59,11 +59,29 @@ func sshHandler() (err error) { } defer cleanupProxy() cfg.PrintLabel("Using local diode client", proxyAddr) + sshIndex := ssh_indexFromArgs(os.Args) + if sshIndex == -1 { + cfg.PrintError("ssh command not found", errors.New("ssh command not found")) + os.Exit(1) + } + return runSSHWithProxyAddr(proxyAddr, normalizeSSHArgs(os.Args[sshIndex+1:])) +} +func ssh_indexFromArgs(osArgs []string) int { + for i, arg := range osArgs { + if arg == sshCommandName { + return i + } + } + return -1 +} + +func runSSHWithProxyAddr(proxyAddr string, sshArgs []string) error { + cfg := config.AppConfig diodeExe, err := os.Executable() if err != nil { cfg.PrintError("Could not determine diode executable path", err) - os.Exit(1) + return newExitStatusError(1, "%s", err.Error()) } args := []string{ @@ -71,33 +89,19 @@ func sshHandler() (err error) { "-o", "ProxyCommand=" + buildSSHProxyCommand(runtimeGOOS, diodeExe, proxyAddr), "-o", "StrictHostKeyChecking=accept-new", } - os_args := os.Args - // Remove all args before the ssh command by finding "ssh" and removing all args before it - ssh_index := -1 - for i, arg := range os_args { - if arg == "ssh" { - ssh_index = i - break - } - } - if ssh_index == -1 { - cfg.PrintError("ssh command not found", errors.New("ssh command not found")) - os.Exit(1) - } - sshArgs := normalizeSSHArgs(os_args[ssh_index+1:]) args = append(args, sshArgs...) if target := extractSSHTarget(sshArgs); target != "" { if err := validateSSHTarget(target); err != nil { cfg.PrintError("Invalid SSH target", err) - os.Exit(1) + return newExitStatusError(1, "%s", err.Error()) } } identityFile, cleanup, err := createEphemeralSSHIdentity() if err != nil { cfg.PrintError("Could not create ephemeral ssh identity", err) - os.Exit(1) + return newExitStatusError(1, "%s", err.Error()) } defer cleanup() args = append(args, "-i", identityFile) @@ -105,7 +109,7 @@ func sshHandler() (err error) { ssh, err := findOpenSSHTool("ssh") if err != nil { cfg.PrintError("ssh not found", err) - os.Exit(1) + return newExitStatusError(1, "%s", err.Error()) } cmd := exec.Command(ssh, args[1:]...) cmd.Stdin = os.Stdin @@ -116,12 +120,30 @@ func sshHandler() (err error) { if err != nil { var exitErr *exec.ExitError if errors.As(err, &exitErr) { - os.Exit(exitErr.ExitCode()) + return newExitStatusError(exitErr.ExitCode(), "ssh exited with status %d", exitErr.ExitCode()) } cfg.PrintError("Could not execute ssh", err) - os.Exit(1) + return newExitStatusError(1, "%s", err.Error()) + } + return nil +} + +func runSSHViaDaemonLease(commandArgs []string, resp daemonResponse) int { + if len(commandArgs) == 0 { + stderrln("missing ssh command arguments") + return 1 + } + cfg := config.AppConfig + cfg.PrintLabel("Using diode daemon proxy", resp.ProxyAddr) + defer func() { + if resp.LeaseID != "" { + _ = releaseDaemonLease(resp.LeaseID) + } + }() + if err := runSSHWithProxyAddr(resp.ProxyAddr, normalizeSSHArgs(commandArgs[1:])); err != nil { + return exitCodeFromError(err) } - return + return 0 } func normalizeSSHArgs(args []string) []string { diff --git a/cmd/diode/token.go b/cmd/diode/token.go index 446eeb96..56910ee4 100644 --- a/cmd/diode/token.go +++ b/cmd/diode/token.go @@ -67,8 +67,7 @@ func parseUnitAndValue(src string) (val int, unit string) { func tokenHandler() (err error) { if tokenCfg.CheckBalance { - showBalance() - return + return showBalance() } valWei, _ := parseUnitAndValue(tokenCfg.Value) diff --git a/cmd/diode/update.go b/cmd/diode/update.go index b9082034..f78b1c55 100644 --- a/cmd/diode/update.go +++ b/cmd/diode/update.go @@ -38,11 +38,25 @@ func writeLastUpdateAt() { } func updateHandler() (err error) { - doUpdate() - return + _, err = doUpdate(updateRestartStandalone) + return err } -func doUpdate() int { +type updateRestartMode int + +const ( + updateRestartStandalone updateRestartMode = iota + updateRestartDeferred +) + +func runDaemonUpdate(args []string) (string, error) { + if len(args) == 0 || args[0] != "update" { + return "", newExitStatusError(2, "missing update command") + } + return doUpdate(updateRestartDeferred) +} + +func doUpdate(restartMode updateRestartMode) (string, error) { cfg := config.AppConfig m := &update.Manager{ Command: "diode", @@ -62,12 +76,15 @@ func doUpdate() int { // Will recheck for an update in 24 hours go func() { time.Sleep(time.Hour * 24) - doUpdate() + _, _ = doUpdate(updateRestartStandalone) }() if err == nil { writeLastUpdateAt() } - return 0 + if err != nil { + return "", newExitStatusError(1, "%s", err.Error()) + } + return "", nil } // searching for binary in path @@ -86,14 +103,17 @@ func doUpdate() int { dir := filepath.Dir(binExe) if err := m.InstallTo(tarball, dir); err != nil { cfg.PrintError("Error installing", err) - return 129 + return "", newExitStatusError(129, "%s", err.Error()) } cmd := path.Join(dir, m.Command) - fmt.Printf("Updated, restarting %s...\n", cmd) + stdoutf("Updated, restarting %s...\n", cmd) writeLastUpdateAt() + if restartMode == updateRestartDeferred { + return cmd, nil + } update.Restart(cmd) - return 0 + return "", nil } func download(m *update.Manager) (string, bool, error) { @@ -127,7 +147,7 @@ func download(m *update.Manager) (string, bool, error) { } // whitespace - fmt.Println() + stdoutln() // download tarball to a tmp dir tarball, err := a.DownloadProxy(progress.Reader) diff --git a/config/flag.go b/config/flag.go index 9466d86d..a883cee3 100644 --- a/config/flag.go +++ b/config/flag.go @@ -5,6 +5,7 @@ package config import ( "fmt" + "io" "os" "strconv" "strings" @@ -44,6 +45,7 @@ type Config struct { EnableUpdate bool `yaml:"update,omitempty" json:"update,omitempty"` EnableMetrics bool `yaml:"metrics,omitempty" json:"metrics,omitempty"` EnableTray bool `yaml:"tray,omitempty" json:"tray,omitempty"` + DisableDaemon bool `yaml:"-" json:"-"` BlockquickDowngrade bool `yaml:"bqdowngrade,omitempty" json:"bqdowngrade,omitempty"` RemoteRPCAddrs StringValues `yaml:"diodeaddrs,omitempty" json:"diodeaddrs,omitempty"` RemoteRPCTimeout time.Duration `yaml:"timeout,omitempty" json:"timeout,omitempty"` @@ -100,6 +102,8 @@ type Config struct { LogMode int `yaml:"-" json:"-"` LogDateTime bool `yaml:"-" json:"-"` Logger *Logger `yaml:"-" json:"-"` + StdoutWriter io.Writer `yaml:"-" json:"-"` + StderrWriter io.Writer `yaml:"-" json:"-"` ConfigFilePath string `yaml:"-" json:"-"` Binds []Bind `yaml:"-" json:"-"` BNSForce bool `yaml:"-" json:"-"` @@ -228,18 +232,27 @@ func (cfg *Config) Blocklists() map[Address]bool { func (cfg *Config) PrintLabel(label string, value string) { msg := fmt.Sprintf("%-20s : %-42s", label, value) msg = strings.Replace(msg, "%", "%%", -1) + if cfg.StdoutWriter != nil { + fmt.Fprintln(cfg.StdoutWriter, msg) + } cfg.Logger.Info("%s", msg) } func (cfg *Config) PrintError(msg string, err error) { + text := fmt.Sprintf("%s: ", msg) if err != nil { - cfg.Logger.Error(fmt.Sprintf("%s: %s", msg, err.Error())) - } else { - cfg.Logger.Error(fmt.Sprintf("%s: ", msg)) + text = fmt.Sprintf("%s: %s", msg, err.Error()) + } + if cfg.StderrWriter != nil { + fmt.Fprintln(cfg.StderrWriter, text) } + cfg.Logger.Error(text) } func (cfg *Config) PrintInfo(msg string) { + if cfg.StdoutWriter != nil { + fmt.Fprintln(cfg.StdoutWriter, msg) + } cfg.Logger.Info("%s", msg) } diff --git a/docs/manual-cli-test-plan.md b/docs/manual-cli-test-plan.md new file mode 100644 index 00000000..1ac6a802 --- /dev/null +++ b/docs/manual-cli-test-plan.md @@ -0,0 +1,675 @@ +# Manual CLI Test Plan + +This plan covers the full `cmd/diode` CLI surface, including daemon-by-default behavior, multi-wallet access checks, file transfer, and the hidden/internal command paths that are exercised indirectly. + +## Scope + +Public commands covered: + +- top-level `--help` +- `version` +- `config` +- `query` +- `time` +- `fetch` +- `publish` +- `gateway` +- `socksd` +- `files` +- `push` +- `pull` +- `ssh` +- `daemon` +- `join` +- `bns` +- `token` +- `reset` +- `update` +- `mcp` + +Hidden/internal coverage: + +- `ssh-proxy` +- `__daemon__` + +## Lab Setup + +1. Build the lab and fixture environment: + +```bash +./scripts/manual/setup_cli_lab.sh up +source ./.manual/cli-lab/env.sh +``` + +2. The env file exports: + +- `DIODE_OWNER_DB` +- `DIODE_PEER_DB` +- `DIODE_VIEWER_DB` +- `DIODE_OWNER_ADDR` +- `DIODE_PEER_ADDR` +- `DIODE_VIEWER_ADDR` +- `DIODE_OWNER_HTTP_ROOT` +- `DIODE_PEER_HTTP_ROOT` +- `DIODE_OWNER_FILES_ROOT` +- `DIODE_PEER_FILES_ROOT` +- `DIODE_OUTPUT_DIR` + +3. The env file also defines wrapper helpers: + +- `downer ...` + Runs owner wallet commands with `-no-daemon`. +- `dpeer ...` + Runs peer wallet commands with `-no-daemon`. +- `dviewer ...` + Runs viewer wallet commands with `-no-daemon`. +- `ddaemon ...` + Runs owner wallet commands with daemon mode enabled. +- `dstopdaemon` + Stops and cleans the global daemon transport files. + +4. The lab also starts two local HTTP fixtures: + +- `http://127.0.0.1:18080/` backed by `owner-public-fixture` +- `http://127.0.0.1:18081/` backed by `peer-public-fixture` + +## Execution Rules + +- Use `ddaemon` only for daemon coverage and single-wallet daemon-managed runtime tests. +- Use `downer`, `dpeer`, and `dviewer` for multi-wallet tests because the daemon is global per user and cannot safely represent multiple `-dbpath` identities at once. +- Before any daemon section, run `dstopdaemon`. +- When a test starts a long-running command in one terminal, keep that terminal open and run the verification command in a second terminal. +- Record for each test: + - exact command + - exit code + - stdout/stderr + - whether the daemon state or runtime state changed as expected + +## Wallet Roles + +- Owner wallet: + Hosts public/private ports and file listeners. +- Peer wallet: + Allowed client for private access tests. +- Viewer wallet: + Unauthorized client for private access tests. + +## Baseline Checks + +### Top-Level Help + +Command: + +```bash +./diode --help +./diode publish --help +./diode daemon --help || true +``` + +Verify: + +- help exits `0` +- help does not start the daemon +- help lists the expected public commands + +### Version + +Command: + +```bash +./diode version +``` + +Verify: + +- exits `0` +- prints OS/arch/CPU +- does not require a running daemon + +## Config and Identity + +### `config` + +Commands: + +```bash +downer config -list +downer config -set manual_key=manual_value +downer config -list +downer config -delete manual_key +downer config -list +downer config -unsafe -list +``` + +Verify: + +- first `config -list` creates the wallet DB if it does not exist +- output includes `
` and the owner address matches `DIODE_OWNER_ADDR` +- `manual_key` appears after `-set` +- `manual_key` disappears after `-delete` +- `-unsafe` prints the private key material only for the current wallet + +### `reset` + +Prerequisites: + +- funded owner wallet on the target network +- operator accepts destructive wallet/fleet change + +Commands: + +```bash +downer reset +downer reset -experimental +``` + +Verify: + +- deploys a new fleet contract +- prints new fleet address +- persists the fleet into the owner DB +- a second `reset` against an already initialized wallet reports that the client is already initialized + +## Read-Only Network Commands + +### `query` + +Commands: + +```bash +downer query -address "$DIODE_OWNER_ADDR" +downer query -address "$DIODE_PEER_ADDR" +``` + +Verify: + +- exits `0` +- prints account type when decodable +- prints one or more device tickets or a clear resolution failure +- device ticket output includes fleet, server, block, and validation status fields + +### `time` + +Command: + +```bash +downer time +``` + +Verify: + +- exits `0` +- prints minimum and maximum blockchain consensus time +- values are plausible and `maximum >= minimum` + +## Daemon Lifecycle and Dispatch + +### Implicit Daemon Startup + +Commands: + +```bash +dstopdaemon +ddaemon publish -public 18080:18080 +ddaemon daemon status +``` + +Verify: + +- first `publish` autostarts the hidden daemon +- `publish` returns quickly and does not remain attached to the foreground +- `daemon status` shows `Active mode: publish` +- `daemon status` shows the old-style published port map for port `18080` + +### Root-Flag Implicit Publish + +Commands: + +```bash +ddaemon -bind 19090:$DIODE_OWNER_ADDR:18080 +ddaemon daemon status +ddaemon -bind 19090:$DIODE_OWNER_ADDR:18080 +ddaemon daemon status +``` + +Verify: + +- root-only `-bind` is treated as implicit `publish` +- the first bind adds one bind entry +- the second identical bind does not duplicate +- `daemon status` still shows the existing `-public 18080:18080` published port table + +### `daemon status` + +Command: + +```bash +ddaemon daemon status +``` + +Verify: + +- shows PID and socket path +- shows active mode and mode args +- shows published port map and bind map when configured +- reports SOCKS/API state correctly + +### `daemon restart` + +Commands: + +```bash +ddaemon daemon restart +ddaemon daemon status +``` + +Verify: + +- exits `0` +- daemon comes back within the timeout +- active mode and published ports survive restart + +### `daemon ports remove` + +Commands: + +```bash +ddaemon daemon ports remove 18080 +ddaemon daemon status +``` + +Verify: + +- removes only the requested published port +- leaves unrelated binds intact +- if no published ports remain, mode is stopped cleanly + +### `daemon ports clear` + +Commands: + +```bash +ddaemon publish -public 18080:18080 +ddaemon daemon ports clear +ddaemon daemon status +``` + +Verify: + +- clears published ports +- stops the active publish/files mode +- `daemon status` returns `Active mode: none` + +### `daemon stop` + +Commands: + +```bash +ddaemon daemon stop +ddaemon daemon status +``` + +Verify: + +- stop exits `0` +- subsequent `daemon status` reports `not running` + +## Publish, Fetch, and Multi-Wallet Access + +### Public Publish + +Terminal A: + +```bash +dstopdaemon +ddaemon publish -public 18080:18080 +``` + +Terminal B: + +```bash +dpeer fetch -url "http://$DIODE_OWNER_ADDR.diode.link:18080/" -output "$DIODE_OUTPUT_DIR/public-owner.html" +cat "$DIODE_OUTPUT_DIR/public-owner.html" +``` + +Verify: + +- `publish` prints the old-style port map +- `fetch` exits `0` +- downloaded body contains `owner-public-fixture` + +### Private Publish With Allowlist + +Terminal A: + +```bash +ddaemon publish -private "18081:18081,$DIODE_PEER_ADDR" +``` + +Terminal B: + +```bash +dpeer fetch -url "http://$DIODE_OWNER_ADDR.diode.link:18081/" -output "$DIODE_OUTPUT_DIR/private-peer.html" +cat "$DIODE_OUTPUT_DIR/private-peer.html" +``` + +Terminal C: + +```bash +dviewer fetch -url "http://$DIODE_OWNER_ADDR.diode.link:18081/" -output "$DIODE_OUTPUT_DIR/private-viewer.html" +``` + +Verify: + +- peer wallet succeeds and sees `peer-public-fixture` only if owner fixture on `18081` is running +- viewer wallet fails with a clear access error or connection failure +- `daemon status` shows the allowlisted private port + +Note: + +- if you want a dedicated private-only fixture on `18081`, start one manually: + +```bash +python3 -m http.server 18081 --bind 127.0.0.1 --directory "$DIODE_PEER_HTTP_ROOT" +``` + +### `fetch` + +Commands: + +```bash +dpeer fetch -url "http://$DIODE_OWNER_ADDR.diode.link:18080/health.json" -output "$DIODE_OUTPUT_DIR/health.json" +dpeer fetch -method GET -header "accept: application/json" -url "http://$DIODE_OWNER_ADDR.diode.link:18080/health.json" -output "$DIODE_OUTPUT_DIR/health-header.json" +``` + +Verify: + +- both commands exit `0` +- output files exist +- response body matches the hosted JSON fixture + +## Files, Push, and Pull + +### `files` + +Terminal A: + +```bash +dstopdaemon +ddaemon files -fileroot "$DIODE_OWNER_FILES_ROOT" 18180 +``` + +Terminal B: + +```bash +dpeer pull "$DIODE_OWNER_ADDR:18180:missing.txt" "$DIODE_OUTPUT_DIR/missing.txt" +``` + +Verify: + +- `files` prints the file port banner +- missing file pull fails with a non-2xx error + +### `push` + +Command: + +```bash +dpeer push "$DIODE_SAMPLE_UPLOAD" "$DIODE_OWNER_ADDR:18180:uploads/sample-upload.txt" +``` + +Verify: + +- exits `0` +- owner filesystem now contains `$DIODE_OWNER_FILES_ROOT/uploads/sample-upload.txt` +- file content matches `manual-upload-payload` + +### `pull` + +Command: + +```bash +dviewer pull "$DIODE_OWNER_ADDR:18180:uploads/sample-upload.txt" "$DIODE_OUTPUT_DIR/pulled-upload.txt" +cat "$DIODE_OUTPUT_DIR/pulled-upload.txt" +``` + +Verify: + +- exits `0` +- downloaded file exists +- downloaded content matches the uploaded content + +## Proxy and Gateway Modes + +### `socksd` + +Terminal A: + +```bash +downer publish -public 18080:18080 +``` + +Terminal B: + +```bash +dpeer socksd -socksd_host 127.0.0.1 -socksd_port 19082 +``` + +Terminal C: + +```bash +curl --socks5-hostname 127.0.0.1:19082 "http://$DIODE_OWNER_ADDR.diode.link:18080/" +``` + +Verify: + +- SOCKS listener binds to `127.0.0.1:19082` +- curl returns `owner-public-fixture` +- stopping `socksd` tears down the listener cleanly + +### `gateway` + +Terminal A: + +```bash +downer publish -public 18080:18080 +``` + +Terminal B: + +```bash +dpeer gateway -httpd_host 127.0.0.1 -httpd_port 19080 +``` + +Terminal C: + +```bash +curl -H "Host: $DIODE_OWNER_ADDR.diode.link:18080" http://127.0.0.1:19080/ +``` + +Verify: + +- gateway listener binds to `127.0.0.1:19080` +- response body is the owner fixture +- with `-secure` and valid certs, HTTPS listener also starts and serves the same destination + +## SSH + +Prerequisites: + +- OpenSSH `ssh` and `ssh-keygen` installed on the client machine +- owner machine has a local UNIX account to expose via `-sshd` + +Terminal A: + +```bash +dstopdaemon +ddaemon publish -sshd "public:2222:$USER" +``` + +Terminal B: + +```bash +dpeer ssh "$USER@$DIODE_OWNER_ADDR.diode" -p 2222 +``` + +Verify: + +- `ssh` prints the local/daemon proxy address it is using +- the connection launches the system `ssh` client +- login succeeds or reaches normal SSH host key / auth prompts +- `ps` or SSH verbose output shows the hidden `ssh-proxy` path is being used + +Hidden coverage: + +- `ssh-proxy` is considered covered when `diode ssh` succeeds through its ProxyCommand path + +## Blockchain Write Commands + +### `token` + +Prerequisites: + +- funded owner wallet + +Commands: + +```bash +downer token -balance +downer token -to "$DIODE_PEER_ADDR" -value 1wei -gasprice 1gwei +dpeer token -balance +``` + +Verify: + +- balance command exits `0` and prints the wallet balance +- transfer command submits successfully +- peer balance or state changes after confirmation + +### `bns` + +Prerequisites: + +- funded owner wallet +- unique BNS name available for testing + +Commands: + +```bash +downer bns -lookup your-test-name +downer bns -register "your-test-name=$DIODE_OWNER_ADDR" +dviewer bns -lookup your-test-name +downer bns -account your-test-name +downer bns -transfer "your-test-name=$DIODE_PEER_ADDR" +downer bns -unregister your-test-name +``` + +Verify: + +- lookup returns owner and mapped address after registration +- account lookup returns nonce, code, and balance data +- transfer updates the reported owner +- unregister removes the mapping + +### `reset` + +This is already covered above under Config and Identity because it mutates fleet identity. + +## Join + +Prerequisites: + +- valid perimeter contract address +- any required Oasis local or remote environment variables for the selected `-network` +- if WireGuard coverage is required, local WireGuard tooling and permissions + +Commands: + +```bash +downer join -dry +downer join -network mainnet +downer join -network testnet +downer join -wireguard -dry +``` + +Verify: + +- `-dry` validates and prints contract-derived state without starting the daemon loop +- normal join starts the long-lived reconcile loop +- contract-driven published ports, binds, SOCKS, and WireGuard state are applied +- switching away from `join` by running `ddaemon publish ...` stops the join mode cleanly + +## MCP + +Prerequisites: + +- MCP inspector or another stdio MCP client available + +Suggested command: + +```bash +./diode -update=false mcp +``` + +Verify with your MCP client: + +- server starts on stdio and stays attached +- tool list includes version, client info, query address, file push/pull, and deploy tools when enabled +- `-mcp-preset=minimal` reduces the exposed tool set +- `-mcp-tool=...` or the corresponding env var filters tools as expected + +## Update + +Prerequisites: + +- use an official release build, not a local `development` binary, for a meaningful update test + +Commands: + +```bash +./diode update +ddaemon update +``` + +Verify: + +- standalone update either reports `No updates` or installs a newer release and restarts +- daemon-routed update returns the update output to the CLI +- after daemon update, `ddaemon daemon status` works again and the daemon resumes the previous active mode + +## Internal Command Coverage + +### `__daemon__` + +Do not invoke directly in normal manual testing. + +Verify indirectly through: + +- daemon autostart from `ddaemon publish ...` +- `daemon status` +- `daemon restart` +- daemon metadata/socket presence + +### `ssh-proxy` + +Do not invoke directly in normal manual testing. + +Verify indirectly through: + +- successful `diode ssh ...` +- ProxyCommand execution in SSH verbose logs + +## Cleanup + +Commands: + +```bash +dstopdaemon +./scripts/manual/setup_cli_lab.sh down +``` + +Verify: + +- daemon is stopped +- fixture HTTP servers are stopped +- no unexpected long-running `diode` test processes remain diff --git a/scripts/manual/setup_cli_lab.sh b/scripts/manual/setup_cli_lab.sh new file mode 100755 index 00000000..b00da9cf --- /dev/null +++ b/scripts/manual/setup_cli_lab.sh @@ -0,0 +1,258 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" +ACTION="${1:-up}" +LAB_DIR="${2:-${DIODE_MANUAL_LAB_DIR:-$ROOT_DIR/.manual/cli-lab}}" +BIN_PATH="$ROOT_DIR/diode" +RUN_DIR="$LAB_DIR/run" +WALLET_DIR="$LAB_DIR/wallets" +HTTP_DIR="$LAB_DIR/http" +FILES_DIR="$LAB_DIR/files" +OUT_DIR="$LAB_DIR/out" +ENV_PATH="$LAB_DIR/env.sh" + +OWNER_DB="$WALLET_DIR/owner/private.db" +PEER_DB="$WALLET_DIR/peer/private.db" +VIEWER_DB="$WALLET_DIR/viewer/private.db" + +OWNER_HTTP_PORT=18080 +PEER_HTTP_PORT=18081 + +OWNER_HTTP_ROOT="$HTTP_DIR/owner" +PEER_HTTP_ROOT="$HTTP_DIR/peer" +OWNER_FILES_ROOT="$FILES_DIR/owner" +PEER_FILES_ROOT="$FILES_DIR/peer" + +usage() { + cat </dev/null 2>&1; then + echo "missing required command: $name" >&2 + exit 1 + fi +} + +stop_pidfile() { + local pidfile="$1" + if [[ ! -f "$pidfile" ]]; then + return + fi + local pid="" + pid="$(cat "$pidfile" 2>/dev/null || true)" + if [[ -n "$pid" ]] && kill -0 "$pid" >/dev/null 2>&1; then + kill "$pid" >/dev/null 2>&1 || true + for _ in $(seq 1 20); do + if ! kill -0 "$pid" >/dev/null 2>&1; then + break + fi + sleep 0.1 + done + if kill -0 "$pid" >/dev/null 2>&1; then + kill -9 "$pid" >/dev/null 2>&1 || true + fi + fi + rm -f "$pidfile" +} + +start_http_server() { + local name="$1" + local root="$2" + local port="$3" + local pidfile="$RUN_DIR/$name.pid" + local logfile="$RUN_DIR/$name.log" + + stop_pidfile "$pidfile" + nohup python3 -m http.server "$port" --bind 127.0.0.1 --directory "$root" >"$logfile" 2>&1 & + local pid=$! + echo "$pid" >"$pidfile" + + for _ in $(seq 1 20); do + if ! kill -0 "$pid" >/dev/null 2>&1; then + echo "failed to start $name fixture server, see $logfile" >&2 + exit 1 + fi + if python3 - </dev/null 2>&1 +import socket +s = socket.socket() +s.settimeout(0.2) +try: + s.connect(("127.0.0.1", $port)) +except OSError: + raise SystemExit(1) +finally: + s.close() +PY + then + return + fi + sleep 0.2 + done + echo "timed out waiting for $name fixture server on 127.0.0.1:$port" >&2 + exit 1 +} + +wallet_address() { + local dbpath="$1" + local output + output="$("$BIN_PATH" -update=false -no-daemon -dbpath "$dbpath" config -list 2>&1)" + printf '%s\n' "$output" | awk -F: '/
/{gsub(/[[:space:]]/, "", $2); print $2; exit}' +} + +write_lab_files() { + mkdir -p "$OWNER_HTTP_ROOT" "$PEER_HTTP_ROOT" "$OWNER_FILES_ROOT" "$PEER_FILES_ROOT" "$OUT_DIR" "$RUN_DIR" + + cat >"$OWNER_HTTP_ROOT/index.html" <<'EOF' +owner-public-fixture +EOF + cat >"$OWNER_HTTP_ROOT/health.json" <<'EOF' +{"service":"owner","status":"ok"} +EOF + cat >"$PEER_HTTP_ROOT/index.html" <<'EOF' +peer-public-fixture +EOF + cat >"$PEER_HTTP_ROOT/health.json" <<'EOF' +{"service":"peer","status":"ok"} +EOF + cat >"$LAB_DIR/sample-upload.txt" <<'EOF' +manual-upload-payload +EOF +} + +write_env() { + local owner_addr="$1" + local peer_addr="$2" + local viewer_addr="$3" + + cat >"$ENV_PATH" </dev/null 2>&1 || true + rm -f "\$HOME/.config/diode/daemon.sock" "\$HOME/.config/diode/daemon.sock.json" +} +EOF +} + +do_up() { + require_cmd go + require_cmd python3 + + mkdir -p "$LAB_DIR" + write_lab_files + + ( + cd "$ROOT_DIR" + go build -o ./diode ./cmd/diode + ) + + local owner_addr peer_addr viewer_addr + owner_addr="$(wallet_address "$OWNER_DB")" + peer_addr="$(wallet_address "$PEER_DB")" + viewer_addr="$(wallet_address "$VIEWER_DB")" + + if [[ -z "$owner_addr" || -z "$peer_addr" || -z "$viewer_addr" ]]; then + echo "failed to derive one or more wallet addresses" >&2 + exit 1 + fi + + start_http_server "owner-http" "$OWNER_HTTP_ROOT" "$OWNER_HTTP_PORT" + start_http_server "peer-http" "$PEER_HTTP_ROOT" "$PEER_HTTP_PORT" + write_env "$owner_addr" "$peer_addr" "$viewer_addr" + + cat </dev/null || true)" + if [[ -n "$pid" ]] && kill -0 "$pid" >/dev/null 2>&1; then + echo "$name: running (pid $pid)" + continue + fi + fi + echo "$name: stopped" + done + if [[ -f "$ENV_PATH" ]]; then + echo "env: $ENV_PATH" + else + echo "env: missing" + fi +} + +case "$ACTION" in +up) + do_up + ;; +down) + do_down + ;; +status) + do_status + ;; +help|-h|--help) + usage + ;; +*) + usage >&2 + exit 2 + ;; +esac From 516b4a98d8eaa60afa66108db65151ecf166bc77 Mon Sep 17 00:00:00 2001 From: Tuhalf <37873266+tuhalf@users.noreply.github.com> Date: Thu, 7 May 2026 18:36:11 +0200 Subject: [PATCH 02/15] Fix command validation and update target --- cmd/diode/query.go | 3 +-- cmd/diode/query_test.go | 32 ++++++++++++++++++++++++++++ cmd/diode/update.go | 37 ++++++++++++++++++--------------- cmd/diode/update_test.go | 45 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 98 insertions(+), 19 deletions(-) create mode 100644 cmd/diode/query_test.go create mode 100644 cmd/diode/update_test.go diff --git a/cmd/diode/query.go b/cmd/diode/query.go index 2792277f..4432715e 100644 --- a/cmd/diode/query.go +++ b/cmd/diode/query.go @@ -33,8 +33,7 @@ func queryHandler() (err error) { cfg := config.AppConfig if cfg.QueryAddress == "" { - cfg.PrintError("Failed to query", fmt.Errorf("-address argument is required")) - return + return newExitStatusError(2, "query requires -address") } err = app.Start() diff --git a/cmd/diode/query_test.go b/cmd/diode/query_test.go new file mode 100644 index 00000000..8fded16f --- /dev/null +++ b/cmd/diode/query_test.go @@ -0,0 +1,32 @@ +package main + +import ( + "strings" + "testing" + + "github.com/diodechain/diode_client/config" +) + +func TestQueryHandlerRequiresAddress(t *testing.T) { + cfg := newSharedControlTestConfig(t) + origCfg := config.AppConfig + config.AppConfig = cfg + t.Cleanup(func() { + config.AppConfig = origCfg + }) + + err := queryHandler() + if err == nil { + t.Fatal("queryHandler() error = nil, want missing address error") + } + if !strings.Contains(err.Error(), "requires -address") { + t.Fatalf("queryHandler() error = %q, want missing address error", err.Error()) + } + statusErr, ok := err.(interface{ Status() int }) + if !ok { + t.Fatalf("queryHandler() error type %T does not expose Status()", err) + } + if statusErr.Status() != 2 { + t.Fatalf("queryHandler() status = %d, want 2", statusErr.Status()) + } +} diff --git a/cmd/diode/update.go b/cmd/diode/update.go index f78b1c55..a389c506 100644 --- a/cmd/diode/update.go +++ b/cmd/diode/update.go @@ -3,8 +3,6 @@ package main import ( "fmt" "os" - "os/exec" - "path" "path/filepath" "runtime" "time" @@ -87,26 +85,13 @@ func doUpdate(restartMode updateRestartMode) (string, error) { return "", nil } - // searching for binary in path - bin, err := exec.LookPath(m.Command) - if err != nil { - // just update local file - bin = os.Args[0] - } - - // find the real path of execute file if the file was symlink - binExe, err := filepath.EvalSymlinks(bin) - if err != nil { - binExe = bin - } - - dir := filepath.Dir(binExe) + dir := updateInstallDir() if err := m.InstallTo(tarball, dir); err != nil { cfg.PrintError("Error installing", err) return "", newExitStatusError(129, "%s", err.Error()) } - cmd := path.Join(dir, m.Command) + cmd := filepath.Join(dir, m.Command) stdoutf("Updated, restarting %s...\n", cmd) writeLastUpdateAt() if restartMode == updateRestartDeferred { @@ -116,6 +101,24 @@ func doUpdate(restartMode updateRestartMode) (string, error) { return "", nil } +func updateInstallDir() string { + bin, err := os.Executable() + if err != nil || bin == "" { + bin = os.Args[0] + } + return updateInstallDirFromExecutable(bin, filepath.EvalSymlinks) +} + +func updateInstallDirFromExecutable(bin string, evalSymlinks func(string) (string, error)) string { + if abs, err := filepath.Abs(bin); err == nil { + bin = abs + } + if resolved, err := evalSymlinks(bin); err == nil { + bin = resolved + } + return filepath.Dir(bin) +} + func download(m *update.Manager) (string, bool, error) { cfg := config.AppConfig ansi.HideCursor() diff --git a/cmd/diode/update_test.go b/cmd/diode/update_test.go new file mode 100644 index 00000000..d8c9ceb8 --- /dev/null +++ b/cmd/diode/update_test.go @@ -0,0 +1,45 @@ +package main + +import ( + "errors" + "path/filepath" + "testing" +) + +func TestUpdateInstallDirFromExecutable(t *testing.T) { + executable := filepath.Join(t.TempDir(), "bin", "diode") + + got := updateInstallDirFromExecutable(executable, func(path string) (string, error) { + return path, nil + }) + + if got != filepath.Dir(executable) { + t.Fatalf("updateInstallDirFromExecutable() = %q, want %q", got, filepath.Dir(executable)) + } +} + +func TestUpdateInstallDirFromExecutableResolvesSymlink(t *testing.T) { + tmp := t.TempDir() + link := filepath.Join(tmp, "link", "diode") + target := filepath.Join(tmp, "target", "diode") + + got := updateInstallDirFromExecutable(link, func(path string) (string, error) { + return target, nil + }) + + if got != filepath.Dir(target) { + t.Fatalf("updateInstallDirFromExecutable() = %q, want %q", got, filepath.Dir(target)) + } +} + +func TestUpdateInstallDirFromExecutableFallsBackWhenResolveFails(t *testing.T) { + executable := filepath.Join(t.TempDir(), "diode") + + got := updateInstallDirFromExecutable(executable, func(path string) (string, error) { + return "", errors.New("not a symlink") + }) + + if got != filepath.Dir(executable) { + t.Fatalf("updateInstallDirFromExecutable() = %q, want %q", got, filepath.Dir(executable)) + } +} From 0043ea8bc2fa8f59b3d00897d7916ffa7f4fc4c7 Mon Sep 17 00:00:00 2001 From: Tuhalf <37873266+tuhalf@users.noreply.github.com> Date: Fri, 8 May 2026 13:25:10 +0200 Subject: [PATCH 03/15] Make daemon modes foreground by default --- cmd/diode/app.go | 2 + cmd/diode/daemon.go | 310 +++++++++++++++++++++++++++++++++---- cmd/diode/daemon_manage.go | 25 +++ cmd/diode/daemon_test.go | 84 ++++++++++ cmd/diode/ssh.go | 23 +++ cmd/diode/ssh_test.go | 19 +++ config/flag.go | 1 + 7 files changed, 435 insertions(+), 29 deletions(-) diff --git a/cmd/diode/app.go b/cmd/diode/app.go index e0fbf7b6..d14bf135 100644 --- a/cmd/diode/app.go +++ b/cmd/diode/app.go @@ -55,6 +55,8 @@ func registerRootFlags(fs *flag.FlagSet, cfg *config.Config) { fs.BoolVar(&cfg.EnableMetrics, "metrics", false, "enable metrics stats") fs.BoolVar(&cfg.EnableTray, "tray", false, "show a system tray icon") fs.BoolVar(&cfg.DisableDaemon, "no-daemon", false, "run this command in standalone mode instead of using the diode daemon") + fs.BoolVar(&cfg.DetachDaemon, "d", false, "run daemon mode in the background") + fs.BoolVar(&cfg.DetachDaemon, "detach", false, "run daemon mode in the background") fs.BoolVar(&cfg.BlockquickDowngrade, "bqdowngrade", false, "reset blockquick window after repeated validation failures") registerSharedControlFlags(fs, cfg, "debug", "api", "apiaddr") fs.IntVar(&cfg.RlimitNofile, "rlimit_nofile", 0, "specify the file descriptor numbers that can be opened by this process") diff --git a/cmd/diode/daemon.go b/cmd/diode/daemon.go index ea10dfef..f3949fa7 100644 --- a/cmd/diode/daemon.go +++ b/cmd/diode/daemon.go @@ -34,6 +34,10 @@ const ( daemonRequestRelease = "release_local_proxy" daemonRequestUpdate = "update" daemonRequestManage = "manage" + daemonStreamStdout = "stdout" + daemonStreamStderr = "stderr" + daemonStreamFinal = "final" + daemonStreamError = "error" ) var ( @@ -80,6 +84,8 @@ var ( "-bnscachetime": true, "-maxports": true, "-no-daemon": true, + "-d": true, + "-detach": true, } localBypassCommands = map[string]bool{"": true, "version": true, "mcp": true, "ssh-proxy": true, daemonCommandName: true} daemonApplyModeCmds = map[string]bool{"publish": true, "gateway": true, "socksd": true, "join": true, "files": true} @@ -130,6 +136,7 @@ type daemonRequest struct { Command string `json:"command"` Args []string `json:"args,omitempty"` LeaseID string `json:"lease_id,omitempty"` + Attach bool `json:"attach,omitempty"` } type daemonResponse struct { @@ -144,6 +151,13 @@ type daemonResponse struct { Shutdown bool `json:"-"` } +type daemonStreamFrame struct { + Type string `json:"type"` + Data string `json:"data,omitempty"` + ExitCode int `json:"exit_code,omitempty"` + Error string `json:"error,omitempty"` +} + type runtimeDaemon struct { socketPath string metaPath string @@ -153,6 +167,7 @@ type runtimeDaemon struct { leasesMu sync.Mutex leases map[string]*rpc.Server stateMu sync.Mutex + modeChange chan struct{} activeMode string activeArgs []string ports map[int]*config.Port @@ -199,6 +214,7 @@ func daemonHandler() error { startup: startupSpec, baseConfig: sanitizedDaemonBaseConfig(cfg), leases: map[string]*rpc.Server{}, + modeChange: make(chan struct{}), } app.Defer(func() { daemonState.closeLeases() @@ -251,22 +267,27 @@ func maybeHandleDaemonCLI(args []string) (bool, int) { return false, 0 } - req := daemonRequest{ - Version: daemonProtocolVersion, - Command: inv.command, - Args: inv.execArgs, + applyDaemonStartupSpec(config.AppConfig, inv.startupSpec) + + req, err := daemonRequestForInvocation(inv) + if err != nil { + stderrln(err.Error()) + return true, exitCodeFromError(err) } - switch inv.command { - case "ssh": - req.Kind = daemonRequestLease - case "update": - req.Kind = daemonRequestUpdate - default: - if daemonApplyModeCmds[inv.command] { - req.Kind = daemonRequestApplyMode - } else { - req.Kind = daemonRequestRunTask + + if req.Attach { + handled, reason, exitCode, err := dispatchViaDaemonAttached(inv.startupSpec, req) + if !handled { + if reason != "" { + stderrln(reason) + } + return false, 0 } + if err != nil { + stderrln(err.Error()) + return true, 1 + } + return true, exitCode } resp, handled, reason, err := dispatchViaDaemon(inv.startupSpec, req) @@ -296,12 +317,40 @@ func maybeHandleDaemonCLI(args []string) (bool, int) { return true, resp.ExitCode } +func daemonRequestForInvocation(inv rootInvocation) (daemonRequest, error) { + req := daemonRequest{ + Version: daemonProtocolVersion, + Command: inv.command, + Args: inv.execArgs, + } + switch inv.command { + case "ssh": + req.Kind = daemonRequestLease + case "update": + req.Kind = daemonRequestUpdate + default: + if daemonApplyModeCmds[inv.command] { + req.Kind = daemonRequestApplyMode + } else { + req.Kind = daemonRequestRunTask + } + } + if inv.detachDaemon && req.Kind != daemonRequestApplyMode { + return req, newExitStatusError(2, "-d is only supported for daemon mode commands") + } + if req.Kind == daemonRequestApplyMode && !inv.detachDaemon { + req.Attach = true + } + return req, nil +} + type rootInvocation struct { command string commandArgs []string execArgs []string help bool disableDaemon bool + detachDaemon bool startupSpec daemonStartupSpec } @@ -324,7 +373,7 @@ func parseRootInvocation(args []string) (rootInvocation, error) { commandName = rest[0] } if commandName == "" && containsHelpArg(args) { - return rootInvocation{help: true, disableDaemon: cfg.DisableDaemon, startupSpec: daemonStartupSpecFromConfig(cfg)}, nil + return rootInvocation{help: true, disableDaemon: cfg.DisableDaemon, detachDaemon: cfg.DetachDaemon, startupSpec: daemonStartupSpecFromConfig(cfg)}, nil } if commandName == "" { commandName = "publish" @@ -332,7 +381,7 @@ func parseRootInvocation(args []string) (rootInvocation, error) { execArgs = append(execArgs, commandName) } if len(rest) > 1 && containsHelpArg(rest[1:]) { - return rootInvocation{help: true, disableDaemon: cfg.DisableDaemon, startupSpec: daemonStartupSpecFromConfig(cfg)}, nil + return rootInvocation{help: true, disableDaemon: cfg.DisableDaemon, detachDaemon: cfg.DetachDaemon, startupSpec: daemonStartupSpecFromConfig(cfg)}, nil } return rootInvocation{ command: commandName, @@ -340,6 +389,7 @@ func parseRootInvocation(args []string) (rootInvocation, error) { execArgs: execArgs, help: false, disableDaemon: cfg.DisableDaemon, + detachDaemon: cfg.DetachDaemon, startupSpec: daemonStartupSpecFromConfig(cfg), }, nil } @@ -355,10 +405,103 @@ func containsHelpArg(args []string) bool { } func dispatchViaDaemon(spec daemonStartupSpec, req daemonRequest) (daemonResponse, bool, string, error) { + conn, handled, reason, err := openDaemonConnection(spec) + if !handled || err != nil { + return daemonResponse{}, handled, reason, err + } + defer conn.Close() + if err := json.NewEncoder(conn).Encode(req); err != nil { + return daemonResponse{}, true, "", err + } + var resp daemonResponse + if err := json.NewDecoder(conn).Decode(&resp); err != nil { + return daemonResponse{}, true, "", err + } + return resp, true, "", nil +} + +func dispatchViaDaemonAttached(spec daemonStartupSpec, req daemonRequest) (bool, string, int, error) { + conn, handled, reason, err := openDaemonConnection(spec) + if !handled || err != nil { + return handled, reason, 0, err + } + defer conn.Close() + + if err := json.NewEncoder(conn).Encode(req); err != nil { + return true, "", 1, err + } + + frameCh := make(chan daemonStreamFrame, 1) + errCh := make(chan error, 1) + go func() { + dec := json.NewDecoder(conn) + for { + var frame daemonStreamFrame + if err := dec.Decode(&frame); err != nil { + errCh <- err + return + } + frameCh <- frame + if frame.Type == daemonStreamFinal || frame.Type == daemonStreamError { + return + } + } + }() + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, daemonSignals()...) + defer signal.Stop(sigCh) + stopRequested := false + for { + select { + case frame := <-frameCh: + switch frame.Type { + case daemonStreamStdout: + _, _ = io.WriteString(stdoutWriter(), frame.Data) + case daemonStreamStderr: + _, _ = io.WriteString(stderrWriter(), frame.Data) + case daemonStreamError: + if frame.Error != "" { + return true, "", exitCodeOrDefault(frame.ExitCode, 1), fmt.Errorf("%s", frame.Error) + } + return true, "", exitCodeOrDefault(frame.ExitCode, 1), fmt.Errorf("daemon stream failed") + case daemonStreamFinal: + if frame.Error != "" { + _, _ = io.WriteString(stderrWriter(), frame.Error+"\n") + } + return true, "", frame.ExitCode, nil + default: + return true, "", 1, fmt.Errorf("unknown daemon stream frame: %s", frame.Type) + } + case err := <-errCh: + if err == io.EOF { + return true, "", 0, nil + } + return true, "", 1, err + case <-sigCh: + if stopRequested { + return true, "", 130, nil + } + stopRequested = true + if err := requestDaemonModeStop(); err != nil { + return true, "", 1, err + } + } + } +} + +func exitCodeOrDefault(code int, fallback int) int { + if code != 0 { + return code + } + return fallback +} + +func openDaemonConnection(spec daemonStartupSpec) (net.Conn, bool, string, error) { meta, metaErr := readDaemonMetadata() if metaErr == nil && !reflect.DeepEqual(meta.StartupSpec, spec) { if _, err := dialDaemon(meta.SocketPath); err == nil { - return daemonResponse{}, false, "running daemon is incompatible with this invocation; using standalone mode. Run `diode daemon restart` to reload the daemon with the current binary and flags.", nil + return nil, false, "running daemon is incompatible with this invocation; using standalone mode. Run `diode daemon restart` to reload the daemon with the current binary and flags.", nil } cleanupDaemonArtifacts(meta.SocketPath, metaPathFromSocket(meta.SocketPath)) metaErr = os.ErrNotExist @@ -372,26 +515,18 @@ func dispatchViaDaemon(spec daemonStartupSpec, req daemonRequest) (daemonRespons if err != nil { cleanupDaemonArtifacts(socketPath, metaPathFromSocket(socketPath)) if err := spawnDaemon(spec); err != nil { - return daemonResponse{}, true, "", err + return nil, true, "", err } meta, err = readDaemonMetadata() if err != nil { - return daemonResponse{}, true, "", err + return nil, true, "", err } conn, err = dialDaemon(meta.SocketPath) if err != nil { - return daemonResponse{}, true, "", err + return nil, true, "", err } } - defer conn.Close() - if err := json.NewEncoder(conn).Encode(req); err != nil { - return daemonResponse{}, true, "", err - } - var resp daemonResponse - if err := json.NewDecoder(conn).Decode(&resp); err != nil { - return daemonResponse{}, true, "", err - } - return resp, true, "", nil + return conn, true, "", nil } func serveDaemon(ln net.Listener) { @@ -415,6 +550,10 @@ func handleDaemonConn(conn net.Conn) { _ = json.NewEncoder(conn).Encode(daemonResponse{Version: daemonProtocolVersion, ExitCode: 1, Error: err.Error()}) return } + if req.Kind == daemonRequestApplyMode && req.Attach { + executeDaemonAttachedRequest(conn, req) + return + } resp := executeDaemonRequest(req) if err := json.NewEncoder(conn).Encode(resp); err != nil { logDaemonInternalError("Couldn't encode daemon response", err) @@ -519,6 +658,74 @@ func executeDaemonBufferedRequest(kind string, fn func() (string, error)) daemon return resp } +func executeDaemonAttachedRequest(conn net.Conn, req daemonRequest) { + var frameMu sync.Mutex + enc := json.NewEncoder(conn) + sendFrame := func(frame daemonStreamFrame) error { + frameMu.Lock() + defer frameMu.Unlock() + return enc.Encode(frame) + } + + if daemonState == nil { + _ = sendFrame(daemonStreamFrame{Type: daemonStreamError, ExitCode: 1, Error: "daemon state is not initialized"}) + return + } + if req.Command == "publish" { + req.Args = mergeImplicitPublishArgs(req.Args) + } + + daemonExecMu.Lock() + *config.AppConfig = cloneDaemonConfig(&daemonState.baseConfig) + resetTransientConfig(config.AppConfig) + resetRequestGlobals() + config.AppConfig.StdoutWriter = daemonFrameWriter{typ: daemonStreamStdout, send: sendFrame} + config.AppConfig.StderrWriter = daemonFrameWriter{typ: daemonStreamStderr, send: sendFrame} + + activeDaemonReqMu.Lock() + activeDaemonReqKind = daemonRequestApplyMode + activeDaemonReqMu.Unlock() + err := runDaemonCommandArgs(req.Args) + activeDaemonReqMu.Lock() + activeDaemonReqKind = "" + activeDaemonReqMu.Unlock() + + exitCode := exitCodeFromError(err) + if err != nil && exitCode == 0 { + exitCode = 1 + } + if err == nil { + daemonState.updateModeSnapshot(req.Command, req.Args, config.AppConfig) + } + daemonState.baseConfig = sanitizedDaemonBaseConfig(config.AppConfig) + daemonExecMu.Unlock() + + if err != nil { + _ = sendFrame(daemonStreamFrame{Type: daemonStreamFinal, ExitCode: exitCode, Error: err.Error()}) + return + } + + daemonState.waitForModeExit(req.Command, req.Args, app.closeCh) + config.AppConfig.StdoutWriter = nil + config.AppConfig.StderrWriter = nil + _ = sendFrame(daemonStreamFrame{Type: daemonStreamFinal, ExitCode: 0}) +} + +type daemonFrameWriter struct { + typ string + send func(daemonStreamFrame) error +} + +func (w daemonFrameWriter) Write(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + if err := w.send(daemonStreamFrame{Type: w.typ, Data: string(p)}); err != nil { + return 0, err + } + return len(p), nil +} + func runDaemonCommandArgs(args []string) error { if len(args) == 0 { return newExitStatusError(2, "missing command") @@ -734,6 +941,7 @@ func resetTransientConfig(cfg *config.Config) { cfg.StdoutWriter = nil cfg.StderrWriter = nil cfg.DisableDaemon = false + cfg.DetachDaemon = false cfg.QueryAddress = "" cfg.ConfigUnsafe = false cfg.ConfigList = false @@ -986,6 +1194,14 @@ func (rd *runtimeDaemon) updateModeSnapshot(mode string, args []string, cfg *con rd.socksAddr = cfg.SocksServerAddr() rd.apiOn = cfg.EnableAPIServer rd.apiAddr = cfg.APIServerAddr + rd.notifyModeChangedLocked() +} + +func (rd *runtimeDaemon) notifyModeChangedLocked() { + if rd.modeChange != nil { + close(rd.modeChange) + } + rd.modeChange = make(chan struct{}) } func (rd *runtimeDaemon) clearModeSnapshot() { @@ -1002,6 +1218,42 @@ func (rd *runtimeDaemon) clearModeSnapshot() { rd.socksAddr = "" rd.apiOn = false rd.apiAddr = "" + rd.notifyModeChangedLocked() +} + +func (rd *runtimeDaemon) waitForModeExit(mode string, args []string, done <-chan struct{}) { + if rd == nil { + return + } + wantArgs := sanitizeModeArgs(mode, args) + for { + rd.stateMu.Lock() + activeMode := rd.activeMode + activeArgs := append([]string{}, rd.activeArgs...) + modeChange := rd.modeChange + rd.stateMu.Unlock() + + if activeMode == "" || activeMode != mode || !equalStringSlices(activeArgs, wantArgs) { + return + } + select { + case <-modeChange: + case <-done: + return + } + } +} + +func equalStringSlices(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true } func daemonRestoreArgsForRestart() []string { diff --git a/cmd/diode/daemon_manage.go b/cmd/diode/daemon_manage.go index db369077..674793c1 100644 --- a/cmd/diode/daemon_manage.go +++ b/cmd/diode/daemon_manage.go @@ -2,6 +2,7 @@ package main import ( "encoding/json" + "fmt" "io" "os" "sort" @@ -236,6 +237,25 @@ func dispatchToRunningDaemon(req daemonRequest) (daemonResponse, bool, error) { return daemonResponse{}, false, nil } +func requestDaemonModeStop() error { + resp, running, err := dispatchToRunningDaemon(daemonRequest{ + Version: daemonProtocolVersion, + Kind: daemonRequestManage, + Command: "daemon", + Args: []string{"daemon", "mode-stop"}, + }) + if err != nil || !running { + return err + } + if resp.ExitCode != 0 { + if resp.Error != "" { + return fmt.Errorf("%s", resp.Error) + } + return newExitStatusError(resp.ExitCode, "daemon mode stop failed") + } + return nil +} + func runDaemonManage(args []string, resp *daemonResponse) error { if len(args) == 0 || args[0] != "daemon" { return newExitStatusError(2, "missing daemon action") @@ -266,6 +286,11 @@ func runDaemonManage(args []string, resp *daemonResponse) error { resp.RestartPath = exePath } return nil + case "mode-stop": + stdoutln("Stopping diode daemon mode.") + app.StopMode() + daemonState.clearModeSnapshot() + return nil case "ports": if len(args) < 3 { return newExitStatusError(2, "usage: diode daemon ports [remove|clear]") diff --git a/cmd/diode/daemon_test.go b/cmd/diode/daemon_test.go index 2377ad04..f595bc42 100644 --- a/cmd/diode/daemon_test.go +++ b/cmd/diode/daemon_test.go @@ -45,6 +45,64 @@ func TestParseRootInvocationDetectsHelpAndNoDaemon(t *testing.T) { } } +func TestParseRootInvocationDetectsDetach(t *testing.T) { + inv, err := parseRootInvocation([]string{"-d", "publish", "-public", "80"}) + if err != nil { + t.Fatalf("parseRootInvocation() error = %v", err) + } + if !inv.detachDaemon { + t.Fatal("detachDaemon = false, want true") + } +} + +func TestDaemonRequestForInvocationAttachesApplyModeByDefault(t *testing.T) { + inv, err := parseRootInvocation([]string{"publish", "-public", "80"}) + if err != nil { + t.Fatalf("parseRootInvocation() error = %v", err) + } + req, err := daemonRequestForInvocation(inv) + if err != nil { + t.Fatalf("daemonRequestForInvocation() error = %v", err) + } + if req.Kind != daemonRequestApplyMode { + t.Fatalf("request kind = %q, want %q", req.Kind, daemonRequestApplyMode) + } + if !req.Attach { + t.Fatal("request Attach = false, want true") + } +} + +func TestDaemonRequestForInvocationDetachedApplyMode(t *testing.T) { + inv, err := parseRootInvocation([]string{"-d", "publish", "-public", "80"}) + if err != nil { + t.Fatalf("parseRootInvocation() error = %v", err) + } + req, err := daemonRequestForInvocation(inv) + if err != nil { + t.Fatalf("daemonRequestForInvocation() error = %v", err) + } + if req.Kind != daemonRequestApplyMode { + t.Fatalf("request kind = %q, want %q", req.Kind, daemonRequestApplyMode) + } + if req.Attach { + t.Fatal("request Attach = true, want false") + } +} + +func TestDaemonRequestForInvocationRejectsDetachedOneOff(t *testing.T) { + inv, err := parseRootInvocation([]string{"-d", "query", "-address", "0xabc"}) + if err != nil { + t.Fatalf("parseRootInvocation() error = %v", err) + } + _, err = daemonRequestForInvocation(inv) + if err == nil { + t.Fatal("daemonRequestForInvocation() error = nil, want detach rejection") + } + if !strings.Contains(err.Error(), "-d is only supported") { + t.Fatalf("daemonRequestForInvocation() error = %q, want detach rejection", err.Error()) + } +} + func TestParseRootInvocationDefaultsToPublishForRootFlagsOnly(t *testing.T) { inv, err := parseRootInvocation([]string{"-bind", "8080:0x8911295322a1b94539e258e46f18e33acf21b48a:80"}) if err != nil { @@ -61,6 +119,32 @@ func TestParseRootInvocationDefaultsToPublishForRootFlagsOnly(t *testing.T) { } } +func TestRunDaemonManageModeStopLeavesDaemonRunning(t *testing.T) { + cfg := newSharedControlTestConfig(t) + setupSharedControlTestEnv(t, cfg) + + origState := daemonState + daemonState = &runtimeDaemon{ + modeChange: make(chan struct{}), + activeMode: "publish", + activeArgs: []string{"publish", "-public", "80"}, + } + t.Cleanup(func() { + daemonState = origState + }) + app.BeginMode("publish") + + if err := runDaemonManage([]string{"daemon", "mode-stop"}, &daemonResponse{}); err != nil { + t.Fatalf("runDaemonManage(mode-stop) error = %v", err) + } + if app.Closed() { + t.Fatal("mode-stop closed the daemon app") + } + if status := daemonState.snapshotStatus(); status.ActiveMode != "" { + t.Fatalf("ActiveMode = %q, want empty", status.ActiveMode) + } +} + func TestSanitizedDaemonBaseConfigResetsTransientState(t *testing.T) { cfg := newRootConfig() cfg.ConfigList = true diff --git a/cmd/diode/ssh.go b/cmd/diode/ssh.go index 905a2474..99e8c375 100644 --- a/cmd/diode/ssh.go +++ b/cmd/diode/ssh.go @@ -189,6 +189,10 @@ func runSSHViaDaemonLease(commandArgs []string, resp daemonResponse) int { stderrln("missing ssh command arguments") return 1 } + if err := ensureDaemonSSHForegroundLogger(); err != nil { + stderrln(fmt.Sprintf("could not initialize ssh command logger: %v", err)) + return 1 + } cfg := config.AppConfig cfg.PrintLabel("Using diode daemon proxy", resp.ProxyAddr) defer func() { @@ -202,6 +206,25 @@ func runSSHViaDaemonLease(commandArgs []string, resp daemonResponse) int { return 0 } +func ensureDaemonSSHForegroundLogger() error { + cfg := config.AppConfig + if cfg == nil { + cfg = newRootConfig() + config.AppConfig = cfg + } + if cfg.Logger != nil { + return nil + } + if cfg.LogFilePath != "" { + cfg.LogMode = config.LogToFile + } else { + cfg.LogMode = config.LogToConsole + } + logger, err := config.NewLogger(cfg) + cfg.Logger = &logger + return err +} + // buildSSHLikeToolArgs builds the argv (excluding argv[0]) for an OpenSSH // tool launched by diode. The user's pass-through args are placed *after* // the diode-injected flags so positional arguments (e.g. scp's source/ diff --git a/cmd/diode/ssh_test.go b/cmd/diode/ssh_test.go index 605501c4..95c16b24 100644 --- a/cmd/diode/ssh_test.go +++ b/cmd/diode/ssh_test.go @@ -7,6 +7,8 @@ import ( "errors" "strings" "testing" + + "github.com/diodechain/diode_client/config" ) func TestExtractSSHTarget(t *testing.T) { @@ -133,6 +135,23 @@ func TestBuildSSHLikeToolArgsKeepsUserArgsLast(t *testing.T) { } } +func TestRunSSHViaDaemonLeaseInitializesForegroundLogger(t *testing.T) { + origCfg := config.AppConfig + cfg := newRootConfig() + config.AppConfig = cfg + t.Cleanup(func() { + config.AppConfig = origCfg + }) + + code := runSSHViaDaemonLease([]string{"ssh", "badhost"}, daemonResponse{ProxyAddr: "127.0.0.1:1"}) + if code == 0 { + t.Fatal("runSSHViaDaemonLease() exit code = 0, want validation failure") + } + if cfg.Logger == nil { + t.Fatal("runSSHViaDaemonLease() did not initialize foreground logger") + } +} + func TestFindOpenSSHToolWindowsInstallHelp(t *testing.T) { origLookPath := lookPath origGOOS := runtimeGOOS diff --git a/config/flag.go b/config/flag.go index 1902e5b8..43b27554 100644 --- a/config/flag.go +++ b/config/flag.go @@ -46,6 +46,7 @@ type Config struct { EnableMetrics bool `yaml:"metrics,omitempty" json:"metrics,omitempty"` EnableTray bool `yaml:"tray,omitempty" json:"tray,omitempty"` DisableDaemon bool `yaml:"-" json:"-"` + DetachDaemon bool `yaml:"-" json:"-"` BlockquickDowngrade bool `yaml:"bqdowngrade,omitempty" json:"bqdowngrade,omitempty"` RemoteRPCAddrs StringValues `yaml:"diodeaddrs,omitempty" json:"diodeaddrs,omitempty"` RemoteRPCTimeout time.Duration `yaml:"timeout,omitempty" json:"timeout,omitempty"` From 21fed4f60493d908980012779c423843b3278b47 Mon Sep 17 00:00:00 2001 From: Tuhalf <37873266+tuhalf@users.noreply.github.com> Date: Fri, 8 May 2026 14:08:29 +0200 Subject: [PATCH 04/15] Add Windows daemon transport dependency --- go.mod | 1 + go.sum | 2 ++ 2 files changed, 3 insertions(+) diff --git a/go.mod b/go.mod index 3afbf2d0..2ae2214c 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/diodechain/diode_client go 1.25.9 require ( + github.com/Microsoft/go-winio v0.6.2 github.com/caddyserver/certmagic v0.19.2 github.com/creack/pty v1.1.24 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc diff --git a/go.sum b/go.sum index ebe39b4b..d4df1f17 100644 --- a/go.sum +++ b/go.sum @@ -37,6 +37,8 @@ github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2 github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/zstd v1.4.5 h1:EndNeuB0l9syBZhut0wns3gV1hL8zX8LIu6ZiVHWLIQ= github.com/DataDog/zstd v1.4.5/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/StackExchange/wmi v1.2.1 h1:VIkavFPXSjcnS+O8yTq7NI32k0R5Aj+v39y29VYDOSA= github.com/StackExchange/wmi v1.2.1/go.mod h1:rcmrprowKIVzvc+NUiLncP2uuArMWLCbu9SBzvHz7e8= github.com/VictoriaMetrics/fastcache v1.12.1 h1:i0mICQuojGDL3KblA7wUNlY5lOK6a4bwt3uRKnkZU40= From 3a01ec3d26c072d130f7c579d5927cb5d2771aa1 Mon Sep 17 00:00:00 2001 From: Tuhalf <37873266+tuhalf@users.noreply.github.com> Date: Fri, 8 May 2026 14:10:07 +0200 Subject: [PATCH 05/15] Document daemon foreground mode --- README.md | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/README.md b/README.md index 66e6c3b8..96d5b999 100644 --- a/README.md +++ b/README.md @@ -106,6 +106,8 @@ If the publish succeeds, the client logs the public gateway URL: - `http://.diode.link/` for public port `80` - `https://.diode.link:/` for some public high ports such as `8000-8100` +`publish` stays attached by default and streams daemon output in the foreground. Press `Ctrl-C` to stop the active publish mode while leaving the daemon running idle. Use `./diode -d publish ...` if you want the command to detach and return immediately. + ### 3. Browse Diode services through a local SOCKS proxy ```bash @@ -139,6 +141,7 @@ Top-level commands: - `config`: inspect or change local stored values - `mcp`: run the client as an MCP server over stdin/stdout - `gateway`: run a public HTTP/HTTPS gateway +- `daemon`: inspect or manage the background daemon - `time`: query consensus time - `version`: print build version @@ -156,6 +159,37 @@ Not: ./diode publish -debug=true -public 80:80 ``` +Daemon-routed long-running commands behave like `docker run`: the CLI uses the daemon, stays attached, streams daemon output, and exits when interrupted or when the active mode exits. This applies to: + +- `publish` +- `gateway` +- `socksd` +- `join` +- `files` + +Use root `-d` / `--detach` to keep the daemon mode running in the background: + +```bash +./diode -d publish -public 8080:80 +./diode -d socksd +``` + +`-d` is only valid for those long-running daemon apply modes. One-off daemon-routed commands such as `query`, `fetch`, `config`, `ssh`, `update`, `push`, and `pull` return after their task finishes and reject `-d`. + +`-no-daemon` bypasses daemon routing entirely and keeps the command running in the local CLI process. + +Daemon management commands: + +```bash +./diode daemon status +./diode daemon restart +./diode daemon stop +./diode daemon ports remove 80 443 +./diode daemon ports clear +``` + +`daemon status` reports whether the daemon is running, the active mode, active arguments, published ports, bind rules, SOCKS state, and config API state. `daemon ports remove` and `daemon ports clear` only manage the current daemon-owned publish/files mode; they do not edit local config. + If you plan to contribute code, see [CONTRIBUTING.md](CONTRIBUTING.md). ## Publishing Ports @@ -695,9 +729,11 @@ Common global flags you will actually use: - `-configpath `: load YAML config from a file - `-dbpath `: change the database path +- `-d` / `--detach`: run long-running daemon modes in the background instead of staying attached - `-diodeaddrs `: override the bootstrap RPC peers - `-bind :::`: create a local forward through Diode - `-maxports `: cap concurrent ports per device +- `-no-daemon`: run in the current process instead of routing through the daemon - `-tray=true`: show the tray icon on supported platforms - `-update=false`: disable auto-update on startup - `-debug=true`: enable debug logging From 8c3807a1e060c326b777606cf269079414363148 Mon Sep 17 00:00:00 2001 From: Tuhalf <37873266+tuhalf@users.noreply.github.com> Date: Fri, 8 May 2026 14:25:18 +0200 Subject: [PATCH 06/15] Address daemon review concurrency issues --- cmd/diode/app.go | 65 ++++++++++++++++-------- cmd/diode/daemon.go | 81 ++++++++++++++++++++++++++---- cmd/diode/daemon_test.go | 105 +++++++++++++++++++++++++++++++++++++++ cmd/diode/update.go | 26 +++++++--- 4 files changed, 238 insertions(+), 39 deletions(-) diff --git a/cmd/diode/app.go b/cmd/diode/app.go index d14bf135..f7cb1b58 100644 --- a/cmd/diode/app.go +++ b/cmd/diode/app.go @@ -37,7 +37,8 @@ var ( PreRun: prepareDiode, PostRun: cleanDiode, } - bootDiodeAddrs = [6]string{ + modeStopWaitTimeout = 5 * time.Second + bootDiodeAddrs = [6]string{ "diode://0xceca2f8cf1983b4cf0c1ba51fd382c2bc37aba58@us1.prenet.diode.io:41046", "diode://0x7e4cd38d266902444dc9c8f7c0aa716a32497d0b@us2.prenet.diode.io:41046", "diode://0x68e0bafdda9ef323f692fc080d612718c941d120@as1.prenet.diode.io:41046", @@ -583,36 +584,60 @@ func (dio *Diode) StopMode() { stopCh := dio.modeStopCh doneCh := dio.modeDoneCh modeDeferals := dio.modeDeferals + socksServer := dio.socksServer + proxyServer := dio.proxyServer + configAPIServer := dio.configAPIServer + clientManager := dio.clientManager dio.modeStopCh = nil dio.modeDoneCh = nil dio.modeDeferals = nil dio.activeMode = "" + dio.socksServer = nil + dio.proxyServer = nil + dio.configAPIServer = nil dio.modeMu.Unlock() if stopCh != nil { close(stopCh) } - if doneCh != nil { - <-doneCh - } - for _, fun := range modeDeferals { - fun() - } - if dio.socksServer != nil { - dio.socksServer.Close() - dio.socksServer = nil - } - if dio.proxyServer != nil { - dio.proxyServer.Close() - dio.proxyServer = nil - } - if dio.configAPIServer != nil { - dio.configAPIServer.Close() - dio.configAPIServer = nil + cleanupMode := func() { + for _, fun := range modeDeferals { + fun() + } + if socksServer != nil { + socksServer.Close() + } + if proxyServer != nil { + proxyServer.Close() + } + if configAPIServer != nil { + configAPIServer.Close() + } + if clientManager != nil { + dio.modeMu.Lock() + activeMode := dio.activeMode + dio.modeMu.Unlock() + if activeMode == "" { + clientManager.GetPool().SetPublishedPorts(map[int]*config.Port{}) + } + } } - if dio.clientManager != nil { - dio.clientManager.GetPool().SetPublishedPorts(map[int]*config.Port{}) + if doneCh != nil { + select { + case <-doneCh: + cleanupMode() + case <-time.After(modeStopWaitTimeout): + if dio.config != nil && dio.config.Logger != nil { + dio.config.Logger.Warn("Timed out waiting for mode to finish") + } + go func() { + <-doneCh + cleanupMode() + }() + } + return } + cleanupMode() } // Closed returns the whether diode application has been closed diff --git a/cmd/diode/daemon.go b/cmd/diode/daemon.go index f3949fa7..2ad0f279 100644 --- a/cmd/diode/daemon.go +++ b/cmd/diode/daemon.go @@ -65,6 +65,8 @@ var ( "-apiaddr": true, "-rlimit_nofile": true, "-logfilepath": true, + "-logstats": true, + "-logtarget": true, "-logdatetime": true, "-configpath": true, "-cpuprofile": true, @@ -105,6 +107,8 @@ type daemonStartupSpec struct { APIServerAddr string `json:"apiaddr"` RlimitNofile int `json:"rlimit_nofile"` LogFilePath string `json:"logfilepath"` + LogStats time.Duration `json:"logstats"` + LogTarget string `json:"logtarget"` LogDateTime bool `json:"logdatetime"` ConfigFilePath string `json:"configpath"` CPUProfile string `json:"cpuprofile"` @@ -121,6 +125,7 @@ type daemonStartupSpec struct { SBlocklists config.StringValues `json:"blocklists"` SAllowlists config.StringValues `json:"allowlists"` ResolveCacheTime time.Duration `json:"resolvecachetime"` + BnsCacheTime time.Duration `json:"bnscachetime"` MaxPortsPerDevice int `json:"maxports"` } @@ -545,6 +550,11 @@ func serveDaemon(ln net.Listener) { func handleDaemonConn(conn net.Conn) { defer conn.Close() + defer func() { + if r := recover(); r != nil { + logDaemonInternalError("Panic handling daemon connection", fmt.Errorf("%v", r)) + } + }() var req daemonRequest if err := json.NewDecoder(conn).Decode(&req); err != nil { _ = json.NewEncoder(conn).Encode(daemonResponse{Version: daemonProtocolVersion, ExitCode: 1, Error: err.Error()}) @@ -570,6 +580,10 @@ func handleDaemonConn(conn net.Conn) { } func executeDaemonRequest(req daemonRequest) daemonResponse { + if req.Kind == daemonRequestUpdate { + return executeDaemonUpdateRequest(req) + } + daemonExecMu.Lock() defer daemonExecMu.Unlock() @@ -597,13 +611,9 @@ func executeDaemonRequest(req daemonRequest) daemonResponse { resp.Error = err.Error() } return resp - case daemonRequestUpdate: - return executeDaemonBufferedRequest(req.Kind, func() (string, error) { - return runDaemonUpdate(req.Args) - }) case daemonRequestManage: manageResp := daemonResponse{Version: daemonProtocolVersion} - buffered := executeDaemonBufferedRequest(req.Kind, func() (string, error) { + buffered := executeDaemonBufferedRequest(req.Kind, false, func() (string, error) { return "", runDaemonManage(req.Args, &manageResp) }) manageResp.Stdout = buffered.Stdout @@ -615,7 +625,7 @@ func executeDaemonRequest(req daemonRequest) daemonResponse { if req.Kind == daemonRequestApplyMode && req.Command == "publish" { req.Args = mergeImplicitPublishArgs(req.Args) } - resp = executeDaemonBufferedRequest(req.Kind, func() (string, error) { + resp = executeDaemonBufferedRequest(req.Kind, daemonRequestPersistsBaseConfig(req), func() (string, error) { return "", runDaemonCommandArgs(req.Args) }) if req.Kind == daemonRequestApplyMode && resp.ExitCode == 0 { @@ -624,7 +634,49 @@ func executeDaemonRequest(req daemonRequest) daemonResponse { return resp } -func executeDaemonBufferedRequest(kind string, fn func() (string, error)) daemonResponse { +func executeDaemonUpdateRequest(req daemonRequest) daemonResponse { + resp := daemonResponse{Version: daemonProtocolVersion} + var stdout bytes.Buffer + var stderr bytes.Buffer + + daemonExecMu.Lock() + if daemonState == nil { + daemonExecMu.Unlock() + resp.ExitCode = 1 + resp.Error = "daemon state is not initialized" + return resp + } + cfg := cloneDaemonConfig(&daemonState.baseConfig) + resetTransientConfig(&cfg) + cfg.StdoutWriter = &stdout + cfg.StderrWriter = &stderr + daemonExecMu.Unlock() + + restartPath, err := runDaemonUpdateWithConfig(req.Args, &cfg) + + resp.Stdout = stdout.String() + resp.Stderr = stderr.String() + resp.RestartPath = restartPath + if err != nil { + resp.ExitCode = exitCodeFromError(err) + resp.Error = err.Error() + if resp.ExitCode == 0 { + resp.ExitCode = 1 + } + } else { + resp.ExitCode = 0 + } + return resp +} + +func daemonRequestPersistsBaseConfig(req daemonRequest) bool { + if req.Kind != daemonRequestRunTask { + return false + } + return req.Command == "config" || req.Command == "reset" +} + +func executeDaemonBufferedRequest(kind string, persistBaseConfig bool, fn func() (string, error)) daemonResponse { resp := daemonResponse{Version: daemonProtocolVersion} var stdout bytes.Buffer var stderr bytes.Buffer @@ -654,7 +706,9 @@ func executeDaemonBufferedRequest(kind string, fn func() (string, error)) daemon } else { resp.ExitCode = 0 } - daemonState.baseConfig = sanitizedDaemonBaseConfig(config.AppConfig) + if persistBaseConfig { + daemonState.baseConfig = sanitizedDaemonBaseConfig(config.AppConfig) + } return resp } @@ -697,7 +751,6 @@ func executeDaemonAttachedRequest(conn net.Conn, req daemonRequest) { if err == nil { daemonState.updateModeSnapshot(req.Command, req.Args, config.AppConfig) } - daemonState.baseConfig = sanitizedDaemonBaseConfig(config.AppConfig) daemonExecMu.Unlock() if err != nil { @@ -706,8 +759,6 @@ func executeDaemonAttachedRequest(conn net.Conn, req daemonRequest) { } daemonState.waitForModeExit(req.Command, req.Args, app.closeCh) - config.AppConfig.StdoutWriter = nil - config.AppConfig.StderrWriter = nil _ = sendFrame(daemonStreamFrame{Type: daemonStreamFinal, ExitCode: 0}) } @@ -983,6 +1034,8 @@ func resetRequestGlobals() { scfg.Host = "127.0.0.1" scfg.Port = 8080 scfg.Indexed = false + publishFileSpecs = nil + publishFileFileroot = "" filesFileroot = "" edgeACME = false edgeACMEEmail = "" @@ -1030,6 +1083,8 @@ func daemonStartupSpecFromConfig(cfg *config.Config) daemonStartupSpec { APIServerAddr: cfg.APIServerAddr, RlimitNofile: cfg.RlimitNofile, LogFilePath: cfg.LogFilePath, + LogStats: cfg.LogStats, + LogTarget: cfg.LogTarget, LogDateTime: cfg.LogDateTime, ConfigFilePath: cfg.ConfigFilePath, CPUProfile: cfg.CPUProfile, @@ -1046,6 +1101,7 @@ func daemonStartupSpecFromConfig(cfg *config.Config) daemonStartupSpec { SBlocklists: append(config.StringValues{}, cfg.SBlocklists...), SAllowlists: append(config.StringValues{}, cfg.SAllowlists...), ResolveCacheTime: cfg.ResolveCacheTime, + BnsCacheTime: cfg.BnsCacheTime, MaxPortsPerDevice: cfg.MaxPortsPerDevice, } } @@ -1063,6 +1119,8 @@ func applyDaemonStartupSpec(cfg *config.Config, spec daemonStartupSpec) { cfg.APIServerAddr = spec.APIServerAddr cfg.RlimitNofile = spec.RlimitNofile cfg.LogFilePath = spec.LogFilePath + cfg.LogStats = spec.LogStats + cfg.LogTarget = spec.LogTarget cfg.LogDateTime = spec.LogDateTime cfg.ConfigFilePath = spec.ConfigFilePath cfg.CPUProfile = spec.CPUProfile @@ -1079,6 +1137,7 @@ func applyDaemonStartupSpec(cfg *config.Config, spec daemonStartupSpec) { cfg.SBlocklists = append(config.StringValues{}, spec.SBlocklists...) cfg.SAllowlists = append(config.StringValues{}, spec.SAllowlists...) cfg.ResolveCacheTime = spec.ResolveCacheTime + cfg.BnsCacheTime = spec.BnsCacheTime cfg.MaxPortsPerDevice = spec.MaxPortsPerDevice } diff --git a/cmd/diode/daemon_test.go b/cmd/diode/daemon_test.go index f595bc42..918ea927 100644 --- a/cmd/diode/daemon_test.go +++ b/cmd/diode/daemon_test.go @@ -145,6 +145,24 @@ func TestRunDaemonManageModeStopLeavesDaemonRunning(t *testing.T) { } } +func TestStopModeTimesOutWaitingForDone(t *testing.T) { + prevTimeout := modeStopWaitTimeout + modeStopWaitTimeout = 5 * time.Millisecond + t.Cleanup(func() { + modeStopWaitTimeout = prevTimeout + }) + + dio := &Diode{config: &config.Config{}} + dio.modeStopCh = make(chan struct{}) + dio.modeDoneCh = make(chan struct{}) + + start := time.Now() + dio.StopMode() + if time.Since(start) > time.Second { + t.Fatal("StopMode blocked waiting for mode done channel") + } +} + func TestSanitizedDaemonBaseConfigResetsTransientState(t *testing.T) { cfg := newRootConfig() cfg.ConfigList = true @@ -183,11 +201,67 @@ func TestSanitizedDaemonBaseConfigResetsTransientState(t *testing.T) { } } +func TestDaemonBufferedRequestPersistsBaseConfigOnlyWhenRequested(t *testing.T) { + prevCfg := config.AppConfig + prevState := daemonState + t.Cleanup(func() { + config.AppConfig = prevCfg + daemonState = prevState + }) + + base := newRootConfig() + base.Debug = false + config.AppConfig = newRootConfig() + daemonState = &runtimeDaemon{baseConfig: *base} + + resp := executeDaemonBufferedRequest(daemonRequestRunTask, false, func() (string, error) { + config.AppConfig.Debug = true + return "", nil + }) + if resp.ExitCode != 0 { + t.Fatalf("executeDaemonBufferedRequest() exit = %d err = %q", resp.ExitCode, resp.Error) + } + if daemonState.baseConfig.Debug { + t.Fatal("baseConfig.Debug changed for non-persistent request") + } + + resp = executeDaemonBufferedRequest(daemonRequestRunTask, true, func() (string, error) { + config.AppConfig.Debug = true + return "", nil + }) + if resp.ExitCode != 0 { + t.Fatalf("executeDaemonBufferedRequest(persist) exit = %d err = %q", resp.ExitCode, resp.Error) + } + if !daemonState.baseConfig.Debug { + t.Fatal("baseConfig.Debug did not change for persistent request") + } +} + +func TestResetRequestGlobalsClearsPublishFileGlobals(t *testing.T) { + publishFileSpecs = config.StringValues{"8080", "9090"} + publishFileFileroot = "/tmp/files" + t.Cleanup(func() { + publishFileSpecs = nil + publishFileFileroot = "" + }) + + resetRequestGlobals() + if len(publishFileSpecs) != 0 { + t.Fatalf("publishFileSpecs = %#v, want empty", publishFileSpecs) + } + if publishFileFileroot != "" { + t.Fatalf("publishFileFileroot = %q, want empty", publishFileFileroot) + } +} + func TestDaemonStartupSpecFromConfigCopiesRootScopedValues(t *testing.T) { cfg := newRootConfig() cfg.RemoteRPCTimeout = 7 * time.Second cfg.RetryWait = 2 * time.Second cfg.ResolveCacheTime = 5 * time.Minute + cfg.BnsCacheTime = 6 * time.Minute + cfg.LogStats = 30 * time.Second + cfg.LogTarget = "collector:1234" cfg.SBlocklists = config.StringValues{"0x1"} spec := daemonStartupSpecFromConfig(cfg) @@ -200,11 +274,42 @@ func TestDaemonStartupSpecFromConfigCopiesRootScopedValues(t *testing.T) { if spec.ResolveCacheTime != 5*time.Minute { t.Fatalf("ResolveCacheTime = %v, want 5m", spec.ResolveCacheTime) } + if spec.BnsCacheTime != 6*time.Minute { + t.Fatalf("BnsCacheTime = %v, want 6m", spec.BnsCacheTime) + } + if spec.LogStats != 30*time.Second { + t.Fatalf("LogStats = %v, want 30s", spec.LogStats) + } + if spec.LogTarget != "collector:1234" { + t.Fatalf("LogTarget = %q, want collector:1234", spec.LogTarget) + } if len(spec.SBlocklists) != 1 || spec.SBlocklists[0] != "0x1" { t.Fatalf("SBlocklists = %#v, want [0x1]", spec.SBlocklists) } } +func TestApplyDaemonStartupSpecCopiesLogAndCacheValues(t *testing.T) { + cfg := newRootConfig() + applyDaemonStartupSpec(cfg, daemonStartupSpec{ + LogStats: 15 * time.Second, + LogTarget: "collector:1234", + ResolveCacheTime: 4 * time.Minute, + BnsCacheTime: 3 * time.Minute, + }) + if cfg.LogStats != 15*time.Second { + t.Fatalf("LogStats = %v, want 15s", cfg.LogStats) + } + if cfg.LogTarget != "collector:1234" { + t.Fatalf("LogTarget = %q, want collector:1234", cfg.LogTarget) + } + if cfg.ResolveCacheTime != 4*time.Minute { + t.Fatalf("ResolveCacheTime = %v, want 4m", cfg.ResolveCacheTime) + } + if cfg.BnsCacheTime != 3*time.Minute { + t.Fatalf("BnsCacheTime = %v, want 3m", cfg.BnsCacheTime) + } +} + func TestDaemonRestartEnvReplacesDaemonSpecificVars(t *testing.T) { t.Setenv(envDaemonReadyFD, "3") t.Setenv(envDaemonStartupSpec, `{"debug":false}`) diff --git a/cmd/diode/update.go b/cmd/diode/update.go index a389c506..6cdf0a0f 100644 --- a/cmd/diode/update.go +++ b/cmd/diode/update.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "io" "os" "path/filepath" "runtime" @@ -47,15 +48,18 @@ const ( updateRestartDeferred ) -func runDaemonUpdate(args []string) (string, error) { +func runDaemonUpdateWithConfig(args []string, cfg *config.Config) (string, error) { if len(args) == 0 || args[0] != "update" { return "", newExitStatusError(2, "missing update command") } - return doUpdate(updateRestartDeferred) + return doUpdateWithConfig(cfg, updateRestartDeferred) } func doUpdate(restartMode updateRestartMode) (string, error) { - cfg := config.AppConfig + return doUpdateWithConfig(config.AppConfig, restartMode) +} + +func doUpdateWithConfig(cfg *config.Config, restartMode updateRestartMode) (string, error) { m := &update.Manager{ Command: "diode", Store: &github.Store{ @@ -69,7 +73,7 @@ func doUpdate(restartMode updateRestartMode) (string, error) { m.Command += ".exe" } - tarball, ok, err := download(m) + tarball, ok, err := download(cfg, m) if !ok { // Will recheck for an update in 24 hours go func() { @@ -92,7 +96,7 @@ func doUpdate(restartMode updateRestartMode) (string, error) { } cmd := filepath.Join(dir, m.Command) - stdoutf("Updated, restarting %s...\n", cmd) + fmt.Fprintf(updateOutputWriter(cfg), "Updated, restarting %s...\n", cmd) writeLastUpdateAt() if restartMode == updateRestartDeferred { return cmd, nil @@ -119,8 +123,7 @@ func updateInstallDirFromExecutable(bin string, evalSymlinks func(string) (strin return filepath.Dir(bin) } -func download(m *update.Manager) (string, bool, error) { - cfg := config.AppConfig +func download(cfg *config.Config, m *update.Manager) (string, bool, error) { ansi.HideCursor() defer ansi.ShowCursor() @@ -150,7 +153,7 @@ func download(m *update.Manager) (string, bool, error) { } // whitespace - stdoutln() + fmt.Fprintln(updateOutputWriter(cfg)) // download tarball to a tmp dir tarball, err := a.DownloadProxy(progress.Reader) @@ -160,3 +163,10 @@ func download(m *update.Manager) (string, bool, error) { return tarball, true, nil } + +func updateOutputWriter(cfg *config.Config) io.Writer { + if cfg != nil && cfg.StdoutWriter != nil { + return cfg.StdoutWriter + } + return os.Stdout +} From bad4080fad48ec3d22953b0acdf2beeb8471049e Mon Sep 17 00:00:00 2001 From: Tuhalf <37873266+tuhalf@users.noreply.github.com> Date: Fri, 8 May 2026 14:41:00 +0200 Subject: [PATCH 07/15] Remove unreachable fetch file write block --- cmd/diode/control_shared_test.go | 3 +++ cmd/diode/fetch.go | 4 ---- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/cmd/diode/control_shared_test.go b/cmd/diode/control_shared_test.go index 9937cd51..4b99e364 100644 --- a/cmd/diode/control_shared_test.go +++ b/cmd/diode/control_shared_test.go @@ -64,6 +64,9 @@ func setupSharedControlTestEnv(t *testing.T, cfg *config.Config) { db.DB = testDB t.Cleanup(func() { + if app != nil { + app.Close() + } app = origApp config.AppConfig = origCfg db.DB = origDB diff --git a/cmd/diode/fetch.go b/cmd/diode/fetch.go index 847a5fdd..604ab2e6 100644 --- a/cmd/diode/fetch.go +++ b/cmd/diode/fetch.go @@ -209,9 +209,5 @@ func fetchHandler() (err error) { } return } - if f != nil { - io.Copy(f, src) - f.Close() - } return } From 3ce026b727a205a67d2243d0ef02715e9e403340 Mon Sep 17 00:00:00 2001 From: Tuhalf <37873266+tuhalf@users.noreply.github.com> Date: Fri, 8 May 2026 14:51:26 +0200 Subject: [PATCH 08/15] Scope daemon instances by wallet dbpath --- README.md | 9 ++++++ cmd/diode/daemon.go | 1 + cmd/diode/daemon_paths.go | 42 +++++++++++++++++++++++++++ cmd/diode/daemon_test.go | 31 ++++++++++++++++++++ cmd/diode/daemon_transport_unix.go | 8 ++--- cmd/diode/daemon_transport_windows.go | 15 +++------- 6 files changed, 89 insertions(+), 17 deletions(-) create mode 100644 cmd/diode/daemon_paths.go diff --git a/README.md b/README.md index 96d5b999..fd630bf0 100644 --- a/README.md +++ b/README.md @@ -190,6 +190,15 @@ Daemon management commands: `daemon status` reports whether the daemon is running, the active mode, active arguments, published ports, bind rules, SOCKS state, and config API state. `daemon ports remove` and `daemon ports clear` only manage the current daemon-owned publish/files mode; they do not edit local config. +Daemon instances are scoped by wallet database path. The default `diode daemon ...` commands manage the daemon for the default `-dbpath`; passing a different root `-dbpath` manages a separate daemon for that wallet: + +```bash +./diode -dbpath ./wallet-a.db -d publish -public 8080:80 +./diode -dbpath ./wallet-b.db -d socksd -socksd_port 1082 +./diode -dbpath ./wallet-a.db daemon status +./diode -dbpath ./wallet-b.db daemon stop +``` + If you plan to contribute code, see [CONTRIBUTING.md](CONTRIBUTING.md). ## Publishing Ports diff --git a/cmd/diode/daemon.go b/cmd/diode/daemon.go index 2ad0f279..7ac700e5 100644 --- a/cmd/diode/daemon.go +++ b/cmd/diode/daemon.go @@ -266,6 +266,7 @@ func maybeHandleDaemonCLI(args []string) (bool, int) { return false, 0 } if inv.command == "daemon" { + applyDaemonStartupSpec(config.AppConfig, inv.startupSpec) return handleDaemonManagerCLI(inv.commandArgs) } if inv.disableDaemon || inv.help || localBypassCommands[inv.command] || !daemonRunnableCmds[inv.command] { diff --git a/cmd/diode/daemon_paths.go b/cmd/diode/daemon_paths.go new file mode 100644 index 00000000..85e92698 --- /dev/null +++ b/cmd/diode/daemon_paths.go @@ -0,0 +1,42 @@ +package main + +import ( + "crypto/sha1" + "encoding/hex" + "os" + "path/filepath" + + "github.com/diodechain/diode_client/config" + "github.com/diodechain/diode_client/util" +) + +func daemonPathID() string { + dbPath := "" + if config.AppConfig != nil { + dbPath = config.AppConfig.DBPath + } + if dbPath == "" { + dbPath = util.DefaultDBPath() + } + if abs, err := filepath.Abs(dbPath); err == nil { + dbPath = abs + } + seed := filepath.Clean(dbPath) + if base, err := os.UserConfigDir(); err == nil { + seed = filepath.Clean(base) + "\x00" + seed + } + sum := sha1.Sum([]byte(seed)) + return hex.EncodeToString(sum[:8]) +} + +func daemonPathDir() (string, error) { + base, err := os.UserConfigDir() + if err != nil { + return "", err + } + dir := filepath.Join(base, "diode", "daemons", daemonPathID()) + if err := os.MkdirAll(dir, 0700); err != nil { + return "", err + } + return dir, nil +} diff --git a/cmd/diode/daemon_test.go b/cmd/diode/daemon_test.go index 918ea927..185d2b8d 100644 --- a/cmd/diode/daemon_test.go +++ b/cmd/diode/daemon_test.go @@ -103,6 +103,37 @@ func TestDaemonRequestForInvocationRejectsDetachedOneOff(t *testing.T) { } } +func TestDaemonPathsAreScopedByDBPath(t *testing.T) { + prevCfg := config.AppConfig + t.Cleanup(func() { + config.AppConfig = prevCfg + }) + + dir := t.TempDir() + cfgA := newRootConfig() + cfgA.DBPath = dir + "/wallet-a.db" + config.AppConfig = cfgA + socketA, metaA, err := daemonPaths() + if err != nil { + t.Fatalf("daemonPaths(wallet-a) error = %v", err) + } + + cfgB := newRootConfig() + cfgB.DBPath = dir + "/wallet-b.db" + config.AppConfig = cfgB + socketB, metaB, err := daemonPaths() + if err != nil { + t.Fatalf("daemonPaths(wallet-b) error = %v", err) + } + + if socketA == socketB { + t.Fatalf("socket path should differ for different dbpaths: %q", socketA) + } + if metaA == metaB { + t.Fatalf("metadata path should differ for different dbpaths: %q", metaA) + } +} + func TestParseRootInvocationDefaultsToPublishForRootFlagsOnly(t *testing.T) { inv, err := parseRootInvocation([]string{"-bind", "8080:0x8911295322a1b94539e258e46f18e33acf21b48a:80"}) if err != nil { diff --git a/cmd/diode/daemon_transport_unix.go b/cmd/diode/daemon_transport_unix.go index fe6d1991..f28d8edd 100644 --- a/cmd/diode/daemon_transport_unix.go +++ b/cmd/diode/daemon_transport_unix.go @@ -14,14 +14,10 @@ import ( ) func daemonPaths() (string, string, error) { - base, err := os.UserConfigDir() + dir, err := daemonPathDir() if err != nil { return "", "", err } - dir := filepath.Join(base, "diode") - if err := os.MkdirAll(dir, 0700); err != nil { - return "", "", err - } socketPath := filepath.Join(dir, "daemon.sock") return socketPath, metaPathFromSocket(socketPath), nil } @@ -30,7 +26,7 @@ func metaPathFromSocket(socketPath string) string { if socketPath == "" { socketPath, _, _ = daemonPaths() } - return socketPath + ".json" + return filepath.Join(filepath.Dir(socketPath), "daemon.json") } func daemonListen(socketPath string) (net.Listener, error) { diff --git a/cmd/diode/daemon_transport_windows.go b/cmd/diode/daemon_transport_windows.go index ca3759cd..b1cecc69 100644 --- a/cmd/diode/daemon_transport_windows.go +++ b/cmd/diode/daemon_transport_windows.go @@ -3,8 +3,6 @@ package main import ( - "crypto/sha1" - "encoding/hex" "encoding/json" "fmt" "net" @@ -18,25 +16,20 @@ import ( ) func daemonPaths() (string, string, error) { - base, err := os.UserConfigDir() + dir, err := daemonPathDir() if err != nil { return "", "", err } - dir := filepath.Join(base, "diode") - if err := os.MkdirAll(dir, 0700); err != nil { - return "", "", err - } - sum := sha1.Sum([]byte(dir)) - socketPath := `\\.\pipe\diode-client-` + hex.EncodeToString(sum[:8]) + socketPath := `\\.\pipe\diode-client-` + daemonPathID() return socketPath, metaPathFromSocket(socketPath), nil } func metaPathFromSocket(socketPath string) string { - base, err := os.UserConfigDir() + dir, err := daemonPathDir() if err != nil { return "daemon.json" } - return filepath.Join(base, "diode", "daemon.json") + return filepath.Join(dir, "daemon.json") } func daemonListen(socketPath string) (net.Listener, error) { From 9e333c235a68545d98f8cce67641964254179bf0 Mon Sep 17 00:00:00 2001 From: Tuhalf <37873266+tuhalf@users.noreply.github.com> Date: Fri, 8 May 2026 15:41:18 +0200 Subject: [PATCH 09/15] Use SHA-256 for daemon path IDs --- cmd/diode/daemon_paths.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/diode/daemon_paths.go b/cmd/diode/daemon_paths.go index 85e92698..8d7a901f 100644 --- a/cmd/diode/daemon_paths.go +++ b/cmd/diode/daemon_paths.go @@ -1,7 +1,7 @@ package main import ( - "crypto/sha1" + "crypto/sha256" "encoding/hex" "os" "path/filepath" @@ -25,7 +25,7 @@ func daemonPathID() string { if base, err := os.UserConfigDir(); err == nil { seed = filepath.Clean(base) + "\x00" + seed } - sum := sha1.Sum([]byte(seed)) + sum := sha256.Sum256([]byte(seed)) return hex.EncodeToString(sum[:8]) } From c6d79a1c515769c6f4aa1a59bba102ccc2f21fdb Mon Sep 17 00:00:00 2001 From: Tuhalf <37873266+tuhalf@users.noreply.github.com> Date: Fri, 8 May 2026 16:12:18 +0200 Subject: [PATCH 10/15] Fix Windows daemon path build --- cmd/diode/daemon_transport_windows.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/diode/daemon_transport_windows.go b/cmd/diode/daemon_transport_windows.go index b1cecc69..f55283d6 100644 --- a/cmd/diode/daemon_transport_windows.go +++ b/cmd/diode/daemon_transport_windows.go @@ -21,7 +21,7 @@ func daemonPaths() (string, string, error) { return "", "", err } socketPath := `\\.\pipe\diode-client-` + daemonPathID() - return socketPath, metaPathFromSocket(socketPath), nil + return socketPath, filepath.Join(dir, "daemon.json"), nil } func metaPathFromSocket(socketPath string) string { From c033bd2bcfbe2dc98b0cffff022d74d17d58fd6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=96mer=20AKG=C3=9CL?= <37873266+tuhalf@users.noreply.github.com> Date: Wed, 27 May 2026 00:15:31 +0200 Subject: [PATCH 11/15] Fix daemon gateway status and output --- cmd/diode/app.go | 4 +- cmd/diode/daemon.go | 59 +++++++++++++++++++---------- cmd/diode/daemon_manage.go | 34 +++++++++++++---- cmd/diode/daemon_test.go | 77 ++++++++++++++++++++++++++++++++++++++ cmd/diode/gateway.go | 33 ++++++++++++++++ 5 files changed, 178 insertions(+), 29 deletions(-) diff --git a/cmd/diode/app.go b/cmd/diode/app.go index f7cb1b58..dd3da5f0 100644 --- a/cmd/diode/app.go +++ b/cmd/diode/app.go @@ -441,9 +441,9 @@ func (dio *Diode) Start() error { dio.startMu.Lock() firstStart := !dio.started + cfg.PrintLabel("Client address", cfg.ClientAddr.HexString()) + cfg.PrintLabel("Fleet address", cfg.FleetAddr.HexString()) if firstStart { - cfg.PrintLabel("Client address", cfg.ClientAddr.HexString()) - cfg.PrintLabel("Fleet address", cfg.FleetAddr.HexString()) dio.clientManager.Start() dio.started = true } diff --git a/cmd/diode/daemon.go b/cmd/diode/daemon.go index 7ac700e5..09d03924 100644 --- a/cmd/diode/daemon.go +++ b/cmd/diode/daemon.go @@ -164,23 +164,28 @@ type daemonStreamFrame struct { } type runtimeDaemon struct { - socketPath string - metaPath string - listener net.Listener - startup daemonStartupSpec - baseConfig config.Config - leasesMu sync.Mutex - leases map[string]*rpc.Server - stateMu sync.Mutex - modeChange chan struct{} - activeMode string - activeArgs []string - ports map[int]*config.Port - binds []config.Bind - socksAddr string - socksOn bool - apiAddr string - apiOn bool + socketPath string + metaPath string + listener net.Listener + startup daemonStartupSpec + baseConfig config.Config + leasesMu sync.Mutex + leases map[string]*rpc.Server + stateMu sync.Mutex + modeChange chan struct{} + activeMode string + activeArgs []string + ports map[int]*config.Port + binds []config.Bind + socksAddr string + socksOn bool + gatewayAddr string + gatewayOn bool + secureGatewayAddr string + secureGatewayOn bool + secureGatewayAdditionalAddrs []string + apiAddr string + apiOn bool } func init() { @@ -1250,8 +1255,14 @@ func (rd *runtimeDaemon) updateModeSnapshot(mode string, args []string, cfg *con rd.activeArgs = sanitizeModeArgs(mode, args) rd.ports = cloneDaemonPortMap(cfg.PublishedPorts) rd.binds = append([]config.Bind{}, cfg.Binds...) - rd.socksOn = cfg.EnableSocksServer - rd.socksAddr = cfg.SocksServerAddr() + gatewayStatus := gatewayStatusFromConfig(cfg) + rd.socksOn = gatewayStatus.SocksEnabled + rd.socksAddr = gatewayStatus.SocksAddr + rd.gatewayOn = gatewayStatus.GatewayEnabled + rd.gatewayAddr = gatewayStatus.GatewayAddr + rd.secureGatewayOn = gatewayStatus.SecureGatewayEnabled + rd.secureGatewayAddr = gatewayStatus.SecureGatewayAddr + rd.secureGatewayAdditionalAddrs = append([]string{}, gatewayStatus.SecureGatewayAdditionalAddrs...) rd.apiOn = cfg.EnableAPIServer rd.apiAddr = cfg.APIServerAddr rd.notifyModeChangedLocked() @@ -1276,6 +1287,11 @@ func (rd *runtimeDaemon) clearModeSnapshot() { rd.binds = nil rd.socksOn = false rd.socksAddr = "" + rd.gatewayOn = false + rd.gatewayAddr = "" + rd.secureGatewayOn = false + rd.secureGatewayAddr = "" + rd.secureGatewayAdditionalAddrs = nil rd.apiOn = false rd.apiAddr = "" rd.notifyModeChangedLocked() @@ -1341,6 +1357,11 @@ func (rd *runtimeDaemon) snapshotStatus() daemonRuntimeStatus { status.Binds = append([]config.Bind{}, rd.binds...) status.SocksEnabled = rd.socksOn status.SocksAddr = rd.socksAddr + status.GatewayEnabled = rd.gatewayOn + status.GatewayAddr = rd.gatewayAddr + status.SecureGatewayEnabled = rd.secureGatewayOn + status.SecureGatewayAddr = rd.secureGatewayAddr + status.SecureGatewayAdditionalAddrs = append([]string{}, rd.secureGatewayAdditionalAddrs...) status.APIEnabled = rd.apiOn status.APIAddr = rd.apiAddr return status diff --git a/cmd/diode/daemon_manage.go b/cmd/diode/daemon_manage.go index 674793c1..40226101 100644 --- a/cmd/diode/daemon_manage.go +++ b/cmd/diode/daemon_manage.go @@ -15,14 +15,19 @@ import ( ) type daemonRuntimeStatus struct { - ActiveMode string - ActiveArgs []string - PublishedPorts map[int]*config.Port - Binds []config.Bind - SocksEnabled bool - SocksAddr string - APIEnabled bool - APIAddr string + ActiveMode string + ActiveArgs []string + PublishedPorts map[int]*config.Port + Binds []config.Bind + SocksEnabled bool + SocksAddr string + GatewayEnabled bool + GatewayAddr string + SecureGatewayEnabled bool + SecureGatewayAddr string + SecureGatewayAdditionalAddrs []string + APIEnabled bool + APIAddr string } var ( @@ -341,6 +346,19 @@ func renderDaemonStatus() { } else { cfg.PrintLabel("SOCKS proxy", "disabled") } + if status.GatewayEnabled { + cfg.PrintLabel("HTTP gateway", status.GatewayAddr) + } else { + cfg.PrintLabel("HTTP gateway", "disabled") + } + if status.SecureGatewayEnabled { + cfg.PrintLabel("HTTPS gateway", status.SecureGatewayAddr) + } else { + cfg.PrintLabel("HTTPS gateway", "disabled") + } + if len(status.SecureGatewayAdditionalAddrs) > 0 { + cfg.PrintLabel("HTTPS gateways", strings.Join(status.SecureGatewayAdditionalAddrs, ", ")) + } if status.APIEnabled { cfg.PrintLabel("Config API", status.APIAddr) } else { diff --git a/cmd/diode/daemon_test.go b/cmd/diode/daemon_test.go index 185d2b8d..1c386064 100644 --- a/cmd/diode/daemon_test.go +++ b/cmd/diode/daemon_test.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "strings" "testing" "time" @@ -194,6 +195,82 @@ func TestStopModeTimesOutWaitingForDone(t *testing.T) { } } +func TestStartPrintsIdentityOnSubsequentCalls(t *testing.T) { + cfg := newSharedControlTestConfig(t) + var stdout bytes.Buffer + cfg.StdoutWriter = &stdout + + dio := &Diode{ + config: cfg, + cmd: daemonManageCmd, + controlsLoaded: true, + started: true, + } + + if err := dio.Start(); err != nil { + t.Fatalf("first Start() error = %v", err) + } + if got := strings.Count(stdout.String(), "Client address"); got != 1 { + t.Fatalf("first Start() client address lines = %d, want 1\n%s", got, stdout.String()) + } + + stdout.Reset() + if err := dio.Start(); err != nil { + t.Fatalf("second Start() error = %v", err) + } + if got := strings.Count(stdout.String(), "Client address"); got != 1 { + t.Fatalf("second Start() client address lines = %d, want 1\n%s", got, stdout.String()) + } +} + +func TestRenderDaemonStatusIncludesGatewayListeners(t *testing.T) { + prevCfg := config.AppConfig + prevState := daemonState + t.Cleanup(func() { + config.AppConfig = prevCfg + daemonState = prevState + }) + + cfg := newSharedControlTestConfig(t) + var stdout bytes.Buffer + cfg.StdoutWriter = &stdout + config.AppConfig = cfg + daemonState = &runtimeDaemon{ + socketPath: "/tmp/daemon.sock", + modeChange: make(chan struct{}), + activeMode: "gateway", + activeArgs: []string{"gateway", "-httpd_port", "18080"}, + socksOn: true, + socksAddr: "127.0.0.1:18080", + gatewayOn: true, + gatewayAddr: "127.0.0.1:18081", + secureGatewayOn: true, + secureGatewayAddr: "127.0.0.1:18443", + secureGatewayAdditionalAddrs: []string{"127.0.0.1:18444"}, + ports: map[int]*config.Port{}, + } + + renderDaemonStatus() + + out := stdout.String() + for _, want := range []string{ + "Active mode", + "gateway", + "SOCKS proxy", + "127.0.0.1:18080", + "HTTP gateway", + "127.0.0.1:18081", + "HTTPS gateway", + "127.0.0.1:18443", + "HTTPS gateways", + "127.0.0.1:18444", + } { + if !strings.Contains(out, want) { + t.Fatalf("status output missing %q\n%s", want, out) + } + } +} + func TestSanitizedDaemonBaseConfigResetsTransientState(t *testing.T) { cfg := newRootConfig() cfg.ConfigList = true diff --git a/cmd/diode/gateway.go b/cmd/diode/gateway.go index 6787ed28..5d2e17c2 100644 --- a/cmd/diode/gateway.go +++ b/cmd/diode/gateway.go @@ -5,6 +5,7 @@ package main import ( "fmt" + "strings" "github.com/diodechain/diode_client/command" "github.com/diodechain/diode_client/config" @@ -49,9 +50,41 @@ func gatewayHandler() (err error) { if result.HasValidationErrors() { return fmt.Errorf("couldn't apply gateway controls: %v", result.ValidationErrors) } + printGatewayBanner(config.AppConfig) if isDaemonApplyRequest() { return nil } app.Wait() return } + +func printGatewayBanner(cfg *config.Config) { + status := gatewayStatusFromConfig(cfg) + if status.GatewayEnabled { + cfg.PrintLabel("HTTP gateway", status.GatewayAddr) + } + if status.SecureGatewayEnabled { + cfg.PrintLabel("HTTPS gateway", status.SecureGatewayAddr) + } + if len(status.SecureGatewayAdditionalAddrs) > 0 { + cfg.PrintLabel("HTTPS gateways", strings.Join(status.SecureGatewayAdditionalAddrs, ", ")) + } + if status.SocksEnabled { + cfg.PrintLabel("SOCKS proxy", status.SocksAddr) + } +} + +func gatewayStatusFromConfig(cfg *config.Config) daemonRuntimeStatus { + status := daemonRuntimeStatus{ + SocksEnabled: cfg.EnableSocksServer, + SocksAddr: cfg.SocksServerAddr(), + GatewayEnabled: cfg.EnableProxyServer, + GatewayAddr: cfg.ProxyServerAddr(), + SecureGatewayEnabled: cfg.EnableSProxyServer, + SecureGatewayAddr: cfg.SProxyServerAddr(), + } + for _, port := range cfg.SProxyAdditionalPorts() { + status.SecureGatewayAdditionalAddrs = append(status.SecureGatewayAdditionalAddrs, cfg.SProxyServerAddrForPort(port)) + } + return status +} From 22a74b3b86635e517c4ca3f9c707688f3bf6951c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=96mer=20AKG=C3=9CL?= <37873266+tuhalf@users.noreply.github.com> Date: Wed, 27 May 2026 01:09:39 +0200 Subject: [PATCH 12/15] Fix daemon mode command routing --- cmd/diode/app.go | 6 ++ cmd/diode/control_shared.go | 35 ++++++-- cmd/diode/daemon.go | 118 ++++++++++++++++++-------- cmd/diode/daemon_manage.go | 76 ++++++++++------- cmd/diode/daemon_paths.go | 16 ++-- cmd/diode/daemon_test.go | 111 ++++++++++++++++++++++-- cmd/diode/daemon_transport_unix.go | 1 + cmd/diode/daemon_transport_windows.go | 1 + cmd/diode/join.go | 4 +- cmd/diode/mode_helpers.go | 3 + cmd/diode/ssh.go | 86 ++++++++++++------- rpc/socks.go | 15 +++- 12 files changed, 347 insertions(+), 125 deletions(-) diff --git a/cmd/diode/app.go b/cmd/diode/app.go index dd3da5f0..43410214 100644 --- a/cmd/diode/app.go +++ b/cmd/diode/app.go @@ -564,6 +564,12 @@ func (dio *Diode) BeginMode(mode string) { } } +func (dio *Diode) ActiveMode() string { + dio.modeMu.Lock() + defer dio.modeMu.Unlock() + return dio.activeMode +} + func (dio *Diode) ModeStopChan() <-chan struct{} { dio.modeMu.Lock() defer dio.modeMu.Unlock() diff --git a/cmd/diode/control_shared.go b/cmd/diode/control_shared.go index 538fa8da..9d051612 100644 --- a/cmd/diode/control_shared.go +++ b/cmd/diode/control_shared.go @@ -1259,22 +1259,45 @@ func (dio *Diode) loadPersistedSharedControls() error { collect(&diodeCmd.Flag) collect(&dio.cmd.Flag) + if err := loadPersistedSharedControlsInto(dio.config, overrides); err != nil { + return err + } + + dio.controlsLoaded = true + return nil +} + +func loadPersistedSharedControlsInto(cfg *config.Config, overrides map[string]bool) error { + if cfg == nil || cfg.LoadFromFile || db.DB == nil { + return nil + } + effects := controlEffect(0) for _, key := range persistedSharedControlKeys { - if overrides[key] { + if overrides != nil && overrides[key] { continue } value, err := db.DB.Get(sharedControlStorageKey(key)) if err != nil { continue } - if _, err := applySharedControlValue(dio.config, key, string(value)); err != nil { + if _, err := applySharedControlValue(cfg, key, string(value)); err != nil { return fmt.Errorf("could not load persisted %s: %w", key, err) } + if spec, ok := sharedControlSpec(key); ok { + effects |= spec.Effects + } } - - config.NormalizeResolveCache(dio.config) - - dio.controlsLoaded = true + if effects&controlEffectServices != 0 { + if err := syncConfigBindsFromSBinds(cfg); err != nil { + return err + } + } + if effects&controlEffectPublished != 0 { + if err := rebuildPublishedPortState(cfg); err != nil { + return err + } + } + config.NormalizeResolveCache(cfg) return nil } diff --git a/cmd/diode/daemon.go b/cmd/diode/daemon.go index 09d03924..f0bc40d7 100644 --- a/cmd/diode/daemon.go +++ b/cmd/diode/daemon.go @@ -91,7 +91,7 @@ var ( } localBypassCommands = map[string]bool{"": true, "version": true, "mcp": true, "ssh-proxy": true, daemonCommandName: true} daemonApplyModeCmds = map[string]bool{"publish": true, "gateway": true, "socksd": true, "join": true, "files": true} - daemonRunnableCmds = map[string]bool{"query": true, "time": true, "fetch": true, "token": true, "bns": true, "config": true, "reset": true, "push": true, "pull": true, "publish": true, "gateway": true, "socksd": true, "join": true, "files": true, "ssh": true, "update": true} + daemonRunnableCmds = map[string]bool{"query": true, "time": true, "fetch": true, "token": true, "bns": true, "config": true, "reset": true, "push": true, "pull": true, "publish": true, "gateway": true, "socksd": true, "join": true, "files": true, "ssh": true, "scp": true, "update": true} ) type daemonStartupSpec struct { @@ -152,6 +152,7 @@ type daemonResponse struct { Error string `json:"error,omitempty"` ProxyAddr string `json:"proxy_addr,omitempty"` LeaseID string `json:"lease_id,omitempty"` + ModeActive bool `json:"mode_active,omitempty"` RestartPath string `json:"-"` Shutdown bool `json:"-"` } @@ -206,6 +207,12 @@ func daemonHandler() error { return err } defer cleanDiode() + if err := loadPersistedSharedControlsInto(cfg, nil); err != nil { + return err + } + if app != nil { + app.controlsLoaded = true + } socketPath, metaPath, err := daemonPaths() if err != nil { @@ -248,7 +255,7 @@ func daemonHandler() error { if err := runDaemonCommandAsKind(daemonRequestApplyMode, restoreArgs); err != nil { logDaemonInternalError("Couldn't restore daemon mode after restart", err) } else { - daemonState.updateModeSnapshot(restoreArgs[0], restoreArgs, config.AppConfig) + daemonState.updateModeSnapshot(daemonModeNameFromArgs(restoreArgs), restoreArgs, config.AppConfig) } } if err := signalDaemonReady(); err != nil { @@ -312,22 +319,29 @@ func maybeHandleDaemonCLI(args []string) (bool, int) { stderrln(err.Error()) return true, 1 } - if resp.Stdout != "" { - io.WriteString(stdoutWriter(), resp.Stdout) - } - if resp.Stderr != "" { - io.WriteString(stderrWriter(), resp.Stderr) - } + writeDaemonResponse(resp) if req.Kind == daemonRequestLease { return true, runSSHViaDaemonLease(inv.commandArgs, resp) } - if req.Kind == daemonRequestApplyMode && resp.ExitCode == 0 { + if req.Kind == daemonRequestApplyMode && resp.ExitCode == 0 && resp.ModeActive { stdoutf("Daemon mode active: %s\n", inv.command) stdoutln("Use `diode daemon status` to inspect or manage the running daemon.") } return true, resp.ExitCode } +func writeDaemonResponse(resp daemonResponse) { + if resp.Stdout != "" { + _, _ = io.WriteString(stdoutWriter(), resp.Stdout) + } + if resp.Stderr != "" { + _, _ = io.WriteString(stderrWriter(), resp.Stderr) + } + if resp.ExitCode != 0 && resp.Error != "" { + stderrln(resp.Error) + } +} + func daemonRequestForInvocation(inv rootInvocation) (daemonRequest, error) { req := daemonRequest{ Version: daemonProtocolVersion, @@ -335,7 +349,7 @@ func daemonRequestForInvocation(inv rootInvocation) (daemonRequest, error) { Args: inv.execArgs, } switch inv.command { - case "ssh": + case "ssh", "scp": req.Kind = daemonRequestLease case "update": req.Kind = daemonRequestUpdate @@ -486,7 +500,7 @@ func dispatchViaDaemonAttached(spec daemonStartupSpec, req daemonRequest) (bool, } case err := <-errCh: if err == io.EOF { - return true, "", 0, nil + return true, "", 1, fmt.Errorf("daemon stream ended before final response") } return true, "", 1, err case <-sigCh: @@ -635,7 +649,7 @@ func executeDaemonRequest(req daemonRequest) daemonResponse { return "", runDaemonCommandArgs(req.Args) }) if req.Kind == daemonRequestApplyMode && resp.ExitCode == 0 { - daemonState.updateModeSnapshot(req.Command, req.Args, config.AppConfig) + resp.ModeActive = updateDaemonModeSnapshotIfActive(req.Command, req.Args) } return resp } @@ -755,7 +769,7 @@ func executeDaemonAttachedRequest(conn net.Conn, req daemonRequest) { exitCode = 1 } if err == nil { - daemonState.updateModeSnapshot(req.Command, req.Args, config.AppConfig) + _ = updateDaemonModeSnapshotIfActive(req.Command, req.Args) } daemonExecMu.Unlock() @@ -763,11 +777,29 @@ func executeDaemonAttachedRequest(conn net.Conn, req daemonRequest) { _ = sendFrame(daemonStreamFrame{Type: daemonStreamFinal, ExitCode: exitCode, Error: err.Error()}) return } + if app.ActiveMode() != req.Command { + _ = sendFrame(daemonStreamFrame{Type: daemonStreamFinal, ExitCode: 0}) + return + } daemonState.waitForModeExit(req.Command, req.Args, app.closeCh) _ = sendFrame(daemonStreamFrame{Type: daemonStreamFinal, ExitCode: 0}) } +func updateDaemonModeSnapshotIfActive(command string, args []string) bool { + if app == nil || daemonState == nil { + return false + } + switch activeMode := app.ActiveMode(); activeMode { + case command: + daemonState.updateModeSnapshot(command, args, config.AppConfig) + return true + case "": + daemonState.clearModeSnapshot() + } + return false +} + type daemonFrameWriter struct { typ string send func(daemonStreamFrame) error @@ -787,6 +819,7 @@ func runDaemonCommandArgs(args []string) error { if len(args) == 0 { return newExitStatusError(2, "missing command") } + resetSharedControlsForArgs(config.AppConfig, args) if err := diodeCmd.Flag.Parse(args); err != nil { return err } @@ -811,6 +844,21 @@ func runDaemonCommandArgs(args []string) error { return subCmd.Run() } +func resetSharedControlsForArgs(cfg *config.Config, args []string) { + if cfg == nil { + return + } + for _, item := range parseRootExecItems(args) { + if !strings.HasPrefix(item.flagName, "-") { + continue + } + name := strings.TrimLeft(item.flagName, "-") + for _, key := range sharedControlFlagKeys(name) { + resetSharedControlValue(cfg, key) + } + } +} + func refreshRequestDerivedConfig(cfg *config.Config) error { if cfg == nil { return nil @@ -1002,29 +1050,9 @@ func resetTransientConfig(cfg *config.Config) { cfg.QueryAddress = "" cfg.ConfigUnsafe = false cfg.ConfigList = false + cfg.ConfigFullValues = false cfg.ConfigDelete = nil cfg.ConfigSet = nil - cfg.PublicPublishedPorts = nil - cfg.ProtectedPublishedPorts = nil - cfg.PrivatePublishedPorts = nil - cfg.SSHPublishedServices = nil - cfg.PublishedPorts = nil - cfg.SBinds = nil - cfg.Binds = nil - cfg.EnableProxyServer = false - cfg.EnableSProxyServer = false - cfg.EnableSocksServer = false - cfg.SocksServerHost = "127.0.0.1" - cfg.SocksServerPort = 1080 - cfg.SocksFallback = "localhost" - cfg.ProxyServerHost = "127.0.0.1" - cfg.ProxyServerPort = 80 - cfg.SProxyServerHost = "127.0.0.1" - cfg.SProxyServerPort = 443 - cfg.SProxyServerPorts = "" - cfg.SProxyServerCertPath = "./priv/fullchain.pem" - cfg.SProxyServerPrivPath = "./priv/privkey.pem" - cfg.AllowRedirectToSProxy = false cfg.BNSForce = false cfg.BNSRegister = "" cfg.BNSUnregister = "" @@ -1077,7 +1105,7 @@ func exitCodeFromError(err error) int { func daemonStartupSpecFromConfig(cfg *config.Config) daemonStartupSpec { return daemonStartupSpec{ - DBPath: cfg.DBPath, + DBPath: canonicalDaemonDBPath(cfg.DBPath), RetryTimes: cfg.RetryTimes, EdgeE2ETimeout: cfg.EdgeE2ETimeout, EnableUpdate: cfg.EnableUpdate, @@ -1112,6 +1140,17 @@ func daemonStartupSpecFromConfig(cfg *config.Config) daemonStartupSpec { } } +func daemonModeNameFromArgs(args []string) string { + inv, err := parseRootInvocation(args) + if err == nil && inv.command != "" { + return inv.command + } + if len(args) == 0 { + return "" + } + return args[0] +} + func applyDaemonStartupSpec(cfg *config.Config, spec daemonStartupSpec) { cfg.DBPath = spec.DBPath cfg.RetryTimes = spec.RetryTimes @@ -1224,6 +1263,15 @@ func daemonRestartEnv(spec daemonStartupSpec) ([]string, error) { return env, nil } +func prepareDaemonForRestart() { + if app != nil { + app.StopMode() + } + if daemonState != nil { + daemonState.clearModeSnapshot() + } +} + func daemonRestoreArgsFromEnv() ([]string, error) { raw := strings.TrimSpace(os.Getenv(envDaemonRestoreArgs)) if raw == "" { diff --git a/cmd/diode/daemon_manage.go b/cmd/diode/daemon_manage.go index 40226101..639cb2d0 100644 --- a/cmd/diode/daemon_manage.go +++ b/cmd/diode/daemon_manage.go @@ -3,7 +3,6 @@ package main import ( "encoding/json" "fmt" - "io" "os" "sort" "strconv" @@ -34,7 +33,7 @@ var ( daemonManageCmd = &command.Command{ Name: "daemon", HelpText: ` Inspect and manage the running diode daemon.`, - ExampleText: " diode daemon status\n diode daemon stop\n diode daemon ports remove 80 443\n diode daemon ports clear", + ExampleText: " diode daemon status\n diode daemon stop\n diode daemon mode-stop\n diode daemon ports remove 80 443\n diode daemon ports clear", Run: daemonManageHandler, Type: command.EmptyConnectionCommand, PassThroughArgs: true, @@ -64,10 +63,21 @@ func handleDaemonManagerCLI(args []string) (bool, int) { return true, runDaemonManagerAction([]string{"daemon", "stop"}) case "restart": return true, runDaemonManagerRestart() + case "mode-stop": + return true, runDaemonManagerAction([]string{"daemon", "mode-stop"}) + case "mode": + if len(subArgs) == 2 && subArgs[1] == "stop" { + return true, runDaemonManagerAction([]string{"daemon", "mode-stop"}) + } + stderrln("usage: diode daemon [status|stop|restart|mode-stop|ports]") + stderrln(" diode daemon mode stop") + stderrln(" diode daemon ports [remove|clear]") + return true, 2 case "ports": return true, runDaemonManagerPorts(subArgs[1:]) default: - stderrln("usage: diode daemon [status|stop|restart|ports]") + stderrln("usage: diode daemon [status|stop|restart|mode-stop|ports]") + stderrln(" diode daemon mode stop") stderrln(" diode daemon ports [remove|clear]") return true, 2 } @@ -88,12 +98,7 @@ func runDaemonManagerStatus() int { stdoutln("Daemon status: not running") return 0 } - if resp.Stdout != "" { - _, _ = io.WriteString(stdoutWriter(), resp.Stdout) - } - if resp.Stderr != "" { - _, _ = io.WriteString(stderrWriter(), resp.Stderr) - } + writeDaemonResponse(resp) return resp.ExitCode } @@ -134,12 +139,7 @@ func runDaemonManagerAction(args []string) int { stdoutln("Daemon status: not running") return 1 } - if resp.Stdout != "" { - _, _ = io.WriteString(stdoutWriter(), resp.Stdout) - } - if resp.Stderr != "" { - _, _ = io.WriteString(stderrWriter(), resp.Stderr) - } + writeDaemonResponse(resp) if len(args) >= 2 && args[0] == "daemon" && args[1] == "stop" && resp.ExitCode == 0 { deadline := time.Now().Add(10 * time.Second) for time.Now().Before(deadline) { @@ -175,12 +175,7 @@ func runDaemonManagerRestart() int { stdoutln("Daemon status: not running") return 1 } - if resp.Stdout != "" { - _, _ = io.WriteString(stdoutWriter(), resp.Stdout) - } - if resp.Stderr != "" { - _, _ = io.WriteString(stderrWriter(), resp.Stderr) - } + writeDaemonResponse(resp) if resp.ExitCode != 0 { return resp.ExitCode } @@ -462,7 +457,7 @@ func daemonReapplyPublishWithoutPorts(args []string, ports []int) error { return newExitStatusError(1, "none of the requested ports are currently configured in publish mode") } sort.Ints(removed) - if countPublishManagedFlags(newArgs) == 0 && len(config.AppConfig.Binds) == 0 { + if countPublishManagedFlags(newArgs) == 0 && len(config.AppConfig.Binds) == 0 && !publishArgsHaveRootBinds(newArgs) { app.StopMode() daemonState.clearModeSnapshot() stdoutf("Removed published ports: %s\n", joinPorts(removed)) @@ -479,13 +474,15 @@ func daemonReapplyPublishWithoutPorts(args []string, ports []int) error { } func filterPublishCommandArgs(args []string, removePorts map[int]bool) ([]string, []int, error) { - if len(args) == 0 || args[0] != "publish" { + pre, post, ok := splitPublishExecArgs(args) + if !ok { return nil, nil, newExitStatusError(1, "daemon is not tracking a publish command") } - filtered := []string{args[0]} + filtered := append([]string{}, pre...) + filtered = append(filtered, "publish") removed := make(map[int]bool) - for i := 1; i < len(args); i++ { - arg := args[i] + for i := 0; i < len(post); i++ { + arg := post[i] flagName, inlineValue, matched := parseManagedPublishFlag(arg) if !matched { filtered = append(filtered, arg) @@ -493,11 +490,11 @@ func filterPublishCommandArgs(args []string, removePorts map[int]bool) ([]string } value := inlineValue if value == "" { - if i+1 >= len(args) { + if i+1 >= len(post) { return nil, nil, newExitStatusError(2, "flag %s is missing a value", flagName) } i++ - value = args[i] + value = post[i] } externPort, err := managedFlagExternPort(flagName, value) if err != nil { @@ -579,14 +576,18 @@ func extractExternPortFromPortSpec(value string) (int, error) { } func countPublishManagedFlags(args []string) int { + _, post, ok := splitPublishExecArgs(args) + if !ok { + return 0 + } count := 0 - for i := 1; i < len(args); i++ { - flagName, inlineValue, matched := parseManagedPublishFlag(args[i]) + for i := 0; i < len(post); i++ { + flagName, inlineValue, matched := parseManagedPublishFlag(post[i]) if !matched { continue } count++ - if inlineValue == "" && i+1 < len(args) { + if inlineValue == "" && i+1 < len(post) { i++ } _ = flagName @@ -594,6 +595,19 @@ func countPublishManagedFlags(args []string) int { return count } +func publishArgsHaveRootBinds(args []string) bool { + pre, _, ok := splitPublishExecArgs(args) + if !ok { + return false + } + for _, item := range parseRootExecItems(pre) { + if item.flagName == "-bind" { + return true + } + } + return false +} + func joinPorts(ports []int) string { items := make([]string, 0, len(ports)) for _, port := range ports { diff --git a/cmd/diode/daemon_paths.go b/cmd/diode/daemon_paths.go index 8d7a901f..a54dc3c8 100644 --- a/cmd/diode/daemon_paths.go +++ b/cmd/diode/daemon_paths.go @@ -15,18 +15,22 @@ func daemonPathID() string { if config.AppConfig != nil { dbPath = config.AppConfig.DBPath } + seed := canonicalDaemonDBPath(dbPath) + if base, err := os.UserConfigDir(); err == nil { + seed = filepath.Clean(base) + "\x00" + seed + } + sum := sha256.Sum256([]byte(seed)) + return hex.EncodeToString(sum[:8]) +} + +func canonicalDaemonDBPath(dbPath string) string { if dbPath == "" { dbPath = util.DefaultDBPath() } if abs, err := filepath.Abs(dbPath); err == nil { dbPath = abs } - seed := filepath.Clean(dbPath) - if base, err := os.UserConfigDir(); err == nil { - seed = filepath.Clean(base) + "\x00" + seed - } - sum := sha256.Sum256([]byte(seed)) - return hex.EncodeToString(sum[:8]) + return filepath.Clean(dbPath) } func daemonPathDir() (string, error) { diff --git a/cmd/diode/daemon_test.go b/cmd/diode/daemon_test.go index 1c386064..3297976c 100644 --- a/cmd/diode/daemon_test.go +++ b/cmd/diode/daemon_test.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "path/filepath" "strings" "testing" "time" @@ -90,6 +91,20 @@ func TestDaemonRequestForInvocationDetachedApplyMode(t *testing.T) { } } +func TestDaemonRequestForInvocationRoutesSCPThroughLease(t *testing.T) { + inv, err := parseRootInvocation([]string{"scp", "./a", "host.diode:/tmp/a"}) + if err != nil { + t.Fatalf("parseRootInvocation() error = %v", err) + } + req, err := daemonRequestForInvocation(inv) + if err != nil { + t.Fatalf("daemonRequestForInvocation() error = %v", err) + } + if req.Kind != daemonRequestLease { + t.Fatalf("request kind = %q, want %q", req.Kind, daemonRequestLease) + } +} + func TestDaemonRequestForInvocationRejectsDetachedOneOff(t *testing.T) { inv, err := parseRootInvocation([]string{"-d", "query", "-address", "0xabc"}) if err != nil { @@ -104,6 +119,20 @@ func TestDaemonRequestForInvocationRejectsDetachedOneOff(t *testing.T) { } } +func TestDaemonStartupSpecCanonicalizesDBPath(t *testing.T) { + cfg := newRootConfig() + cfg.DBPath = filepath.Join(".", "relative-wallet.db") + spec := daemonStartupSpecFromConfig(cfg) + want, err := filepath.Abs(cfg.DBPath) + if err != nil { + t.Fatalf("filepath.Abs() error = %v", err) + } + want = filepath.Clean(want) + if spec.DBPath != want { + t.Fatalf("startup DBPath = %q, want %q", spec.DBPath, want) + } +} + func TestDaemonPathsAreScopedByDBPath(t *testing.T) { prevCfg := config.AppConfig t.Cleanup(func() { @@ -271,7 +300,7 @@ func TestRenderDaemonStatusIncludesGatewayListeners(t *testing.T) { } } -func TestSanitizedDaemonBaseConfigResetsTransientState(t *testing.T) { +func TestSanitizedDaemonBaseConfigResetsRequestOnlyState(t *testing.T) { cfg := newRootConfig() cfg.ConfigList = true cfg.QueryAddress = "0x1" @@ -289,14 +318,14 @@ func TestSanitizedDaemonBaseConfigResetsTransientState(t *testing.T) { if sanitized.QueryAddress != "" { t.Fatalf("QueryAddress = %q, want empty", sanitized.QueryAddress) } - if sanitized.EnableProxyServer { - t.Fatalf("EnableProxyServer = true, want false") + if !sanitized.EnableProxyServer { + t.Fatalf("EnableProxyServer = false, want true") } - if sanitized.SocksServerPort != 1080 { - t.Fatalf("SocksServerPort = %d, want 1080", sanitized.SocksServerPort) + if sanitized.SocksServerPort != 9999 { + t.Fatalf("SocksServerPort = %d, want 9999", sanitized.SocksServerPort) } - if len(sanitized.PublicPublishedPorts) != 0 { - t.Fatalf("PublicPublishedPorts = %#v, want empty", sanitized.PublicPublishedPorts) + if len(sanitized.PublicPublishedPorts) != 1 { + t.Fatalf("PublicPublishedPorts = %#v, want preserved", sanitized.PublicPublishedPorts) } if sanitized.BNSLookup != "" { t.Fatalf("BNSLookup = %q, want empty", sanitized.BNSLookup) @@ -304,8 +333,8 @@ func TestSanitizedDaemonBaseConfigResetsTransientState(t *testing.T) { if sanitized.StdoutWriter != nil || sanitized.StderrWriter != nil { t.Fatalf("stdout/stderr writers should be cleared") } - if sanitized.PublishedPorts != nil { - t.Fatalf("PublishedPorts = %#v, want nil", sanitized.PublishedPorts) + if len(sanitized.PublishedPorts) != 1 { + t.Fatalf("PublishedPorts = %#v, want preserved", sanitized.PublishedPorts) } } @@ -496,6 +525,70 @@ func TestFilterPublishCommandArgsRemovesRequestedPorts(t *testing.T) { } } +func TestFilterPublishCommandArgsPreservesRootBinds(t *testing.T) { + args := []string{ + "-bind", "auto:0x1234567890123456789012345678901234567890:80:tcp", + "publish", + "-public", "80:80", + } + filtered, removed, err := filterPublishCommandArgs(args, map[int]bool{80: true}) + if err != nil { + t.Fatalf("filterPublishCommandArgs() error = %v", err) + } + want := []string{ + "-bind", "auto:0x1234567890123456789012345678901234567890:80:tcp", + "publish", + } + if strings.Join(filtered, "\x00") != strings.Join(want, "\x00") { + t.Fatalf("filtered = %#v, want %#v", filtered, want) + } + if len(removed) != 1 || removed[0] != 80 { + t.Fatalf("removed = %#v, want [80]", removed) + } + if !publishArgsHaveRootBinds(filtered) { + t.Fatal("publishArgsHaveRootBinds() = false, want true") + } + if got := countPublishManagedFlags(filtered); got != 0 { + t.Fatalf("countPublishManagedFlags() = %d, want 0", got) + } +} + +func TestDaemonModeNameFromArgsFindsImplicitPublish(t *testing.T) { + got := daemonModeNameFromArgs([]string{ + "-bind", "auto:0x1234567890123456789012345678901234567890:80:tcp", + "publish", + }) + if got != "publish" { + t.Fatalf("daemonModeNameFromArgs() = %q, want publish", got) + } +} + +func TestResetTransientConfigClearsConfigFullValues(t *testing.T) { + cfg := newRootConfig() + cfg.ConfigFullValues = true + resetTransientConfig(cfg) + if cfg.ConfigFullValues { + t.Fatal("ConfigFullValues = true, want false") + } +} + +func TestResetSharedControlsForArgsClearsOnlyOverriddenLists(t *testing.T) { + cfg := newRootConfig() + cfg.SocksServerPort = 23104 + cfg.PublicPublishedPorts = config.StringValues{"80:80"} + cfg.SBinds = config.StringValues{"auto:0x1234567890123456789012345678901234567890:80:tcp"} + resetSharedControlsForArgs(cfg, []string{"publish", "-public", "8080:80"}) + if cfg.SocksServerPort != 23104 { + t.Fatalf("SocksServerPort = %d, want preserved 23104", cfg.SocksServerPort) + } + if len(cfg.SBinds) != 1 { + t.Fatalf("SBinds = %#v, want preserved", cfg.SBinds) + } + if len(cfg.PublicPublishedPorts) != 0 { + t.Fatalf("PublicPublishedPorts = %#v, want reset", cfg.PublicPublishedPorts) + } +} + func TestManagedFlagExternPortSupportsFilesSpec(t *testing.T) { port, err := managedFlagExternPort("-files", "8080,example.diode") if err != nil { diff --git a/cmd/diode/daemon_transport_unix.go b/cmd/diode/daemon_transport_unix.go index f28d8edd..c5d63bfe 100644 --- a/cmd/diode/daemon_transport_unix.go +++ b/cmd/diode/daemon_transport_unix.go @@ -114,6 +114,7 @@ func daemonRestartSelf(cmd string, startup daemonStartupSpec) error { if err != nil { return err } + prepareDaemonForRestart() r, w, err := os.Pipe() if err != nil { return err diff --git a/cmd/diode/daemon_transport_windows.go b/cmd/diode/daemon_transport_windows.go index f55283d6..9b0b0cd5 100644 --- a/cmd/diode/daemon_transport_windows.go +++ b/cmd/diode/daemon_transport_windows.go @@ -96,6 +96,7 @@ func daemonRestartSelf(cmd string, startup daemonStartupSpec) error { if err != nil { return err } + prepareDaemonForRestart() devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0600) if err != nil { return err diff --git a/cmd/diode/join.go b/cmd/diode/join.go index b7fe2d33..12c4c867 100644 --- a/cmd/diode/join.go +++ b/cmd/diode/join.go @@ -2739,9 +2739,6 @@ func joinHandler() (err error) { if err != nil { return } - if isDaemonApplyRequest() { - beginRuntimeMode("join") - } // Initial contract sync to apply perimeter before starting services if syncErr := runContractControllerOnce(cfg); syncErr != nil { @@ -2753,6 +2750,7 @@ func joinHandler() (err error) { return nil } if isDaemonApplyRequest() { + beginRuntimeMode("join") done := make(chan struct{}) app.SetModeDone(done) go func() { diff --git a/cmd/diode/mode_helpers.go b/cmd/diode/mode_helpers.go index ebbd69d6..565f7d43 100644 --- a/cmd/diode/mode_helpers.go +++ b/cmd/diode/mode_helpers.go @@ -5,6 +5,9 @@ func beginRuntimeMode(name string) { return } app.StopMode() + if daemonState != nil { + daemonState.clearModeSnapshot() + } app.BeginMode(name) } diff --git a/cmd/diode/ssh.go b/cmd/diode/ssh.go index 99e8c375..e2511f78 100644 --- a/cmd/diode/ssh.go +++ b/cmd/diode/ssh.go @@ -45,17 +45,8 @@ var ( var runtimeGOOS = runtime.GOOS func sshHandler() error { - return runSSHLikeTool(sshLikeToolOptions{ - commandName: sshCommandName, - toolName: "ssh", - validateLabel: "Invalid SSH target", - validateArgs: func(args []string) error { - if target := extractSSHTarget(args); target != "" { - return validateSSHTarget(target) - } - return nil - }, - }) + opts, _ := sshLikeOptionsForCommand(sshCommandName) + return runSSHLikeTool(opts) } // sshLikeToolOptions configures runSSHLikeTool for a specific OpenSSH-based @@ -122,25 +113,6 @@ func sshLikePassThroughArgs(commandName string, osArgs []string) ([]string, erro return nil, fmt.Errorf("%s command not found", commandName) } -func runSSHWithProxyAddr(proxyAddr string, sshArgs []string) error { - opts := sshLikeToolOptions{ - commandName: sshCommandName, - toolName: "ssh", - validateLabel: "Invalid SSH target", - validateArgs: func(args []string) error { - if target := extractSSHTarget(args); target != "" { - return validateSSHTarget(target) - } - return nil - }, - } - if err := opts.validateArgs(sshArgs); err != nil { - config.AppConfig.PrintError(opts.validateLabel, err) - return newExitStatusError(1, "%s", err.Error()) - } - return runSSHToolWithProxyAddr(proxyAddr, opts, sshArgs) -} - func runSSHToolWithProxyAddr(proxyAddr string, opts sshLikeToolOptions, passArgs []string) error { cfg := config.AppConfig toolName := opts.toolName @@ -185,12 +157,21 @@ func runSSHToolWithProxyAddr(proxyAddr string, opts sshLikeToolOptions, passArgs } func runSSHViaDaemonLease(commandArgs []string, resp daemonResponse) int { + return runSSHLikeViaDaemonLease(commandArgs, resp) +} + +func runSSHLikeViaDaemonLease(commandArgs []string, resp daemonResponse) int { if len(commandArgs) == 0 { - stderrln("missing ssh command arguments") + stderrln("missing ssh-like command arguments") return 1 } if err := ensureDaemonSSHForegroundLogger(); err != nil { - stderrln(fmt.Sprintf("could not initialize ssh command logger: %v", err)) + stderrln(fmt.Sprintf("could not initialize %s command logger: %v", commandArgs[0], err)) + return 1 + } + opts, ok := sshLikeOptionsForCommand(commandArgs[0]) + if !ok { + stderrln(fmt.Sprintf("unsupported daemon proxy command: %s", commandArgs[0])) return 1 } cfg := config.AppConfig @@ -200,12 +181,51 @@ func runSSHViaDaemonLease(commandArgs []string, resp daemonResponse) int { _ = releaseDaemonLease(resp.LeaseID) } }() - if err := runSSHWithProxyAddr(resp.ProxyAddr, normalizeSSHArgs(commandArgs[1:])); err != nil { + passArgs := normalizeSSHArgs(commandArgs[1:]) + if opts.validateArgs != nil { + if err := opts.validateArgs(passArgs); err != nil { + label := opts.validateLabel + if label == "" { + label = fmt.Sprintf("Invalid %s argument", opts.commandName) + } + cfg.PrintError(label, err) + return exitCodeFromError(newExitStatusError(1, "%s", err.Error())) + } + } + if err := runSSHToolWithProxyAddr(resp.ProxyAddr, opts, passArgs); err != nil { return exitCodeFromError(err) } return 0 } +func sshLikeOptionsForCommand(commandName string) (sshLikeToolOptions, bool) { + switch commandName { + case sshCommandName: + return sshLikeToolOptions{ + commandName: sshCommandName, + toolName: "ssh", + validateLabel: "Invalid SSH target", + validateArgs: func(args []string) error { + if target := extractSSHTarget(args); target != "" { + return validateSSHTarget(target) + } + return nil + }, + }, true + case scpCommandName: + return sshLikeToolOptions{ + commandName: scpCommandName, + toolName: "scp", + validateLabel: "Invalid scp target", + validateArgs: func(args []string) error { + return validateSCPArgs(args) + }, + }, true + default: + return sshLikeToolOptions{}, false + } +} + func ensureDaemonSSHForegroundLogger() error { cfg := config.AppConfig if cfg == nil { diff --git a/rpc/socks.go b/rpc/socks.go index 85448907..14e705b1 100644 --- a/rpc/socks.go +++ b/rpc/socks.go @@ -848,7 +848,7 @@ func (socksServer *Server) Start() error { go func() { buf := make([]byte, 2048) - for { + for !socksServer.Closed() { socksServer.handleUDP(buf) } }() @@ -857,8 +857,15 @@ func (socksServer *Server) Start() error { } func (socksServer *Server) handleUDP(packet []byte) { - n, addr, err := socksServer.udpconn.ReadFrom(packet) + udpconn := socksServer.udpconn + if udpconn == nil { + return + } + n, addr, err := udpconn.ReadFrom(packet) if err != nil { + if socksServer.Closed() { + return + } socksServer.logger.Error("handleUDP error: %v", err) return } @@ -1202,6 +1209,10 @@ func (socksServer *Server) Close() { socksServer.listener.Close() socksServer.listener = nil } + if socksServer.udpconn != nil { + socksServer.udpconn.Close() + socksServer.udpconn = nil + } socksServer.bindsMu.RLock() for _, bind := range socksServer.binds { if bind.tcp != nil { From 245661a7968ee41fcf6dc927acedd4c1ef82e490 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=96mer=20AKG=C3=9CL?= <37873266+tuhalf@users.noreply.github.com> Date: Thu, 28 May 2026 14:37:10 +0200 Subject: [PATCH 13/15] Fix daemon ssh lease and busy handling --- cmd/diode/daemon.go | 64 +++++++++++++++++++++++++++++---------- cmd/diode/daemon_test.go | 65 ++++++++++++++++++++++++++++++++++++++++ cmd/diode/ssh.go | 46 ++++++++++++++++++---------- cmd/diode/ssh_test.go | 50 +++++++++++++++++++++++++++++++ 4 files changed, 194 insertions(+), 31 deletions(-) diff --git a/cmd/diode/daemon.go b/cmd/diode/daemon.go index f0bc40d7..d0a785e9 100644 --- a/cmd/diode/daemon.go +++ b/cmd/diode/daemon.go @@ -476,6 +476,7 @@ func dispatchViaDaemonAttached(spec daemonStartupSpec, req daemonRequest) (bool, sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, daemonSignals()...) defer signal.Stop(sigCh) + stopErrCh := make(chan error, 1) stopRequested := false for { select { @@ -503,14 +504,18 @@ func dispatchViaDaemonAttached(spec daemonStartupSpec, req daemonRequest) (bool, return true, "", 1, fmt.Errorf("daemon stream ended before final response") } return true, "", 1, err + case err := <-stopErrCh: + if err != nil { + return true, "", 1, err + } case <-sigCh: if stopRequested { return true, "", 130, nil } stopRequested = true - if err := requestDaemonModeStop(); err != nil { - return true, "", 1, err - } + go func() { + stopErrCh <- requestDaemonModeStop() + }() } } } @@ -601,10 +606,16 @@ func handleDaemonConn(conn net.Conn) { func executeDaemonRequest(req daemonRequest) daemonResponse { if req.Kind == daemonRequestUpdate { + if !daemonExecMu.TryLock() { + return executeDaemonBusyResponse(req) + } + daemonExecMu.Unlock() return executeDaemonUpdateRequest(req) } - daemonExecMu.Lock() + if !daemonExecMu.TryLock() { + return executeDaemonBusyResponse(req) + } defer daemonExecMu.Unlock() resp := daemonResponse{Version: daemonProtocolVersion} @@ -654,6 +665,39 @@ func executeDaemonRequest(req daemonRequest) daemonResponse { return resp } +func executeDaemonBusyResponse(req daemonRequest) daemonResponse { + resp := daemonResponse{Version: daemonProtocolVersion} + if req.Kind == daemonRequestManage && len(req.Args) >= 2 && req.Args[0] == "daemon" { + switch req.Args[1] { + case "stop": + resp.Stdout = "Stopping diode daemon.\n" + if app != nil { + app.StopMode() + } + if daemonState != nil { + daemonState.clearModeSnapshot() + } + resp.Shutdown = true + return resp + case "mode-stop": + resp.Stdout = "Stopping diode daemon mode.\n" + if app != nil { + app.StopMode() + } + if daemonState != nil { + daemonState.clearModeSnapshot() + } + return resp + case "status": + resp.Stdout = "Daemon status : busy\nActive request : starting or stopping a daemon command\n" + return resp + } + } + resp.ExitCode = 1 + resp.Error = "daemon is busy starting or stopping another command; run `diode daemon stop` to reset it" + return resp +} + func executeDaemonUpdateRequest(req daemonRequest) daemonResponse { resp := daemonResponse{Version: daemonProtocolVersion} var stdout bytes.Buffer @@ -1461,17 +1505,7 @@ func daemonLeaseLocalProxy() (string, string, error) { if err := app.Start(); err != nil { return "", "", err } - cfg := config.AppConfig - socksCfg := rpc.Config{ - Addr: net.JoinHostPort("127.0.0.1", "0"), - FleetAddr: cfg.FleetAddr, - Blocklists: cfg.Blocklists(), - Allowlists: cfg.Allowlists, - EnableProxy: false, - ProxyServerAddr: cfg.ProxyServerAddr(), - Fallback: cfg.SocksFallback, - } - socksServer, err := rpc.NewSocksServer(socksCfg, app.clientManager) + socksServer, err := rpc.NewSocksServer(sshLocalSocksConfig(config.AppConfig), app.clientManager) if err != nil { return "", "", err } diff --git a/cmd/diode/daemon_test.go b/cmd/diode/daemon_test.go index 3297976c..8ca1219b 100644 --- a/cmd/diode/daemon_test.go +++ b/cmd/diode/daemon_test.go @@ -206,6 +206,71 @@ func TestRunDaemonManageModeStopLeavesDaemonRunning(t *testing.T) { } } +func TestExecuteDaemonBusyModeStopDoesNotWaitForExecLock(t *testing.T) { + cfg := newSharedControlTestConfig(t) + setupSharedControlTestEnv(t, cfg) + + origState := daemonState + daemonState = &runtimeDaemon{ + modeChange: make(chan struct{}), + activeMode: "publish", + activeArgs: []string{"publish", "-public", "80"}, + } + t.Cleanup(func() { + daemonState = origState + }) + app.BeginMode("publish") + + daemonExecMu.Lock() + resp := executeDaemonRequest(daemonRequest{ + Version: daemonProtocolVersion, + Kind: daemonRequestManage, + Command: "daemon", + Args: []string{"daemon", "mode-stop"}, + }) + daemonExecMu.Unlock() + + if resp.ExitCode != 0 { + t.Fatalf("ExitCode = %d, Error = %q", resp.ExitCode, resp.Error) + } + if !strings.Contains(resp.Stdout, "Stopping diode daemon mode.") { + t.Fatalf("Stdout = %q, want mode-stop message", resp.Stdout) + } + if got := app.ActiveMode(); got != "" { + t.Fatalf("ActiveMode() = %q, want empty", got) + } + if status := daemonState.snapshotStatus(); status.ActiveMode != "" { + t.Fatalf("snapshot ActiveMode = %q, want empty", status.ActiveMode) + } +} + +func TestExecuteDaemonBusyLeaseReturnsActionableError(t *testing.T) { + cfg := newSharedControlTestConfig(t) + setupSharedControlTestEnv(t, cfg) + + origState := daemonState + daemonState = &runtimeDaemon{modeChange: make(chan struct{})} + t.Cleanup(func() { + daemonState = origState + }) + + daemonExecMu.Lock() + resp := executeDaemonRequest(daemonRequest{ + Version: daemonProtocolVersion, + Kind: daemonRequestLease, + Command: "ssh", + Args: []string{"ssh", "ubuntu@example.diode"}, + }) + daemonExecMu.Unlock() + + if resp.ExitCode == 0 { + t.Fatalf("ExitCode = 0, want busy failure") + } + if !strings.Contains(resp.Error, "daemon is busy") { + t.Fatalf("Error = %q, want busy error", resp.Error) + } +} + func TestStopModeTimesOutWaitingForDone(t *testing.T) { prevTimeout := modeStopWaitTimeout modeStopWaitTimeout = 5 * time.Millisecond diff --git a/cmd/diode/ssh.go b/cmd/diode/ssh.go index 3a685c6e..e3648ae1 100644 --- a/cmd/diode/ssh.go +++ b/cmd/diode/ssh.go @@ -174,13 +174,24 @@ func runSSHLikeViaDaemonLease(commandArgs []string, resp daemonResponse) int { stderrln(fmt.Sprintf("unsupported daemon proxy command: %s", commandArgs[0])) return 1 } + if resp.LeaseID != "" { + defer func() { + _ = releaseDaemonLease(resp.LeaseID) + }() + } + if resp.ExitCode != 0 { + return resp.ExitCode + } + if resp.Error != "" { + stderrln(resp.Error) + return 1 + } + if strings.TrimSpace(resp.ProxyAddr) == "" { + stderrln("daemon proxy lease did not expose an address") + return 1 + } cfg := config.AppConfig cfg.PrintLabel("Using diode daemon proxy", resp.ProxyAddr) - defer func() { - if resp.LeaseID != "" { - _ = releaseDaemonLease(resp.LeaseID) - } - }() passArgs := normalizeSSHArgs(commandArgs[1:]) if opts.validateArgs != nil { if err := opts.validateArgs(passArgs); err != nil { @@ -289,17 +300,7 @@ func subcommandPassThroughArgs(args []string, name string) ([]string, error) { // on this instance (rpc.Config); that is not the same as cfg.EnableSocksServer. func startSSHLocalSocksProxy() (string, func(), error) { cfg := config.AppConfig - socksCfg := rpc.Config{ - EnableSocks: true, - Addr: net.JoinHostPort("127.0.0.1", "0"), - FleetAddr: cfg.FleetAddr, - Blocklists: cfg.Blocklists(), - Allowlists: cfg.Allowlists, - EnableProxy: false, - ProxyServerAddr: cfg.ProxyServerAddr(), - Fallback: cfg.SocksFallback, - } - socksServer, err := rpc.NewSocksServer(socksCfg, app.clientManager) + socksServer, err := rpc.NewSocksServer(sshLocalSocksConfig(cfg), app.clientManager) if err != nil { return "", nil, err } @@ -326,6 +327,19 @@ func startSSHLocalSocksProxy() (string, func(), error) { return net.JoinHostPort(host, strconv.Itoa(tcpAddr.Port)), cleanup, nil } +func sshLocalSocksConfig(cfg *config.Config) rpc.Config { + return rpc.Config{ + EnableSocks: true, + Addr: net.JoinHostPort("127.0.0.1", "0"), + FleetAddr: cfg.FleetAddr, + Blocklists: cfg.Blocklists(), + Allowlists: cfg.Allowlists, + EnableProxy: false, + ProxyServerAddr: cfg.ProxyServerAddr(), + Fallback: cfg.SocksFallback, + } +} + func createEphemeralSSHIdentity() (string, func(), error) { sshKeygen, err := findOpenSSHTool("ssh-keygen") if err != nil { diff --git a/cmd/diode/ssh_test.go b/cmd/diode/ssh_test.go index c7510012..713ce7a8 100644 --- a/cmd/diode/ssh_test.go +++ b/cmd/diode/ssh_test.go @@ -202,6 +202,56 @@ func TestRunSSHViaDaemonLeaseInitializesForegroundLogger(t *testing.T) { } } +func TestSSHLocalSocksConfigEnablesListener(t *testing.T) { + cfg := newRootConfig() + got := sshLocalSocksConfig(cfg) + if !got.EnableSocks { + t.Fatal("sshLocalSocksConfig() did not enable SOCKS listener") + } + if got.Addr != net.JoinHostPort("127.0.0.1", "0") { + t.Fatalf("sshLocalSocksConfig() Addr = %q, want ephemeral localhost", got.Addr) + } +} + +func TestRunSSHViaDaemonLeaseStopsOnLeaseFailure(t *testing.T) { + origCfg := config.AppConfig + config.AppConfig = newRootConfig() + t.Cleanup(func() { + config.AppConfig = origCfg + }) + + code := runSSHViaDaemonLease([]string{"ssh", "ubuntu@miner2023.diode"}, daemonResponse{ + ExitCode: 1, + Error: "proxy lease did not expose an address", + }) + if code != 1 { + t.Fatalf("runSSHViaDaemonLease() exit code = %d, want 1", code) + } +} + +func TestRunSSHViaDaemonLeaseRejectsEmptyProxyAddr(t *testing.T) { + origCfg := config.AppConfig + origLookPath := lookPath + config.AppConfig = newRootConfig() + lookPathCalled := false + lookPath = func(string) (string, error) { + lookPathCalled = true + return "", errors.New("unexpected OpenSSH lookup") + } + t.Cleanup(func() { + config.AppConfig = origCfg + lookPath = origLookPath + }) + + code := runSSHViaDaemonLease([]string{"ssh", "ubuntu@miner2023.diode"}, daemonResponse{}) + if code != 1 { + t.Fatalf("runSSHViaDaemonLease() exit code = %d, want 1", code) + } + if lookPathCalled { + t.Fatal("runSSHViaDaemonLease() looked up OpenSSH tool despite empty proxy address") + } +} + func TestFindOpenSSHToolWindowsInstallHelp(t *testing.T) { origLookPath := lookPath origGOOS := runtimeGOOS From 6af9a6fb5265418d49098b6328cab1b11547bf67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=96mer=20AKG=C3=9CL?= <37873266+tuhalf@users.noreply.github.com> Date: Thu, 28 May 2026 17:03:37 +0200 Subject: [PATCH 14/15] Clear daemon runtime state on mode stop --- cmd/diode/app.go | 1 + cmd/diode/daemon_test.go | 14 ++++++++++++++ 2 files changed, 15 insertions(+) diff --git a/cmd/diode/app.go b/cmd/diode/app.go index e077bee6..56c81dd3 100644 --- a/cmd/diode/app.go +++ b/cmd/diode/app.go @@ -627,6 +627,7 @@ func (dio *Diode) StopMode() { dio.modeMu.Unlock() if activeMode == "" { clientManager.GetPool().SetPublishedPorts(map[int]*config.Port{}) + dio.controlRuntime = controlRuntimeState{} } } } diff --git a/cmd/diode/daemon_test.go b/cmd/diode/daemon_test.go index 8ca1219b..00c5abda 100644 --- a/cmd/diode/daemon_test.go +++ b/cmd/diode/daemon_test.go @@ -289,6 +289,20 @@ func TestStopModeTimesOutWaitingForDone(t *testing.T) { } } +func TestStopModeClearsPublishedRuntimeState(t *testing.T) { + dio := NewDiode(newSharedControlTestConfig(t)) + dio.BeginMode("publish") + dio.controlRuntime.published = publishedControlState{ + public: []string{"80"}, + } + + dio.StopMode() + + if len(dio.controlRuntime.published.public) != 0 { + t.Fatalf("published runtime state = %#v, want cleared", dio.controlRuntime.published) + } +} + func TestStartPrintsIdentityOnSubsequentCalls(t *testing.T) { cfg := newSharedControlTestConfig(t) var stdout bytes.Buffer From 161cc479601c610ccb6aaabff5ea0ad8f154d49b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=96mer=20AKG=C3=9CL?= <37873266+tuhalf@users.noreply.github.com> Date: Thu, 28 May 2026 17:59:19 +0200 Subject: [PATCH 15/15] Fix embedded ssh output race --- rpc/ssh_service.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rpc/ssh_service.go b/rpc/ssh_service.go index 3ba5f218..f482df27 100644 --- a/rpc/ssh_service.go +++ b/rpc/ssh_service.go @@ -496,7 +496,7 @@ func handleSSHSessionStart(req *ssh.Request, proc *sshProcessHandle, localUser s if err != nil || !handled { return next, err } - go proxySSHProcessIO(channel, next) + proxySSHProcessIO(channel, next) go waitForProcess(next) return next, nil }