diff --git a/CLAUDE.md b/CLAUDE.md index 7552334..69217cc 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -74,6 +74,7 @@ task license-fix # Add missing license headers | `logging` | Pre-configured `*slog.Logger` factory with consistent ToolHive defaults (Alpha) | | `oci/artifact` | Artifact-agnostic OCI tar/gzip/extraction/platform primitives shared by oci/skills and oci/plugins (Alpha) | | `oci/skills` | OCI artifact types, media types, and registry operations for ToolHive skills (Alpha) | +| `oci/plugins` | OCI artifact types, media types, and registry operations for ToolHive plugins (Alpha) | | `postgres` | PostgreSQL connection pool with optional AWS RDS IAM dynamic auth (Alpha) | | `recovery` | HTTP panic recovery middleware (Beta) | | `validation/http` | RFC 7230/8707 compliant HTTP header and URI validation | diff --git a/README.md b/README.md index 436ecbe..90c2f72 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ The ToolHive ecosystem spans multiple Go repositories, and several of these proj | `httperr` | Stable | Wrap errors with HTTP status codes | | `logging` | Alpha | Pre-configured `*slog.Logger` factory with consistent ToolHive defaults | | `oci/skills` | Alpha | OCI artifact types, media types, and registry operations for skills | +| `oci/plugins` | Alpha | OCI artifact types, media types, and registry operations for plugins | | `postgres` | Alpha | PostgreSQL connection pool with optional AWS RDS IAM dynamic auth | | `recovery` | Beta | HTTP panic recovery middleware | | `validation/http` | Stable | RFC 7230/8707 compliant HTTP header and URI validation | diff --git a/oci/plugins/doc.go b/oci/plugins/doc.go new file mode 100644 index 0000000..4f2169f --- /dev/null +++ b/oci/plugins/doc.go @@ -0,0 +1,53 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +/* +Package plugins provides OCI artifact types, media types, local storage, and +remote registry operations for ToolHive plugin packages. + +A plugin is an OCI artifact containing a .claude-plugin/plugin.json manifest and +its component directories. This package defines the constants, data structures, +storage layer, packager, and registry client that the rest of the ToolHive +ecosystem uses to package, push, pull, and cache plugins as OCI images. + +# Media Types and Constants + +Standard OCI media types and ToolHive-specific annotation/label keys: + + // Artifact type identifies a plugin manifest + plugins.ArtifactTypePlugin // "dev.toolhive.plugins.v1" + + // Annotations carry metadata in manifests + plugins.AnnotationPluginName + plugins.AnnotationPluginVersion + plugins.AnnotationPluginComponents + + // Labels carry metadata in OCI image configs + plugins.LabelPluginName + plugins.LabelPluginFiles + +# Registry Client + +The [Registry] type implements [RegistryClient] for pushing and pulling plugin +artifacts to/from OCI-compliant registries (GHCR, ECR, Docker Hub, etc.). It +uses ORAS for registry operations and the Docker credential store for +authentication by default: + + reg, err := plugins.NewRegistry() + // Push an artifact + err = reg.Push(ctx, store, indexDigest, "ghcr.io/myorg/my-plugin:v1.0.0") + // Pull an artifact + digest, err := reg.Pull(ctx, store, "ghcr.io/myorg/my-plugin:v1.0.0") + +Use functional options to customise behaviour: + + reg, err := plugins.NewRegistry( + plugins.WithPlainHTTP(true), // for local dev registries + plugins.WithCredentialStore(myStore), // custom auth + ) + +# Stability + +This package is Alpha. Breaking changes are possible between minor versions. +*/ +package plugins diff --git a/oci/plugins/errors.go b/oci/plugins/errors.go new file mode 100644 index 0000000..3cfe5fb --- /dev/null +++ b/oci/plugins/errors.go @@ -0,0 +1,36 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package plugins + +import "errors" + +// Sentinel errors returned (wrapped) by the packager so callers can classify +// failures with errors.Is instead of matching error message strings. The +// underlying error message is preserved at each call site via fmt.Errorf +// with %w; only the classification is added. +var ( + // ErrInvalidPluginDir indicates the plugin directory is missing, not a + // directory, or otherwise unsafe to read (e.g. contains path traversal). + ErrInvalidPluginDir = errors.New("invalid plugin directory") + + // ErrPluginManifestMissing indicates .claude-plugin/plugin.json is not + // present in the plugin directory. + ErrPluginManifestMissing = errors.New(".claude-plugin/plugin.json missing") + + // ErrInvalidPluginManifest indicates the plugin manifest is malformed, + // oversized, or missing required fields such as the plugin name. + ErrInvalidPluginManifest = errors.New("invalid plugin manifest") + + // ErrTooManyFiles indicates the plugin directory exceeds the maximum + // allowed number of files. + ErrTooManyFiles = errors.New("too many files in plugin directory") + + // ErrPluginTooLarge indicates the plugin directory exceeds the maximum + // allowed total size. + ErrPluginTooLarge = errors.New("plugin directory too large") + + // ErrInvalidPluginFile indicates a per-file issue inside the plugin + // directory: a symlink, a non-regular file, or an unreadable entry. + ErrInvalidPluginFile = errors.New("invalid plugin file") +) diff --git a/oci/plugins/integration_test.go b/oci/plugins/integration_test.go new file mode 100644 index 0000000..185b8ac --- /dev/null +++ b/oci/plugins/integration_test.go @@ -0,0 +1,302 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package plugins + +import ( + "encoding/json" + "testing" + "time" + + "github.com/opencontainers/go-digest" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "oras.land/oras-go/v2" + "oras.land/oras-go/v2/content/memory" + "oras.land/oras-go/v2/registry" + + "github.com/stacklok/toolhive-core/oci/artifact" +) + +// TestIntegration_PackagePushPull exercises the full e2e flow: +// package a plugin → push to an in-memory registry → pull into a fresh store → +// verify all content (index, manifests, config, layer, tags, extracted files). +func TestIntegration_PackagePushPull(t *testing.T) { + t.Parallel() + + ctx := t.Context() + ref := "example.com/myorg/integration-plugin:v1.0.0" + + // --- Setup: create a plugin directory with components --- + pluginDir := createTestPluginDir(t) + + // --- Phase 1: Package --- + packageStore, err := NewStore(t.TempDir()) + require.NoError(t, err) + + packager := NewPackager(packageStore) + opts := PackageOptions{Epoch: time.Unix(1000000, 0).UTC()} + + result, err := packager.Package(ctx, pluginDir, opts) + require.NoError(t, err) + + assert.Equal(t, testPluginName, result.Config.Name) + assert.Equal(t, "A test plugin for packaging", result.Config.Description) + assert.Equal(t, "1.0.0", result.Config.Version) + assert.Equal(t, "Apache-2.0", result.Config.License) + assert.Contains(t, result.Config.Files, ManifestFileName) + assert.Contains(t, result.Config.Files, "commands/test.md") + assert.Equal(t, artifact.DefaultPlatforms, result.Platforms) + + // Verify the index was stored and is well-formed + isIdx, err := packageStore.IsIndex(ctx, result.IndexDigest) + require.NoError(t, err) + assert.True(t, isIdx, "packaged artifact should be an index") + + idx, err := packageStore.GetIndex(ctx, result.IndexDigest) + require.NoError(t, err) + assert.Equal(t, ocispec.MediaTypeImageIndex, idx.MediaType) + assert.Equal(t, ArtifactTypePlugin, idx.ArtifactType) + require.Len(t, idx.Manifests, len(artifact.DefaultPlatforms)) + + // --- Phase 2: Push to in-memory registry --- + remoteStore := memory.New() + reg := &Registry{ + newTarget: func(_ registry.Reference) (oras.Target, error) { + return remoteStore, nil + }, + } + + err = reg.Push(ctx, packageStore, result.IndexDigest, ref) + require.NoError(t, err) + + // Verify the remote has the content tagged + remoteDesc, err := remoteStore.Resolve(ctx, "v1.0.0") + require.NoError(t, err) + assert.Equal(t, result.IndexDigest, remoteDesc.Digest) + + // --- Phase 3: Pull into a fresh store --- + pullStore, err := NewStore(t.TempDir()) + require.NoError(t, err) + + pulledDigest, err := reg.Pull(ctx, pullStore, ref) + require.NoError(t, err) + assert.Equal(t, result.IndexDigest, pulledDigest, "pulled digest should match packaged index digest") + + // --- Phase 4: Verify pulled content --- + + // 4a. Tag resolution + resolved, err := pullStore.Resolve(ctx, ref) + require.NoError(t, err) + assert.Equal(t, pulledDigest, resolved) + + // 4b. Index is intact + pulledIdx, err := pullStore.GetIndex(ctx, pulledDigest) + require.NoError(t, err) + assert.Equal(t, ocispec.MediaTypeImageIndex, pulledIdx.MediaType) + assert.Equal(t, ArtifactTypePlugin, pulledIdx.ArtifactType) + require.Len(t, pulledIdx.Manifests, len(artifact.DefaultPlatforms)) + + // 4c. Each platform manifest, config, and layer are present and correct + for _, desc := range pulledIdx.Manifests { + require.NotNil(t, desc.Platform) + platformStr := desc.Platform.OS + "/" + desc.Platform.Architecture + + // Manifest + manifestBytes, err := pullStore.GetManifest(ctx, desc.Digest) + require.NoError(t, err, "manifest for %s should be present", platformStr) + + var manifest ocispec.Manifest + require.NoError(t, json.Unmarshal(manifestBytes, &manifest)) + + assert.Equal(t, ocispec.MediaTypeImageManifest, manifest.MediaType) + assert.Equal(t, ArtifactTypePlugin, manifest.ArtifactType) + assert.Equal(t, testPluginName, manifest.Annotations[AnnotationPluginName]) + assert.Equal(t, "1.0.0", manifest.Annotations[AnnotationPluginVersion]) + require.Len(t, manifest.Layers, 1) + + // Config + configBytes, err := pullStore.GetBlob(ctx, manifest.Config.Digest) + require.NoError(t, err, "config for %s should be present", platformStr) + + var ociConfig ocispec.Image + require.NoError(t, json.Unmarshal(configBytes, &ociConfig)) + + assert.Equal(t, desc.Platform.OS, ociConfig.OS) + assert.Equal(t, desc.Platform.Architecture, ociConfig.Architecture) + + labels := ociConfig.Config.Labels + require.NotNil(t, labels) + assert.Equal(t, testPluginName, labels[LabelPluginName]) + assert.Equal(t, "1.0.0", labels[LabelPluginVersion]) + + config, err := PluginConfigFromImageConfig(&ociConfig) + require.NoError(t, err) + assert.Equal(t, testPluginName, config.Name) + assert.Equal(t, []string{testRequireServerV1, testRequireSkillV1}, config.Requires) + + // Layer — extract and verify files + layerBytes, err := pullStore.GetBlob(ctx, manifest.Layers[0].Digest) + require.NoError(t, err, "layer for %s should be present", platformStr) + + files, err := artifact.DecompressTar(layerBytes) + require.NoError(t, err) + + fileMap := make(map[string][]byte, len(files)) + for _, f := range files { + fileMap[f.Path] = f.Content + } + + // Verify the plugin manifest is present and has correct content + manifestJSON, ok := fileMap[ManifestFileName] + require.True(t, ok, "%s should be in the layer", ManifestFileName) + assert.Contains(t, string(manifestJSON), testPluginName) + + // Verify a component file is present + command, ok := fileMap["commands/test.md"] + require.True(t, ok, "commands/test.md should be in the layer") + assert.Contains(t, string(command), "Test Command") + } +} + +// TestIntegration_PushPull_TwoVersions verifies that pushing two versions +// of the same plugin and pulling them both results in correct content. +func TestIntegration_PushPull_TwoVersions(t *testing.T) { + t.Parallel() + + ctx := t.Context() + remoteStore := memory.New() + reg := &Registry{ + newTarget: func(_ registry.Reference) (oras.Target, error) { + return remoteStore, nil + }, + } + + // Package and push v1 + v1Dir := createVersionedPluginDir(t, "1.0.0") + v1Store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + v1Result, err := NewPackager(v1Store).Package(ctx, v1Dir, PackageOptions{ + Epoch: time.Unix(1000, 0).UTC(), + }) + require.NoError(t, err) + + ref1 := "example.com/myorg/versioned-plugin:v1.0.0" + err = reg.Push(ctx, v1Store, v1Result.IndexDigest, ref1) + require.NoError(t, err) + + // Package and push v2 + v2Dir := createVersionedPluginDir(t, "2.0.0") + v2Store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + v2Result, err := NewPackager(v2Store).Package(ctx, v2Dir, PackageOptions{ + Epoch: time.Unix(2000, 0).UTC(), + }) + require.NoError(t, err) + + ref2 := "example.com/myorg/versioned-plugin:v2.0.0" + err = reg.Push(ctx, v2Store, v2Result.IndexDigest, ref2) + require.NoError(t, err) + + // Digests should differ + assert.NotEqual(t, v1Result.IndexDigest, v2Result.IndexDigest) + + // Pull both into the same store + pullStore, err := NewStore(t.TempDir()) + require.NoError(t, err) + + pulledV1, err := reg.Pull(ctx, pullStore, ref1) + require.NoError(t, err) + assert.Equal(t, v1Result.IndexDigest, pulledV1) + + pulledV2, err := reg.Pull(ctx, pullStore, ref2) + require.NoError(t, err) + assert.Equal(t, v2Result.IndexDigest, pulledV2) + + // Both tags resolve correctly in the same store + resolvedV1, err := pullStore.Resolve(ctx, ref1) + require.NoError(t, err) + assert.Equal(t, pulledV1, resolvedV1) + + resolvedV2, err := pullStore.Resolve(ctx, ref2) + require.NoError(t, err) + assert.Equal(t, pulledV2, resolvedV2) + + // Verify version annotations on each + for _, tc := range []struct { + dig digest.Digest + version string + }{ + {pulledV1, "1.0.0"}, + {pulledV2, "2.0.0"}, + } { + idx, err := pullStore.GetIndex(ctx, tc.dig) + require.NoError(t, err) + assert.Equal(t, tc.version, idx.Annotations[AnnotationPluginVersion]) + } +} + +// TestIntegration_PullPreservesBlobs verifies that after a pull, the pulled +// blobs can be used to reconstruct the original plugin content byte-for-byte. +func TestIntegration_PullPreservesBlobs(t *testing.T) { + t.Parallel() + + ctx := t.Context() + remoteStore := memory.New() + reg := &Registry{ + newTarget: func(_ registry.Reference) (oras.Target, error) { + return remoteStore, nil + }, + } + + pluginDir := createTestPluginDir(t) + packageStore, err := NewStore(t.TempDir()) + require.NoError(t, err) + + opts := PackageOptions{Epoch: time.Unix(0, 0).UTC()} + result, err := NewPackager(packageStore).Package(ctx, pluginDir, opts) + require.NoError(t, err) + + ref := "example.com/myorg/blob-test:v1.0.0" + err = reg.Push(ctx, packageStore, result.IndexDigest, ref) + require.NoError(t, err) + + pullStore, err := NewStore(t.TempDir()) + require.NoError(t, err) + + _, err = reg.Pull(ctx, pullStore, ref) + require.NoError(t, err) + + // Get the original layer bytes from the package store + originalLayer, err := packageStore.GetBlob(ctx, result.LayerDigest) + require.NoError(t, err) + + // Get the pulled layer bytes + pulledLayer, err := pullStore.GetBlob(ctx, result.LayerDigest) + require.NoError(t, err) + + assert.Equal(t, originalLayer, pulledLayer, "layer content should be byte-for-byte identical after pull") + + // Same for config + originalConfig, err := packageStore.GetBlob(ctx, result.ConfigDigest) + require.NoError(t, err) + + pulledConfig, err := pullStore.GetBlob(ctx, result.ConfigDigest) + require.NoError(t, err) + + assert.Equal(t, originalConfig, pulledConfig, "config content should be byte-for-byte identical after pull") +} + +// --- integration test helpers --- + +// createVersionedPluginDir creates a plugin directory with the given version. +func createVersionedPluginDir(t *testing.T, version string) string { + t.Helper() + + dir := t.TempDir() + writeManifest(t, dir, `{"name":"versioned-plugin","description":"Versioned plugin","version":"`+version+`"}`) + return dir +} diff --git a/oci/plugins/interfaces.go b/oci/plugins/interfaces.go new file mode 100644 index 0000000..ab216f6 --- /dev/null +++ b/oci/plugins/interfaces.go @@ -0,0 +1,49 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package plugins + +//go:generate mockgen -copyright_file=../../.github/license-header.txt -source=interfaces.go -destination=mocks/mock_interfaces.go -package=mocks + +import ( + "context" + "time" + + "github.com/opencontainers/go-digest" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" +) + +// RegistryClient provides remote OCI registry operations for plugins. +type RegistryClient interface { + // Push pushes an artifact from the local store to a remote registry. + Push(ctx context.Context, store *Store, manifestDigest digest.Digest, ref string) error + + // Pull pulls an artifact from a remote registry into the local store. + Pull(ctx context.Context, store *Store, ref string) (digest.Digest, error) +} + +// PluginPackager creates OCI artifacts from plugin directories. +type PluginPackager interface { + // Package packages a plugin directory into an OCI artifact in the local store. + Package(ctx context.Context, pluginDir string, opts PackageOptions) (*PackageResult, error) +} + +// PackageOptions configures plugin packaging. +type PackageOptions struct { + // Epoch is the timestamp to use for reproducible builds. + Epoch time.Time + + // Platforms specifies target platforms for the image index. + // If empty, defaults to DefaultPlatforms. + Platforms []ocispec.Platform +} + +// PackageResult contains the result of packaging a plugin. +type PackageResult struct { + IndexDigest digest.Digest + ManifestDigest digest.Digest + ConfigDigest digest.Digest + LayerDigest digest.Digest + Config *PluginConfig + Platforms []ocispec.Platform +} diff --git a/oci/plugins/mediatypes.go b/oci/plugins/mediatypes.go new file mode 100644 index 0000000..4470ce6 --- /dev/null +++ b/oci/plugins/mediatypes.go @@ -0,0 +1,157 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package plugins + +import ( + "encoding/json" + "fmt" + + ocispec "github.com/opencontainers/image-spec/specs-go/v1" +) + +// Artifact type for plugin identification. +const ( + // ArtifactTypePlugin identifies plugin artifacts in manifests. + ArtifactTypePlugin = "dev.toolhive.plugins.v1" +) + +// Annotation keys for plugin metadata in manifests. +const ( + // AnnotationPluginName is the annotation key for plugin name. + AnnotationPluginName = "dev.toolhive.plugins.name" + + // AnnotationPluginDescription is the annotation key for plugin description. + AnnotationPluginDescription = "dev.toolhive.plugins.description" + + // AnnotationPluginVersion is the annotation key for plugin version. + AnnotationPluginVersion = "dev.toolhive.plugins.version" + + // AnnotationPluginLicense is the annotation key for plugin license. + AnnotationPluginLicense = "dev.toolhive.plugins.license" + + // AnnotationPluginFiles is the annotation key for plugin files (JSON array). + AnnotationPluginFiles = "dev.toolhive.plugins.files" + + // AnnotationPluginComponents is the annotation key for plugin component inventory (JSON object). + AnnotationPluginComponents = "dev.toolhive.plugins.components" + + // AnnotationPluginRequires is the annotation key for plugin external dependencies (JSON array of OCI references). + AnnotationPluginRequires = "dev.toolhive.plugins.requires" +) + +// Label keys for plugin metadata in OCI image config. +const ( + // LabelPluginName is the label key for plugin name. + LabelPluginName = "dev.toolhive.plugins.name" + + // LabelPluginDescription is the label key for plugin description. + LabelPluginDescription = "dev.toolhive.plugins.description" + + // LabelPluginVersion is the label key for plugin version. + LabelPluginVersion = "dev.toolhive.plugins.version" + + // LabelPluginLicense is the label key for plugin license. + LabelPluginLicense = "dev.toolhive.plugins.license" + + // LabelPluginFiles is the label key for plugin files (JSON array). + LabelPluginFiles = "dev.toolhive.plugins.files" + + // LabelPluginComponents is the label key for plugin component inventory (JSON object). + LabelPluginComponents = "dev.toolhive.plugins.components" + + // LabelPluginRequires is the label key for plugin external dependencies (JSON array of OCI references). + LabelPluginRequires = "dev.toolhive.plugins.requires" +) + +// ComponentInventory summarizes the component types declared by a plugin. +type ComponentInventory map[string]int + +// PluginConfig represents plugin metadata extracted from OCI image config labels. +type PluginConfig struct { + Name string `json:"name"` + Description string `json:"description"` + Version string `json:"version,omitempty"` + License string `json:"license,omitempty"` + Files []string `json:"files"` + Components ComponentInventory `json:"components,omitempty"` + Requires []string `json:"requires,omitempty"` +} + +// PluginConfigFromImageConfig extracts PluginConfig from OCI image config labels. +func PluginConfigFromImageConfig(imgConfig *ocispec.Image) (*PluginConfig, error) { + if imgConfig == nil { + return nil, fmt.Errorf("image config is nil") + } + + labels := imgConfig.Config.Labels + if labels == nil { + return nil, fmt.Errorf("oci config has no labels") + } + + config := &PluginConfig{ + Name: labels[LabelPluginName], + Description: labels[LabelPluginDescription], + Version: labels[LabelPluginVersion], + License: labels[LabelPluginLicense], + } + + if config.Name == "" { + return nil, fmt.Errorf("plugin name is required in labels") + } + + // Parse JSON-encoded metadata. + if filesJSON := labels[LabelPluginFiles]; filesJSON != "" { + if err := json.Unmarshal([]byte(filesJSON), &config.Files); err != nil { + return nil, fmt.Errorf("parsing files: %w", err) + } + } + + if componentsJSON := labels[LabelPluginComponents]; componentsJSON != "" { + if err := json.Unmarshal([]byte(componentsJSON), &config.Components); err != nil { + return nil, fmt.Errorf("parsing components: %w", err) + } + } + + if requiresJSON := labels[LabelPluginRequires]; requiresJSON != "" { + if err := json.Unmarshal([]byte(requiresJSON), &config.Requires); err != nil { + return nil, fmt.Errorf("parsing requires: %w", err) + } + } + + return config, nil +} + +// ParseComponentsAnnotation parses plugin component inventory from manifest annotations. +// Returns nil if the annotation is missing or invalid. +func ParseComponentsAnnotation(annotations map[string]string) ComponentInventory { + componentsJSON := annotations[AnnotationPluginComponents] + if componentsJSON == "" { + return nil + } + + var components ComponentInventory + if err := json.Unmarshal([]byte(componentsJSON), &components); err != nil { + // Invalid annotation format - return nil rather than propagating error + // since annotations may come from older versions or external sources. + return nil + } + return components +} + +// ParseRequiresAnnotation parses plugin dependency references from manifest annotations. +// Returns nil if the annotation is missing or invalid. +func ParseRequiresAnnotation(annotations map[string]string) []string { + requiresJSON := annotations[AnnotationPluginRequires] + if requiresJSON == "" { + return nil + } + + var refs []string + if err := json.Unmarshal([]byte(requiresJSON), &refs); err != nil { + // Invalid annotation format - return nil rather than propagating error + // since annotations may come from older versions or external sources. + return nil + } + return refs +} diff --git a/oci/plugins/mediatypes_test.go b/oci/plugins/mediatypes_test.go new file mode 100644 index 0000000..7237a0d --- /dev/null +++ b/oci/plugins/mediatypes_test.go @@ -0,0 +1,318 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package plugins + +import ( + "encoding/json" + "testing" + + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPluginConfigFromImageConfig(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config *ocispec.Image + wantName string + wantErr bool + wantFiles []string + wantComponents ComponentInventory + wantRequires []string + }{ + { + name: "all fields populated", + config: &ocispec.Image{ + Config: ocispec.ImageConfig{ + Labels: map[string]string{ + LabelPluginName: testPluginMyPlugin, + LabelPluginDescription: "A test plugin", + LabelPluginVersion: "1.0.0", + LabelPluginLicense: "Apache-2.0", + LabelPluginFiles: `[".claude-plugin/plugin.json","commands/test.md"]`, + LabelPluginComponents: `{"commands":1,"skills":2}`, + LabelPluginRequires: `["ghcr.io/org/server:v1"]`, + }, + }, + }, + wantName: testPluginMyPlugin, + wantFiles: []string{ManifestFileName, "commands/test.md"}, + wantComponents: ComponentInventory{testComponentCommands: 1, "skills": 2}, + wantRequires: []string{"ghcr.io/org/server:v1"}, + }, + { + name: "minimal config", + config: &ocispec.Image{ + Config: ocispec.ImageConfig{ + Labels: map[string]string{ + LabelPluginName: testPluginMinimal, + }, + }, + }, + wantName: testPluginMinimal, + }, + { + name: "nil config", + config: nil, + wantErr: true, + }, + { + name: "nil labels", + config: &ocispec.Image{ + Config: ocispec.ImageConfig{Labels: nil}, + }, + wantErr: true, + }, + { + name: "missing name", + config: &ocispec.Image{ + Config: ocispec.ImageConfig{ + Labels: map[string]string{ + LabelPluginDescription: "no name", + }, + }, + }, + wantErr: true, + }, + { + name: "invalid files JSON", + config: &ocispec.Image{ + Config: ocispec.ImageConfig{ + Labels: map[string]string{ + LabelPluginName: "bad-files", + LabelPluginFiles: testNotJSON, + }, + }, + }, + wantErr: true, + }, + { + name: "invalid components JSON", + config: &ocispec.Image{ + Config: ocispec.ImageConfig{ + Labels: map[string]string{ + LabelPluginName: "bad-components", + LabelPluginComponents: testNotJSON, + }, + }, + }, + wantErr: true, + }, + { + name: "invalid requires JSON", + config: &ocispec.Image{ + Config: ocispec.ImageConfig{ + Labels: map[string]string{ + LabelPluginName: "bad-requires", + LabelPluginRequires: testNotJSON, + }, + }, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := PluginConfigFromImageConfig(tt.config) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantName, got.Name) + if tt.wantFiles != nil { + assert.Equal(t, tt.wantFiles, got.Files) + } + if tt.wantComponents != nil { + assert.Equal(t, tt.wantComponents, got.Components) + } + if tt.wantRequires != nil { + assert.Equal(t, tt.wantRequires, got.Requires) + } + }) + } +} + +// pluginConfigToLabels serializes a PluginConfig into OCI image config Labels +// exactly the way createOCIConfig does in packager.go. Keeping this helper in +// lock-step with createOCIConfig lets the round-trip test guard the full +// serialize → parse cycle without duplicating that logic in each case. +func pluginConfigToLabels(t *testing.T, cfg *PluginConfig) map[string]string { + t.Helper() + + filesJSON, err := json.Marshal(cfg.Files) + require.NoError(t, err) + componentsJSON, err := json.Marshal(cfg.Components) + require.NoError(t, err) + requiresJSON, err := json.Marshal(cfg.Requires) + require.NoError(t, err) + + return map[string]string{ + LabelPluginName: cfg.Name, + LabelPluginDescription: cfg.Description, + LabelPluginVersion: cfg.Version, + LabelPluginLicense: cfg.License, + LabelPluginFiles: string(filesJSON), + LabelPluginComponents: string(componentsJSON), + LabelPluginRequires: string(requiresJSON), + } +} + +// TestPluginConfig_RoundTrip locks in that a PluginConfig serialized into OCI +// image config labels (the way createOCIConfig does) and parsed back with +// PluginConfigFromImageConfig deep-equals the original. +func TestPluginConfig_RoundTrip(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *PluginConfig + }{ + { + name: "fully populated", + cfg: &PluginConfig{ + Name: testPluginMyPlugin, + Description: "A fully populated plugin", + Version: "1.2.3", + License: "Apache-2.0", + Files: []string{ManifestFileName, "commands/test.md", "agents/reviewer.md"}, + Components: ComponentInventory{testComponentCommands: 1, testComponentSkills: 2}, + Requires: []string{testRequireServerV1, testRequireSkillV1}, + }, + }, + { + name: "minimal name only", + cfg: &PluginConfig{ + Name: testPluginMinimal, + Files: []string{ManifestFileName}, + }, + }, + { + // Regression guard for the Fix-1 nil behaviour: componentInventory + // returns nil for a zero-component plugin so the empty map dropped + // by `omitempty` round-trips back to nil rather than an empty map. + name: "zero components and zero requires", + cfg: &PluginConfig{ + Name: "no-components-plugin", + Description: "A plugin with no components or dependencies", + Version: "0.1.0", + Files: []string{ManifestFileName, "README.md"}, + Components: nil, + Requires: nil, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + img := &ocispec.Image{ + Config: ocispec.ImageConfig{ + Labels: pluginConfigToLabels(t, tt.cfg), + }, + } + + got, err := PluginConfigFromImageConfig(img) + require.NoError(t, err) + assert.Equal(t, tt.cfg, got) + }) + } +} + +func TestParseComponentsAnnotation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + annotations map[string]string + want ComponentInventory + }{ + { + name: "valid components", + annotations: map[string]string{ + AnnotationPluginComponents: `{"commands":3,"agents":1}`, + }, + want: ComponentInventory{testComponentCommands: 3, "agents": 1}, + }, + { + name: "empty annotations", + annotations: map[string]string{}, + want: nil, + }, + { + name: "missing annotation", + annotations: map[string]string{ + "other.key": "value", + }, + want: nil, + }, + { + name: testNameInvalidJSON, + annotations: map[string]string{ + AnnotationPluginComponents: testNotJSON, + }, + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := ParseComponentsAnnotation(tt.annotations) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestParseRequiresAnnotation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + annotations map[string]string + want []string + }{ + { + name: "valid refs", + annotations: map[string]string{ + AnnotationPluginRequires: `["ghcr.io/org/plugin1:v1","ghcr.io/org/plugin2:v2"]`, + }, + want: []string{"ghcr.io/org/plugin1:v1", "ghcr.io/org/plugin2:v2"}, + }, + { + name: "empty annotations", + annotations: map[string]string{}, + want: nil, + }, + { + name: "missing annotation", + annotations: map[string]string{ + "other.key": "value", + }, + want: nil, + }, + { + name: testNameInvalidJSON, + annotations: map[string]string{ + AnnotationPluginRequires: testNotJSON, + }, + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := ParseRequiresAnnotation(tt.annotations) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/oci/plugins/mocks/mock_interfaces.go b/oci/plugins/mocks/mock_interfaces.go new file mode 100644 index 0000000..7051754 --- /dev/null +++ b/oci/plugins/mocks/mock_interfaces.go @@ -0,0 +1,115 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 +// + +// Code generated by MockGen. DO NOT EDIT. +// Source: interfaces.go +// +// Generated by this command: +// +// mockgen -copyright_file=../../.github/license-header.txt -source=interfaces.go -destination=mocks/mock_interfaces.go -package=mocks +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + digest "github.com/opencontainers/go-digest" + plugins "github.com/stacklok/toolhive-core/oci/plugins" + gomock "go.uber.org/mock/gomock" +) + +// MockRegistryClient is a mock of RegistryClient interface. +type MockRegistryClient struct { + ctrl *gomock.Controller + recorder *MockRegistryClientMockRecorder + isgomock struct{} +} + +// MockRegistryClientMockRecorder is the mock recorder for MockRegistryClient. +type MockRegistryClientMockRecorder struct { + mock *MockRegistryClient +} + +// NewMockRegistryClient creates a new mock instance. +func NewMockRegistryClient(ctrl *gomock.Controller) *MockRegistryClient { + mock := &MockRegistryClient{ctrl: ctrl} + mock.recorder = &MockRegistryClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRegistryClient) EXPECT() *MockRegistryClientMockRecorder { + return m.recorder +} + +// Pull mocks base method. +func (m *MockRegistryClient) Pull(ctx context.Context, store *plugins.Store, ref string) (digest.Digest, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Pull", ctx, store, ref) + ret0, _ := ret[0].(digest.Digest) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Pull indicates an expected call of Pull. +func (mr *MockRegistryClientMockRecorder) Pull(ctx, store, ref any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Pull", reflect.TypeOf((*MockRegistryClient)(nil).Pull), ctx, store, ref) +} + +// Push mocks base method. +func (m *MockRegistryClient) Push(ctx context.Context, store *plugins.Store, manifestDigest digest.Digest, ref string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Push", ctx, store, manifestDigest, ref) + ret0, _ := ret[0].(error) + return ret0 +} + +// Push indicates an expected call of Push. +func (mr *MockRegistryClientMockRecorder) Push(ctx, store, manifestDigest, ref any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Push", reflect.TypeOf((*MockRegistryClient)(nil).Push), ctx, store, manifestDigest, ref) +} + +// MockPluginPackager is a mock of PluginPackager interface. +type MockPluginPackager struct { + ctrl *gomock.Controller + recorder *MockPluginPackagerMockRecorder + isgomock struct{} +} + +// MockPluginPackagerMockRecorder is the mock recorder for MockPluginPackager. +type MockPluginPackagerMockRecorder struct { + mock *MockPluginPackager +} + +// NewMockPluginPackager creates a new mock instance. +func NewMockPluginPackager(ctrl *gomock.Controller) *MockPluginPackager { + mock := &MockPluginPackager{ctrl: ctrl} + mock.recorder = &MockPluginPackagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPluginPackager) EXPECT() *MockPluginPackagerMockRecorder { + return m.recorder +} + +// Package mocks base method. +func (m *MockPluginPackager) Package(ctx context.Context, pluginDir string, opts plugins.PackageOptions) (*plugins.PackageResult, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Package", ctx, pluginDir, opts) + ret0, _ := ret[0].(*plugins.PackageResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Package indicates an expected call of Package. +func (mr *MockPluginPackagerMockRecorder) Package(ctx, pluginDir, opts any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Package", reflect.TypeOf((*MockPluginPackager)(nil).Package), ctx, pluginDir, opts) +} diff --git a/oci/plugins/packager.go b/oci/plugins/packager.go new file mode 100644 index 0000000..8bd5f02 --- /dev/null +++ b/oci/plugins/packager.go @@ -0,0 +1,671 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package plugins + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io/fs" + "os" + "path/filepath" + "slices" + "strconv" + "strings" + "time" + + "github.com/opencontainers/go-digest" + specs "github.com/opencontainers/image-spec/specs-go" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + + "github.com/stacklok/toolhive-core/oci/artifact" +) + +const ( + // ManifestFileName is the required manifest file name for a plugin directory. + ManifestFileName = ".claude-plugin/plugin.json" + + // maxManifestSize limits plugin.json to prevent JSON parsing attacks. + maxManifestSize = 64 * 1024 + + // maxPluginFiles limits the number of files in a plugin directory to prevent + // memory exhaustion during packaging. + maxPluginFiles = 1_000 + + // maxPluginTotalSize limits the total aggregate size of all files in a plugin + // directory to prevent memory exhaustion during packaging. Kept below + // artifact.MaxDecompressedSize (100 MB) so that per-file tar header overhead + // cannot push a packaged artifact past the limit enforced on extraction. + maxPluginTotalSize int64 = 95 * 1024 * 1024 + + pluginLayerAnnotation = "plugin.tar.gz" +) + +// Packager creates reproducible OCI artifacts from plugin directories. +type Packager struct { + store *Store +} + +// manifestInfo holds a manifest digest along with its size. +type manifestInfo struct { + digest digest.Digest + size int64 +} + +// pluginManifest represents fields ToolHive reads from .claude-plugin/plugin.json. +type pluginManifest struct { + Name string `json:"name"` + Description string `json:"description"` + Version string `json:"version,omitempty"` + License string `json:"license,omitempty"` + Commands json.RawMessage `json:"commands,omitempty"` + Agents json.RawMessage `json:"agents,omitempty"` + Skills json.RawMessage `json:"skills,omitempty"` + Hooks json.RawMessage `json:"hooks,omitempty"` + MCPServers json.RawMessage `json:"mcpServers,omitempty"` + LSPServers json.RawMessage `json:"lspServers,omitempty"` + Dependencies json.RawMessage `json:"dependencies,omitempty"` + Requires json.RawMessage `json:"requires,omitempty"` +} + +// pluginFile is a regular file collected from a plugin directory, retaining its +// permission bits so executable scripts survive a package/extract round-trip. +type pluginFile struct { + content []byte + mode int64 +} + +// pluginDirContent holds the raw files and parsed metadata from a plugin directory. +type pluginDirContent struct { + manifest []byte + // files maps relative paths (e.g., "commands/foo.md") to file content and mode. + files map[string]pluginFile + // pm is the parsed manifest. + pm *pluginManifest +} + +// Compile-time assertion that Packager implements PluginPackager. +var _ PluginPackager = (*Packager)(nil) + +// NewPackager creates a new packager with the given store. +// Panics if store is nil. +func NewPackager(store *Store) *Packager { + if store == nil { + panic("plugins: NewPackager called with nil store") + } + return &Packager{store: store} +} + +// DefaultPackageOptions returns default packaging options. +// Respects SOURCE_DATE_EPOCH for reproducible builds. +func DefaultPackageOptions() PackageOptions { + epoch := time.Unix(0, 0).UTC() + + if sde := os.Getenv("SOURCE_DATE_EPOCH"); sde != "" { + if ts, err := strconv.ParseInt(sde, 10, 64); err == nil { + epoch = time.Unix(ts, 0).UTC() + } + } + + return PackageOptions{ + Epoch: epoch, + Platforms: artifact.DefaultPlatforms, + } +} + +// Package packages a plugin directory into an OCI artifact in the local store. +func (p *Packager) Package(ctx context.Context, pluginDir string, opts PackageOptions) (*PackageResult, error) { + if len(opts.Platforms) == 0 { + opts.Platforms = artifact.DefaultPlatforms + } + // Normalize a zero epoch so the OCI config (which marshals opts.Epoch as-is) + // agrees with the tar/gzip layers (which default a zero epoch to Unix 0) + // when callers pass a bare PackageOptions{}. + if opts.Epoch.IsZero() { + opts.Epoch = time.Unix(0, 0).UTC() + } + + content, err := readPluginDirectory(pluginDir) + if err != nil { + return nil, fmt.Errorf("reading plugin directory: %w", err) + } + + layerBytes, uncompressedTar, err := createContentLayer(content, opts) + if err != nil { + return nil, fmt.Errorf("creating content layer: %w", err) + } + + layerDigest, err := p.store.PutBlob(ctx, layerBytes) + if err != nil { + return nil, fmt.Errorf("storing layer blob: %w", err) + } + + platformManifests := make(map[string]manifestInfo, len(opts.Platforms)) + var primaryManifestDigest, primaryConfigDigest digest.Digest + var pluginConfig *PluginConfig + var manifestAnnotations map[string]string + + for i, platform := range opts.Platforms { + platformStr := artifact.PlatformString(platform) + + ociConfig, cfg := createOCIConfig(content, uncompressedTar, platform, opts) + configBytes, err := json.Marshal(ociConfig) + if err != nil { + return nil, fmt.Errorf("marshaling config for platform %s: %w", platformStr, err) + } + + configDigest, err := p.store.PutBlob(ctx, configBytes) + if err != nil { + return nil, fmt.Errorf("storing config blob for platform %s: %w", platformStr, err) + } + + manifest := createManifest(configBytes, configDigest, layerBytes, layerDigest, cfg, opts) + manifestBytes, err := json.Marshal(manifest) + if err != nil { + return nil, fmt.Errorf("marshaling manifest for platform %s: %w", platformStr, err) + } + + manifestDigest, err := p.store.PutManifest(ctx, manifestBytes) + if err != nil { + return nil, fmt.Errorf("storing manifest for platform %s: %w", platformStr, err) + } + + platformManifests[platformStr] = manifestInfo{ + digest: manifestDigest, + size: int64(len(manifestBytes)), + } + + if i == 0 { + primaryManifestDigest = manifestDigest + primaryConfigDigest = configDigest + pluginConfig = cfg + manifestAnnotations = manifest.Annotations + } + } + + indexDigest, err := p.createIndex(ctx, platformManifests, manifestAnnotations, opts) + if err != nil { + return nil, fmt.Errorf("creating index: %w", err) + } + + return &PackageResult{ + IndexDigest: indexDigest, + ManifestDigest: primaryManifestDigest, + ConfigDigest: primaryConfigDigest, + LayerDigest: layerDigest, + Config: pluginConfig, + Platforms: opts.Platforms, + }, nil +} + +// readPluginDirectory reads a plugin directory, validates its contents, and parses the manifest. +func readPluginDirectory(dir string) (*pluginDirContent, error) { + if err := validatePluginDir(dir); err != nil { + return nil, err + } + + manifestPath := filepath.Join(dir, ManifestFileName) + + // Lstat and size-check before reading: reject a symlinked manifest (TOCTOU + // race against validatePluginDir) and avoid allocating an oversized file + // before parseManifest's size check would reject it. + fi, err := os.Lstat(manifestPath) + if err != nil { + if os.IsNotExist(err) { + return nil, fmt.Errorf("%s not found in plugin directory: %w", ManifestFileName, ErrPluginManifestMissing) + } + return nil, fmt.Errorf("checking %s: %w", ManifestFileName, err) + } + if fi.Mode()&os.ModeSymlink != 0 || !fi.Mode().IsRegular() { + return nil, fmt.Errorf("%s must be a regular file: %w", ManifestFileName, ErrInvalidPluginFile) + } + if fi.Size() > maxManifestSize { + return nil, fmt.Errorf("manifest file size %d exceeds maximum of %d bytes: %w", + fi.Size(), maxManifestSize, ErrInvalidPluginManifest) + } + + manifest, err := os.ReadFile(manifestPath) //#nosec G304 -- validated dir; Lstat guard above rejects symlinks + if err != nil { + if os.IsNotExist(err) { + return nil, fmt.Errorf("%s not found in plugin directory: %w", ManifestFileName, ErrPluginManifestMissing) + } + return nil, fmt.Errorf("reading %s: %w", ManifestFileName, err) + } + + pm, err := parseManifest(manifest) + if err != nil { + return nil, fmt.Errorf("parsing %s: %w", ManifestFileName, err) + } + + if pm.Name == "" { + return nil, fmt.Errorf("plugin name is required in %s: %w", ManifestFileName, ErrInvalidPluginManifest) + } + + files, err := collectPluginFiles(dir) + if err != nil { + return nil, err + } + + return &pluginDirContent{ + manifest: manifest, + files: files, + pm: pm, + }, nil +} + +// validatePluginDir checks that the directory exists and is safe to read. +func validatePluginDir(dir string) error { + info, err := os.Stat(dir) + if err != nil { + if os.IsNotExist(err) { + return fmt.Errorf("plugin directory not found: %s: %w", dir, ErrInvalidPluginDir) + } + return fmt.Errorf("accessing plugin directory: %w: %w", err, ErrInvalidPluginDir) + } + if !info.IsDir() { + return fmt.Errorf("path is not a directory: %s: %w", dir, ErrInvalidPluginDir) + } + + // filepath.Clean resolves ".." segments, so a substring check for ".." is + // ineffective against a traversal that resolves to a valid path. Require an + // absolute path instead; relative inputs are rejected up front. + cleanDir := filepath.Clean(dir) + if !filepath.IsAbs(cleanDir) { + return fmt.Errorf("plugin directory must be an absolute path: %s: %w", dir, ErrInvalidPluginDir) + } + + return nil +} + +// collectPluginFiles walks a plugin directory and returns all regular files +// (excluding hidden files except .claude-plugin/plugin.json). It enforces limits +// on file count and total aggregate size to prevent memory exhaustion. +func collectPluginFiles(dir string) (map[string]pluginFile, error) { + files := make(map[string]pluginFile) + var totalSize int64 + if err := filepath.WalkDir(dir, func(path string, d fs.DirEntry, walkErr error) error { + if walkErr != nil { + return walkErr + } + if path == dir { + return nil + } + return collectPluginFile(dir, path, d, files, &totalSize) + }); err != nil { + return nil, fmt.Errorf("walking plugin directory: %w", err) + } + return files, nil +} + +func collectPluginFile(dir, path string, d fs.DirEntry, files map[string]pluginFile, totalSize *int64) error { + relPath, err := filepath.Rel(dir, path) + if err != nil { + return fmt.Errorf("getting relative path: %w", err) + } + relPath = filepath.ToSlash(relPath) + + if d.Type()&os.ModeSymlink != 0 { + return fmt.Errorf("symlinks not allowed in plugin directory: %s: %w", relPath, ErrInvalidPluginFile) + } + + if d.IsDir() { + if strings.HasPrefix(filepath.Base(relPath), ".") && relPath != ".claude-plugin" { + return filepath.SkipDir + } + return nil + } + + fileInfo, err := validatePluginFile(path, relPath) + if err != nil { + return err + } + + if isHiddenPath(relPath) && !isAllowedHiddenPluginFile(relPath) { + return nil + } + + if relPath == ManifestFileName { + return nil + } + + if len(files) >= maxPluginFiles { + return fmt.Errorf("plugin directory exceeds maximum of %d files: %w", maxPluginFiles, ErrTooManyFiles) + } + + content, err := os.ReadFile(path) //#nosec G304,G122 -- path from WalkDir, symlink-checked + if err != nil { + return fmt.Errorf("reading %s: %w", relPath, err) + } + + *totalSize += int64(len(content)) + if *totalSize > maxPluginTotalSize { + return fmt.Errorf("plugin directory exceeds maximum total size of %d bytes: %w", maxPluginTotalSize, ErrPluginTooLarge) + } + + files[relPath] = pluginFile{content: content, mode: int64(fileInfo.Mode().Perm())} + return nil +} + +// validatePluginFile checks that a file in the plugin directory is safe to +// include and returns its FileInfo so the caller can preserve the file mode. +func validatePluginFile(absPath, relPath string) (os.FileInfo, error) { + fileInfo, err := os.Lstat(absPath) + if err != nil { + return nil, fmt.Errorf("checking file type for %s: %w", relPath, err) + } + if fileInfo.Mode()&os.ModeSymlink != 0 { + return nil, fmt.Errorf("symlinks not allowed in plugin directory: %s: %w", relPath, ErrInvalidPluginFile) + } + if !fileInfo.Mode().IsRegular() { + return nil, fmt.Errorf("non-regular file not allowed in plugin directory: %s: %w", relPath, ErrInvalidPluginFile) + } + return fileInfo, nil +} + +func isHiddenPath(relPath string) bool { + parts := strings.Split(relPath, "/") + for _, part := range parts { + if strings.HasPrefix(part, ".") { + return true + } + } + return false +} + +func isAllowedHiddenPluginFile(relPath string) bool { + return relPath == ".mcp.json" || relPath == ManifestFileName +} + +// parseManifest parses .claude-plugin/plugin.json content. +func parseManifest(content []byte) (*pluginManifest, error) { + content = bytes.TrimSpace(content) + if len(content) > maxManifestSize { + return nil, fmt.Errorf("manifest exceeds maximum size of %d bytes: %w", maxManifestSize, ErrInvalidPluginManifest) + } + + var pm pluginManifest + if err := json.Unmarshal(content, &pm); err != nil { + return nil, fmt.Errorf("parsing manifest JSON: %w: %w", err, ErrInvalidPluginManifest) + } + + return &pm, nil +} + +// createContentLayer creates a reproducible tar.gz of the plugin content. +// Returns both compressed and uncompressed bytes (uncompressed needed for diff_id). +func createContentLayer(content *pluginDirContent, opts PackageOptions) (compressed, uncompressed []byte, err error) { + var files []artifact.FileEntry + + files = append(files, artifact.FileEntry{ + Path: ManifestFileName, + Content: content.manifest, + }) + + sortedPaths := make([]string, 0, len(content.files)) + for p := range content.files { + sortedPaths = append(sortedPaths, p) + } + slices.Sort(sortedPaths) + + for _, p := range sortedPaths { + f := content.files[p] + files = append(files, artifact.FileEntry{ + Path: p, + Content: f.content, + Mode: f.mode, + }) + } + + tarOpts := artifact.TarOptions{Epoch: opts.Epoch} + gzipOpts := artifact.DefaultGzipOptions() + // Honour the package epoch in the gzip member header too, so the timestamp a + // consumer reads from the gzip header agrees with the tar and OCI metadata. + gzipOpts.Epoch = opts.Epoch + + uncompressed, err = artifact.CreateTar(files, tarOpts) + if err != nil { + return nil, nil, fmt.Errorf("creating tar: %w", err) + } + + compressed, err = artifact.Compress(uncompressed, gzipOpts) + if err != nil { + return nil, nil, fmt.Errorf("compressing tar: %w", err) + } + + return compressed, uncompressed, nil +} + +// createOCIConfig creates the OCI image config with plugin metadata in labels. +func createOCIConfig( + content *pluginDirContent, + uncompressedTar []byte, + platform ocispec.Platform, + opts PackageOptions, +) (*ocispec.Image, *PluginConfig) { + cfg := pluginConfig(content) + + epoch := opts.Epoch + ociConfig := &ocispec.Image{ + Created: &epoch, + Platform: platform, + Config: ocispec.ImageConfig{ + Labels: map[string]string{ + LabelPluginName: cfg.Name, + LabelPluginDescription: cfg.Description, + LabelPluginVersion: cfg.Version, + LabelPluginLicense: cfg.License, + LabelPluginFiles: mustMarshalJSON("files", cfg.Files), + LabelPluginComponents: mustMarshalJSON("components", cfg.Components), + LabelPluginRequires: mustMarshalJSON("requires", cfg.Requires), + }, + }, + RootFS: ocispec.RootFS{ + Type: "layers", + DiffIDs: []digest.Digest{digest.FromBytes(uncompressedTar)}, + }, + History: []ocispec.History{ + { + Created: &epoch, + CreatedBy: "toolhive package", + }, + }, + } + + return ociConfig, cfg +} + +func pluginConfig(content *pluginDirContent) *PluginConfig { + allFiles := []string{ManifestFileName} + for p := range content.files { + allFiles = append(allFiles, p) + } + slices.Sort(allFiles) + + return &PluginConfig{ + Name: content.pm.Name, + Description: content.pm.Description, + Version: content.pm.Version, + License: content.pm.License, + Files: allFiles, + Components: componentInventory(content.pm), + Requires: requires(content.pm), + } +} + +func componentInventory(pm *pluginManifest) ComponentInventory { + components := ComponentInventory{} + addCount := func(name string, raw json.RawMessage) { + if len(bytes.TrimSpace(raw)) == 0 { + return + } + if count := jsonComponentCount(raw); count > 0 { + components[name] = count + } + } + + addCount("commands", pm.Commands) + addCount("agents", pm.Agents) + addCount("skills", pm.Skills) + addCount("hooks", pm.Hooks) + addCount("mcpServers", pm.MCPServers) + addCount("lspServers", pm.LSPServers) + + // Return nil rather than an empty map so the value round-trips cleanly: + // an empty map is dropped by the `omitempty` config/annotation tags and + // reparsed as nil, so a zero-component plugin must produce nil here too. + if len(components) == 0 { + return nil + } + + return components +} + +func jsonComponentCount(raw json.RawMessage) int { + var arr []json.RawMessage + if err := json.Unmarshal(raw, &arr); err == nil { + return len(arr) + } + + var obj map[string]json.RawMessage + if err := json.Unmarshal(raw, &obj); err == nil { + return len(obj) + } + + return 1 +} + +func requires(pm *pluginManifest) []string { + refs := stringArray(pm.Requires) + refs = append(refs, stringArray(pm.Dependencies)...) + slices.Sort(refs) + return slices.Compact(refs) +} + +func stringArray(raw json.RawMessage) []string { + if len(bytes.TrimSpace(raw)) == 0 { + return nil + } + + var refs []string + if err := json.Unmarshal(raw, &refs); err == nil { + return refs + } + + var obj map[string]string + if err := json.Unmarshal(raw, &obj); err == nil { + refs := make([]string, 0, len(obj)) + for _, ref := range obj { + if ref != "" { + refs = append(refs, ref) + } + } + return refs + } + + return nil +} + +// mustMarshalJSON marshals v for embedding in an OCI label/annotation. The +// values passed here ([]string, map[string]int) cannot fail to marshal; a panic +// guards against a future type change silently producing a broken artifact. +func mustMarshalJSON(field string, v any) string { + b, err := json.Marshal(v) + if err != nil { + panic(fmt.Sprintf("plugins: marshal %s: %v", field, err)) + } + return string(b) +} + +// createManifest creates the OCI manifest. +func createManifest( + configBytes []byte, + configDigest digest.Digest, + layerBytes []byte, + layerDigest digest.Digest, + cfg *PluginConfig, + opts PackageOptions, +) *ocispec.Manifest { + annotations := map[string]string{ + ocispec.AnnotationCreated: opts.Epoch.Format(time.RFC3339), + AnnotationPluginName: cfg.Name, + AnnotationPluginDescription: cfg.Description, + AnnotationPluginVersion: cfg.Version, + AnnotationPluginLicense: cfg.License, + AnnotationPluginFiles: mustMarshalJSON("files", cfg.Files), + AnnotationPluginComponents: mustMarshalJSON("components", cfg.Components), + AnnotationPluginRequires: mustMarshalJSON("requires", cfg.Requires), + ocispec.AnnotationVersion: cfg.Version, + ocispec.AnnotationLicenses: cfg.License, + ocispec.AnnotationTitle: cfg.Name, + ocispec.AnnotationDescription: cfg.Description, + } + + return &ocispec.Manifest{ + Versioned: specs.Versioned{SchemaVersion: 2}, + MediaType: ocispec.MediaTypeImageManifest, + ArtifactType: ArtifactTypePlugin, + Config: ocispec.Descriptor{ + MediaType: ocispec.MediaTypeImageConfig, + Digest: configDigest, + Size: int64(len(configBytes)), + }, + Layers: []ocispec.Descriptor{ + { + MediaType: ocispec.MediaTypeImageLayerGzip, + Digest: layerDigest, + Size: int64(len(layerBytes)), + Annotations: map[string]string{ + ocispec.AnnotationTitle: pluginLayerAnnotation, + }, + }, + }, + Annotations: annotations, + } +} + +// createIndex creates an OCI image index with per-platform manifests. +func (p *Packager) createIndex( + ctx context.Context, + platformManifests map[string]manifestInfo, + annotations map[string]string, + opts PackageOptions, +) (digest.Digest, error) { + manifests := make([]ocispec.Descriptor, 0, len(opts.Platforms)) + for _, platform := range opts.Platforms { + platformStr := artifact.PlatformString(platform) + info, ok := platformManifests[platformStr] + if !ok { + return "", fmt.Errorf("missing manifest for platform %s", platformStr) + } + + p := platform // copy for pointer + manifests = append(manifests, ocispec.Descriptor{ + MediaType: ocispec.MediaTypeImageManifest, + Digest: info.digest, + Size: info.size, + Platform: &p, + }) + } + + index := ocispec.Index{ + Versioned: specs.Versioned{SchemaVersion: 2}, + MediaType: ocispec.MediaTypeImageIndex, + ArtifactType: ArtifactTypePlugin, + Manifests: manifests, + Annotations: annotations, + } + + indexBytes, err := json.Marshal(index) + if err != nil { + return "", fmt.Errorf("marshaling index: %w", err) + } + + indexDigest, err := p.store.PutManifest(ctx, indexBytes) + if err != nil { + return "", fmt.Errorf("storing index: %w", err) + } + + return indexDigest, nil +} diff --git a/oci/plugins/packager_test.go b/oci/plugins/packager_test.go new file mode 100644 index 0000000..320426e --- /dev/null +++ b/oci/plugins/packager_test.go @@ -0,0 +1,694 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package plugins + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "testing" + "time" + + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive-core/oci/artifact" +) + +const testPluginName = "test-plugin" + +func TestPackager_Package(t *testing.T) { + t.Parallel() + + pluginDir := createTestPluginDir(t) + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + packager := NewPackager(store) + opts := PackageOptions{Epoch: time.Unix(0, 0).UTC()} + + result, err := packager.Package(context.Background(), pluginDir, opts) + require.NoError(t, err) + + assert.NotEmpty(t, result.ManifestDigest.String()) + assert.NotEmpty(t, result.ConfigDigest.String()) + assert.NotEmpty(t, result.LayerDigest.String()) + assert.NotEmpty(t, result.IndexDigest.String()) + + assert.Equal(t, testPluginName, result.Config.Name) + assert.Equal(t, "A test plugin for packaging", result.Config.Description) + assert.Equal(t, "1.0.0", result.Config.Version) + assert.Equal(t, "Apache-2.0", result.Config.License) + assert.Contains(t, result.Config.Files, ManifestFileName) + assert.Contains(t, result.Config.Files, "commands/test.md") + assert.Equal(t, ComponentInventory{testComponentCommands: 1, "agents": 1, "skills": 1, "hooks": 1, "mcpServers": 1}, result.Config.Components) + assert.Equal(t, []string{"ghcr.io/org/server:v1", "ghcr.io/org/skill:v1"}, result.Config.Requires) +} + +func TestPackager_Package_Reproducible(t *testing.T) { + t.Parallel() + + pluginDir := createTestPluginDir(t) + opts := PackageOptions{Epoch: time.Unix(0, 0).UTC()} + + store1, err := NewStore(t.TempDir()) + require.NoError(t, err) + + store2, err := NewStore(t.TempDir()) + require.NoError(t, err) + + ctx := context.Background() + + result1, err := NewPackager(store1).Package(ctx, pluginDir, opts) + require.NoError(t, err) + + result2, err := NewPackager(store2).Package(ctx, pluginDir, opts) + require.NoError(t, err) + + assert.Equal(t, result1.IndexDigest, result2.IndexDigest, "IndexDigest not reproducible") + assert.Equal(t, result1.ManifestDigest, result2.ManifestDigest, "ManifestDigest not reproducible") + assert.Equal(t, result1.ConfigDigest, result2.ConfigDigest, "ConfigDigest not reproducible") + assert.Equal(t, result1.LayerDigest, result2.LayerDigest, "LayerDigest not reproducible") +} + +func TestPackager_Package_Reproducible_SourceDateEpoch(t *testing.T) { + t.Setenv("SOURCE_DATE_EPOCH", "1234567890") + + pluginDir := createTestPluginDir(t) + ctx := context.Background() + + store1, err := NewStore(t.TempDir()) + require.NoError(t, err) + result1, err := NewPackager(store1).Package(ctx, pluginDir, DefaultPackageOptions()) + require.NoError(t, err) + + store2, err := NewStore(t.TempDir()) + require.NoError(t, err) + result2, err := NewPackager(store2).Package(ctx, pluginDir, DefaultPackageOptions()) + require.NoError(t, err) + + assert.Equal(t, result1.IndexDigest, result2.IndexDigest) + assert.Equal(t, result1.ManifestDigest, result2.ManifestDigest) + assert.Equal(t, result1.ConfigDigest, result2.ConfigDigest) + assert.Equal(t, result1.LayerDigest, result2.LayerDigest) +} + +func TestPackager_Package_VerifyManifest(t *testing.T) { + t.Parallel() + + pluginDir := createTestPluginDir(t) + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + packager := NewPackager(store) + opts := PackageOptions{Epoch: time.Unix(0, 0).UTC()} + + ctx := context.Background() + result, err := packager.Package(ctx, pluginDir, opts) + require.NoError(t, err) + + manifestBytes, err := store.GetManifest(ctx, result.ManifestDigest) + require.NoError(t, err) + + var manifest ocispec.Manifest + require.NoError(t, json.Unmarshal(manifestBytes, &manifest)) + + assert.Equal(t, 2, manifest.SchemaVersion) + assert.Equal(t, ocispec.MediaTypeImageManifest, manifest.MediaType) + assert.Equal(t, ArtifactTypePlugin, manifest.ArtifactType) + assert.Equal(t, ocispec.MediaTypeImageConfig, manifest.Config.MediaType) + require.Len(t, manifest.Layers, 1) + assert.Equal(t, ocispec.MediaTypeImageLayerGzip, manifest.Layers[0].MediaType) + assert.Equal(t, "plugin.tar.gz", manifest.Layers[0].Annotations[ocispec.AnnotationTitle]) + assert.Equal(t, testPluginName, manifest.Annotations[AnnotationPluginName]) + assert.Equal(t, "Apache-2.0", manifest.Annotations[AnnotationPluginLicense]) + assert.JSONEq(t, testPluginComponents, manifest.Annotations[AnnotationPluginComponents]) + assert.JSONEq(t, `["ghcr.io/org/server:v1","ghcr.io/org/skill:v1"]`, manifest.Annotations[AnnotationPluginRequires]) +} + +func TestPackager_Package_VerifyLayer(t *testing.T) { + t.Parallel() + + pluginDir := createTestPluginDir(t) + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + packager := NewPackager(store) + opts := PackageOptions{Epoch: time.Unix(0, 0).UTC()} + + ctx := context.Background() + result, err := packager.Package(ctx, pluginDir, opts) + require.NoError(t, err) + + layerBytes, err := store.GetBlob(ctx, result.LayerDigest) + require.NoError(t, err) + + files, err := artifact.DecompressTar(layerBytes) + require.NoError(t, err) + + fileMap := make(map[string][]byte, len(files)) + for _, f := range files { + fileMap[f.Path] = f.Content + } + + _, ok := fileMap[ManifestFileName] + assert.True(t, ok, "plugin manifest not found in layer") + _, ok = fileMap["commands/test.md"] + assert.True(t, ok, "commands/test.md not found in layer") + _, ok = fileMap[".mcp.json"] + assert.True(t, ok, ".mcp.json should be packaged verbatim") + _, ok = fileMap[".hidden"] + assert.False(t, ok, "hidden file should not be in layer") +} + +func TestPackager_Package_VerifyOCIConfig(t *testing.T) { + t.Parallel() + + pluginDir := createTestPluginDir(t) + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + packager := NewPackager(store) + opts := PackageOptions{Epoch: time.Unix(0, 0).UTC()} + + ctx := context.Background() + result, err := packager.Package(ctx, pluginDir, opts) + require.NoError(t, err) + + configBytes, err := store.GetBlob(ctx, result.ConfigDigest) + require.NoError(t, err) + + var ociConfig ocispec.Image + require.NoError(t, json.Unmarshal(configBytes, &ociConfig)) + + assert.Equal(t, artifact.ArchAMD64, ociConfig.Architecture) + assert.Equal(t, artifact.OSLinux, ociConfig.OS) + assert.NotNil(t, ociConfig.Created, "top-level created field should be set") + assert.Equal(t, "layers", ociConfig.RootFS.Type) + require.Len(t, ociConfig.RootFS.DiffIDs, 1) + assert.Contains(t, ociConfig.RootFS.DiffIDs[0].String(), "sha256:") + + labels := ociConfig.Config.Labels + require.NotNil(t, labels) + assert.Equal(t, testPluginName, labels[LabelPluginName]) + assert.Equal(t, "A test plugin for packaging", labels[LabelPluginDescription]) + assert.Equal(t, "1.0.0", labels[LabelPluginVersion]) + assert.JSONEq(t, testPluginComponents, labels[LabelPluginComponents]) + + cfg, err := PluginConfigFromImageConfig(&ociConfig) + require.NoError(t, err) + assert.Equal(t, result.Config, cfg) + + require.Len(t, ociConfig.History, 1) + assert.Equal(t, "toolhive package", ociConfig.History[0].CreatedBy) +} + +func TestPackager_Package_MultiPlatformConfigMatch(t *testing.T) { + t.Parallel() + + pluginDir := createTestPluginDir(t) + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + packager := NewPackager(store) + platforms := []ocispec.Platform{ + {OS: artifact.OSLinux, Architecture: artifact.ArchAMD64}, + {OS: artifact.OSLinux, Architecture: artifact.ArchARM64}, + } + opts := PackageOptions{ + Epoch: time.Unix(0, 0).UTC(), + Platforms: platforms, + } + + ctx := context.Background() + result, err := packager.Package(ctx, pluginDir, opts) + require.NoError(t, err) + + assert.Equal(t, platforms, result.Platforms) + + indexBytes, err := store.GetManifest(ctx, result.IndexDigest) + require.NoError(t, err) + + var index ocispec.Index + require.NoError(t, json.Unmarshal(indexBytes, &index)) + + require.Len(t, index.Manifests, 2) + + for _, descriptor := range index.Manifests { + require.NotNil(t, descriptor.Platform) + platformStr := descriptor.Platform.OS + "/" + descriptor.Platform.Architecture + + manifestBytes, err := store.GetManifest(ctx, descriptor.Digest) + require.NoError(t, err) + + var manifest ocispec.Manifest + require.NoError(t, json.Unmarshal(manifestBytes, &manifest)) + + configBytes, err := store.GetBlob(ctx, manifest.Config.Digest) + require.NoError(t, err) + + var ociConfig ocispec.Image + require.NoError(t, json.Unmarshal(configBytes, &ociConfig)) + + assert.Equal(t, descriptor.Platform.OS, ociConfig.OS, + "Config OS for platform %s", platformStr) + assert.Equal(t, descriptor.Platform.Architecture, ociConfig.Architecture, + "Config Architecture for platform %s", platformStr) + } +} + +func TestDefaultPackageOptions(t *testing.T) { + t.Parallel() + + opts := DefaultPackageOptions() + assert.False(t, opts.Epoch.IsZero()) + assert.Equal(t, artifact.DefaultPlatforms, opts.Platforms) +} + +func TestDefaultPackageOptions_WithSourceDateEpoch(t *testing.T) { + t.Setenv("SOURCE_DATE_EPOCH", "1234567890") + + opts := DefaultPackageOptions() + expected := time.Unix(1234567890, 0).UTC() + assert.True(t, opts.Epoch.Equal(expected)) +} + +func TestPackager_Package_MissingManifest(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + packager := NewPackager(store) + opts := PackageOptions{Epoch: time.Unix(0, 0).UTC()} + + _, err = packager.Package(context.Background(), dir, opts) + assert.Error(t, err) + assert.Contains(t, err.Error(), ManifestFileName+" not found") + assert.ErrorIs(t, err, ErrPluginManifestMissing) +} + +func TestPackager_Package_MissingName(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + writeManifest(t, dir, `{"description":"A plugin without a name","version":"1.0.0"}`) + + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + packager := NewPackager(store) + opts := PackageOptions{Epoch: time.Unix(0, 0).UTC()} + + _, err = packager.Package(context.Background(), dir, opts) + assert.Error(t, err) + assert.Contains(t, err.Error(), "plugin name is required") + assert.ErrorIs(t, err, ErrInvalidPluginManifest) +} + +func TestPackager_Package_DefaultPlatforms(t *testing.T) { + t.Parallel() + + pluginDir := createTestPluginDir(t) + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + packager := NewPackager(store) + opts := PackageOptions{Epoch: time.Unix(0, 0).UTC()} + + result, err := packager.Package(context.Background(), pluginDir, opts) + require.NoError(t, err) + + assert.Equal(t, artifact.DefaultPlatforms, result.Platforms) +} + +func TestPackager_Package_RejectsSymlinks(t *testing.T) { + t.Parallel() + + dir := createTestPluginDir(t) + require.NoError(t, os.Symlink("/etc/passwd", filepath.Join(dir, "evil_link"))) + + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + packager := NewPackager(store) + opts := PackageOptions{Epoch: time.Unix(0, 0).UTC()} + + _, err = packager.Package(context.Background(), dir, opts) + assert.Error(t, err) + assert.Contains(t, err.Error(), "symlinks not allowed") + assert.ErrorIs(t, err, ErrInvalidPluginFile) +} + +func TestPackager_Package_RejectsSymlinkedDirectory(t *testing.T) { + t.Parallel() + + dir := createTestPluginDir(t) + require.NoError(t, os.Symlink("/etc", filepath.Join(dir, "evil_dir"))) + + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + packager := NewPackager(store) + opts := PackageOptions{Epoch: time.Unix(0, 0).UTC()} + + _, err = packager.Package(context.Background(), dir, opts) + assert.Error(t, err) + assert.Contains(t, err.Error(), "symlinks not allowed") + assert.ErrorIs(t, err, ErrInvalidPluginFile) +} + +func TestNewPackager_NilStore(t *testing.T) { + t.Parallel() + + assert.Panics(t, func() { + NewPackager(nil) + }) +} + +func TestPackager_Package_InvalidManifest(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + content string + }{ + { + name: testNameInvalidJSON, + content: `{"name":`, + }, + { + name: "oversized manifest", + content: `{"name":"test","x":"` + string(bytes.Repeat([]byte("a"), maxManifestSize+1)) + `"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + writeManifest(t, dir, tt.content) + + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + packager := NewPackager(store) + opts := PackageOptions{Epoch: time.Unix(0, 0).UTC()} + + _, err = packager.Package(context.Background(), dir, opts) + assert.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidPluginManifest) + }) + } +} + +func TestPackager_Package_NonexistentDir(t *testing.T) { + t.Parallel() + + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + packager := NewPackager(store) + opts := PackageOptions{Epoch: time.Unix(0, 0).UTC()} + + _, err = packager.Package(context.Background(), "/nonexistent/path", opts) + assert.Error(t, err) + assert.Contains(t, err.Error(), "plugin directory not found") + assert.ErrorIs(t, err, ErrInvalidPluginDir) +} + +func TestPackager_Package_IndexStructure(t *testing.T) { + t.Parallel() + + pluginDir := createTestPluginDir(t) + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + packager := NewPackager(store) + opts := PackageOptions{Epoch: time.Unix(0, 0).UTC()} + + ctx := context.Background() + result, err := packager.Package(ctx, pluginDir, opts) + require.NoError(t, err) + + indexBytes, err := store.GetManifest(ctx, result.IndexDigest) + require.NoError(t, err) + + var index ocispec.Index + require.NoError(t, json.Unmarshal(indexBytes, &index)) + + assert.Equal(t, 2, index.SchemaVersion) + assert.Equal(t, ocispec.MediaTypeImageIndex, index.MediaType) + assert.Equal(t, ArtifactTypePlugin, index.ArtifactType) + assert.NotEmpty(t, index.Annotations) + assert.Equal(t, testPluginName, index.Annotations[AnnotationPluginName]) +} + +func TestParseManifest(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + content string + want *pluginManifest + wantErr bool + }{ + { + name: "full manifest", + content: `{ + "name":"my-plugin", + "description":"A great plugin", + "version":"2.0.0", + "license":"MIT", + "commands":{"hello":"commands/hello.md"} +}`, + want: &pluginManifest{ + Name: testPluginMyPlugin, + Description: "A great plugin", + Version: "2.0.0", + License: "MIT", + }, + }, + { + name: testNameInvalidJSON, + content: "not json", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + pm, err := parseManifest([]byte(tt.content)) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want.Name, pm.Name) + assert.Equal(t, tt.want.Description, pm.Description) + assert.Equal(t, tt.want.Version, pm.Version) + assert.Equal(t, tt.want.License, pm.License) + }) + } +} + +func TestCollectPluginFiles_ExceedsMaxFiles(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + writeManifest(t, dir, `{"name":"too-many-files","description":"A plugin with too many files","version":"1.0.0"}`) + + // Create maxPluginFiles + 1 extra files (plugin.json is excluded from the count). + for i := range maxPluginFiles + 1 { + name := filepath.Join(dir, fmt.Sprintf("file_%05d.txt", i)) + require.NoError(t, os.WriteFile(name, []byte("x"), 0600)) + } + + _, err := collectPluginFiles(dir) + require.Error(t, err) + assert.Contains(t, err.Error(), "exceeds maximum") + assert.ErrorIs(t, err, ErrTooManyFiles) +} + +func TestPackager_Package_SentinelErrors(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setup func(t *testing.T) string + wantErr error + }{ + { + name: "missing plugin directory", + setup: func(t *testing.T) string { + t.Helper() + return filepath.Join(t.TempDir(), "does-not-exist") + }, + wantErr: ErrInvalidPluginDir, + }, + { + name: "path is file not directory", + setup: func(t *testing.T) string { + t.Helper() + f := filepath.Join(t.TempDir(), "not-a-dir") + require.NoError(t, os.WriteFile(f, []byte("x"), 0600)) + return f + }, + wantErr: ErrInvalidPluginDir, + }, + { + name: "path contains traversal", + setup: func(_ *testing.T) string { + return "../no-such-plugin-dir" + }, + wantErr: ErrInvalidPluginDir, + }, + { + name: "missing plugin manifest", + setup: func(t *testing.T) string { + t.Helper() + return t.TempDir() + }, + wantErr: ErrPluginManifestMissing, + }, + { + name: "manifest invalid JSON", + setup: func(t *testing.T) string { + t.Helper() + dir := t.TempDir() + writeManifest(t, dir, `{"name":`) + return dir + }, + wantErr: ErrInvalidPluginManifest, + }, + { + name: "manifest missing name", + setup: func(t *testing.T) string { + t.Helper() + dir := t.TempDir() + writeManifest(t, dir, `{"description":"nameless plugin"}`) + return dir + }, + wantErr: ErrInvalidPluginManifest, + }, + { + name: "symlinked file in plugin directory", + setup: func(t *testing.T) string { + t.Helper() + dir := createTestPluginDir(t) + require.NoError(t, os.Symlink("/etc/passwd", filepath.Join(dir, "evil_link"))) + return dir + }, + wantErr: ErrInvalidPluginFile, + }, + { + name: "symlinked directory in plugin directory", + setup: func(t *testing.T) string { + t.Helper() + dir := createTestPluginDir(t) + require.NoError(t, os.Symlink("/etc", filepath.Join(dir, "evil_dir"))) + return dir + }, + wantErr: ErrInvalidPluginFile, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + _, err = NewPackager(store).Package(context.Background(), tt.setup(t), PackageOptions{}) + require.Error(t, err) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func createTestPluginDir(t *testing.T) string { + t.Helper() + + dir := t.TempDir() + writeManifest(t, dir, `{ + "name": "test-plugin", + "description": "A test plugin for packaging", + "version": "1.0.0", + "license": "Apache-2.0", + "commands": {"test": "commands/test.md"}, + "agents": ["agents/reviewer.md"], + "skills": ["skills/foo"], + "hooks": {"PreToolUse": [{"command": "scripts/hook.sh"}]}, + "mcpServers": {"srv": {"command": "node", "args": ["server.js"]}}, + "dependencies": ["ghcr.io/org/skill:v1"], + "requires": ["ghcr.io/org/server:v1"] +}`) + writeFile(t, dir, "commands/test.md", "# Test Command\n") + writeFile(t, dir, "agents/reviewer.md", "# Reviewer\n") + writeFile(t, dir, "skills/foo/SKILL.md", "---\nname: foo\n---\n# Foo\n") + writeFile(t, dir, "scripts/hook.sh", "#!/bin/sh\necho hook\n") + writeFile(t, dir, ".mcp.json", `{"mcpServers":{"srv":{"command":"node","args":["server.js"]}}}`) + writeFile(t, dir, ".hidden", "hidden\n") + return dir +} + +func writeManifest(t *testing.T, dir, content string) { + t.Helper() + writeFile(t, dir, ManifestFileName, content) +} + +func writeFile(t *testing.T, dir, relPath, content string) { + t.Helper() + path := filepath.Join(dir, filepath.FromSlash(relPath)) + require.NoError(t, os.MkdirAll(filepath.Dir(path), 0750)) + require.NoError(t, os.WriteFile(path, []byte(content), 0600)) +} + +// TestPackager_Package_PreservesFileMode verifies that an executable file's +// permission bits survive packaging rather than being flattened to 0644. +func TestPackager_Package_PreservesFileMode(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + writeManifest(t, dir, `{ + "name": "test-plugin", + "description": "plugin with an executable hook", + "version": "1.0.0", + "hooks": {"PreToolUse": [{"command": "scripts/hook.sh"}]} +}`) + scriptPath := filepath.Join(dir, "scripts", "hook.sh") + require.NoError(t, os.MkdirAll(filepath.Dir(scriptPath), 0750)) + require.NoError(t, os.WriteFile(scriptPath, []byte("#!/bin/sh\necho hook\n"), 0700)) //#nosec G306 -- test fixture must be executable + + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + ctx := context.Background() + result, err := NewPackager(store).Package(ctx, dir, PackageOptions{Epoch: time.Unix(0, 0).UTC()}) + require.NoError(t, err) + + layerBytes, err := store.GetBlob(ctx, result.LayerDigest) + require.NoError(t, err) + + files, err := artifact.DecompressTar(layerBytes) + require.NoError(t, err) + + var found bool + for _, f := range files { + if f.Path == "scripts/hook.sh" { + found = true + assert.Equal(t, int64(0700), f.Mode&0777, "executable bit should be preserved in the layer") + } + } + assert.True(t, found, "scripts/hook.sh not found in layer") +} diff --git a/oci/plugins/registry.go b/oci/plugins/registry.go new file mode 100644 index 0000000..8bf2c23 --- /dev/null +++ b/oci/plugins/registry.go @@ -0,0 +1,166 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package plugins + +import ( + "context" + "fmt" + + "github.com/opencontainers/go-digest" + "oras.land/oras-go/v2" + "oras.land/oras-go/v2/registry" + "oras.land/oras-go/v2/registry/remote" + "oras.land/oras-go/v2/registry/remote/auth" + "oras.land/oras-go/v2/registry/remote/credentials" + "oras.land/oras-go/v2/registry/remote/retry" + + "github.com/stacklok/toolhive-core/oci/artifact" +) + +// Compile-time interface check. +var _ RegistryClient = (*Registry)(nil) + +// Registry provides operations for pushing and pulling plugins from OCI registries. +type Registry struct { + credStore credentials.Store + plainHTTP bool + + // newTarget creates an oras.Target for the given reference. + // Defaults to creating an authenticated remote.Repository. + // Override in tests to inject an in-memory store. + newTarget func(ref registry.Reference) (oras.Target, error) +} + +// RegistryOption configures a Registry. +type RegistryOption func(*Registry) + +// WithPlainHTTP configures whether the registry client uses plain HTTP (insecure) connections. +func WithPlainHTTP(enabled bool) RegistryOption { + return func(r *Registry) { + r.plainHTTP = enabled + } +} + +// WithCredentialStore sets a custom credential store for registry authentication. +// If not provided, the default Docker credential store is used. +func WithCredentialStore(store credentials.Store) RegistryOption { + return func(r *Registry) { + r.credStore = store + } +} + +// NewRegistry creates a new registry client with the given options. +// By default it uses the Docker credential store for authentication. +func NewRegistry(opts ...RegistryOption) (*Registry, error) { + r := &Registry{} + + for _, opt := range opts { + opt(r) + } + + if r.credStore == nil { + credStore, err := credentials.NewStoreFromDocker(credentials.StoreOptions{}) + if err != nil { + return nil, fmt.Errorf("creating credential store: %w", err) + } + r.credStore = credStore + } + + if r.newTarget == nil { + r.newTarget = r.defaultNewTarget + } + + return r, nil +} + +// Push pushes a plugin artifact from the local store to a remote registry. +// The digest can be either an index digest or a manifest digest. +func (r *Registry) Push(ctx context.Context, store *Store, artifactDigest digest.Digest, ref string) error { + parsedRef, err := parseReference(ref) + if err != nil { + return err + } + + // Resolve the artifact to get its full descriptor from the OCI store. + desc, err := store.Target().Resolve(ctx, artifactDigest.String()) + if err != nil { + return fmt.Errorf("resolving artifact descriptor: %w", err) + } + + target, err := r.newTarget(parsedRef) + if err != nil { + return fmt.Errorf("getting repository: %w", err) + } + + // Copy the content graph (blobs → manifests → index) to the remote. + if err := oras.CopyGraph(ctx, store.Target(), target, desc, oras.DefaultCopyGraphOptions); err != nil { + return fmt.Errorf("pushing to registry: %w", err) + } + + // Tag on the remote with the requested reference. + if err := target.Tag(ctx, desc, parsedRef.Reference); err != nil { + return fmt.Errorf("tagging remote: %w", err) + } + + return nil +} + +// Pull pulls a plugin artifact from a remote registry to the local store. +// Returns the digest of the pulled artifact (index or manifest). +func (r *Registry) Pull(ctx context.Context, store *Store, ref string) (digest.Digest, error) { + parsedRef, err := parseReference(ref) + if err != nil { + return "", err + } + + target, err := r.newTarget(parsedRef) + if err != nil { + return "", fmt.Errorf("getting repository: %w", err) + } + + validated := artifact.NewValidatingTarget(store.Target()) + + // Copy from remote to the validated local store, tagging locally under the + // full OCI reference. The local store is shared across all plugins, so using + // the bare tag (e.g. "v1.0.0") as the destination would let one plugin's + // pull silently overwrite another plugin's identically-tagged entry. + desc, err := oras.Copy( + ctx, target, parsedRef.Reference, validated, ref, oras.DefaultCopyOptions, + ) + if err != nil { + return "", fmt.Errorf("pulling from registry: %w", err) + } + + return desc.Digest, nil +} + +// parseReference parses an OCI reference and validates it has a tag or digest. +func parseReference(ref string) (registry.Reference, error) { + parsedRef, err := registry.ParseReference(ref) + if err != nil { + return registry.Reference{}, fmt.Errorf("parsing reference %q: %w", ref, err) + } + if parsedRef.Reference == "" { + return registry.Reference{}, fmt.Errorf("reference %q must include a tag or digest", ref) + } + return parsedRef, nil +} + +// defaultNewTarget creates a remote repository client for the given parsed reference. +func (r *Registry) defaultNewTarget(ref registry.Reference) (oras.Target, error) { + repoPath := ref.Registry + "/" + ref.Repository + + repo, err := remote.NewRepository(repoPath) + if err != nil { + return nil, fmt.Errorf("creating repository for %q: %w", repoPath, err) + } + + repo.Client = &auth.Client{ + Client: retry.DefaultClient, + Credential: credentials.Credential(r.credStore), + } + repo.PlainHTTP = r.plainHTTP + + return repo, nil +} diff --git a/oci/plugins/registry_test.go b/oci/plugins/registry_test.go new file mode 100644 index 0000000..803ce9b --- /dev/null +++ b/oci/plugins/registry_test.go @@ -0,0 +1,232 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package plugins + +import ( + "encoding/json" + "testing" + + "github.com/opencontainers/go-digest" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "oras.land/oras-go/v2" + "oras.land/oras-go/v2/content/memory" + "oras.land/oras-go/v2/registry" + + "github.com/stacklok/toolhive-core/oci/artifact" +) + +func TestNewRegistry_Default(t *testing.T) { + t.Parallel() + + reg, err := NewRegistry() + require.NoError(t, err) + assert.NotNil(t, reg) + assert.NotNil(t, reg.credStore, "default credential store should be set") + assert.False(t, reg.plainHTTP, "plainHTTP should default to false") +} + +func TestNewRegistry_WithOptions(t *testing.T) { + t.Parallel() + + reg, err := NewRegistry( + WithPlainHTTP(true), + ) + require.NoError(t, err) + assert.True(t, reg.plainHTTP, "plainHTTP should be set by option") +} + +func TestParseReference(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + ref string + wantErr bool + }{ + {"valid tag", "ghcr.io/myorg/plugin:v1.0.0", false}, + {"valid digest", "ghcr.io/myorg/plugin@sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", false}, + {"missing tag or digest", "ghcr.io/myorg/plugin", true}, + {"invalid reference", ":::invalid", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := parseReference(tt.ref) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func newTestRegistry(t *testing.T, remoteStore *memory.Store) *Registry { + t.Helper() + return &Registry{ + newTarget: func(_ registry.Reference) (oras.Target, error) { + return remoteStore, nil + }, + } +} + +func buildTestManifest(t *testing.T, store *Store) (digest.Digest, []byte) { + t.Helper() + ctx := t.Context() + + configContent := []byte(`{"architecture":"amd64","os":"linux","rootfs":{"type":"layers","diff_ids":[]}}`) + layerContent := []byte("plugin layer content") + + configDigest, err := store.PutBlob(ctx, configContent) + require.NoError(t, err) + layerDigest, err := store.PutBlob(ctx, layerContent) + require.NoError(t, err) + + manifest := ocispec.Manifest{ + MediaType: ocispec.MediaTypeImageManifest, + ArtifactType: ArtifactTypePlugin, + Config: ocispec.Descriptor{ + MediaType: ocispec.MediaTypeImageConfig, + Digest: configDigest, + Size: int64(len(configContent)), + }, + Layers: []ocispec.Descriptor{ + { + MediaType: ocispec.MediaTypeImageLayerGzip, + Digest: layerDigest, + Size: int64(len(layerContent)), + }, + }, + } + + manifestBytes, err := json.Marshal(manifest) + require.NoError(t, err) + + manifestDigest, err := store.PutManifest(ctx, manifestBytes) + require.NoError(t, err) + + return manifestDigest, manifestBytes +} + +func TestPushPull_ManifestRoundTrip(t *testing.T) { + t.Parallel() + + ctx := t.Context() + remoteStore := memory.New() + + localStore, err := NewStore(t.TempDir()) + require.NoError(t, err) + + manifestDigest, _ := buildTestManifest(t, localStore) + + reg := newTestRegistry(t, remoteStore) + ref := "example.com/myorg/my-plugin:v1.0.0" + + err = reg.Push(ctx, localStore, manifestDigest, ref) + require.NoError(t, err) + + pullStore, err := NewStore(t.TempDir()) + require.NoError(t, err) + + pulledDigest, err := reg.Pull(ctx, pullStore, ref) + require.NoError(t, err) + assert.Equal(t, manifestDigest, pulledDigest) + + got, err := pullStore.GetManifest(ctx, pulledDigest) + require.NoError(t, err) + assert.NotEmpty(t, got) + + resolved, err := pullStore.Resolve(ctx, ref) + require.NoError(t, err) + assert.Equal(t, pulledDigest, resolved) +} + +func TestPushPull_IndexRoundTrip(t *testing.T) { + t.Parallel() + + ctx := t.Context() + remoteStore := memory.New() + + localStore, err := NewStore(t.TempDir()) + require.NoError(t, err) + + manifestDigest, manifestBytes := buildTestManifest(t, localStore) + + index := ocispec.Index{ + MediaType: ocispec.MediaTypeImageIndex, + ArtifactType: ArtifactTypePlugin, + Manifests: []ocispec.Descriptor{ + { + MediaType: ocispec.MediaTypeImageManifest, + Digest: manifestDigest, + Size: int64(len(manifestBytes)), + Platform: &ocispec.Platform{OS: artifact.OSLinux, Architecture: artifact.ArchAMD64}, + }, + }, + } + index.SchemaVersion = 2 + + indexBytes, err := json.Marshal(index) + require.NoError(t, err) + indexDigest, err := localStore.PutManifest(ctx, indexBytes) + require.NoError(t, err) + + reg := newTestRegistry(t, remoteStore) + ref := "example.com/myorg/my-plugin:v2.0.0" + + err = reg.Push(ctx, localStore, indexDigest, ref) + require.NoError(t, err) + + pullStore, err := NewStore(t.TempDir()) + require.NoError(t, err) + + pulledDigest, err := reg.Pull(ctx, pullStore, ref) + require.NoError(t, err) + + isIdx, err := pullStore.IsIndex(ctx, pulledDigest) + require.NoError(t, err) + assert.True(t, isIdx) + + pulledIndex, err := pullStore.GetIndex(ctx, pulledDigest) + require.NoError(t, err) + require.Len(t, pulledIndex.Manifests, 1) + assert.Equal(t, manifestDigest, pulledIndex.Manifests[0].Digest) + + pulledManifest, err := pullStore.GetManifest(ctx, manifestDigest) + require.NoError(t, err) + assert.NotEmpty(t, pulledManifest) + + resolved, err := pullStore.Resolve(ctx, ref) + require.NoError(t, err) + assert.Equal(t, pulledDigest, resolved) +} + +func TestPush_InvalidReference(t *testing.T) { + t.Parallel() + + ctx := t.Context() + localStore, err := NewStore(t.TempDir()) + require.NoError(t, err) + + reg := newTestRegistry(t, memory.New()) + err = reg.Push(ctx, localStore, digest.FromString("test"), ":::invalid") + require.Error(t, err) + assert.Contains(t, err.Error(), "parsing reference") +} + +func TestPull_InvalidReference(t *testing.T) { + t.Parallel() + + ctx := t.Context() + localStore, err := NewStore(t.TempDir()) + require.NoError(t, err) + + reg := newTestRegistry(t, memory.New()) + _, err = reg.Pull(ctx, localStore, ":::invalid") + require.Error(t, err) + assert.Contains(t, err.Error(), "parsing reference") +} diff --git a/oci/plugins/store.go b/oci/plugins/store.go new file mode 100644 index 0000000..874a263 --- /dev/null +++ b/oci/plugins/store.go @@ -0,0 +1,378 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package plugins + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/adrg/xdg" + "github.com/opencontainers/go-digest" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "oras.land/oras-go/v2" + "oras.land/oras-go/v2/content/oci" + "oras.land/oras-go/v2/errdef" + + "github.com/stacklok/toolhive-core/httperr" + "github.com/stacklok/toolhive-core/oci/artifact" +) + +const mediaTypeOctetStream = "application/octet-stream" + +// Store provides local OCI artifact storage backed by an OCI Image Layout. +type Store struct { + root string + inner *oci.Store + + // mu serializes tag mutation and blob deletion so a DeleteBuild cannot race + // a concurrent Tag and delete blobs the newly-added tag now references. + mu sync.Mutex +} + +// NewStore creates a new local OCI store at the given root directory. +// The directory is initialized as an OCI Image Layout with blobs/, oci-layout, and index.json. +func NewStore(root string) (*Store, error) { + inner, err := oci.New(root) + if err != nil { + return nil, fmt.Errorf("creating OCI store at %s: %w", root, err) + } + + return &Store{root: root, inner: inner}, nil +} + +// StoreRoot returns the plugins store root within the given data home directory. +// This is the injectable, testable form. For the standard XDG location, use DefaultStoreRoot. +func StoreRoot(dataHome string) string { + return filepath.Join(dataHome, "toolhive", "plugins") +} + +// DefaultStoreRoot returns the default store root directory using XDG base directory conventions. +func DefaultStoreRoot() string { + return StoreRoot(xdg.DataHome) +} + +// PutBlob stores a blob and returns its digest. +func (s *Store) PutBlob(ctx context.Context, content []byte) (digest.Digest, error) { + d := digest.FromBytes(content) + desc := ocispec.Descriptor{ + MediaType: mediaTypeOctetStream, + Digest: d, + Size: int64(len(content)), + } + + if err := s.inner.Push(ctx, desc, bytes.NewReader(content)); err != nil { + if errors.Is(err, errdef.ErrAlreadyExists) { + return d, nil + } + return "", fmt.Errorf("writing blob: %w", err) + } + + return d, nil +} + +// GetBlob retrieves a blob by digest. +func (s *Store) GetBlob(ctx context.Context, d digest.Digest) ([]byte, error) { + data, err := s.fetchContent(ctx, d) + if err != nil { + return nil, fmt.Errorf("blob not found: %s: %w", d, err) + } + return data, nil +} + +// PutManifest stores a manifest and returns its digest. +func (s *Store) PutManifest(ctx context.Context, content []byte) (digest.Digest, error) { + d := digest.FromBytes(content) + + // Parse media type from content so oci.Store indexes it correctly. + var header struct { + MediaType string `json:"mediaType"` + } + mediaType := mediaTypeOctetStream + if err := json.Unmarshal(content, &header); err == nil && header.MediaType != "" { + mediaType = header.MediaType + } + + desc := ocispec.Descriptor{ + MediaType: mediaType, + Digest: d, + Size: int64(len(content)), + } + + if err := s.inner.Push(ctx, desc, bytes.NewReader(content)); err != nil { + if errors.Is(err, errdef.ErrAlreadyExists) { + return d, nil + } + return "", fmt.Errorf("writing manifest: %w", err) + } + + return d, nil +} + +// GetManifest retrieves a manifest by digest. +func (s *Store) GetManifest(ctx context.Context, d digest.Digest) ([]byte, error) { + data, err := s.fetchContent(ctx, d) + if err != nil { + return nil, fmt.Errorf("manifest not found: %s: %w", d, err) + } + return data, nil +} + +// Tag associates a tag with a manifest digest. +func (s *Store) Tag(ctx context.Context, d digest.Digest, tag string) error { + s.mu.Lock() + defer s.mu.Unlock() + return s.tagLocked(ctx, d, tag) +} + +// tagLocked is the body of Tag. Callers must hold s.mu. +func (s *Store) tagLocked(ctx context.Context, d digest.Digest, tag string) error { + // Resolve the digest to get the full descriptor (manifests are auto-tagged by digest on Push). + desc, err := s.inner.Resolve(ctx, d.String()) + if err != nil { + return fmt.Errorf("resolving digest for tag: %w", err) + } + + if err := s.inner.Tag(ctx, desc, tag); err != nil { + return fmt.Errorf("tagging: %w", err) + } + + return nil +} + +// DeleteTag removes a tag from the store index without deleting the underlying blobs. +func (s *Store) DeleteTag(ctx context.Context, tag string) error { + s.mu.Lock() + defer s.mu.Unlock() + return s.deleteTagLocked(ctx, tag) +} + +// deleteTagLocked is the body of DeleteTag. Callers must hold s.mu. +func (s *Store) deleteTagLocked(ctx context.Context, tag string) error { + if err := s.inner.Untag(ctx, tag); err != nil { + if errors.Is(err, errdef.ErrNotFound) { + return httperr.WithCode( + fmt.Errorf("tag not found: %s: %w", tag, err), + http.StatusNotFound, + ) + } + return fmt.Errorf("removing tag: %w", err) + } + return nil +} + +// DeleteBuild removes a tag and, if no other tag shares the same digest, +// deletes all associated blobs (config, layers, manifest, and index if applicable). +// Use DeleteTag when tag-only removal is desired and blob cleanup is not needed. +func (s *Store) DeleteBuild(ctx context.Context, tag string) error { + s.mu.Lock() + defer s.mu.Unlock() + + d, err := s.Resolve(ctx, tag) + if err != nil { + return httperr.WithCode( + fmt.Errorf("tag not found: %s: %w", tag, err), + http.StatusNotFound, + ) + } + + if err := s.deleteTagLocked(ctx, tag); err != nil { + return err + } + + shared, err := s.isDigestReferenced(ctx, d) + if err != nil { + return fmt.Errorf("checking remaining references: %w", err) + } + if shared { + return nil + } + + return s.deleteOrphanedBlobs(ctx, d) +} + +// Resolve resolves a tag to a manifest digest. +func (s *Store) Resolve(ctx context.Context, tag string) (digest.Digest, error) { + desc, err := s.inner.Resolve(ctx, tag) + if err != nil { + return "", fmt.Errorf("tag not found: %s: %w", tag, err) + } + return desc.Digest, nil +} + +// ListTags returns all tags in the store. +func (s *Store) ListTags(ctx context.Context) ([]string, error) { + var tags []string + if err := s.inner.Tags(ctx, "", func(t []string) error { + tags = append(tags, t...) + return nil + }); err != nil { + return nil, fmt.Errorf("listing tags: %w", err) + } + return tags, nil +} + +// GetIndex retrieves and parses an image index by digest. +func (s *Store) GetIndex(ctx context.Context, d digest.Digest) (*ocispec.Index, error) { + data, err := s.fetchContent(ctx, d) + if err != nil { + return nil, fmt.Errorf("getting index: %w", err) + } + + var index ocispec.Index + if err := json.Unmarshal(data, &index); err != nil { + return nil, fmt.Errorf("parsing index: %w", err) + } + + return &index, nil +} + +// IsIndex checks if the content at the given digest is an image index. +func (s *Store) IsIndex(ctx context.Context, d digest.Digest) (bool, error) { + data, err := s.fetchContent(ctx, d) + if err != nil { + return false, fmt.Errorf("manifest not found: %s: %w", d, err) + } + + var header struct { + MediaType string `json:"mediaType"` + } + if err := json.Unmarshal(data, &header); err != nil { + return false, fmt.Errorf("parsing media type: %w", err) + } + + return header.MediaType == ocispec.MediaTypeImageIndex, nil +} + +// Root returns the store root directory. +func (s *Store) Root() string { + return s.root +} + +// Target returns the underlying oras.Target for direct use by registry operations. +func (s *Store) Target() oras.Target { + return s.inner +} + +// fetchContent retrieves raw content by digest from the underlying store. +func (s *Store) fetchContent(ctx context.Context, d digest.Digest) ([]byte, error) { + // oci.Store's Fetch only uses the Digest field to locate blobs in blobs//. + rc, err := s.inner.Fetch(ctx, ocispec.Descriptor{Digest: d}) + if err != nil { + return nil, err + } + defer func() { _ = rc.Close() }() + + // Bound local reads the same way ValidatingTarget bounds incoming pulls: a + // corrupted or tampered local layout must not trigger an unbounded allocation. + data, err := io.ReadAll(io.LimitReader(rc, artifact.MaxBlobSize+1)) + if err != nil { + return nil, err + } + if int64(len(data)) > artifact.MaxBlobSize { + return nil, fmt.Errorf("blob %s exceeds local fetch size limit of %d bytes", d, artifact.MaxBlobSize) + } + + return data, nil +} + +// isDigestReferenced checks whether any remaining tag still resolves to d. +func (s *Store) isDigestReferenced(ctx context.Context, d digest.Digest) (bool, error) { + tags, err := s.ListTags(ctx) + if err != nil { + return false, err + } + for _, tag := range tags { + resolved, err := s.Resolve(ctx, tag) + if err != nil { + continue + } + if resolved == d { + return true, nil + } + } + return false, nil +} + +// deleteOrphanedBlobs removes all blobs reachable from d (index or manifest), +// including d itself. Callers must ensure no remaining tag references d. +func (s *Store) deleteOrphanedBlobs(ctx context.Context, d digest.Digest) error { + isIdx, err := s.IsIndex(ctx, d) + if err != nil { + return fmt.Errorf("inspecting orphaned digest: %w", err) + } + + if isIdx { + idx, err := s.GetIndex(ctx, d) + if err != nil { + return fmt.Errorf("fetching orphaned index: %w", err) + } + for _, m := range idx.Manifests { + if err := s.deleteManifestBlobs(ctx, m.Digest); err != nil { + return err + } + } + } else { + if err := s.deleteManifestBlobs(ctx, d); err != nil { + return err + } + // deleteManifestBlobs already deletes d when it's a plain manifest. + return nil + } + + return s.deleteBlob(d) +} + +// deleteManifestBlobs fetches the manifest at d, deletes its config and layer +// blobs, then deletes the manifest blob itself. +func (s *Store) deleteManifestBlobs(ctx context.Context, d digest.Digest) error { + data, err := s.fetchContent(ctx, d) + if err != nil { + return fmt.Errorf("fetching manifest %s: %w", d, err) + } + + var m ocispec.Manifest + if err := json.Unmarshal(data, &m); err != nil { + return fmt.Errorf("parsing manifest %s: %w", d, err) + } + + if err := s.deleteBlob(m.Config.Digest); err != nil { + return err + } + for _, layer := range m.Layers { + if err := s.deleteBlob(layer.Digest); err != nil { + return err + } + } + return s.deleteBlob(d) +} + +// deleteBlob removes the blob file for d from the local OCI layout. +// A missing file is treated as success (idempotent). +func (s *Store) deleteBlob(d digest.Digest) error { + // digest.Digest is an unvalidated string typedef; the components below feed + // directly into a filesystem path. Validate before use so a crafted digest + // (e.g. "sha256:../../etc/passwd") cannot escape the store root. + if err := d.Validate(); err != nil { + return fmt.Errorf("deleting blob: invalid digest %q: %w", d, err) + } + blobRoot := filepath.Join(s.root, "blobs") + path := filepath.Join(blobRoot, d.Algorithm().String(), d.Encoded()) + if !strings.HasPrefix(filepath.Clean(path)+string(filepath.Separator), + filepath.Clean(blobRoot)+string(filepath.Separator)) { + return fmt.Errorf("deleting blob: path escapes store root: %s", path) + } + if err := os.Remove(path); err != nil && !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("deleting blob %s: %w", d, err) + } + return nil +} diff --git a/oci/plugins/store_test.go b/oci/plugins/store_test.go new file mode 100644 index 0000000..7d82577 --- /dev/null +++ b/oci/plugins/store_test.go @@ -0,0 +1,507 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package plugins + +import ( + "context" + "encoding/json" + "net/http" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/opencontainers/go-digest" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive-core/httperr" + "github.com/stacklok/toolhive-core/oci/artifact" +) + +func TestNewStore(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + storePath := filepath.Join(tmpDir, "store") + + store, err := NewStore(storePath) + require.NoError(t, err) + assert.Equal(t, storePath, store.Root()) + + blobsDir := filepath.Join(storePath, "blobs") + _, err = os.Stat(blobsDir) + assert.NoError(t, err, "blobs directory should exist") + + ociLayoutFile := filepath.Join(storePath, "oci-layout") + _, err = os.Stat(ociLayoutFile) + assert.NoError(t, err, "oci-layout file should exist") + + indexFile := filepath.Join(storePath, "index.json") + _, err = os.Stat(indexFile) + assert.NoError(t, err, "index.json file should exist") +} + +func TestStore_PutGetBlob(t *testing.T) { + t.Parallel() + + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + ctx := context.Background() + content := []byte("test blob content") + + d, err := store.PutBlob(ctx, content) + require.NoError(t, err) + assert.Equal(t, digest.FromBytes(content), d) + + retrieved, err := store.GetBlob(ctx, d) + require.NoError(t, err) + assert.Equal(t, content, retrieved) +} + +func TestStore_PutBlob_Idempotent(t *testing.T) { + t.Parallel() + + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + ctx := context.Background() + content := []byte("test blob content") + + d1, err := store.PutBlob(ctx, content) + require.NoError(t, err) + + d2, err := store.PutBlob(ctx, content) + require.NoError(t, err) + + assert.Equal(t, d1, d2, "putting the same content twice should return the same digest") +} + +func TestStore_GetBlob_NotFound(t *testing.T) { + t.Parallel() + + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + ctx := context.Background() + fakeDigest := digest.FromString("nonexistent") + + _, err = store.GetBlob(ctx, fakeDigest) + require.Error(t, err) + assert.Contains(t, err.Error(), "blob not found") +} + +func TestStore_PutGetManifest(t *testing.T) { + t.Parallel() + + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + ctx := context.Background() + manifest := []byte(`{"schemaVersion": 2, "mediaType": "application/vnd.oci.image.manifest.v1+json"}`) + + d, err := store.PutManifest(ctx, manifest) + require.NoError(t, err) + + retrieved, err := store.GetManifest(ctx, d) + require.NoError(t, err) + assert.Equal(t, manifest, retrieved) +} + +func TestStore_TagResolve(t *testing.T) { + t.Parallel() + + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + ctx := context.Background() + manifest := []byte(`{"schemaVersion": 2, "mediaType": "application/vnd.oci.image.manifest.v1+json"}`) + + d, err := store.PutManifest(ctx, manifest) + require.NoError(t, err) + + tag := "ghcr.io/myorg/my-plugin:v1.0.0" + err = store.Tag(ctx, d, tag) + require.NoError(t, err) + + resolved, err := store.Resolve(ctx, tag) + require.NoError(t, err) + assert.Equal(t, d, resolved) +} + +func TestStore_Resolve_NotFound(t *testing.T) { + t.Parallel() + + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + ctx := context.Background() + + _, err = store.Resolve(ctx, "nonexistent:tag") + require.Error(t, err) + assert.Contains(t, err.Error(), "tag not found") +} + +func TestStore_ListTags(t *testing.T) { + t.Parallel() + + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + ctx := context.Background() + + tags, err := store.ListTags(ctx) + require.NoError(t, err) + assert.Empty(t, tags) + + manifest := []byte(`{"schemaVersion": 2, "mediaType": "application/vnd.oci.image.manifest.v1+json"}`) + d, err := store.PutManifest(ctx, manifest) + require.NoError(t, err) + + expectedTags := []string{"tag1", "tag2", "tag3"} + for _, tag := range expectedTags { + err = store.Tag(ctx, d, tag) + require.NoError(t, err) + } + + tags, err = store.ListTags(ctx) + require.NoError(t, err) + assert.Len(t, tags, len(expectedTags)) + for _, expected := range expectedTags { + assert.Contains(t, tags, expected) + } +} + +func TestStore_TagOverwrite(t *testing.T) { + t.Parallel() + + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + ctx := context.Background() + + manifest1 := []byte(`{"schemaVersion": 2, "mediaType": "application/vnd.oci.image.manifest.v1+json", "version": 1}`) + manifest2 := []byte(`{"schemaVersion": 2, "mediaType": "application/vnd.oci.image.manifest.v1+json", "version": 2}`) + + d1, err := store.PutManifest(ctx, manifest1) + require.NoError(t, err) + + d2, err := store.PutManifest(ctx, manifest2) + require.NoError(t, err) + + tag := "my-plugin:latest" + err = store.Tag(ctx, d1, tag) + require.NoError(t, err) + + err = store.Tag(ctx, d2, tag) + require.NoError(t, err) + + resolved, err := store.Resolve(ctx, tag) + require.NoError(t, err) + assert.Equal(t, d2, resolved, "tag should resolve to the second manifest after overwrite") +} + +func TestStore_GetIndex(t *testing.T) { + t.Parallel() + + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + ctx := context.Background() + + idx := &ocispec.Index{ + MediaType: ocispec.MediaTypeImageIndex, + Manifests: []ocispec.Descriptor{ + { + MediaType: ocispec.MediaTypeImageManifest, + Digest: digest.FromString("test"), + Size: 100, + Platform: &ocispec.Platform{OS: artifact.OSLinux, Architecture: artifact.ArchAMD64}, + }, + }, + } + idx.SchemaVersion = 2 + + data, err := json.Marshal(idx) + require.NoError(t, err) + + d, err := store.PutManifest(ctx, data) + require.NoError(t, err) + + got, err := store.GetIndex(ctx, d) + require.NoError(t, err) + assert.Equal(t, 2, got.SchemaVersion) + assert.Equal(t, ocispec.MediaTypeImageIndex, got.MediaType) + require.Len(t, got.Manifests, 1) + assert.Equal(t, artifact.OSLinux, got.Manifests[0].Platform.OS) + assert.Equal(t, artifact.ArchAMD64, got.Manifests[0].Platform.Architecture) +} + +func TestStore_IsIndex(t *testing.T) { + t.Parallel() + + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + ctx := context.Background() + + idx := ocispec.Index{ + MediaType: ocispec.MediaTypeImageIndex, + } + idx.SchemaVersion = 2 + indexData, err := json.Marshal(idx) + require.NoError(t, err) + + indexDigest, err := store.PutManifest(ctx, indexData) + require.NoError(t, err) + + isIdx, err := store.IsIndex(ctx, indexDigest) + require.NoError(t, err) + assert.True(t, isIdx) + + manifest := []byte(`{"schemaVersion": 2, "mediaType": "application/vnd.oci.image.manifest.v1+json"}`) + manifestDigest, err := store.PutManifest(ctx, manifest) + require.NoError(t, err) + + isIdx, err = store.IsIndex(ctx, manifestDigest) + require.NoError(t, err) + assert.False(t, isIdx) +} + +func TestStore_DeleteTag(t *testing.T) { + t.Parallel() + + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + ctx := context.Background() + manifest := []byte(`{"schemaVersion": 2, "mediaType": "application/vnd.oci.image.manifest.v1+json"}`) + d, err := store.PutManifest(ctx, manifest) + require.NoError(t, err) + + tag := "my-plugin:v1" + require.NoError(t, store.Tag(ctx, d, tag)) + require.NoError(t, store.DeleteTag(ctx, tag)) + + _, err = store.Resolve(ctx, tag) + require.Error(t, err) +} + +func TestStore_DeleteTag_NotFound(t *testing.T) { + t.Parallel() + + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + err = store.DeleteTag(context.Background(), "missing") + require.Error(t, err) + assert.Equal(t, http.StatusNotFound, httperr.Code(err)) +} + +func TestStoreRoot(t *testing.T) { + t.Parallel() + + assert.Equal(t, filepath.Join("tmp", "data", "toolhive", "plugins"), StoreRoot(filepath.Join("tmp", "data"))) +} + +func TestDefaultStoreRoot(t *testing.T) { + t.Parallel() + + root := DefaultStoreRoot() + assert.True(t, strings.HasSuffix(root, filepath.Join("toolhive", "plugins"))) +} + +// putTestManifest creates a realistic OCI artifact (config + layer + manifest) +// in the store WITHOUT tagging it, and returns the config, layer, and manifest digests. +func putTestManifest(ctx context.Context, t *testing.T, s *Store) (configDigest, layerDigest, manifestDigest digest.Digest) { + t.Helper() + + configData := []byte(`{"architecture":"amd64","os":"linux"}`) + configDigest, err := s.PutBlob(ctx, configData) + require.NoError(t, err) + + layerData := []byte("fake layer content") + layerDigest, err = s.PutBlob(ctx, layerData) + require.NoError(t, err) + + manifestContent, err := json.Marshal(ocispec.Manifest{ + MediaType: ocispec.MediaTypeImageManifest, + Config: ocispec.Descriptor{ + MediaType: ocispec.MediaTypeImageConfig, + Digest: configDigest, + Size: int64(len(configData)), + }, + Layers: []ocispec.Descriptor{ + { + MediaType: ocispec.MediaTypeImageLayerGzip, + Digest: layerDigest, + Size: int64(len(layerData)), + }, + }, + }) + require.NoError(t, err) + + manifestDigest, err = s.PutManifest(ctx, manifestContent) + require.NoError(t, err) + + return configDigest, layerDigest, manifestDigest +} + +// putTestArtifact creates a realistic OCI artifact (config + layer + manifest) +// in the store, tags it with the given tag, and returns the config, layer, and manifest digests. +func putTestArtifact(ctx context.Context, t *testing.T, s *Store, tag string) (configDigest, layerDigest, manifestDigest digest.Digest) { + t.Helper() + configDigest, layerDigest, manifestDigest = putTestManifest(ctx, t, s) + require.NoError(t, s.Tag(ctx, manifestDigest, tag)) + return configDigest, layerDigest, manifestDigest +} + +// blobExists reports whether the blob file for d is present on disk. +func blobExists(t *testing.T, s *Store, d digest.Digest) bool { + t.Helper() + path := filepath.Join(s.Root(), "blobs", d.Algorithm().String(), d.Encoded()) + _, err := os.Stat(path) + return err == nil +} + +func TestStore_DeleteBuild(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setup func(t *testing.T, s *Store, ctx context.Context) (configDigest, layerDigest, manifestDigest digest.Digest) + tag string + wantErr bool + wantCode int + postCheck func(t *testing.T, s *Store, ctx context.Context, configDigest, layerDigest, manifestDigest digest.Digest) + }{ + { + name: "removes tag and blobs when no other tag shares the digest", + setup: func(t *testing.T, s *Store, ctx context.Context) (digest.Digest, digest.Digest, digest.Digest) { + t.Helper() + return putTestArtifact(ctx, t, s, "v1") + }, + tag: "v1", + postCheck: func(t *testing.T, s *Store, ctx context.Context, configDigest, layerDigest, manifestDigest digest.Digest) { + t.Helper() + _, err := s.Resolve(ctx, "v1") + assert.Error(t, err, "tag should be gone") + + assert.False(t, blobExists(t, s, manifestDigest), "manifest blob should be deleted") + assert.False(t, blobExists(t, s, configDigest), "config blob should be deleted") + assert.False(t, blobExists(t, s, layerDigest), "layer blob should be deleted") + }, + }, + { + name: "keeps blobs when another tag shares the same digest", + setup: func(t *testing.T, s *Store, ctx context.Context) (digest.Digest, digest.Digest, digest.Digest) { + t.Helper() + c, l, m := putTestArtifact(ctx, t, s, "v1") + require.NoError(t, s.Tag(ctx, m, "v2")) + return c, l, m + }, + tag: "v1", + postCheck: func(t *testing.T, s *Store, ctx context.Context, configDigest, layerDigest, manifestDigest digest.Digest) { + t.Helper() + resolved, err := s.Resolve(ctx, "v2") + require.NoError(t, err) + assert.Equal(t, manifestDigest, resolved, "v2 should still resolve") + + assert.True(t, blobExists(t, s, manifestDigest), "manifest blob should be retained") + assert.True(t, blobExists(t, s, configDigest), "config blob should be retained") + assert.True(t, blobExists(t, s, layerDigest), "layer blob should be retained") + }, + }, + { + name: "returns 404 when tag does not exist", + setup: func(_ *testing.T, _ *Store, _ context.Context) (digest.Digest, digest.Digest, digest.Digest) { + return "", "", "" + }, + tag: "nonexistent", + wantErr: true, + wantCode: http.StatusNotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + ctx := context.Background() + configDigest, layerDigest, manifestDigest := tt.setup(t, store, ctx) + + err = store.DeleteBuild(ctx, tt.tag) + + if tt.wantErr { + require.Error(t, err) + if tt.wantCode != 0 { + assert.Equal(t, tt.wantCode, httperr.Code(err)) + } + return + } + + require.NoError(t, err) + if tt.postCheck != nil { + tt.postCheck(t, store, ctx, configDigest, layerDigest, manifestDigest) + } + }) + } +} + +// TestStore_DeleteBuild_Index covers the deleteOrphanedBlobs index branch: a +// tagged image index must, on deletion, recursively delete every referenced +// manifest and its config/layer blobs as well as the index blob itself. +func TestStore_DeleteBuild_Index(t *testing.T) { + t.Parallel() + + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + ctx := context.Background() + + configDigest, layerDigest, manifestDigest := putTestManifest(ctx, t, store) + + indexContent, err := json.Marshal(ocispec.Index{ + MediaType: ocispec.MediaTypeImageIndex, + Manifests: []ocispec.Descriptor{ + { + MediaType: ocispec.MediaTypeImageManifest, + Digest: manifestDigest, + }, + }, + }) + require.NoError(t, err) + + indexDigest, err := store.PutManifest(ctx, indexContent) + require.NoError(t, err) + require.NoError(t, store.Tag(ctx, indexDigest, "v1")) + + require.NoError(t, store.DeleteBuild(ctx, "v1")) + + _, err = store.Resolve(ctx, "v1") + assert.Error(t, err, "tag should be gone") + + assert.False(t, blobExists(t, store, indexDigest), "index blob should be deleted") + assert.False(t, blobExists(t, store, manifestDigest), "manifest blob should be deleted") + assert.False(t, blobExists(t, store, configDigest), "config blob should be deleted") + assert.False(t, blobExists(t, store, layerDigest), "layer blob should be deleted") +} + +// TestStore_deleteBlob_RejectsInvalidDigest guards the path-traversal fix: a +// digest that is not a well-formed algorithm:hex pair must be rejected before +// it can be turned into a filesystem path. +func TestStore_deleteBlob_RejectsInvalidDigest(t *testing.T) { + t.Parallel() + + store, err := NewStore(t.TempDir()) + require.NoError(t, err) + + err = store.deleteBlob(digest.Digest("sha256:../../../../etc/passwd")) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid digest") +} diff --git a/oci/plugins/testconsts_test.go b/oci/plugins/testconsts_test.go new file mode 100644 index 0000000..c766291 --- /dev/null +++ b/oci/plugins/testconsts_test.go @@ -0,0 +1,21 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package plugins + +const ( + testFileA = "a.txt" + testFileB = "b.txt" + testPluginMyPlugin = "my-plugin" + testNotJSON = "not-json" + testNameInvalidJSON = "invalid JSON" + testComponentCommands = "commands" + testPluginComponents = `{"commands":1,"agents":1,"skills":1,"hooks":1,"mcpServers":1}` + testPlatformAMD64 = "linux/amd64" + testPlatformARMv7 = "linux/arm/v7" + testArchARM = "arm" + testComponentSkills = "skills" + testPluginMinimal = "minimal-plugin" + testRequireServerV1 = "ghcr.io/org/server:v1" + testRequireSkillV1 = "ghcr.io/org/skill:v1" +) diff --git a/oci/skills/packager.go b/oci/skills/packager.go index baeb1b8..537b784 100644 --- a/oci/skills/packager.go +++ b/oci/skills/packager.go @@ -107,8 +107,10 @@ const maxFrontmatterSize = 64 * 1024 const maxSkillFiles = 1_000 // maxSkillTotalSize limits the total aggregate size of all files in a skill -// directory to prevent memory exhaustion during packaging (100 MB). -const maxSkillTotalSize int64 = 100 * 1024 * 1024 +// directory to prevent memory exhaustion during packaging. Kept below +// artifact.MaxDecompressedSize (100 MB) so that per-file tar header overhead +// cannot push a packaged artifact past the limit enforced on extraction. +const maxSkillTotalSize int64 = 95 * 1024 * 1024 // Compile-time assertion that Packager implements SkillPackager. var _ SkillPackager = (*Packager)(nil) @@ -422,6 +424,9 @@ func createContentLayer(content *skillDirContent, opts PackageOptions) (compress tarOpts := artifact.TarOptions{Epoch: opts.Epoch} gzipOpts := artifact.DefaultGzipOptions() + // Honour the package epoch in the gzip member header too, so the timestamp a + // consumer reads from the gzip header agrees with the tar and OCI metadata. + gzipOpts.Epoch = opts.Epoch uncompressed, err = artifact.CreateTar(files, tarOpts) if err != nil { diff --git a/oci/skills/registry.go b/oci/skills/registry.go index f9864e9..270525a 100644 --- a/oci/skills/registry.go +++ b/oci/skills/registry.go @@ -13,6 +13,7 @@ import ( "oras.land/oras-go/v2/registry/remote" "oras.land/oras-go/v2/registry/remote/auth" "oras.land/oras-go/v2/registry/remote/credentials" + "oras.land/oras-go/v2/registry/remote/retry" "github.com/stacklok/toolhive-core/oci/artifact" ) @@ -120,20 +121,17 @@ func (r *Registry) Pull(ctx context.Context, store *Store, ref string) (digest.D validated := artifact.NewValidatingTarget(store.Target()) - // Copy from remote to the validated local store + // Copy from remote to the validated local store, tagging locally under the + // full OCI reference. The local store is shared across all skills, so using + // the bare tag (e.g. "v1.0.0") as the destination would let one skill's pull + // silently overwrite another skill's identically-tagged entry. desc, err := oras.Copy( - ctx, target, parsedRef.Reference, validated, parsedRef.Reference, oras.DefaultCopyOptions, + ctx, target, parsedRef.Reference, validated, ref, oras.DefaultCopyOptions, ) if err != nil { return "", fmt.Errorf("pulling from registry: %w", err) } - // oras.Copy already tagged with the short reference (parsedRef.Reference, e.g. "v1.0.0"). - // Additionally tag with the full OCI reference for local resolution. - if err := store.Tag(ctx, desc.Digest, ref); err != nil { - return "", fmt.Errorf("tagging locally: %w", err) - } - return desc.Digest, nil } @@ -159,6 +157,7 @@ func (r *Registry) defaultNewTarget(ref registry.Reference) (oras.Target, error) } repo.Client = &auth.Client{ + Client: retry.DefaultClient, Credential: credentials.Credential(r.credStore), } repo.PlainHTTP = r.plainHTTP diff --git a/oci/skills/store.go b/oci/skills/store.go index 0d1a1d5..607fb8b 100644 --- a/oci/skills/store.go +++ b/oci/skills/store.go @@ -13,6 +13,7 @@ import ( "net/http" "os" "path/filepath" + "strings" "github.com/adrg/xdg" "github.com/opencontainers/go-digest" @@ -22,6 +23,7 @@ import ( "oras.land/oras-go/v2/errdef" "github.com/stacklok/toolhive-core/httperr" + "github.com/stacklok/toolhive-core/oci/artifact" ) const mediaTypeOctetStream = "application/octet-stream" @@ -248,10 +250,15 @@ func (s *Store) fetchContent(ctx context.Context, d digest.Digest) ([]byte, erro } defer func() { _ = rc.Close() }() - data, err := io.ReadAll(rc) + // Bound local reads the same way ValidatingTarget bounds incoming pulls: a + // corrupted or tampered local layout must not trigger an unbounded allocation. + data, err := io.ReadAll(io.LimitReader(rc, artifact.MaxBlobSize+1)) if err != nil { return nil, err } + if int64(len(data)) > artifact.MaxBlobSize { + return nil, fmt.Errorf("blob %s exceeds local fetch size limit of %d bytes", d, artifact.MaxBlobSize) + } return data, nil } @@ -330,7 +337,18 @@ func (s *Store) deleteManifestBlobs(ctx context.Context, d digest.Digest) error // deleteBlob removes the blob file for d from the local OCI layout. // A missing file is treated as success (idempotent). func (s *Store) deleteBlob(d digest.Digest) error { - path := filepath.Join(s.root, "blobs", d.Algorithm().String(), d.Encoded()) + // digest.Digest is an unvalidated string typedef; the components below feed + // directly into a filesystem path. Validate before use so a crafted digest + // (e.g. "sha256:../../etc/passwd") cannot escape the store root. + if err := d.Validate(); err != nil { + return fmt.Errorf("deleting blob: invalid digest %q: %w", d, err) + } + blobRoot := filepath.Join(s.root, "blobs") + path := filepath.Join(blobRoot, d.Algorithm().String(), d.Encoded()) + if !strings.HasPrefix(filepath.Clean(path)+string(filepath.Separator), + filepath.Clean(blobRoot)+string(filepath.Separator)) { + return fmt.Errorf("deleting blob: path escapes store root: %s", path) + } if err := os.Remove(path); err != nil && !errors.Is(err, os.ErrNotExist) { return fmt.Errorf("deleting blob %s: %w", d, err) }