Skip to content
This repository was archived by the owner on Apr 15, 2026. It is now read-only.

Commit dc0d403

Browse files
committed
Implement graceful shutdown with proper prediction completion
This implements a comprehensive graceful shutdown mechanism that waits for in-flight predictions to complete before stopping runners and the service. Key changes: **Runner-level graceful shutdown:** - Add shutdownWhenIdle atomic flag and readyForShutdown channel to Runner - GracefulShutdown() signals runners to shutdown when idle - updateStatus() automatically closes readyForShutdown when becoming READY with no pending predictions - Add nil check with warning for test compatibility **Handler-level prediction rejection:** - Add gracefulShutdown atomic flag to reject new predictions during shutdown - Handler.Stop() sets flag and waits for manager shutdown - Predict() returns 503 Service Unavailable during shutdown **Manager-level coordinated shutdown:** - Manager.Stop() signals all runners for graceful shutdown - Use WaitGroup.Go() for independent parallel runner shutdowns - Respect RunnerShutdownGracePeriod timeout before force stopping - Wait on runner.readyForShutdown channel or timeout **Service-level errgroup coordination:** - Fix errgroup goroutines to exit on shutdown signal - Add shutdown case to force shutdown monitor goroutine - Signal handler already had proper shutdown case - Add contextcheck nolint for long-lived errgroup context **Test coverage:** - Add E2E test for 503 rejection of new predictions during shutdown - Verify graceful shutdown waits for in-flight predictions - Test service properly stops after shutdown completes This restores the graceful shutdown behavior from commit 575d218 that was lost during the server refactor, ensuring predictions complete naturally during the grace period rather than being immediately force-killed.
1 parent b4f70de commit dc0d403

8 files changed

Lines changed: 521 additions & 118 deletions

File tree

internal/runner/manager.go

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ import (
1717
"time"
1818

1919
"go.uber.org/zap"
20-
"golang.org/x/sync/errgroup"
2120

2221
"github.com/replicate/cog-runtime/internal/config"
2322
"github.com/replicate/cog-runtime/internal/webhook"
@@ -747,43 +746,48 @@ func (m *Manager) Stop() error {
747746
log.Info("stopping runner manager")
748747

749748
m.mu.Lock()
750-
defer m.mu.Unlock()
751-
752-
// Stop all runners
753-
for i, runner := range m.runners {
749+
runnerList := make([]*Runner, 0, len(m.runners))
750+
for _, runner := range m.runners {
754751
if runner != nil {
755-
log.Infow("stopping runner", "name", runner.runnerCtx.id, "slot", i)
756-
if err := runner.Stop(); err != nil {
757-
log.Errorw("error stopping runner", "name", runner.runnerCtx.id, "error", err)
758-
if stopErr == nil {
759-
stopErr = err
760-
}
761-
}
752+
runnerList = append(runnerList, runner)
762753
}
763754
}
755+
m.mu.Unlock()
764756

765-
// Wait for runners to stop concurrently
766-
eg := errgroup.Group{}
767-
for i, runner := range m.runners {
768-
if runner != nil {
769-
name := runner.runnerCtx.id
770-
eg.Go(func() error {
771-
log.Infow("waiting for runner to stop", "name", name, "slot", i)
772-
runner.WaitForStop()
773-
return nil
774-
})
775-
}
757+
// Signal all runners for graceful shutdown
758+
for _, runner := range runnerList {
759+
runner.GracefulShutdown()
776760
}
777761

778-
if err := eg.Wait(); err != nil {
779-
log.Errorw("error waiting for runners to stop", "error", err)
780-
if stopErr == nil {
781-
stopErr = err
782-
}
783-
} else {
784-
log.Info("all runners stopped successfully")
762+
// Wait for runners to become idle or timeout using WaitGroup
763+
gracePeriod := m.cfg.RunnerShutdownGracePeriod
764+
log.Infow("grace period configuration", "grace_period", gracePeriod)
765+
graceCtx, cancel := context.WithTimeout(m.ctx, gracePeriod)
766+
defer cancel()
767+
768+
var wg sync.WaitGroup
769+
for _, runner := range runnerList {
770+
wg.Go(func() {
771+
log.Debugw("waiting for runner to become idle", "name", runner.runnerCtx.id, "grace_period", gracePeriod)
772+
// Wait for this runner to become idle OR timeout
773+
select {
774+
case <-runner.readyForShutdown:
775+
log.Infow("runner became idle naturally", "name", runner.runnerCtx.id)
776+
case <-graceCtx.Done():
777+
log.Warnw("grace period expired for runner", "name", runner.runnerCtx.id, "context_err", graceCtx.Err())
778+
}
779+
780+
// Always try to stop, handle errors independently
781+
if err := runner.Stop(); err != nil {
782+
log.Errorw("failed to stop runner gracefully", "name", runner.runnerCtx.id, "error", err)
783+
}
784+
})
785785
}
786786

787+
// Wait for all runners to complete shutdown (success or failure)
788+
wg.Wait()
789+
790+
log.Info("all runners stopped successfully")
787791
close(m.stopped)
788792
})
789793

internal/runner/runner.go

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"regexp"
1414
"strings"
1515
"sync"
16+
"sync/atomic"
1617
"syscall"
1718
"time"
1819

@@ -303,6 +304,8 @@ type Runner struct {
303304
procedureHash string
304305
mu sync.RWMutex
305306
stopped chan bool
307+
shutdownWhenIdle atomic.Bool
308+
readyForShutdown chan struct{} // closed when idle and ready to be stopped
306309
setupComplete chan struct{} // closed on first READY after setup
307310
webhookSender webhook.Sender
308311
logCaptureComplete chan struct{} // closed when both stdout/stderr capture complete
@@ -345,6 +348,34 @@ func (r *Runner) WaitForStop() {
345348
<-r.stopped
346349
}
347350

351+
func (r *Runner) GracefulShutdown() {
352+
log := r.logger.Sugar()
353+
if !r.shutdownWhenIdle.CompareAndSwap(false, true) {
354+
log.Debugw("graceful shutdown already initiated", "runner_id", r.runnerCtx.id)
355+
return
356+
}
357+
358+
r.mu.RLock()
359+
shouldSignal := (r.status == StatusReady && len(r.pending) == 0)
360+
r.mu.RUnlock()
361+
362+
log.Debugw("graceful shutdown initiated", "runner_id", r.runnerCtx.id, "status", r.status, "pending_count", len(r.pending), "should_signal", shouldSignal)
363+
364+
if shouldSignal {
365+
if r.readyForShutdown == nil {
366+
log.Warnw("readyForShutdown channel is nil, cannot signal shutdown readiness", "runner_id", r.runnerCtx.id)
367+
} else {
368+
select {
369+
case <-r.readyForShutdown:
370+
log.Debugw("readyForShutdown already closed", "runner_id", r.runnerCtx.id)
371+
default:
372+
log.Debugw("closing readyForShutdown channel", "runner_id", r.runnerCtx.id)
373+
close(r.readyForShutdown)
374+
}
375+
}
376+
}
377+
}
378+
348379
func (r *Runner) Start(ctx context.Context) error {
349380
log := r.logger.Sugar()
350381
r.mu.Lock()
@@ -806,6 +837,17 @@ func (r *Runner) updateStatus(statusStr string) error {
806837
return err
807838
}
808839
r.status = status
840+
841+
// Close readyForShutdown channel when idle and shutdown requested
842+
if status == StatusReady && r.shutdownWhenIdle.Load() && len(r.pending) == 0 {
843+
select {
844+
case <-r.readyForShutdown:
845+
// Already closed
846+
default:
847+
close(r.readyForShutdown)
848+
}
849+
}
850+
809851
return nil
810852
}
811853

@@ -909,14 +951,14 @@ func (r *Runner) updateSetupResult() {
909951
switch r.setupResult.Status {
910952
case SetupSucceeded:
911953
r.status = StatusReady
912-
log.Debug("setup succeeded", "status", r.status.String())
954+
log.Debugw("setup succeeded", "status", r.status.String())
913955
case SetupFailed:
914956
r.status = StatusSetupFailed
915-
log.Debug("setup failed", "status", r.status.String())
957+
log.Debugw("setup failed", "status", r.status.String())
916958
default:
917959
r.setupResult.Status = SetupFailed
918960
r.status = StatusSetupFailed
919-
log.Debug("unknown setup status, defaulting to failed", "status", r.status.String())
961+
log.Debugw("unknown setup status, defaulting to failed", "status", r.status.String())
920962
}
921963
}
922964

@@ -972,6 +1014,7 @@ func NewRunner(ctx context.Context, ctxCancel context.CancelFunc, runnerCtx Runn
9721014
verifyFn: verifyProcessGroupTerminated,
9731015
cleanupSlot: make(chan struct{}, 1),
9741016
stopped: make(chan bool),
1017+
readyForShutdown: make(chan struct{}),
9751018
setupComplete: make(chan struct{}),
9761019
logCaptureComplete: make(chan struct{}),
9771020
cleanupTimeout: cleanupTimeout,

internal/server/mux.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ func NewServeMux(handler *Handler, useProcedureMode bool) *http.ServeMux {
1818
serveMux.HandleFunc("GET /{$}", handler.Root)
1919
serveMux.HandleFunc("GET /health-check", handler.HealthCheck)
2020
serveMux.HandleFunc("GET /openapi.json", handler.OpenAPI)
21-
serveMux.HandleFunc("POST /shutdown", handler.Shutdown)
2221

2322
if useProcedureMode {
2423
serveMux.HandleFunc("POST /procedures", handler.Predict)

internal/server/server.go

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"net/http"
1111
"os"
1212
"path"
13+
"sync/atomic"
1314
"time"
1415

1516
"go.uber.org/zap"
@@ -37,22 +38,21 @@ type IPC struct {
3738
}
3839

3940
type Handler struct {
40-
cfg config.Config
41-
shutdown context.CancelFunc
42-
startedAt time.Time
43-
runnerManager *runner.Manager
41+
cfg config.Config
42+
startedAt time.Time
43+
runnerManager *runner.Manager
44+
gracefulShutdown atomic.Bool
4445

4546
cwd string
4647

4748
logger *zap.Logger
4849
}
4950

50-
func NewHandler(ctx context.Context, cfg config.Config, shutdown context.CancelFunc, baseLogger *zap.Logger) (*Handler, error) {
51+
func NewHandler(ctx context.Context, cfg config.Config, baseLogger *zap.Logger) (*Handler, error) {
5152
runnerManager := runner.NewManager(ctx, cfg, baseLogger)
5253

5354
h := &Handler{
5455
cfg: cfg,
55-
shutdown: shutdown,
5656
startedAt: time.Now(),
5757
runnerManager: runnerManager,
5858
cwd: cfg.WorkingDirectory,
@@ -133,30 +133,21 @@ func (h *Handler) OpenAPI(w http.ResponseWriter, r *http.Request) {
133133
h.writeBytes(w, []byte(schema))
134134
}
135135

136-
func (h *Handler) Shutdown(w http.ResponseWriter, r *http.Request) {
137-
err := h.Stop()
138-
if err != nil {
139-
http.Error(w, err.Error(), http.StatusInternalServerError)
140-
} else {
141-
w.WriteHeader(http.StatusOK)
142-
}
143-
}
144-
145136
// ForceKillAll immediately force-kills all runners (for test cleanup)
146137
func (h *Handler) ForceKillAll() {
147138
h.runnerManager.ForceKillAll()
148139
}
149140

150141
func (h *Handler) Stop() error {
151-
// Stop the runner manager and handle shutdown in background
152-
go func() {
153-
log := h.logger.Sugar()
154-
if err := h.runnerManager.Stop(); err != nil {
155-
log.Errorw("failed to stop runner manager", "error", err)
156-
os.Exit(1)
157-
}
158-
h.shutdown()
159-
}()
142+
// Set graceful shutdown flag to reject new predictions
143+
h.gracefulShutdown.Store(true)
144+
145+
// Stop the runner manager synchronously
146+
log := h.logger.Sugar()
147+
if err := h.runnerManager.Stop(); err != nil {
148+
log.Errorw("failed to stop runner manager", "error", err)
149+
return err
150+
}
160151
return nil
161152
}
162153

@@ -207,6 +198,13 @@ func (h *Handler) HandleIPC(w http.ResponseWriter, r *http.Request) {
207198

208199
func (h *Handler) Predict(w http.ResponseWriter, r *http.Request) {
209200
log := h.logger.Sugar()
201+
202+
// Reject new predictions during graceful shutdown
203+
if h.gracefulShutdown.Load() {
204+
http.Error(w, "server shutting down", http.StatusServiceUnavailable)
205+
return
206+
}
207+
210208
if r.Header.Get("Content-Type") != "application/json" {
211209
http.Error(w, "invalid content type", http.StatusUnsupportedMediaType)
212210
return

0 commit comments

Comments
 (0)