diff --git a/main.go b/main.go index e978552..870ff5f 100644 --- a/main.go +++ b/main.go @@ -404,6 +404,8 @@ func main() { setupApi(cfg, r, version) setupPages(cfg, r) r.SetRedirectTrailingSlash(false) + routingHandler := proxy.RoutingHandler(cfg) + noRouteHandler := proxy.NoRouteHandler(cfg) r.GET("/github.com/:user/:repo/releases/*filepath", func(c *touka.Context) { // 规范化路径: 移除前导斜杠, 简化后续处理 @@ -433,7 +435,7 @@ func main() { // 根据匹配结果执行最终操作 if isValidDownload { c.Set("matcher", "releases") - proxy.RoutingHandler(cfg)(c) + routingHandler(c) } else { // 任何不符合下载链接格式的 'releases' 路径都被视为浏览页面并拒绝 proxy.ErrorPage(c, proxy.NewErrorWithStatusLookup(400, "unsupported releases page, only download links are allowed")) @@ -443,45 +445,45 @@ func main() { r.GET("/github.com/:user/:repo/archive/*filepath", func(c *touka.Context) { c.Set("matcher", "releases") - proxy.RoutingHandler(cfg)(c) + routingHandler(c) }) r.GET("/github.com/:user/:repo/blob/*filepath", func(c *touka.Context) { c.Set("matcher", "blob") - proxy.RoutingHandler(cfg)(c) + routingHandler(c) }) r.GET("/github.com/:user/:repo/raw/*filepath", func(c *touka.Context) { c.Set("matcher", "raw") - proxy.RoutingHandler(cfg)(c) + routingHandler(c) }) r.GET("/github.com/:user/:repo/info/*filepath", func(c *touka.Context) { c.Set("matcher", "clone") - proxy.RoutingHandler(cfg)(c) + routingHandler(c) }) r.GET("/github.com/:user/:repo/git-upload-pack", func(c *touka.Context) { c.Set("matcher", "clone") - proxy.RoutingHandler(cfg)(c) + routingHandler(c) }) r.POST("/github.com/:user/:repo/git-upload-pack", func(c *touka.Context) { c.Set("matcher", "clone") - proxy.RoutingHandler(cfg)(c) + routingHandler(c) }) r.GET("/raw.githubusercontent.com/:user/:repo/*filepath", func(c *touka.Context) { c.Set("matcher", "raw") - proxy.RoutingHandler(cfg)(c) + routingHandler(c) }) r.GET("/gist.githubusercontent.com/:user/*filepath", func(c *touka.Context) { c.Set("matcher", "gist") - proxy.NoRouteHandler(cfg)(c) + noRouteHandler(c) }) r.ANY("/api.github.com/repos/:user/:repo/*filepath", func(c *touka.Context) { c.Set("matcher", "api") - proxy.RoutingHandler(cfg)(c) + routingHandler(c) }) r.ANY("/v2/*path", @@ -497,7 +499,7 @@ func main() { }) r.NoRoute(func(c *touka.Context) { - proxy.NoRouteHandler(cfg)(c) + noRouteHandler(c) }) fmt.Printf("GHProxy Version: %s\n", version) diff --git a/proxy/chunkreq.go b/proxy/chunkreq.go index 7e3725e..484f81f 100644 --- a/proxy/chunkreq.go +++ b/proxy/chunkreq.go @@ -74,7 +74,7 @@ func ChunkedProxyRequest(ctx context.Context, c *touka.Context, u string, cfg *c // 处理响应体大小限制 var ( - bodySize int + bodySize = -1 contentLength string sizelimit int ) @@ -134,7 +134,7 @@ func ChunkedProxyRequest(ctx context.Context, c *touka.Context, u string, cfg *c var reader io.Reader - reader, _, err = processLinks(bodyReader, c.Request.Host, cfg, c) + reader, _, err = processLinks(bodyReader, c.Request.Host, cfg, c, bodySize) c.WriteStream(reader) if err != nil { c.Errorf("%s %s %s %s %s Failed to copy response body: %v", c.ClientIP(), c.Request.Method, u, c.UserAgent(), c.Request.Proto, err) diff --git a/proxy/handler.go b/proxy/handler.go index b15f1b5..35032ef 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -3,50 +3,70 @@ package proxy import ( "fmt" "ghproxy/config" - "regexp" "strings" "github.com/infinite-iroha/touka" ) -var re = regexp.MustCompile(`^(http:|https:)?/?/?(.*)`) // 匹配http://或https://开头的路径 +func buildProxyPath(path, matcher string) string { + var sb strings.Builder + sb.Grow(len(path) + 50) + + if matcher == "blob" && strings.HasPrefix(path, "github.com") { + sb.WriteString("https://raw.githubusercontent.com") + pathSegment := path[len("github.com"):] + if i := strings.Index(pathSegment, "/blob/"); i != -1 { + sb.WriteString(pathSegment[:i]) + sb.WriteByte('/') + sb.WriteString(pathSegment[i+len("/blob/"):]) + } else { + sb.WriteString(pathSegment) + } + return sb.String() + } + + sb.WriteString("https://") + sb.WriteString(path) + return sb.String() +} + +func normalizeProxyPath(rawPath string) (string, bool) { + path := strings.TrimLeft(rawPath, "/") + + switch { + case strings.HasPrefix(path, "https:"): + path = path[len("https:"):] + case strings.HasPrefix(path, "http:"): + path = path[len("http:"):] + } + + path = strings.TrimLeft(path, "/") + return path, path != "" +} func NoRouteHandler(cfg *config.Config) touka.HandlerFunc { return func(c *touka.Context) { var ctx = c.Request.Context() var shoudBreak bool - var ( - rawPath string - matches []string - ) - - rawPath = strings.TrimPrefix(c.GetRequestURI(), "/") // 去掉前缀/ - matches = re.FindStringSubmatch(rawPath) // 匹配路径 + path, ok := normalizeProxyPath(c.GetRequestURI()) // 匹配路径错误处理 - if len(matches) < 3 { + if !ok { c.Warnf("%s %s %s %s %s Invalid URL", c.ClientIP(), c.Request.Method, c.Request.URL.Path, c.UserAgent(), c.Request.Proto) ErrorPage(c, NewErrorWithStatusLookup(400, fmt.Sprintf("Invalid URL Format: %s", c.GetRequestURI()))) return } - // 制作url - rawPath = "https://" + matches[2] - - var ( - user string - repo string - matcher string - ) - var matcherErr *GHProxyErrors - user, repo, matcher, matcherErr = Matcher(rawPath, cfg) + user, repo, matcher, matcherErr := Matcher("https://"+path, cfg) if matcherErr != nil { ErrorPage(c, matcherErr) return } + rawPath := buildProxyPath(path, matcher) + shoudBreak = listCheck(cfg, c, user, repo, rawPath) if shoudBreak { return @@ -59,9 +79,6 @@ func NoRouteHandler(cfg *config.Config) touka.HandlerFunc { // 处理blob/raw路径 if matcher == "blob" { - rawPath = rawPath[18:] - rawPath = "https://raw.githubusercontent.com" + rawPath - rawPath = strings.Replace(rawPath, "/blob/", "/", 1) matcher = "raw" } diff --git a/proxy/hotpath_test.go b/proxy/hotpath_test.go new file mode 100644 index 0000000..8d9b4e9 --- /dev/null +++ b/proxy/hotpath_test.go @@ -0,0 +1,192 @@ +package proxy + +import ( + "net/http" + "net/http/httptest" + "reflect" + "testing" + + "ghproxy/config" + + "github.com/infinite-iroha/touka" +) + +func TestNormalizeProxyPath(t *testing.T) { + testCases := []struct { + name string + rawPath string + expected string + expectValid bool + }{ + {name: "Plain host path", rawPath: "/github.com/owner/repo", expected: "github.com/owner/repo", expectValid: true}, + {name: "HTTPS URL", rawPath: "/https://github.com/owner/repo", expected: "github.com/owner/repo", expectValid: true}, + {name: "HTTP URL", rawPath: "http://github.com/owner/repo", expected: "github.com/owner/repo", expectValid: true}, + {name: "Scheme with single slash", rawPath: "https:/github.com/owner/repo", expected: "github.com/owner/repo", expectValid: true}, + {name: "Extra leading slashes", rawPath: "////github.com/owner/repo", expected: "github.com/owner/repo", expectValid: true}, + {name: "Empty path", rawPath: "", expectValid: false}, + {name: "Slash only", rawPath: "////", expectValid: false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, ok := normalizeProxyPath(tc.rawPath) + if ok != tc.expectValid { + t.Fatalf("valid = %v, want %v", ok, tc.expectValid) + } + if got != tc.expected { + t.Fatalf("path = %q, want %q", got, tc.expected) + } + }) + } +} + +func TestCopyHeaderFiltered(t *testing.T) { + src := http.Header{ + "Accept": {"text/plain"}, + "Connection": {"keep-alive"}, + "X-Test": {"one", "two"}, + "Accept-Encoding": {"gzip"}, + } + dst := make(http.Header) + + copyHeaderFiltered(dst, src, reqHeadersToRemove) + + if got := dst.Values("Accept"); !reflect.DeepEqual(got, []string{"text/plain"}) { + t.Fatalf("Accept = %v, want [text/plain]", got) + } + if got := dst.Values("X-Test"); !reflect.DeepEqual(got, []string{"one", "two"}) { + t.Fatalf("X-Test = %v, want [one two]", got) + } + if got := dst.Values("Connection"); len(got) != 0 { + t.Fatalf("Connection should be filtered, got %v", got) + } + if got := dst.Values("Accept-Encoding"); len(got) != 0 { + t.Fatalf("Accept-Encoding should be filtered, got %v", got) + } +} + +func TestCopyHeaderFiltered_CanonicalizesDenylist(t *testing.T) { + src := http.Header{ + "Cf-Ipcountry": {"CN"}, + "Cf-Ray": {"abc123"}, + "Cf-Ew-Via": {"edge"}, + "X-Forwarded-For": {"127.0.0.1"}, + } + dst := make(http.Header) + + copyHeaderFiltered(dst, src, reqHeadersToRemove) + + if got := dst.Values("Cf-Ipcountry"); len(got) != 0 { + t.Fatalf("Cf-Ipcountry should be filtered, got %v", got) + } + if got := dst.Values("Cf-Ray"); len(got) != 0 { + t.Fatalf("Cf-Ray should be filtered, got %v", got) + } + if got := dst.Values("Cf-Ew-Via"); len(got) != 0 { + t.Fatalf("Cf-Ew-Via should be filtered, got %v", got) + } + if got := dst.Values("X-Forwarded-For"); !reflect.DeepEqual(got, []string{"127.0.0.1"}) { + t.Fatalf("X-Forwarded-For = %v, want [127.0.0.1]", got) + } +} + +func TestCopyHeaderFiltered_AllowsAllWhenDenylistEmpty(t *testing.T) { + src := http.Header{ + "X-Test": {"one", "two"}, + } + dst := make(http.Header) + + copyHeaderFiltered(dst, src, nil) + + if got := dst.Values("X-Test"); !reflect.DeepEqual(got, []string{"one", "two"}) { + t.Fatalf("X-Test = %v, want [one two]", got) + } +} + +func TestBuildProxyPath(t *testing.T) { + testCases := []struct { + name string + path string + matcher string + expected string + }{ + { + name: "Blob path rewrites to raw host", + path: "github.com/owner/repo/blob/main/file.go", + matcher: "blob", + expected: "https://raw.githubusercontent.com/owner/repo/main/file.go", + }, + { + name: "Non blob path keeps host", + path: "raw.githubusercontent.com/owner/repo/main/file.go", + matcher: "raw", + expected: "https://raw.githubusercontent.com/owner/repo/main/file.go", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if got := buildProxyPath(tc.path, tc.matcher); got != tc.expected { + t.Fatalf("buildProxyPath() = %q, want %q", got, tc.expected) + } + }) + } +} + +func TestNoRouteHandler_InvalidURI_ReturnsBadRequest(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "http://client.example/", nil) + req.RequestURI = "/" + + ctx, _ := touka.CreateTestContextWithRequest(recorder, req) + NoRouteHandler(&config.Config{})(ctx) + + if recorder.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusBadRequest) + } + if body := recorder.Body.String(); body == "" { + t.Fatal("expected error response body to be written") + } +} + +func TestNoRouteHandler_NormalizesAbsoluteRequestURIForAPI(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "http://client.example/placeholder", nil) + req.RequestURI = "/https://api.github.com/repos/WJQSERVER-STUDIO/ghproxy/releases?per_page=1" + + ctx, _ := touka.CreateTestContextWithRequest(recorder, req) + cfg := &config.Config{} + NoRouteHandler(cfg)(ctx) + + if recorder.Code != http.StatusForbidden { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusForbidden) + } + if body := recorder.Body.String(); body == "" { + t.Fatal("expected error response body to be written") + } +} + +func BenchmarkNormalizeProxyPath(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _ = normalizeProxyPath("/https://github.com/WJQSERVER-STUDIO/ghproxy/releases/download/v1.0.0/asset.tar.gz") + } +} + +func BenchmarkCopyHeaderFiltered(b *testing.B) { + src := http.Header{ + "Accept": {"text/plain"}, + "Accept-Encoding": {"gzip"}, + "Connection": {"keep-alive"}, + "User-Agent": {"curl/8.0.1"}, + "X-Test": {"one", "two"}, + "CF-Connecting-IP": {"127.0.0.1"}, + "X-Forwarded-For": {"127.0.0.1"}, + "Transfer-Encoding": {"chunked"}, + } + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + dst := make(http.Header) + copyHeaderFiltered(dst, src, reqHeadersToRemove) + } +} diff --git a/proxy/httpc.go b/proxy/httpc.go index 857f3f0..3abbdfb 100644 --- a/proxy/httpc.go +++ b/proxy/httpc.go @@ -39,10 +39,13 @@ func initHTTPClient(cfg *config.Config) *httpc.Client { switch cfg.Httpc.Mode { case "auto", "": tr = &http.Transport{ - IdleConnTimeout: 30 * time.Second, - WriteBufferSize: 32 * 1024, // 32KB - ReadBufferSize: 32 * 1024, // 32KB - Protocols: proTolcols, + MaxIdleConns: cfg.Httpc.MaxIdleConns, + MaxConnsPerHost: cfg.Httpc.MaxConnsPerHost, + MaxIdleConnsPerHost: cfg.Httpc.MaxIdleConnsPerHost, + IdleConnTimeout: 30 * time.Second, + WriteBufferSize: 32 * 1024, // 32KB + ReadBufferSize: 32 * 1024, // 32KB + Protocols: proTolcols, } case "advanced": tr = &http.Transport{ @@ -77,9 +80,12 @@ func initGitHTTPClient(cfg *config.Config) { switch cfg.Httpc.Mode { case "auto", "": gittr = &http.Transport{ - IdleConnTimeout: 30 * time.Second, - WriteBufferSize: 32 * 1024, // 32KB - ReadBufferSize: 32 * 1024, // 32KB + MaxIdleConns: cfg.Httpc.MaxIdleConns, + MaxConnsPerHost: cfg.Httpc.MaxConnsPerHost, + MaxIdleConnsPerHost: cfg.Httpc.MaxIdleConnsPerHost, + IdleConnTimeout: 30 * time.Second, + WriteBufferSize: 32 * 1024, // 32KB + ReadBufferSize: 32 * 1024, // 32KB } case "advanced": gittr = &http.Transport{ diff --git a/proxy/httpc_test.go b/proxy/httpc_test.go new file mode 100644 index 0000000..31898fa --- /dev/null +++ b/proxy/httpc_test.go @@ -0,0 +1,64 @@ +package proxy + +import ( + "ghproxy/config" + "testing" +) + +func TestInitHTTPClient_AutoModeUsesConfiguredPoolSizes(t *testing.T) { + oldTr, oldClient := tr, client + t.Cleanup(func() { + tr = oldTr + client = oldClient + }) + + cfg := &config.Config{} + cfg.Httpc.Mode = "auto" + cfg.Httpc.MaxIdleConns = 123 + cfg.Httpc.MaxIdleConnsPerHost = 45 + cfg.Httpc.MaxConnsPerHost = 67 + + initHTTPClient(cfg) + + if tr == nil { + t.Fatal("transport was not initialized") + } + if tr.MaxIdleConns != 123 { + t.Fatalf("MaxIdleConns = %d, want 123", tr.MaxIdleConns) + } + if tr.MaxIdleConnsPerHost != 45 { + t.Fatalf("MaxIdleConnsPerHost = %d, want 45", tr.MaxIdleConnsPerHost) + } + if tr.MaxConnsPerHost != 67 { + t.Fatalf("MaxConnsPerHost = %d, want 67", tr.MaxConnsPerHost) + } +} + +func TestInitGitHTTPClient_AutoModeUsesConfiguredPoolSizes(t *testing.T) { + oldGitTr, oldGitClient := gittr, gitclient + t.Cleanup(func() { + gittr = oldGitTr + gitclient = oldGitClient + }) + + cfg := &config.Config{} + cfg.Httpc.Mode = "auto" + cfg.Httpc.MaxIdleConns = 98 + cfg.Httpc.MaxIdleConnsPerHost = 76 + cfg.Httpc.MaxConnsPerHost = 54 + + initGitHTTPClient(cfg) + + if gittr == nil { + t.Fatal("git transport was not initialized") + } + if gittr.MaxIdleConns != 98 { + t.Fatalf("MaxIdleConns = %d, want 98", gittr.MaxIdleConns) + } + if gittr.MaxIdleConnsPerHost != 76 { + t.Fatalf("MaxIdleConnsPerHost = %d, want 76", gittr.MaxIdleConnsPerHost) + } + if gittr.MaxConnsPerHost != 54 { + t.Fatalf("MaxConnsPerHost = %d, want 54", gittr.MaxConnsPerHost) + } +} diff --git a/proxy/match.go b/proxy/match.go index 9353c8b..29a6cdf 100644 --- a/proxy/match.go +++ b/proxy/match.go @@ -116,11 +116,19 @@ func Matcher(rawPath string, cfg *config.Config) (string, string, string, *GHPro // 匹配 "https://raw.githubusercontent.com/" if strings.HasPrefix(rawPath, rawPrefix) { remaining := rawPath[rawPrefixLen:] - parts := strings.SplitN(remaining, "/", 3) - if len(parts) < 3 { + i := strings.IndexByte(remaining, '/') + if i <= 0 { + return "", "", "", NewErrorWithStatusLookup(400, "malformed raw url: path too short") + } + user := remaining[:i] + remaining = remaining[i+1:] + + i = strings.IndexByte(remaining, '/') + if i <= 0 || i == len(remaining)-1 { return "", "", "", NewErrorWithStatusLookup(400, "malformed raw url: path too short") } - return parts[0], parts[1], "raw", nil + + return user, remaining[:i], "raw", nil } // 匹配 "https://gist.github.com/" 或 "https://gist.githubusercontent.com/" @@ -132,11 +140,16 @@ func Matcher(rawPath string, cfg *config.Config) (string, string, string, *GHPro } else { remaining = rawPath[gistContentPrefixLen:] } - parts := strings.SplitN(remaining, "/", 2) - if len(parts) == 0 || parts[0] == "" { + if remaining == "" { return "", "", "", NewErrorWithStatusLookup(400, "malformed gist url: missing user") } - return parts[0], "", "gist", nil + if i := strings.IndexByte(remaining, '/'); i != -1 { + if i == 0 { + return "", "", "", NewErrorWithStatusLookup(400, "malformed gist url: missing user") + } + remaining = remaining[:i] + } + return remaining, "", "gist", nil } // 匹配 "https://api.github.com/" @@ -147,15 +160,35 @@ func Matcher(rawPath string, cfg *config.Config) (string, string, string, *GHPro remaining := rawPath[apiPrefixLen:] var user, repo string if strings.HasPrefix(remaining, "repos/") { - parts := strings.SplitN(remaining[6:], "/", 3) - if len(parts) >= 2 { - user = parts[0] - repo = parts[1] + remaining = remaining[6:] + if q := strings.IndexByte(remaining, '?'); q != -1 { + remaining = remaining[:q] + } + if remaining != "" && !strings.ContainsRune(remaining, '/') { + user = remaining + return user, "", "api", nil + } + i := strings.IndexByte(remaining, '/') + if i > 0 { + user = remaining[:i] + rest := remaining[i+1:] + if j := strings.IndexByte(rest, '/'); j != -1 { + repo = rest[:j] + } else { + repo = rest + } } } else if strings.HasPrefix(remaining, "users/") { - parts := strings.SplitN(remaining[6:], "/", 2) - if len(parts) >= 1 { - user = parts[0] + remaining = remaining[6:] + if q := strings.IndexByte(remaining, '?'); q != -1 { + remaining = remaining[:q] + } + if remaining != "" { + if i := strings.IndexByte(remaining, '/'); i != -1 { + user = remaining[:i] + } else { + user = remaining + } } } return user, repo, "api", nil diff --git a/proxy/matcher_test.go b/proxy/matcher_test.go index 07f3e4a..9a58a7e 100644 --- a/proxy/matcher_test.go +++ b/proxy/matcher_test.go @@ -99,12 +99,24 @@ func TestMatcher_Compatibility(t *testing.T) { config: cfgWithAuth, expectedUser: "owner", expectedRepo: "repo", expectedMatcher: "raw", }, + { + name: "Malformed Raw Path (missing branch)", + rawPath: "https://raw.githubusercontent.com/owner/repo", + config: cfgWithAuth, + expectError: true, expectedErrCode: 400, + }, { name: "Gist Path", rawPath: "https://gist.github.com/user/abcdef1234567890", config: cfgWithAuth, expectedUser: "user", expectedRepo: "", expectedMatcher: "gist", }, + { + name: "Gist Path (user only)", + rawPath: "https://gist.github.com/user", + config: cfgWithAuth, + expectedUser: "user", expectedRepo: "", expectedMatcher: "gist", + }, { name: "Gist UserContent Path", rawPath: "https://gist.githubusercontent.com/user/abcdef1234567890", @@ -135,6 +147,30 @@ func TestMatcher_Compatibility(t *testing.T) { config: cfgApiForceAllowed, // Auth disabled, but force allowed expectedUser: "owner", expectedRepo: "repo", expectedMatcher: "api", }, + { + name: "API Repos Path (missing repo)", + rawPath: "https://api.github.com/repos/owner", + config: cfgWithAuth, + expectedUser: "owner", expectedRepo: "", expectedMatcher: "api", + }, + { + name: "API Repos Path (trailing slash)", + rawPath: "https://api.github.com/repos/owner/", + config: cfgWithAuth, + expectedUser: "owner", expectedRepo: "", expectedMatcher: "api", + }, + { + name: "API Repos Path (missing repo with query)", + rawPath: "https://api.github.com/repos/owner?per_page=1", + config: cfgWithAuth, + expectedUser: "owner", expectedRepo: "", expectedMatcher: "api", + }, + { + name: "API Users Path (exact user)", + rawPath: "https://api.github.com/users/someuser", + config: cfgWithAuth, + expectedUser: "someuser", expectedRepo: "", expectedMatcher: "api", + }, { name: "Malformed GH Path (no repo)", rawPath: "https://github.com/owner/", @@ -265,10 +301,11 @@ func TestExtractParts_Compatibility(t *testing.T) { }, { name: "Empty path segments", - rawURL: "https://example.com//repo/a", // Will be treated as /repo/a - expectedOwner: "", // First part is empty + rawURL: "https://example.com//repo/a", + expectedOwner: "/", expectedRepo: "/repo", expectedRem: "/a", + expectedQuery: url.Values{}, }, { name: "Invalid URL format", diff --git a/proxy/nest.go b/proxy/nest.go index 675e4a3..ffa2aa4 100644 --- a/proxy/nest.go +++ b/proxy/nest.go @@ -2,14 +2,27 @@ package proxy import ( "bufio" + "bytes" "fmt" "ghproxy/config" "io" "strings" + "sync" "github.com/infinite-iroha/touka" ) +var ( + prefixGithub = []byte("https://github.com") + prefixRawUser = []byte("https://raw.githubusercontent.com") + prefixRaw = []byte("https://raw.github.com") + prefixGistUser = []byte("https://gist.githubusercontent.com") + prefixGist = []byte("https://gist.github.com") + prefixAPIBytes = []byte("https://api.github.com") + prefixHTTP = []byte("http://") + prefixHTTPS = []byte("https://") +) + func EditorMatcher(rawPath string, cfg *config.Config) (bool, error) { // 匹配 "https://github.com"开头的链接 if strings.HasPrefix(rawPath, "https://github.com") { @@ -40,6 +53,28 @@ func EditorMatcher(rawPath string, cfg *config.Config) (bool, error) { return false, nil } +func EditorMatcherBytes(rawPath []byte, cfg *config.Config) bool { + if bytes.HasPrefix(rawPath, prefixGithub) { + return true + } + if bytes.HasPrefix(rawPath, prefixRawUser) { + return true + } + if bytes.HasPrefix(rawPath, prefixRaw) { + return true + } + if bytes.HasPrefix(rawPath, prefixGistUser) { + return true + } + if bytes.HasPrefix(rawPath, prefixGist) { + return true + } + if cfg.Shell.RewriteAPI && bytes.HasPrefix(rawPath, prefixAPIBytes) { + return true + } + return false +} + // 匹配文件扩展名是sh的rawPath func MatcherShell(rawPath string) bool { return strings.HasSuffix(rawPath, ".sh") @@ -64,87 +99,140 @@ func modifyURL(url string, host string, cfg *config.Config) string { return url } -// processLinks 处理链接,返回包含处理后数据的 io.Reader -func processLinks(input io.ReadCloser, host string, cfg *config.Config, c *touka.Context) (readerOut io.Reader, written int64, err error) { - pipeReader, pipeWriter := io.Pipe() // 创建 io.Pipe +func modifyURLBytes(url []byte, host []byte, cfg *config.Config) []byte { + if !EditorMatcherBytes(url, cfg) { + return url + } + + var trimmed []byte + if bytes.HasPrefix(url, prefixHTTPS) { + trimmed = url[len(prefixHTTPS):] + } else if bytes.HasPrefix(url, prefixHTTP) { + trimmed = url[len(prefixHTTP):] + } else { + trimmed = url + } + + newURL := make([]byte, len(prefixHTTPS)+len(host)+1+len(trimmed)) + written := 0 + written += copy(newURL[written:], prefixHTTPS) + written += copy(newURL[written:], host) + written += copy(newURL[written:], []byte("/")) + copy(newURL[written:], trimmed) + + return newURL +} + +var bufferPool = sync.Pool{ + New: func() any { + return new(bytes.Buffer) + }, +} + +func processLinksStreamingInternal(input io.ReadCloser, host string, cfg *config.Config, c *touka.Context) (readerOut io.Reader, written int64, err error) { + pipeReader, pipeWriter := io.Pipe() readerOut = pipeReader - go func() { // 在 Goroutine 中执行写入操作 + go func() { defer func() { - if pipeWriter != nil { // 确保 pipeWriter 关闭,即使发生错误 - if err != nil { - if closeErr := pipeWriter.CloseWithError(err); closeErr != nil { // 如果有错误,传递错误给 reader - c.Errorf("pipeWriter close with error failed: %v, original error: %v", closeErr, err) - } - } else { - if closeErr := pipeWriter.Close(); closeErr != nil { // 没有错误,正常关闭 - c.Errorf("pipeWriter close failed: %v", closeErr) - if err == nil { // 如果之前没有错误,记录关闭错误 - err = closeErr - } - } - } + if err != nil { + _ = pipeWriter.CloseWithError(err) + return } + _ = pipeWriter.Close() }() - defer func() { - if err := input.Close(); err != nil { - c.Errorf("input close failed: %v", err) + if closeErr := input.Close(); closeErr != nil && c != nil { + c.Errorf("input close failed: %v", closeErr) } - }() - var bufReader *bufio.Reader - - bufReader = bufio.NewReader(input) - - var bufWriter *bufio.Writer - - bufWriter = bufio.NewWriterSize(pipeWriter, 4096) // 使用 pipeWriter - - //确保writer关闭 + bufReader := bufio.NewReader(input) + bufWriter := bufio.NewWriterSize(pipeWriter, 4096) defer func() { - if flushErr := bufWriter.Flush(); flushErr != nil { - c.Errorf("writer flush failed %v", flushErr) - // 如果已经存在错误,则保留。否则,记录此错误。 - if err == nil { - err = flushErr - } + if flushErr := bufWriter.Flush(); flushErr != nil && err == nil { + err = fmt.Errorf("flush writer failed: %w", flushErr) } }() - // 使用正则表达式匹配 http 和 https 链接 for { line, readErr := bufReader.ReadString('\n') - if readErr != nil { - if readErr == io.EOF { - break // 文件结束 - } - err = fmt.Errorf("读取行错误: %v", readErr) // 传递错误 - return // Goroutine 中使用 return 返回错误 + if readErr != nil && readErr != io.EOF { + err = fmt.Errorf("read error: %w", readErr) + return } - // 替换所有匹配的 URL - modifiedLine := urlPattern.ReplaceAllStringFunc(line, func(originalURL string) string { - return modifyURL(originalURL, host, cfg) // 假设 modifyURL 函数已定义 - }) + if len(line) > 0 { + modifiedLine := urlPattern.ReplaceAllStringFunc(line, func(originalURL string) string { + return modifyURL(originalURL, host, cfg) + }) - n, writeErr := bufWriter.WriteString(modifiedLine) - written += int64(n) // 更新写入的字节数 - if writeErr != nil { - err = fmt.Errorf("写入文件错误: %v", writeErr) // 传递错误 - return // Goroutine 中使用 return 返回错误 + n, writeErr := bufWriter.WriteString(modifiedLine) + written += int64(n) + if writeErr != nil { + err = fmt.Errorf("write error: %w", writeErr) + return + } + } + + if readErr == io.EOF { + break } } + }() + + return readerOut, written, nil +} + +func processLinks(input io.ReadCloser, host string, cfg *config.Config, c *touka.Context, bodySize int) (readerOut io.Reader, written int64, err error) { + const sizeThreshold = 256 * 1024 + + if bodySize == -1 || bodySize > sizeThreshold { + return processLinksStreamingInternal(input, host, cfg, c) + } + + return processLinksBufferedInternal(input, host, cfg, c) +} + +func processLinksBufferedInternal(input io.ReadCloser, host string, cfg *config.Config, c *touka.Context) (readerOut io.Reader, written int64, err error) { + pipeReader, pipeWriter := io.Pipe() + readerOut = pipeReader + hostBytes := []byte(host) - // 在返回之前,再刷新一次 (虽然 defer 中已经有 flush,但这里再加一次确保及时刷新) - if flushErr := bufWriter.Flush(); flushErr != nil { - if err == nil { // 避免覆盖之前的错误 - err = flushErr + go func() { + defer func() { + if closeErr := input.Close(); closeErr != nil && c != nil { + c.Errorf("input close failed: %v", closeErr) + } + }() + defer func() { + if err != nil { + _ = pipeWriter.CloseWithError(err) + return } - return // Goroutine 中使用 return 返回错误 + _ = pipeWriter.Close() + }() + + buf := bufferPool.Get().(*bytes.Buffer) + buf.Reset() + defer bufferPool.Put(buf) + + if _, err = buf.ReadFrom(input); err != nil { + err = fmt.Errorf("reading input failed: %w", err) + return + } + + modifiedBytes := urlPattern.ReplaceAllFunc(buf.Bytes(), func(originalURL []byte) []byte { + return modifyURLBytes(originalURL, hostBytes, cfg) + }) + + var n int + n, err = pipeWriter.Write(modifiedBytes) + written = int64(n) + if err != nil { + err = fmt.Errorf("writing to pipe failed: %w", err) } }() - return readerOut, written, nil // 返回 reader 和 written,error 由 Goroutine 通过 pipeWriter.CloseWithError 传递 + return readerOut, written, nil } diff --git a/proxy/nest_bench_test.go b/proxy/nest_bench_test.go new file mode 100644 index 0000000..c8bc878 --- /dev/null +++ b/proxy/nest_bench_test.go @@ -0,0 +1,63 @@ +package proxy + +import ( + "ghproxy/config" + "io" + "strings" + "testing" +) + +const benchmarkInput = ` +Some text here. +Link to be replaced: http://github.com/user/repo +Another link: https://google.com +And one more: http://example.com/some/path +This should not be replaced: notalink +End of text. +` + +func BenchmarkProcessLinksStreaming(b *testing.B) { + cfg := &config.Config{} + host := "my-proxy.com" + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + b.StopTimer() + input := io.NopCloser(strings.NewReader(benchmarkInput)) + b.StartTimer() + + reader, _, err := processLinksStreamingInternal(input, host, cfg, nil) + if err != nil { + b.Fatalf("processLinksStreamingInternal failed: %v", err) + } + + if _, err = io.ReadAll(reader); err != nil { + b.Fatalf("failed to read from processed reader: %v", err) + } + } +} + +func BenchmarkProcessLinksBuffered(b *testing.B) { + cfg := &config.Config{} + host := "my-proxy.com" + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + b.StopTimer() + input := io.NopCloser(strings.NewReader(benchmarkInput)) + b.StartTimer() + + reader, _, err := processLinksBufferedInternal(input, host, cfg, nil) + if err != nil { + b.Fatalf("processLinksBufferedInternal failed: %v", err) + } + + if _, err = io.ReadAll(reader); err != nil { + b.Fatalf("failed to read from processed reader: %v", err) + } + } +} diff --git a/proxy/perf_compare_test.go b/proxy/perf_compare_test.go new file mode 100644 index 0000000..d5ebe86 --- /dev/null +++ b/proxy/perf_compare_test.go @@ -0,0 +1,88 @@ +package proxy + +import ( + "net/http" + "testing" + + "ghproxy/config" + + "github.com/infinite-iroha/touka" +) + +var benchmarkHeaderSource = http.Header{ + "Accept": {"text/plain"}, + "Accept-Encoding": {"gzip"}, + "Connection": {"keep-alive"}, + "User-Agent": {"curl/8.0.1"}, + "X-Test": {"one", "two"}, + "CF-Connecting-IP": {"127.0.0.1"}, + "X-Forwarded-For": {"127.0.0.1"}, + "Transfer-Encoding": {"chunked"}, +} + +func BenchmarkMatcherGithubRelease(b *testing.B) { + cfg := &config.Config{ + Auth: config.AuthConfig{Enabled: true, Method: "header", ForceAllowApi: false}, + } + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _, _, _ = Matcher("https://github.com/WJQSERVER-STUDIO/ghproxy/releases/download/v1.0.0/asset.tar.gz", cfg) + } +} + +func BenchmarkMatcherRaw(b *testing.B) { + cfg := &config.Config{ + Auth: config.AuthConfig{Enabled: true, Method: "header", ForceAllowApi: false}, + } + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _, _, _ = Matcher("https://raw.githubusercontent.com/WJQSERVER-STUDIO/ghproxy/main/README.md", cfg) + } +} + +func BenchmarkMatcherGist(b *testing.B) { + cfg := &config.Config{ + Auth: config.AuthConfig{Enabled: true, Method: "header", ForceAllowApi: false}, + } + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _, _, _ = Matcher("https://gist.githubusercontent.com/user/abcdef1234567890/raw/file.txt", cfg) + } +} + +func BenchmarkMatcherAPI(b *testing.B) { + cfg := &config.Config{ + Auth: config.AuthConfig{Enabled: true, Method: "header", ForceAllowApi: false}, + } + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _, _, _ = Matcher("https://api.github.com/repos/WJQSERVER-STUDIO/ghproxy/releases", cfg) + } +} + +func BenchmarkSetRequestHeadersClone(b *testing.B) { + ctx := &touka.Context{Request: &http.Request{Header: benchmarkHeaderSource}} + cfg := &config.Config{} + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + req := &http.Request{Header: make(http.Header)} + setRequestHeaders(ctx, req, cfg, "clone") + } +} + +func BenchmarkSetRequestHeadersRawCustom(b *testing.B) { + ctx := &touka.Context{Request: &http.Request{Header: benchmarkHeaderSource}} + cfg := &config.Config{} + cfg.Httpc.UseCustomRawHeaders = true + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + req := &http.Request{Header: make(http.Header)} + setRequestHeaders(ctx, req, cfg, "raw") + } +} diff --git a/proxy/reqheader.go b/proxy/reqheader.go index 57d8542..dce4e2f 100644 --- a/proxy/reqheader.go +++ b/proxy/reqheader.go @@ -60,6 +60,37 @@ func copyHeader(dst, src http.Header) { } } +func canonicalizeHeaderSet(headers map[string]struct{}) map[string]struct{} { + canonicalized := make(map[string]struct{}, len(headers)) + for key := range headers { + canonicalized[http.CanonicalHeaderKey(key)] = struct{}{} + } + return canonicalized +} + +func init() { + reqHeadersToRemove = canonicalizeHeaderSet(reqHeadersToRemove) + cloneHeadersToRemove = canonicalizeHeaderSet(cloneHeadersToRemove) + respHeadersToRemove = canonicalizeHeaderSet(respHeadersToRemove) + defaultHeaders = map[string]string{ + "Accept": "*/*", + "Accept-Encoding": "", + "Transfer-Encoding": "chunked", + "User-Agent": "GHProxy/1.0", + } +} + +func copyHeaderFiltered(dst, src http.Header, denylist map[string]struct{}) { + for k, vv := range src { + if _, denied := denylist[k]; denied { + continue + } + for _, v := range vv { + dst.Add(k, v) + } + } +} + func setRequestHeaders(c *touka.Context, req *http.Request, cfg *config.Config, matcher string) { if matcher == "raw" && cfg.Httpc.UseCustomRawHeaders { // 使用预定义Header @@ -67,14 +98,8 @@ func setRequestHeaders(c *touka.Context, req *http.Request, cfg *config.Config, req.Header.Set(key, value) } } else if matcher == "clone" { - copyHeader(req.Header, c.Request.Header) - for key := range cloneHeadersToRemove { - req.Header.Del(key) - } + copyHeaderFiltered(req.Header, c.Request.Header, cloneHeadersToRemove) } else { - copyHeader(req.Header, c.Request.Header) - for key := range reqHeadersToRemove { - req.Header.Del(key) - } + copyHeaderFiltered(req.Header, c.Request.Header, reqHeadersToRemove) } } diff --git a/proxy/routing.go b/proxy/routing.go index 7a5748f..2c80601 100644 --- a/proxy/routing.go +++ b/proxy/routing.go @@ -44,17 +44,12 @@ func RoutingHandler(cfg *config.Config) touka.HandlerFunc { return } - // 处理blob/raw路径 + rawPath = buildProxyPath(rawPath, matcher) + if matcher == "blob" { - rawPath = rawPath[10:] - rawPath = "raw.githubusercontent.com" + rawPath - rawPath = strings.Replace(rawPath, "/blob/", "/", 1) matcher = "raw" } - // 为rawpath加入https:// 头 - rawPath = "https://" + rawPath - switch matcher { case "releases", "blob", "raw", "gist", "api": ChunkedProxyRequest(ctx, c, rawPath, cfg, matcher)