From 76fbb2e096bfcfdc09722c9e7a94a04d0d8db665 Mon Sep 17 00:00:00 2001 From: Spadav Date: Tue, 2 Jun 2026 12:56:48 +0100 Subject: [PATCH] Add Docker container discovery --- core/cli.go | 5 +- core/commands.go | 4 +- core/commands_test.go | 10 +++ core/config.go | 1 + core/config_test.go | 11 +++ core/controller.go | 7 +- core/discover.go | 17 +++- core/discover_docker.go | 142 ++++++++++++++++++++++++++++++++ core/discover_docker_windows.go | 19 +++++ core/discover_unix.go | 24 ++++++ core/discover_unix_test.go | 75 +++++++++++++++++ core/doctor.go | 18 ++++ core/output.go | 12 ++- core/poll_test.go | 2 +- core/proxy.go | 43 ++++++---- core/store.go | 6 +- core/store_test.go | 23 +++++- 17 files changed, 384 insertions(+), 35 deletions(-) create mode 100644 core/commands_test.go create mode 100644 core/discover_docker.go create mode 100644 core/discover_docker_windows.go diff --git a/core/cli.go b/core/cli.go index da95d59..838f172 100644 --- a/core/cli.go +++ b/core/cli.go @@ -71,6 +71,7 @@ Flags (defaults come from ~/.tailscale-proxy/config.json if present): tailnet host resolve the public funnel name (persists) --log-requests Log each proxied request (default on) --quiet Disable per-request logging + --docker Also query Docker API for containers (default off) -h, --help Show this help Press Ctrl-C to stop — the Serve/Funnel entry is reset automatically on exit. @@ -93,6 +94,7 @@ type startOpts struct { forwardHost bool quiet bool acceptDNS string + docker bool } // modeOf returns the exposure mode for the private flag. @@ -125,6 +127,7 @@ func cmdStart(argv []string) int { fs.BoolVar(&o.forwardHost, "forward-host", cfg.ForwardHost, "forward the public host to apps (X-Forwarded-Host/Proto); default presents a local request") fs.StringVar(&o.acceptDNS, "accept-dns", cfg.AcceptDNS, "optionally set Tailscale MagicDNS (true|false) on start; default unset = leave it alone") fs.BoolVar(&o.quiet, "quiet", false, "disable per-request logging") + fs.BoolVar(&o.docker, "docker", cfg.Docker, "also query Docker API for containers (default off)") fs.BoolVar(&o.bg, "bg", false, "run detached in background") var fg bool fs.BoolVar(&fg, "fg", false, "run in foreground (default)") @@ -186,7 +189,7 @@ func cmdStart(argv []string) int { fmt.Printf("set tailscale accept-dns=%s (persists after exit; revert with: tailscale set --accept-dns=%s)\n", o.acceptDNS, revert) } - dcfg := discoverConfig{rng: rng, all: o.all, runtimes: parseRuntimes(o.runtimesRaw)} + dcfg := discoverConfig{rng: rng, all: o.all, runtimes: parseRuntimes(o.runtimesRaw), docker: o.docker} disc := newDiscoverer(runner) if !printChecks(runDoctor(runner, disc, dcfg, mode)) && !o.proxyOnly { diff --git a/core/commands.go b/core/commands.go index b849b39..9e50379 100644 --- a/core/commands.go +++ b/core/commands.go @@ -21,6 +21,7 @@ func cmdConfigure(argv []string) int { fs.IntVar(&cfg.DeregisterCycles, "deregister-cycles", cfg.DeregisterCycles, "missing scans before removal") fs.BoolVar(&cfg.LogRequests, "log-requests", cfg.LogRequests, "log each proxied request") fs.BoolVar(&cfg.ForwardHost, "forward-host", cfg.ForwardHost, "forward the public host to apps") + fs.BoolVar(&cfg.Docker, "docker", cfg.Docker, "also query Docker API for containers") fs.StringVar(&cfg.AcceptDNS, "accept-dns", cfg.AcceptDNS, "set Tailscale MagicDNS (true|false) on start; empty = leave it alone") if err := fs.Parse(argv); err != nil { if err == flag.ErrHelp { @@ -94,10 +95,11 @@ func queryConfig(argv []string) (Mode, discoverConfig, int) { runtimesRaw := fs.String("runtimes", cfg.Runtimes, "comma-separated runtimes") private := fs.Bool("private", cfg.Private, "private (Serve) mode") httpsPort := fs.Int("https-port", cfg.HTTPSPort, "public/tailnet HTTPS port") + docker := fs.Bool("docker", cfg.Docker, "also query Docker API for containers") _ = fs.Parse(argv) rng, err := parsePortRange(*portsRaw) if err != nil { rng = PortRange{Lo: 3000, Hi: 5000} } - return modeOf(*private), discoverConfig{rng: rng, all: *all, runtimes: parseRuntimes(*runtimesRaw)}, *httpsPort + return modeOf(*private), discoverConfig{rng: rng, all: *all, runtimes: parseRuntimes(*runtimesRaw), docker: *docker}, *httpsPort } diff --git a/core/commands_test.go b/core/commands_test.go new file mode 100644 index 0000000..5dadb78 --- /dev/null +++ b/core/commands_test.go @@ -0,0 +1,10 @@ +package core + +import "testing" + +func TestQueryConfigParsesDockerFlag(t *testing.T) { + _, dcfg, _ := queryConfig([]string{"--docker"}) + if !dcfg.docker { + t.Fatal("queryConfig should enable Docker discovery from --docker") + } +} diff --git a/core/config.go b/core/config.go index 916ee24..99d51c5 100644 --- a/core/config.go +++ b/core/config.go @@ -21,6 +21,7 @@ type Config struct { DeregisterCycles int `json:"deregisterCycles"` // missing scans before removal ForwardHost bool `json:"forwardHost"` // forward the external host to the app AcceptDNS string `json:"acceptDns"` // "" = leave Tailscale DNS alone; "true"/"false" = set on start + Docker bool `json:"docker"` // also query Docker API for containers } // defaultConfig returns the built-in defaults. diff --git a/core/config_test.go b/core/config_test.go index b47cca8..e71e2ca 100644 --- a/core/config_test.go +++ b/core/config_test.go @@ -33,6 +33,9 @@ func TestDefaultConfig(t *testing.T) { if cfg.Private { t.Error("Private = true, want false") } + if cfg.Docker { + t.Error("Docker = true, want false") + } } func TestLoadConfigFrom_missingReturnsDefaults(t *testing.T) { @@ -66,6 +69,7 @@ func TestSaveAndLoadRoundTrip(t *testing.T) { HTTPSPort: 8443, LogRequests: false, DeregisterCycles: 10, + Docker: true, } if err := saveConfigTo(path, original); err != nil { @@ -84,6 +88,13 @@ func TestSaveAndLoadRoundTrip(t *testing.T) { } } +func TestOptionsFromConfigIncludesDocker(t *testing.T) { + opts := OptionsFromConfig(Config{Docker: true}) + if !opts.Docker { + t.Fatal("OptionsFromConfig should copy Docker") + } +} + func TestLoadConfigFrom_partialOverlaysDefaults(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "config.json") diff --git a/core/controller.go b/core/controller.go index a722fd6..8afed71 100644 --- a/core/controller.go +++ b/core/controller.go @@ -24,6 +24,7 @@ type Options struct { DeregisterCycles int ForwardHost bool LogRequests bool + Docker bool ProxyOnly bool // run the proxy only; skip the Serve/Funnel entry } @@ -33,7 +34,7 @@ func OptionsFromConfig(c Config) Options { Ports: c.Ports, All: c.All, Runtimes: c.Runtimes, Private: c.Private, Bind: c.Bind, Port: c.Port, Interval: c.Interval, HTTPSPort: c.HTTPSPort, DeregisterCycles: c.DeregisterCycles, ForwardHost: c.ForwardHost, - LogRequests: c.LogRequests, + LogRequests: c.LogRequests, Docker: c.Docker, } } @@ -122,7 +123,7 @@ func (c *Controller) Start(o Options) error { mode := modeOf(o.Private) disc := newDiscoverer(runner) - dcfg := discoverConfig{rng: rng, all: o.All, runtimes: parseRuntimes(o.Runtimes)} + dcfg := discoverConfig{rng: rng, all: o.All, runtimes: parseRuntimes(o.Runtimes), docker: o.Docker} store := NewRouteStore(func() ([]Service, []Duplicate, error) { return disc.Discover(dcfg) }, o.DeregisterCycles) _, _, _, _ = store.refresh() @@ -342,6 +343,6 @@ func Doctor(o Options) []Check { rng = PortRange{Lo: 3000, Hi: 5000} } r := execRunner{} - dcfg := discoverConfig{rng: rng, all: o.All, runtimes: parseRuntimes(o.Runtimes)} + dcfg := discoverConfig{rng: rng, all: o.All, runtimes: parseRuntimes(o.Runtimes), docker: o.Docker} return runDoctor(r, newDiscoverer(r), dcfg, modeOf(o.Private)) } diff --git a/core/discover.go b/core/discover.go index 6e1963b..506e503 100644 --- a/core/discover.go +++ b/core/discover.go @@ -13,7 +13,8 @@ import ( // Service is one discovered listening dev server. type Service struct { Slug string // URL path segment - Port int // listening port (127.0.0.1) + Port int // upstream port + Host string // upstream host/IP; empty means 127.0.0.1 Runtime string // node|bun|deno or "" (unknown) Dir string // working directory (may be "") PID int @@ -29,6 +30,7 @@ type discoverConfig struct { rng PortRange all bool runtimes map[string]bool // nil = all known web runtimes + docker bool // also query Docker API for containers } // listener is a raw OS-level listening socket (pre-classification). Comm is the @@ -36,6 +38,7 @@ type discoverConfig struct { // process's own name from `ps` (e.g. "http-server", or a "go-build" temp path). type listener struct { Port int + Host string PID int Comm string PsComm string @@ -135,6 +138,9 @@ func projectRootDir(dir string) string { if dir == "" || dir == "/" { return "" } + if !filepath.IsAbs(dir) { + return dir + } d := dir for { for _, m := range projectMarkers { @@ -178,7 +184,11 @@ type Duplicate struct { // serviceOf builds a Service from a raw listener under a given slug. func serviceOf(l listener, slug string) Service { - return Service{Slug: slug, Port: l.Port, Runtime: runtimeOf(l), Dir: l.Cwd, PID: l.PID} + host := l.Host + if host == "" { + host = "127.0.0.1" + } + return Service{Slug: slug, Port: l.Port, Host: host, Runtime: runtimeOf(l), Dir: l.Cwd, PID: l.PID} } // projectBaseSlug derives the clean project slug from a working directory, or @@ -326,6 +336,9 @@ func (d *Discoverer) Discover(cfg discoverConfig) ([]Service, []Duplicate, error if err != nil { return nil, nil, err } + if cfg.docker { + ls = d.mergeDockerListeners(ls, cfg.rng) + } svcs, dups := buildServices(ls, cfg.all, cfg.runtimes) return svcs, dups, nil } diff --git a/core/discover_docker.go b/core/discover_docker.go new file mode 100644 index 0000000..0723c81 --- /dev/null +++ b/core/discover_docker.go @@ -0,0 +1,142 @@ +//go:build !windows + +package core + +import ( + "context" + "encoding/json" + "io" + "net" + "net/http" + "os" + "strings" + "time" +) + +const dockerSocketPath = "/var/run/docker.sock" + +type dockerPortBinding struct { + PublicPort int `json:"PublicPort"` + PrivatePort int `json:"PrivatePort"` + IP string `json:"IP"` +} + +type dockerNetworkInfo struct { + IPAddress string `json:"IPAddress"` +} + +type dockerContainerInfo struct { + Names []string `json:"Names"` + Ports []dockerPortBinding `json:"Ports"` + NetworkSettings struct { + Networks map[string]dockerNetworkInfo `json:"Networks"` + } `json:"NetworkSettings"` +} + +// dockerListeners queries the Docker API for running containers and returns +// listeners for each port found. Returns nil (not an error) if Docker is +// unavailable — the caller (discover_unix.go) treats this as "no docker +// listeners" rather than a failure. +func (d *Discoverer) dockerListeners(rng PortRange) []listener { + if _, err := os.Stat(dockerSocketPath); err != nil { + return nil + } + + client := &http.Client{ + Timeout: 5 * time.Second, + Transport: &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", dockerSocketPath) + }, + }, + } + + resp, err := client.Get("http://localhost/containers/json") + if err != nil { + return nil + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil + } + + return parseDockerListeners(body, rng) +} + +func parseDockerListeners(body []byte, rng PortRange) []listener { + var containers []dockerContainerInfo + if err := json.Unmarshal(body, &containers); err != nil { + return nil + } + + var result []listener + for ci, c := range containers { + if len(c.Names) == 0 || len(c.Ports) == 0 { + continue + } + name := strings.TrimPrefix(c.Names[0], "/") + containerIP := firstContainerIP(c.NetworkSettings.Networks) + + for pi, p := range c.Ports { + port := p.PublicPort + host := "127.0.0.1" + if port == 0 { + if containerIP == "" { + continue + } + port = p.PrivatePort + host = containerIP + } + if !rng.contains(port) { + continue + } + result = append(result, listener{ + Port: port, + Host: host, + PID: syntheticDockerPID(ci, pi), + Comm: "docker", + Cwd: name, + }) + } + } + + return result +} + +func firstContainerIP(networks map[string]dockerNetworkInfo) string { + for _, n := range networks { + if n.IPAddress != "" { + return n.IPAddress + } + } + return "" +} + +func syntheticDockerPID(containerIndex, portIndex int) int { + return -1 - containerIndex*1000 - portIndex +} + +// dockerAvailable checks if the Docker socket is accessible and the API responds. +func dockerAvailable() bool { + if _, err := os.Stat(dockerSocketPath); err != nil { + return false + } + + client := &http.Client{ + Timeout: 5 * time.Second, + Transport: &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", dockerSocketPath) + }, + }, + } + + resp, err := client.Get("http://localhost/version") + if err != nil { + return false + } + defer resp.Body.Close() + return resp.StatusCode == http.StatusOK +} diff --git a/core/discover_docker_windows.go b/core/discover_docker_windows.go new file mode 100644 index 0000000..7b241f1 --- /dev/null +++ b/core/discover_docker_windows.go @@ -0,0 +1,19 @@ +//go:build windows + +package core + +// dockerListeners is a no-op on Windows — Docker API over Unix socket +// is not available. Returns nil so the caller treats it as "no docker listeners". +func (d *Discoverer) dockerListeners(rng PortRange) []listener { + return nil +} + +// mergeDockerListeners is a no-op on Windows. +func (d *Discoverer) mergeDockerListeners(lsofListeners []listener, rng PortRange) []listener { + return lsofListeners +} + +// dockerAvailable always returns false on Windows since we rely on Unix sockets. +func dockerAvailable() bool { + return false +} diff --git a/core/discover_unix.go b/core/discover_unix.go index c99e217..1448c02 100644 --- a/core/discover_unix.go +++ b/core/discover_unix.go @@ -51,6 +51,30 @@ func (d *Discoverer) listeners(rng PortRange) ([]listener, error) { return ls, nil } +// mergeDockerListeners appends Docker-discovered listeners to lsof results, +// deduplicating by port (lsof takes priority). +func (d *Discoverer) mergeDockerListeners(lsofListeners []listener, rng PortRange) []listener { + dockerLs := d.dockerListeners(rng) + if len(dockerLs) == 0 { + return lsofListeners + } + + // Build a set of ports already covered by lsof. + covered := make(map[int]bool, len(lsofListeners)) + for _, l := range lsofListeners { + covered[l.Port] = true + } + + for _, dl := range dockerLs { + if !covered[dl.Port] { + lsofListeners = append(lsofListeners, dl) + covered[dl.Port] = true + } + } + + return lsofListeners +} + // parseLsofListeners parses `lsof -Fpcn` output, deduping per (pid,port). func parseLsofListeners(out string, rng PortRange) []listener { var res []listener diff --git a/core/discover_unix_test.go b/core/discover_unix_test.go index d6c976a..cc4195a 100644 --- a/core/discover_unix_test.go +++ b/core/discover_unix_test.go @@ -51,3 +51,78 @@ func TestParseLsofCwd(t *testing.T) { t.Errorf("4231 cwd = %q", m[4231]) } } + +func TestParseDockerListeners_hostBoundAndInternalFallback(t *testing.T) { + body := []byte(`[ + { + "Names":["/web"], + "Ports":[ + {"PrivatePort":3000,"PublicPort":49153}, + {"PrivatePort":8080,"PublicPort":0} + ], + "NetworkSettings":{"Networks":{"bridge":{"IPAddress":"172.17.0.2"}}} + }, + { + "Names":["/worker"], + "Ports":[{"PrivatePort":9000,"PublicPort":0}], + "NetworkSettings":{"Networks":{"bridge":{"IPAddress":""}}} + } + ]`) + + ls := parseDockerListeners(body, PortRange{Lo: 3000, Hi: 50000}) + if len(ls) != 2 { + t.Fatalf("got %d listeners: %+v", len(ls), ls) + } + if ls[0].Port != 49153 || ls[0].Host != "127.0.0.1" || ls[0].Cwd != "web" { + t.Fatalf("host-bound listener wrong: %+v", ls[0]) + } + if ls[1].Port != 8080 || ls[1].Host != "172.17.0.2" || ls[1].Cwd != "web" { + t.Fatalf("internal fallback listener wrong: %+v", ls[1]) + } + if ls[0].PID == ls[1].PID { + t.Fatalf("docker ports need distinct synthetic PIDs, got %+v", ls) + } +} + +func TestDockerListenersBuildServices_keepsMultipleContainerPorts(t *testing.T) { + ls := []listener{ + {Port: 3000, Host: "172.17.0.2", PID: -1, Comm: "docker", Cwd: "web"}, + {Port: 8080, Host: "172.17.0.2", PID: -2, Comm: "docker", Cwd: "web"}, + } + + svcs, dups := buildServices(ls, false, nil) + if len(svcs) != 2 { + t.Fatalf("got %d services: %+v", len(svcs), svcs) + } + if len(dups) != 1 { + t.Fatalf("got %d duplicate groups: %+v", len(dups), dups) + } + if svcs[0].Host != "172.17.0.2" || svcs[1].Host != "172.17.0.2" { + t.Fatalf("service hosts not preserved: %+v", svcs) + } +} + +func TestDockerListenerContainerNamesRemainDistinctProjects(t *testing.T) { + ls := []listener{ + {Port: 3000, Host: "127.0.0.1", PID: -1, Comm: "docker", Cwd: "web-api"}, + {Port: 8090, Host: "127.0.0.1", PID: -2, Comm: "docker", Cwd: "model-server"}, + } + + svcs, dups := buildServices(ls, false, nil) + if len(svcs) != 2 { + t.Fatalf("got %d services: %+v", len(svcs), svcs) + } + if len(dups) != 0 { + t.Fatalf("distinct containers should not be duplicates: %+v", dups) + } + bySlug := map[string]Service{} + for _, svc := range svcs { + bySlug[svc.Slug] = svc + } + if _, ok := bySlug["web-api"]; !ok { + t.Fatalf("missing web-api slug: %+v", svcs) + } + if _, ok := bySlug["model-server"]; !ok { + t.Fatalf("missing model-server slug: %+v", svcs) + } +} diff --git a/core/doctor.go b/core/doctor.go index 466b095..98daf98 100644 --- a/core/doctor.go +++ b/core/doctor.go @@ -25,6 +25,11 @@ const ( func runDoctor(r Runner, disc *Discoverer, cfg discoverConfig, mode Mode) []Check { var checks []Check + // Docker check (only when --docker is used). + if cfg.docker { + checks = append(checks, dockerCheck()) + } + verOut, _, verErr := r.Run("tailscale", "version") if verErr != nil { checks = append(checks, Check{ @@ -98,6 +103,19 @@ func runDoctor(r Runner, disc *Discoverer, cfg discoverConfig, mode Mode) []Chec return checks } +// dockerCheck returns a Check for Docker socket availability. +func dockerCheck() Check { + if dockerAvailable() { + return Check{Name: "docker available", OK: true, Detail: "socket accessible"} + } + return Check{ + Name: "docker available", + OK: false, + Detail: "Docker socket not accessible at /var/run/docker.sock", + Fix: "Start Docker or ensure the socket is available", + } +} + // firstLine returns the first line of s, trimmed. func firstLine(s string) string { if i := strings.IndexByte(s, '\n'); i >= 0 { diff --git a/core/output.go b/core/output.go index a2d53d4..5eeb181 100644 --- a/core/output.go +++ b/core/output.go @@ -47,7 +47,7 @@ func printStartHeader(o startOpts, mode Mode, rng PortRange, cfgPath string, exi func printServiceURLs(snap map[string]Service, node string, httpsPort int) { base := publicBase(node, httpsPort) for _, slug := range sortedSlugs(snap) { - fmt.Printf(" %s/%s/ → 127.0.0.1:%d\n", base, slug, snap[slug].Port) + fmt.Printf(" %s/%s/ → %s\n", base, slug, targetDisplay(snap[slug])) } } @@ -75,7 +75,7 @@ func printDiscovered(dcfg discoverConfig, mode Mode, httpsPort int) int { } for _, slug := range sortedSlugs(snap) { s := snap[slug] - fmt.Printf(" %-26s %-6s :%d pid %d %s\n", slug, runtimeOr(s.Runtime), s.Port, s.PID, dirOr(s.Dir)) + fmt.Printf(" %-26s %-6s %-21s pid %d %s\n", slug, runtimeOr(s.Runtime), targetDisplay(s), s.PID, dirOr(s.Dir)) if nerr == nil { fmt.Printf(" %s/%s/\n", publicBase(node, httpsPort), slug) } @@ -98,7 +98,7 @@ func printDuplicateNotes(dups []Duplicate) { if i == 0 { tag = " [main]" } - fmt.Printf(" /%s/ → :%d (%s, pid %d)%s\n", m.Slug, m.Port, runtimeOr(m.Runtime), m.PID, tag) + fmt.Printf(" /%s/ → %s (%s, pid %d)%s\n", m.Slug, targetDisplay(m), runtimeOr(m.Runtime), m.PID, tag) } } } @@ -126,11 +126,15 @@ func dirOr(dir string) string { return dir } +func targetDisplay(s Service) string { + return s.upstreamHost() + ":" + strconv.Itoa(s.Port) +} + // portList renders a project's services on one line: "/slug/ :port(runtime)". func portList(d Duplicate) string { parts := make([]string, 0, len(d.Members)) for _, m := range d.Members { - parts = append(parts, "/"+m.Slug+"/ :"+strconv.Itoa(m.Port)+"("+runtimeOr(m.Runtime)+")") + parts = append(parts, "/"+m.Slug+"/ "+targetDisplay(m)+"("+runtimeOr(m.Runtime)+")") } return strings.Join(parts, ", ") } diff --git a/core/poll_test.go b/core/poll_test.go index 0d47d39..d6c64e4 100644 --- a/core/poll_test.go +++ b/core/poll_test.go @@ -24,7 +24,7 @@ func TestPoll_picksUpChanges(t *testing.T) { stage.Store(1) deadline := time.After(2 * time.Second) for { - if p, ok := store.lookup("x"); ok && p == 9 { + if svc, ok := store.lookup("x"); ok && svc.Port == 9 { return } select { diff --git a/core/proxy.go b/core/proxy.go index a701e1e..2cb78cb 100644 --- a/core/proxy.go +++ b/core/proxy.go @@ -50,7 +50,7 @@ func writeIndex(w http.ResponseWriter, store *RouteStore, status int) { if rt == "" { rt = "?" } - fmt.Fprintf(w, " /%s/ → 127.0.0.1:%d (%s)\n", s, svc.Port, rt) + fmt.Fprintf(w, " /%s/ → %s:%d (%s)\n", s, svc.upstreamHost(), svc.Port, rt) } } @@ -58,14 +58,21 @@ type ctxKey int const targetKey ctxKey = 0 +func (s Service) upstreamHost() string { + if s.Host != "" { + return s.Host + } + return "127.0.0.1" +} + // target is the resolved upstream for a single request. type target struct { - port int // upstream port on the loopback interface + host string + port int // upstream port path string // rewritten path with the matched segment stripped } -// dialHost is the reliable IPv4 loopback address we connect to. -func (t target) dialHost() string { return "127.0.0.1:" + strconv.Itoa(t.port) } +func (t target) dialHost() string { return t.host + ":" + strconv.Itoa(t.port) } // hostHeader is the Host the app sees. We use "localhost" (not the raw IP) // because dev servers, CORS origins, and cookies are keyed to how developers @@ -96,7 +103,7 @@ func newHandler(store *RouteStore, logRequests, forwardHost bool) http.Handler { Rewrite: func(pr *httputil.ProxyRequest) { tgt := pr.In.Context().Value(targetKey).(target) pr.Out.URL.Scheme = "http" - pr.Out.URL.Host = tgt.dialHost() // connect to 127.0.0.1 + pr.Out.URL.Host = tgt.dialHost() pr.Out.URL.Path = tgt.path pr.Out.URL.RawQuery = pr.In.URL.RawQuery // Present the request as "localhost:" so it's indistinguishable @@ -127,9 +134,9 @@ func newHandler(store *RouteStore, logRequests, forwardHost bool) http.Handler { start := time.Now() rec := &statusRecorder{ResponseWriter: w, status: http.StatusOK} - port, path, ok := resolveRoute(store, r, w) + svc, path, ok := resolveRoute(store, r, w) if ok { - tgt := target{port: port, path: path} + tgt := target{host: svc.upstreamHost(), port: svc.Port, path: path} ctx := context.WithValue(r.Context(), targetKey, tgt) proxy.ServeHTTP(rec, r.WithContext(ctx)) } else { @@ -137,7 +144,7 @@ func newHandler(store *RouteStore, logRequests, forwardHost bool) http.Handler { } if logRequests { - logRequest(r, rec.status, port, time.Since(start)) + logRequest(r, rec.status, svc, time.Since(start)) } }) } @@ -151,26 +158,26 @@ const routeCookie = "tsp_route" // resolveRoute determines the upstream port and rewritten path for a request. // First path segment matching a slug wins (prefix stripped, affinity cookie set). // Otherwise it falls back to the affinity cookie and forwards the full path. -// Returns (port, path, ok). -func resolveRoute(store *RouteStore, r *http.Request, w http.ResponseWriter) (int, string, bool) { +// Returns (service, path, ok). +func resolveRoute(store *RouteStore, r *http.Request, w http.ResponseWriter) (Service, string, bool) { seg, rest := splitFirstSegment(r.URL.Path) if seg != "" { - if port, ok := store.lookup(seg); ok { + if svc, ok := store.lookup(seg); ok { // Remember this app for subsequent prefix-less requests. http.SetCookie(w, &http.Cookie{ Name: routeCookie, Value: seg, Path: "/", SameSite: http.SameSiteLaxMode, }) - return port, rest, true + return svc, rest, true } } // Prefix-less request (asset/API/HMR): follow the affinity cookie, forwarding // the full original path unchanged. if c, err := r.Cookie(routeCookie); err == nil && c.Value != "" { - if port, ok := store.lookup(c.Value); ok { - return port, r.URL.Path, true + if svc, ok := store.lookup(c.Value); ok { + return svc, r.URL.Path, true } } - return 0, "", false + return Service{}, "", false } // statusRecorder captures the response status code while preserving streaming @@ -200,10 +207,10 @@ func (s *statusRecorder) Flush() { } // logRequest prints one nicely formatted request line. -func logRequest(r *http.Request, status, port int, dur time.Duration) { +func logRequest(r *http.Request, status int, svc Service, dur time.Duration) { target := "—" - if port > 0 { - target = "127.0.0.1:" + strconv.Itoa(port) + if svc.Port > 0 { + target = svc.upstreamHost() + ":" + strconv.Itoa(svc.Port) } log.Printf("%s %-6s %s → %s (%s)", colorStatus(status), r.Method, r.URL.Path, target, dur.Round(time.Millisecond)) diff --git a/core/store.go b/core/store.go index 342908a..02636a5 100644 --- a/core/store.go +++ b/core/store.go @@ -30,11 +30,11 @@ func NewRouteStore(discover func() ([]Service, []Duplicate, error), deregisterCy } } -func (s *RouteStore) lookup(slug string) (int, bool) { +func (s *RouteStore) lookup(slug string) (Service, bool) { s.mu.RLock() defer s.mu.RUnlock() svc, ok := s.services[slug] - return svc.Port, ok + return svc, ok } func (s *RouteStore) snapshot() map[string]Service { @@ -77,7 +77,7 @@ func (s *RouteStore) refresh() (added, repointed []Service, removed []string, er switch { case !ok: added = append(added, svc) - case prev.Port != svc.Port: + case prev.Port != svc.Port || prev.Host != svc.Host: repointed = append(repointed, svc) } s.services[slug] = svc // register or update diff --git a/core/store_test.go b/core/store_test.go index 1caa5dd..f0ca690 100644 --- a/core/store_test.go +++ b/core/store_test.go @@ -20,14 +20,33 @@ func TestRouteStore_addAndLookup(t *testing.T) { if len(repointed) != 0 || len(removed) != 0 { t.Errorf("expected no repointed/removed, got %v %v", repointed, removed) } - if port, ok := store.lookup("a"); !ok || port != 1 { - t.Errorf("lookup(a) = (%d, %v), want (1, true)", port, ok) + if svc, ok := store.lookup("a"); !ok || svc.Port != 1 { + t.Errorf("lookup(a) = (%+v, %v), want port 1 and true", svc, ok) } if store.snapshot()["a"].Runtime != "node" { t.Errorf("snapshot[a].Runtime wrong") } } +func TestRouteStore_repointsWhenHostChanges(t *testing.T) { + stage := 0 + store := NewRouteStore(func() ([]Service, []Duplicate, error) { + stage++ + if stage == 1 { + return []Service{{Slug: "a", Host: "127.0.0.1", Port: 3000}}, nil, nil + } + return []Service{{Slug: "a", Host: "172.17.0.2", Port: 3000}}, nil, nil + }, 1) + + if added, repointed, _, _ := store.refresh(); len(added) != 1 || len(repointed) != 0 { + t.Fatalf("initial refresh added/repointed = %v/%v", added, repointed) + } + _, repointed, _, _ := store.refresh() + if len(repointed) != 1 || repointed[0].Host != "172.17.0.2" { + t.Fatalf("host change should repoint, got %+v", repointed) + } +} + func TestRouteStore_debounceDeregister(t *testing.T) { var empty bool store := NewRouteStore(func() ([]Service, []Duplicate, error) {