Skip to content

Commit 2b2d168

Browse files
authored
Merge branch 'main' into jerm/vmcp-o11y
2 parents b0bedaa + 92922c6 commit 2b2d168

File tree

6 files changed

+139
-98
lines changed

6 files changed

+139
-98
lines changed

cmd/thv/app/run.go

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@ import (
77
"net"
88
"net/url"
99
"os"
10-
"os/signal"
1110
"strings"
12-
"syscall"
1311
"time"
1412

1513
"github.com/spf13/cobra"
@@ -126,7 +124,7 @@ func init() {
126124
AddOIDCFlags(runCmd)
127125
}
128126

129-
func cleanupAndWait(workloadManager workloads.Manager, name string, cancel context.CancelFunc, errCh <-chan error) {
127+
func cleanupAndWait(workloadManager workloads.Manager, name string) {
130128
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), 30*time.Second)
131129
defer cleanupCancel()
132130

@@ -138,13 +136,6 @@ func cleanupAndWait(workloadManager workloads.Manager, name string, cancel conte
138136
logger.Warnf("DeleteWorkloads group error for %q: %v", name, err)
139137
}
140138
}
141-
142-
cancel()
143-
select {
144-
case <-errCh:
145-
case <-time.After(5 * time.Second):
146-
logger.Warnf("Timeout waiting for workload to stop")
147-
}
148139
}
149140

150141
// nolint:gocyclo // This function is complex by design
@@ -304,28 +295,26 @@ func getworkloadDefaultName(_ context.Context, serverOrImage string) string {
304295
}
305296

306297
func runForeground(ctx context.Context, workloadManager workloads.Manager, runnerConfig *runner.RunConfig) error {
307-
ctx, cancel := context.WithCancel(ctx)
308-
defer cancel()
309-
310-
sigCh := make(chan os.Signal, 1)
311-
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
312-
defer signal.Stop(sigCh)
313298

314299
errCh := make(chan error, 1)
315300
go func() {
316301
errCh <- workloadManager.RunWorkload(ctx, runnerConfig)
317302
}()
318303

319-
select {
320-
case sig := <-sigCh:
321-
if !process.IsDetached() {
322-
logger.Infof("Received signal: %v, stopping server %q", sig, runnerConfig.BaseName)
323-
cleanupAndWait(workloadManager, runnerConfig.BaseName, cancel, errCh)
324-
}
325-
return nil
326-
case err := <-errCh:
327-
return err
304+
// workloadManager.RunWorkload will block until the context is cancelled
305+
// or an unrecoverable error is returned. In either case, it will stop the server.
306+
// We wait until workloadManager.RunWorkload exits before deleting the workload,
307+
// so stopping and deleting don't race.
308+
//
309+
// There's room for improvement in the factoring here.
310+
// Shutdown and cancellation logic is unnecessarily spread across two goroutines.
311+
err := <-errCh
312+
if !process.IsDetached() {
313+
logger.Infof("RunWorkload Exited. Error: %v, stopping server %q", err, runnerConfig.BaseName)
314+
cleanupAndWait(workloadManager, runnerConfig.BaseName)
328315
}
316+
return err
317+
329318
}
330319

331320
func validateGroup(ctx context.Context, workloadsManager workloads.Manager, serverOrImage string) error {

cmd/thv/main.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
package main
33

44
import (
5+
"context"
56
"os"
67
"os/signal"
78
"syscall"
@@ -12,7 +13,6 @@ import (
1213
"github.com/stacklok/toolhive/cmd/thv/app"
1314
"github.com/stacklok/toolhive/pkg/client"
1415
"github.com/stacklok/toolhive/pkg/container"
15-
"github.com/stacklok/toolhive/pkg/container/runtime"
1616
"github.com/stacklok/toolhive/pkg/lockfile"
1717
"github.com/stacklok/toolhive/pkg/logger"
1818
"github.com/stacklok/toolhive/pkg/migration"
@@ -23,7 +23,7 @@ func main() {
2323
logger.Initialize()
2424

2525
// Setup signal handling for graceful cleanup
26-
setupSignalHandler()
26+
ctx := setupSignalHandler()
2727

2828
// Clean up stale lock files on startup
2929
cleanupStaleLockFiles()
@@ -47,8 +47,10 @@ func main() {
4747
migration.CheckAndPerformDefaultGroupMigration()
4848
}
4949

50+
cmd := app.NewRootCmd(!app.IsCompletionCommand(os.Args))
51+
5052
// Skip update check for completion command or if we are running in kubernetes
51-
if err := app.NewRootCmd(!app.IsCompletionCommand(os.Args) && !runtime.IsKubernetesRuntime()).Execute(); err != nil {
53+
if err := cmd.ExecuteContext(ctx); err != nil {
5254
// Clean up any remaining lock files on error exit
5355
lockfile.CleanupAllLocks()
5456
os.Exit(1)
@@ -59,16 +61,19 @@ func main() {
5961
}
6062

6163
// setupSignalHandler configures signal handling to ensure lock files are cleaned up
62-
func setupSignalHandler() {
64+
func setupSignalHandler() context.Context {
6365
sigCh := make(chan os.Signal, 1)
6466
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM, syscall.SIGQUIT)
6567

68+
ctx, cancel := context.WithCancel(context.Background())
6669
go func() {
6770
<-sigCh
6871
logger.Debugf("Received signal, cleaning up lock files...")
6972
lockfile.CleanupAllLocks()
70-
os.Exit(0)
73+
cancel()
7174
}()
75+
76+
return ctx
7277
}
7378

7479
// cleanupStaleLockFiles removes stale lock files from known directories on startup

pkg/runner/runner.go

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@ import (
88
"fmt"
99
"net/http"
1010
"os"
11-
"os/signal"
1211
"strings"
13-
"syscall"
1412
"time"
1513

1614
"golang.org/x/oauth2"
@@ -317,16 +315,19 @@ func (r *Runner) Run(ctx context.Context) error {
317315

318316
// Define a function to stop the MCP server
319317
stopMCPServer := func(reason string) {
318+
// Use a background context to avoid cancellation of the main context.
319+
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), 1*time.Minute)
320+
defer cleanupCancel()
320321
logger.Infof("Stopping MCP server: %s", reason)
321322

322323
// Stop the transport (which also stops the container, monitoring, and handles removal)
323324
logger.Infof("Stopping %s transport...", r.Config.Transport)
324-
if err := transportHandler.Stop(ctx); err != nil {
325+
if err := transportHandler.Stop(cleanupCtx); err != nil {
325326
logger.Warnf("Warning: Failed to stop transport: %v", err)
326327
}
327328

328329
// Cleanup telemetry provider
329-
if err := r.Cleanup(ctx); err != nil {
330+
if err := r.Cleanup(cleanupCtx); err != nil {
330331
logger.Warnf("Warning: Failed to cleanup telemetry: %v", err)
331332
}
332333

@@ -335,7 +336,7 @@ func (r *Runner) Run(ctx context.Context) error {
335336
if err := process.RemovePIDFile(r.Config.BaseName); err != nil {
336337
logger.Warnf("Warning: Failed to remove PID file: %v", err)
337338
}
338-
if err := r.statusManager.ResetWorkloadPID(ctx, r.Config.BaseName); err != nil {
339+
if err := r.statusManager.ResetWorkloadPID(cleanupCtx, r.Config.BaseName); err != nil {
339340
logger.Warnf("Warning: Failed to reset workload %s PID: %v", r.Config.ContainerName, err)
340341
}
341342

@@ -354,10 +355,6 @@ func (r *Runner) Run(ctx context.Context) error {
354355
logger.Info("Press Ctrl+C to stop or wait for container to exit")
355356
}
356357

357-
// Set up signal handling
358-
sigCh := make(chan os.Signal, 1)
359-
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
360-
361358
// Create a done channel to signal when the server has been stopped
362359
doneCh := make(chan struct{})
363360

@@ -399,8 +396,8 @@ func (r *Runner) Run(ctx context.Context) error {
399396

400397
// Wait for either a signal or the done channel to be closed
401398
select {
402-
case sig := <-sigCh:
403-
stopMCPServer(fmt.Sprintf("Received signal %s", sig))
399+
case <-ctx.Done():
400+
stopMCPServer("Context cancelled")
404401
case <-doneCh:
405402
// The transport has already been stopped (likely by the container exit)
406403
// Clean up the PID file and state

pkg/workloads/manager.go

Lines changed: 58 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,7 @@ func (d *DefaultManager) DeleteWorkloads(_ context.Context, names []string) (*er
872872
}
873873

874874
// RestartWorkloads restarts the specified workloads by name.
875-
func (d *DefaultManager) RestartWorkloads(_ context.Context, names []string, foreground bool) (*errgroup.Group, error) {
875+
func (d *DefaultManager) RestartWorkloads(ctx context.Context, names []string, foreground bool) (*errgroup.Group, error) {
876876
// Validate all workload names to prevent path traversal attacks
877877
for _, name := range names {
878878
if err := types.ValidateWorkloadName(name); err != nil {
@@ -884,7 +884,7 @@ func (d *DefaultManager) RestartWorkloads(_ context.Context, names []string, for
884884

885885
for _, name := range names {
886886
group.Go(func() error {
887-
return d.restartSingleWorkload(name, foreground)
887+
return d.restartSingleWorkload(ctx, name, foreground)
888888
})
889889
}
890890

@@ -943,39 +943,59 @@ func (d *DefaultManager) updateSingleWorkload(workloadName string, newConfig *ru
943943
}
944944

945945
// restartSingleWorkload handles the restart logic for a single workload
946-
func (d *DefaultManager) restartSingleWorkload(name string, foreground bool) error {
947-
// Create a child context with a longer timeout
948-
childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout)
949-
defer cancel()
946+
func (d *DefaultManager) restartSingleWorkload(ctx context.Context, name string, foreground bool) error {
950947

951948
// First, try to load the run configuration to check if it's a remote workload
952-
runConfig, err := runner.LoadState(childCtx, name)
949+
runConfig, err := runner.LoadState(ctx, name)
953950
if err != nil {
954951
// If we can't load the state, it might be a container workload or the workload doesn't exist
955952
// Try to restart it as a container workload
956-
return d.restartContainerWorkload(childCtx, name, foreground)
953+
return d.restartContainerWorkload(ctx, name, foreground)
957954
}
958955

959956
// Check if this is a remote workload
960957
if runConfig.RemoteURL != "" {
961-
return d.restartRemoteWorkload(childCtx, name, runConfig, foreground)
958+
return d.restartRemoteWorkload(ctx, name, runConfig, foreground)
962959
}
963960

964961
// This is a container-based workload
965-
return d.restartContainerWorkload(childCtx, name, foreground)
962+
return d.restartContainerWorkload(ctx, name, foreground)
966963
}
967964

968965
// restartRemoteWorkload handles restarting a remote workload
966+
// It blocks until the context is cancelled or there is already a supervisor process running.
969967
func (d *DefaultManager) restartRemoteWorkload(
970968
ctx context.Context,
971969
name string,
972970
runConfig *runner.RunConfig,
973971
foreground bool,
974972
) error {
973+
mcpRunner, err := d.maybeSetupRemoteWorkload(ctx, name, runConfig)
974+
if err != nil {
975+
return fmt.Errorf("failed to setup remote workload: %w", err)
976+
}
977+
978+
if mcpRunner == nil {
979+
return nil
980+
}
981+
982+
return d.startWorkload(ctx, name, mcpRunner, foreground)
983+
}
984+
985+
// maybeSetupRemoteWorkload is the startup steps for a remote workload.
986+
// A runner may not be returned if the workload is already running and supervised.
987+
func (d *DefaultManager) maybeSetupRemoteWorkload(
988+
ctx context.Context,
989+
name string,
990+
runConfig *runner.RunConfig,
991+
) (*runner.Runner, error) {
992+
ctx, cancel := context.WithTimeout(ctx, AsyncOperationTimeout)
993+
defer cancel()
994+
975995
// Get workload status using the status manager
976996
workload, err := d.statuses.GetWorkload(ctx, name)
977997
if err != nil && !errors.Is(err, rt.ErrWorkloadNotFound) {
978-
return err
998+
return nil, err
979999
}
9801000

9811001
// If workload is already running, check if the supervisor process is healthy
@@ -986,7 +1006,7 @@ func (d *DefaultManager) restartRemoteWorkload(
9861006
if supervisorAlive {
9871007
// Workload is running and healthy - preserve old behavior (no-op)
9881008
logger.Infof("Remote workload %s is already running", name)
989-
return nil
1009+
return nil, nil
9901010
}
9911011

9921012
// Supervisor is dead/missing - we need to clean up and restart to fix the damaged state
@@ -1015,7 +1035,7 @@ func (d *DefaultManager) restartRemoteWorkload(
10151035
// Load runner configuration from state
10161036
mcpRunner, err := d.loadRunnerFromState(ctx, runConfig.BaseName)
10171037
if err != nil {
1018-
return fmt.Errorf("failed to load state for %s: %v", runConfig.BaseName, err)
1038+
return nil, fmt.Errorf("failed to load state for %s: %v", runConfig.BaseName, err)
10191039
}
10201040

10211041
// Set status to starting
@@ -1024,16 +1044,31 @@ func (d *DefaultManager) restartRemoteWorkload(
10241044
}
10251045

10261046
logger.Infof("Loaded configuration from state for %s", runConfig.BaseName)
1047+
return mcpRunner, nil
1048+
}
1049+
1050+
// restartContainerWorkload handles restarting a container-based workload.
1051+
// It blocks until the context is cancelled or there is already a supervisor process running.
1052+
func (d *DefaultManager) restartContainerWorkload(ctx context.Context, name string, foreground bool) error {
1053+
workloadName, mcpRunner, err := d.maybeSetupContainerWorkload(ctx, name)
1054+
if err != nil {
1055+
return fmt.Errorf("failed to setup container workload: %w", err)
1056+
}
1057+
1058+
if mcpRunner == nil {
1059+
return nil
1060+
}
10271061

1028-
// Start the remote workload using the loaded runner
1029-
// Use background context to avoid timeout cancellation - same reasoning as container workloads
1030-
return d.startWorkload(context.Background(), name, mcpRunner, foreground)
1062+
return d.startWorkload(ctx, workloadName, mcpRunner, foreground)
10311063
}
10321064

1033-
// restartContainerWorkload handles restarting a container-based workload
1065+
// maybeSetupContainerWorkload is the startup steps for a container-based workload.
1066+
// A runner may not be returned if the workload is already running and supervised.
10341067
//
10351068
//nolint:gocyclo // Complexity is justified - handles multiple restart scenarios and edge cases
1036-
func (d *DefaultManager) restartContainerWorkload(ctx context.Context, name string, foreground bool) error {
1069+
func (d *DefaultManager) maybeSetupContainerWorkload(ctx context.Context, name string) (string, *runner.Runner, error) {
1070+
ctx, cancel := context.WithTimeout(ctx, AsyncOperationTimeout)
1071+
defer cancel()
10371072
// Get container info to resolve partial names and extract proper workload name
10381073
var containerName string
10391074
var workloadName string
@@ -1057,7 +1092,7 @@ func (d *DefaultManager) restartContainerWorkload(ctx context.Context, name stri
10571092
// Get workload status using the status manager
10581093
workload, err := d.statuses.GetWorkload(ctx, name)
10591094
if err != nil && !errors.Is(err, rt.ErrWorkloadNotFound) {
1060-
return err
1095+
return "", nil, err
10611096
}
10621097

10631098
// Check if workload is running and healthy (including supervisor process)
@@ -1068,7 +1103,7 @@ func (d *DefaultManager) restartContainerWorkload(ctx context.Context, name stri
10681103
if supervisorAlive {
10691104
// Workload is running and healthy - preserve old behavior (no-op)
10701105
logger.Infof("Container %s is already running", containerName)
1071-
return nil
1106+
return "", nil, nil
10721107
}
10731108

10741109
// Supervisor is dead/missing - we need to clean up and restart to fix the damaged state
@@ -1107,7 +1142,7 @@ func (d *DefaultManager) restartContainerWorkload(ctx context.Context, name stri
11071142
if statusErr := d.statuses.SetWorkloadStatus(ctx, workloadName, rt.WorkloadStatusError, err.Error()); statusErr != nil {
11081143
logger.Warnf("Failed to set workload %s status to error: %v", workloadName, statusErr)
11091144
}
1110-
return fmt.Errorf("failed to stop container %s: %v", containerName, err)
1145+
return "", nil, fmt.Errorf("failed to stop container %s: %v", containerName, err)
11111146
}
11121147
logger.Infof("Container %s stopped", containerName)
11131148
}
@@ -1126,7 +1161,7 @@ func (d *DefaultManager) restartContainerWorkload(ctx context.Context, name stri
11261161
// Load runner configuration from state
11271162
mcpRunner, err := d.loadRunnerFromState(ctx, workloadName)
11281163
if err != nil {
1129-
return fmt.Errorf("failed to load state for %s: %v", workloadName, err)
1164+
return "", nil, fmt.Errorf("failed to load state for %s: %v", workloadName, err)
11301165
}
11311166

11321167
// Set workload status to starting - use the workload name for status operations
@@ -1135,11 +1170,7 @@ func (d *DefaultManager) restartContainerWorkload(ctx context.Context, name stri
11351170
}
11361171
logger.Infof("Loaded configuration from state for %s", workloadName)
11371172

1138-
// Start the workload with background context to avoid timeout cancellation
1139-
// The ctx with AsyncOperationTimeout is only for the restart setup operations,
1140-
// but the actual workload should run indefinitely with its own lifecycle management
1141-
// Use workload name for user-facing operations
1142-
return d.startWorkload(context.Background(), workloadName, mcpRunner, foreground)
1173+
return workloadName, mcpRunner, nil
11431174
}
11441175

11451176
// startWorkload starts the workload in either foreground or background mode

0 commit comments

Comments
 (0)