Skip to content

Commit eeb000a

Browse files
committed
fix(handler): drain channels after timeout to capture final measurement
1 parent 772318a commit eeb000a

2 files changed

Lines changed: 114 additions & 16 deletions

File tree

internal/handler/handler.go

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,29 @@ func (h *Handler) upgradeAndRunMeasurement(kind model.TestDirection, rw http.Res
258258
for {
259259
select {
260260
case <-timeout.Done():
261-
// If the test has timed out count it as a success and return.
261+
// The test has timed out. Before returning, drain any
262+
// remaining measurements from the sender/receiver
263+
// goroutines. The sender sends a final Measure() on
264+
// ctx.Done() which may still be in-flight.
265+
drainTimeout := time.After(500 * time.Millisecond)
266+
for draining := true; draining; {
267+
select {
268+
case m := <-senderCh:
269+
if kind == model.DirectionDownload && m.CC != "" {
270+
archivalData.CCAlgorithm = m.CC
271+
}
272+
archivalData.ServerMeasurements = append(
273+
archivalData.ServerMeasurements, m.Measurement)
274+
case m := <-receiverCh:
275+
if kind == model.DirectionUpload && m.CC != "" {
276+
archivalData.CCAlgorithm = m.CC
277+
}
278+
archivalData.ClientMeasurements = append(
279+
archivalData.ClientMeasurements, m.Measurement)
280+
case <-drainTimeout:
281+
draining = false
282+
}
283+
}
262284
testsTotal.WithLabelValues(string(kind), "ok-timeout").Inc()
263285
return
264286
case m := <-senderCh:

internal/handler/handler_test.go

Lines changed: 91 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@ package handler_test
22

33
import (
44
"context"
5+
"encoding/json"
6+
"math"
57
"net"
68
"net/http"
79
"net/http/httptest"
810
"net/url"
911
"os"
12+
"path/filepath"
1013
"strings"
1114
"testing"
1215
"time"
@@ -83,13 +86,7 @@ func TestHandler_Upload(t *testing.T) {
8386
drain(t, timeout, senderCh, receiverCh, errCh)
8487

8588
// Check that the output JSON file has been created.
86-
files, err := os.ReadDir(tempDir)
87-
if err != nil {
88-
t.Fatalf("reading output folder failed: %v", err)
89-
}
90-
if len(files) != 1 {
91-
t.Fatalf("invalid number of files in output folder")
92-
}
89+
waitForArchivalFile(t, tempDir, 2*time.Second)
9390
}
9491

9592
func TestHandler_Download(t *testing.T) {
@@ -124,19 +121,13 @@ func TestHandler_Download(t *testing.T) {
124121
}
125122

126123
proto := throughput1.New(conn)
127-
timeout, cancel := context.WithTimeout(context.Background(), 1*time.Second)
124+
timeout, cancel := context.WithTimeout(context.Background(), 2*time.Second)
128125
defer cancel()
129126
senderCh, receiverCh, errCh := proto.ReceiverLoop(timeout)
130127
drain(t, timeout, senderCh, receiverCh, errCh)
131128

132129
// Check that the output JSON file has been created.
133-
files, err := os.ReadDir(tempDir)
134-
if err != nil {
135-
t.Fatalf("reading output folder failed: %v", err)
136-
}
137-
if len(files) != 1 {
138-
t.Fatalf("invalid number of files in output folder")
139-
}
130+
waitForArchivalFile(t, tempDir, 2*time.Second)
140131
}
141132

142133
func TestHandler_DownloadInvalidCC(t *testing.T) {
@@ -169,6 +160,29 @@ func TestHandler_DownloadInvalidCC(t *testing.T) {
169160
}
170161
}
171162

163+
// waitForArchivalFile polls until at least one JSON file appears in the
164+
// directory tree, or the timeout is exceeded. The drain loop in the handler
165+
// delays the deferred writeResult, so we need to poll.
166+
func waitForArchivalFile(t *testing.T, dir string, timeout time.Duration) string {
167+
t.Helper()
168+
deadline := time.Now().Add(timeout)
169+
for time.Now().Before(deadline) {
170+
var found string
171+
filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
172+
if err == nil && !info.IsDir() && filepath.Ext(path) == ".json" {
173+
found = path
174+
}
175+
return nil
176+
})
177+
if found != "" {
178+
return found
179+
}
180+
time.Sleep(50 * time.Millisecond)
181+
}
182+
t.Fatalf("no archival JSON file found in %s within %v", dir, timeout)
183+
return ""
184+
}
185+
172186
// Utility function to drain sender/receiver channels in tests.
173187
func drain(t *testing.T, timeout context.Context, senderCh,
174188
receiverCh <-chan model.WireMeasurement, errCh <-chan error) {
@@ -189,6 +203,68 @@ func drain(t *testing.T, timeout context.Context, senderCh,
189203
}
190204
}
191205

206+
func TestHandler_DownloadFinalMeasurement(t *testing.T) {
207+
tempDir := t.TempDir()
208+
h := handler.New(tempDir)
209+
210+
server := setupTestServer(tempDir, http.HandlerFunc(h.Download))
211+
server.Start()
212+
defer server.Close()
213+
214+
u, err := url.Parse(server.URL)
215+
rtx.Must(err, "cannot get server URL")
216+
u.Scheme = "ws"
217+
q := u.Query()
218+
q.Add("mid", "test-mid")
219+
q.Add("streams", "1")
220+
q.Add("duration", "500")
221+
u.RawQuery = q.Encode()
222+
223+
dialer := setupTestWSDialer(u)
224+
225+
headers := http.Header{}
226+
headers.Add("Sec-WebSocket-Protocol", spec.SecWebSocketProtocol)
227+
228+
conn, _, err := dialer.Dial(u.String(), headers)
229+
if err != nil {
230+
t.Fatalf("websocket dial failed: %v", err)
231+
}
232+
proto := throughput1.New(conn)
233+
timeout, cancel := context.WithTimeout(context.Background(), 2*time.Second)
234+
defer cancel()
235+
senderCh, receiverCh, errCh := proto.ReceiverLoop(timeout)
236+
drain(t, timeout, senderCh, receiverCh, errCh)
237+
238+
// Wait for the archival JSON file to be written.
239+
jsonFile := waitForArchivalFile(t, tempDir, 2*time.Second)
240+
241+
data, err := os.ReadFile(jsonFile)
242+
if err != nil {
243+
t.Fatalf("failed to read archival file: %v", err)
244+
}
245+
246+
var result model.Throughput1Result
247+
if err := json.Unmarshal(data, &result); err != nil {
248+
t.Fatalf("failed to unmarshal archival data: %v", err)
249+
}
250+
251+
if len(result.ServerMeasurements) == 0 {
252+
t.Fatalf("expected at least one server measurement")
253+
}
254+
255+
// The last server measurement's ElapsedTime should be close to the
256+
// requested duration (500ms = 500_000 microseconds). Allow 100ms
257+
// tolerance.
258+
last := result.ServerMeasurements[len(result.ServerMeasurements)-1]
259+
requestedDurationUs := int64(500_000) // 500ms in microseconds
260+
toleranceUs := int64(100_000) // 100ms
261+
diff := int64(math.Abs(float64(last.ElapsedTime - requestedDurationUs)))
262+
if diff > toleranceUs {
263+
t.Errorf("last ServerMeasurement.ElapsedTime = %d us, want within %d us of %d us (diff = %d us)",
264+
last.ElapsedTime, toleranceUs, requestedDurationUs, diff)
265+
}
266+
}
267+
192268
func TestHandler_Validation(t *testing.T) {
193269
// This string exceeds the maximum metadata key length.
194270
longKey := strings.Repeat("longkey", 10)

0 commit comments

Comments
 (0)