diff --git a/README.md b/README.md index 66e6c3b8..fd630bf0 100644 --- a/README.md +++ b/README.md @@ -106,6 +106,8 @@ If the publish succeeds, the client logs the public gateway URL: - `http://.diode.link/` for public port `80` - `https://.diode.link:/` for some public high ports such as `8000-8100` +`publish` stays attached by default and streams daemon output in the foreground. Press `Ctrl-C` to stop the active publish mode while leaving the daemon running idle. Use `./diode -d publish ...` if you want the command to detach and return immediately. + ### 3. Browse Diode services through a local SOCKS proxy ```bash @@ -139,6 +141,7 @@ Top-level commands: - `config`: inspect or change local stored values - `mcp`: run the client as an MCP server over stdin/stdout - `gateway`: run a public HTTP/HTTPS gateway +- `daemon`: inspect or manage the background daemon - `time`: query consensus time - `version`: print build version @@ -156,6 +159,46 @@ Not: ./diode publish -debug=true -public 80:80 ``` +Daemon-routed long-running commands behave like `docker run`: the CLI uses the daemon, stays attached, streams daemon output, and exits when interrupted or when the active mode exits. This applies to: + +- `publish` +- `gateway` +- `socksd` +- `join` +- `files` + +Use root `-d` / `--detach` to keep the daemon mode running in the background: + +```bash +./diode -d publish -public 8080:80 +./diode -d socksd +``` + +`-d` is only valid for those long-running daemon apply modes. One-off daemon-routed commands such as `query`, `fetch`, `config`, `ssh`, `update`, `push`, and `pull` return after their task finishes and reject `-d`. + +`-no-daemon` bypasses daemon routing entirely and keeps the command running in the local CLI process. + +Daemon management commands: + +```bash +./diode daemon status +./diode daemon restart +./diode daemon stop +./diode daemon ports remove 80 443 +./diode daemon ports clear +``` + +`daemon status` reports whether the daemon is running, the active mode, active arguments, published ports, bind rules, SOCKS state, and config API state. `daemon ports remove` and `daemon ports clear` only manage the current daemon-owned publish/files mode; they do not edit local config. + +Daemon instances are scoped by wallet database path. The default `diode daemon ...` commands manage the daemon for the default `-dbpath`; passing a different root `-dbpath` manages a separate daemon for that wallet: + +```bash +./diode -dbpath ./wallet-a.db -d publish -public 8080:80 +./diode -dbpath ./wallet-b.db -d socksd -socksd_port 1082 +./diode -dbpath ./wallet-a.db daemon status +./diode -dbpath ./wallet-b.db daemon stop +``` + If you plan to contribute code, see [CONTRIBUTING.md](CONTRIBUTING.md). ## Publishing Ports @@ -695,9 +738,11 @@ Common global flags you will actually use: - `-configpath `: load YAML config from a file - `-dbpath `: change the database path +- `-d` / `--detach`: run long-running daemon modes in the background instead of staying attached - `-diodeaddrs `: override the bootstrap RPC peers - `-bind :::`: create a local forward through Diode - `-maxports `: cap concurrent ports per device +- `-no-daemon`: run in the current process instead of routing through the daemon - `-tray=true`: show the tray icon on supported platforms - `-update=false`: disable auto-update on startup - `-debug=true`: enable debug logging diff --git a/cmd/diode/app.go b/cmd/diode/app.go index 7437c227..56c81dd3 100644 --- a/cmd/diode/app.go +++ b/cmd/diode/app.go @@ -4,7 +4,9 @@ package main import ( + "flag" "fmt" + "io" "math/rand" "net" "net/http" @@ -35,7 +37,8 @@ var ( PreRun: prepareDiode, PostRun: cleanDiode, } - bootDiodeAddrs = [6]string{ + modeStopWaitTimeout = 5 * time.Second + bootDiodeAddrs = [6]string{ "diode://0xceca2f8cf1983b4cf0c1ba51fd382c2bc37aba58@us1.prenet.diode.io:41046", "diode://0x7e4cd38d266902444dc9c8f7c0aa716a32497d0b@us2.prenet.diode.io:41046", "diode://0x68e0bafdda9ef323f692fc080d612718c941d120@as1.prenet.diode.io:41046", @@ -45,35 +48,48 @@ var ( } ) -func init() { +func registerRootFlags(fs *flag.FlagSet, cfg *config.Config) { + fs.StringVar(&cfg.DBPath, "dbpath", util.DefaultDBPath(), "file path to db file") + fs.IntVar(&cfg.RetryTimes, "retrytimes", 3, "retry times to connect the remote rpc server") + fs.DurationVar(&cfg.EdgeE2ETimeout, "e2etimeout", 15*time.Second, "timeout seconds for edge e2e handshake") + fs.BoolVar(&cfg.EnableUpdate, "update", true, "enable update when start diode") + fs.BoolVar(&cfg.EnableMetrics, "metrics", false, "enable metrics stats") + fs.BoolVar(&cfg.EnableTray, "tray", false, "show a system tray icon") + fs.BoolVar(&cfg.DisableDaemon, "no-daemon", false, "run this command in standalone mode instead of using the diode daemon") + fs.BoolVar(&cfg.DetachDaemon, "d", false, "run daemon mode in the background") + fs.BoolVar(&cfg.DetachDaemon, "detach", false, "run daemon mode in the background") + fs.BoolVar(&cfg.BlockquickDowngrade, "bqdowngrade", false, "reset blockquick window after repeated validation failures") + registerSharedControlFlags(fs, cfg, "debug", "api", "apiaddr") + fs.IntVar(&cfg.RlimitNofile, "rlimit_nofile", 0, "specify the file descriptor numbers that can be opened by this process") + registerSharedControlFlags(fs, cfg, "logfilepath", "logstats", "logtarget", "logdatetime") + fs.StringVar(&cfg.ConfigFilePath, "configpath", "", "yaml file path to config file") + fs.StringVar(&cfg.CPUProfile, "cpuprofile", "", "file path for cpu profiling") + fs.StringVar(&cfg.MEMProfile, "memprofile", "", "file path for memory profiling") + fs.IntVar(&cfg.PProfPort, "pprofport", 0, "localhost port for pprof for memory debugging") + fs.StringVar(&cfg.BlockProfile, "blockprofile", "", "file path for block profiling") + fs.IntVar(&cfg.BlockProfileRate, "blockprofilerate", 1, "the fraction of goroutine blocking events that are reported in the blocking profile") + fs.StringVar(&cfg.MutexProfile, "mutexprofile", "", "file path for mutex profiling") + fs.IntVar(&cfg.MutexProfileRate, "mutexprofilerate", 1, "the fraction of mutex contention events that are reported in the mutex profile") + + fs.String("fleet", "", "fleet contract address (0x...) for this invocation only; use 'diode config -set fleet=0x...' to persist") + + fs.DurationVar(&cfg.RemoteRPCTimeout, "timeout", 5*time.Second, "timeout seconds to connect to the remote rpc server") + fs.DurationVar(&cfg.RetryWait, "retrywait", 1*time.Second, "wait seconds before next retry") + registerSharedControlFlags(fs, cfg, "diodeaddrs", "blockdomains", "blocklists", "allowlists", "bind", "resolvecachetime", "bnscachetime") + fs.IntVar(&cfg.MaxPortsPerDevice, "maxports", 0, "maximum concurrent ports per device (0 = unlimited)") +} + +func newRootConfig() *config.Config { cfg := &config.Config{} - diodeCmd.Flag.StringVar(&cfg.DBPath, "dbpath", util.DefaultDBPath(), "file path to db file") - diodeCmd.Flag.IntVar(&cfg.RetryTimes, "retrytimes", 3, "retry times to connect the remote rpc server") - diodeCmd.Flag.DurationVar(&cfg.EdgeE2ETimeout, "e2etimeout", 15*time.Second, "timeout seconds for edge e2e handshake") - // should put to httpd or other command - diodeCmd.Flag.BoolVar(&cfg.EnableUpdate, "update", true, "enable update when start diode") - diodeCmd.Flag.BoolVar(&cfg.EnableMetrics, "metrics", false, "enable metrics stats") - diodeCmd.Flag.BoolVar(&cfg.EnableTray, "tray", false, "show a system tray icon") - diodeCmd.Flag.BoolVar(&cfg.BlockquickDowngrade, "bqdowngrade", false, "reset blockquick window after repeated validation failures") - registerSharedControlFlags(&diodeCmd.Flag, cfg, "debug", "api", "apiaddr") - diodeCmd.Flag.IntVar(&cfg.RlimitNofile, "rlimit_nofile", 0, "specify the file descriptor numbers that can be opened by this process") - registerSharedControlFlags(&diodeCmd.Flag, cfg, "logfilepath", "logstats", "logtarget", "logdatetime") - diodeCmd.Flag.StringVar(&cfg.ConfigFilePath, "configpath", "", "yaml file path to config file") - diodeCmd.Flag.StringVar(&cfg.CPUProfile, "cpuprofile", "", "file path for cpu profiling") - // diodeCmd.Flag.IntVar(&cfg.CPUProfileRate, "cpuprofilerate", 100, "the CPU profiling rate to hz samples per second") - diodeCmd.Flag.StringVar(&cfg.MEMProfile, "memprofile", "", "file path for memory profiling") - diodeCmd.Flag.IntVar(&cfg.PProfPort, "pprofport", 0, "localhost port for pprof for memory debugging") - diodeCmd.Flag.StringVar(&cfg.BlockProfile, "blockprofile", "", "file path for block profiling") - diodeCmd.Flag.IntVar(&cfg.BlockProfileRate, "blockprofilerate", 1, "the fraction of goroutine blocking events that are reported in the blocking profile") - diodeCmd.Flag.StringVar(&cfg.MutexProfile, "mutexprofile", "", "file path for mutex profiling") - diodeCmd.Flag.IntVar(&cfg.MutexProfileRate, "mutexprofilerate", 1, "the fraction of mutex contention events that are reported in the mutex profile") - - diodeCmd.Flag.String("fleet", "", "fleet contract address (0x...) for this invocation only; use 'diode config -set fleet=0x...' to persist") - - diodeCmd.Flag.DurationVar(&cfg.RemoteRPCTimeout, "timeout", 5*time.Second, "timeout seconds to connect to the remote rpc server") - diodeCmd.Flag.DurationVar(&cfg.RetryWait, "retrywait", 1*time.Second, "wait seconds before next retry") - registerSharedControlFlags(&diodeCmd.Flag, cfg, "diodeaddrs", "blockdomains", "blocklists", "allowlists", "bind", "resolvecachetime", "bnscachetime") - diodeCmd.Flag.IntVar(&cfg.MaxPortsPerDevice, "maxports", 0, "maximum concurrent ports per device (0 = unlimited)") + fs := flag.NewFlagSet("diode-root-defaults", flag.ContinueOnError) + fs.SetOutput(io.Discard) + registerRootFlags(fs, cfg) + return cfg +} + +func init() { + cfg := newRootConfig() + registerRootFlags(&diodeCmd.Flag, cfg) config.AppConfig = cfg // Add diode commands diodeCmd.AddSubCommand(bnsCmd) @@ -213,6 +229,13 @@ type Diode struct { deferals []func() closeCh chan struct{} cmd *command.Command + startMu sync.Mutex + started bool + modeMu sync.Mutex + activeMode string + modeDeferals []func() + modeStopCh chan struct{} + modeDoneCh chan struct{} } // NewDiode return diode application @@ -256,7 +279,7 @@ func (dio *Diode) Init() error { shouldUpdateDiode = diff.Hours() >= 24 } if shouldUpdateDiode { - doUpdate() + _, _ = doUpdate(updateRestartStandalone) } } @@ -384,12 +407,32 @@ func (dio *Diode) Defer(deferal func()) { dio.deferals = append(dio.deferals, deferal) } +// ModeDefer registers cleanup tied to the active daemon mode. +func (dio *Diode) ModeDefer(deferal func()) { + dio.modeDeferals = append(dio.modeDeferals, deferal) +} + +func (dio *Diode) SetCommand(cmd *command.Command) { + dio.cmd = cmd +} + +func (dio *Diode) resolveCommand() (*command.Command, error) { + if dio.cmd != nil { + return dio.cmd, nil + } + dio.cmd = diodeCmd.SubCommand() + if dio.cmd == nil { + return nil, fmt.Errorf("could not determine command to start") + } + return dio.cmd, nil +} + // Start the diode application func (dio *Diode) Start() error { cfg := dio.config - dio.cmd = diodeCmd.SubCommand() - if dio.cmd == nil { - return fmt.Errorf("could not determine command to start") + cmd, err := dio.resolveCommand() + if err != nil { + return err } if err := dio.loadPersistedSharedControls(); err != nil { return err @@ -397,14 +440,21 @@ func (dio *Diode) Start() error { if err := applyFleetCLIOverride(&diodeCmd.Flag, dio.config); err != nil { return err } + + dio.startMu.Lock() + firstStart := !dio.started cfg.PrintLabel("Client address", cfg.ClientAddr.HexString()) cfg.PrintLabel("Fleet address", cfg.FleetAddr.HexString()) - dio.clientManager.Start() + if firstStart { + dio.clientManager.Start() + dio.started = true + } + dio.startMu.Unlock() // socksd waits for a validated client inside Start(); reconcile the SOCKS listener here so // local integration tests (and scripts like ci_test.sh) can probe the port while the // network handshake is still in progress. - if dio.cmd.Name == "socksd" { + if cmd.Name == "socksd" { patch := ControlPatch{} patch.Add("socksd", "socksd", true) result := dio.ApplyControlPatch(patch, controlPatchApplyOptions{Reconcile: true}) @@ -413,11 +463,11 @@ func (dio *Diode) Start() error { } } - if dio.cmd.Type == command.EmptyConnectionCommand { + if cmd.Type == command.EmptyConnectionCommand { return nil } - isOneOffCommand := dio.cmd.Type == command.OneOffCommand + isOneOffCommand := cmd.Type == command.OneOffCommand //onlyNeedOne := dio.cmd.SingleConnection || isOneOffCommand if len(dio.config.RemoteRPCAddrs) < 1 { @@ -498,15 +548,107 @@ func (dio *Diode) Wait() { // go func() { // listen to signal sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGINT) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) sig := <-sigChan switch sig { - case syscall.SIGINT: + case syscall.SIGINT, syscall.SIGTERM: dio.Close() } // }() } +func (dio *Diode) BeginMode(mode string) { + dio.modeMu.Lock() + defer dio.modeMu.Unlock() + dio.activeMode = mode + if dio.modeStopCh == nil { + dio.modeStopCh = make(chan struct{}) + } +} + +func (dio *Diode) ActiveMode() string { + dio.modeMu.Lock() + defer dio.modeMu.Unlock() + return dio.activeMode +} + +func (dio *Diode) ModeStopChan() <-chan struct{} { + dio.modeMu.Lock() + defer dio.modeMu.Unlock() + if dio.modeStopCh == nil { + dio.modeStopCh = make(chan struct{}) + } + return dio.modeStopCh +} + +func (dio *Diode) SetModeDone(done chan struct{}) { + dio.modeMu.Lock() + dio.modeDoneCh = done + dio.modeMu.Unlock() +} + +func (dio *Diode) StopMode() { + dio.modeMu.Lock() + stopCh := dio.modeStopCh + doneCh := dio.modeDoneCh + modeDeferals := dio.modeDeferals + socksServer := dio.socksServer + proxyServer := dio.proxyServer + configAPIServer := dio.configAPIServer + clientManager := dio.clientManager + dio.modeStopCh = nil + dio.modeDoneCh = nil + dio.modeDeferals = nil + dio.activeMode = "" + dio.socksServer = nil + dio.proxyServer = nil + dio.configAPIServer = nil + dio.modeMu.Unlock() + + if stopCh != nil { + close(stopCh) + } + cleanupMode := func() { + for _, fun := range modeDeferals { + fun() + } + if socksServer != nil { + socksServer.Close() + } + if proxyServer != nil { + proxyServer.Close() + } + if configAPIServer != nil { + configAPIServer.Close() + } + if clientManager != nil { + dio.modeMu.Lock() + activeMode := dio.activeMode + dio.modeMu.Unlock() + if activeMode == "" { + clientManager.GetPool().SetPublishedPorts(map[int]*config.Port{}) + dio.controlRuntime = controlRuntimeState{} + } + } + } + if doneCh != nil { + select { + case <-doneCh: + cleanupMode() + case <-time.After(modeStopWaitTimeout): + if dio.config != nil && dio.config.Logger != nil { + dio.config.Logger.Warn("Timed out waiting for mode to finish") + } + go func() { + <-doneCh + cleanupMode() + }() + } + return + } + cleanupMode() +} + // Closed returns the whether diode application has been closed func (dio *Diode) isClosed(closedCh <-chan struct{}) bool { select { @@ -528,6 +670,7 @@ func (dio *Diode) Close() { defer dio.mu.Unlock() dio.cd.Do(func() { cfg := config.AppConfig + dio.StopMode() for _, fun := range dio.deferals { fun() } diff --git a/cmd/diode/cli_status.go b/cmd/diode/cli_status.go new file mode 100644 index 00000000..0c5d0a16 --- /dev/null +++ b/cmd/diode/cli_status.go @@ -0,0 +1,23 @@ +package main + +import "fmt" + +type exitStatusError struct { + code int + msg string +} + +func (e *exitStatusError) Error() string { + return e.msg +} + +func (e *exitStatusError) Status() int { + return e.code +} + +func newExitStatusError(code int, format string, args ...interface{}) error { + return &exitStatusError{ + code: code, + msg: fmt.Sprintf(format, args...), + } +} diff --git a/cmd/diode/config_server.go b/cmd/diode/config_server.go index 59b0423d..f3d0a8d4 100644 --- a/cmd/diode/config_server.go +++ b/cmd/diode/config_server.go @@ -505,6 +505,10 @@ func (configAPIServer *ConfigAPIServer) apiHandleFunc() func(w http.ResponseWrit return } configAPIServer.successResponse(w, "ok") + if daemonState != nil { + daemonState.baseConfig = sanitizedDaemonBaseConfig(configAPIServer.appConfig) + configAPIServer.appConfig.Logger.Info("Updated config in daemon mode without process restart") + } return } configAPIServer.notFoundError(w) diff --git a/cmd/diode/control_shared.go b/cmd/diode/control_shared.go index 1747335e..4b51c211 100644 --- a/cmd/diode/control_shared.go +++ b/cmd/diode/control_shared.go @@ -1290,22 +1290,45 @@ func (dio *Diode) loadPersistedSharedControls() error { collect(&diodeCmd.Flag) collect(&dio.cmd.Flag) + if err := loadPersistedSharedControlsInto(dio.config, overrides); err != nil { + return err + } + + dio.controlsLoaded = true + return nil +} + +func loadPersistedSharedControlsInto(cfg *config.Config, overrides map[string]bool) error { + if cfg == nil || cfg.LoadFromFile || db.DB == nil { + return nil + } + effects := controlEffect(0) for _, key := range persistedSharedControlKeys { - if overrides[key] { + if overrides != nil && overrides[key] { continue } value, err := db.DB.Get(sharedControlStorageKey(key)) if err != nil { continue } - if _, err := applySharedControlValue(dio.config, key, string(value)); err != nil { + if _, err := applySharedControlValue(cfg, key, string(value)); err != nil { return fmt.Errorf("could not load persisted %s: %w", key, err) } + if spec, ok := sharedControlSpec(key); ok { + effects |= spec.Effects + } } - - config.NormalizeResolveCache(dio.config) - - dio.controlsLoaded = true + if effects&controlEffectServices != 0 { + if err := syncConfigBindsFromSBinds(cfg); err != nil { + return err + } + } + if effects&controlEffectPublished != 0 { + if err := rebuildPublishedPortState(cfg); err != nil { + return err + } + } + config.NormalizeResolveCache(cfg) return nil } diff --git a/cmd/diode/daemon.go b/cmd/diode/daemon.go new file mode 100644 index 00000000..d0a785e9 --- /dev/null +++ b/cmd/diode/daemon.go @@ -0,0 +1,1590 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "flag" + "fmt" + "io" + "net" + "os" + "os/signal" + "reflect" + "strconv" + "strings" + "sync" + "time" + + "github.com/diodechain/diode_client/command" + "github.com/diodechain/diode_client/config" + "github.com/diodechain/diode_client/rpc" + "github.com/diodechain/diode_client/util" +) + +const ( + daemonCommandName = "__daemon__" + envDaemonReadyFD = "DIODE_DAEMON_READY_FD" + envDaemonStartupSpec = "DIODE_DAEMON_STARTUP_SPEC" + envDaemonRestoreArgs = "DIODE_DAEMON_RESTORE_ARGS" + daemonProtocolVersion = 1 + daemonRequestRunTask = "run_task" + daemonRequestApplyMode = "apply_mode" + daemonRequestLease = "lease_local_proxy" + daemonRequestRelease = "release_local_proxy" + daemonRequestUpdate = "update" + daemonRequestManage = "manage" + daemonStreamStdout = "stdout" + daemonStreamStderr = "stderr" + daemonStreamFinal = "final" + daemonStreamError = "error" +) + +var ( + daemonCmd = &command.Command{ + Name: daemonCommandName, + Run: daemonHandler, + Type: command.DaemonCommand, + Hidden: true, + SkipParentHooks: true, + } + daemonExecMu sync.Mutex + activeDaemonReqKind string + activeDaemonReqMu sync.Mutex + daemonState *runtimeDaemon + daemonStartupFlagNames = map[string]bool{ + "-dbpath": true, + "-retrytimes": true, + "-e2etimeout": true, + "-update": true, + "-metrics": true, + "-tray": true, + "-bqdowngrade": true, + "-debug": true, + "-api": true, + "-apiaddr": true, + "-rlimit_nofile": true, + "-logfilepath": true, + "-logstats": true, + "-logtarget": true, + "-logdatetime": true, + "-configpath": true, + "-cpuprofile": true, + "-memprofile": true, + "-pprofport": true, + "-blockprofile": true, + "-blockprofilerate": true, + "-mutexprofile": true, + "-mutexprofilerate": true, + "-timeout": true, + "-retrywait": true, + "-diodeaddrs": true, + "-blockdomains": true, + "-blocklists": true, + "-allowlists": true, + "-resolvecachetime": true, + "-bnscachetime": true, + "-maxports": true, + "-no-daemon": true, + "-d": true, + "-detach": true, + } + localBypassCommands = map[string]bool{"": true, "version": true, "mcp": true, "ssh-proxy": true, daemonCommandName: true} + daemonApplyModeCmds = map[string]bool{"publish": true, "gateway": true, "socksd": true, "join": true, "files": true} + daemonRunnableCmds = map[string]bool{"query": true, "time": true, "fetch": true, "token": true, "bns": true, "config": true, "reset": true, "push": true, "pull": true, "publish": true, "gateway": true, "socksd": true, "join": true, "files": true, "ssh": true, "scp": true, "update": true} +) + +type daemonStartupSpec struct { + DBPath string `json:"dbpath"` + RetryTimes int `json:"retrytimes"` + EdgeE2ETimeout time.Duration `json:"e2etimeout"` + EnableUpdate bool `json:"update"` + EnableMetrics bool `json:"metrics"` + EnableTray bool `json:"tray"` + BlockquickDowngrade bool `json:"bqdowngrade"` + Debug bool `json:"debug"` + EnableAPIServer bool `json:"api"` + APIServerAddr string `json:"apiaddr"` + RlimitNofile int `json:"rlimit_nofile"` + LogFilePath string `json:"logfilepath"` + LogStats time.Duration `json:"logstats"` + LogTarget string `json:"logtarget"` + LogDateTime bool `json:"logdatetime"` + ConfigFilePath string `json:"configpath"` + CPUProfile string `json:"cpuprofile"` + MEMProfile string `json:"memprofile"` + PProfPort int `json:"pprofport"` + BlockProfile string `json:"blockprofile"` + BlockProfileRate int `json:"blockprofilerate"` + MutexProfile string `json:"mutexprofile"` + MutexProfileRate int `json:"mutexprofilerate"` + RemoteRPCTimeout time.Duration `json:"timeout"` + RetryWait time.Duration `json:"retrywait"` + RemoteRPCAddrs config.StringValues `json:"diodeaddrs"` + SBlockdomains config.StringValues `json:"blockdomains"` + SBlocklists config.StringValues `json:"blocklists"` + SAllowlists config.StringValues `json:"allowlists"` + ResolveCacheTime time.Duration `json:"resolvecachetime"` + BnsCacheTime time.Duration `json:"bnscachetime"` + MaxPortsPerDevice int `json:"maxports"` +} + +type daemonMetadata struct { + PID int `json:"pid"` + SocketPath string `json:"socket_path"` + StartupSpec daemonStartupSpec `json:"startup_spec"` +} + +type daemonRequest struct { + Version int `json:"version"` + Kind string `json:"kind"` + Command string `json:"command"` + Args []string `json:"args,omitempty"` + LeaseID string `json:"lease_id,omitempty"` + Attach bool `json:"attach,omitempty"` +} + +type daemonResponse struct { + Version int `json:"version"` + Stdout string `json:"stdout,omitempty"` + Stderr string `json:"stderr,omitempty"` + ExitCode int `json:"exit_code"` + Error string `json:"error,omitempty"` + ProxyAddr string `json:"proxy_addr,omitempty"` + LeaseID string `json:"lease_id,omitempty"` + ModeActive bool `json:"mode_active,omitempty"` + RestartPath string `json:"-"` + Shutdown bool `json:"-"` +} + +type daemonStreamFrame struct { + Type string `json:"type"` + Data string `json:"data,omitempty"` + ExitCode int `json:"exit_code,omitempty"` + Error string `json:"error,omitempty"` +} + +type runtimeDaemon struct { + socketPath string + metaPath string + listener net.Listener + startup daemonStartupSpec + baseConfig config.Config + leasesMu sync.Mutex + leases map[string]*rpc.Server + stateMu sync.Mutex + modeChange chan struct{} + activeMode string + activeArgs []string + ports map[int]*config.Port + binds []config.Bind + socksAddr string + socksOn bool + gatewayAddr string + gatewayOn bool + secureGatewayAddr string + secureGatewayOn bool + secureGatewayAdditionalAddrs []string + apiAddr string + apiOn bool +} + +func init() { + diodeCmd.AddSubCommand(daemonCmd) +} + +func daemonHandler() error { + cfg := config.AppConfig + cfg.DisableDaemon = true + startupSpec := daemonStartupSpecFromConfig(cfg) + if raw := os.Getenv(envDaemonStartupSpec); raw != "" { + if err := json.Unmarshal([]byte(raw), &startupSpec); err != nil { + return err + } + applyDaemonStartupSpec(cfg, startupSpec) + } + if err := prepareDiode(); err != nil { + return err + } + defer cleanDiode() + if err := loadPersistedSharedControlsInto(cfg, nil); err != nil { + return err + } + if app != nil { + app.controlsLoaded = true + } + + socketPath, metaPath, err := daemonPaths() + if err != nil { + return err + } + _ = os.Remove(socketPath) + ln, err := daemonListen(socketPath) + if err != nil { + return err + } + + daemonState = &runtimeDaemon{ + socketPath: socketPath, + metaPath: metaPath, + listener: ln, + startup: startupSpec, + baseConfig: sanitizedDaemonBaseConfig(cfg), + leases: map[string]*rpc.Server{}, + modeChange: make(chan struct{}), + } + app.Defer(func() { + daemonState.closeLeases() + _ = ln.Close() + cleanupDaemonTransport(socketPath) + _ = os.Remove(metaPath) + }) + + if err := writeDaemonMetadata(metaPath, daemonMetadata{ + PID: os.Getpid(), + SocketPath: socketPath, + StartupSpec: daemonState.startup, + }); err != nil { + return err + } + restoreArgs, err := daemonRestoreArgsFromEnv() + if err != nil { + return err + } + if len(restoreArgs) > 0 { + if err := runDaemonCommandAsKind(daemonRequestApplyMode, restoreArgs); err != nil { + logDaemonInternalError("Couldn't restore daemon mode after restart", err) + } else { + daemonState.updateModeSnapshot(daemonModeNameFromArgs(restoreArgs), restoreArgs, config.AppConfig) + } + } + if err := signalDaemonReady(); err != nil { + return err + } + go serveDaemon(ln) + sigCtx, stop := signal.NotifyContext(context.Background(), daemonSignals()...) + defer stop() + select { + case <-sigCtx.Done(): + case <-app.closeCh: + } + app.Close() + return nil +} + +func maybeHandleDaemonCLI(args []string) (bool, int) { + inv, err := parseRootInvocation(args) + if err != nil { + return false, 0 + } + if inv.command == "daemon" { + applyDaemonStartupSpec(config.AppConfig, inv.startupSpec) + return handleDaemonManagerCLI(inv.commandArgs) + } + if inv.disableDaemon || inv.help || localBypassCommands[inv.command] || !daemonRunnableCmds[inv.command] { + return false, 0 + } + + applyDaemonStartupSpec(config.AppConfig, inv.startupSpec) + + req, err := daemonRequestForInvocation(inv) + if err != nil { + stderrln(err.Error()) + return true, exitCodeFromError(err) + } + + if req.Attach { + handled, reason, exitCode, err := dispatchViaDaemonAttached(inv.startupSpec, req) + if !handled { + if reason != "" { + stderrln(reason) + } + return false, 0 + } + if err != nil { + stderrln(err.Error()) + return true, 1 + } + return true, exitCode + } + + resp, handled, reason, err := dispatchViaDaemon(inv.startupSpec, req) + if !handled { + if reason != "" { + stderrln(reason) + } + return false, 0 + } + if err != nil { + stderrln(err.Error()) + return true, 1 + } + writeDaemonResponse(resp) + if req.Kind == daemonRequestLease { + return true, runSSHViaDaemonLease(inv.commandArgs, resp) + } + if req.Kind == daemonRequestApplyMode && resp.ExitCode == 0 && resp.ModeActive { + stdoutf("Daemon mode active: %s\n", inv.command) + stdoutln("Use `diode daemon status` to inspect or manage the running daemon.") + } + return true, resp.ExitCode +} + +func writeDaemonResponse(resp daemonResponse) { + if resp.Stdout != "" { + _, _ = io.WriteString(stdoutWriter(), resp.Stdout) + } + if resp.Stderr != "" { + _, _ = io.WriteString(stderrWriter(), resp.Stderr) + } + if resp.ExitCode != 0 && resp.Error != "" { + stderrln(resp.Error) + } +} + +func daemonRequestForInvocation(inv rootInvocation) (daemonRequest, error) { + req := daemonRequest{ + Version: daemonProtocolVersion, + Command: inv.command, + Args: inv.execArgs, + } + switch inv.command { + case "ssh", "scp": + req.Kind = daemonRequestLease + case "update": + req.Kind = daemonRequestUpdate + default: + if daemonApplyModeCmds[inv.command] { + req.Kind = daemonRequestApplyMode + } else { + req.Kind = daemonRequestRunTask + } + } + if inv.detachDaemon && req.Kind != daemonRequestApplyMode { + return req, newExitStatusError(2, "-d is only supported for daemon mode commands") + } + if req.Kind == daemonRequestApplyMode && !inv.detachDaemon { + req.Attach = true + } + return req, nil +} + +type rootInvocation struct { + command string + commandArgs []string + execArgs []string + help bool + disableDaemon bool + detachDaemon bool + startupSpec daemonStartupSpec +} + +func parseRootInvocation(args []string) (rootInvocation, error) { + cfg := newRootConfig() + fs := flag.NewFlagSet("diode-root-parse", flag.ContinueOnError) + fs.SetOutput(io.Discard) + registerRootFlags(fs, cfg) + err := fs.Parse(args) + if err == flag.ErrHelp { + return rootInvocation{help: true}, nil + } + if err != nil { + return rootInvocation{}, err + } + rest := fs.Args() + commandName := "" + execArgs := append([]string{}, args...) + if len(rest) > 0 { + commandName = rest[0] + } + if commandName == "" && containsHelpArg(args) { + return rootInvocation{help: true, disableDaemon: cfg.DisableDaemon, detachDaemon: cfg.DetachDaemon, startupSpec: daemonStartupSpecFromConfig(cfg)}, nil + } + if commandName == "" { + commandName = "publish" + rest = append([]string{commandName}, rest...) + execArgs = append(execArgs, commandName) + } + if len(rest) > 1 && containsHelpArg(rest[1:]) { + return rootInvocation{help: true, disableDaemon: cfg.DisableDaemon, detachDaemon: cfg.DetachDaemon, startupSpec: daemonStartupSpecFromConfig(cfg)}, nil + } + return rootInvocation{ + command: commandName, + commandArgs: rest, + execArgs: execArgs, + help: false, + disableDaemon: cfg.DisableDaemon, + detachDaemon: cfg.DetachDaemon, + startupSpec: daemonStartupSpecFromConfig(cfg), + }, nil +} + +func containsHelpArg(args []string) bool { + for _, arg := range args { + switch arg { + case "--help", "-help", "-h": + return true + } + } + return false +} + +func dispatchViaDaemon(spec daemonStartupSpec, req daemonRequest) (daemonResponse, bool, string, error) { + conn, handled, reason, err := openDaemonConnection(spec) + if !handled || err != nil { + return daemonResponse{}, handled, reason, err + } + defer conn.Close() + if err := json.NewEncoder(conn).Encode(req); err != nil { + return daemonResponse{}, true, "", err + } + var resp daemonResponse + if err := json.NewDecoder(conn).Decode(&resp); err != nil { + return daemonResponse{}, true, "", err + } + return resp, true, "", nil +} + +func dispatchViaDaemonAttached(spec daemonStartupSpec, req daemonRequest) (bool, string, int, error) { + conn, handled, reason, err := openDaemonConnection(spec) + if !handled || err != nil { + return handled, reason, 0, err + } + defer conn.Close() + + if err := json.NewEncoder(conn).Encode(req); err != nil { + return true, "", 1, err + } + + frameCh := make(chan daemonStreamFrame, 1) + errCh := make(chan error, 1) + go func() { + dec := json.NewDecoder(conn) + for { + var frame daemonStreamFrame + if err := dec.Decode(&frame); err != nil { + errCh <- err + return + } + frameCh <- frame + if frame.Type == daemonStreamFinal || frame.Type == daemonStreamError { + return + } + } + }() + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, daemonSignals()...) + defer signal.Stop(sigCh) + stopErrCh := make(chan error, 1) + stopRequested := false + for { + select { + case frame := <-frameCh: + switch frame.Type { + case daemonStreamStdout: + _, _ = io.WriteString(stdoutWriter(), frame.Data) + case daemonStreamStderr: + _, _ = io.WriteString(stderrWriter(), frame.Data) + case daemonStreamError: + if frame.Error != "" { + return true, "", exitCodeOrDefault(frame.ExitCode, 1), fmt.Errorf("%s", frame.Error) + } + return true, "", exitCodeOrDefault(frame.ExitCode, 1), fmt.Errorf("daemon stream failed") + case daemonStreamFinal: + if frame.Error != "" { + _, _ = io.WriteString(stderrWriter(), frame.Error+"\n") + } + return true, "", frame.ExitCode, nil + default: + return true, "", 1, fmt.Errorf("unknown daemon stream frame: %s", frame.Type) + } + case err := <-errCh: + if err == io.EOF { + return true, "", 1, fmt.Errorf("daemon stream ended before final response") + } + return true, "", 1, err + case err := <-stopErrCh: + if err != nil { + return true, "", 1, err + } + case <-sigCh: + if stopRequested { + return true, "", 130, nil + } + stopRequested = true + go func() { + stopErrCh <- requestDaemonModeStop() + }() + } + } +} + +func exitCodeOrDefault(code int, fallback int) int { + if code != 0 { + return code + } + return fallback +} + +func openDaemonConnection(spec daemonStartupSpec) (net.Conn, bool, string, error) { + meta, metaErr := readDaemonMetadata() + if metaErr == nil && !reflect.DeepEqual(meta.StartupSpec, spec) { + if _, err := dialDaemon(meta.SocketPath); err == nil { + return nil, false, "running daemon is incompatible with this invocation; using standalone mode. Run `diode daemon restart` to reload the daemon with the current binary and flags.", nil + } + cleanupDaemonArtifacts(meta.SocketPath, metaPathFromSocket(meta.SocketPath)) + metaErr = os.ErrNotExist + } + + socketPath := "" + if metaErr == nil { + socketPath = meta.SocketPath + } + conn, err := dialDaemon(socketPath) + if err != nil { + cleanupDaemonArtifacts(socketPath, metaPathFromSocket(socketPath)) + if err := spawnDaemon(spec); err != nil { + return nil, true, "", err + } + meta, err = readDaemonMetadata() + if err != nil { + return nil, true, "", err + } + conn, err = dialDaemon(meta.SocketPath) + if err != nil { + return nil, true, "", err + } + } + return conn, true, "", nil +} + +func serveDaemon(ln net.Listener) { + for { + conn, err := ln.Accept() + if err != nil { + if app.Closed() { + return + } + time.Sleep(50 * time.Millisecond) + continue + } + go handleDaemonConn(conn) + } +} + +func handleDaemonConn(conn net.Conn) { + defer conn.Close() + defer func() { + if r := recover(); r != nil { + logDaemonInternalError("Panic handling daemon connection", fmt.Errorf("%v", r)) + } + }() + var req daemonRequest + if err := json.NewDecoder(conn).Decode(&req); err != nil { + _ = json.NewEncoder(conn).Encode(daemonResponse{Version: daemonProtocolVersion, ExitCode: 1, Error: err.Error()}) + return + } + if req.Kind == daemonRequestApplyMode && req.Attach { + executeDaemonAttachedRequest(conn, req) + return + } + resp := executeDaemonRequest(req) + if err := json.NewEncoder(conn).Encode(resp); err != nil { + logDaemonInternalError("Couldn't encode daemon response", err) + return + } + if resp.Shutdown { + go app.Close() + } + if resp.RestartPath != "" { + if err := daemonRestartSelf(resp.RestartPath, daemonState.startup); err != nil { + logDaemonInternalError("Couldn't restart daemon after update", err) + } + } +} + +func executeDaemonRequest(req daemonRequest) daemonResponse { + if req.Kind == daemonRequestUpdate { + if !daemonExecMu.TryLock() { + return executeDaemonBusyResponse(req) + } + daemonExecMu.Unlock() + return executeDaemonUpdateRequest(req) + } + + if !daemonExecMu.TryLock() { + return executeDaemonBusyResponse(req) + } + defer daemonExecMu.Unlock() + + resp := daemonResponse{Version: daemonProtocolVersion} + if daemonState == nil { + resp.ExitCode = 1 + resp.Error = "daemon state is not initialized" + return resp + } + + switch req.Kind { + case daemonRequestLease: + addr, leaseID, err := daemonLeaseLocalProxy() + if err != nil { + resp.ExitCode = 1 + resp.Error = err.Error() + return resp + } + resp.ProxyAddr = addr + resp.LeaseID = leaseID + return resp + case daemonRequestRelease: + if err := daemonReleaseLocalProxy(req.LeaseID); err != nil { + resp.ExitCode = 1 + resp.Error = err.Error() + } + return resp + case daemonRequestManage: + manageResp := daemonResponse{Version: daemonProtocolVersion} + buffered := executeDaemonBufferedRequest(req.Kind, false, func() (string, error) { + return "", runDaemonManage(req.Args, &manageResp) + }) + manageResp.Stdout = buffered.Stdout + manageResp.Stderr = buffered.Stderr + manageResp.ExitCode = buffered.ExitCode + manageResp.Error = buffered.Error + return manageResp + } + if req.Kind == daemonRequestApplyMode && req.Command == "publish" { + req.Args = mergeImplicitPublishArgs(req.Args) + } + resp = executeDaemonBufferedRequest(req.Kind, daemonRequestPersistsBaseConfig(req), func() (string, error) { + return "", runDaemonCommandArgs(req.Args) + }) + if req.Kind == daemonRequestApplyMode && resp.ExitCode == 0 { + resp.ModeActive = updateDaemonModeSnapshotIfActive(req.Command, req.Args) + } + return resp +} + +func executeDaemonBusyResponse(req daemonRequest) daemonResponse { + resp := daemonResponse{Version: daemonProtocolVersion} + if req.Kind == daemonRequestManage && len(req.Args) >= 2 && req.Args[0] == "daemon" { + switch req.Args[1] { + case "stop": + resp.Stdout = "Stopping diode daemon.\n" + if app != nil { + app.StopMode() + } + if daemonState != nil { + daemonState.clearModeSnapshot() + } + resp.Shutdown = true + return resp + case "mode-stop": + resp.Stdout = "Stopping diode daemon mode.\n" + if app != nil { + app.StopMode() + } + if daemonState != nil { + daemonState.clearModeSnapshot() + } + return resp + case "status": + resp.Stdout = "Daemon status : busy\nActive request : starting or stopping a daemon command\n" + return resp + } + } + resp.ExitCode = 1 + resp.Error = "daemon is busy starting or stopping another command; run `diode daemon stop` to reset it" + return resp +} + +func executeDaemonUpdateRequest(req daemonRequest) daemonResponse { + resp := daemonResponse{Version: daemonProtocolVersion} + var stdout bytes.Buffer + var stderr bytes.Buffer + + daemonExecMu.Lock() + if daemonState == nil { + daemonExecMu.Unlock() + resp.ExitCode = 1 + resp.Error = "daemon state is not initialized" + return resp + } + cfg := cloneDaemonConfig(&daemonState.baseConfig) + resetTransientConfig(&cfg) + cfg.StdoutWriter = &stdout + cfg.StderrWriter = &stderr + daemonExecMu.Unlock() + + restartPath, err := runDaemonUpdateWithConfig(req.Args, &cfg) + + resp.Stdout = stdout.String() + resp.Stderr = stderr.String() + resp.RestartPath = restartPath + if err != nil { + resp.ExitCode = exitCodeFromError(err) + resp.Error = err.Error() + if resp.ExitCode == 0 { + resp.ExitCode = 1 + } + } else { + resp.ExitCode = 0 + } + return resp +} + +func daemonRequestPersistsBaseConfig(req daemonRequest) bool { + if req.Kind != daemonRequestRunTask { + return false + } + return req.Command == "config" || req.Command == "reset" +} + +func executeDaemonBufferedRequest(kind string, persistBaseConfig bool, fn func() (string, error)) daemonResponse { + resp := daemonResponse{Version: daemonProtocolVersion} + var stdout bytes.Buffer + var stderr bytes.Buffer + *config.AppConfig = cloneDaemonConfig(&daemonState.baseConfig) + resetTransientConfig(config.AppConfig) + resetRequestGlobals() + config.AppConfig.StdoutWriter = &stdout + config.AppConfig.StderrWriter = &stderr + + activeDaemonReqMu.Lock() + activeDaemonReqKind = kind + activeDaemonReqMu.Unlock() + restartPath, err := fn() + activeDaemonReqMu.Lock() + activeDaemonReqKind = "" + activeDaemonReqMu.Unlock() + + resp.Stdout = stdout.String() + resp.Stderr = stderr.String() + resp.RestartPath = restartPath + if err != nil { + resp.ExitCode = exitCodeFromError(err) + resp.Error = err.Error() + if resp.ExitCode == 0 { + resp.ExitCode = 1 + } + } else { + resp.ExitCode = 0 + } + if persistBaseConfig { + daemonState.baseConfig = sanitizedDaemonBaseConfig(config.AppConfig) + } + return resp +} + +func executeDaemonAttachedRequest(conn net.Conn, req daemonRequest) { + var frameMu sync.Mutex + enc := json.NewEncoder(conn) + sendFrame := func(frame daemonStreamFrame) error { + frameMu.Lock() + defer frameMu.Unlock() + return enc.Encode(frame) + } + + if daemonState == nil { + _ = sendFrame(daemonStreamFrame{Type: daemonStreamError, ExitCode: 1, Error: "daemon state is not initialized"}) + return + } + if req.Command == "publish" { + req.Args = mergeImplicitPublishArgs(req.Args) + } + + daemonExecMu.Lock() + *config.AppConfig = cloneDaemonConfig(&daemonState.baseConfig) + resetTransientConfig(config.AppConfig) + resetRequestGlobals() + config.AppConfig.StdoutWriter = daemonFrameWriter{typ: daemonStreamStdout, send: sendFrame} + config.AppConfig.StderrWriter = daemonFrameWriter{typ: daemonStreamStderr, send: sendFrame} + + activeDaemonReqMu.Lock() + activeDaemonReqKind = daemonRequestApplyMode + activeDaemonReqMu.Unlock() + err := runDaemonCommandArgs(req.Args) + activeDaemonReqMu.Lock() + activeDaemonReqKind = "" + activeDaemonReqMu.Unlock() + + exitCode := exitCodeFromError(err) + if err != nil && exitCode == 0 { + exitCode = 1 + } + if err == nil { + _ = updateDaemonModeSnapshotIfActive(req.Command, req.Args) + } + daemonExecMu.Unlock() + + if err != nil { + _ = sendFrame(daemonStreamFrame{Type: daemonStreamFinal, ExitCode: exitCode, Error: err.Error()}) + return + } + if app.ActiveMode() != req.Command { + _ = sendFrame(daemonStreamFrame{Type: daemonStreamFinal, ExitCode: 0}) + return + } + + daemonState.waitForModeExit(req.Command, req.Args, app.closeCh) + _ = sendFrame(daemonStreamFrame{Type: daemonStreamFinal, ExitCode: 0}) +} + +func updateDaemonModeSnapshotIfActive(command string, args []string) bool { + if app == nil || daemonState == nil { + return false + } + switch activeMode := app.ActiveMode(); activeMode { + case command: + daemonState.updateModeSnapshot(command, args, config.AppConfig) + return true + case "": + daemonState.clearModeSnapshot() + } + return false +} + +type daemonFrameWriter struct { + typ string + send func(daemonStreamFrame) error +} + +func (w daemonFrameWriter) Write(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + if err := w.send(daemonStreamFrame{Type: w.typ, Data: string(p)}); err != nil { + return 0, err + } + return len(p), nil +} + +func runDaemonCommandArgs(args []string) error { + if len(args) == 0 { + return newExitStatusError(2, "missing command") + } + resetSharedControlsForArgs(config.AppConfig, args) + if err := diodeCmd.Flag.Parse(args); err != nil { + return err + } + if err := refreshRequestDerivedConfig(config.AppConfig); err != nil { + return err + } + subCmd := diodeCmd.SubCommand() + if subCmd == nil { + return newExitStatusError(2, "unknown command: %s", args[0]) + } + app.SetCommand(subCmd) + rootArgs := diodeCmd.Flag.Args() + if len(rootArgs) > 1 { + if !subCmd.PassThroughArgs { + if err := subCmd.Flag.Parse(rootArgs[1:]); err != nil { + return err + } + } + } else if !subCmd.PassThroughArgs { + _ = subCmd.Flag.Parse([]string{}) + } + return subCmd.Run() +} + +func resetSharedControlsForArgs(cfg *config.Config, args []string) { + if cfg == nil { + return + } + for _, item := range parseRootExecItems(args) { + if !strings.HasPrefix(item.flagName, "-") { + continue + } + name := strings.TrimLeft(item.flagName, "-") + for _, key := range sharedControlFlagKeys(name) { + resetSharedControlValue(cfg, key) + } + } +} + +func refreshRequestDerivedConfig(cfg *config.Config) error { + if cfg == nil { + return nil + } + cfg.SBinds = dedupeStringValues(cfg.SBinds) + cfg.Binds = make([]config.Bind, 0, len(cfg.SBinds)) + for _, str := range cfg.SBinds { + bind, err := parseBind(str) + if err != nil { + return err + } + cfg.Binds = append(cfg.Binds, *bind) + } + if len(cfg.SAllowlists) == 0 { + cfg.Allowlists = nil + return nil + } + cfg.Allowlists = make(map[util.Address]bool, len(cfg.SAllowlists)) + for _, raw := range cfg.SAllowlists { + addr, err := util.DecodeAddress(raw) + if err != nil { + return err + } + cfg.Allowlists[addr] = true + } + return nil +} + +func dedupeStringValues(values config.StringValues) config.StringValues { + if len(values) < 2 { + return values + } + out := make(config.StringValues, 0, len(values)) + for _, value := range values { + if !util.StringsContain(out, value) { + out = append(out, value) + } + } + return out +} + +func mergeImplicitPublishArgs(args []string) []string { + if daemonState == nil { + return args + } + pre, post, ok := splitPublishExecArgs(args) + if !ok || len(post) > 0 { + return args + } + mode, existingArgs := daemonModeArgs() + if mode != "publish" || len(existingArgs) == 0 { + return args + } + existingPre, existingPost, ok := splitPublishExecArgs(existingArgs) + if !ok || len(existingPost) == 0 { + return args + } + merged := mergeImplicitPublishPreArgs(existingPre, pre) + merged = append(merged, "publish") + merged = append(merged, existingPost...) + return merged +} + +func splitPublishExecArgs(args []string) (pre []string, post []string, ok bool) { + for i, arg := range args { + if arg == "publish" { + return append([]string{}, args[:i]...), append([]string{}, args[i+1:]...), true + } + } + return nil, nil, false +} + +func sanitizeModeArgs(mode string, args []string) []string { + if len(args) == 0 { + return nil + } + cmdIdx := -1 + for i, arg := range args { + if arg == mode { + cmdIdx = i + break + } + } + if cmdIdx < 0 { + return append([]string{}, args...) + } + preItems := parseRootExecItems(args[:cmdIdx]) + sanitized := make([]string, 0, len(args)) + for _, item := range preItems { + if daemonStartupFlagNames[item.flagName] { + continue + } + sanitized = append(sanitized, item.args...) + } + sanitized = append(sanitized, args[cmdIdx:]...) + return sanitized +} + +func mergeImplicitPublishPreArgs(existingPre, currentPre []string) []string { + items := append(parseRootExecItems(existingPre), parseRootExecItems(currentPre)...) + if len(items) == 0 { + return nil + } + bindValues := make(config.StringValues, 0) + out := make([]rootExecItem, 0, len(items)) + for _, item := range items { + if item.flagName == "-bind" { + if item.value == "" || util.StringsContain(bindValues, item.value) { + continue + } + bindValues = append(bindValues, item.value) + } + out = append(out, item) + } + merged := make([]string, 0, len(existingPre)+len(currentPre)) + for _, item := range out { + merged = append(merged, item.args...) + } + return merged +} + +type rootExecItem struct { + flagName string + value string + args []string +} + +func parseRootExecItems(args []string) []rootExecItem { + items := make([]rootExecItem, 0, len(args)) + for i := 0; i < len(args); i++ { + arg := args[i] + if !strings.HasPrefix(arg, "-") || arg == "-" { + items = append(items, rootExecItem{args: []string{arg}}) + continue + } + flagName := arg + value := "" + itemArgs := []string{arg} + if idx := strings.Index(arg, "="); idx >= 0 { + flagName = arg[:idx] + value = arg[idx+1:] + } else if i+1 < len(args) && !strings.HasPrefix(args[i+1], "-") { + value = args[i+1] + itemArgs = append(itemArgs, args[i+1]) + i++ + } + items = append(items, rootExecItem{ + flagName: flagName, + value: value, + args: itemArgs, + }) + } + return items +} + +func isDaemonApplyRequest() bool { + activeDaemonReqMu.Lock() + defer activeDaemonReqMu.Unlock() + return activeDaemonReqKind == daemonRequestApplyMode +} + +func cloneDaemonConfig(cfg *config.Config) config.Config { + cp := *cfg + cp.RemoteRPCAddrs = append(config.StringValues{}, cfg.RemoteRPCAddrs...) + cp.SBlockdomains = append(config.StringValues{}, cfg.SBlockdomains...) + cp.SBlocklists = append(config.StringValues{}, cfg.SBlocklists...) + cp.SAllowlists = append(config.StringValues{}, cfg.SAllowlists...) + cp.SBinds = append(config.StringValues{}, cfg.SBinds...) + cp.PublicPublishedPorts = append(config.StringValues{}, cfg.PublicPublishedPorts...) + cp.ProtectedPublishedPorts = append(config.StringValues{}, cfg.ProtectedPublishedPorts...) + cp.PrivatePublishedPorts = append(config.StringValues{}, cfg.PrivatePublishedPorts...) + cp.SSHPublishedServices = append(config.StringValues{}, cfg.SSHPublishedServices...) + cp.ConfigDelete = append(config.StringValues{}, cfg.ConfigDelete...) + cp.ConfigSet = append(config.StringValues{}, cfg.ConfigSet...) + return cp +} + +func sanitizedDaemonBaseConfig(cfg *config.Config) config.Config { + cp := cloneDaemonConfig(cfg) + resetTransientConfig(&cp) + return cp +} + +func resetTransientConfig(cfg *config.Config) { + cfg.StdoutWriter = nil + cfg.StderrWriter = nil + cfg.DisableDaemon = false + cfg.DetachDaemon = false + cfg.QueryAddress = "" + cfg.ConfigUnsafe = false + cfg.ConfigList = false + cfg.ConfigFullValues = false + cfg.ConfigDelete = nil + cfg.ConfigSet = nil + cfg.BNSForce = false + cfg.BNSRegister = "" + cfg.BNSUnregister = "" + cfg.BNSTransfer = "" + cfg.BNSLookup = "" + cfg.BNSAccount = "" + cfg.Experimental = false +} + +func resetRequestGlobals() { + enableStaticServer = false + scfg.RootDirectory = "" + scfg.Host = "127.0.0.1" + scfg.Port = 8080 + scfg.Indexed = false + publishFileSpecs = nil + publishFileFileroot = "" + filesFileroot = "" + edgeACME = false + edgeACMEEmail = "" + edgeACMEAddtlCerts = "" + if fetchCfg != nil { + *fetchCfg = fetchConfig{Method: "GET"} + } + if tokenCfg != nil { + *tokenCfg = tokenConfig{Gas: "21000"} + } + dryRun = false + network = "mainnet" + contractAddress = "" + oasisClient = nil + wantWireGuard = false + wgSuffix = "" +} + +func exitCodeFromError(err error) int { + type statusError interface{ Status() int } + type codeError interface{ Code() int } + if err == nil { + return 0 + } + if se, ok := err.(statusError); ok { + return se.Status() + } + if ce, ok := err.(codeError); ok { + return ce.Code() + } + return 1 +} + +func daemonStartupSpecFromConfig(cfg *config.Config) daemonStartupSpec { + return daemonStartupSpec{ + DBPath: canonicalDaemonDBPath(cfg.DBPath), + RetryTimes: cfg.RetryTimes, + EdgeE2ETimeout: cfg.EdgeE2ETimeout, + EnableUpdate: cfg.EnableUpdate, + EnableMetrics: cfg.EnableMetrics, + EnableTray: cfg.EnableTray, + BlockquickDowngrade: cfg.BlockquickDowngrade, + Debug: cfg.Debug, + EnableAPIServer: cfg.EnableAPIServer, + APIServerAddr: cfg.APIServerAddr, + RlimitNofile: cfg.RlimitNofile, + LogFilePath: cfg.LogFilePath, + LogStats: cfg.LogStats, + LogTarget: cfg.LogTarget, + LogDateTime: cfg.LogDateTime, + ConfigFilePath: cfg.ConfigFilePath, + CPUProfile: cfg.CPUProfile, + MEMProfile: cfg.MEMProfile, + PProfPort: cfg.PProfPort, + BlockProfile: cfg.BlockProfile, + BlockProfileRate: cfg.BlockProfileRate, + MutexProfile: cfg.MutexProfile, + MutexProfileRate: cfg.MutexProfileRate, + RemoteRPCTimeout: cfg.RemoteRPCTimeout, + RetryWait: cfg.RetryWait, + RemoteRPCAddrs: append(config.StringValues{}, cfg.RemoteRPCAddrs...), + SBlockdomains: append(config.StringValues{}, cfg.SBlockdomains...), + SBlocklists: append(config.StringValues{}, cfg.SBlocklists...), + SAllowlists: append(config.StringValues{}, cfg.SAllowlists...), + ResolveCacheTime: cfg.ResolveCacheTime, + BnsCacheTime: cfg.BnsCacheTime, + MaxPortsPerDevice: cfg.MaxPortsPerDevice, + } +} + +func daemonModeNameFromArgs(args []string) string { + inv, err := parseRootInvocation(args) + if err == nil && inv.command != "" { + return inv.command + } + if len(args) == 0 { + return "" + } + return args[0] +} + +func applyDaemonStartupSpec(cfg *config.Config, spec daemonStartupSpec) { + cfg.DBPath = spec.DBPath + cfg.RetryTimes = spec.RetryTimes + cfg.EdgeE2ETimeout = spec.EdgeE2ETimeout + cfg.EnableUpdate = spec.EnableUpdate + cfg.EnableMetrics = spec.EnableMetrics + cfg.EnableTray = spec.EnableTray + cfg.BlockquickDowngrade = spec.BlockquickDowngrade + cfg.Debug = spec.Debug + cfg.EnableAPIServer = spec.EnableAPIServer + cfg.APIServerAddr = spec.APIServerAddr + cfg.RlimitNofile = spec.RlimitNofile + cfg.LogFilePath = spec.LogFilePath + cfg.LogStats = spec.LogStats + cfg.LogTarget = spec.LogTarget + cfg.LogDateTime = spec.LogDateTime + cfg.ConfigFilePath = spec.ConfigFilePath + cfg.CPUProfile = spec.CPUProfile + cfg.MEMProfile = spec.MEMProfile + cfg.PProfPort = spec.PProfPort + cfg.BlockProfile = spec.BlockProfile + cfg.BlockProfileRate = spec.BlockProfileRate + cfg.MutexProfile = spec.MutexProfile + cfg.MutexProfileRate = spec.MutexProfileRate + cfg.RemoteRPCTimeout = spec.RemoteRPCTimeout + cfg.RetryWait = spec.RetryWait + cfg.RemoteRPCAddrs = append(config.StringValues{}, spec.RemoteRPCAddrs...) + cfg.SBlockdomains = append(config.StringValues{}, spec.SBlockdomains...) + cfg.SBlocklists = append(config.StringValues{}, spec.SBlocklists...) + cfg.SAllowlists = append(config.StringValues{}, spec.SAllowlists...) + cfg.ResolveCacheTime = spec.ResolveCacheTime + cfg.BnsCacheTime = spec.BnsCacheTime + cfg.MaxPortsPerDevice = spec.MaxPortsPerDevice +} + +func readDaemonMetadata() (daemonMetadata, error) { + socketPath, metaPath, err := daemonPaths() + if err != nil { + return daemonMetadata{}, err + } + _ = socketPath + buf, err := os.ReadFile(metaPath) + if err != nil { + return daemonMetadata{}, err + } + var meta daemonMetadata + if err := json.Unmarshal(buf, &meta); err != nil { + return daemonMetadata{}, err + } + return meta, nil +} + +func writeDaemonMetadata(path string, meta daemonMetadata) error { + buf, err := json.Marshal(meta) + if err != nil { + return err + } + return os.WriteFile(path, buf, 0600) +} + +func cleanupDaemonArtifacts(socketPath, metaPath string) { + if socketPath != "" { + cleanupDaemonTransport(socketPath) + } + if metaPath != "" { + _ = os.Remove(metaPath) + } +} + +func signalDaemonReady() error { + fdStr := strings.TrimSpace(os.Getenv(envDaemonReadyFD)) + if fdStr == "" { + return nil + } + fd, err := strconv.Atoi(fdStr) + if err != nil { + return err + } + f := os.NewFile(uintptr(fd), "daemon-ready") + if f == nil { + return fmt.Errorf("invalid daemon ready file descriptor") + } + defer f.Close() + _, err = f.Write([]byte{1}) + return err +} + +func daemonRestartEnv(spec daemonStartupSpec) ([]string, error) { + specBytes, err := json.Marshal(spec) + if err != nil { + return nil, err + } + env := make([]string, 0, len(os.Environ())+1) + for _, item := range os.Environ() { + if strings.HasPrefix(item, envDaemonReadyFD+"=") || + strings.HasPrefix(item, envDaemonStartupSpec+"=") || + strings.HasPrefix(item, envDaemonRestoreArgs+"=") { + continue + } + env = append(env, item) + } + env = append(env, fmt.Sprintf("%s=%s", envDaemonStartupSpec, string(specBytes))) + if restoreArgs := daemonRestoreArgsForRestart(); len(restoreArgs) > 0 { + restoreBytes, err := json.Marshal(restoreArgs) + if err != nil { + return nil, err + } + env = append(env, fmt.Sprintf("%s=%s", envDaemonRestoreArgs, string(restoreBytes))) + } + return env, nil +} + +func prepareDaemonForRestart() { + if app != nil { + app.StopMode() + } + if daemonState != nil { + daemonState.clearModeSnapshot() + } +} + +func daemonRestoreArgsFromEnv() ([]string, error) { + raw := strings.TrimSpace(os.Getenv(envDaemonRestoreArgs)) + if raw == "" { + return nil, nil + } + var args []string + if err := json.Unmarshal([]byte(raw), &args); err != nil { + return nil, err + } + return args, nil +} + +func logDaemonInternalError(msg string, err error) { + cfg := config.AppConfig + if cfg != nil && cfg.Logger != nil { + cfg.Logger.Error("%s: %v", msg, err) + return + } + fmt.Fprintf(os.Stderr, "%s: %v\n", msg, err) +} + +func (rd *runtimeDaemon) updateModeSnapshot(mode string, args []string, cfg *config.Config) { + if rd == nil || cfg == nil { + return + } + rd.stateMu.Lock() + defer rd.stateMu.Unlock() + rd.activeMode = mode + rd.activeArgs = sanitizeModeArgs(mode, args) + rd.ports = cloneDaemonPortMap(cfg.PublishedPorts) + rd.binds = append([]config.Bind{}, cfg.Binds...) + gatewayStatus := gatewayStatusFromConfig(cfg) + rd.socksOn = gatewayStatus.SocksEnabled + rd.socksAddr = gatewayStatus.SocksAddr + rd.gatewayOn = gatewayStatus.GatewayEnabled + rd.gatewayAddr = gatewayStatus.GatewayAddr + rd.secureGatewayOn = gatewayStatus.SecureGatewayEnabled + rd.secureGatewayAddr = gatewayStatus.SecureGatewayAddr + rd.secureGatewayAdditionalAddrs = append([]string{}, gatewayStatus.SecureGatewayAdditionalAddrs...) + rd.apiOn = cfg.EnableAPIServer + rd.apiAddr = cfg.APIServerAddr + rd.notifyModeChangedLocked() +} + +func (rd *runtimeDaemon) notifyModeChangedLocked() { + if rd.modeChange != nil { + close(rd.modeChange) + } + rd.modeChange = make(chan struct{}) +} + +func (rd *runtimeDaemon) clearModeSnapshot() { + if rd == nil { + return + } + rd.stateMu.Lock() + defer rd.stateMu.Unlock() + rd.activeMode = "" + rd.activeArgs = nil + rd.ports = map[int]*config.Port{} + rd.binds = nil + rd.socksOn = false + rd.socksAddr = "" + rd.gatewayOn = false + rd.gatewayAddr = "" + rd.secureGatewayOn = false + rd.secureGatewayAddr = "" + rd.secureGatewayAdditionalAddrs = nil + rd.apiOn = false + rd.apiAddr = "" + rd.notifyModeChangedLocked() +} + +func (rd *runtimeDaemon) waitForModeExit(mode string, args []string, done <-chan struct{}) { + if rd == nil { + return + } + wantArgs := sanitizeModeArgs(mode, args) + for { + rd.stateMu.Lock() + activeMode := rd.activeMode + activeArgs := append([]string{}, rd.activeArgs...) + modeChange := rd.modeChange + rd.stateMu.Unlock() + + if activeMode == "" || activeMode != mode || !equalStringSlices(activeArgs, wantArgs) { + return + } + select { + case <-modeChange: + case <-done: + return + } + } +} + +func equalStringSlices(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +func daemonRestoreArgsForRestart() []string { + if daemonState == nil { + return nil + } + daemonState.stateMu.Lock() + defer daemonState.stateMu.Unlock() + if daemonState.activeMode == "" || len(daemonState.activeArgs) == 0 { + return nil + } + return append([]string{}, daemonState.activeArgs...) +} + +func (rd *runtimeDaemon) snapshotStatus() daemonRuntimeStatus { + status := daemonRuntimeStatus{} + if rd == nil { + return status + } + rd.stateMu.Lock() + defer rd.stateMu.Unlock() + status.ActiveMode = rd.activeMode + status.ActiveArgs = append([]string{}, rd.activeArgs...) + status.PublishedPorts = cloneDaemonPortMap(rd.ports) + status.Binds = append([]config.Bind{}, rd.binds...) + status.SocksEnabled = rd.socksOn + status.SocksAddr = rd.socksAddr + status.GatewayEnabled = rd.gatewayOn + status.GatewayAddr = rd.gatewayAddr + status.SecureGatewayEnabled = rd.secureGatewayOn + status.SecureGatewayAddr = rd.secureGatewayAddr + status.SecureGatewayAdditionalAddrs = append([]string{}, rd.secureGatewayAdditionalAddrs...) + status.APIEnabled = rd.apiOn + status.APIAddr = rd.apiAddr + return status +} + +func cloneDaemonPortMap(in map[int]*config.Port) map[int]*config.Port { + if len(in) == 0 { + return map[int]*config.Port{} + } + out := make(map[int]*config.Port, len(in)) + for k, v := range in { + if v == nil { + out[k] = nil + continue + } + cp := *v + cp.Allowlist = cloneAddressSet(v.Allowlist) + cp.BnsAllowlist = cloneStringBoolSet(v.BnsAllowlist) + cp.DriveAllowList = cloneAddressSet(v.DriveAllowList) + cp.DriveMemberAllowList = cloneAddressSet(v.DriveMemberAllowList) + out[k] = &cp + } + return out +} + +func cloneAddressSet(in map[util.Address]bool) map[util.Address]bool { + if len(in) == 0 { + return map[util.Address]bool{} + } + out := make(map[util.Address]bool, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func cloneStringBoolSet(in map[string]bool) map[string]bool { + if len(in) == 0 { + return map[string]bool{} + } + out := make(map[string]bool, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func daemonLeaseLocalProxy() (string, string, error) { + if err := app.Start(); err != nil { + return "", "", err + } + socksServer, err := rpc.NewSocksServer(sshLocalSocksConfig(config.AppConfig), app.clientManager) + if err != nil { + return "", "", err + } + if err := socksServer.Start(); err != nil { + return "", "", err + } + addr := socksServer.Addr() + if addr == nil { + socksServer.Close() + return "", "", fmt.Errorf("proxy lease did not expose an address") + } + tcpAddr, ok := addr.(*net.TCPAddr) + if !ok { + socksServer.Close() + return "", "", fmt.Errorf("unexpected proxy lease address type: %T", addr) + } + host := tcpAddr.IP.String() + if host == "" || host == "::" || host == "0.0.0.0" { + host = "127.0.0.1" + } + leaseID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), tcpAddr.Port) + daemonState.leasesMu.Lock() + daemonState.leases[leaseID] = socksServer + daemonState.leasesMu.Unlock() + return net.JoinHostPort(host, strconv.Itoa(tcpAddr.Port)), leaseID, nil +} + +func daemonReleaseLocalProxy(leaseID string) error { + if daemonState == nil { + return nil + } + daemonState.leasesMu.Lock() + socksServer := daemonState.leases[leaseID] + delete(daemonState.leases, leaseID) + daemonState.leasesMu.Unlock() + if socksServer != nil { + socksServer.Close() + } + return nil +} + +func (rt *runtimeDaemon) closeLeases() { + rt.leasesMu.Lock() + defer rt.leasesMu.Unlock() + for leaseID, server := range rt.leases { + delete(rt.leases, leaseID) + if server != nil { + server.Close() + } + } +} + +func releaseDaemonLease(leaseID string) error { + if leaseID == "" { + return nil + } + meta, err := readDaemonMetadata() + if err != nil { + return err + } + conn, err := dialDaemon(meta.SocketPath) + if err != nil { + return err + } + defer conn.Close() + req := daemonRequest{ + Version: daemonProtocolVersion, + Kind: daemonRequestRelease, + LeaseID: leaseID, + } + if err := json.NewEncoder(conn).Encode(req); err != nil { + return err + } + var resp daemonResponse + if err := json.NewDecoder(conn).Decode(&resp); err != nil { + return err + } + if resp.Error != "" { + return fmt.Errorf("%s", resp.Error) + } + return nil +} diff --git a/cmd/diode/daemon_manage.go b/cmd/diode/daemon_manage.go new file mode 100644 index 00000000..639cb2d0 --- /dev/null +++ b/cmd/diode/daemon_manage.go @@ -0,0 +1,630 @@ +package main + +import ( + "encoding/json" + "fmt" + "os" + "sort" + "strconv" + "strings" + "time" + + "github.com/diodechain/diode_client/command" + "github.com/diodechain/diode_client/config" +) + +type daemonRuntimeStatus struct { + ActiveMode string + ActiveArgs []string + PublishedPorts map[int]*config.Port + Binds []config.Bind + SocksEnabled bool + SocksAddr string + GatewayEnabled bool + GatewayAddr string + SecureGatewayEnabled bool + SecureGatewayAddr string + SecureGatewayAdditionalAddrs []string + APIEnabled bool + APIAddr string +} + +var ( + daemonManageCmd = &command.Command{ + Name: "daemon", + HelpText: ` Inspect and manage the running diode daemon.`, + ExampleText: " diode daemon status\n diode daemon stop\n diode daemon mode-stop\n diode daemon ports remove 80 443\n diode daemon ports clear", + Run: daemonManageHandler, + Type: command.EmptyConnectionCommand, + PassThroughArgs: true, + } +) + +func init() { + diodeCmd.AddSubCommand(daemonManageCmd) +} + +func daemonManageHandler() error { + return nil +} + +func handleDaemonManagerCLI(args []string) (bool, int) { + if len(args) == 0 || args[0] != "daemon" { + return false, 0 + } + subArgs := args[1:] + if len(subArgs) == 0 { + subArgs = []string{"status"} + } + switch subArgs[0] { + case "status": + return true, runDaemonManagerStatus() + case "stop": + return true, runDaemonManagerAction([]string{"daemon", "stop"}) + case "restart": + return true, runDaemonManagerRestart() + case "mode-stop": + return true, runDaemonManagerAction([]string{"daemon", "mode-stop"}) + case "mode": + if len(subArgs) == 2 && subArgs[1] == "stop" { + return true, runDaemonManagerAction([]string{"daemon", "mode-stop"}) + } + stderrln("usage: diode daemon [status|stop|restart|mode-stop|ports]") + stderrln(" diode daemon mode stop") + stderrln(" diode daemon ports [remove|clear]") + return true, 2 + case "ports": + return true, runDaemonManagerPorts(subArgs[1:]) + default: + stderrln("usage: diode daemon [status|stop|restart|mode-stop|ports]") + stderrln(" diode daemon mode stop") + stderrln(" diode daemon ports [remove|clear]") + return true, 2 + } +} + +func runDaemonManagerStatus() int { + resp, running, err := dispatchToRunningDaemon(daemonRequest{ + Version: daemonProtocolVersion, + Kind: daemonRequestManage, + Command: "daemon", + Args: []string{"daemon", "status"}, + }) + if err != nil { + stderrln(err.Error()) + return 1 + } + if !running { + stdoutln("Daemon status: not running") + return 0 + } + writeDaemonResponse(resp) + return resp.ExitCode +} + +func runDaemonManagerPorts(args []string) int { + if len(args) == 0 { + stderrln("usage: diode daemon ports [remove|clear]") + return 2 + } + switch args[0] { + case "remove", "rm": + if len(args) < 2 { + stderrln("usage: diode daemon ports remove [...]") + return 2 + } + reqArgs := []string{"daemon", "ports", "remove"} + reqArgs = append(reqArgs, args[1:]...) + return runDaemonManagerAction(reqArgs) + case "clear": + return runDaemonManagerAction([]string{"daemon", "ports", "clear"}) + default: + stderrln("usage: diode daemon ports [remove|clear]") + return 2 + } +} + +func runDaemonManagerAction(args []string) int { + resp, running, err := dispatchToRunningDaemon(daemonRequest{ + Version: daemonProtocolVersion, + Kind: daemonRequestManage, + Command: "daemon", + Args: args, + }) + if err != nil { + stderrln(err.Error()) + return 1 + } + if !running { + stdoutln("Daemon status: not running") + return 1 + } + writeDaemonResponse(resp) + if len(args) >= 2 && args[0] == "daemon" && args[1] == "stop" && resp.ExitCode == 0 { + deadline := time.Now().Add(10 * time.Second) + for time.Now().Before(deadline) { + time.Sleep(100 * time.Millisecond) + _, running, err := dispatchToRunningDaemon(daemonRequest{ + Version: daemonProtocolVersion, + Kind: daemonRequestManage, + Command: "daemon", + Args: []string{"daemon", "status"}, + }) + if err == nil && !running { + return 0 + } + } + stderrln("daemon stop timed out waiting for the daemon to exit") + return 1 + } + return resp.ExitCode +} + +func runDaemonManagerRestart() int { + resp, running, err := dispatchToRunningDaemon(daemonRequest{ + Version: daemonProtocolVersion, + Kind: daemonRequestManage, + Command: "daemon", + Args: []string{"daemon", "restart"}, + }) + if err != nil { + stderrln(err.Error()) + return 1 + } + if !running { + stdoutln("Daemon status: not running") + return 1 + } + writeDaemonResponse(resp) + if resp.ExitCode != 0 { + return resp.ExitCode + } + deadline := time.Now().Add(10 * time.Second) + for time.Now().Before(deadline) { + time.Sleep(200 * time.Millisecond) + statusResp, ok, err := dispatchToRunningDaemon(daemonRequest{ + Version: daemonProtocolVersion, + Kind: daemonRequestManage, + Command: "daemon", + Args: []string{"daemon", "status"}, + }) + if err == nil && ok && statusResp.ExitCode == 0 { + stdoutln("Daemon restarted.") + return 0 + } + } + stderrln("daemon restart timed out waiting for the daemon to come back") + return 1 +} + +func dispatchToRunningDaemon(req daemonRequest) (daemonResponse, bool, error) { + var lastMeta daemonMetadata + for attempt := 0; attempt < 10; attempt++ { + meta, err := readDaemonMetadata() + if err != nil { + if os.IsNotExist(err) { + if attempt == 9 { + return daemonResponse{}, false, nil + } + time.Sleep(100 * time.Millisecond) + continue + } + return daemonResponse{}, false, err + } + lastMeta = meta + conn, err := dialDaemon(meta.SocketPath) + if err != nil { + if attempt == 9 { + cleanupDaemonArtifacts(meta.SocketPath, metaPathFromSocket(meta.SocketPath)) + return daemonResponse{}, false, nil + } + time.Sleep(100 * time.Millisecond) + continue + } + defer conn.Close() + if err := json.NewEncoder(conn).Encode(req); err != nil { + return daemonResponse{}, true, err + } + var resp daemonResponse + if err := json.NewDecoder(conn).Decode(&resp); err != nil { + return daemonResponse{}, true, err + } + return resp, true, nil + } + if lastMeta.SocketPath != "" { + cleanupDaemonArtifacts(lastMeta.SocketPath, metaPathFromSocket(lastMeta.SocketPath)) + } + return daemonResponse{}, false, nil +} + +func requestDaemonModeStop() error { + resp, running, err := dispatchToRunningDaemon(daemonRequest{ + Version: daemonProtocolVersion, + Kind: daemonRequestManage, + Command: "daemon", + Args: []string{"daemon", "mode-stop"}, + }) + if err != nil || !running { + return err + } + if resp.ExitCode != 0 { + if resp.Error != "" { + return fmt.Errorf("%s", resp.Error) + } + return newExitStatusError(resp.ExitCode, "daemon mode stop failed") + } + return nil +} + +func runDaemonManage(args []string, resp *daemonResponse) error { + if len(args) == 0 || args[0] != "daemon" { + return newExitStatusError(2, "missing daemon action") + } + action := "status" + if len(args) > 1 { + action = args[1] + } + switch action { + case "status": + renderDaemonStatus() + return nil + case "stop": + stdoutln("Stopping diode daemon.") + app.StopMode() + daemonState.clearModeSnapshot() + if resp != nil { + resp.Shutdown = true + } + return nil + case "restart": + exePath, err := os.Executable() + if err != nil { + exePath = os.Args[0] + } + stdoutln("Restarting diode daemon.") + if resp != nil { + resp.RestartPath = exePath + } + return nil + case "mode-stop": + stdoutln("Stopping diode daemon mode.") + app.StopMode() + daemonState.clearModeSnapshot() + return nil + case "ports": + if len(args) < 3 { + return newExitStatusError(2, "usage: diode daemon ports [remove|clear]") + } + switch args[2] { + case "remove", "rm": + if len(args) < 4 { + return newExitStatusError(2, "usage: diode daemon ports remove [...]") + } + ports := make([]int, 0, len(args)-3) + for _, raw := range args[3:] { + port, err := strconv.Atoi(raw) + if err != nil || port < 1 || port > 65535 { + return newExitStatusError(2, "invalid port: %s", raw) + } + ports = append(ports, port) + } + return daemonRemoveManagedPorts(ports) + case "clear": + return daemonClearManagedPorts() + default: + return newExitStatusError(2, "usage: diode daemon ports [remove|clear]") + } + default: + return newExitStatusError(2, "unknown daemon action: %s", action) + } +} + +func renderDaemonStatus() { + status := daemonState.snapshotStatus() + cfg := config.AppConfig + + cfg.PrintLabel("Daemon status", "running") + cfg.PrintLabel("PID", strconv.Itoa(os.Getpid())) + cfg.PrintLabel("Socket", daemonState.socketPath) + mode := status.ActiveMode + if mode == "" { + mode = "none" + } + cfg.PrintLabel("Active mode", mode) + cfg.PrintLabel("Client address", cfg.ClientAddr.HexString()) + cfg.PrintLabel("Fleet address", cfg.FleetAddr.HexString()) + if cfg.ClientName != "" { + cfg.PrintLabel("Client name", cfg.ClientName+".diode") + } + if status.SocksEnabled { + cfg.PrintLabel("SOCKS proxy", status.SocksAddr) + } else { + cfg.PrintLabel("SOCKS proxy", "disabled") + } + if status.GatewayEnabled { + cfg.PrintLabel("HTTP gateway", status.GatewayAddr) + } else { + cfg.PrintLabel("HTTP gateway", "disabled") + } + if status.SecureGatewayEnabled { + cfg.PrintLabel("HTTPS gateway", status.SecureGatewayAddr) + } else { + cfg.PrintLabel("HTTPS gateway", "disabled") + } + if len(status.SecureGatewayAdditionalAddrs) > 0 { + cfg.PrintLabel("HTTPS gateways", strings.Join(status.SecureGatewayAdditionalAddrs, ", ")) + } + if status.APIEnabled { + cfg.PrintLabel("Config API", status.APIAddr) + } else { + cfg.PrintLabel("Config API", "disabled") + } + if len(status.ActiveArgs) > 0 { + cfg.PrintLabel("Mode args", strings.Join(status.ActiveArgs, " ")) + } + if len(status.PublishedPorts) > 0 { + renderPublishedPortMap(cfg, status.PublishedPorts) + } else { + cfg.PrintLabel("Published ports", "none") + } + if len(status.Binds) > 0 { + renderBindMap(cfg, status.Binds) + } else { + cfg.PrintLabel("Binds", "none") + } +} + +func portAllowlistStrings(port *config.Port) []string { + if port == nil { + return nil + } + addrs := make([]string, 0, len(port.Allowlist)+len(port.BnsAllowlist)+len(port.DriveAllowList)+len(port.DriveMemberAllowList)) + for addr := range port.Allowlist { + addrs = append(addrs, addr.HexString()) + } + for bnsName := range port.BnsAllowlist { + addrs = append(addrs, bnsName) + } + for drive := range port.DriveAllowList { + addrs = append(addrs, drive.HexString()) + } + for driveMember := range port.DriveMemberAllowList { + addrs = append(addrs, driveMember.HexString()) + } + sort.Strings(addrs) + return addrs +} + +func daemonRemoveManagedPorts(ports []int) error { + mode, args := daemonModeArgs() + switch mode { + case "": + return newExitStatusError(1, "daemon has no active mode") + case "publish": + return daemonReapplyPublishWithoutPorts(args, ports) + case "files": + current := daemonState.snapshotStatus() + if len(current.PublishedPorts) != 1 { + return newExitStatusError(1, "files mode is in an unexpected state") + } + for _, port := range ports { + if _, ok := current.PublishedPorts[port]; ok { + app.StopMode() + daemonState.clearModeSnapshot() + stdoutf("Removed published port: %d\n", port) + stdoutln("Files mode stopped because no published ports remain.") + return nil + } + } + return newExitStatusError(1, "requested port is not active in files mode") + default: + return newExitStatusError(1, "port removal is only supported for publish/files modes; current mode is %s", mode) + } +} + +func daemonClearManagedPorts() error { + mode, _ := daemonModeArgs() + switch mode { + case "": + stdoutln("Daemon has no active mode.") + return nil + case "publish", "files": + app.StopMode() + daemonState.clearModeSnapshot() + stdoutln("Removed all published ports and stopped the active publish mode.") + return nil + default: + return newExitStatusError(1, "clearing published ports is only supported for publish/files modes; current mode is %s", mode) + } +} + +func daemonModeArgs() (string, []string) { + daemonState.stateMu.Lock() + defer daemonState.stateMu.Unlock() + return daemonState.activeMode, append([]string{}, daemonState.activeArgs...) +} + +func daemonReapplyPublishWithoutPorts(args []string, ports []int) error { + portSet := make(map[int]bool, len(ports)) + for _, port := range ports { + portSet[port] = true + } + newArgs, removed, err := filterPublishCommandArgs(args, portSet) + if err != nil { + return err + } + if len(removed) == 0 { + return newExitStatusError(1, "none of the requested ports are currently configured in publish mode") + } + sort.Ints(removed) + if countPublishManagedFlags(newArgs) == 0 && len(config.AppConfig.Binds) == 0 && !publishArgsHaveRootBinds(newArgs) { + app.StopMode() + daemonState.clearModeSnapshot() + stdoutf("Removed published ports: %s\n", joinPorts(removed)) + stdoutln("No published ports remain; publish mode stopped.") + return nil + } + if err := runDaemonCommandAsKind(daemonRequestApplyMode, newArgs); err != nil { + return err + } + daemonState.updateModeSnapshot("publish", newArgs, config.AppConfig) + stdoutf("Removed published ports: %s\n", joinPorts(removed)) + stdoutln("Publish mode was reapplied successfully.") + return nil +} + +func filterPublishCommandArgs(args []string, removePorts map[int]bool) ([]string, []int, error) { + pre, post, ok := splitPublishExecArgs(args) + if !ok { + return nil, nil, newExitStatusError(1, "daemon is not tracking a publish command") + } + filtered := append([]string{}, pre...) + filtered = append(filtered, "publish") + removed := make(map[int]bool) + for i := 0; i < len(post); i++ { + arg := post[i] + flagName, inlineValue, matched := parseManagedPublishFlag(arg) + if !matched { + filtered = append(filtered, arg) + continue + } + value := inlineValue + if value == "" { + if i+1 >= len(post) { + return nil, nil, newExitStatusError(2, "flag %s is missing a value", flagName) + } + i++ + value = post[i] + } + externPort, err := managedFlagExternPort(flagName, value) + if err != nil { + return nil, nil, err + } + if removePorts[externPort] { + removed[externPort] = true + continue + } + if inlineValue == "" { + filtered = append(filtered, flagName, value) + } else { + filtered = append(filtered, flagName+"="+value) + } + } + out := make([]int, 0, len(removed)) + for port := range removed { + out = append(out, port) + } + return filtered, out, nil +} + +func parseManagedPublishFlag(arg string) (string, string, bool) { + for _, name := range []string{"-public", "-protected", "-private", "-sshd", "-files"} { + if arg == name { + return name, "", true + } + prefix := name + "=" + if strings.HasPrefix(arg, prefix) { + return name, strings.TrimPrefix(arg, prefix), true + } + } + return "", "", false +} + +func managedFlagExternPort(flagName, value string) (int, error) { + switch flagName { + case "-public", "-protected", "-private": + return extractExternPortFromPortSpec(value) + case "-files": + portSpec, _, err := expandFilesSpec(value) + if err != nil { + return 0, err + } + return extractExternPortFromPortSpec(portSpec) + case "-sshd": + head := sshServicePattern.FindStringSubmatch(strings.TrimSpace(strings.Split(value, ",")[0])) + if len(head) != 4 { + return 0, newExitStatusError(2, "invalid ssh publish spec: %s", value) + } + port, err := strconv.Atoi(head[2]) + if err != nil { + return 0, newExitStatusError(2, "invalid ssh publish spec: %s", value) + } + return port, nil + default: + return 0, newExitStatusError(2, "unsupported publish flag: %s", flagName) + } +} + +func extractExternPortFromPortSpec(value string) (int, error) { + head := strings.TrimSpace(strings.Split(value, ",")[0]) + match := portPattern.FindStringSubmatch(head) + if len(match) != 8 { + return 0, newExitStatusError(2, "invalid publish spec: %s", value) + } + srcPort, err := strconv.Atoi(match[3]) + if err != nil { + return 0, err + } + if match[5] == "" { + return srcPort, nil + } + toPort, err := strconv.Atoi(match[5]) + if err != nil { + return 0, err + } + return toPort, nil +} + +func countPublishManagedFlags(args []string) int { + _, post, ok := splitPublishExecArgs(args) + if !ok { + return 0 + } + count := 0 + for i := 0; i < len(post); i++ { + flagName, inlineValue, matched := parseManagedPublishFlag(post[i]) + if !matched { + continue + } + count++ + if inlineValue == "" && i+1 < len(post) { + i++ + } + _ = flagName + } + return count +} + +func publishArgsHaveRootBinds(args []string) bool { + pre, _, ok := splitPublishExecArgs(args) + if !ok { + return false + } + for _, item := range parseRootExecItems(pre) { + if item.flagName == "-bind" { + return true + } + } + return false +} + +func joinPorts(ports []int) string { + items := make([]string, 0, len(ports)) + for _, port := range ports { + items = append(items, strconv.Itoa(port)) + } + return strings.Join(items, ", ") +} + +func runDaemonCommandAsKind(kind string, args []string) error { + activeDaemonReqMu.Lock() + prev := activeDaemonReqKind + activeDaemonReqKind = kind + activeDaemonReqMu.Unlock() + defer func() { + activeDaemonReqMu.Lock() + activeDaemonReqKind = prev + activeDaemonReqMu.Unlock() + }() + return runDaemonCommandArgs(args) +} diff --git a/cmd/diode/daemon_paths.go b/cmd/diode/daemon_paths.go new file mode 100644 index 00000000..a54dc3c8 --- /dev/null +++ b/cmd/diode/daemon_paths.go @@ -0,0 +1,46 @@ +package main + +import ( + "crypto/sha256" + "encoding/hex" + "os" + "path/filepath" + + "github.com/diodechain/diode_client/config" + "github.com/diodechain/diode_client/util" +) + +func daemonPathID() string { + dbPath := "" + if config.AppConfig != nil { + dbPath = config.AppConfig.DBPath + } + seed := canonicalDaemonDBPath(dbPath) + if base, err := os.UserConfigDir(); err == nil { + seed = filepath.Clean(base) + "\x00" + seed + } + sum := sha256.Sum256([]byte(seed)) + return hex.EncodeToString(sum[:8]) +} + +func canonicalDaemonDBPath(dbPath string) string { + if dbPath == "" { + dbPath = util.DefaultDBPath() + } + if abs, err := filepath.Abs(dbPath); err == nil { + dbPath = abs + } + return filepath.Clean(dbPath) +} + +func daemonPathDir() (string, error) { + base, err := os.UserConfigDir() + if err != nil { + return "", err + } + dir := filepath.Join(base, "diode", "daemons", daemonPathID()) + if err := os.MkdirAll(dir, 0700); err != nil { + return "", err + } + return dir, nil +} diff --git a/cmd/diode/daemon_test.go b/cmd/diode/daemon_test.go new file mode 100644 index 00000000..00c5abda --- /dev/null +++ b/cmd/diode/daemon_test.go @@ -0,0 +1,774 @@ +package main + +import ( + "bytes" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/diodechain/diode_client/config" +) + +func TestParseRootInvocationExtractsCommandAndStartupFlags(t *testing.T) { + inv, err := parseRootInvocation([]string{ + "-debug=true", + "-dbpath", "/tmp/diode.db", + "query", + "-address", "0xabc", + }) + if err != nil { + t.Fatalf("parseRootInvocation() error = %v", err) + } + if inv.command != "query" { + t.Fatalf("command = %q, want query", inv.command) + } + if len(inv.commandArgs) != 3 || inv.commandArgs[0] != "query" { + t.Fatalf("commandArgs = %#v, want query subcommand args", inv.commandArgs) + } + if !inv.startupSpec.Debug { + t.Fatalf("startupSpec.Debug = false, want true") + } + if inv.startupSpec.DBPath != "/tmp/diode.db" { + t.Fatalf("startupSpec.DBPath = %q, want /tmp/diode.db", inv.startupSpec.DBPath) + } +} + +func TestParseRootInvocationDetectsHelpAndNoDaemon(t *testing.T) { + inv, err := parseRootInvocation([]string{"-no-daemon", "publish", "--help"}) + if err != nil { + t.Fatalf("parseRootInvocation() error = %v", err) + } + if !inv.help { + t.Fatalf("help = false, want true") + } + if !inv.disableDaemon { + t.Fatalf("disableDaemon = false, want true") + } +} + +func TestParseRootInvocationDetectsDetach(t *testing.T) { + inv, err := parseRootInvocation([]string{"-d", "publish", "-public", "80"}) + if err != nil { + t.Fatalf("parseRootInvocation() error = %v", err) + } + if !inv.detachDaemon { + t.Fatal("detachDaemon = false, want true") + } +} + +func TestDaemonRequestForInvocationAttachesApplyModeByDefault(t *testing.T) { + inv, err := parseRootInvocation([]string{"publish", "-public", "80"}) + if err != nil { + t.Fatalf("parseRootInvocation() error = %v", err) + } + req, err := daemonRequestForInvocation(inv) + if err != nil { + t.Fatalf("daemonRequestForInvocation() error = %v", err) + } + if req.Kind != daemonRequestApplyMode { + t.Fatalf("request kind = %q, want %q", req.Kind, daemonRequestApplyMode) + } + if !req.Attach { + t.Fatal("request Attach = false, want true") + } +} + +func TestDaemonRequestForInvocationDetachedApplyMode(t *testing.T) { + inv, err := parseRootInvocation([]string{"-d", "publish", "-public", "80"}) + if err != nil { + t.Fatalf("parseRootInvocation() error = %v", err) + } + req, err := daemonRequestForInvocation(inv) + if err != nil { + t.Fatalf("daemonRequestForInvocation() error = %v", err) + } + if req.Kind != daemonRequestApplyMode { + t.Fatalf("request kind = %q, want %q", req.Kind, daemonRequestApplyMode) + } + if req.Attach { + t.Fatal("request Attach = true, want false") + } +} + +func TestDaemonRequestForInvocationRoutesSCPThroughLease(t *testing.T) { + inv, err := parseRootInvocation([]string{"scp", "./a", "host.diode:/tmp/a"}) + if err != nil { + t.Fatalf("parseRootInvocation() error = %v", err) + } + req, err := daemonRequestForInvocation(inv) + if err != nil { + t.Fatalf("daemonRequestForInvocation() error = %v", err) + } + if req.Kind != daemonRequestLease { + t.Fatalf("request kind = %q, want %q", req.Kind, daemonRequestLease) + } +} + +func TestDaemonRequestForInvocationRejectsDetachedOneOff(t *testing.T) { + inv, err := parseRootInvocation([]string{"-d", "query", "-address", "0xabc"}) + if err != nil { + t.Fatalf("parseRootInvocation() error = %v", err) + } + _, err = daemonRequestForInvocation(inv) + if err == nil { + t.Fatal("daemonRequestForInvocation() error = nil, want detach rejection") + } + if !strings.Contains(err.Error(), "-d is only supported") { + t.Fatalf("daemonRequestForInvocation() error = %q, want detach rejection", err.Error()) + } +} + +func TestDaemonStartupSpecCanonicalizesDBPath(t *testing.T) { + cfg := newRootConfig() + cfg.DBPath = filepath.Join(".", "relative-wallet.db") + spec := daemonStartupSpecFromConfig(cfg) + want, err := filepath.Abs(cfg.DBPath) + if err != nil { + t.Fatalf("filepath.Abs() error = %v", err) + } + want = filepath.Clean(want) + if spec.DBPath != want { + t.Fatalf("startup DBPath = %q, want %q", spec.DBPath, want) + } +} + +func TestDaemonPathsAreScopedByDBPath(t *testing.T) { + prevCfg := config.AppConfig + t.Cleanup(func() { + config.AppConfig = prevCfg + }) + + dir := t.TempDir() + cfgA := newRootConfig() + cfgA.DBPath = dir + "/wallet-a.db" + config.AppConfig = cfgA + socketA, metaA, err := daemonPaths() + if err != nil { + t.Fatalf("daemonPaths(wallet-a) error = %v", err) + } + + cfgB := newRootConfig() + cfgB.DBPath = dir + "/wallet-b.db" + config.AppConfig = cfgB + socketB, metaB, err := daemonPaths() + if err != nil { + t.Fatalf("daemonPaths(wallet-b) error = %v", err) + } + + if socketA == socketB { + t.Fatalf("socket path should differ for different dbpaths: %q", socketA) + } + if metaA == metaB { + t.Fatalf("metadata path should differ for different dbpaths: %q", metaA) + } +} + +func TestParseRootInvocationDefaultsToPublishForRootFlagsOnly(t *testing.T) { + inv, err := parseRootInvocation([]string{"-bind", "8080:0x8911295322a1b94539e258e46f18e33acf21b48a:80"}) + if err != nil { + t.Fatalf("parseRootInvocation() error = %v", err) + } + if inv.command != "publish" { + t.Fatalf("command = %q, want publish", inv.command) + } + if len(inv.commandArgs) != 1 || inv.commandArgs[0] != "publish" { + t.Fatalf("commandArgs = %#v, want implicit publish only", inv.commandArgs) + } + if len(inv.execArgs) != 3 || inv.execArgs[0] != "-bind" || inv.execArgs[2] != "publish" { + t.Fatalf("execArgs = %#v, want root flags plus implicit publish", inv.execArgs) + } +} + +func TestRunDaemonManageModeStopLeavesDaemonRunning(t *testing.T) { + cfg := newSharedControlTestConfig(t) + setupSharedControlTestEnv(t, cfg) + + origState := daemonState + daemonState = &runtimeDaemon{ + modeChange: make(chan struct{}), + activeMode: "publish", + activeArgs: []string{"publish", "-public", "80"}, + } + t.Cleanup(func() { + daemonState = origState + }) + app.BeginMode("publish") + + if err := runDaemonManage([]string{"daemon", "mode-stop"}, &daemonResponse{}); err != nil { + t.Fatalf("runDaemonManage(mode-stop) error = %v", err) + } + if app.Closed() { + t.Fatal("mode-stop closed the daemon app") + } + if status := daemonState.snapshotStatus(); status.ActiveMode != "" { + t.Fatalf("ActiveMode = %q, want empty", status.ActiveMode) + } +} + +func TestExecuteDaemonBusyModeStopDoesNotWaitForExecLock(t *testing.T) { + cfg := newSharedControlTestConfig(t) + setupSharedControlTestEnv(t, cfg) + + origState := daemonState + daemonState = &runtimeDaemon{ + modeChange: make(chan struct{}), + activeMode: "publish", + activeArgs: []string{"publish", "-public", "80"}, + } + t.Cleanup(func() { + daemonState = origState + }) + app.BeginMode("publish") + + daemonExecMu.Lock() + resp := executeDaemonRequest(daemonRequest{ + Version: daemonProtocolVersion, + Kind: daemonRequestManage, + Command: "daemon", + Args: []string{"daemon", "mode-stop"}, + }) + daemonExecMu.Unlock() + + if resp.ExitCode != 0 { + t.Fatalf("ExitCode = %d, Error = %q", resp.ExitCode, resp.Error) + } + if !strings.Contains(resp.Stdout, "Stopping diode daemon mode.") { + t.Fatalf("Stdout = %q, want mode-stop message", resp.Stdout) + } + if got := app.ActiveMode(); got != "" { + t.Fatalf("ActiveMode() = %q, want empty", got) + } + if status := daemonState.snapshotStatus(); status.ActiveMode != "" { + t.Fatalf("snapshot ActiveMode = %q, want empty", status.ActiveMode) + } +} + +func TestExecuteDaemonBusyLeaseReturnsActionableError(t *testing.T) { + cfg := newSharedControlTestConfig(t) + setupSharedControlTestEnv(t, cfg) + + origState := daemonState + daemonState = &runtimeDaemon{modeChange: make(chan struct{})} + t.Cleanup(func() { + daemonState = origState + }) + + daemonExecMu.Lock() + resp := executeDaemonRequest(daemonRequest{ + Version: daemonProtocolVersion, + Kind: daemonRequestLease, + Command: "ssh", + Args: []string{"ssh", "ubuntu@example.diode"}, + }) + daemonExecMu.Unlock() + + if resp.ExitCode == 0 { + t.Fatalf("ExitCode = 0, want busy failure") + } + if !strings.Contains(resp.Error, "daemon is busy") { + t.Fatalf("Error = %q, want busy error", resp.Error) + } +} + +func TestStopModeTimesOutWaitingForDone(t *testing.T) { + prevTimeout := modeStopWaitTimeout + modeStopWaitTimeout = 5 * time.Millisecond + t.Cleanup(func() { + modeStopWaitTimeout = prevTimeout + }) + + dio := &Diode{config: &config.Config{}} + dio.modeStopCh = make(chan struct{}) + dio.modeDoneCh = make(chan struct{}) + + start := time.Now() + dio.StopMode() + if time.Since(start) > time.Second { + t.Fatal("StopMode blocked waiting for mode done channel") + } +} + +func TestStopModeClearsPublishedRuntimeState(t *testing.T) { + dio := NewDiode(newSharedControlTestConfig(t)) + dio.BeginMode("publish") + dio.controlRuntime.published = publishedControlState{ + public: []string{"80"}, + } + + dio.StopMode() + + if len(dio.controlRuntime.published.public) != 0 { + t.Fatalf("published runtime state = %#v, want cleared", dio.controlRuntime.published) + } +} + +func TestStartPrintsIdentityOnSubsequentCalls(t *testing.T) { + cfg := newSharedControlTestConfig(t) + var stdout bytes.Buffer + cfg.StdoutWriter = &stdout + + dio := &Diode{ + config: cfg, + cmd: daemonManageCmd, + controlsLoaded: true, + started: true, + } + + if err := dio.Start(); err != nil { + t.Fatalf("first Start() error = %v", err) + } + if got := strings.Count(stdout.String(), "Client address"); got != 1 { + t.Fatalf("first Start() client address lines = %d, want 1\n%s", got, stdout.String()) + } + + stdout.Reset() + if err := dio.Start(); err != nil { + t.Fatalf("second Start() error = %v", err) + } + if got := strings.Count(stdout.String(), "Client address"); got != 1 { + t.Fatalf("second Start() client address lines = %d, want 1\n%s", got, stdout.String()) + } +} + +func TestRenderDaemonStatusIncludesGatewayListeners(t *testing.T) { + prevCfg := config.AppConfig + prevState := daemonState + t.Cleanup(func() { + config.AppConfig = prevCfg + daemonState = prevState + }) + + cfg := newSharedControlTestConfig(t) + var stdout bytes.Buffer + cfg.StdoutWriter = &stdout + config.AppConfig = cfg + daemonState = &runtimeDaemon{ + socketPath: "/tmp/daemon.sock", + modeChange: make(chan struct{}), + activeMode: "gateway", + activeArgs: []string{"gateway", "-httpd_port", "18080"}, + socksOn: true, + socksAddr: "127.0.0.1:18080", + gatewayOn: true, + gatewayAddr: "127.0.0.1:18081", + secureGatewayOn: true, + secureGatewayAddr: "127.0.0.1:18443", + secureGatewayAdditionalAddrs: []string{"127.0.0.1:18444"}, + ports: map[int]*config.Port{}, + } + + renderDaemonStatus() + + out := stdout.String() + for _, want := range []string{ + "Active mode", + "gateway", + "SOCKS proxy", + "127.0.0.1:18080", + "HTTP gateway", + "127.0.0.1:18081", + "HTTPS gateway", + "127.0.0.1:18443", + "HTTPS gateways", + "127.0.0.1:18444", + } { + if !strings.Contains(out, want) { + t.Fatalf("status output missing %q\n%s", want, out) + } + } +} + +func TestSanitizedDaemonBaseConfigResetsRequestOnlyState(t *testing.T) { + cfg := newRootConfig() + cfg.ConfigList = true + cfg.QueryAddress = "0x1" + cfg.EnableProxyServer = true + cfg.SocksServerPort = 9999 + cfg.PublicPublishedPorts = config.StringValues{"80:80"} + cfg.BNSLookup = "example" + cfg.StdoutWriter = testingLoggerWriter{t} + cfg.PublishedPorts = map[int]*config.Port{80: {}} + + sanitized := sanitizedDaemonBaseConfig(cfg) + if sanitized.ConfigList { + t.Fatalf("ConfigList = true, want false") + } + if sanitized.QueryAddress != "" { + t.Fatalf("QueryAddress = %q, want empty", sanitized.QueryAddress) + } + if !sanitized.EnableProxyServer { + t.Fatalf("EnableProxyServer = false, want true") + } + if sanitized.SocksServerPort != 9999 { + t.Fatalf("SocksServerPort = %d, want 9999", sanitized.SocksServerPort) + } + if len(sanitized.PublicPublishedPorts) != 1 { + t.Fatalf("PublicPublishedPorts = %#v, want preserved", sanitized.PublicPublishedPorts) + } + if sanitized.BNSLookup != "" { + t.Fatalf("BNSLookup = %q, want empty", sanitized.BNSLookup) + } + if sanitized.StdoutWriter != nil || sanitized.StderrWriter != nil { + t.Fatalf("stdout/stderr writers should be cleared") + } + if len(sanitized.PublishedPorts) != 1 { + t.Fatalf("PublishedPorts = %#v, want preserved", sanitized.PublishedPorts) + } +} + +func TestDaemonBufferedRequestPersistsBaseConfigOnlyWhenRequested(t *testing.T) { + prevCfg := config.AppConfig + prevState := daemonState + t.Cleanup(func() { + config.AppConfig = prevCfg + daemonState = prevState + }) + + base := newRootConfig() + base.Debug = false + config.AppConfig = newRootConfig() + daemonState = &runtimeDaemon{baseConfig: *base} + + resp := executeDaemonBufferedRequest(daemonRequestRunTask, false, func() (string, error) { + config.AppConfig.Debug = true + return "", nil + }) + if resp.ExitCode != 0 { + t.Fatalf("executeDaemonBufferedRequest() exit = %d err = %q", resp.ExitCode, resp.Error) + } + if daemonState.baseConfig.Debug { + t.Fatal("baseConfig.Debug changed for non-persistent request") + } + + resp = executeDaemonBufferedRequest(daemonRequestRunTask, true, func() (string, error) { + config.AppConfig.Debug = true + return "", nil + }) + if resp.ExitCode != 0 { + t.Fatalf("executeDaemonBufferedRequest(persist) exit = %d err = %q", resp.ExitCode, resp.Error) + } + if !daemonState.baseConfig.Debug { + t.Fatal("baseConfig.Debug did not change for persistent request") + } +} + +func TestResetRequestGlobalsClearsPublishFileGlobals(t *testing.T) { + publishFileSpecs = config.StringValues{"8080", "9090"} + publishFileFileroot = "/tmp/files" + t.Cleanup(func() { + publishFileSpecs = nil + publishFileFileroot = "" + }) + + resetRequestGlobals() + if len(publishFileSpecs) != 0 { + t.Fatalf("publishFileSpecs = %#v, want empty", publishFileSpecs) + } + if publishFileFileroot != "" { + t.Fatalf("publishFileFileroot = %q, want empty", publishFileFileroot) + } +} + +func TestDaemonStartupSpecFromConfigCopiesRootScopedValues(t *testing.T) { + cfg := newRootConfig() + cfg.RemoteRPCTimeout = 7 * time.Second + cfg.RetryWait = 2 * time.Second + cfg.ResolveCacheTime = 5 * time.Minute + cfg.BnsCacheTime = 6 * time.Minute + cfg.LogStats = 30 * time.Second + cfg.LogTarget = "collector:1234" + cfg.SBlocklists = config.StringValues{"0x1"} + + spec := daemonStartupSpecFromConfig(cfg) + if spec.RemoteRPCTimeout != 7*time.Second { + t.Fatalf("RemoteRPCTimeout = %v, want 7s", spec.RemoteRPCTimeout) + } + if spec.RetryWait != 2*time.Second { + t.Fatalf("RetryWait = %v, want 2s", spec.RetryWait) + } + if spec.ResolveCacheTime != 5*time.Minute { + t.Fatalf("ResolveCacheTime = %v, want 5m", spec.ResolveCacheTime) + } + if spec.BnsCacheTime != 6*time.Minute { + t.Fatalf("BnsCacheTime = %v, want 6m", spec.BnsCacheTime) + } + if spec.LogStats != 30*time.Second { + t.Fatalf("LogStats = %v, want 30s", spec.LogStats) + } + if spec.LogTarget != "collector:1234" { + t.Fatalf("LogTarget = %q, want collector:1234", spec.LogTarget) + } + if len(spec.SBlocklists) != 1 || spec.SBlocklists[0] != "0x1" { + t.Fatalf("SBlocklists = %#v, want [0x1]", spec.SBlocklists) + } +} + +func TestApplyDaemonStartupSpecCopiesLogAndCacheValues(t *testing.T) { + cfg := newRootConfig() + applyDaemonStartupSpec(cfg, daemonStartupSpec{ + LogStats: 15 * time.Second, + LogTarget: "collector:1234", + ResolveCacheTime: 4 * time.Minute, + BnsCacheTime: 3 * time.Minute, + }) + if cfg.LogStats != 15*time.Second { + t.Fatalf("LogStats = %v, want 15s", cfg.LogStats) + } + if cfg.LogTarget != "collector:1234" { + t.Fatalf("LogTarget = %q, want collector:1234", cfg.LogTarget) + } + if cfg.ResolveCacheTime != 4*time.Minute { + t.Fatalf("ResolveCacheTime = %v, want 4m", cfg.ResolveCacheTime) + } + if cfg.BnsCacheTime != 3*time.Minute { + t.Fatalf("BnsCacheTime = %v, want 3m", cfg.BnsCacheTime) + } +} + +func TestDaemonRestartEnvReplacesDaemonSpecificVars(t *testing.T) { + t.Setenv(envDaemonReadyFD, "3") + t.Setenv(envDaemonStartupSpec, `{"debug":false}`) + daemonState = nil + + env, err := daemonRestartEnv(daemonStartupSpec{Debug: true}) + if err != nil { + t.Fatalf("daemonRestartEnv() error = %v", err) + } + + startupVars := 0 + for _, item := range env { + if strings.HasPrefix(item, envDaemonReadyFD+"=") { + t.Fatalf("daemon restart env still contains %s: %q", envDaemonReadyFD, item) + } + if strings.HasPrefix(item, envDaemonStartupSpec+"=") { + startupVars++ + if !strings.Contains(item, `"debug":true`) { + t.Fatalf("startup spec env = %q, want debug=true", item) + } + } + } + if startupVars != 1 { + t.Fatalf("startup spec env vars = %d, want 1", startupVars) + } +} + +func TestDaemonRestartEnvPreservesActiveModeArgs(t *testing.T) { + prev := daemonState + defer func() { daemonState = prev }() + daemonState = &runtimeDaemon{ + activeMode: "publish", + activeArgs: []string{"publish", "-public", "80"}, + } + + env, err := daemonRestartEnv(daemonStartupSpec{Debug: true}) + if err != nil { + t.Fatalf("daemonRestartEnv() error = %v", err) + } + + restoreVars := 0 + for _, item := range env { + if strings.HasPrefix(item, envDaemonRestoreArgs+"=") { + restoreVars++ + if !strings.Contains(item, `"publish"`) || !strings.Contains(item, `"-public"`) || !strings.Contains(item, `"80"`) { + t.Fatalf("restore args env = %q", item) + } + } + } + if restoreVars != 1 { + t.Fatalf("restore args env vars = %d, want 1", restoreVars) + } +} + +func TestFilterPublishCommandArgsRemovesRequestedPorts(t *testing.T) { + args := []string{ + "publish", + "-public", "80:80", + "-private=127.0.0.1:22:2222,0x1234567890123456789012345678901234567890", + "-sshd", "private:2022:ubuntu,0x1234567890123456789012345678901234567890", + "-socksd=true", + } + filtered, removed, err := filterPublishCommandArgs(args, map[int]bool{80: true, 2022: true}) + if err != nil { + t.Fatalf("filterPublishCommandArgs() error = %v", err) + } + if got := strings.Join(filtered, " "); got != "publish -private=127.0.0.1:22:2222,0x1234567890123456789012345678901234567890 -socksd=true" { + t.Fatalf("filtered = %q", got) + } + if len(removed) != 2 { + t.Fatalf("removed = %#v, want 2 items", removed) + } + gotRemoved := map[int]bool{removed[0]: true, removed[1]: true} + if !gotRemoved[80] || !gotRemoved[2022] { + t.Fatalf("removed = %#v, want ports 80 and 2022", removed) + } +} + +func TestFilterPublishCommandArgsPreservesRootBinds(t *testing.T) { + args := []string{ + "-bind", "auto:0x1234567890123456789012345678901234567890:80:tcp", + "publish", + "-public", "80:80", + } + filtered, removed, err := filterPublishCommandArgs(args, map[int]bool{80: true}) + if err != nil { + t.Fatalf("filterPublishCommandArgs() error = %v", err) + } + want := []string{ + "-bind", "auto:0x1234567890123456789012345678901234567890:80:tcp", + "publish", + } + if strings.Join(filtered, "\x00") != strings.Join(want, "\x00") { + t.Fatalf("filtered = %#v, want %#v", filtered, want) + } + if len(removed) != 1 || removed[0] != 80 { + t.Fatalf("removed = %#v, want [80]", removed) + } + if !publishArgsHaveRootBinds(filtered) { + t.Fatal("publishArgsHaveRootBinds() = false, want true") + } + if got := countPublishManagedFlags(filtered); got != 0 { + t.Fatalf("countPublishManagedFlags() = %d, want 0", got) + } +} + +func TestDaemonModeNameFromArgsFindsImplicitPublish(t *testing.T) { + got := daemonModeNameFromArgs([]string{ + "-bind", "auto:0x1234567890123456789012345678901234567890:80:tcp", + "publish", + }) + if got != "publish" { + t.Fatalf("daemonModeNameFromArgs() = %q, want publish", got) + } +} + +func TestResetTransientConfigClearsConfigFullValues(t *testing.T) { + cfg := newRootConfig() + cfg.ConfigFullValues = true + resetTransientConfig(cfg) + if cfg.ConfigFullValues { + t.Fatal("ConfigFullValues = true, want false") + } +} + +func TestResetSharedControlsForArgsClearsOnlyOverriddenLists(t *testing.T) { + cfg := newRootConfig() + cfg.SocksServerPort = 23104 + cfg.PublicPublishedPorts = config.StringValues{"80:80"} + cfg.SBinds = config.StringValues{"auto:0x1234567890123456789012345678901234567890:80:tcp"} + resetSharedControlsForArgs(cfg, []string{"publish", "-public", "8080:80"}) + if cfg.SocksServerPort != 23104 { + t.Fatalf("SocksServerPort = %d, want preserved 23104", cfg.SocksServerPort) + } + if len(cfg.SBinds) != 1 { + t.Fatalf("SBinds = %#v, want preserved", cfg.SBinds) + } + if len(cfg.PublicPublishedPorts) != 0 { + t.Fatalf("PublicPublishedPorts = %#v, want reset", cfg.PublicPublishedPorts) + } +} + +func TestManagedFlagExternPortSupportsFilesSpec(t *testing.T) { + port, err := managedFlagExternPort("-files", "8080,example.diode") + if err != nil { + t.Fatalf("managedFlagExternPort(-files) error = %v", err) + } + if port != 8080 { + t.Fatalf("port = %d, want 8080", port) + } +} + +func TestRefreshRequestDerivedConfigDedupesBinds(t *testing.T) { + cfg := newRootConfig() + cfg.SBinds = config.StringValues{ + "8080:0x8911295322a1b94539e258e46f18e33acf21b48a:80", + "8080:0x8911295322a1b94539e258e46f18e33acf21b48a:80", + } + if err := refreshRequestDerivedConfig(cfg); err != nil { + t.Fatalf("refreshRequestDerivedConfig() error = %v", err) + } + if len(cfg.SBinds) != 1 { + t.Fatalf("SBinds = %#v, want one unique bind", cfg.SBinds) + } + if len(cfg.Binds) != 1 { + t.Fatalf("Binds = %#v, want one parsed bind", cfg.Binds) + } +} + +func TestMergeImplicitPublishArgsPreservesExistingPublishFlagsAndDedupesBinds(t *testing.T) { + prev := daemonState + defer func() { daemonState = prev }() + daemonState = &runtimeDaemon{ + activeMode: "publish", + activeArgs: []string{ + "-bind", "8080:0x8911295322a1b94539e258e46f18e33acf21b48a:80", + "publish", "-public", "80", + }, + } + + got := mergeImplicitPublishArgs([]string{ + "-bind", "8080:0x8911295322a1b94539e258e46f18e33acf21b48a:80", + "publish", + }) + + want := []string{ + "-bind", "8080:0x8911295322a1b94539e258e46f18e33acf21b48a:80", + "publish", "-public", "80", + } + if strings.Join(got, "\x00") != strings.Join(want, "\x00") { + t.Fatalf("mergeImplicitPublishArgs() = %#v, want %#v", got, want) + } +} + +func TestMergeImplicitPublishArgsAppendsNewBindToExistingPublishState(t *testing.T) { + prev := daemonState + defer func() { daemonState = prev }() + daemonState = &runtimeDaemon{ + activeMode: "publish", + activeArgs: []string{ + "-bind", "8080:0x8911295322a1b94539e258e46f18e33acf21b48a:80", + "publish", "-public", "80", + }, + } + + got := mergeImplicitPublishArgs([]string{ + "-bind", "8081:0x8911295322a1b94539e258e46f18e33acf21b48a:80", + "publish", + }) + + want := []string{ + "-bind", "8080:0x8911295322a1b94539e258e46f18e33acf21b48a:80", + "-bind", "8081:0x8911295322a1b94539e258e46f18e33acf21b48a:80", + "publish", "-public", "80", + } + if strings.Join(got, "\x00") != strings.Join(want, "\x00") { + t.Fatalf("mergeImplicitPublishArgs() = %#v, want %#v", got, want) + } +} + +func TestSanitizeModeArgsRemovesStartupFlagsButKeepsModeFlags(t *testing.T) { + got := sanitizeModeArgs("publish", []string{ + "-update=false", + "-dbpath", "/tmp/test.db", + "-bind", "8080:0x8911295322a1b94539e258e46f18e33acf21b48a:80", + "publish", + "-public", "80", + }) + want := []string{ + "-bind", "8080:0x8911295322a1b94539e258e46f18e33acf21b48a:80", + "publish", + "-public", "80", + } + if strings.Join(got, "\x00") != strings.Join(want, "\x00") { + t.Fatalf("sanitizeModeArgs() = %#v, want %#v", got, want) + } +} + +type testingLoggerWriter struct { + t *testing.T +} + +func (w testingLoggerWriter) Write(p []byte) (int, error) { + w.t.Helper() + return len(p), nil +} diff --git a/cmd/diode/daemon_transport_unix.go b/cmd/diode/daemon_transport_unix.go new file mode 100644 index 00000000..c5d63bfe --- /dev/null +++ b/cmd/diode/daemon_transport_unix.go @@ -0,0 +1,160 @@ +//go:build !windows + +package main + +import ( + "encoding/json" + "fmt" + "net" + "os" + "os/exec" + "path/filepath" + "syscall" + "time" +) + +func daemonPaths() (string, string, error) { + dir, err := daemonPathDir() + if err != nil { + return "", "", err + } + socketPath := filepath.Join(dir, "daemon.sock") + return socketPath, metaPathFromSocket(socketPath), nil +} + +func metaPathFromSocket(socketPath string) string { + if socketPath == "" { + socketPath, _, _ = daemonPaths() + } + return filepath.Join(filepath.Dir(socketPath), "daemon.json") +} + +func daemonListen(socketPath string) (net.Listener, error) { + ln, err := net.Listen("unix", socketPath) + if err != nil { + return nil, err + } + if err := os.Chmod(socketPath, 0600); err != nil { + _ = ln.Close() + _ = os.Remove(socketPath) + return nil, err + } + return ln, nil +} + +func dialDaemon(socketPath string) (net.Conn, error) { + if socketPath == "" { + path, _, err := daemonPaths() + if err != nil { + return nil, err + } + socketPath = path + } + return net.DialTimeout("unix", socketPath, 500*time.Millisecond) +} + +func cleanupDaemonTransport(socketPath string) { + _ = os.Remove(socketPath) +} + +func daemonSignals() []os.Signal { + return []os.Signal{os.Interrupt, syscall.SIGTERM} +} + +func spawnDaemon(spec daemonStartupSpec) error { + specBytes, err := json.Marshal(spec) + if err != nil { + return err + } + r, w, err := os.Pipe() + if err != nil { + return err + } + defer r.Close() + devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0600) + if err != nil { + _ = w.Close() + return err + } + defer devNull.Close() + + cmd := exec.Command(os.Args[0], daemonCommandName) + cmd.Stdin = devNull + cmd.Stdout = devNull + cmd.Stderr = devNull + cmd.ExtraFiles = []*os.File{w} + cmd.Env = append(os.Environ(), + fmt.Sprintf("%s=3", envDaemonReadyFD), + fmt.Sprintf("%s=%s", envDaemonStartupSpec, string(specBytes)), + ) + cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true} + if err := cmd.Start(); err != nil { + _ = w.Close() + return err + } + _ = w.Close() + + done := make(chan error, 1) + go func() { + buf := make([]byte, 1) + _, err := r.Read(buf) + done <- err + }() + select { + case err := <-done: + return err + case <-time.After(10 * time.Second): + _ = cmd.Process.Kill() + return fmt.Errorf("timed out waiting for daemon startup") + } +} + +func daemonRestartSelf(cmd string, startup daemonStartupSpec) error { + env, err := daemonRestartEnv(startup) + if err != nil { + return err + } + prepareDaemonForRestart() + r, w, err := os.Pipe() + if err != nil { + return err + } + defer r.Close() + devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0600) + if err != nil { + _ = w.Close() + return err + } + defer devNull.Close() + + child := exec.Command(cmd, daemonCommandName) + child.Stdin = devNull + child.Stdout = devNull + child.Stderr = devNull + child.ExtraFiles = []*os.File{w} + child.Env = append(env, fmt.Sprintf("%s=3", envDaemonReadyFD)) + child.SysProcAttr = &syscall.SysProcAttr{Setsid: true} + if err := child.Start(); err != nil { + _ = w.Close() + return err + } + _ = w.Close() + + done := make(chan error, 1) + go func() { + buf := make([]byte, 1) + _, err := r.Read(buf) + done <- err + }() + select { + case err := <-done: + if err != nil { + return err + } + case <-time.After(10 * time.Second): + _ = child.Process.Kill() + return fmt.Errorf("timed out waiting for restarted daemon startup") + } + os.Exit(0) + return nil +} diff --git a/cmd/diode/daemon_transport_windows.go b/cmd/diode/daemon_transport_windows.go new file mode 100644 index 00000000..9b0b0cd5 --- /dev/null +++ b/cmd/diode/daemon_transport_windows.go @@ -0,0 +1,128 @@ +//go:build windows + +package main + +import ( + "encoding/json" + "fmt" + "net" + "os" + "os/exec" + "path/filepath" + "syscall" + "time" + + "github.com/Microsoft/go-winio" +) + +func daemonPaths() (string, string, error) { + dir, err := daemonPathDir() + if err != nil { + return "", "", err + } + socketPath := `\\.\pipe\diode-client-` + daemonPathID() + return socketPath, filepath.Join(dir, "daemon.json"), nil +} + +func metaPathFromSocket(socketPath string) string { + dir, err := daemonPathDir() + if err != nil { + return "daemon.json" + } + return filepath.Join(dir, "daemon.json") +} + +func daemonListen(socketPath string) (net.Listener, error) { + return winio.ListenPipe(socketPath, &winio.PipeConfig{ + SecurityDescriptor: "D:P(A;;GA;;;SY)(A;;GA;;;BA)(A;;GA;;;OW)", + }) +} + +func dialDaemon(socketPath string) (net.Conn, error) { + if socketPath == "" { + path, _, err := daemonPaths() + if err != nil { + return nil, err + } + socketPath = path + } + timeout := 500 * time.Millisecond + return winio.DialPipe(socketPath, &timeout) +} + +func cleanupDaemonTransport(socketPath string) {} + +func daemonSignals() []os.Signal { + return []os.Signal{os.Interrupt} +} + +func spawnDaemon(spec daemonStartupSpec) error { + specBytes, err := json.Marshal(spec) + if err != nil { + return err + } + devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0600) + if err != nil { + return err + } + defer devNull.Close() + + cmd := exec.Command(os.Args[0], daemonCommandName) + cmd.Stdin = devNull + cmd.Stdout = devNull + cmd.Stderr = devNull + cmd.Env = append(os.Environ(), fmt.Sprintf("%s=%s", envDaemonStartupSpec, string(specBytes))) + cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} + if err := cmd.Start(); err != nil { + return err + } + + deadline := time.Now().Add(10 * time.Second) + for time.Now().Before(deadline) { + meta, err := readDaemonMetadata() + if err == nil && meta.PID == cmd.Process.Pid { + if _, err := dialDaemon(meta.SocketPath); err == nil { + return nil + } + } + time.Sleep(100 * time.Millisecond) + } + _ = cmd.Process.Kill() + return fmt.Errorf("timed out waiting for daemon startup") +} + +func daemonRestartSelf(cmd string, startup daemonStartupSpec) error { + env, err := daemonRestartEnv(startup) + if err != nil { + return err + } + prepareDaemonForRestart() + devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0600) + if err != nil { + return err + } + defer devNull.Close() + + child := exec.Command(cmd, daemonCommandName) + child.Stdin = devNull + child.Stdout = devNull + child.Stderr = devNull + child.Env = env + child.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} + if err := child.Start(); err != nil { + return err + } + deadline := time.Now().Add(10 * time.Second) + for time.Now().Before(deadline) { + meta, err := readDaemonMetadata() + if err == nil && meta.PID == child.Process.Pid { + if _, err := dialDaemon(meta.SocketPath); err == nil { + os.Exit(0) + return nil + } + } + time.Sleep(100 * time.Millisecond) + } + _ = child.Process.Kill() + return fmt.Errorf("timed out waiting for restarted daemon startup") +} diff --git a/cmd/diode/diode.go b/cmd/diode/diode.go index 11d71a8b..a5928961 100644 --- a/cmd/diode/diode.go +++ b/cmd/diode/diode.go @@ -27,6 +27,10 @@ func main() { os.Exit(0) } + if handled, exitCode := maybeHandleDaemonCLI(os.Args[1:]); handled { + os.Exit(exitCode) + } + cfg := config.AppConfig err := diodeCmd.Execute() if err != nil { diff --git a/cmd/diode/fetch.go b/cmd/diode/fetch.go index 59edeea5..604ab2e6 100644 --- a/cmd/diode/fetch.go +++ b/cmd/diode/fetch.go @@ -66,12 +66,12 @@ func (fp *fetchProgress) Read(p []byte) (int, error) { if fp.read == 0 { if fp.contentLength > 0 { fp.pointSize = float64(fp.contentLength) / 60 - fmt.Printf("Downloading %d bytes into '%s'.\n", fp.contentLength, fp.name) - fmt.Println("[------------------------------------------------------------]") - fmt.Printf("[") + stdoutf("Downloading %d bytes into '%s'.\n", fp.contentLength, fp.name) + stdoutln("[------------------------------------------------------------]") + stdoutf("[") } else { fp.pointSize = 0 - fmt.Printf("Downloading into '%s'.\n", fp.name) + stdoutf("Downloading into '%s'.\n", fp.name) } } n, err := fp.Reader.Read(p) @@ -79,13 +79,13 @@ func (fp *fetchProgress) Read(p []byte) (int, error) { if fp.pointSize > 0 { for int64(float64(fp.read)/fp.pointSize) > fp.points { - fmt.Printf("#") + stdoutf("#") fp.points++ } } if err == io.EOF && fp.contentLength > 0 && fp.read == fp.contentLength { - fmt.Printf("] Done!\n") + stdoutf("] Done!\n") } return n, err @@ -187,19 +187,27 @@ func fetchHandler() (err error) { defer resp.Body.Close() var f *os.File + var out io.Writer if len(fetchCfg.Output) > 0 { f, err = os.OpenFile(fetchCfg.Output, os.O_CREATE|os.O_WRONLY, 0600) if err != nil { return } + out = f } else if fetchCfg.Verbose { - f = os.Stdout + out = stdoutWriter() src = resp.Body } - if f != nil { - io.Copy(f, src) - f.Close() + if out != nil { + _, err = io.Copy(out, src) + if f != nil { + f.Close() + } + if err != nil { + return + } + return } return } diff --git a/cmd/diode/files.go b/cmd/diode/files.go index d1550a35..b6f9ffcb 100644 --- a/cmd/diode/files.go +++ b/cmd/diode/files.go @@ -8,7 +8,6 @@ import ( "fmt" "net" "net/http" - "os" "strconv" "strings" "time" @@ -84,9 +83,9 @@ func filesHandler() error { cfg := config.AppConfig spec := strings.TrimSpace(filesCmd.Flag.Arg(0)) if spec == "" { - fmt.Fprintln(os.Stderr, "usage: diode files [-fileroot ] ") - fmt.Fprintln(os.Stderr, " files-spec: for public, or ,... for private") - os.Exit(2) + stderrln("usage: diode files [-fileroot ] ") + stderrln(" files-spec: for public, or ,... for private") + return newExitStatusError(2, "missing files publish spec") } portStr, mode, err := expandFilesSpec(spec) @@ -106,12 +105,13 @@ func filesHandler() error { if err := app.Start(); err != nil { return err } + beginRuntimeMode("files") cleanup, err := startFileListener(p, filesFileroot) if err != nil { return err } - app.Defer(cleanup) + registerRuntimeCleanup(cleanup) publicPorts, privatePorts, protectedPorts, err := appendPublishedControlDefinition(nil, nil, nil, p) if err != nil { @@ -137,6 +137,9 @@ func filesHandler() error { if err := app.ReconcileControlServices(); err != nil { return err } + if isDaemonApplyRequest() { + return nil + } app.Wait() return nil diff --git a/cmd/diode/gateway.go b/cmd/diode/gateway.go index cf47f657..5d2e17c2 100644 --- a/cmd/diode/gateway.go +++ b/cmd/diode/gateway.go @@ -5,6 +5,7 @@ package main import ( "fmt" + "strings" "github.com/diodechain/diode_client/command" "github.com/diodechain/diode_client/config" @@ -38,6 +39,7 @@ func init() { } func gatewayHandler() (err error) { + beginRuntimeMode("gateway") err = app.Start() if err != nil { return @@ -48,6 +50,41 @@ func gatewayHandler() (err error) { if result.HasValidationErrors() { return fmt.Errorf("couldn't apply gateway controls: %v", result.ValidationErrors) } + printGatewayBanner(config.AppConfig) + if isDaemonApplyRequest() { + return nil + } app.Wait() return } + +func printGatewayBanner(cfg *config.Config) { + status := gatewayStatusFromConfig(cfg) + if status.GatewayEnabled { + cfg.PrintLabel("HTTP gateway", status.GatewayAddr) + } + if status.SecureGatewayEnabled { + cfg.PrintLabel("HTTPS gateway", status.SecureGatewayAddr) + } + if len(status.SecureGatewayAdditionalAddrs) > 0 { + cfg.PrintLabel("HTTPS gateways", strings.Join(status.SecureGatewayAdditionalAddrs, ", ")) + } + if status.SocksEnabled { + cfg.PrintLabel("SOCKS proxy", status.SocksAddr) + } +} + +func gatewayStatusFromConfig(cfg *config.Config) daemonRuntimeStatus { + status := daemonRuntimeStatus{ + SocksEnabled: cfg.EnableSocksServer, + SocksAddr: cfg.SocksServerAddr(), + GatewayEnabled: cfg.EnableProxyServer, + GatewayAddr: cfg.ProxyServerAddr(), + SecureGatewayEnabled: cfg.EnableSProxyServer, + SecureGatewayAddr: cfg.SProxyServerAddr(), + } + for _, port := range cfg.SProxyAdditionalPorts() { + status.SecureGatewayAdditionalAddrs = append(status.SecureGatewayAdditionalAddrs, cfg.SProxyServerAddrForPort(port)) + } + return status +} diff --git a/cmd/diode/join.go b/cmd/diode/join.go index 6c124d88..12c4c867 100644 --- a/cmd/diode/join.go +++ b/cmd/diode/join.go @@ -2657,6 +2657,28 @@ func runContractController(cfg *config.Config) error { } } +func runContractControllerUntil(cfg *config.Config, stopCh <-chan struct{}) error { + for { + select { + case <-stopCh: + cfg.Logger.Info("Join controller stopped") + return nil + default: + } + + if err := contractSync(cfg); err != nil { + cfg.Logger.Warn("Perimeter contract sync failed: %v", err) + } + + select { + case <-stopCh: + cfg.Logger.Info("Join controller stopped") + return nil + case <-time.After(30 * time.Second): + } + } +} + func joinHandler() (err error) { cfg := config.AppConfig cfg.Logger.Warn("join command is still BETA, parameters may change") @@ -2727,6 +2749,16 @@ func joinHandler() (err error) { if dryRun { return nil } + if isDaemonApplyRequest() { + beginRuntimeMode("join") + done := make(chan struct{}) + app.SetModeDone(done) + go func() { + defer close(done) + _ = runContractControllerUntil(cfg, app.ModeStopChan()) + }() + return nil + } return runContractController(cfg) } diff --git a/cmd/diode/mode_helpers.go b/cmd/diode/mode_helpers.go new file mode 100644 index 00000000..565f7d43 --- /dev/null +++ b/cmd/diode/mode_helpers.go @@ -0,0 +1,23 @@ +package main + +func beginRuntimeMode(name string) { + if !isDaemonApplyRequest() { + return + } + app.StopMode() + if daemonState != nil { + daemonState.clearModeSnapshot() + } + app.BeginMode(name) +} + +func registerRuntimeCleanup(cleanup func()) { + if cleanup == nil { + return + } + if isDaemonApplyRequest() { + app.ModeDefer(cleanup) + return + } + app.Defer(cleanup) +} diff --git a/cmd/diode/output_helpers.go b/cmd/diode/output_helpers.go new file mode 100644 index 00000000..42c7150c --- /dev/null +++ b/cmd/diode/output_helpers.go @@ -0,0 +1,37 @@ +package main + +import ( + "fmt" + "io" + "os" + + "github.com/diodechain/diode_client/config" +) + +func stdoutWriter() io.Writer { + cfg := config.AppConfig + if cfg != nil && cfg.StdoutWriter != nil { + return cfg.StdoutWriter + } + return os.Stdout +} + +func stderrWriter() io.Writer { + cfg := config.AppConfig + if cfg != nil && cfg.StderrWriter != nil { + return cfg.StderrWriter + } + return os.Stderr +} + +func stdoutf(format string, args ...interface{}) { + fmt.Fprintf(stdoutWriter(), format, args...) +} + +func stdoutln(args ...interface{}) { + fmt.Fprintln(stdoutWriter(), args...) +} + +func stderrln(args ...interface{}) { + fmt.Fprintln(stderrWriter(), args...) +} diff --git a/cmd/diode/publish.go b/cmd/diode/publish.go index a92a1e79..4b4add18 100644 --- a/cmd/diode/publish.go +++ b/cmd/diode/publish.go @@ -6,7 +6,6 @@ package main import ( "fmt" "net" - "os" "regexp" "strconv" "strings" @@ -275,6 +274,7 @@ func publishHandler() (err error) { if err != nil { return } + beginRuntimeMode("publish") publicPorts := cloneStrings(cfg.PublicPublishedPorts) privatePorts := cloneStrings(cfg.PrivatePublishedPorts) protectedPorts := cloneStrings(cfg.ProtectedPublishedPorts) @@ -307,7 +307,7 @@ func publishHandler() (err error) { if err != nil { return } - app.Defer(cleanup) + registerRuntimeCleanup(cleanup) publicPorts, privatePorts, protectedPorts, err = appendPublishedControlDefinition(publicPorts, privatePorts, protectedPorts, p) if err != nil { return err @@ -339,7 +339,7 @@ func publishHandler() (err error) { return } }() - app.Defer(func() { + registerRuntimeCleanup(func() { // Since we didn't use ListenAndServe, call // ln.Close() instead of staticServer.Close() ln.Close() @@ -358,12 +358,12 @@ func publishHandler() (err error) { } if len(publicPorts) == 0 && len(privatePorts) == 0 && len(protectedPorts) == 0 && len(cfg.SSHPublishedServices) == 0 && len(cfg.Binds) == 0 { - fmt.Println() - fmt.Println("ERROR: Can't run publish without any arguments!") - fmt.Println(" HINT: Try 'diode publish -public 8080:80' to publish a local port") - fmt.Println(" HINT: Check our docs to learn more about publishing ports: https://diode.io/docs/getting-started.html") - fmt.Println(" HINT: Or run 'diode --help' to see all commands") - os.Exit(2) + stdoutln() + stdoutln("ERROR: Can't run publish without any arguments!") + stdoutln(" HINT: Try 'diode publish -public 8080:80' to publish a local port") + stdoutln(" HINT: Check our docs to learn more about publishing ports: https://diode.io/docs/getting-started.html") + stdoutln(" HINT: Or run 'diode --help' to see all commands") + return newExitStatusError(2, "publish requires at least one published port or bind") } patch := ControlPatch{} @@ -378,6 +378,9 @@ func publishHandler() (err error) { if err := app.ReconcileControlServices(); err != nil { return err } + if isDaemonApplyRequest() { + return nil + } for { app.Wait() if !app.Closed() { diff --git a/cmd/diode/publish_render.go b/cmd/diode/publish_render.go new file mode 100644 index 00000000..fbb85560 --- /dev/null +++ b/cmd/diode/publish_render.go @@ -0,0 +1,58 @@ +package main + +import ( + "fmt" + "sort" + "strings" + + "github.com/diodechain/diode_client/config" +) + +func renderPublishedPortMap(cfg *config.Config, ports map[int]*config.Port) { + if len(ports) == 0 { + return + } + cfg.PrintInfo("") + name := cfg.ClientAddr.HexString() + if cfg.ClientName != "" { + name = cfg.ClientName + } + + keys := make([]int, 0, len(ports)) + for key := range ports { + keys = append(keys, key) + } + sort.Ints(keys) + + for _, key := range keys { + port := ports[key] + if port.Mode == config.PublicPublishedMode { + if port.To == httpPort { + cfg.PrintLabel("HTTP Gateway Enabled", fmt.Sprintf("http://%s.diode.link/", name)) + } + if (8000 <= port.To && port.To <= 8100) || (8400 <= port.To && port.To <= 8500) { + cfg.PrintLabel("HTTP Gateway Enabled", fmt.Sprintf("https://%s.diode.link:%d/", name, port.To)) + } + } + } + + cfg.PrintLabel("Port ", " ") + for _, key := range keys { + port := ports[key] + cfg.PrintLabel( + fmt.Sprintf("Port %12s", publishedPortDisplayHost(port)), + fmt.Sprintf("%8d %10s %s %s", port.To, config.ModeName(port.Mode), config.ProtocolName(port.Protocol), strings.Join(portAllowlistStrings(port), ",")), + ) + } +} + +func renderBindMap(cfg *config.Config, binds []config.Bind) { + if len(binds) == 0 { + return + } + cfg.PrintInfo("") + cfg.PrintLabel("Bind ", " ") + for _, bind := range binds { + cfg.PrintLabel(fmt.Sprintf("Port %5d", bind.LocalPort), fmt.Sprintf("%5s %11s:%d", config.ProtocolName(bind.Protocol), bind.To, bind.ToPort)) + } +} diff --git a/cmd/diode/pushpull.go b/cmd/diode/pushpull.go index 628bf121..3ec0121d 100644 --- a/cmd/diode/pushpull.go +++ b/cmd/diode/pushpull.go @@ -135,9 +135,9 @@ func parsePeerPort(s string) (host string, port int, err error) { func pushHandler() error { args := pushCmd.Flag.Args() if len(args) < 2 { - fmt.Fprintln(os.Stderr, "usage: diode push :") - fmt.Fprintln(os.Stderr, " diode push ::") - os.Exit(2) + stderrln("usage: diode push :") + stderrln(" diode push ::") + return newExitStatusError(2, "missing push arguments") } if len(args) > 2 { return fmt.Errorf("too many arguments (quote :: as one token)") @@ -208,15 +208,15 @@ func pushHandler() error { } return fmt.Errorf("push failed: %s", msg) } - config.AppConfig.Logger.Info("push ok: %s -> %s", localPath, urlStr) + config.AppConfig.PrintInfo(fmt.Sprintf("push ok: %s -> %s", localPath, urlStr)) return nil } func pullHandler() error { args := pullCmd.Flag.Args() if len(args) < 1 { - fmt.Fprintln(os.Stderr, "usage: diode pull :: []") - os.Exit(2) + stderrln("usage: diode pull :: []") + return newExitStatusError(2, "missing pull arguments") } if len(args) > 2 { return fmt.Errorf("too many arguments (quote paths that contain spaces)") @@ -288,7 +288,7 @@ func pullHandler() error { if cerr != nil { return cerr } - config.AppConfig.Logger.Info("pull ok: %s -> ./%s", urlStr, base) + config.AppConfig.PrintInfo(fmt.Sprintf("pull ok: %s -> ./%s", urlStr, base)) return nil } @@ -311,6 +311,6 @@ func pullHandler() error { if cerr != nil { return cerr } - config.AppConfig.Logger.Info("pull ok: %s -> %s", urlStr, dest) + config.AppConfig.PrintInfo(fmt.Sprintf("pull ok: %s -> %s", urlStr, dest)) return nil } diff --git a/cmd/diode/query.go b/cmd/diode/query.go index 2792277f..4432715e 100644 --- a/cmd/diode/query.go +++ b/cmd/diode/query.go @@ -33,8 +33,7 @@ func queryHandler() (err error) { cfg := config.AppConfig if cfg.QueryAddress == "" { - cfg.PrintError("Failed to query", fmt.Errorf("-address argument is required")) - return + return newExitStatusError(2, "query requires -address") } err = app.Start() diff --git a/cmd/diode/query_test.go b/cmd/diode/query_test.go new file mode 100644 index 00000000..8fded16f --- /dev/null +++ b/cmd/diode/query_test.go @@ -0,0 +1,32 @@ +package main + +import ( + "strings" + "testing" + + "github.com/diodechain/diode_client/config" +) + +func TestQueryHandlerRequiresAddress(t *testing.T) { + cfg := newSharedControlTestConfig(t) + origCfg := config.AppConfig + config.AppConfig = cfg + t.Cleanup(func() { + config.AppConfig = origCfg + }) + + err := queryHandler() + if err == nil { + t.Fatal("queryHandler() error = nil, want missing address error") + } + if !strings.Contains(err.Error(), "requires -address") { + t.Fatalf("queryHandler() error = %q, want missing address error", err.Error()) + } + statusErr, ok := err.(interface{ Status() int }) + if !ok { + t.Fatalf("queryHandler() error type %T does not expose Status()", err) + } + if statusErr.Status() != 2 { + t.Fatalf("queryHandler() status = %d, want 2", statusErr.Status()) + } +} diff --git a/cmd/diode/socksd.go b/cmd/diode/socksd.go index a3c45888..29c9ffe0 100644 --- a/cmd/diode/socksd.go +++ b/cmd/diode/socksd.go @@ -26,10 +26,14 @@ func init() { } func socksdHandler() (err error) { + beginRuntimeMode("socksd") err = app.Start() if err != nil { return } + if isDaemonApplyRequest() { + return nil + } app.Wait() return } diff --git a/cmd/diode/ssh.go b/cmd/diode/ssh.go index 16438788..e3648ae1 100644 --- a/cmd/diode/ssh.go +++ b/cmd/diode/ssh.go @@ -45,17 +45,8 @@ var ( var runtimeGOOS = runtime.GOOS func sshHandler() error { - return runSSHLikeTool(sshLikeToolOptions{ - commandName: sshCommandName, - toolName: "ssh", - validateLabel: "Invalid SSH target", - validateArgs: func(args []string) error { - if target := extractSSHTarget(args); target != "" { - return validateSSHTarget(target) - } - return nil - }, - }) + opts, _ := sshLikeOptionsForCommand(sshCommandName) + return runSSHLikeTool(opts) } // sshLikeToolOptions configures runSSHLikeTool for a specific OpenSSH-based @@ -79,44 +70,25 @@ type sshLikeToolOptions struct { // identity and a ProxyCommand that tunnels via `diode ssh-proxy`. func runSSHLikeTool(opts sshLikeToolOptions) error { cfg := config.AppConfig - toolName := opts.toolName - if toolName == "" { - toolName = opts.commandName - } cfg.Logger.Warn("%s command is still BETA, parameters may change", opts.commandName) if err := app.Start(); err != nil { cfg.PrintError("Could not start local Diode client", err) - os.Exit(1) + return newExitStatusError(1, "%s", err.Error()) } proxyAddr, cleanupProxy, err := startSSHLocalSocksProxy() if err != nil { cfg.PrintError("Could not start local Diode SOCKS proxy", err) - os.Exit(1) + return newExitStatusError(1, "%s", err.Error()) } defer cleanupProxy() cfg.PrintLabel("Using local diode client", proxyAddr) - diodeExe, err := os.Executable() + passArgs, err := sshLikePassThroughArgs(opts.commandName, os.Args) if err != nil { - cfg.PrintError("Could not determine diode executable path", err) - os.Exit(1) - } - - os_args := os.Args - cmdIndex := -1 - for i, arg := range os_args { - if arg == opts.commandName { - cmdIndex = i - break - } + cfg.PrintError(err.Error(), err) + return newExitStatusError(1, "%s", err.Error()) } - if cmdIndex == -1 { - msg := fmt.Sprintf("%s command not found", opts.commandName) - cfg.PrintError(msg, errors.New(msg)) - os.Exit(1) - } - passArgs := normalizeSSHArgs(os_args[cmdIndex+1:]) if opts.validateArgs != nil { if err := opts.validateArgs(passArgs); err != nil { @@ -125,21 +97,45 @@ func runSSHLikeTool(opts sshLikeToolOptions) error { label = fmt.Sprintf("Invalid %s argument", opts.commandName) } cfg.PrintError(label, err) - os.Exit(1) + return newExitStatusError(1, "%s", err.Error()) + } + } + + return runSSHToolWithProxyAddr(proxyAddr, opts, passArgs) +} + +func sshLikePassThroughArgs(commandName string, osArgs []string) ([]string, error) { + for i, arg := range osArgs { + if arg == commandName { + return normalizeSSHArgs(osArgs[i+1:]), nil } } + return nil, fmt.Errorf("%s command not found", commandName) +} + +func runSSHToolWithProxyAddr(proxyAddr string, opts sshLikeToolOptions, passArgs []string) error { + cfg := config.AppConfig + toolName := opts.toolName + if toolName == "" { + toolName = opts.commandName + } + diodeExe, err := os.Executable() + if err != nil { + cfg.PrintError("Could not determine diode executable path", err) + return newExitStatusError(1, "%s", err.Error()) + } identityFile, cleanup, err := createEphemeralSSHIdentity() if err != nil { cfg.PrintError("Could not create ephemeral ssh identity", err) - os.Exit(1) + return newExitStatusError(1, "%s", err.Error()) } defer cleanup() toolPath, err := findOpenSSHTool(toolName) if err != nil { cfg.PrintError(fmt.Sprintf("%s not found", toolName), err) - os.Exit(1) + return newExitStatusError(1, "%s", err.Error()) } args := buildSSHLikeToolArgs(runtimeGOOS, diodeExe, proxyAddr, identityFile, passArgs) @@ -152,14 +148,114 @@ func runSSHLikeTool(opts sshLikeToolOptions) error { if err := cmd.Run(); err != nil { var exitErr *exec.ExitError if errors.As(err, &exitErr) { - os.Exit(exitErr.ExitCode()) + return newExitStatusError(exitErr.ExitCode(), "%s exited with status %d", toolName, exitErr.ExitCode()) } cfg.PrintError(fmt.Sprintf("Could not execute %s", toolName), err) - os.Exit(1) + return newExitStatusError(1, "%s", err.Error()) } return nil } +func runSSHViaDaemonLease(commandArgs []string, resp daemonResponse) int { + return runSSHLikeViaDaemonLease(commandArgs, resp) +} + +func runSSHLikeViaDaemonLease(commandArgs []string, resp daemonResponse) int { + if len(commandArgs) == 0 { + stderrln("missing ssh-like command arguments") + return 1 + } + if err := ensureDaemonSSHForegroundLogger(); err != nil { + stderrln(fmt.Sprintf("could not initialize %s command logger: %v", commandArgs[0], err)) + return 1 + } + opts, ok := sshLikeOptionsForCommand(commandArgs[0]) + if !ok { + stderrln(fmt.Sprintf("unsupported daemon proxy command: %s", commandArgs[0])) + return 1 + } + if resp.LeaseID != "" { + defer func() { + _ = releaseDaemonLease(resp.LeaseID) + }() + } + if resp.ExitCode != 0 { + return resp.ExitCode + } + if resp.Error != "" { + stderrln(resp.Error) + return 1 + } + if strings.TrimSpace(resp.ProxyAddr) == "" { + stderrln("daemon proxy lease did not expose an address") + return 1 + } + cfg := config.AppConfig + cfg.PrintLabel("Using diode daemon proxy", resp.ProxyAddr) + passArgs := normalizeSSHArgs(commandArgs[1:]) + if opts.validateArgs != nil { + if err := opts.validateArgs(passArgs); err != nil { + label := opts.validateLabel + if label == "" { + label = fmt.Sprintf("Invalid %s argument", opts.commandName) + } + cfg.PrintError(label, err) + return exitCodeFromError(newExitStatusError(1, "%s", err.Error())) + } + } + if err := runSSHToolWithProxyAddr(resp.ProxyAddr, opts, passArgs); err != nil { + return exitCodeFromError(err) + } + return 0 +} + +func sshLikeOptionsForCommand(commandName string) (sshLikeToolOptions, bool) { + switch commandName { + case sshCommandName: + return sshLikeToolOptions{ + commandName: sshCommandName, + toolName: "ssh", + validateLabel: "Invalid SSH target", + validateArgs: func(args []string) error { + if target := extractSSHTarget(args); target != "" { + return validateSSHTarget(target) + } + return nil + }, + }, true + case scpCommandName: + return sshLikeToolOptions{ + commandName: scpCommandName, + toolName: "scp", + validateLabel: "Invalid scp target", + validateArgs: func(args []string) error { + return validateSCPArgs(args) + }, + }, true + default: + return sshLikeToolOptions{}, false + } +} + +func ensureDaemonSSHForegroundLogger() error { + cfg := config.AppConfig + if cfg == nil { + cfg = newRootConfig() + config.AppConfig = cfg + } + if cfg.Logger != nil { + return nil + } + if cfg.LogFilePath != "" { + cfg.LogMode = config.LogToFile + } else { + cfg.LogMode = config.LogToConsole + } + logger, err := config.NewLogger(cfg) + cfg.Logger = &logger + return err +} + // buildSSHLikeToolArgs builds the argv (excluding argv[0]) for an OpenSSH // tool launched by diode. The user's pass-through args are placed *after* // the diode-injected flags so positional arguments (e.g. scp's source/ @@ -204,17 +300,7 @@ func subcommandPassThroughArgs(args []string, name string) ([]string, error) { // on this instance (rpc.Config); that is not the same as cfg.EnableSocksServer. func startSSHLocalSocksProxy() (string, func(), error) { cfg := config.AppConfig - socksCfg := rpc.Config{ - EnableSocks: true, - Addr: net.JoinHostPort("127.0.0.1", "0"), - FleetAddr: cfg.FleetAddr, - Blocklists: cfg.Blocklists(), - Allowlists: cfg.Allowlists, - EnableProxy: false, - ProxyServerAddr: cfg.ProxyServerAddr(), - Fallback: cfg.SocksFallback, - } - socksServer, err := rpc.NewSocksServer(socksCfg, app.clientManager) + socksServer, err := rpc.NewSocksServer(sshLocalSocksConfig(cfg), app.clientManager) if err != nil { return "", nil, err } @@ -241,6 +327,19 @@ func startSSHLocalSocksProxy() (string, func(), error) { return net.JoinHostPort(host, strconv.Itoa(tcpAddr.Port)), cleanup, nil } +func sshLocalSocksConfig(cfg *config.Config) rpc.Config { + return rpc.Config{ + EnableSocks: true, + Addr: net.JoinHostPort("127.0.0.1", "0"), + FleetAddr: cfg.FleetAddr, + Blocklists: cfg.Blocklists(), + Allowlists: cfg.Allowlists, + EnableProxy: false, + ProxyServerAddr: cfg.ProxyServerAddr(), + Fallback: cfg.SocksFallback, + } +} + func createEphemeralSSHIdentity() (string, func(), error) { sshKeygen, err := findOpenSSHTool("ssh-keygen") if err != nil { diff --git a/cmd/diode/ssh_test.go b/cmd/diode/ssh_test.go index 400b44d4..713ce7a8 100644 --- a/cmd/diode/ssh_test.go +++ b/cmd/diode/ssh_test.go @@ -185,6 +185,73 @@ func TestBuildSSHLikeToolArgsKeepsUserArgsLast(t *testing.T) { } } +func TestRunSSHViaDaemonLeaseInitializesForegroundLogger(t *testing.T) { + origCfg := config.AppConfig + cfg := newRootConfig() + config.AppConfig = cfg + t.Cleanup(func() { + config.AppConfig = origCfg + }) + + code := runSSHViaDaemonLease([]string{"ssh", "badhost"}, daemonResponse{ProxyAddr: "127.0.0.1:1"}) + if code == 0 { + t.Fatal("runSSHViaDaemonLease() exit code = 0, want validation failure") + } + if cfg.Logger == nil { + t.Fatal("runSSHViaDaemonLease() did not initialize foreground logger") + } +} + +func TestSSHLocalSocksConfigEnablesListener(t *testing.T) { + cfg := newRootConfig() + got := sshLocalSocksConfig(cfg) + if !got.EnableSocks { + t.Fatal("sshLocalSocksConfig() did not enable SOCKS listener") + } + if got.Addr != net.JoinHostPort("127.0.0.1", "0") { + t.Fatalf("sshLocalSocksConfig() Addr = %q, want ephemeral localhost", got.Addr) + } +} + +func TestRunSSHViaDaemonLeaseStopsOnLeaseFailure(t *testing.T) { + origCfg := config.AppConfig + config.AppConfig = newRootConfig() + t.Cleanup(func() { + config.AppConfig = origCfg + }) + + code := runSSHViaDaemonLease([]string{"ssh", "ubuntu@miner2023.diode"}, daemonResponse{ + ExitCode: 1, + Error: "proxy lease did not expose an address", + }) + if code != 1 { + t.Fatalf("runSSHViaDaemonLease() exit code = %d, want 1", code) + } +} + +func TestRunSSHViaDaemonLeaseRejectsEmptyProxyAddr(t *testing.T) { + origCfg := config.AppConfig + origLookPath := lookPath + config.AppConfig = newRootConfig() + lookPathCalled := false + lookPath = func(string) (string, error) { + lookPathCalled = true + return "", errors.New("unexpected OpenSSH lookup") + } + t.Cleanup(func() { + config.AppConfig = origCfg + lookPath = origLookPath + }) + + code := runSSHViaDaemonLease([]string{"ssh", "ubuntu@miner2023.diode"}, daemonResponse{}) + if code != 1 { + t.Fatalf("runSSHViaDaemonLease() exit code = %d, want 1", code) + } + if lookPathCalled { + t.Fatal("runSSHViaDaemonLease() looked up OpenSSH tool despite empty proxy address") + } +} + func TestFindOpenSSHToolWindowsInstallHelp(t *testing.T) { origLookPath := lookPath origGOOS := runtimeGOOS diff --git a/cmd/diode/token.go b/cmd/diode/token.go index 4d0221dc..3a4d8763 100644 --- a/cmd/diode/token.go +++ b/cmd/diode/token.go @@ -67,8 +67,7 @@ func parseUnitAndValue(src string) (val int, unit string) { func tokenHandler() (err error) { if tokenCfg.CheckBalance { - showBalance() - return + return showBalance() } valWei, _ := parseUnitAndValue(tokenCfg.Value) diff --git a/cmd/diode/update.go b/cmd/diode/update.go index b9082034..6cdf0a0f 100644 --- a/cmd/diode/update.go +++ b/cmd/diode/update.go @@ -2,9 +2,8 @@ package main import ( "fmt" + "io" "os" - "os/exec" - "path" "path/filepath" "runtime" "time" @@ -38,12 +37,29 @@ func writeLastUpdateAt() { } func updateHandler() (err error) { - doUpdate() - return + _, err = doUpdate(updateRestartStandalone) + return err } -func doUpdate() int { - cfg := config.AppConfig +type updateRestartMode int + +const ( + updateRestartStandalone updateRestartMode = iota + updateRestartDeferred +) + +func runDaemonUpdateWithConfig(args []string, cfg *config.Config) (string, error) { + if len(args) == 0 || args[0] != "update" { + return "", newExitStatusError(2, "missing update command") + } + return doUpdateWithConfig(cfg, updateRestartDeferred) +} + +func doUpdate(restartMode updateRestartMode) (string, error) { + return doUpdateWithConfig(config.AppConfig, restartMode) +} + +func doUpdateWithConfig(cfg *config.Config, restartMode updateRestartMode) (string, error) { m := &update.Manager{ Command: "diode", Store: &github.Store{ @@ -57,47 +73,57 @@ func doUpdate() int { m.Command += ".exe" } - tarball, ok, err := download(m) + tarball, ok, err := download(cfg, m) if !ok { // Will recheck for an update in 24 hours go func() { time.Sleep(time.Hour * 24) - doUpdate() + _, _ = doUpdate(updateRestartStandalone) }() if err == nil { writeLastUpdateAt() } - return 0 - } - - // searching for binary in path - bin, err := exec.LookPath(m.Command) - if err != nil { - // just update local file - bin = os.Args[0] - } - - // find the real path of execute file if the file was symlink - binExe, err := filepath.EvalSymlinks(bin) - if err != nil { - binExe = bin + if err != nil { + return "", newExitStatusError(1, "%s", err.Error()) + } + return "", nil } - dir := filepath.Dir(binExe) + dir := updateInstallDir() if err := m.InstallTo(tarball, dir); err != nil { cfg.PrintError("Error installing", err) - return 129 + return "", newExitStatusError(129, "%s", err.Error()) } - cmd := path.Join(dir, m.Command) - fmt.Printf("Updated, restarting %s...\n", cmd) + cmd := filepath.Join(dir, m.Command) + fmt.Fprintf(updateOutputWriter(cfg), "Updated, restarting %s...\n", cmd) writeLastUpdateAt() + if restartMode == updateRestartDeferred { + return cmd, nil + } update.Restart(cmd) - return 0 + return "", nil +} + +func updateInstallDir() string { + bin, err := os.Executable() + if err != nil || bin == "" { + bin = os.Args[0] + } + return updateInstallDirFromExecutable(bin, filepath.EvalSymlinks) } -func download(m *update.Manager) (string, bool, error) { - cfg := config.AppConfig +func updateInstallDirFromExecutable(bin string, evalSymlinks func(string) (string, error)) string { + if abs, err := filepath.Abs(bin); err == nil { + bin = abs + } + if resolved, err := evalSymlinks(bin); err == nil { + bin = resolved + } + return filepath.Dir(bin) +} + +func download(cfg *config.Config, m *update.Manager) (string, bool, error) { ansi.HideCursor() defer ansi.ShowCursor() @@ -127,7 +153,7 @@ func download(m *update.Manager) (string, bool, error) { } // whitespace - fmt.Println() + fmt.Fprintln(updateOutputWriter(cfg)) // download tarball to a tmp dir tarball, err := a.DownloadProxy(progress.Reader) @@ -137,3 +163,10 @@ func download(m *update.Manager) (string, bool, error) { return tarball, true, nil } + +func updateOutputWriter(cfg *config.Config) io.Writer { + if cfg != nil && cfg.StdoutWriter != nil { + return cfg.StdoutWriter + } + return os.Stdout +} diff --git a/cmd/diode/update_test.go b/cmd/diode/update_test.go new file mode 100644 index 00000000..d8c9ceb8 --- /dev/null +++ b/cmd/diode/update_test.go @@ -0,0 +1,45 @@ +package main + +import ( + "errors" + "path/filepath" + "testing" +) + +func TestUpdateInstallDirFromExecutable(t *testing.T) { + executable := filepath.Join(t.TempDir(), "bin", "diode") + + got := updateInstallDirFromExecutable(executable, func(path string) (string, error) { + return path, nil + }) + + if got != filepath.Dir(executable) { + t.Fatalf("updateInstallDirFromExecutable() = %q, want %q", got, filepath.Dir(executable)) + } +} + +func TestUpdateInstallDirFromExecutableResolvesSymlink(t *testing.T) { + tmp := t.TempDir() + link := filepath.Join(tmp, "link", "diode") + target := filepath.Join(tmp, "target", "diode") + + got := updateInstallDirFromExecutable(link, func(path string) (string, error) { + return target, nil + }) + + if got != filepath.Dir(target) { + t.Fatalf("updateInstallDirFromExecutable() = %q, want %q", got, filepath.Dir(target)) + } +} + +func TestUpdateInstallDirFromExecutableFallsBackWhenResolveFails(t *testing.T) { + executable := filepath.Join(t.TempDir(), "diode") + + got := updateInstallDirFromExecutable(executable, func(path string) (string, error) { + return "", errors.New("not a symlink") + }) + + if got != filepath.Dir(executable) { + t.Fatalf("updateInstallDirFromExecutable() = %q, want %q", got, filepath.Dir(executable)) + } +} diff --git a/config/flag.go b/config/flag.go index 68d45358..05b5330f 100644 --- a/config/flag.go +++ b/config/flag.go @@ -5,6 +5,7 @@ package config import ( "fmt" + "io" "os" "strconv" "strings" @@ -44,6 +45,8 @@ type Config struct { EnableUpdate bool `yaml:"update,omitempty" json:"update,omitempty"` EnableMetrics bool `yaml:"metrics,omitempty" json:"metrics,omitempty"` EnableTray bool `yaml:"tray,omitempty" json:"tray,omitempty"` + DisableDaemon bool `yaml:"-" json:"-"` + DetachDaemon bool `yaml:"-" json:"-"` BlockquickDowngrade bool `yaml:"bqdowngrade,omitempty" json:"bqdowngrade,omitempty"` RemoteRPCAddrs StringValues `yaml:"diodeaddrs,omitempty" json:"diodeaddrs,omitempty"` RemoteRPCTimeout time.Duration `yaml:"timeout,omitempty" json:"timeout,omitempty"` @@ -105,6 +108,8 @@ type Config struct { LogMode int `yaml:"-" json:"-"` LogDateTime bool `yaml:"logdatetime,omitempty" json:"logdatetime,omitempty"` Logger *Logger `yaml:"-" json:"-"` + StdoutWriter io.Writer `yaml:"-" json:"-"` + StderrWriter io.Writer `yaml:"-" json:"-"` LogTargetTo string `yaml:"-" json:"-"` // parsed device (BNS or hex) for implicit bind LogTargetPort int `yaml:"-" json:"-"` LogTargetRemote interface{} `yaml:"-" json:"-"` // zapcore.WriteSyncer; set before ReloadLogger @@ -261,18 +266,27 @@ func (cfg *Config) SetBlocklists(blocklists map[Address]bool) { func (cfg *Config) PrintLabel(label string, value string) { msg := fmt.Sprintf("%-20s : %-42s", label, value) msg = strings.Replace(msg, "%", "%%", -1) + if cfg.StdoutWriter != nil { + fmt.Fprintln(cfg.StdoutWriter, msg) + } cfg.Logger.Info("%s", msg) } func (cfg *Config) PrintError(msg string, err error) { + text := fmt.Sprintf("%s: ", msg) if err != nil { - cfg.Logger.Error(fmt.Sprintf("%s: %s", msg, err.Error())) - } else { - cfg.Logger.Error(fmt.Sprintf("%s: ", msg)) + text = fmt.Sprintf("%s: %s", msg, err.Error()) + } + if cfg.StderrWriter != nil { + fmt.Fprintln(cfg.StderrWriter, text) } + cfg.Logger.Error(text) } func (cfg *Config) PrintInfo(msg string) { + if cfg.StdoutWriter != nil { + fmt.Fprintln(cfg.StdoutWriter, msg) + } cfg.Logger.Info("%s", msg) } diff --git a/docs/manual-cli-test-plan.md b/docs/manual-cli-test-plan.md new file mode 100644 index 00000000..1ac6a802 --- /dev/null +++ b/docs/manual-cli-test-plan.md @@ -0,0 +1,675 @@ +# Manual CLI Test Plan + +This plan covers the full `cmd/diode` CLI surface, including daemon-by-default behavior, multi-wallet access checks, file transfer, and the hidden/internal command paths that are exercised indirectly. + +## Scope + +Public commands covered: + +- top-level `--help` +- `version` +- `config` +- `query` +- `time` +- `fetch` +- `publish` +- `gateway` +- `socksd` +- `files` +- `push` +- `pull` +- `ssh` +- `daemon` +- `join` +- `bns` +- `token` +- `reset` +- `update` +- `mcp` + +Hidden/internal coverage: + +- `ssh-proxy` +- `__daemon__` + +## Lab Setup + +1. Build the lab and fixture environment: + +```bash +./scripts/manual/setup_cli_lab.sh up +source ./.manual/cli-lab/env.sh +``` + +2. The env file exports: + +- `DIODE_OWNER_DB` +- `DIODE_PEER_DB` +- `DIODE_VIEWER_DB` +- `DIODE_OWNER_ADDR` +- `DIODE_PEER_ADDR` +- `DIODE_VIEWER_ADDR` +- `DIODE_OWNER_HTTP_ROOT` +- `DIODE_PEER_HTTP_ROOT` +- `DIODE_OWNER_FILES_ROOT` +- `DIODE_PEER_FILES_ROOT` +- `DIODE_OUTPUT_DIR` + +3. The env file also defines wrapper helpers: + +- `downer ...` + Runs owner wallet commands with `-no-daemon`. +- `dpeer ...` + Runs peer wallet commands with `-no-daemon`. +- `dviewer ...` + Runs viewer wallet commands with `-no-daemon`. +- `ddaemon ...` + Runs owner wallet commands with daemon mode enabled. +- `dstopdaemon` + Stops and cleans the global daemon transport files. + +4. The lab also starts two local HTTP fixtures: + +- `http://127.0.0.1:18080/` backed by `owner-public-fixture` +- `http://127.0.0.1:18081/` backed by `peer-public-fixture` + +## Execution Rules + +- Use `ddaemon` only for daemon coverage and single-wallet daemon-managed runtime tests. +- Use `downer`, `dpeer`, and `dviewer` for multi-wallet tests because the daemon is global per user and cannot safely represent multiple `-dbpath` identities at once. +- Before any daemon section, run `dstopdaemon`. +- When a test starts a long-running command in one terminal, keep that terminal open and run the verification command in a second terminal. +- Record for each test: + - exact command + - exit code + - stdout/stderr + - whether the daemon state or runtime state changed as expected + +## Wallet Roles + +- Owner wallet: + Hosts public/private ports and file listeners. +- Peer wallet: + Allowed client for private access tests. +- Viewer wallet: + Unauthorized client for private access tests. + +## Baseline Checks + +### Top-Level Help + +Command: + +```bash +./diode --help +./diode publish --help +./diode daemon --help || true +``` + +Verify: + +- help exits `0` +- help does not start the daemon +- help lists the expected public commands + +### Version + +Command: + +```bash +./diode version +``` + +Verify: + +- exits `0` +- prints OS/arch/CPU +- does not require a running daemon + +## Config and Identity + +### `config` + +Commands: + +```bash +downer config -list +downer config -set manual_key=manual_value +downer config -list +downer config -delete manual_key +downer config -list +downer config -unsafe -list +``` + +Verify: + +- first `config -list` creates the wallet DB if it does not exist +- output includes `
` and the owner address matches `DIODE_OWNER_ADDR` +- `manual_key` appears after `-set` +- `manual_key` disappears after `-delete` +- `-unsafe` prints the private key material only for the current wallet + +### `reset` + +Prerequisites: + +- funded owner wallet on the target network +- operator accepts destructive wallet/fleet change + +Commands: + +```bash +downer reset +downer reset -experimental +``` + +Verify: + +- deploys a new fleet contract +- prints new fleet address +- persists the fleet into the owner DB +- a second `reset` against an already initialized wallet reports that the client is already initialized + +## Read-Only Network Commands + +### `query` + +Commands: + +```bash +downer query -address "$DIODE_OWNER_ADDR" +downer query -address "$DIODE_PEER_ADDR" +``` + +Verify: + +- exits `0` +- prints account type when decodable +- prints one or more device tickets or a clear resolution failure +- device ticket output includes fleet, server, block, and validation status fields + +### `time` + +Command: + +```bash +downer time +``` + +Verify: + +- exits `0` +- prints minimum and maximum blockchain consensus time +- values are plausible and `maximum >= minimum` + +## Daemon Lifecycle and Dispatch + +### Implicit Daemon Startup + +Commands: + +```bash +dstopdaemon +ddaemon publish -public 18080:18080 +ddaemon daemon status +``` + +Verify: + +- first `publish` autostarts the hidden daemon +- `publish` returns quickly and does not remain attached to the foreground +- `daemon status` shows `Active mode: publish` +- `daemon status` shows the old-style published port map for port `18080` + +### Root-Flag Implicit Publish + +Commands: + +```bash +ddaemon -bind 19090:$DIODE_OWNER_ADDR:18080 +ddaemon daemon status +ddaemon -bind 19090:$DIODE_OWNER_ADDR:18080 +ddaemon daemon status +``` + +Verify: + +- root-only `-bind` is treated as implicit `publish` +- the first bind adds one bind entry +- the second identical bind does not duplicate +- `daemon status` still shows the existing `-public 18080:18080` published port table + +### `daemon status` + +Command: + +```bash +ddaemon daemon status +``` + +Verify: + +- shows PID and socket path +- shows active mode and mode args +- shows published port map and bind map when configured +- reports SOCKS/API state correctly + +### `daemon restart` + +Commands: + +```bash +ddaemon daemon restart +ddaemon daemon status +``` + +Verify: + +- exits `0` +- daemon comes back within the timeout +- active mode and published ports survive restart + +### `daemon ports remove` + +Commands: + +```bash +ddaemon daemon ports remove 18080 +ddaemon daemon status +``` + +Verify: + +- removes only the requested published port +- leaves unrelated binds intact +- if no published ports remain, mode is stopped cleanly + +### `daemon ports clear` + +Commands: + +```bash +ddaemon publish -public 18080:18080 +ddaemon daemon ports clear +ddaemon daemon status +``` + +Verify: + +- clears published ports +- stops the active publish/files mode +- `daemon status` returns `Active mode: none` + +### `daemon stop` + +Commands: + +```bash +ddaemon daemon stop +ddaemon daemon status +``` + +Verify: + +- stop exits `0` +- subsequent `daemon status` reports `not running` + +## Publish, Fetch, and Multi-Wallet Access + +### Public Publish + +Terminal A: + +```bash +dstopdaemon +ddaemon publish -public 18080:18080 +``` + +Terminal B: + +```bash +dpeer fetch -url "http://$DIODE_OWNER_ADDR.diode.link:18080/" -output "$DIODE_OUTPUT_DIR/public-owner.html" +cat "$DIODE_OUTPUT_DIR/public-owner.html" +``` + +Verify: + +- `publish` prints the old-style port map +- `fetch` exits `0` +- downloaded body contains `owner-public-fixture` + +### Private Publish With Allowlist + +Terminal A: + +```bash +ddaemon publish -private "18081:18081,$DIODE_PEER_ADDR" +``` + +Terminal B: + +```bash +dpeer fetch -url "http://$DIODE_OWNER_ADDR.diode.link:18081/" -output "$DIODE_OUTPUT_DIR/private-peer.html" +cat "$DIODE_OUTPUT_DIR/private-peer.html" +``` + +Terminal C: + +```bash +dviewer fetch -url "http://$DIODE_OWNER_ADDR.diode.link:18081/" -output "$DIODE_OUTPUT_DIR/private-viewer.html" +``` + +Verify: + +- peer wallet succeeds and sees `peer-public-fixture` only if owner fixture on `18081` is running +- viewer wallet fails with a clear access error or connection failure +- `daemon status` shows the allowlisted private port + +Note: + +- if you want a dedicated private-only fixture on `18081`, start one manually: + +```bash +python3 -m http.server 18081 --bind 127.0.0.1 --directory "$DIODE_PEER_HTTP_ROOT" +``` + +### `fetch` + +Commands: + +```bash +dpeer fetch -url "http://$DIODE_OWNER_ADDR.diode.link:18080/health.json" -output "$DIODE_OUTPUT_DIR/health.json" +dpeer fetch -method GET -header "accept: application/json" -url "http://$DIODE_OWNER_ADDR.diode.link:18080/health.json" -output "$DIODE_OUTPUT_DIR/health-header.json" +``` + +Verify: + +- both commands exit `0` +- output files exist +- response body matches the hosted JSON fixture + +## Files, Push, and Pull + +### `files` + +Terminal A: + +```bash +dstopdaemon +ddaemon files -fileroot "$DIODE_OWNER_FILES_ROOT" 18180 +``` + +Terminal B: + +```bash +dpeer pull "$DIODE_OWNER_ADDR:18180:missing.txt" "$DIODE_OUTPUT_DIR/missing.txt" +``` + +Verify: + +- `files` prints the file port banner +- missing file pull fails with a non-2xx error + +### `push` + +Command: + +```bash +dpeer push "$DIODE_SAMPLE_UPLOAD" "$DIODE_OWNER_ADDR:18180:uploads/sample-upload.txt" +``` + +Verify: + +- exits `0` +- owner filesystem now contains `$DIODE_OWNER_FILES_ROOT/uploads/sample-upload.txt` +- file content matches `manual-upload-payload` + +### `pull` + +Command: + +```bash +dviewer pull "$DIODE_OWNER_ADDR:18180:uploads/sample-upload.txt" "$DIODE_OUTPUT_DIR/pulled-upload.txt" +cat "$DIODE_OUTPUT_DIR/pulled-upload.txt" +``` + +Verify: + +- exits `0` +- downloaded file exists +- downloaded content matches the uploaded content + +## Proxy and Gateway Modes + +### `socksd` + +Terminal A: + +```bash +downer publish -public 18080:18080 +``` + +Terminal B: + +```bash +dpeer socksd -socksd_host 127.0.0.1 -socksd_port 19082 +``` + +Terminal C: + +```bash +curl --socks5-hostname 127.0.0.1:19082 "http://$DIODE_OWNER_ADDR.diode.link:18080/" +``` + +Verify: + +- SOCKS listener binds to `127.0.0.1:19082` +- curl returns `owner-public-fixture` +- stopping `socksd` tears down the listener cleanly + +### `gateway` + +Terminal A: + +```bash +downer publish -public 18080:18080 +``` + +Terminal B: + +```bash +dpeer gateway -httpd_host 127.0.0.1 -httpd_port 19080 +``` + +Terminal C: + +```bash +curl -H "Host: $DIODE_OWNER_ADDR.diode.link:18080" http://127.0.0.1:19080/ +``` + +Verify: + +- gateway listener binds to `127.0.0.1:19080` +- response body is the owner fixture +- with `-secure` and valid certs, HTTPS listener also starts and serves the same destination + +## SSH + +Prerequisites: + +- OpenSSH `ssh` and `ssh-keygen` installed on the client machine +- owner machine has a local UNIX account to expose via `-sshd` + +Terminal A: + +```bash +dstopdaemon +ddaemon publish -sshd "public:2222:$USER" +``` + +Terminal B: + +```bash +dpeer ssh "$USER@$DIODE_OWNER_ADDR.diode" -p 2222 +``` + +Verify: + +- `ssh` prints the local/daemon proxy address it is using +- the connection launches the system `ssh` client +- login succeeds or reaches normal SSH host key / auth prompts +- `ps` or SSH verbose output shows the hidden `ssh-proxy` path is being used + +Hidden coverage: + +- `ssh-proxy` is considered covered when `diode ssh` succeeds through its ProxyCommand path + +## Blockchain Write Commands + +### `token` + +Prerequisites: + +- funded owner wallet + +Commands: + +```bash +downer token -balance +downer token -to "$DIODE_PEER_ADDR" -value 1wei -gasprice 1gwei +dpeer token -balance +``` + +Verify: + +- balance command exits `0` and prints the wallet balance +- transfer command submits successfully +- peer balance or state changes after confirmation + +### `bns` + +Prerequisites: + +- funded owner wallet +- unique BNS name available for testing + +Commands: + +```bash +downer bns -lookup your-test-name +downer bns -register "your-test-name=$DIODE_OWNER_ADDR" +dviewer bns -lookup your-test-name +downer bns -account your-test-name +downer bns -transfer "your-test-name=$DIODE_PEER_ADDR" +downer bns -unregister your-test-name +``` + +Verify: + +- lookup returns owner and mapped address after registration +- account lookup returns nonce, code, and balance data +- transfer updates the reported owner +- unregister removes the mapping + +### `reset` + +This is already covered above under Config and Identity because it mutates fleet identity. + +## Join + +Prerequisites: + +- valid perimeter contract address +- any required Oasis local or remote environment variables for the selected `-network` +- if WireGuard coverage is required, local WireGuard tooling and permissions + +Commands: + +```bash +downer join -dry +downer join -network mainnet +downer join -network testnet +downer join -wireguard -dry +``` + +Verify: + +- `-dry` validates and prints contract-derived state without starting the daemon loop +- normal join starts the long-lived reconcile loop +- contract-driven published ports, binds, SOCKS, and WireGuard state are applied +- switching away from `join` by running `ddaemon publish ...` stops the join mode cleanly + +## MCP + +Prerequisites: + +- MCP inspector or another stdio MCP client available + +Suggested command: + +```bash +./diode -update=false mcp +``` + +Verify with your MCP client: + +- server starts on stdio and stays attached +- tool list includes version, client info, query address, file push/pull, and deploy tools when enabled +- `-mcp-preset=minimal` reduces the exposed tool set +- `-mcp-tool=...` or the corresponding env var filters tools as expected + +## Update + +Prerequisites: + +- use an official release build, not a local `development` binary, for a meaningful update test + +Commands: + +```bash +./diode update +ddaemon update +``` + +Verify: + +- standalone update either reports `No updates` or installs a newer release and restarts +- daemon-routed update returns the update output to the CLI +- after daemon update, `ddaemon daemon status` works again and the daemon resumes the previous active mode + +## Internal Command Coverage + +### `__daemon__` + +Do not invoke directly in normal manual testing. + +Verify indirectly through: + +- daemon autostart from `ddaemon publish ...` +- `daemon status` +- `daemon restart` +- daemon metadata/socket presence + +### `ssh-proxy` + +Do not invoke directly in normal manual testing. + +Verify indirectly through: + +- successful `diode ssh ...` +- ProxyCommand execution in SSH verbose logs + +## Cleanup + +Commands: + +```bash +dstopdaemon +./scripts/manual/setup_cli_lab.sh down +``` + +Verify: + +- daemon is stopped +- fixture HTTP servers are stopped +- no unexpected long-running `diode` test processes remain diff --git a/go.mod b/go.mod index 3afbf2d0..2ae2214c 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/diodechain/diode_client go 1.25.9 require ( + github.com/Microsoft/go-winio v0.6.2 github.com/caddyserver/certmagic v0.19.2 github.com/creack/pty v1.1.24 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc diff --git a/go.sum b/go.sum index ebe39b4b..d4df1f17 100644 --- a/go.sum +++ b/go.sum @@ -37,6 +37,8 @@ github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2 github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/zstd v1.4.5 h1:EndNeuB0l9syBZhut0wns3gV1hL8zX8LIu6ZiVHWLIQ= github.com/DataDog/zstd v1.4.5/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/StackExchange/wmi v1.2.1 h1:VIkavFPXSjcnS+O8yTq7NI32k0R5Aj+v39y29VYDOSA= github.com/StackExchange/wmi v1.2.1/go.mod h1:rcmrprowKIVzvc+NUiLncP2uuArMWLCbu9SBzvHz7e8= github.com/VictoriaMetrics/fastcache v1.12.1 h1:i0mICQuojGDL3KblA7wUNlY5lOK6a4bwt3uRKnkZU40= diff --git a/rpc/socks.go b/rpc/socks.go index b484c98b..3aa9d417 100644 --- a/rpc/socks.go +++ b/rpc/socks.go @@ -907,8 +907,7 @@ func (socksServer *Server) startSocksListeners() error { if socksServer.Closed() { return } - err := socksServer.handleUDP(udp, buf) - if err != nil && socksServer.Closed() { + if err := socksServer.handleUDP(udp, buf); err != nil { return } } @@ -953,8 +952,14 @@ func (socksServer *Server) Start() error { } func (socksServer *Server) handleUDP(udpconn net.PacketConn, packet []byte) error { + if udpconn == nil { + return nil + } n, addr, err := udpconn.ReadFrom(packet) if err != nil { + if socksServer.Closed() { + return err + } socksServer.logger.Error("handleUDP error: %v", err) return err } diff --git a/rpc/ssh_service.go b/rpc/ssh_service.go index 3ba5f218..f482df27 100644 --- a/rpc/ssh_service.go +++ b/rpc/ssh_service.go @@ -496,7 +496,7 @@ func handleSSHSessionStart(req *ssh.Request, proc *sshProcessHandle, localUser s if err != nil || !handled { return next, err } - go proxySSHProcessIO(channel, next) + proxySSHProcessIO(channel, next) go waitForProcess(next) return next, nil } diff --git a/scripts/manual/setup_cli_lab.sh b/scripts/manual/setup_cli_lab.sh new file mode 100755 index 00000000..b00da9cf --- /dev/null +++ b/scripts/manual/setup_cli_lab.sh @@ -0,0 +1,258 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" +ACTION="${1:-up}" +LAB_DIR="${2:-${DIODE_MANUAL_LAB_DIR:-$ROOT_DIR/.manual/cli-lab}}" +BIN_PATH="$ROOT_DIR/diode" +RUN_DIR="$LAB_DIR/run" +WALLET_DIR="$LAB_DIR/wallets" +HTTP_DIR="$LAB_DIR/http" +FILES_DIR="$LAB_DIR/files" +OUT_DIR="$LAB_DIR/out" +ENV_PATH="$LAB_DIR/env.sh" + +OWNER_DB="$WALLET_DIR/owner/private.db" +PEER_DB="$WALLET_DIR/peer/private.db" +VIEWER_DB="$WALLET_DIR/viewer/private.db" + +OWNER_HTTP_PORT=18080 +PEER_HTTP_PORT=18081 + +OWNER_HTTP_ROOT="$HTTP_DIR/owner" +PEER_HTTP_ROOT="$HTTP_DIR/peer" +OWNER_FILES_ROOT="$FILES_DIR/owner" +PEER_FILES_ROOT="$FILES_DIR/peer" + +usage() { + cat </dev/null 2>&1; then + echo "missing required command: $name" >&2 + exit 1 + fi +} + +stop_pidfile() { + local pidfile="$1" + if [[ ! -f "$pidfile" ]]; then + return + fi + local pid="" + pid="$(cat "$pidfile" 2>/dev/null || true)" + if [[ -n "$pid" ]] && kill -0 "$pid" >/dev/null 2>&1; then + kill "$pid" >/dev/null 2>&1 || true + for _ in $(seq 1 20); do + if ! kill -0 "$pid" >/dev/null 2>&1; then + break + fi + sleep 0.1 + done + if kill -0 "$pid" >/dev/null 2>&1; then + kill -9 "$pid" >/dev/null 2>&1 || true + fi + fi + rm -f "$pidfile" +} + +start_http_server() { + local name="$1" + local root="$2" + local port="$3" + local pidfile="$RUN_DIR/$name.pid" + local logfile="$RUN_DIR/$name.log" + + stop_pidfile "$pidfile" + nohup python3 -m http.server "$port" --bind 127.0.0.1 --directory "$root" >"$logfile" 2>&1 & + local pid=$! + echo "$pid" >"$pidfile" + + for _ in $(seq 1 20); do + if ! kill -0 "$pid" >/dev/null 2>&1; then + echo "failed to start $name fixture server, see $logfile" >&2 + exit 1 + fi + if python3 - </dev/null 2>&1 +import socket +s = socket.socket() +s.settimeout(0.2) +try: + s.connect(("127.0.0.1", $port)) +except OSError: + raise SystemExit(1) +finally: + s.close() +PY + then + return + fi + sleep 0.2 + done + echo "timed out waiting for $name fixture server on 127.0.0.1:$port" >&2 + exit 1 +} + +wallet_address() { + local dbpath="$1" + local output + output="$("$BIN_PATH" -update=false -no-daemon -dbpath "$dbpath" config -list 2>&1)" + printf '%s\n' "$output" | awk -F: '/
/{gsub(/[[:space:]]/, "", $2); print $2; exit}' +} + +write_lab_files() { + mkdir -p "$OWNER_HTTP_ROOT" "$PEER_HTTP_ROOT" "$OWNER_FILES_ROOT" "$PEER_FILES_ROOT" "$OUT_DIR" "$RUN_DIR" + + cat >"$OWNER_HTTP_ROOT/index.html" <<'EOF' +owner-public-fixture +EOF + cat >"$OWNER_HTTP_ROOT/health.json" <<'EOF' +{"service":"owner","status":"ok"} +EOF + cat >"$PEER_HTTP_ROOT/index.html" <<'EOF' +peer-public-fixture +EOF + cat >"$PEER_HTTP_ROOT/health.json" <<'EOF' +{"service":"peer","status":"ok"} +EOF + cat >"$LAB_DIR/sample-upload.txt" <<'EOF' +manual-upload-payload +EOF +} + +write_env() { + local owner_addr="$1" + local peer_addr="$2" + local viewer_addr="$3" + + cat >"$ENV_PATH" </dev/null 2>&1 || true + rm -f "\$HOME/.config/diode/daemon.sock" "\$HOME/.config/diode/daemon.sock.json" +} +EOF +} + +do_up() { + require_cmd go + require_cmd python3 + + mkdir -p "$LAB_DIR" + write_lab_files + + ( + cd "$ROOT_DIR" + go build -o ./diode ./cmd/diode + ) + + local owner_addr peer_addr viewer_addr + owner_addr="$(wallet_address "$OWNER_DB")" + peer_addr="$(wallet_address "$PEER_DB")" + viewer_addr="$(wallet_address "$VIEWER_DB")" + + if [[ -z "$owner_addr" || -z "$peer_addr" || -z "$viewer_addr" ]]; then + echo "failed to derive one or more wallet addresses" >&2 + exit 1 + fi + + start_http_server "owner-http" "$OWNER_HTTP_ROOT" "$OWNER_HTTP_PORT" + start_http_server "peer-http" "$PEER_HTTP_ROOT" "$PEER_HTTP_PORT" + write_env "$owner_addr" "$peer_addr" "$viewer_addr" + + cat </dev/null || true)" + if [[ -n "$pid" ]] && kill -0 "$pid" >/dev/null 2>&1; then + echo "$name: running (pid $pid)" + continue + fi + fi + echo "$name: stopped" + done + if [[ -f "$ENV_PATH" ]]; then + echo "env: $ENV_PATH" + else + echo "env: missing" + fi +} + +case "$ACTION" in +up) + do_up + ;; +down) + do_down + ;; +status) + do_status + ;; +help|-h|--help) + usage + ;; +*) + usage >&2 + exit 2 + ;; +esac