Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ The following sets of tools are available (all are on by default):

- **list_discussion_categories** - List discussion categories
- `owner`: Repository owner (string, required)
- `repo`: Repository name (string, required)
- `repo`: Repository name. If not provided, discussion categories will be queried at the organisation level. (string, optional)

- **list_discussions** - List discussions
- `after`: Cursor for pagination. Use the endCursor from the previous page's PageInfo for GraphQL APIs. (string, optional)
Expand Down
24 changes: 14 additions & 10 deletions pkg/github/discussions.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ func GetDiscussionComments(getGQLClient GetGQLClientFn, t translations.Translati

func ListDiscussionCategories(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("list_discussion_categories",
mcp.WithDescription(t("TOOL_LIST_DISCUSSION_CATEGORIES_DESCRIPTION", "List discussion categories with their id and name, for a repository")),
mcp.WithDescription(t("TOOL_LIST_DISCUSSION_CATEGORIES_DESCRIPTION", "List discussion categories with their id and name, for a repository or organisation.")),
mcp.WithToolAnnotation(mcp.ToolAnnotation{
Title: t("TOOL_LIST_DISCUSSION_CATEGORIES_USER_TITLE", "List discussion categories"),
ReadOnlyHint: ToBoolPtr(true),
Expand All @@ -453,19 +453,23 @@ func ListDiscussionCategories(getGQLClient GetGQLClientFn, t translations.Transl
mcp.Description("Repository owner"),
),
mcp.WithString("repo",
mcp.Required(),
mcp.Description("Repository name"),
mcp.Description("Repository name. If not provided, discussion categories will be queried at the organisation level."),
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
// Decode params
var params struct {
Owner string
Repo string
owner, err := RequiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
if err := mapstructure.Decode(request.Params.Arguments, &params); err != nil {
repo, err := OptionalParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
// when not provided, default to the .github repository
// this will query discussion categories at the organisation level
if repo == "" {
repo = ".github"
}

client, err := getGQLClient(ctx)
if err != nil {
Expand All @@ -490,8 +494,8 @@ func ListDiscussionCategories(getGQLClient GetGQLClientFn, t translations.Transl
} `graphql:"repository(owner: $owner, name: $repo)"`
}
vars := map[string]interface{}{
"owner": githubv4.String(params.Owner),
"repo": githubv4.String(params.Repo),
"owner": githubv4.String(owner),
"repo": githubv4.String(repo),
"first": githubv4.Int(25),
}
if err := client.Query(ctx, &q, vars); err != nil {
Expand Down
147 changes: 115 additions & 32 deletions pkg/github/discussions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ func Test_GetDiscussion(t *testing.T) {
assert.ElementsMatch(t, toolDef.InputSchema.Required, []string{"owner", "repo", "discussionNumber"})

// Use exact string query that matches implementation output
qGetDiscussion := "query($discussionNumber:Int!$owner:String!$repo:String!){repository(owner: $owner, name: $repo){discussion(number: $discussionNumber){number,title,body,createdAt,url,category{name}}}}"
qGetDiscussion := "query($discussionNumber:Int!$owner:String!$repo:String!){repository(owner: $owner, name: $repo){discussion(number: $discussionNumber){number,title,body,createdAt,url,category{name}}}}"

vars := map[string]interface{}{
"owner": "owner",
Expand Down Expand Up @@ -638,17 +638,33 @@ func Test_GetDiscussionComments(t *testing.T) {
}

func Test_ListDiscussionCategories(t *testing.T) {
mockClient := githubv4.NewClient(nil)
toolDef, _ := ListDiscussionCategories(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper)
assert.Equal(t, "list_discussion_categories", toolDef.Name)
assert.NotEmpty(t, toolDef.Description)
assert.Contains(t, toolDef.Description, "or organisation")
assert.Contains(t, toolDef.InputSchema.Properties, "owner")
assert.Contains(t, toolDef.InputSchema.Properties, "repo")
assert.ElementsMatch(t, toolDef.InputSchema.Required, []string{"owner"})

// Use exact string query that matches implementation output
qListCategories := "query($first:Int!$owner:String!$repo:String!){repository(owner: $owner, name: $repo){discussionCategories(first: $first){nodes{id,name},pageInfo{hasNextPage,hasPreviousPage,startCursor,endCursor},totalCount}}}"

// Variables matching what GraphQL receives after JSON marshaling/unmarshaling
vars := map[string]interface{}{
// Variables for repository-level categories
varsRepo := map[string]interface{}{
"owner": "owner",
"repo": "repo",
"first": float64(25),
}

mockResp := githubv4mock.DataResponse(map[string]any{
// Variables for organization-level categories (using .github repo)
varsOrg := map[string]interface{}{
"owner": "owner",
"repo": ".github",
"first": float64(25),
}

mockRespRepo := githubv4mock.DataResponse(map[string]any{
"repository": map[string]any{
"discussionCategories": map[string]any{
"nodes": []map[string]any{
Expand All @@ -665,37 +681,104 @@ func Test_ListDiscussionCategories(t *testing.T) {
},
},
})
matcher := githubv4mock.NewQueryMatcher(qListCategories, vars, mockResp)
httpClient := githubv4mock.NewMockedHTTPClient(matcher)
gqlClient := githubv4.NewClient(httpClient)

tool, handler := ListDiscussionCategories(stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper)
assert.Equal(t, "list_discussion_categories", tool.Name)
assert.NotEmpty(t, tool.Description)
assert.Contains(t, tool.InputSchema.Properties, "owner")
assert.Contains(t, tool.InputSchema.Properties, "repo")
assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo"})
mockRespOrg := githubv4mock.DataResponse(map[string]any{
"repository": map[string]any{
"discussionCategories": map[string]any{
"nodes": []map[string]any{
{"id": "789", "name": "Announcements"},
{"id": "101", "name": "General"},
{"id": "112", "name": "Ideas"},
},
"pageInfo": map[string]any{
"hasNextPage": false,
"hasPreviousPage": false,
"startCursor": "",
"endCursor": "",
},
"totalCount": 3,
},
},
})

request := createMCPRequest(map[string]interface{}{"owner": "owner", "repo": "repo"})
result, err := handler(context.Background(), request)
require.NoError(t, err)
tests := []struct {
name string
reqParams map[string]interface{}
vars map[string]interface{}
mockResponse githubv4mock.GQLResponse
expectError bool
expectedCount int
}{
{
name: "list repository-level discussion categories",
reqParams: map[string]interface{}{
"owner": "owner",
"repo": "repo",
},
vars: varsRepo,
mockResponse: mockRespRepo,
expectError: false,
expectedCount: 2,
},
{
name: "list org-level discussion categories (no repo provided)",
reqParams: map[string]interface{}{
"owner": "owner",
// repo is not provided, it will default to ".github"
},
vars: varsOrg,
mockResponse: mockRespOrg,
expectError: false,
expectedCount: 3,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
matcher := githubv4mock.NewQueryMatcher(qListCategories, tc.vars, tc.mockResponse)
httpClient := githubv4mock.NewMockedHTTPClient(matcher)
gqlClient := githubv4.NewClient(httpClient)

text := getTextResult(t, result).Text
_, handler := ListDiscussionCategories(stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper)

var response struct {
Categories []map[string]string `json:"categories"`
PageInfo struct {
HasNextPage bool `json:"hasNextPage"`
HasPreviousPage bool `json:"hasPreviousPage"`
StartCursor string `json:"startCursor"`
EndCursor string `json:"endCursor"`
} `json:"pageInfo"`
TotalCount int `json:"totalCount"`
req := createMCPRequest(tc.reqParams)
res, err := handler(context.Background(), req)
text := getTextResult(t, res).Text

if tc.expectError {
require.True(t, res.IsError)
return
}
require.NoError(t, err)

var response struct {
Categories []map[string]string `json:"categories"`
PageInfo struct {
HasNextPage bool `json:"hasNextPage"`
HasPreviousPage bool `json:"hasPreviousPage"`
StartCursor string `json:"startCursor"`
EndCursor string `json:"endCursor"`
} `json:"pageInfo"`
TotalCount int `json:"totalCount"`
}
require.NoError(t, json.Unmarshal([]byte(text), &response))
assert.Len(t, response.Categories, tc.expectedCount)

// Verify specific content based on test case
switch tc.name {
case "list repository-level discussion categories":
assert.Equal(t, "123", response.Categories[0]["id"])
assert.Equal(t, "CategoryOne", response.Categories[0]["name"])
assert.Equal(t, "456", response.Categories[1]["id"])
assert.Equal(t, "CategoryTwo", response.Categories[1]["name"])
case "list org-level discussion categories (no repo provided)":
assert.Equal(t, "789", response.Categories[0]["id"])
assert.Equal(t, "Announcements", response.Categories[0]["name"])
assert.Equal(t, "101", response.Categories[1]["id"])
assert.Equal(t, "General", response.Categories[1]["name"])
assert.Equal(t, "112", response.Categories[2]["id"])
assert.Equal(t, "Ideas", response.Categories[2]["name"])
}
})
}
require.NoError(t, json.Unmarshal([]byte(text), &response))
assert.Len(t, response.Categories, 2)
assert.Equal(t, "123", response.Categories[0]["id"])
assert.Equal(t, "CategoryOne", response.Categories[0]["name"])
assert.Equal(t, "456", response.Categories[1]["id"])
assert.Equal(t, "CategoryTwo", response.Categories[1]["name"])
}
Loading