Skip to content

Commit 7d26ce6

Browse files
committed
Align scale-to-zero internal terminology with OpenAPI spec
Rename refcounted hold methods to Acquire/Release so that Disable/Enable can carry the idempotent persistent-toggle semantics defined by the /scaletozero/{disable,enable} API. Split the low-level direct toggle out into a separate Toggler interface (unikraftCloudToggler) wrapped by DebouncedController.
1 parent 101e27c commit 7d26ce6

8 files changed

Lines changed: 260 additions & 216 deletions

File tree

server/cmd/api/api/api.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ type ApiService struct {
4949

5050
// DevTools upstream manager (Chromium supervisord log tailer)
5151
upstreamMgr *devtoolsproxy.UpstreamManager
52-
stz scaletozero.PinnedController
52+
stz scaletozero.Controller
5353

5454
// inputMu serializes input-related operations (mouse, keyboard, screenshot)
5555
inputMu sync.Mutex
@@ -96,7 +96,7 @@ func New(
9696
recordManager recorder.RecordManager,
9797
factory recorder.FFmpegRecorderFactory,
9898
upstreamMgr *devtoolsproxy.UpstreamManager,
99-
stz scaletozero.PinnedController,
99+
stz scaletozero.Controller,
100100
nekoAuthClient *nekoclient.AuthClient,
101101
captureSession *capturesession.CaptureSession,
102102
eventStream *events.EventStream,

server/cmd/api/api/scaletozero.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@ import (
88
)
99

1010
func (s *ApiService) DisableScaleToZero(ctx context.Context, _ oapi.DisableScaleToZeroRequestObject) (oapi.DisableScaleToZeroResponseObject, error) {
11-
if err := s.stz.Pin(ctx); err != nil {
11+
if err := s.stz.Disable(ctx); err != nil {
1212
logger.FromContext(ctx).Error("failed to disable scale-to-zero", "err", err)
1313
return oapi.DisableScaleToZero500JSONResponse{InternalErrorJSONResponse: oapi.InternalErrorJSONResponse{Message: "failed to disable scale-to-zero"}}, nil
1414
}
1515
return oapi.DisableScaleToZero204Response{}, nil
1616
}
1717

1818
func (s *ApiService) EnableScaleToZero(ctx context.Context, _ oapi.EnableScaleToZeroRequestObject) (oapi.EnableScaleToZeroResponseObject, error) {
19-
if err := s.stz.Unpin(ctx); err != nil {
19+
if err := s.stz.Enable(ctx); err != nil {
2020
logger.FromContext(ctx).Error("failed to enable scale-to-zero", "err", err)
2121
return oapi.EnableScaleToZero500JSONResponse{InternalErrorJSONResponse: oapi.InternalErrorJSONResponse{Message: "failed to enable scale-to-zero"}}, nil
2222
}

server/cmd/api/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ func main() {
5151
// ensure ffmpeg is available
5252
mustFFmpeg()
5353

54-
stz := scaletozero.NewDebouncedControllerWithCooldown(scaletozero.NewUnikraftCloudController(), config.ScaleToZeroCooldown)
54+
stz := scaletozero.NewDebouncedControllerWithCooldown(scaletozero.NewUnikraftCloudToggler(), config.ScaleToZeroCooldown)
5555
r := chi.NewRouter()
5656
r.Use(
5757
chiMiddleware.Logger,

server/lib/recorder/ffmpeg.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,9 @@ func (fr *FFmpegRecorder) Start(ctx context.Context) error {
183183
return fmt.Errorf("recording already in progress")
184184
}
185185

186-
if err := fr.stz.Disable(ctx); err != nil {
186+
if err := fr.stz.Acquire(ctx); err != nil {
187187
fr.mu.Unlock()
188-
return fmt.Errorf("failed to disable scale-to-zero: %w", err)
188+
return fmt.Errorf("failed to acquire scale-to-zero hold: %w", err)
189189
}
190190

191191
// ensure internal state
@@ -196,7 +196,7 @@ func (fr *FFmpegRecorder) Start(ctx context.Context) error {
196196

197197
args, err := ffmpegArgs(fr.params, fr.outputPath)
198198
if err != nil {
199-
_ = fr.stz.Enable(context.WithoutCancel(ctx))
199+
_ = fr.stz.Release(context.WithoutCancel(ctx))
200200
fr.cmd = nil
201201
close(fr.exited)
202202
fr.mu.Unlock()
@@ -214,7 +214,7 @@ func (fr *FFmpegRecorder) Start(ctx context.Context) error {
214214
fr.mu.Unlock()
215215

216216
if err := cmd.Start(); err != nil {
217-
_ = fr.stz.Enable(context.WithoutCancel(ctx))
217+
_ = fr.stz.Release(context.WithoutCancel(ctx))
218218
fr.mu.Lock()
219219
fr.ffmpegErr = err
220220
fr.cmd = nil // reset cmd on failure to start so IsRecording() remains correct
@@ -238,7 +238,7 @@ func (fr *FFmpegRecorder) Start(ctx context.Context) error {
238238

239239
// Stop gracefully stops the recording using a multi-phase shutdown process.
240240
func (fr *FFmpegRecorder) Stop(ctx context.Context) error {
241-
defer fr.stz.Enable(context.WithoutCancel(ctx))
241+
defer fr.stz.Release(context.WithoutCancel(ctx))
242242

243243
// Use singleflight to prevent concurrent Stop() calls from sending multiple SIGINTs
244244
// to ffmpeg, which causes immediate abort without proper file closure.
@@ -281,7 +281,7 @@ func (fr *FFmpegRecorder) WaitForFinalization(ctx context.Context) error {
281281
func (fr *FFmpegRecorder) ForceStop(ctx context.Context) error {
282282
log := logger.FromContext(ctx)
283283

284-
defer fr.stz.Enable(context.WithoutCancel(ctx))
284+
defer fr.stz.Release(context.WithoutCancel(ctx))
285285
shutdownErr := fr.shutdownInPhases(ctx, []shutdownPhase{
286286
{"kill", []syscall.Signal{syscall.SIGKILL}, 100 * time.Millisecond, "immediate kill"},
287287
})
@@ -530,7 +530,7 @@ func ffmpegArgs(params FFmpegRecordingParams, outputPath string) ([]string, erro
530530
// update the internal state accordingly. It also triggers finalization to add proper duration
531531
// metadata for recordings that exit naturally (max duration, max file size, etc.).
532532
func (fr *FFmpegRecorder) waitForCommand(ctx context.Context) {
533-
defer fr.stz.Enable(context.WithoutCancel(ctx))
533+
defer fr.stz.Release(context.WithoutCancel(ctx))
534534

535535
log := logger.FromContext(ctx)
536536

server/lib/scaletozero/middleware.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ import (
88
"github.com/kernel/kernel-images/server/lib/logger"
99
)
1010

11-
// Middleware returns a standard net/http middleware that disables scale-to-zero
12-
// at the start of each request and re-enables it after the handler completes.
13-
// Connections from loopback addresses are ignored and do not affect the
14-
// scale-to-zero state.
11+
// Middleware returns a standard net/http middleware that acquires a
12+
// scale-to-zero hold at the start of each request and releases it after the
13+
// handler completes. Connections from loopback addresses are ignored and do
14+
// not affect the scale-to-zero state.
1515
func Middleware(ctrl Controller) func(http.Handler) http.Handler {
1616
return func(next http.Handler) http.Handler {
1717
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -20,12 +20,12 @@ func Middleware(ctrl Controller) func(http.Handler) http.Handler {
2020
return
2121
}
2222

23-
if err := ctrl.Disable(r.Context()); err != nil {
24-
logger.FromContext(r.Context()).Error("failed to disable scale-to-zero", "error", err)
25-
http.Error(w, "failed to disable scale-to-zero", http.StatusInternalServerError)
23+
if err := ctrl.Acquire(r.Context()); err != nil {
24+
logger.FromContext(r.Context()).Error("failed to acquire scale-to-zero hold", "error", err)
25+
http.Error(w, "failed to acquire scale-to-zero hold", http.StatusInternalServerError)
2626
return
2727
}
28-
defer ctrl.Enable(context.WithoutCancel(r.Context()))
28+
defer ctrl.Release(context.WithoutCancel(r.Context()))
2929

3030
next.ServeHTTP(w, r)
3131
})

server/lib/scaletozero/middleware_test.go

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,44 @@
11
package scaletozero
22

33
import (
4+
"context"
45
"net/http"
56
"net/http/httptest"
7+
"sync"
68
"testing"
79

810
"github.com/stretchr/testify/assert"
911
"github.com/stretchr/testify/require"
1012
)
1113

12-
func TestMiddlewareDisablesAndEnablesForExternalAddr(t *testing.T) {
14+
type mockController struct {
15+
mu sync.Mutex
16+
acquireCalls int
17+
releaseCalls int
18+
acquireErr error
19+
releaseErr error
20+
}
21+
22+
func (m *mockController) Acquire(ctx context.Context) error {
23+
m.mu.Lock()
24+
defer m.mu.Unlock()
25+
m.acquireCalls++
26+
return m.acquireErr
27+
}
28+
29+
func (m *mockController) Release(ctx context.Context) error {
30+
m.mu.Lock()
31+
defer m.mu.Unlock()
32+
m.releaseCalls++
33+
return m.releaseErr
34+
}
35+
36+
func (m *mockController) Disable(ctx context.Context) error { return nil }
37+
func (m *mockController) Enable(ctx context.Context) error { return nil }
38+
39+
func TestMiddlewareAcquiresAndReleasesForExternalAddr(t *testing.T) {
1340
t.Parallel()
14-
mock := &mockScaleToZeroer{}
41+
mock := &mockController{}
1542
handler := Middleware(mock)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1643
w.WriteHeader(http.StatusOK)
1744
}))
@@ -23,8 +50,8 @@ func TestMiddlewareDisablesAndEnablesForExternalAddr(t *testing.T) {
2350
handler.ServeHTTP(rec, req)
2451

2552
assert.Equal(t, http.StatusOK, rec.Code)
26-
assert.Equal(t, 1, mock.disableCalls)
27-
assert.Equal(t, 1, mock.enableCalls)
53+
assert.Equal(t, 1, mock.acquireCalls)
54+
assert.Equal(t, 1, mock.releaseCalls)
2855
}
2956

3057
func TestMiddlewareSkipsLoopbackAddrs(t *testing.T) {
@@ -41,7 +68,7 @@ func TestMiddlewareSkipsLoopbackAddrs(t *testing.T) {
4168
for _, tc := range loopbackAddrs {
4269
t.Run(tc.name, func(t *testing.T) {
4370
t.Parallel()
44-
mock := &mockScaleToZeroer{}
71+
mock := &mockController{}
4572
var called bool
4673
handler := Middleware(mock)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
4774
called = true
@@ -56,15 +83,15 @@ func TestMiddlewareSkipsLoopbackAddrs(t *testing.T) {
5683

5784
assert.True(t, called, "handler should still be called")
5885
assert.Equal(t, http.StatusOK, rec.Code)
59-
assert.Equal(t, 0, mock.disableCalls, "should not disable for loopback addr")
60-
assert.Equal(t, 0, mock.enableCalls, "should not enable for loopback addr")
86+
assert.Equal(t, 0, mock.acquireCalls, "should not acquire for loopback addr")
87+
assert.Equal(t, 0, mock.releaseCalls, "should not release for loopback addr")
6188
})
6289
}
6390
}
6491

65-
func TestMiddlewareDisableError(t *testing.T) {
92+
func TestMiddlewareAcquireError(t *testing.T) {
6693
t.Parallel()
67-
mock := &mockScaleToZeroer{disableErr: assert.AnError}
94+
mock := &mockController{acquireErr: assert.AnError}
6895
var called bool
6996
handler := Middleware(mock)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
7097
called = true
@@ -76,9 +103,9 @@ func TestMiddlewareDisableError(t *testing.T) {
76103

77104
handler.ServeHTTP(rec, req)
78105

79-
assert.False(t, called, "handler should not be called on disable error")
106+
assert.False(t, called, "handler should not be called on acquire error")
80107
assert.Equal(t, http.StatusInternalServerError, rec.Code)
81-
assert.Equal(t, 0, mock.enableCalls)
108+
assert.Equal(t, 0, mock.releaseCalls)
82109
}
83110

84111
func TestIsLoopbackAddr(t *testing.T) {

0 commit comments

Comments
 (0)