Skip to content

Commit 4f1dee2

Browse files
Add content filtering functionality and tests
Co-authored-by: SamMorrowDrums <4811358+SamMorrowDrums@users.noreply.github.com>
1 parent e21c81b commit 4f1dee2

File tree

4 files changed

+742
-0
lines changed

4 files changed

+742
-0
lines changed

cmd/github-mcp-server/main.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ var (
5353
ExportTranslations: viper.GetBool("export-translations"),
5454
EnableCommandLogging: viper.GetBool("enable-command-logging"),
5555
LogFilePath: viper.GetString("log-file"),
56+
TrustedRepo: viper.GetString("trusted-repo"),
5657
}
5758

5859
return ghmcp.RunStdioServer(stdioServerConfig)
@@ -73,6 +74,7 @@ func init() {
7374
rootCmd.PersistentFlags().Bool("enable-command-logging", false, "When enabled, the server will log all command requests and responses to the log file")
7475
rootCmd.PersistentFlags().Bool("export-translations", false, "Save translations to a JSON file")
7576
rootCmd.PersistentFlags().String("gh-host", "", "Specify the GitHub hostname (for GitHub Enterprise etc.)")
77+
rootCmd.PersistentFlags().String("trusted-repo", "", "Limit content to users with push access to the specified repo (format: owner/repo)")
7678

7779
// Bind flag to viper
7880
_ = viper.BindPFlag("toolsets", rootCmd.PersistentFlags().Lookup("toolsets"))
@@ -82,6 +84,7 @@ func init() {
8284
_ = viper.BindPFlag("enable-command-logging", rootCmd.PersistentFlags().Lookup("enable-command-logging"))
8385
_ = viper.BindPFlag("export-translations", rootCmd.PersistentFlags().Lookup("export-translations"))
8486
_ = viper.BindPFlag("host", rootCmd.PersistentFlags().Lookup("gh-host"))
87+
_ = viper.BindPFlag("trusted-repo", rootCmd.PersistentFlags().Lookup("trusted-repo"))
8588

8689
// Add subcommands
8790
rootCmd.AddCommand(stdioCmd)

internal/ghmcp/server.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,17 @@ type MCPServerConfig struct {
4343
// ReadOnly indicates if we should only offer read-only tools
4444
ReadOnly bool
4545

46+
// TrustedRepo is a repository in the format "owner/repo" used to limit content
47+
// to users with push access to the specified repo
48+
TrustedRepo string
49+
4650
// Translator provides translated text for the server tooling
4751
Translator translations.TranslationHelperFunc
4852
}
4953

5054
func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) {
55+
ctx := context.Background()
56+
5157
apiHost, err := parseAPIHost(cfg.Host)
5258
if err != nil {
5359
return nil, fmt.Errorf("failed to parse API host: %w", err)
@@ -112,6 +118,14 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) {
112118
return gqlClient, nil // closing over client
113119
}
114120

121+
// Initialize the content filter if a trusted repo is specified
122+
if cfg.TrustedRepo != "" {
123+
ctx, err := InitContentFilter(ctx, cfg.TrustedRepo, getGQLClient)
124+
if err != nil {
125+
return nil, fmt.Errorf("failed to initialize content filter: %w", err)
126+
}
127+
}
128+
115129
// Create default toolsets
116130
toolsets, err := github.InitToolsets(
117131
enabledToolsets,
@@ -169,6 +183,10 @@ type StdioServerConfig struct {
169183

170184
// Path to the log file if not stderr
171185
LogFilePath string
186+
187+
// TrustedRepo is a repository in the format "owner/repo" used to limit content
188+
// to users with push access to the specified repo
189+
TrustedRepo string
172190
}
173191

174192
// RunStdioServer is not concurrent safe.
@@ -186,6 +204,7 @@ func RunStdioServer(cfg StdioServerConfig) error {
186204
EnabledToolsets: cfg.EnabledToolsets,
187205
DynamicToolsets: cfg.DynamicToolsets,
188206
ReadOnly: cfg.ReadOnly,
207+
TrustedRepo: cfg.TrustedRepo,
189208
Translator: t,
190209
})
191210
if err != nil {

pkg/github/content_filter.go

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
package github
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"strings"
7+
"sync"
8+
9+
"github.com/shurcooL/githubv4"
10+
)
11+
12+
// contextKey is a private type used for context keys
13+
type contextKey int
14+
15+
const (
16+
// ContentFilterKey is the key used to access content filter settings from context
17+
contentFilterKey contextKey = iota
18+
)
19+
20+
// ContentFilterSettings holds the configuration for content filtering
21+
type ContentFilterSettings struct {
22+
// Enabled indicates if content filtering is enabled
23+
Enabled bool
24+
// TrustedRepo is the repository in format "owner/repo" that is used to check permissions
25+
TrustedRepo string
26+
// OwnerRepo is the parsed owner and repo from TrustedRepo
27+
OwnerRepo OwnerRepo
28+
// IsPrivate indicates if the trusted repo is private
29+
IsPrivate bool
30+
// TrustedUsers is a map of users who have been verified to have push access
31+
TrustedUsers map[string]bool
32+
// mu protects the TrustedUsers map
33+
mu sync.RWMutex
34+
}
35+
36+
// OwnerRepo holds the parsed owner and repo from a string in the format "owner/repo"
37+
type OwnerRepo struct {
38+
Owner string
39+
Repo string
40+
}
41+
42+
// ParseOwnerRepo parses a string in the format "owner/repo" into an OwnerRepo struct
43+
func ParseOwnerRepo(s string) (OwnerRepo, error) {
44+
parts := strings.Split(s, "/")
45+
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
46+
return OwnerRepo{}, fmt.Errorf("invalid format for owner/repo: %s", s)
47+
}
48+
return OwnerRepo{Owner: parts[0], Repo: parts[1]}, nil
49+
}
50+
51+
// GetContentFilterFromContext retrieves the content filter settings from the context
52+
func GetContentFilterFromContext(ctx context.Context) (*ContentFilterSettings, bool) {
53+
if ctx == nil {
54+
return nil, false
55+
}
56+
settings, ok := ctx.Value(contentFilterKey).(*ContentFilterSettings)
57+
return settings, ok
58+
}
59+
60+
// InitContentFilter initializes the content filter in the context
61+
func InitContentFilter(ctx context.Context, trustedRepo string, getGQLClient GetGQLClientFn) (context.Context, error) {
62+
if trustedRepo == "" {
63+
// Content filtering is not enabled
64+
return ctx, nil
65+
}
66+
67+
ownerRepo, err := ParseOwnerRepo(trustedRepo)
68+
if err != nil {
69+
return ctx, err
70+
}
71+
72+
settings := &ContentFilterSettings{
73+
Enabled: true,
74+
TrustedRepo: trustedRepo,
75+
OwnerRepo: ownerRepo,
76+
TrustedUsers: map[string]bool{},
77+
}
78+
79+
// Check if the repository is private, if so, disable content filtering
80+
isPrivate, err := IsRepoPrivate(ctx, settings.OwnerRepo, getGQLClient)
81+
if err != nil {
82+
return ctx, fmt.Errorf("failed to check repository visibility: %w", err)
83+
}
84+
settings.IsPrivate = isPrivate
85+
86+
return context.WithValue(ctx, contentFilterKey, settings), nil
87+
}
88+
89+
// IsRepoPrivate checks if a repository is private using GraphQL
90+
func IsRepoPrivate(ctx context.Context, ownerRepo OwnerRepo, getGQLClient GetGQLClientFn) (bool, error) {
91+
client, err := getGQLClient(ctx)
92+
if err != nil {
93+
return false, fmt.Errorf("failed to get GraphQL client: %w", err)
94+
}
95+
96+
var query struct {
97+
Repository struct {
98+
IsPrivate githubv4.Boolean
99+
} `graphql:"repository(owner: $owner, name: $name)"`
100+
}
101+
102+
variables := map[string]interface{}{
103+
"owner": githubv4.String(ownerRepo.Owner),
104+
"name": githubv4.String(ownerRepo.Repo),
105+
}
106+
107+
err = client.Query(ctx, &query, variables)
108+
if err != nil {
109+
return false, fmt.Errorf("failed to query repository visibility: %w", err)
110+
}
111+
112+
return bool(query.Repository.IsPrivate), nil
113+
}
114+
115+
// HasPushAccess checks if a user has push access to the trusted repository
116+
func HasPushAccess(ctx context.Context, username string, getGQLClient GetGQLClientFn) (bool, error) {
117+
settings, ok := GetContentFilterFromContext(ctx)
118+
if !ok || !settings.Enabled || settings.IsPrivate {
119+
// If filtering is not enabled or repo is private, all users are trusted
120+
return true, nil
121+
}
122+
123+
// Check cache first
124+
settings.mu.RLock()
125+
trusted, found := settings.TrustedUsers[username]
126+
settings.mu.RUnlock()
127+
if found {
128+
return trusted, nil
129+
}
130+
131+
// Query GitHub API for permission
132+
client, err := getGQLClient(ctx)
133+
if err != nil {
134+
return false, fmt.Errorf("failed to get GraphQL client: %w", err)
135+
}
136+
137+
var query struct {
138+
Repository struct {
139+
Collaborators struct {
140+
Edges []struct {
141+
Permission githubv4.String
142+
Node struct {
143+
Login githubv4.String
144+
}
145+
}
146+
} `graphql:"collaborators(query: $username, first: 1)"`
147+
} `graphql:"repository(owner: $owner, name: $name)"`
148+
}
149+
150+
variables := map[string]interface{}{
151+
"owner": githubv4.String(settings.OwnerRepo.Owner),
152+
"name": githubv4.String(settings.OwnerRepo.Repo),
153+
"username": githubv4.String(username),
154+
}
155+
156+
err = client.Query(ctx, &query, variables)
157+
if err != nil {
158+
return false, fmt.Errorf("failed to query user permissions: %w", err)
159+
}
160+
161+
// Check if the user has push access
162+
hasPush := false
163+
for _, edge := range query.Repository.Collaborators.Edges {
164+
login := string(edge.Node.Login)
165+
if strings.EqualFold(login, username) {
166+
permission := string(edge.Permission)
167+
// WRITE, ADMIN, and MAINTAIN permissions have push access
168+
hasPush = permission == "WRITE" || permission == "ADMIN" || permission == "MAINTAIN"
169+
break
170+
}
171+
}
172+
173+
// Cache the result
174+
settings.mu.Lock()
175+
settings.TrustedUsers[username] = hasPush
176+
settings.mu.Unlock()
177+
178+
return hasPush, nil
179+
}
180+
181+
// ShouldIncludeContent checks if content from a user should be included
182+
func ShouldIncludeContent(ctx context.Context, username string, getGQLClient GetGQLClientFn) bool {
183+
settings, ok := GetContentFilterFromContext(ctx)
184+
if !ok || !settings.Enabled || settings.IsPrivate {
185+
// If filtering is not enabled or repo is private, include all content
186+
return true
187+
}
188+
189+
// Check if user has push access
190+
hasPush, err := HasPushAccess(ctx, username, getGQLClient)
191+
if err != nil {
192+
// If there's an error checking permissions, default to not including the content for safety
193+
return false
194+
}
195+
return hasPush
196+
}

0 commit comments

Comments
 (0)