Skip to content

Commit 9426075

Browse files
authored
Fallback to default branch in get_file_contents when main doesn't exist (#1669)
* Fallback to default branch in get_file_contents when main doesn't exist * Addressing review comments
1 parent 7e32623 commit 9426075

File tree

2 files changed

+142
-21
lines changed

2 files changed

+142
-21
lines changed

pkg/github/repositories.go

Lines changed: 68 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,8 @@ func GetFileContents(t translations.TranslationHelperFunc) inventory.ServerTool
671671
if err != nil {
672672
return utils.NewToolResultError(err.Error()), nil, nil
673673
}
674+
originalRef := ref
675+
674676
sha, err := OptionalParam[string](args, "sha")
675677
if err != nil {
676678
return utils.NewToolResultError(err.Error()), nil, nil
@@ -681,7 +683,7 @@ func GetFileContents(t translations.TranslationHelperFunc) inventory.ServerTool
681683
return utils.NewToolResultError("failed to get GitHub client"), nil, nil
682684
}
683685

684-
rawOpts, err := resolveGitReference(ctx, client, owner, repo, ref, sha)
686+
rawOpts, fallbackUsed, err := resolveGitReference(ctx, client, owner, repo, ref, sha)
685687
if err != nil {
686688
return utils.NewToolResultError(fmt.Sprintf("failed to resolve git reference: %s", err)), nil, nil
687689
}
@@ -747,6 +749,12 @@ func GetFileContents(t translations.TranslationHelperFunc) inventory.ServerTool
747749
}
748750
}
749751

752+
// main branch ref passed in ref parameter but it doesn't exist - default branch was used
753+
var successNote string
754+
if fallbackUsed {
755+
successNote = fmt.Sprintf(" Note: the provided ref '%s' does not exist, default branch '%s' was used instead.", originalRef, rawOpts.Ref)
756+
}
757+
750758
// Determine if content is text or binary
751759
isTextContent := strings.HasPrefix(contentType, "text/") ||
752760
contentType == "application/json" ||
@@ -762,9 +770,9 @@ func GetFileContents(t translations.TranslationHelperFunc) inventory.ServerTool
762770
}
763771
// Include SHA in the result metadata
764772
if fileSHA != "" {
765-
return utils.NewToolResultResource(fmt.Sprintf("successfully downloaded text file (SHA: %s)", fileSHA), result), nil, nil
773+
return utils.NewToolResultResource(fmt.Sprintf("successfully downloaded text file (SHA: %s)", fileSHA)+successNote, result), nil, nil
766774
}
767-
return utils.NewToolResultResource("successfully downloaded text file", result), nil, nil
775+
return utils.NewToolResultResource("successfully downloaded text file"+successNote, result), nil, nil
768776
}
769777

770778
result := &mcp.ResourceContents{
@@ -774,9 +782,9 @@ func GetFileContents(t translations.TranslationHelperFunc) inventory.ServerTool
774782
}
775783
// Include SHA in the result metadata
776784
if fileSHA != "" {
777-
return utils.NewToolResultResource(fmt.Sprintf("successfully downloaded binary file (SHA: %s)", fileSHA), result), nil, nil
785+
return utils.NewToolResultResource(fmt.Sprintf("successfully downloaded binary file (SHA: %s)", fileSHA)+successNote, result), nil, nil
778786
}
779-
return utils.NewToolResultResource("successfully downloaded binary file", result), nil, nil
787+
return utils.NewToolResultResource("successfully downloaded binary file"+successNote, result), nil, nil
780788
}
781789

782790
// Raw API call failed
@@ -1876,15 +1884,15 @@ func looksLikeSHA(s string) bool {
18761884
//
18771885
// Any unexpected (non-404) errors during the resolution process are returned
18781886
// immediately. All API errors are logged with rich context to aid diagnostics.
1879-
func resolveGitReference(ctx context.Context, githubClient *github.Client, owner, repo, ref, sha string) (*raw.ContentOpts, error) {
1887+
func resolveGitReference(ctx context.Context, githubClient *github.Client, owner, repo, ref, sha string) (*raw.ContentOpts, bool, error) {
18801888
// 1) If SHA explicitly provided, it's the highest priority.
18811889
if sha != "" {
1882-
return &raw.ContentOpts{Ref: "", SHA: sha}, nil
1890+
return &raw.ContentOpts{Ref: "", SHA: sha}, false, nil
18831891
}
18841892

18851893
// 1a) If sha is empty but ref looks like a SHA, return it without changes
18861894
if looksLikeSHA(ref) {
1887-
return &raw.ContentOpts{Ref: "", SHA: ref}, nil
1895+
return &raw.ContentOpts{Ref: "", SHA: ref}, false, nil
18881896
}
18891897

18901898
originalRef := ref // Keep original ref for clearer error messages down the line.
@@ -1893,16 +1901,16 @@ func resolveGitReference(ctx context.Context, githubClient *github.Client, owner
18931901
var reference *github.Reference
18941902
var resp *github.Response
18951903
var err error
1904+
var fallbackUsed bool
18961905

18971906
switch {
18981907
case originalRef == "":
18991908
// 2a) If ref is empty, determine the default branch.
1900-
repoInfo, resp, err := githubClient.Repositories.Get(ctx, owner, repo)
1909+
reference, err = resolveDefaultBranch(ctx, githubClient, owner, repo)
19011910
if err != nil {
1902-
_, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get repository info", resp, err)
1903-
return nil, fmt.Errorf("failed to get repository info: %w", err)
1911+
return nil, false, err // Error is already wrapped in resolveDefaultBranch.
19041912
}
1905-
ref = fmt.Sprintf("refs/heads/%s", repoInfo.GetDefaultBranch())
1913+
ref = reference.GetRef()
19061914
case strings.HasPrefix(originalRef, "refs/"):
19071915
// 2b) Already fully qualified. The reference will be fetched at the end.
19081916
case strings.HasPrefix(originalRef, "heads/") || strings.HasPrefix(originalRef, "tags/"):
@@ -1928,19 +1936,26 @@ func resolveGitReference(ctx context.Context, githubClient *github.Client, owner
19281936
ghErr2, isGhErr2 := err.(*github.ErrorResponse)
19291937
if isGhErr2 && ghErr2.Response.StatusCode == http.StatusNotFound {
19301938
if originalRef == "main" {
1931-
return nil, fmt.Errorf("could not find branch or tag 'main'. Some repositories use 'master' as the default branch name")
1939+
reference, err = resolveDefaultBranch(ctx, githubClient, owner, repo)
1940+
if err != nil {
1941+
return nil, false, err // Error is already wrapped in resolveDefaultBranch.
1942+
}
1943+
// Update ref to the actual default branch ref so the note can be generated
1944+
ref = reference.GetRef()
1945+
fallbackUsed = true
1946+
break
19321947
}
1933-
return nil, fmt.Errorf("could not resolve ref %q as a branch or a tag", originalRef)
1948+
return nil, false, fmt.Errorf("could not resolve ref %q as a branch or a tag", originalRef)
19341949
}
19351950

19361951
// The tag lookup failed for a different reason.
19371952
_, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get reference (tag)", resp, err)
1938-
return nil, fmt.Errorf("failed to get reference for tag '%s': %w", originalRef, err)
1953+
return nil, false, fmt.Errorf("failed to get reference for tag '%s': %w", originalRef, err)
19391954
}
19401955
} else {
19411956
// The branch lookup failed for a different reason.
19421957
_, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get reference (branch)", resp, err)
1943-
return nil, fmt.Errorf("failed to get reference for branch '%s': %w", originalRef, err)
1958+
return nil, false, fmt.Errorf("failed to get reference for branch '%s': %w", originalRef, err)
19441959
}
19451960
}
19461961
}
@@ -1949,15 +1964,48 @@ func resolveGitReference(ctx context.Context, githubClient *github.Client, owner
19491964
reference, resp, err = githubClient.Git.GetRef(ctx, owner, repo, ref)
19501965
if err != nil {
19511966
if ref == "refs/heads/main" {
1952-
return nil, fmt.Errorf("could not find branch 'main'. Some repositories use 'master' as the default branch name")
1967+
reference, err = resolveDefaultBranch(ctx, githubClient, owner, repo)
1968+
if err != nil {
1969+
return nil, false, err // Error is already wrapped in resolveDefaultBranch.
1970+
}
1971+
// Update ref to the actual default branch ref so the note can be generated
1972+
ref = reference.GetRef()
1973+
fallbackUsed = true
1974+
} else {
1975+
_, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get final reference", resp, err)
1976+
return nil, false, fmt.Errorf("failed to get final reference for %q: %w", ref, err)
19531977
}
1954-
_, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get final reference", resp, err)
1955-
return nil, fmt.Errorf("failed to get final reference for %q: %w", ref, err)
19561978
}
19571979
}
19581980

19591981
sha = reference.GetObject().GetSHA()
1960-
return &raw.ContentOpts{Ref: ref, SHA: sha}, nil
1982+
return &raw.ContentOpts{Ref: ref, SHA: sha}, fallbackUsed, nil
1983+
}
1984+
1985+
func resolveDefaultBranch(ctx context.Context, githubClient *github.Client, owner, repo string) (*github.Reference, error) {
1986+
repoInfo, resp, err := githubClient.Repositories.Get(ctx, owner, repo)
1987+
if err != nil {
1988+
_, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get repository info", resp, err)
1989+
return nil, fmt.Errorf("failed to get repository info: %w", err)
1990+
}
1991+
1992+
if resp != nil && resp.Body != nil {
1993+
_ = resp.Body.Close()
1994+
}
1995+
1996+
defaultBranch := repoInfo.GetDefaultBranch()
1997+
1998+
defaultRef, resp, err := githubClient.Git.GetRef(ctx, owner, repo, "heads/"+defaultBranch)
1999+
if err != nil {
2000+
_, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get default branch reference", resp, err)
2001+
return nil, fmt.Errorf("failed to get default branch reference: %w", err)
2002+
}
2003+
2004+
if resp != nil && resp.Body != nil {
2005+
defer func() { _ = resp.Body.Close() }()
2006+
}
2007+
2008+
return defaultRef, nil
19612009
}
19622010

19632011
// ListStarredRepositories creates a tool to list starred repositories for the authenticated user or a specified user.

pkg/github/repositories_test.go

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ func Test_GetFileContents(t *testing.T) {
6969
expectedResult interface{}
7070
expectedErrMsg string
7171
expectStatus int
72+
expectedMsg string // optional: expected message text to verify in result
7273
}{
7374
{
7475
name: "successful text content fetch",
@@ -290,6 +291,70 @@ func Test_GetFileContents(t *testing.T) {
290291
MIMEType: "text/markdown",
291292
},
292293
},
294+
{
295+
name: "successful text content fetch with note when ref falls back to default branch",
296+
mockedClient: mock.NewMockedHTTPClient(
297+
mock.WithRequestMatchHandler(
298+
mock.GetReposByOwnerByRepo,
299+
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
300+
w.WriteHeader(http.StatusOK)
301+
_, _ = w.Write([]byte(`{"name": "repo", "default_branch": "develop"}`))
302+
}),
303+
),
304+
mock.WithRequestMatchHandler(
305+
mock.GetReposGitRefByOwnerByRepoByRef,
306+
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
307+
// Request for "refs/heads/main" -> 404 (doesn't exist)
308+
// Request for "refs/heads/develop" (default branch) -> 200
309+
switch {
310+
case strings.Contains(r.URL.Path, "heads/main"):
311+
w.WriteHeader(http.StatusNotFound)
312+
_, _ = w.Write([]byte(`{"message": "Not Found"}`))
313+
case strings.Contains(r.URL.Path, "heads/develop"):
314+
w.WriteHeader(http.StatusOK)
315+
_, _ = w.Write([]byte(`{"ref": "refs/heads/develop", "object": {"sha": "abc123def456"}}`))
316+
default:
317+
w.WriteHeader(http.StatusNotFound)
318+
_, _ = w.Write([]byte(`{"message": "Not Found"}`))
319+
}
320+
}),
321+
),
322+
mock.WithRequestMatchHandler(
323+
mock.GetReposContentsByOwnerByRepoByPath,
324+
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
325+
w.WriteHeader(http.StatusOK)
326+
fileContent := &github.RepositoryContent{
327+
Name: github.Ptr("README.md"),
328+
Path: github.Ptr("README.md"),
329+
SHA: github.Ptr("abc123"),
330+
Type: github.Ptr("file"),
331+
}
332+
contentBytes, _ := json.Marshal(fileContent)
333+
_, _ = w.Write(contentBytes)
334+
}),
335+
),
336+
mock.WithRequestMatchHandler(
337+
raw.GetRawReposContentsByOwnerByRepoBySHAByPath,
338+
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
339+
w.Header().Set("Content-Type", "text/markdown")
340+
_, _ = w.Write(mockRawContent)
341+
}),
342+
),
343+
),
344+
requestArgs: map[string]interface{}{
345+
"owner": "owner",
346+
"repo": "repo",
347+
"path": "README.md",
348+
"ref": "main",
349+
},
350+
expectError: false,
351+
expectedResult: mcp.ResourceContents{
352+
URI: "repo://owner/repo/abc123def456/contents/README.md",
353+
Text: "# Test Repository\n\nThis is a test repository.",
354+
MIMEType: "text/markdown",
355+
},
356+
expectedMsg: " Note: the provided ref 'main' does not exist, default branch 'refs/heads/develop' was used instead.",
357+
},
293358
{
294359
name: "content fetch fails",
295360
mockedClient: mock.NewMockedHTTPClient(
@@ -358,6 +423,14 @@ func Test_GetFileContents(t *testing.T) {
358423
// Handle both text and blob resources
359424
resource := getResourceResult(t, result)
360425
assert.Equal(t, expected, *resource)
426+
427+
// If expectedMsg is set, verify the message text
428+
if tc.expectedMsg != "" {
429+
require.Len(t, result.Content, 2)
430+
textContent, ok := result.Content[0].(*mcp.TextContent)
431+
require.True(t, ok, "expected Content[0] to be TextContent")
432+
assert.Contains(t, textContent.Text, tc.expectedMsg)
433+
}
361434
case []*github.RepositoryContent:
362435
// Directory content fetch returns a text result (JSON array)
363436
textContent := getTextResult(t, result)
@@ -3288,7 +3361,7 @@ func Test_resolveGitReference(t *testing.T) {
32883361
t.Run(tc.name, func(t *testing.T) {
32893362
// Setup client with mock
32903363
client := github.NewClient(tc.mockSetup())
3291-
opts, err := resolveGitReference(ctx, client, owner, repo, tc.ref, tc.sha)
3364+
opts, _, err := resolveGitReference(ctx, client, owner, repo, tc.ref, tc.sha)
32923365

32933366
if tc.expectError {
32943367
require.Error(t, err)

0 commit comments

Comments
 (0)