Skip to content

Commit 790c20e

Browse files
authored
Add import cache for offline workflow compilation with SHA-based storage (#3981)
1 parent 17d880e commit 790c20e

7 files changed

Lines changed: 610 additions & 31 deletions

File tree

pkg/cli/mcp_inspect.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ func InspectWorkflowMCP(workflowFile string, serverFilter string, toolFilter str
151151

152152
// Process imports from frontmatter to merge imported MCP servers
153153
markdownDir := filepath.Dir(workflowPath)
154-
importsResult, err := parser.ProcessImportsFromFrontmatterWithManifest(workflowData.Frontmatter, markdownDir)
154+
importsResult, err := parser.ProcessImportsFromFrontmatterWithManifest(workflowData.Frontmatter, markdownDir, nil)
155155
if err != nil {
156156
return fmt.Errorf("failed to process imports from frontmatter: %w", err)
157157
}
@@ -295,7 +295,7 @@ func spawnMCPInspector(workflowFile string, serverFilter string, verbose bool) e
295295

296296
// Process imports from frontmatter to merge imported MCP servers
297297
markdownDir := filepath.Dir(workflowPath)
298-
importsResult, err := parser.ProcessImportsFromFrontmatterWithManifest(workflowData.Frontmatter, markdownDir)
298+
importsResult, err := parser.ProcessImportsFromFrontmatterWithManifest(workflowData.Frontmatter, markdownDir, nil)
299299
if err != nil {
300300
return fmt.Errorf("failed to process imports from frontmatter: %w", err)
301301
}

pkg/parser/frontmatter.go

Lines changed: 92 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ func ExtractMarkdown(filePath string) (string, error) {
371371
// ProcessImportsFromFrontmatter processes imports field from frontmatter
372372
// Returns merged tools and engines from imported files
373373
func ProcessImportsFromFrontmatter(frontmatter map[string]any, baseDir string) (mergedTools string, mergedEngines []string, err error) {
374-
result, err := ProcessImportsFromFrontmatterWithManifest(frontmatter, baseDir)
374+
result, err := ProcessImportsFromFrontmatterWithManifest(frontmatter, baseDir, nil)
375375
if err != nil {
376376
return "", nil, err
377377
}
@@ -389,7 +389,7 @@ type importQueueItem struct {
389389
// ProcessImportsFromFrontmatterWithManifest processes imports field from frontmatter
390390
// Returns result containing merged tools, engines, markdown content, and list of imported files
391391
// Uses BFS traversal with queues for deterministic ordering and cycle detection
392-
func ProcessImportsFromFrontmatterWithManifest(frontmatter map[string]any, baseDir string) (*ImportsResult, error) {
392+
func ProcessImportsFromFrontmatterWithManifest(frontmatter map[string]any, baseDir string, cache *ImportCache) (*ImportsResult, error) {
393393
// Check if imports field exists
394394
importsField, exists := frontmatter["imports"]
395395
if !exists {
@@ -451,7 +451,7 @@ func ProcessImportsFromFrontmatterWithManifest(frontmatter map[string]any, baseD
451451
}
452452

453453
// Resolve import path (supports workflowspec format)
454-
fullPath, err := resolveIncludePath(filePath, baseDir)
454+
fullPath, err := resolveIncludePath(filePath, baseDir, cache)
455455
if err != nil {
456456
return nil, fmt.Errorf("failed to resolve import '%s': %w", filePath, err)
457457
}
@@ -558,7 +558,7 @@ func ProcessImportsFromFrontmatterWithManifest(frontmatter map[string]any, baseD
558558
}
559559

560560
// Resolve nested import path relative to the workflows directory, not the nested file's directory
561-
nestedFullPath, err := resolveIncludePath(nestedFilePath, baseDir)
561+
nestedFullPath, err := resolveIncludePath(nestedFilePath, baseDir, cache)
562562
if err != nil {
563563
return nil, fmt.Errorf("failed to resolve nested import '%s' from '%s': %w", nestedFilePath, item.fullPath, err)
564564
}
@@ -724,7 +724,7 @@ func processIncludesWithVisited(content, baseDir string, extractTools bool, visi
724724
}
725725

726726
// Resolve file path first to get the canonical path
727-
fullPath, err := resolveIncludePath(filePath, baseDir)
727+
fullPath, err := resolveIncludePath(filePath, baseDir, nil)
728728
if err != nil {
729729
if isOptional {
730730
// For optional includes, show a friendly informational message to stdout
@@ -796,12 +796,12 @@ func isUnderWorkflowsDirectory(filePath string) bool {
796796
}
797797

798798
// resolveIncludePath resolves include path based on workflowspec format or relative path
799-
func resolveIncludePath(filePath, baseDir string) (string, error) {
799+
func resolveIncludePath(filePath, baseDir string, cache *ImportCache) (string, error) {
800800
// Check if this is a workflowspec (contains owner/repo/path format)
801801
// Format: owner/repo/path@ref or owner/repo/path@ref#section
802802
if isWorkflowSpec(filePath) {
803-
// Download from GitHub using workflowspec
804-
return downloadIncludeFromWorkflowSpec(filePath)
803+
// Download from GitHub using workflowspec (with cache support)
804+
return downloadIncludeFromWorkflowSpec(filePath, cache)
805805
}
806806

807807
// Regular path, resolve relative to base directory
@@ -850,7 +850,8 @@ func isWorkflowSpec(path string) bool {
850850
}
851851

852852
// downloadIncludeFromWorkflowSpec downloads an include file from GitHub using workflowspec
853-
func downloadIncludeFromWorkflowSpec(spec string) (string, error) {
853+
// It first checks the cache, and only downloads if not cached
854+
func downloadIncludeFromWorkflowSpec(spec string, cache *ImportCache) (string, error) {
854855
// Parse the workflowspec
855856
// Format: owner/repo/path@ref or owner/repo/path@ref#section
856857

@@ -880,13 +881,47 @@ func downloadIncludeFromWorkflowSpec(spec string) (string, error) {
880881
repo := slashParts[1]
881882
filePath := strings.Join(slashParts[2:], "/")
882883

884+
// Resolve ref to SHA for cache lookup
885+
var sha string
886+
if cache != nil {
887+
// Only resolve SHA if we're using the cache
888+
resolvedSHA, err := resolveRefToSHA(owner, repo, ref)
889+
if err != nil {
890+
// If the error is an authentication error, propagate it immediately
891+
lowerErr := strings.ToLower(err.Error())
892+
if strings.Contains(lowerErr, "auth") || strings.Contains(lowerErr, "unauthoriz") || strings.Contains(lowerErr, "forbidden") || strings.Contains(lowerErr, "token") || strings.Contains(lowerErr, "permission denied") {
893+
return "", fmt.Errorf("failed to resolve ref to SHA due to authentication error: %w", err)
894+
}
895+
log.Printf("Failed to resolve ref to SHA, will skip cache: %v", err)
896+
// Continue without caching if SHA resolution fails
897+
} else {
898+
sha = resolvedSHA
899+
// Check cache using SHA
900+
if cachedPath, found := cache.Get(owner, repo, filePath, sha); found {
901+
log.Printf("Using cached import: %s/%s/%s@%s (SHA: %s)", owner, repo, filePath, ref, sha)
902+
return cachedPath, nil
903+
}
904+
}
905+
}
906+
883907
// Download the file content from GitHub
884908
content, err := downloadFileFromGitHub(owner, repo, filePath, ref)
885909
if err != nil {
886910
return "", fmt.Errorf("failed to download include from %s: %w", spec, err)
887911
}
888912

889-
// Create a temporary file to store the downloaded content
913+
// If cache is available and we have a SHA, store in cache
914+
if cache != nil && sha != "" {
915+
cachedPath, err := cache.Set(owner, repo, filePath, sha, content)
916+
if err != nil {
917+
log.Printf("Failed to cache import: %v", err)
918+
// Don't fail the compilation, fall back to temp file
919+
} else {
920+
return cachedPath, nil
921+
}
922+
}
923+
924+
// Fallback: Create a temporary file to store the downloaded content
890925
tempFile, err := os.CreateTemp("", "gh-aw-include-*.md")
891926
if err != nil {
892927
return "", fmt.Errorf("failed to create temp file: %w", err)
@@ -906,7 +941,52 @@ func downloadIncludeFromWorkflowSpec(spec string) (string, error) {
906941
return tempFile.Name(), nil
907942
}
908943

909-
// downloadFileFromGitHub downloads a file from GitHub using gh CLI
944+
// resolveRefToSHA resolves a git ref (branch, tag, or SHA) to its commit SHA
945+
func resolveRefToSHA(owner, repo, ref string) (string, error) {
946+
// If ref is already a full SHA (40 hex characters), return it as-is
947+
if len(ref) == 40 && isHexString(ref) {
948+
return ref, nil
949+
}
950+
951+
// Use gh CLI to get the commit SHA for the ref
952+
// This works for branches, tags, and short SHAs
953+
cmd := exec.Command("gh", "api", fmt.Sprintf("/repos/%s/%s/commits/%s", owner, repo, ref), "--jq", ".sha")
954+
955+
output, err := cmd.CombinedOutput()
956+
if err != nil {
957+
outputStr := string(output)
958+
if strings.Contains(outputStr, "GH_TOKEN") || strings.Contains(outputStr, "authentication") || strings.Contains(outputStr, "not logged into") {
959+
return "", fmt.Errorf("failed to resolve ref to SHA: GitHub authentication required. Please run 'gh auth login' or set GH_TOKEN/GITHUB_TOKEN environment variable: %w", err)
960+
}
961+
return "", fmt.Errorf("failed to resolve ref %s to SHA for %s/%s: %s: %w", ref, owner, repo, strings.TrimSpace(outputStr), err)
962+
}
963+
964+
sha := strings.TrimSpace(string(output))
965+
if sha == "" {
966+
return "", fmt.Errorf("empty SHA returned for ref %s in %s/%s", ref, owner, repo)
967+
}
968+
969+
// Validate it's a valid SHA (40 hex characters)
970+
if len(sha) != 40 || !isHexString(sha) {
971+
return "", fmt.Errorf("invalid SHA format returned: %s", sha)
972+
}
973+
974+
return sha, nil
975+
}
976+
977+
// isHexString checks if a string contains only hexadecimal characters
978+
func isHexString(s string) bool {
979+
if len(s) == 0 {
980+
return false
981+
}
982+
for _, c := range s {
983+
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')) {
984+
return false
985+
}
986+
}
987+
return true
988+
}
989+
910990
func downloadFileFromGitHub(owner, repo, path, ref string) ([]byte, error) {
911991
// Use go-gh/v2 to download the file
912992
stdout, stderr, err := gh.Exec("api", fmt.Sprintf("/repos/%s/%s/contents/%s?ref=%s", owner, repo, path, ref), "--jq", ".content")
@@ -1321,7 +1401,7 @@ func processIncludesForField(content, baseDir string, extractFunc func(string) (
13211401
}
13221402

13231403
// Resolve file path
1324-
fullPath, err := resolveIncludePath(filePath, baseDir)
1404+
fullPath, err := resolveIncludePath(filePath, baseDir, nil)
13251405
if err != nil {
13261406
if isOptional {
13271407
// For optional includes, skip extraction

pkg/parser/frontmatter_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -865,7 +865,7 @@ func TestResolveIncludePath(t *testing.T) {
865865

866866
for _, tt := range tests {
867867
t.Run(tt.name, func(t *testing.T) {
868-
result, err := resolveIncludePath(tt.filePath, tt.baseDir)
868+
result, err := resolveIncludePath(tt.filePath, tt.baseDir, nil)
869869

870870
if tt.wantErr {
871871
if err == nil {

pkg/parser/import_cache.go

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
package parser
2+
3+
import (
4+
"fmt"
5+
"os"
6+
"path/filepath"
7+
"strings"
8+
9+
"github.com/githubnext/gh-aw/pkg/logger"
10+
)
11+
12+
var importCacheLog = logger.New("parser:import_cache")
13+
14+
const (
15+
// ImportCacheDir is the directory where cached imports are stored
16+
ImportCacheDir = ".github/aw/imports"
17+
)
18+
19+
// sanitizePath converts a file path to a safe filename by using filepath.Clean
20+
// and replacing directory separators with underscores
21+
func sanitizePath(path string) string {
22+
// Clean the path to remove any ".." or other suspicious elements
23+
cleaned := filepath.Clean(path)
24+
// Replace directory separators with underscores to create a flat filename
25+
// This prevents directory traversal while preserving path uniqueness
26+
sanitized := strings.ReplaceAll(cleaned, string(filepath.Separator), "_")
27+
return sanitized
28+
}
29+
30+
// validatePathComponents validates that path components don't contain malicious sequences
31+
func validatePathComponents(owner, repo, path, sha string) error {
32+
components := []string{owner, repo, path, sha}
33+
for _, comp := range components {
34+
// Check for empty components
35+
if comp == "" {
36+
return fmt.Errorf("empty component in path")
37+
}
38+
// Check for path traversal attempts
39+
if strings.Contains(comp, "..") {
40+
return fmt.Errorf("component contains '..' sequence: %s", comp)
41+
}
42+
// Check for absolute paths
43+
if filepath.IsAbs(comp) {
44+
return fmt.Errorf("component is absolute path: %s", comp)
45+
}
46+
}
47+
return nil
48+
}
49+
50+
// ImportCache manages cached imported workflow files
51+
type ImportCache struct {
52+
baseDir string // Base directory for cache (typically repo root)
53+
}
54+
55+
// NewImportCache creates a new import cache instance
56+
func NewImportCache(repoRoot string) *ImportCache {
57+
importCacheLog.Printf("Creating import cache with base dir: %s", repoRoot)
58+
return &ImportCache{
59+
baseDir: repoRoot,
60+
}
61+
}
62+
63+
// Get retrieves a cached file path if it exists
64+
// sha parameter should be the resolved commit SHA
65+
func (c *ImportCache) Get(owner, repo, path, sha string) (string, bool) {
66+
// Use SHA-based approach: cache files are stored by commit SHA
67+
// Cache path: .github/aw/imports/owner/repo/sha/sanitized_path.md
68+
sanitizedPath := sanitizePath(path)
69+
relativeCachePath := filepath.Join(ImportCacheDir, owner, repo, sha, sanitizedPath)
70+
fullCachePath := filepath.Join(c.baseDir, relativeCachePath)
71+
72+
// Check if the cached file exists
73+
if _, err := os.Stat(fullCachePath); err != nil {
74+
if os.IsNotExist(err) {
75+
importCacheLog.Printf("Cache miss: %s/%s/%s@%s", owner, repo, path, sha)
76+
} else {
77+
// Log other types of errors (permissions, I/O issues, etc.)
78+
importCacheLog.Printf("Cache access error for %s/%s/%s@%s: %v", owner, repo, path, sha, err)
79+
}
80+
return "", false
81+
}
82+
83+
importCacheLog.Printf("Cache hit: %s/%s/%s@%s -> %s", owner, repo, path, sha, fullCachePath)
84+
return fullCachePath, true
85+
}
86+
87+
// Set stores a new cache entry by saving the content to the cache directory
88+
// sha parameter should be the resolved commit SHA
89+
func (c *ImportCache) Set(owner, repo, path, sha string, content []byte) (string, error) {
90+
// Validate file size (max 10MB)
91+
const maxFileSize = 10 * 1024 * 1024
92+
if len(content) > maxFileSize {
93+
return "", fmt.Errorf("file size (%d bytes) exceeds maximum allowed size (%d bytes)", len(content), maxFileSize)
94+
}
95+
96+
// Validate path components to prevent path traversal
97+
if err := validatePathComponents(owner, repo, path, sha); err != nil {
98+
return "", fmt.Errorf("invalid path components: %w", err)
99+
}
100+
101+
// Use SHA in path for consistent caching
102+
// This ensures that different refs pointing to the same commit reuse the same cache
103+
sanitizedPath := sanitizePath(path)
104+
relativeCachePath := filepath.Join(ImportCacheDir, owner, repo, sha, sanitizedPath)
105+
fullCachePath := filepath.Join(c.baseDir, relativeCachePath)
106+
107+
// Ensure directory exists
108+
dir := filepath.Dir(fullCachePath)
109+
if err := os.MkdirAll(dir, 0755); err != nil {
110+
importCacheLog.Printf("Failed to create cache directory: %v", err)
111+
return "", err
112+
}
113+
114+
// Ensure .gitattributes file exists in cache root
115+
if err := c.ensureGitAttributes(); err != nil {
116+
importCacheLog.Printf("Failed to ensure .gitattributes: %v", err)
117+
// Non-fatal error - continue with caching
118+
}
119+
120+
// Write content to cache file
121+
if err := os.WriteFile(fullCachePath, content, 0644); err != nil {
122+
importCacheLog.Printf("Failed to write cache file: %v", err)
123+
return "", err
124+
}
125+
126+
importCacheLog.Printf("Cached import: %s/%s/%s@%s -> %s", owner, repo, path, sha, fullCachePath)
127+
return fullCachePath, nil
128+
}
129+
130+
// GetCacheDir returns the base cache directory path
131+
func (c *ImportCache) GetCacheDir() string {
132+
return filepath.Join(c.baseDir, ImportCacheDir)
133+
}
134+
135+
// ensureGitAttributes creates the .gitattributes file in the cache directory if it doesn't exist
136+
func (c *ImportCache) ensureGitAttributes() error {
137+
gitAttributesPath := filepath.Join(c.GetCacheDir(), ".gitattributes")
138+
139+
// Check if .gitattributes already exists
140+
if _, err := os.Stat(gitAttributesPath); err == nil {
141+
// File already exists, nothing to do
142+
return nil
143+
}
144+
145+
// Ensure cache root directory exists
146+
cacheDir := c.GetCacheDir()
147+
if err := os.MkdirAll(cacheDir, 0755); err != nil {
148+
return err
149+
}
150+
151+
// Create .gitattributes file with content
152+
content := `# Mark all cached import files as generated
153+
* linguist-generated=true
154+
155+
# Use 'ours' merge strategy to keep local cached versions
156+
* merge=ours
157+
`
158+
159+
if err := os.WriteFile(gitAttributesPath, []byte(content), 0644); err != nil {
160+
return err
161+
}
162+
163+
importCacheLog.Printf("Created .gitattributes in cache directory: %s", gitAttributesPath)
164+
return nil
165+
}

0 commit comments

Comments
 (0)