Skip to content

Commit 1a33a4d

Browse files
committed
Add code
1 parent ba2af0d commit 1a33a4d

2 files changed

Lines changed: 107 additions & 0 deletions

File tree

cc.go

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package cc
2+
3+
import "sync"
4+
5+
// Pool manages a pool of concurrent workers. It works a bit like a Waitgroup, but with error reporting and concurrency limits
6+
// You create one with New, and run functions with Run. Then you wait on it like a regular WaitGroup.
7+
//
8+
// Example:
9+
//
10+
// p := cc.New(4)
11+
// p.Run(func() error {
12+
// afunction()
13+
// return nil
14+
// })
15+
// errs := p.Wait()
16+
//
17+
// for err := range errs {
18+
//
19+
// }
20+
type Pool struct {
21+
errors chan error
22+
23+
semaphore chan bool
24+
wg *sync.WaitGroup
25+
}
26+
27+
// New returns a new pool where a limited number (concurrency) of goroutine can work at the same time
28+
func New(concurrency int) *Pool {
29+
wg := sync.WaitGroup{}
30+
p := Pool{
31+
errors: make(chan error),
32+
semaphore: make(chan bool, concurrency),
33+
wg: &wg,
34+
}
35+
return &p
36+
}
37+
38+
// Wait blocks and ensures that the channels are closed when all the goroutines end.
39+
// It returns a list of all the errors returned by the goroutine
40+
func (p *Pool) Wait() []error {
41+
go func() {
42+
p.wg.Wait()
43+
44+
close(p.semaphore)
45+
close(p.errors)
46+
}()
47+
48+
errs := []error{}
49+
50+
for err := range p.errors {
51+
if err != nil {
52+
errs = append(errs, err)
53+
}
54+
}
55+
56+
return errs
57+
}
58+
59+
// Run wraps the given function into a goroutine and ensure that the concurrency limits are respected.
60+
// The error returned by the function is stored into the error list returned by Wait
61+
func (p *Pool) Run(fn func() error) {
62+
p.wg.Add(1)
63+
go func() {
64+
p.semaphore <- true
65+
p.errors <- fn()
66+
<-p.semaphore
67+
p.wg.Done()
68+
}()
69+
}

cc_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package cc_test
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"testing"
7+
8+
"github.com/codeclysm/cc"
9+
)
10+
11+
func Example() {
12+
p := cc.New(4)
13+
p.Run(func() error {
14+
return errors.New("fail1")
15+
})
16+
p.Run(func() error {
17+
return errors.New("fail2")
18+
})
19+
p.Run(func() error {
20+
return nil
21+
})
22+
23+
errs := p.Wait()
24+
fmt.Println(len(errs))
25+
// Output: 2
26+
}
27+
28+
func TestRace(t *testing.T) {
29+
p := cc.New(4)
30+
31+
for i := 0; i < 1000; i++ {
32+
p.Run(func() error {
33+
return errors.New("fail")
34+
})
35+
}
36+
37+
p.Wait()
38+
}

0 commit comments

Comments
 (0)