Skip to content

Commit 57ea3a5

Browse files
committed
feat: add concurrent multi-url batch fetching
1 parent d5415c6 commit 57ea3a5

3 files changed

Lines changed: 242 additions & 9 deletions

File tree

cmd/agent-fetch/batch.go

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"io"
7+
"strings"
8+
"sync"
9+
"time"
10+
11+
"github.com/firede/agent-fetch/internal/fetcher"
12+
)
13+
14+
type fetchFunc func(context.Context, string, fetcher.Config) (fetcher.Result, error)
15+
16+
type taskResult struct {
17+
index int
18+
inputURL string
19+
markdown string
20+
err error
21+
}
22+
23+
func fetchBatch(ctx context.Context, urls []string, cfg fetcher.Config, concurrency int, fetch fetchFunc) []taskResult {
24+
if concurrency < 1 {
25+
concurrency = 1
26+
}
27+
28+
results := make([]taskResult, len(urls))
29+
sem := make(chan struct{}, concurrency)
30+
31+
var wg sync.WaitGroup
32+
for i, url := range urls {
33+
i, url := i, url
34+
wg.Add(1)
35+
go func() {
36+
defer wg.Done()
37+
sem <- struct{}{}
38+
defer func() { <-sem }()
39+
40+
reqCtx, cancel := context.WithTimeout(ctx, maxDuration(cfg.Timeout, cfg.BrowserTimeout)+5*time.Second)
41+
defer cancel()
42+
43+
res, err := fetch(reqCtx, url, cfg)
44+
results[i] = taskResult{
45+
index: i + 1,
46+
inputURL: url,
47+
markdown: res.Markdown,
48+
err: err,
49+
}
50+
}()
51+
}
52+
wg.Wait()
53+
54+
return results
55+
}
56+
57+
func writeBatchMarkdown(w io.Writer, results []taskResult) error {
58+
total := len(results)
59+
failed := failedCount(results)
60+
succeeded := total - failed
61+
62+
if _, err := fmt.Fprintf(w, "<!-- count: %d, succeeded: %d, failed: %d -->\n", total, succeeded, failed); err != nil {
63+
return err
64+
}
65+
66+
for i, result := range results {
67+
if i > 0 {
68+
if _, err := io.WriteString(w, "\n"); err != nil {
69+
return err
70+
}
71+
}
72+
73+
url := sanitizeForComment(result.inputURL)
74+
if result.err != nil {
75+
if _, err := fmt.Fprintf(w, "<!-- task[%d](failed): %s -->\n", result.index, url); err != nil {
76+
return err
77+
}
78+
errMsg := sanitizeForComment(result.err.Error())
79+
if _, err := fmt.Fprintf(w, "<!-- error[%d]: %s -->\n", result.index, errMsg); err != nil {
80+
return err
81+
}
82+
continue
83+
}
84+
85+
if _, err := fmt.Fprintf(w, "<!-- task[%d]: %s -->\n", result.index, url); err != nil {
86+
return err
87+
}
88+
if _, err := io.WriteString(w, result.markdown); err != nil {
89+
return err
90+
}
91+
if !strings.HasSuffix(result.markdown, "\n") {
92+
if _, err := io.WriteString(w, "\n"); err != nil {
93+
return err
94+
}
95+
}
96+
if _, err := fmt.Fprintf(w, "<!-- /task[%d] -->\n", result.index); err != nil {
97+
return err
98+
}
99+
}
100+
101+
return nil
102+
}
103+
104+
func failedCount(results []taskResult) int {
105+
n := 0
106+
for _, result := range results {
107+
if result.err != nil {
108+
n++
109+
}
110+
}
111+
return n
112+
}
113+
114+
func sanitizeForComment(s string) string {
115+
s = strings.ReplaceAll(s, "\r", " ")
116+
s = strings.ReplaceAll(s, "\n", " ")
117+
s = strings.ReplaceAll(s, "-->", "-- >")
118+
return strings.TrimSpace(s)
119+
}

cmd/agent-fetch/batch_test.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"errors"
6+
"strings"
7+
"testing"
8+
"time"
9+
10+
"github.com/firede/agent-fetch/internal/fetcher"
11+
)
12+
13+
func TestWriteBatchMarkdown(t *testing.T) {
14+
results := []taskResult{
15+
{
16+
index: 1,
17+
inputURL: "https://example.com/hello",
18+
markdown: "# hello\n",
19+
},
20+
{
21+
index: 2,
22+
inputURL: "https://abc.com",
23+
err: errors.New("http request failed: timeout"),
24+
},
25+
{
26+
index: 3,
27+
inputURL: "https://example.net/hi",
28+
markdown: "hi",
29+
},
30+
}
31+
32+
var b strings.Builder
33+
if err := writeBatchMarkdown(&b, results); err != nil {
34+
t.Fatalf("write batch markdown: %v", err)
35+
}
36+
37+
got := b.String()
38+
want := strings.Join([]string{
39+
"<!-- count: 3, succeeded: 2, failed: 1 -->",
40+
"<!-- task[1]: https://example.com/hello -->",
41+
"# hello",
42+
"<!-- /task[1] -->",
43+
"",
44+
"<!-- task[2](failed): https://abc.com -->",
45+
"<!-- error[2]: http request failed: timeout -->",
46+
"",
47+
"<!-- task[3]: https://example.net/hi -->",
48+
"hi",
49+
"<!-- /task[3] -->",
50+
"",
51+
}, "\n")
52+
53+
if got != want {
54+
t.Fatalf("unexpected output\n--- got ---\n%s\n--- want ---\n%s", got, want)
55+
}
56+
}
57+
58+
func TestFetchBatchPreservesInputOrder(t *testing.T) {
59+
urls := []string{
60+
"https://example.com/1",
61+
"https://example.com/2",
62+
"https://example.com/3",
63+
}
64+
65+
delayByURL := map[string]time.Duration{
66+
urls[0]: 50 * time.Millisecond,
67+
urls[1]: 5 * time.Millisecond,
68+
urls[2]: 25 * time.Millisecond,
69+
}
70+
71+
fetch := func(ctx context.Context, url string, cfg fetcher.Config) (fetcher.Result, error) {
72+
select {
73+
case <-ctx.Done():
74+
return fetcher.Result{}, ctx.Err()
75+
case <-time.After(delayByURL[url]):
76+
}
77+
return fetcher.Result{Markdown: "content-" + url}, nil
78+
}
79+
80+
cfg := fetcher.DefaultConfig()
81+
results := fetchBatch(context.Background(), urls, cfg, 3, fetch)
82+
if len(results) != 3 {
83+
t.Fatalf("unexpected result count: %d", len(results))
84+
}
85+
86+
for i := range urls {
87+
if results[i].index != i+1 {
88+
t.Fatalf("unexpected index at %d: got %d want %d", i, results[i].index, i+1)
89+
}
90+
if results[i].inputURL != urls[i] {
91+
t.Fatalf("unexpected url at %d: got %q want %q", i, results[i].inputURL, urls[i])
92+
}
93+
if results[i].err != nil {
94+
t.Fatalf("unexpected error at %d: %v", i, results[i].err)
95+
}
96+
}
97+
}

cmd/agent-fetch/main.go

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func main() {
3333
cmd := &cli.Command{
3434
Name: "agent-fetch",
3535
Usage: "Fetch web content and return markdown-friendly output",
36-
UsageText: "agent-fetch [options] <url>",
36+
UsageText: "agent-fetch [options] <url> [url ...]",
3737
Version: versionString(),
3838
Flags: []cli.Flag{
3939
&cli.StringFlag{Name: "mode", Value: defaultCfg.Mode, Usage: "fetch mode: auto|static|browser|raw"},
@@ -44,13 +44,14 @@ func main() {
4444
&cli.StringFlag{Name: "wait-selector", Usage: "CSS selector to wait for in browser mode"},
4545
&cli.StringFlag{Name: "user-agent", Value: defaultCfg.UserAgent, Usage: "User-Agent header"},
4646
&cli.Int64Flag{Name: "max-body-bytes", Value: defaultCfg.MaxBodyBytes, Usage: "max response bytes to read"},
47+
&cli.IntFlag{Name: "concurrency", Value: 4, Usage: "max concurrent URL fetches when multiple URLs are provided"},
4748
&cli.StringSliceFlag{
4849
Name: "header",
4950
Usage: "custom request header, repeatable. Example: --header 'Authorization: Bearer token'",
5051
},
5152
},
5253
Action: func(ctx context.Context, c *cli.Command) error {
53-
if c.Args().Len() != 1 {
54+
if c.Args().Len() < 1 {
5455
_ = cli.ShowRootCommandHelp(c)
5556
return &exitStatusError{code: 2}
5657
}
@@ -71,18 +72,34 @@ func main() {
7172
}
7273
cfg.Headers = parsedHeaders
7374

74-
url := c.Args().First()
75-
reqCtx, cancel := context.WithTimeout(ctx, maxDuration(cfg.Timeout, cfg.BrowserTimeout)+5*time.Second)
76-
defer cancel()
75+
urls := c.Args().Slice()
76+
concurrency := c.Int("concurrency")
77+
if concurrency < 1 {
78+
return &exitStatusError{code: 2, msg: "invalid concurrency: must be >= 1"}
79+
}
7780

78-
res, err := fetcher.Fetch(reqCtx, url, cfg)
79-
if err != nil {
80-
return &exitStatusError{code: 1, msg: fmt.Sprintf("fetch failed: %v", err)}
81+
if len(urls) == 1 {
82+
reqCtx, cancel := context.WithTimeout(ctx, maxDuration(cfg.Timeout, cfg.BrowserTimeout)+5*time.Second)
83+
defer cancel()
84+
85+
res, err := fetcher.Fetch(reqCtx, urls[0], cfg)
86+
if err != nil {
87+
return &exitStatusError{code: 1, msg: fmt.Sprintf("fetch failed: %v", err)}
88+
}
89+
90+
if _, err := os.Stdout.WriteString(res.Markdown); err != nil {
91+
return &exitStatusError{code: 1, msg: fmt.Sprintf("write failed: %v", err)}
92+
}
93+
return nil
8194
}
8295

83-
if _, err := os.Stdout.WriteString(res.Markdown); err != nil {
96+
results := fetchBatch(ctx, urls, cfg, concurrency, fetcher.Fetch)
97+
if err := writeBatchMarkdown(os.Stdout, results); err != nil {
8498
return &exitStatusError{code: 1, msg: fmt.Sprintf("write failed: %v", err)}
8599
}
100+
if failed := failedCount(results); failed > 0 {
101+
return &exitStatusError{code: 1}
102+
}
86103
return nil
87104
},
88105
}

0 commit comments

Comments
 (0)