Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,13 @@ linters:
- errcheck
- dupl
- gosec
- goconst
path: (.+)_test\.go
# Test helper / fixture files are not named *_test.go but contain
# the same kind of repeated literal data that goconst flags noisily.
- linters:
- goconst
path: (test_helpers|testutils.*)\.go
- linters:
- lll
path: .golangci.yml
Expand Down
12 changes: 12 additions & 0 deletions internal/audit/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ const (
ResourceTypeSkill = "skill"
)

// Target field keys.
const (
targetFieldMethod = "method"
targetFieldPath = "path"
targetFieldResourceType = "resource_type"
targetFieldResourceName = "resource_name"
targetFieldRegistryName = "registry_name"
targetFieldNamespace = "namespace"
targetFieldEntryType = "entry_type"
targetFieldVersion = "version"
)

// Event types for the MCP registry v0.1 discovery API.
const (
EventServerList = "server.list"
Expand Down
54 changes: 27 additions & 27 deletions internal/audit/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ func setRouteInfo(ctx context.Context, info *RouteInfo) {
func Audited(eventType, resourceType, nameParam string, h http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
target := map[string]string{
"method": r.Method,
"path": r.URL.Path,
targetFieldMethod: r.Method,
targetFieldPath: r.URL.Path,
}
if resourceType != "" {
target["resource_type"] = resourceType
target[targetFieldResourceType] = resourceType
}
if nameParam != "" {
if name := chi.URLParam(r, nameParam); name != "" {
target["resource_name"] = name
target[targetFieldResourceName] = name
}
}
setRouteInfo(r.Context(), &RouteInfo{
Expand All @@ -86,15 +86,15 @@ func Audited(eventType, resourceType, nameParam string, h http.HandlerFunc) http
func AuditedUpsert(onCreate, onUpdate, resourceType, nameParam string, h http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
target := map[string]string{
"method": r.Method,
"path": r.URL.Path,
targetFieldMethod: r.Method,
targetFieldPath: r.URL.Path,
}
if resourceType != "" {
target["resource_type"] = resourceType
target[targetFieldResourceType] = resourceType
}
if nameParam != "" {
if name := chi.URLParam(r, nameParam); name != "" {
target["resource_name"] = name
target[targetFieldResourceName] = name
}
}
setRouteInfo(r.Context(), &RouteInfo{
Expand All @@ -111,18 +111,18 @@ func AuditedUpsert(onCreate, onUpdate, resourceType, nameParam string, h http.Ha
func AuditedEntry(eventType string, h http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
target := map[string]string{
"method": r.Method,
"path": r.URL.Path,
"resource_type": ResourceTypeEntry,
targetFieldMethod: r.Method,
targetFieldPath: r.URL.Path,
targetFieldResourceType: ResourceTypeEntry,
}
if entryType := chi.URLParam(r, "type"); entryType != "" {
target["entry_type"] = entryType
target[targetFieldEntryType] = entryType
}
if name := chi.URLParam(r, "name"); name != "" {
target["resource_name"] = name
target[targetFieldResourceName] = name
}
if version := chi.URLParam(r, "version"); version != "" {
target["version"] = version
target[targetFieldVersion] = version
}
setRouteInfo(r.Context(), &RouteInfo{
EventType: eventType,
Expand All @@ -137,18 +137,18 @@ func AuditedEntry(eventType string, h http.HandlerFunc) http.HandlerFunc {
func AuditedServer(eventType string, h http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
target := map[string]string{
"method": r.Method,
"path": r.URL.Path,
"resource_type": ResourceTypeServer,
targetFieldMethod: r.Method,
targetFieldPath: r.URL.Path,
targetFieldResourceType: ResourceTypeServer,
}
if registryName := chi.URLParam(r, "registryName"); registryName != "" {
target["registry_name"] = registryName
target[targetFieldRegistryName] = registryName
}
if serverName := chi.URLParam(r, "serverName"); serverName != "" {
target["resource_name"] = serverName
target[targetFieldResourceName] = serverName
}
if version := chi.URLParam(r, "version"); version != "" {
target["version"] = version
target[targetFieldVersion] = version
}
setRouteInfo(r.Context(), &RouteInfo{
EventType: eventType,
Expand All @@ -163,21 +163,21 @@ func AuditedServer(eventType string, h http.HandlerFunc) http.HandlerFunc {
func AuditedSkill(eventType string, h http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
target := map[string]string{
"method": r.Method,
"path": r.URL.Path,
"resource_type": ResourceTypeSkill,
targetFieldMethod: r.Method,
targetFieldPath: r.URL.Path,
targetFieldResourceType: ResourceTypeSkill,
}
if registryName := chi.URLParam(r, "registryName"); registryName != "" {
target["registry_name"] = registryName
target[targetFieldRegistryName] = registryName
}
if namespace := chi.URLParam(r, "namespace"); namespace != "" {
target["namespace"] = namespace
target[targetFieldNamespace] = namespace
}
if name := chi.URLParam(r, "name"); name != "" {
target["resource_name"] = name
target[targetFieldResourceName] = name
}
if version := chi.URLParam(r, "version"); version != "" {
target["version"] = version
target[targetFieldVersion] = version
}
setRouteInfo(r.Context(), &RouteInfo{
EventType: eventType,
Expand Down
4 changes: 2 additions & 2 deletions internal/audit/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ func emitAuthFailureEvent(r *http.Request, logger *Logger) {
source := SourceFromRequest(r)

target := map[string]string{
"method": r.Method,
"path": r.URL.Path,
targetFieldMethod: r.Method,
targetFieldPath: r.URL.Path,
}

event := audit.NewAuditEvent(
Expand Down
16 changes: 15 additions & 1 deletion internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ func (p *OAuthProviderConfig) validateProvider(index int, insecureAllowHTTP bool
// Enforce HTTPS unless THV_REGISTRY_INSECURE_URL=true or localhost
if issuerURL.Scheme != "https" && !insecureAllowHTTP {
host := issuerURL.Hostname()
if host != "localhost" && host != "127.0.0.1" && host != "::1" {
if !isLoopbackHost(host) {
const msg = "must use HTTPS (set THV_REGISTRY_INSECURE_URL=true to allow HTTP)"
return fmt.Errorf("auth.oauth.providers[%d].issuerUrl %s", index, msg)
}
Expand Down Expand Up @@ -1239,3 +1239,17 @@ func (c *Config) validateAuth() error {

return nil
}

// hostLocalhost is the DNS loopback alias accepted alongside 127.0.0.1 and ::1.
const hostLocalhost = "localhost"

// isLoopbackHost reports whether host is a loopback identifier for which
// HTTP issuer URLs are accepted without THV_REGISTRY_INSECURE_URL.
func isLoopbackHost(host string) bool {
switch host {
case hostLocalhost, "127.0.0.1", "::1":
return true
default:
return false
}
}
10 changes: 6 additions & 4 deletions internal/kubernetes/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ func (r *MCPServerReconciler) Reconcile(ctx context.Context, req ctrl.Request) (
return ctrl.Result{}, nil
}

// annotationValueTrue is the literal value treated as opt-in for boolean
// annotations on registry-export CRDs. Kubernetes annotations are string-typed,
// so we accept only the exact string "true".
const annotationValueTrue = "true"

func checkAnnotation(annotations map[string]string, annotation string) bool {
if annotations == nil {
return false
Expand All @@ -84,10 +89,7 @@ func checkAnnotation(annotations map[string]string, annotation string) bool {
if !ok {
return false
}
if value == "true" {
return true
}
return false
return value == annotationValueTrue
}

func makeNewObjectPredicate[T client.Object](
Expand Down
38 changes: 25 additions & 13 deletions internal/kubernetes/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@ import (
mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1"
)

const (
// defaultServerVersion is used when an MCPServer resource does not carry
// an explicit version annotation.
defaultServerVersion = "1.0.0"
// defaultServerStatus is the registry.ServerExtensions.Status value used
// for entries derived from running Kubernetes resources.
defaultServerStatus = "active"
// defaultImageVersion is returned by parseImageTagOrDigest when no tag or
// digest can be extracted from the container image reference.
defaultImageVersion = "latest"
)

// extractServer converts an MCPServer to a ServerJSON object
//
//nolint:unparam
Expand All @@ -27,9 +39,9 @@ func extractServer(mcpServer *mcpv1alpha1.MCPServer) (*upstreamv0.ServerJSON, er
// Note: MCPServer is a Kubernetes deployment resource, so we extract
// what information is available and create a minimal ServerJSON
serverJSON := &upstreamv0.ServerJSON{
Schema: "https://static.modelcontextprotocol.io/schemas/2025-12-11/server.schema.json",
Schema: model.CurrentSchemaURL,
Name: serverName,
Version: "1.0.0", // Default version, could be extracted from annotations or labels
Version: defaultServerVersion, // Default version, could be extracted from annotations or labels
}

// Extract packages from MCPServer spec (using the container image)
Expand Down Expand Up @@ -69,7 +81,7 @@ func extractServer(mcpServer *mcpv1alpha1.MCPServer) (*upstreamv0.ServerJSON, er

// Create ServerExtensions with Kubernetes metadata
extensions := &registry.ServerExtensions{
Status: "active", // Default status
Status: defaultServerStatus, // Default status
Metadata: &registry.Metadata{
Kubernetes: &registry.KubernetesMetadata{
Kind: mcpServer.Kind,
Expand Down Expand Up @@ -119,9 +131,9 @@ func extractVirtualMCPServer(virtualMCPServer *mcpv1alpha1.VirtualMCPServer) (*u
// Note: VirtualMCPServer is a Kubernetes deployment resource, so we extract
// what information is available and create a minimal ServerJSON
serverJSON := &upstreamv0.ServerJSON{
Schema: "https://static.modelcontextprotocol.io/schemas/2025-12-11/server.schema.json",
Schema: model.CurrentSchemaURL,
Name: serverName,
Version: "1.0.0", // Default version, could be extracted from annotations or labels
Version: defaultServerVersion, // Default version, could be extracted from annotations or labels
}

annotations := virtualMCPServer.GetAnnotations()
Expand Down Expand Up @@ -157,7 +169,7 @@ func extractVirtualMCPServer(virtualMCPServer *mcpv1alpha1.VirtualMCPServer) (*u

// Create ServerExtensions with Kubernetes metadata
extensions := &registry.ServerExtensions{
Status: "active", // Default status
Status: defaultServerStatus, // Default status
Metadata: &registry.Metadata{
Kubernetes: &registry.KubernetesMetadata{
Kind: virtualMCPServer.Kind,
Expand Down Expand Up @@ -205,9 +217,9 @@ func extractMCPRemoteProxy(mcpRemoteProxy *mcpv1alpha1.MCPRemoteProxy) (*upstrea
// Note: MCPRemoteProxy is a Kubernetes deployment resource, so we extract
// what information is available and create a minimal ServerJSON
serverJSON := &upstreamv0.ServerJSON{
Schema: "https://static.modelcontextprotocol.io/schemas/2025-12-11/server.schema.json",
Schema: model.CurrentSchemaURL,
Name: serverName,
Version: "1.0.0", // Default version, could be extracted from annotations or labels
Version: defaultServerVersion, // Default version, could be extracted from annotations or labels
}

annotations := mcpRemoteProxy.GetAnnotations()
Expand Down Expand Up @@ -243,7 +255,7 @@ func extractMCPRemoteProxy(mcpRemoteProxy *mcpv1alpha1.MCPRemoteProxy) (*upstrea

// Create ServerExtensions with Kubernetes metadata
extensions := &registry.ServerExtensions{
Status: "active", // Default status
Status: defaultServerStatus, // Default status
Metadata: &registry.Metadata{
Kubernetes: &registry.KubernetesMetadata{
Kind: mcpRemoteProxy.Kind,
Expand Down Expand Up @@ -329,17 +341,17 @@ func extractPackages(mcpServer *mcpv1alpha1.MCPServer) []model.Package {
transportType = model.TransportTypeStreamableHTTP
}

if transportType == "stdio" {
if transportType == model.TransportTypeStdio {
if mcpServer.Spec.ProxyMode != "" {
transportType = mcpServer.Spec.ProxyMode
} else {
transportType = "streamable-http"
transportType = model.TransportTypeStreamableHTTP
}
}

version := parseImageTagOrDigest(mcpServer.Spec.Image)
packageModel := model.Package{
RegistryType: "oci",
RegistryType: model.RegistryTypeOCI,
Identifier: mcpServer.Spec.Image,
Version: version,
Transport: model.Transport{
Expand Down Expand Up @@ -375,7 +387,7 @@ func parseImageTagOrDigest(image string) string {
return potentialTag
}

return "latest"
return defaultImageVersion
}

// structToMap converts a struct to map[string]any using JSON marshaling.
Expand Down
21 changes: 15 additions & 6 deletions internal/sync/writer/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ const (
iconThemeDark = "DARK"
)

// colServerID is the FK column shared by every per-server temp table fed by COPY.
const colServerID = "server_id"

// Skill package registry types (see toolhive-core skills_types.go).
const (
skillRegistryTypeOCI = "oci"
skillRegistryTypeGit = "git"
)

// dbSyncWriter is a SyncWriter implementation that persists data to a database
type dbSyncWriter struct {
pool *pgxpool.Pool
Expand Down Expand Up @@ -365,7 +374,7 @@ func bulkInsertPackages(

// COPY into temp table
_, err := tx.CopyFrom(ctx, pgx.Identifier{"temp_mcp_server_package"},
[]string{"server_id", "registry_type", "pkg_registry_url", "pkg_identifier", "pkg_version",
[]string{colServerID, "registry_type", "pkg_registry_url", "pkg_identifier", "pkg_version",
"runtime_hint", "runtime_arguments", "package_arguments", "env_vars", "sha256_hash",
"transport", "transport_url", "transport_headers"},
pgx.CopyFromRows(packageRows))
Expand Down Expand Up @@ -460,7 +469,7 @@ func bulkInsertRemotes(

// COPY into temp table
_, err := tx.CopyFrom(ctx, pgx.Identifier{"temp_mcp_server_remote"},
[]string{"server_id", "transport", "transport_url", "transport_headers"},
[]string{colServerID, "transport", "transport_url", "transport_headers"},
pgx.CopyFromRows(remoteRows))
if err != nil {
return fmt.Errorf("failed to copy remotes: %w", err)
Expand Down Expand Up @@ -544,7 +553,7 @@ func bulkInsertIcons(

// COPY into temp table
_, err := tx.CopyFrom(ctx, pgx.Identifier{"temp_mcp_server_icon"},
[]string{"server_id", "source_uri", "mime_type", "theme"},
[]string{colServerID, "source_uri", "mime_type", "theme"},
pgx.CopyFromRows(iconRows))
if err != nil {
return fmt.Errorf("failed to copy icons: %w", err)
Expand Down Expand Up @@ -1111,7 +1120,7 @@ func replaceSkillPackages(

for _, pkg := range skill.Packages {
switch pkg.RegistryType {
case "oci":
case skillRegistryTypeOCI:
if err := querier.InsertSkillOciPackage(ctx, sqlc.InsertSkillOciPackageParams{
SkillID: skillVersionID,
Identifier: pkg.Identifier,
Expand All @@ -1120,7 +1129,7 @@ func replaceSkillPackages(
}); err != nil {
return fmt.Errorf("failed to insert OCI package for skill %s: %w", key, err)
}
case "git":
case skillRegistryTypeGit:
if err := querier.InsertSkillGitPackage(ctx, sqlc.InsertSkillGitPackageParams{
SkillID: skillVersionID,
Url: pkg.URL,
Expand Down Expand Up @@ -1191,7 +1200,7 @@ func marshalJSONOrNil(v any) ([]byte, error) {
return nil, nil
}
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Ptr && rv.IsNil() {
if rv.Kind() == reflect.Pointer && rv.IsNil() {
return nil, nil
}
return json.Marshal(v)
Expand Down
Loading
Loading