Skip to content

Commit 939cb54

Browse files
committed
Fix runtime_config API semantics and validation
1 parent 7868d49 commit 939cb54

6 files changed

Lines changed: 498 additions & 37 deletions

File tree

pkg/api/v1/workload_service.go

Lines changed: 118 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,20 @@ import (
77
"context"
88
"errors"
99
"fmt"
10-
"log/slog"
10+
"strings"
1111
"time"
1212

13-
regtypes "github.com/stacklok/toolhive-core/registry/types"
13+
nameref "github.com/google/go-containerregistry/pkg/name"
1414
groupval "github.com/stacklok/toolhive-core/validation/group"
1515
httpval "github.com/stacklok/toolhive-core/validation/http"
1616
"github.com/stacklok/toolhive/pkg/auth/remote"
1717
"github.com/stacklok/toolhive/pkg/config"
1818
"github.com/stacklok/toolhive/pkg/container/runtime"
19+
"github.com/stacklok/toolhive/pkg/container/templates"
1920
"github.com/stacklok/toolhive/pkg/groups"
21+
"github.com/stacklok/toolhive/pkg/logger"
2022
"github.com/stacklok/toolhive/pkg/networking"
23+
regtypes "github.com/stacklok/toolhive/pkg/registry/registry"
2124
"github.com/stacklok/toolhive/pkg/runner"
2225
"github.com/stacklok/toolhive/pkg/runner/retriever"
2326
"github.com/stacklok/toolhive/pkg/secrets"
@@ -32,6 +35,24 @@ const (
3235
imageRetrievalTimeout = 10 * time.Minute
3336
)
3437

38+
func isValidRuntimePackageName(pkg string) bool {
39+
if pkg == "" {
40+
return false
41+
}
42+
for i, r := range pkg {
43+
switch {
44+
case r >= 'a' && r <= 'z':
45+
case r >= 'A' && r <= 'Z':
46+
case r >= '0' && r <= '9':
47+
case r == '.', r == '_':
48+
case (r == '+' || r == '-') && i > 0:
49+
default:
50+
return false
51+
}
52+
}
53+
return true
54+
}
55+
3556
// WorkloadService handles business logic for workload operations
3657
type WorkloadService struct {
3758
workloadManager workloads.Manager
@@ -73,13 +94,13 @@ func (s *WorkloadService) CreateWorkloadFromRequest(ctx context.Context, req *cr
7394

7495
// Save the workload state
7596
if err := runConfig.SaveState(ctx); err != nil {
76-
slog.Error("failed to save workload config", "error", err)
97+
logger.Errorf("Failed to save workload config: %v", err)
7798
return nil, fmt.Errorf("failed to save workload config: %w", err)
7899
}
79100

80101
// Start workload
81102
if err := s.workloadManager.RunWorkloadDetached(ctx, runConfig); err != nil {
82-
slog.Error("failed to start workload", "error", err)
103+
logger.Errorf("Failed to start workload: %v", err)
83104
return nil, fmt.Errorf("failed to start workload: %w", err)
84105
}
85106

@@ -91,7 +112,7 @@ func (s *WorkloadService) UpdateWorkloadFromRequest(ctx context.Context, name st
91112
// If ProxyPort is 0, reuse the existing port
92113
if req.ProxyPort == 0 && existingPort > 0 {
93114
req.ProxyPort = existingPort
94-
slog.Debug("reusing existing port", "port", existingPort, "name", name)
115+
logger.Debugf("Reusing existing port %d for workload %s", existingPort, name)
95116
}
96117

97118
// Build the full run config
@@ -162,7 +183,11 @@ func (s *WorkloadService) BuildFullRunConfig(
162183
var imageURL string
163184
var imageMetadata *regtypes.ImageMetadata
164185
var serverMetadata regtypes.ServerMetadata
165-
var registryProxyPort int
186+
runtimeConfigOverride := runtimeConfigFromRequest(req)
187+
retrievalRuntimeConfig, err := runtimeConfigForImageBuild(req, runtimeConfigOverride)
188+
if err != nil {
189+
return nil, fmt.Errorf("%w: %w", retriever.ErrInvalidRunConfig, err)
190+
}
166191

167192
if req.URL != "" {
168193
// Configure remote authentication if OAuth config is provided
@@ -181,8 +206,8 @@ func (s *WorkloadService) BuildFullRunConfig(
181206
req.Image,
182207
"", // We do not let the user specify a CA cert path here.
183208
retriever.VerifyImageWarn,
184-
"", // TODO Add support for registry groups lookups for API
185-
nil, // No runtime override from API (yet)
209+
"", // TODO Add support for registry groups lookups for API
210+
retrievalRuntimeConfig,
186211
)
187212
if err != nil {
188213
// Check if the error is due to context timeout
@@ -193,12 +218,7 @@ func (s *WorkloadService) BuildFullRunConfig(
193218
return nil, fmt.Errorf("failed to retrieve MCP server image: %w", err)
194219
}
195220

196-
if remoteServerMetadata, ok := serverMetadata.(*regtypes.RemoteServerMetadata); ok && remoteServerMetadata != nil {
197-
// Use registry proxy port if not set by request
198-
if req.ProxyPort == 0 && remoteServerMetadata.ProxyPort > 0 {
199-
registryProxyPort = remoteServerMetadata.ProxyPort
200-
}
201-
221+
if remoteServerMetadata, ok := serverMetadata.(*regtypes.RemoteServerMetadata); ok {
202222
if remoteServerMetadata.OAuthConfig != nil {
203223
// Default resource: user-provided > registry metadata > derived from remote URL
204224
resource := req.OAuthConfig.Resource
@@ -234,11 +254,8 @@ func (s *WorkloadService) BuildFullRunConfig(
234254
}
235255
}
236256
}
237-
// Handle server metadata - API only supports container servers.
238-
// Use type assertion with nil check to guard against typed nil pointers.
239-
if md, ok := serverMetadata.(*regtypes.ImageMetadata); ok && md != nil {
240-
imageMetadata = md
241-
}
257+
// Handle server metadata - API only supports container servers
258+
imageMetadata, _ = serverMetadata.(*regtypes.ImageMetadata)
242259
}
243260

244261
// Build RunConfig
@@ -281,6 +298,11 @@ func (s *WorkloadService) BuildFullRunConfig(
281298
runner.WithTelemetryConfigFromFlags("", false, false, false, "", 0.0, nil, false, nil, false),
282299
}
283300

301+
// Runtime overrides only apply to protocol-scheme image builds.
302+
if runtimeConfigOverride != nil && req.URL == "" {
303+
options = append(options, runner.WithRuntimeConfig(runtimeConfigOverride))
304+
}
305+
284306
// Add header forward configuration if specified
285307
if req.HeaderForward != nil {
286308
if len(req.HeaderForward.AddPlaintextHeaders) > 0 {
@@ -291,11 +313,6 @@ func (s *WorkloadService) BuildFullRunConfig(
291313
}
292314
}
293315

294-
// Use registry proxy port for remote servers if not set by request
295-
if registryProxyPort > 0 {
296-
options = append(options, runner.WithRegistryProxyPort(registryProxyPort))
297-
}
298-
299316
// Add existing port if provided (for update operations)
300317
if existingPort > 0 {
301318
options = append(options, runner.WithExistingPort(existingPort))
@@ -305,10 +322,8 @@ func (s *WorkloadService) BuildFullRunConfig(
305322
transportType := "streamable-http"
306323
if req.Transport != "" {
307324
transportType = req.Transport
308-
} else if md, ok := serverMetadata.(*regtypes.ImageMetadata); ok && md != nil {
309-
if t := md.GetTransport(); t != "" {
310-
transportType = t
311-
}
325+
} else if serverMetadata != nil {
326+
transportType = serverMetadata.GetTransport()
312327
}
313328

314329
// Configure middleware from flags
@@ -330,7 +345,7 @@ func (s *WorkloadService) BuildFullRunConfig(
330345

331346
runConfig, err := runner.NewRunConfigBuilder(ctx, imageMetadata, req.EnvVars, &runner.DetachedEnvVarValidator{}, options...)
332347
if err != nil {
333-
slog.Error("failed to build run config", "error", err)
348+
logger.Errorf("Failed to build run config: %v", err)
334349
return nil, fmt.Errorf("%w: Failed to build run config: %w", retriever.ErrInvalidRunConfig, err)
335350
}
336351

@@ -377,6 +392,80 @@ func createRequestToRemoteAuthConfig(
377392
return remoteAuthConfig
378393
}
379394

395+
func runtimeConfigFromRequest(req *createRequest) *templates.RuntimeConfig {
396+
if req == nil || req.RuntimeConfig == nil {
397+
return nil
398+
}
399+
400+
runtimeConfig := &templates.RuntimeConfig{}
401+
if builderImage := strings.TrimSpace(req.RuntimeConfig.BuilderImage); builderImage != "" {
402+
runtimeConfig.BuilderImage = builderImage
403+
}
404+
if len(req.RuntimeConfig.AdditionalPackages) > 0 {
405+
for _, pkg := range req.RuntimeConfig.AdditionalPackages {
406+
if trimmedPkg := strings.TrimSpace(pkg); trimmedPkg != "" {
407+
runtimeConfig.AdditionalPackages = append(runtimeConfig.AdditionalPackages, trimmedPkg)
408+
}
409+
}
410+
}
411+
if runtimeConfig.BuilderImage == "" && len(runtimeConfig.AdditionalPackages) == 0 {
412+
return nil
413+
}
414+
415+
return runtimeConfig
416+
}
417+
418+
func validateRuntimeConfig(runtimeConfig *templates.RuntimeConfig) error {
419+
if runtimeConfig == nil {
420+
return nil
421+
}
422+
423+
if runtimeConfig.BuilderImage != "" {
424+
if _, err := nameref.ParseReference(runtimeConfig.BuilderImage); err != nil {
425+
return fmt.Errorf("runtime_config.builder_image must be a valid container image reference")
426+
}
427+
}
428+
429+
for _, pkg := range runtimeConfig.AdditionalPackages {
430+
if !isValidRuntimePackageName(pkg) {
431+
return fmt.Errorf("runtime_config.additional_packages contains invalid package name %q", pkg)
432+
}
433+
}
434+
435+
return nil
436+
}
437+
438+
func runtimeConfigForImageBuild(req *createRequest, runtimeConfigOverride *templates.RuntimeConfig) (*templates.RuntimeConfig, error) {
439+
if runtimeConfigOverride == nil || req == nil {
440+
return nil, nil
441+
}
442+
if err := validateRuntimeConfig(runtimeConfigOverride); err != nil {
443+
return nil, err
444+
}
445+
if req.URL != "" || !runner.IsImageProtocolScheme(req.Image) {
446+
return nil, fmt.Errorf("runtime_config is only supported for protocol-scheme images")
447+
}
448+
449+
transportType, _, err := runner.ParseProtocolScheme(req.Image)
450+
if err != nil {
451+
return nil, err
452+
}
453+
454+
baseConfig := runner.GetBaseRuntimeConfig(transportType)
455+
merged := &templates.RuntimeConfig{
456+
BuilderImage: baseConfig.BuilderImage,
457+
AdditionalPackages: append([]string{}, baseConfig.AdditionalPackages...),
458+
}
459+
if runtimeConfigOverride.BuilderImage != "" {
460+
merged.BuilderImage = runtimeConfigOverride.BuilderImage
461+
}
462+
if len(runtimeConfigOverride.AdditionalPackages) > 0 {
463+
merged.AdditionalPackages = append(merged.AdditionalPackages, runtimeConfigOverride.AdditionalPackages...)
464+
}
465+
466+
return merged, nil
467+
}
468+
380469
// GetWorkloadNamesFromRequest gets workload names from either the names field or group
381470
func (s *WorkloadService) GetWorkloadNamesFromRequest(ctx context.Context, req bulkOperationRequest) ([]string, error) {
382471
if len(req.Names) > 0 {

0 commit comments

Comments
 (0)