Skip to content

Commit 560740a

Browse files
committed
Silence context.Canceled errors from client disconnects
Add web.ServerError and web.IsClientDisconnect helpers to centralize the check, replacing repetitive inline handling across all API and UI handlers. Also adds missing error logging in the popularity handler.
1 parent 93316a3 commit 560740a

File tree

9 files changed

+207
-19
lines changed

9 files changed

+207
-19
lines changed

internal/packages/handler.go

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package packages
22

33
import (
4-
"log/slog"
54
"net/http"
65

76
"pkgstatsd/internal/web"
@@ -30,8 +29,7 @@ func (h *Handler) HandleGet(w http.ResponseWriter, r *http.Request) {
3029

3130
pkg, err := h.repo.FindByName(r.Context(), name, startMonth, endMonth)
3231
if err != nil {
33-
slog.Error("failed to find package", "error", err)
34-
web.InternalServerError(w, "internal server error")
32+
web.ServerError(w, "failed to find package", err)
3533
return
3634
}
3735

@@ -59,8 +57,7 @@ func (h *Handler) HandleList(w http.ResponseWriter, r *http.Request) {
5957

6058
list, err := h.repo.FindAll(r.Context(), query, startMonth, endMonth, limit, offset)
6159
if err != nil {
62-
slog.Error("failed to list packages", "error", err)
63-
web.InternalServerError(w, "internal server error")
60+
web.ServerError(w, "failed to list packages", err)
6461
return
6562
}
6663

@@ -88,8 +85,7 @@ func (h *Handler) HandleSeries(w http.ResponseWriter, r *http.Request) {
8885

8986
list, err := h.repo.FindSeriesByName(r.Context(), name, startMonth, endMonth, limit, offset)
9087
if err != nil {
91-
slog.Error("failed to find package series", "error", err)
92-
web.InternalServerError(w, "internal server error")
88+
web.ServerError(w, "failed to find package series", err)
9389
return
9490
}
9591

internal/packages/handler_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,66 @@ func TestHandleList_PaginationValidCases(t *testing.T) {
615615
}
616616
}
617617

618+
func TestHandleGet_ContextCanceled(t *testing.T) {
619+
repo := &mockRepository{
620+
findByNameFunc: func(_ context.Context, _ string, _, _ int) (*PackagePopularity, error) {
621+
return nil, context.Canceled
622+
},
623+
}
624+
625+
mux := newTestMux(repo)
626+
req := httptest.NewRequest(http.MethodGet, "/api/packages/pacman", nil)
627+
rr := httptest.NewRecorder()
628+
mux.ServeHTTP(rr, req)
629+
630+
if rr.Code != http.StatusOK {
631+
t.Errorf("expected no response written (status 200), got %d", rr.Code)
632+
}
633+
if rr.Body.Len() != 0 {
634+
t.Errorf("expected empty body, got %q", rr.Body.String())
635+
}
636+
}
637+
638+
func TestHandleList_ContextCanceled(t *testing.T) {
639+
repo := &mockRepository{
640+
findAllFunc: func(_ context.Context, _ string, _, _, _, _ int) (*PackagePopularityList, error) {
641+
return nil, fmt.Errorf("count packages: %w", context.Canceled)
642+
},
643+
}
644+
645+
mux := newTestMux(repo)
646+
req := httptest.NewRequest(http.MethodGet, "/api/packages", nil)
647+
rr := httptest.NewRecorder()
648+
mux.ServeHTTP(rr, req)
649+
650+
if rr.Code != http.StatusOK {
651+
t.Errorf("expected no response written (status 200), got %d", rr.Code)
652+
}
653+
if rr.Body.Len() != 0 {
654+
t.Errorf("expected empty body, got %q", rr.Body.String())
655+
}
656+
}
657+
658+
func TestHandleSeries_ContextCanceled(t *testing.T) {
659+
repo := &mockRepository{
660+
findSeriesByNameFunc: func(_ context.Context, _ string, _, _, _, _ int) (*PackagePopularityList, error) {
661+
return nil, context.Canceled
662+
},
663+
}
664+
665+
mux := newTestMux(repo)
666+
req := httptest.NewRequest(http.MethodGet, "/api/packages/pacman/series", nil)
667+
rr := httptest.NewRecorder()
668+
mux.ServeHTTP(rr, req)
669+
670+
if rr.Code != http.StatusOK {
671+
t.Errorf("expected no response written (status 200), got %d", rr.Code)
672+
}
673+
if rr.Body.Len() != 0 {
674+
t.Errorf("expected empty body, got %q", rr.Body.String())
675+
}
676+
}
677+
618678
func TestHandleList_PaginationInvalidCases(t *testing.T) {
619679
tests := []struct {
620680
name string

internal/popularity/handler.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ func (h *Handler[T, L]) HandleGet(w http.ResponseWriter, r *http.Request) {
4444

4545
item, err := h.repo.FindByIdentifier(r.Context(), identifier, startMonth, endMonth)
4646
if err != nil {
47-
web.InternalServerError(w, "internal server error")
47+
web.ServerError(w, "failed to find item", err)
4848
return
4949
}
5050

@@ -72,7 +72,7 @@ func (h *Handler[T, L]) HandleList(w http.ResponseWriter, r *http.Request) {
7272

7373
list, err := h.repo.FindAll(r.Context(), query, startMonth, endMonth, limit, offset)
7474
if err != nil {
75-
web.InternalServerError(w, "internal server error")
75+
web.ServerError(w, "failed to list items", err)
7676
return
7777
}
7878

@@ -100,7 +100,7 @@ func (h *Handler[T, L]) HandleSeries(w http.ResponseWriter, r *http.Request) {
100100

101101
list, err := h.repo.FindSeries(r.Context(), identifier, startMonth, endMonth, limit, offset)
102102
if err != nil {
103-
web.InternalServerError(w, "internal server error")
103+
web.ServerError(w, "failed to find item series", err)
104104
return
105105
}
106106

internal/sitemap/handler.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ func (h *Handler) HandleSitemap(w http.ResponseWriter, r *http.Request) {
6262

6363
list, err := h.repo.FindAll(r.Context(), "", currentMonth, currentMonth, packageLimit, 0)
6464
if err != nil {
65+
if web.IsClientDisconnect(err) {
66+
return
67+
}
6568
slog.Error("failed to fetch packages for sitemap", "error", err)
6669
} else {
6770
for _, pkg := range list.PackagePopularities {

internal/sitemap/handler_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package sitemap
22

33
import (
44
"context"
5+
"errors"
56
"net/http"
67
"net/http/httptest"
78
"strings"
@@ -13,6 +14,10 @@ import (
1314

1415
type mockRepo struct{}
1516

17+
type errorRepo struct {
18+
err error
19+
}
20+
1621
func (m *mockRepo) FindByName(_ context.Context, _ string, _, _ int) (*packages.PackagePopularity, error) {
1722
return nil, nil
1823
}
@@ -76,6 +81,57 @@ func TestHandleSitemap(t *testing.T) {
7681
}
7782
}
7883

84+
func (m *errorRepo) FindByName(_ context.Context, _ string, _, _ int) (*packages.PackagePopularity, error) {
85+
return nil, m.err
86+
}
87+
88+
func (m *errorRepo) FindSeriesByName(_ context.Context, _ string, _, _, _, _ int) (*packages.PackagePopularityList, error) {
89+
return nil, m.err
90+
}
91+
92+
func (m *errorRepo) FindAll(_ context.Context, _ string, _, _, _, _ int) (*packages.PackagePopularityList, error) {
93+
return nil, m.err
94+
}
95+
96+
func TestHandleSitemapReturnsEarlyOnClientDisconnect(t *testing.T) {
97+
handler := NewHandler(&errorRepo{err: context.Canceled})
98+
99+
req := httptest.NewRequest(http.MethodGet, "/sitemap.xml", nil)
100+
req.Host = "example.com"
101+
rr := httptest.NewRecorder()
102+
103+
handler.HandleSitemap(rr, req)
104+
105+
if rr.Code != http.StatusOK {
106+
t.Errorf("expected status 200, got %d", rr.Code)
107+
}
108+
if rr.Body.Len() != 0 {
109+
t.Errorf("expected empty body on client disconnect, got %d bytes", rr.Body.Len())
110+
}
111+
}
112+
113+
func TestHandleSitemapServesPartialOnError(t *testing.T) {
114+
handler := NewHandler(&errorRepo{err: errors.New("db error")})
115+
116+
req := httptest.NewRequest(http.MethodGet, "/sitemap.xml", nil)
117+
req.Host = "example.com"
118+
rr := httptest.NewRecorder()
119+
120+
handler.HandleSitemap(rr, req)
121+
122+
if rr.Code != http.StatusOK {
123+
t.Errorf("expected status 200, got %d", rr.Code)
124+
}
125+
126+
body := rr.Body.String()
127+
if !strings.Contains(body, "<loc>http://example.com/</loc>") {
128+
t.Error("expected static URLs in partial sitemap")
129+
}
130+
if strings.Contains(body, "<loc>http://example.com/packages/linux</loc>") {
131+
t.Error("expected no package URLs on error")
132+
}
133+
}
134+
79135
func TestLastDayOfMonth(t *testing.T) {
80136
tests := []struct {
81137
yearMonth int

internal/submit/handler.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package submit
22

33
import (
44
"fmt"
5-
"log/slog"
65
"net/http"
76
"net/netip"
87
"time"
@@ -39,8 +38,7 @@ func (h *Handler) HandleSubmit(w http.ResponseWriter, r *http.Request) {
3938
anonymizedIP := AnonymizeIP(clientIP)
4039
allowed, retryAfter, err := h.limiter.Allow(r.Context(), anonymizedIP)
4140
if err != nil {
42-
slog.Error("rate limit check failed", "error", err)
43-
web.InternalServerError(w, "internal server error")
41+
web.ServerError(w, "rate limit check failed", err)
4442
return
4543
}
4644

@@ -72,8 +70,7 @@ func (h *Handler) HandleSubmit(w http.ResponseWriter, r *http.Request) {
7270
mirrorURL := FilterMirrorURL(req.Pacman.Mirror)
7371

7472
if err := h.repo.SaveSubmission(r.Context(), req, mirrorURL); err != nil {
75-
slog.Error("failed to save submission", "error", err)
76-
web.InternalServerError(w, "failed to save submission")
73+
web.ServerError(w, "failed to save submission", err)
7774
return
7875
}
7976

internal/ui/layout/render.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
package layout
22

33
import (
4-
"context"
5-
"errors"
64
"log/slog"
75
"net/http"
86

97
"github.com/a-h/templ"
8+
9+
"pkgstatsd/internal/web"
1010
)
1111

1212
func Render(w http.ResponseWriter, r *http.Request, page Page, content templ.Component) {
@@ -24,12 +24,14 @@ func Render(w http.ResponseWriter, r *http.Request, page Page, content templ.Com
2424
}
2525

2626
if err := Base(page, content).Render(r.Context(), w); err != nil {
27-
slog.Error("failed to render page", "error", err)
27+
if !web.IsClientDisconnect(err) {
28+
slog.Error("failed to render page", "error", err)
29+
}
2830
}
2931
}
3032

3133
func ServerError(w http.ResponseWriter, msg string, err error) {
32-
if errors.Is(err, context.Canceled) {
34+
if web.IsClientDisconnect(err) {
3335
return
3436
}
3537
slog.Error(msg, "error", err)

internal/web/error.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package web
22

33
import (
4+
"context"
45
"encoding/json"
6+
"errors"
57
"log/slog"
68
"net/http"
79
"strconv"
@@ -42,6 +44,21 @@ func InternalServerError(w http.ResponseWriter, detail string) {
4244
WriteError(w, http.StatusInternalServerError, detail)
4345
}
4446

47+
// ServerError logs the error and writes a 500 response, unless the error
48+
// is due to a canceled context (client disconnect), in which case it's a no-op.
49+
func ServerError(w http.ResponseWriter, msg string, err error) {
50+
if errors.Is(err, context.Canceled) {
51+
return
52+
}
53+
slog.Error(msg, "error", err)
54+
InternalServerError(w, "internal server error")
55+
}
56+
57+
// IsClientDisconnect reports whether the error is due to a canceled context.
58+
func IsClientDisconnect(err error) bool {
59+
return errors.Is(err, context.Canceled)
60+
}
61+
4562
func TooManyRequests(w http.ResponseWriter, detail string, retryAfterSeconds int) {
4663
w.Header().Set("Retry-After", strconv.Itoa(retryAfterSeconds))
4764
WriteError(w, http.StatusTooManyRequests, detail)

internal/web/error_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
package web
22

33
import (
4+
"context"
45
"encoding/json"
6+
"errors"
7+
"fmt"
58
"net/http"
69
"net/http/httptest"
710
"testing"
@@ -50,4 +53,58 @@ func TestErrorResponses(t *testing.T) {
5053
t.Errorf("got Retry-After %s", rr.Header().Get("Retry-After"))
5154
}
5255
})
56+
57+
t.Run("ServerError writes 500 for real errors", func(t *testing.T) {
58+
rr := httptest.NewRecorder()
59+
ServerError(rr, "db failed", errors.New("connection refused"))
60+
if rr.Code != http.StatusInternalServerError {
61+
t.Errorf("got %d, want %d", rr.Code, http.StatusInternalServerError)
62+
}
63+
if ct := rr.Header().Get("Content-Type"); ct != "application/problem+json" {
64+
t.Errorf("got Content-Type %q", ct)
65+
}
66+
if cc := rr.Header().Get("Cache-Control"); cc != "no-store" {
67+
t.Errorf("got Cache-Control %q", cc)
68+
}
69+
})
70+
71+
t.Run("ServerError is no-op for context.Canceled", func(t *testing.T) {
72+
rr := httptest.NewRecorder()
73+
ServerError(rr, "should not log", context.Canceled)
74+
if rr.Code != http.StatusOK {
75+
t.Errorf("got %d, want %d (no response written)", rr.Code, http.StatusOK)
76+
}
77+
if rr.Body.Len() != 0 {
78+
t.Errorf("expected empty body, got %q", rr.Body.String())
79+
}
80+
})
81+
82+
t.Run("ServerError is no-op for wrapped context.Canceled", func(t *testing.T) {
83+
rr := httptest.NewRecorder()
84+
ServerError(rr, "should not log", fmt.Errorf("query: %w", context.Canceled))
85+
if rr.Code != http.StatusOK {
86+
t.Errorf("got %d, want %d (no response written)", rr.Code, http.StatusOK)
87+
}
88+
})
89+
}
90+
91+
func TestIsClientDisconnect(t *testing.T) {
92+
tests := []struct {
93+
name string
94+
err error
95+
want bool
96+
}{
97+
{"context.Canceled", context.Canceled, true},
98+
{"wrapped context.Canceled", fmt.Errorf("scan: %w", context.Canceled), true},
99+
{"generic error", errors.New("timeout"), false},
100+
{"nil", nil, false},
101+
}
102+
103+
for _, tt := range tests {
104+
t.Run(tt.name, func(t *testing.T) {
105+
if got := IsClientDisconnect(tt.err); got != tt.want {
106+
t.Errorf("IsClientDisconnect(%v) = %v, want %v", tt.err, got, tt.want)
107+
}
108+
})
109+
}
53110
}

0 commit comments

Comments
 (0)