Skip to content

Commit 010110c

Browse files
authored
Bug fixes for endpoints (#155)
1 parent 78cb82f commit 010110c

5 files changed

Lines changed: 160 additions & 14 deletions

File tree

pkg/cmd/global.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
kong "github.com/alecthomas/kong"
1515
client "github.com/mutablelogic/go-client"
1616
server "github.com/mutablelogic/go-server"
17+
types "github.com/mutablelogic/go-server/pkg/types"
1718
metric "go.opentelemetry.io/otel/metric"
1819
trace "go.opentelemetry.io/otel/trace"
1920
)
@@ -95,7 +96,7 @@ func (g *global) Meter() metric.Meter {
9596
// ClientEndpoint returns the HTTP endpoint URL and client options derived
9697
// from the global HTTP flags.
9798
func (g *global) ClientEndpoint() (string, []client.ClientOpt, error) {
98-
scheme := "http"
99+
scheme := types.SchemeInsecure
99100
host, port, err := net.SplitHostPort(g.HTTP.Addr)
100101
if err != nil {
101102
return "", nil, err
@@ -107,8 +108,14 @@ func (g *global) ClientEndpoint() (string, []client.ClientOpt, error) {
107108
if err != nil {
108109
return "", nil, err
109110
}
110-
if portn == 443 {
111-
scheme = "https"
111+
hostaddr := net.JoinHostPort(host, strconv.FormatUint(portn, 10))
112+
switch portn {
113+
case 80:
114+
scheme = types.SchemeInsecure
115+
hostaddr = host
116+
case 443:
117+
scheme = types.SchemeSecure
118+
hostaddr = host
112119
}
113120
opts := []client.ClientOpt{}
114121
if g.Debug || g.Verbose {
@@ -120,7 +127,7 @@ func (g *global) ClientEndpoint() (string, []client.ClientOpt, error) {
120127
if g.HTTP.Timeout > 0 {
121128
opts = append(opts, client.OptTimeout(g.HTTP.Timeout))
122129
}
123-
return fmt.Sprintf("%s://%s%s", scheme, net.JoinHostPort(host, strconv.FormatUint(portn, 10)), g.HTTP.Prefix), opts, nil
130+
return fmt.Sprintf("%s://%s%s", scheme, hostaddr, g.HTTP.Prefix), opts, nil
124131
}
125132

126133
func (g *global) URL() *url.URL {

pkg/cmd/server.go

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ import (
1616
errgroup "golang.org/x/sync/errgroup"
1717
)
1818

19-
// RegisterFunc is called after the router is created but before the server
20-
// starts listening. Use it to add routes and wire up handlers.
19+
// RegisterFunc is called after the server URL has been resolved but before the
20+
// server starts serving requests. Use it to add routes and wire up handlers.
2121
type RegisterFunc func(*httprouter.Router) error
2222

2323
///////////////////////////////////////////////////////////////////////////////
@@ -87,11 +87,19 @@ func (s *RunServer) Run(ctx server.Cmd) error {
8787
tlsCfg = &tls.Config{ServerName: s.TLS.ServerName}
8888
}
8989

90-
// Create a new server and start listening. The server will run until the context is cancelled.
90+
// Create a new server. The listener is bound before registration so callbacks
91+
// can rely on the final advertised URL, including :0 port allocation.
9192
srv, err := httpserver.New(ctx.HTTPAddr(), tlsCfg, serverOpts...)
9293
if err != nil {
9394
return fmt.Errorf("httpserver: %w", err)
94-
} else if url := srv.URL(); url != nil {
95+
}
96+
97+
// Bind to the server's address to ensure the advertised URL is final before
98+
// route registration runs.
99+
if err := srv.Listen(); err != nil {
100+
return err
101+
}
102+
if url := srv.URL(); url != nil {
95103
url.Path = types.NormalisePath(ctx.HTTPPrefix())
96104
ctx.(*global).url = url
97105
}
@@ -126,11 +134,6 @@ func (s *RunServer) Run(ctx server.Cmd) error {
126134
return fmt.Errorf("catchall: %w", err)
127135
}
128136

129-
// Bind to the server's address to ensure it's available
130-
if err := srv.Listen(); err != nil {
131-
return err
132-
}
133-
134137
// Set the server URL in the OpenAPI spec now that the bound address is known.
135138
// When a TLS server name is configured (e.g. behind a reverse proxy) it is
136139
// used as the public hostname; otherwise the bound listen address is used.

pkg/cmd/server_test.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"log/slog"
66
"os"
7+
"strings"
78
"testing"
89
"time"
910

@@ -128,3 +129,94 @@ func Test_ClientEndpoint_Verbose(t *testing.T) {
128129
t.Error("expected opts when Verbose=true, got none")
129130
}
130131
}
132+
133+
func Test_RunServer_AdvertisedURL(t *testing.T) {
134+
ctx, cancel := context.WithCancel(context.Background())
135+
defer cancel()
136+
137+
g := &global{
138+
ctx: ctx,
139+
logger: slog.New(slog.NewTextHandler(os.Stderr, nil)),
140+
}
141+
g.HTTP.Addr = ":0"
142+
g.HTTP.Prefix = "/api"
143+
144+
s := &RunServer{}
145+
var registerURL string
146+
s.Register(func(_ *httprouter.Router) error {
147+
if url := g.URL(); url != nil {
148+
registerURL = url.String()
149+
}
150+
return nil
151+
})
152+
153+
done := make(chan error, 1)
154+
go func() {
155+
done <- s.Run(g)
156+
}()
157+
158+
time.Sleep(100 * time.Millisecond)
159+
if registerURL == "" {
160+
t.Fatal("expected URL to be available during route registration")
161+
}
162+
if !strings.HasPrefix(registerURL, "http://localhost:") {
163+
t.Fatalf("register URL = %q, want localhost prefix", registerURL)
164+
}
165+
if strings.Contains(registerURL, ":0/") {
166+
t.Fatalf("register URL = %q, want bound port instead of :0", registerURL)
167+
}
168+
169+
if url := g.URL(); url == nil {
170+
t.Fatal("expected URL after listen")
171+
} else if strings.Contains(url.String(), ":0/") {
172+
t.Fatalf("final URL = %q, want bound port", url.String())
173+
}
174+
175+
cancel()
176+
if err := <-done; err != nil {
177+
t.Fatal(err)
178+
}
179+
}
180+
181+
func Test_RunServer_AdvertisedURL_TLSName(t *testing.T) {
182+
ctx, cancel := context.WithCancel(context.Background())
183+
defer cancel()
184+
185+
g := &global{
186+
ctx: ctx,
187+
logger: slog.New(slog.NewTextHandler(os.Stderr, nil)),
188+
}
189+
g.HTTP.Addr = ":0"
190+
g.HTTP.Prefix = "/api"
191+
192+
s := &RunServer{}
193+
s.TLS.ServerName = "auth.example.com"
194+
var registerURL string
195+
s.Register(func(_ *httprouter.Router) error {
196+
if url := g.URL(); url != nil {
197+
registerURL = url.String()
198+
}
199+
return nil
200+
})
201+
202+
done := make(chan error, 1)
203+
go func() {
204+
done <- s.Run(g)
205+
}()
206+
207+
time.Sleep(100 * time.Millisecond)
208+
if registerURL != "https://auth.example.com/api" {
209+
t.Fatalf("register URL = %q, want %q", registerURL, "https://auth.example.com/api")
210+
}
211+
212+
if url := g.URL(); url == nil {
213+
t.Fatal("expected URL after listen")
214+
} else if got := url.String(); got != "https://auth.example.com/api" {
215+
t.Fatalf("final URL = %q, want %q", got, "https://auth.example.com/api")
216+
}
217+
218+
cancel()
219+
if err := <-done; err != nil {
220+
t.Fatal(err)
221+
}
222+
}

pkg/httpserver/httpserver.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"errors"
77
"net"
88
"net/http"
9+
"net/netip"
910
"net/url"
1011
"strings"
1112
"sync"
@@ -129,7 +130,7 @@ func (server *server) URL() *url.URL {
129130
var url url.URL
130131
url.Scheme = defaultListenPortHttp
131132
url.Path = "/"
132-
url.Host = server.Addr()
133+
url.Host = normalizeURLHost(server.Addr())
133134
if server.serverName != "" {
134135
url.Scheme = defaultListenPortHttps
135136
url.Host = server.serverName
@@ -203,3 +204,23 @@ func (server *server) shutdown() error {
203204
return err
204205
}
205206
}
207+
208+
func normalizeURLHost(addr string) string {
209+
host, port, err := net.SplitHostPort(addr)
210+
if err != nil {
211+
return addr
212+
}
213+
if host == "" || isUnspecifiedHost(host) {
214+
host = defaultListenHost
215+
}
216+
return net.JoinHostPort(host, port)
217+
}
218+
219+
func isUnspecifiedHost(host string) bool {
220+
host = strings.TrimPrefix(strings.TrimSuffix(strings.TrimSpace(host), "]"), "[")
221+
if host == "" {
222+
return true
223+
}
224+
addr, err := netip.ParseAddr(host)
225+
return err == nil && addr.IsUnspecified()
226+
}

pkg/httpserver/httpserver_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,29 @@ func Test_Listen_002(t *testing.T) {
361361
assert.Equal("localhost:9999", s.Addr())
362362
}
363363

364+
func Test_URL_001(t *testing.T) {
365+
assert := assert.New(t)
366+
367+
// Empty listen host should advertise localhost rather than a bare :port URL.
368+
s, err := httpserver.New(":8084", nil)
369+
assert.NoError(err)
370+
if assert.NotNil(s.URL()) {
371+
assert.Equal("http://localhost:8084/", s.URL().String())
372+
}
373+
}
374+
375+
func Test_URL_002(t *testing.T) {
376+
assert := assert.New(t)
377+
378+
// Wildcard listener addresses should advertise localhost for local development.
379+
s, err := httpserver.New(":0", nil)
380+
assert.NoError(err)
381+
assert.NoError(s.Listen())
382+
if assert.NotNil(s.URL()) {
383+
assert.Contains(s.URL().String(), "http://localhost:")
384+
}
385+
}
386+
364387
///////////////////////////////////////////////////////////////////////////////
365388
// TESTS - RUN
366389

0 commit comments

Comments
 (0)