11package scaletozero
22
33import (
4+ "context"
45 "net/http"
56 "net/http/httptest"
7+ "sync"
68 "testing"
79
810 "github.com/stretchr/testify/assert"
911 "github.com/stretchr/testify/require"
1012)
1113
12- func TestMiddlewareDisablesAndEnablesForExternalAddr (t * testing.T ) {
14+ type mockController struct {
15+ mu sync.Mutex
16+ acquireCalls int
17+ releaseCalls int
18+ acquireErr error
19+ releaseErr error
20+ }
21+
22+ func (m * mockController ) Acquire (ctx context.Context ) error {
23+ m .mu .Lock ()
24+ defer m .mu .Unlock ()
25+ m .acquireCalls ++
26+ return m .acquireErr
27+ }
28+
29+ func (m * mockController ) Release (ctx context.Context ) error {
30+ m .mu .Lock ()
31+ defer m .mu .Unlock ()
32+ m .releaseCalls ++
33+ return m .releaseErr
34+ }
35+
36+ func (m * mockController ) Disable (ctx context.Context ) error { return nil }
37+ func (m * mockController ) Enable (ctx context.Context ) error { return nil }
38+
39+ func TestMiddlewareAcquiresAndReleasesForExternalAddr (t * testing.T ) {
1340 t .Parallel ()
14- mock := & mockScaleToZeroer {}
41+ mock := & mockController {}
1542 handler := Middleware (mock )(http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
1643 w .WriteHeader (http .StatusOK )
1744 }))
@@ -23,8 +50,8 @@ func TestMiddlewareDisablesAndEnablesForExternalAddr(t *testing.T) {
2350 handler .ServeHTTP (rec , req )
2451
2552 assert .Equal (t , http .StatusOK , rec .Code )
26- assert .Equal (t , 1 , mock .disableCalls )
27- assert .Equal (t , 1 , mock .enableCalls )
53+ assert .Equal (t , 1 , mock .acquireCalls )
54+ assert .Equal (t , 1 , mock .releaseCalls )
2855}
2956
3057func TestMiddlewareSkipsLoopbackAddrs (t * testing.T ) {
@@ -41,7 +68,7 @@ func TestMiddlewareSkipsLoopbackAddrs(t *testing.T) {
4168 for _ , tc := range loopbackAddrs {
4269 t .Run (tc .name , func (t * testing.T ) {
4370 t .Parallel ()
44- mock := & mockScaleToZeroer {}
71+ mock := & mockController {}
4572 var called bool
4673 handler := Middleware (mock )(http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
4774 called = true
@@ -56,15 +83,15 @@ func TestMiddlewareSkipsLoopbackAddrs(t *testing.T) {
5683
5784 assert .True (t , called , "handler should still be called" )
5885 assert .Equal (t , http .StatusOK , rec .Code )
59- assert .Equal (t , 0 , mock .disableCalls , "should not disable for loopback addr" )
60- assert .Equal (t , 0 , mock .enableCalls , "should not enable for loopback addr" )
86+ assert .Equal (t , 0 , mock .acquireCalls , "should not acquire for loopback addr" )
87+ assert .Equal (t , 0 , mock .releaseCalls , "should not release for loopback addr" )
6188 })
6289 }
6390}
6491
65- func TestMiddlewareDisableError (t * testing.T ) {
92+ func TestMiddlewareAcquireError (t * testing.T ) {
6693 t .Parallel ()
67- mock := & mockScaleToZeroer { disableErr : assert .AnError }
94+ mock := & mockController { acquireErr : assert .AnError }
6895 var called bool
6996 handler := Middleware (mock )(http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
7097 called = true
@@ -76,9 +103,9 @@ func TestMiddlewareDisableError(t *testing.T) {
76103
77104 handler .ServeHTTP (rec , req )
78105
79- assert .False (t , called , "handler should not be called on disable error" )
106+ assert .False (t , called , "handler should not be called on acquire error" )
80107 assert .Equal (t , http .StatusInternalServerError , rec .Code )
81- assert .Equal (t , 0 , mock .enableCalls )
108+ assert .Equal (t , 0 , mock .releaseCalls )
82109}
83110
84111func TestIsLoopbackAddr (t * testing.T ) {
0 commit comments