Skip to content

Commit a4d2d84

Browse files
committed
refactor: extract shared retry package from pollUntil
1 parent 206fe28 commit a4d2d84

6 files changed

Lines changed: 217 additions & 101 deletions

File tree

cmd/create.go

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package cmd
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67
"io"
78
"strings"
@@ -11,6 +12,7 @@ import (
1112
"github.com/spf13/cobra"
1213

1314
"github.com/deevus/pixels/internal/cache"
15+
"github.com/deevus/pixels/internal/retry"
1416
"github.com/deevus/pixels/internal/ssh"
1517
tnc "github.com/deevus/pixels/internal/truenas"
1618
)
@@ -183,15 +185,15 @@ func runCreate(cmd *cobra.Command, args []string) error {
183185

184186
// Poll for IP — DHCP assignment takes a few seconds after start.
185187
logv(cmd, "Waiting for IP assignment...")
186-
for range 15 {
187-
instance, err = client.Virt.GetInstance(ctx, containerName(name))
188+
if err := retry.Poll(ctx, time.Second, 15*time.Second, func(ctx context.Context) (bool, error) {
189+
inst, err := client.Virt.GetInstance(ctx, containerName(name))
188190
if err != nil {
189-
return fmt.Errorf("refreshing instance: %w", err)
191+
return false, fmt.Errorf("refreshing instance: %w", err)
190192
}
191-
if resolveIP(instance) != "" {
192-
break
193-
}
194-
time.Sleep(time.Second)
193+
instance = inst
194+
return resolveIP(instance) != "", nil
195+
}); err != nil && !errors.Is(err, retry.ErrTimeout) {
196+
return err
195197
}
196198
}
197199

@@ -231,15 +233,15 @@ func runCreate(cmd *cobra.Command, args []string) error {
231233
}
232234
// Poll for IP — DHCP assignment takes a few seconds after restart.
233235
logv(cmd, "Waiting for IP assignment...")
234-
for range 15 {
235-
instance, err = client.Virt.GetInstance(ctx, containerName(name))
236+
if err := retry.Poll(ctx, time.Second, 15*time.Second, func(ctx context.Context) (bool, error) {
237+
inst, err := client.Virt.GetInstance(ctx, containerName(name))
236238
if err != nil {
237-
return fmt.Errorf("refreshing instance: %w", err)
238-
}
239-
if resolveIP(instance) != "" {
240-
break
239+
return false, fmt.Errorf("refreshing instance: %w", err)
241240
}
242-
time.Sleep(time.Second)
241+
instance = inst
242+
return resolveIP(instance) != "", nil
243+
}); err != nil && !errors.Is(err, retry.ErrTimeout) {
244+
return err
243245
}
244246
}
245247
}

cmd/destroy.go

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package cmd
22

33
import (
44
"bufio"
5+
"context"
56
"fmt"
67
"os"
78
"strings"
@@ -11,6 +12,7 @@ import (
1112
"github.com/spf13/cobra"
1213

1314
"github.com/deevus/pixels/internal/cache"
15+
"github.com/deevus/pixels/internal/retry"
1416
)
1517

1618
func init() {
@@ -67,18 +69,10 @@ func runDestroy(cmd *cobra.Command, args []string) error {
6769

6870
// Retry delete — Incus sometimes needs a moment after stop for the
6971
// storage volume to be fully released.
70-
var deleteErr error
71-
for i := range 3 {
72-
if i > 0 {
73-
time.Sleep(2 * time.Second)
74-
}
75-
deleteErr = client.Virt.DeleteInstance(ctx, containerName(name))
76-
if deleteErr == nil {
77-
break
78-
}
79-
}
80-
if deleteErr != nil {
81-
return fmt.Errorf("deleting %s: %w", name, deleteErr)
72+
if err := retry.Do(ctx, 3, 2*time.Second, func(ctx context.Context) error {
73+
return client.Virt.DeleteInstance(ctx, containerName(name))
74+
}); err != nil {
75+
return fmt.Errorf("deleting %s: %w", name, err)
8276
}
8377

8478
cache.Delete(name)

internal/retry/retry.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package retry
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"time"
8+
)
9+
10+
// ErrTimeout is returned by Poll when the deadline expires without success.
11+
var ErrTimeout = errors.New("poll timed out")
12+
13+
// Poll calls fn at the given interval until it returns (true, nil), a non-nil
14+
// error (fatal — stop immediately), or the timeout/context expires.
15+
func Poll(ctx context.Context, interval, timeout time.Duration, fn func(ctx context.Context) (bool, error)) error {
16+
deadline := time.After(timeout)
17+
ticker := time.NewTicker(interval)
18+
defer ticker.Stop()
19+
20+
for {
21+
done, err := fn(ctx)
22+
if err != nil {
23+
return err
24+
}
25+
if done {
26+
return nil
27+
}
28+
29+
select {
30+
case <-ctx.Done():
31+
return ctx.Err()
32+
case <-deadline:
33+
return fmt.Errorf("%w after %s", ErrTimeout, timeout)
34+
case <-ticker.C:
35+
}
36+
}
37+
}
38+
39+
// Do calls fn up to attempts times, waiting delay between retries.
40+
// It returns nil on first success, or the last error if all attempts fail.
41+
// The delay between retries is context-aware.
42+
func Do(ctx context.Context, attempts int, delay time.Duration, fn func(ctx context.Context) error) error {
43+
var lastErr error
44+
for i := range attempts {
45+
if i > 0 {
46+
select {
47+
case <-ctx.Done():
48+
return ctx.Err()
49+
case <-time.After(delay):
50+
}
51+
}
52+
lastErr = fn(ctx)
53+
if lastErr == nil {
54+
return nil
55+
}
56+
}
57+
return lastErr
58+
}

internal/retry/retry_test.go

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
package retry
2+
3+
import (
4+
"context"
5+
"errors"
6+
"testing"
7+
"time"
8+
)
9+
10+
// ---------------------------------------------------------------------------
11+
// Poll tests
12+
// ---------------------------------------------------------------------------
13+
14+
func TestPoll_SuccessOnFirstCheck(t *testing.T) {
15+
err := Poll(context.Background(), 10*time.Millisecond, time.Second, func(_ context.Context) (bool, error) {
16+
return true, nil
17+
})
18+
if err != nil {
19+
t.Fatalf("expected nil, got %v", err)
20+
}
21+
}
22+
23+
func TestPoll_SuccessAfterSeveralChecks(t *testing.T) {
24+
calls := 0
25+
err := Poll(context.Background(), 10*time.Millisecond, time.Second, func(_ context.Context) (bool, error) {
26+
calls++
27+
return calls >= 3, nil
28+
})
29+
if err != nil {
30+
t.Fatalf("expected nil, got %v", err)
31+
}
32+
if calls < 3 {
33+
t.Fatalf("expected at least 3 calls, got %d", calls)
34+
}
35+
}
36+
37+
func TestPoll_Timeout(t *testing.T) {
38+
err := Poll(context.Background(), 10*time.Millisecond, 50*time.Millisecond, func(_ context.Context) (bool, error) {
39+
return false, nil
40+
})
41+
if err == nil {
42+
t.Fatal("expected timeout error, got nil")
43+
}
44+
if !errors.Is(err, ErrTimeout) {
45+
t.Fatalf("expected ErrTimeout, got %v", err)
46+
}
47+
}
48+
49+
func TestPoll_ContextCancellation(t *testing.T) {
50+
ctx, cancel := context.WithCancel(context.Background())
51+
cancel()
52+
53+
err := Poll(ctx, 10*time.Millisecond, time.Second, func(_ context.Context) (bool, error) {
54+
return false, nil
55+
})
56+
if !errors.Is(err, context.Canceled) {
57+
t.Fatalf("expected context.Canceled, got %v", err)
58+
}
59+
}
60+
61+
func TestPoll_FatalErrorStopsPolling(t *testing.T) {
62+
fatal := errors.New("fatal failure")
63+
calls := 0
64+
err := Poll(context.Background(), 10*time.Millisecond, time.Second, func(_ context.Context) (bool, error) {
65+
calls++
66+
return false, fatal
67+
})
68+
if !errors.Is(err, fatal) {
69+
t.Fatalf("expected fatal error, got %v", err)
70+
}
71+
if calls != 1 {
72+
t.Fatalf("expected 1 call, got %d", calls)
73+
}
74+
}
75+
76+
// ---------------------------------------------------------------------------
77+
// Do tests
78+
// ---------------------------------------------------------------------------
79+
80+
func TestDo_SuccessOnFirstAttempt(t *testing.T) {
81+
err := Do(context.Background(), 3, 10*time.Millisecond, func(_ context.Context) error {
82+
return nil
83+
})
84+
if err != nil {
85+
t.Fatalf("expected nil, got %v", err)
86+
}
87+
}
88+
89+
func TestDo_SuccessAfterRetries(t *testing.T) {
90+
calls := 0
91+
err := Do(context.Background(), 5, 10*time.Millisecond, func(_ context.Context) error {
92+
calls++
93+
if calls < 3 {
94+
return errors.New("not yet")
95+
}
96+
return nil
97+
})
98+
if err != nil {
99+
t.Fatalf("expected nil, got %v", err)
100+
}
101+
if calls != 3 {
102+
t.Fatalf("expected 3 calls, got %d", calls)
103+
}
104+
}
105+
106+
func TestDo_AllAttemptsExhausted(t *testing.T) {
107+
lastErr := errors.New("persistent failure")
108+
err := Do(context.Background(), 3, 10*time.Millisecond, func(_ context.Context) error {
109+
return lastErr
110+
})
111+
if !errors.Is(err, lastErr) {
112+
t.Fatalf("expected last error, got %v", err)
113+
}
114+
}
115+
116+
func TestDo_ContextCancellationDuringDelay(t *testing.T) {
117+
ctx, cancel := context.WithCancel(context.Background())
118+
calls := 0
119+
err := Do(ctx, 5, time.Second, func(_ context.Context) error {
120+
calls++
121+
if calls == 1 {
122+
// Cancel during the delay before the next retry.
123+
cancel()
124+
}
125+
return errors.New("fail")
126+
})
127+
if !errors.Is(err, context.Canceled) {
128+
t.Fatalf("expected context.Canceled, got %v", err)
129+
}
130+
if calls != 1 {
131+
t.Fatalf("expected 1 call before cancellation, got %d", calls)
132+
}
133+
}

internal/ssh/ssh.go

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99
"os"
1010
"os/exec"
1111
"time"
12+
13+
"github.com/deevus/pixels/internal/retry"
1214
)
1315

1416
// WaitReady polls the host's SSH port until it accepts connections or the timeout expires.
@@ -71,33 +73,12 @@ func Output(ctx context.Context, host, user, keyPath string, command []string) (
7173

7274
// WaitProvisioned polls the remote host until /root/.devtools-provisioned exists.
7375
func WaitProvisioned(ctx context.Context, host, user, keyPath string, timeout time.Duration) error {
74-
return pollUntil(ctx, 2*time.Second, timeout, func(ctx context.Context) bool {
76+
return retry.Poll(ctx, 2*time.Second, timeout, func(ctx context.Context) (bool, error) {
7577
code, err := Exec(ctx, host, user, keyPath, []string{"sudo", "test", "-f", "/root/.devtools-provisioned"})
76-
return err == nil && code == 0
78+
return err == nil && code == 0, nil
7779
})
7880
}
7981

80-
// pollUntil calls checkFn at the given interval until it returns true or the
81-
// timeout/context expires.
82-
func pollUntil(ctx context.Context, interval, timeout time.Duration, checkFn func(ctx context.Context) bool) error {
83-
deadline := time.After(timeout)
84-
ticker := time.NewTicker(interval)
85-
defer ticker.Stop()
86-
87-
for {
88-
select {
89-
case <-ctx.Done():
90-
return ctx.Err()
91-
case <-deadline:
92-
return fmt.Errorf("timed out after %s", timeout)
93-
case <-ticker.C:
94-
if checkFn(ctx) {
95-
return nil
96-
}
97-
}
98-
}
99-
}
100-
10182
// TestAuth runs a quick SSH connection test (ssh ... true) to verify
10283
// key-based authentication works. Returns nil on success.
10384
func TestAuth(ctx context.Context, host, user, keyPath string) error {

0 commit comments

Comments
 (0)