Skip to content

Commit 771c4b0

Browse files
authored
Merge pull request #92 from sourcegraph/context-pool-panic
pool: let WithCancelOnError also cancel on panics
2 parents 14c5081 + e8fb035 commit 771c4b0

File tree

3 files changed

+59
-1
lines changed

3 files changed

+59
-1
lines changed

pool/context_pool.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,18 @@ type ContextPool struct {
2323
// are busy, a call to Go() will block until the task can be started.
2424
func (p *ContextPool) Go(f func(ctx context.Context) error) {
2525
p.errorPool.Go(func() error {
26+
if p.cancelOnError {
27+
// If we are cancelling on error, then we also want to cancel if a
28+
// panic is raised. To do this, we need to recover, cancel, and then
29+
// re-throw the caught panic.
30+
defer func() {
31+
if r := recover(); r != nil {
32+
p.cancel()
33+
panic(r)
34+
}
35+
}()
36+
}
37+
2638
err := f(p.ctx)
2739
if err != nil && p.cancelOnError {
2840
// Leaky abstraction warning: We add the error directly because
@@ -56,7 +68,7 @@ func (p *ContextPool) WithFirstError() *ContextPool {
5668
}
5769

5870
// WithCancelOnError configures the pool to cancel its context as soon as
59-
// any task returns an error. By default, the pool's context is not
71+
// any task returns an error or panics. By default, the pool's context is not
6072
// canceled until the parent context is canceled.
6173
//
6274
// In this case, all errors returned from the pool after the first will

pool/context_pool_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"testing"
1010
"time"
1111

12+
"github.com/stretchr/testify/assert"
1213
"github.com/stretchr/testify/require"
1314
)
1415

@@ -187,6 +188,27 @@ func TestContextPool(t *testing.T) {
187188
require.NotErrorIs(t, err, context.Canceled)
188189
})
189190

191+
t.Run("WithCancelOnError and panic", func(t *testing.T) {
192+
t.Parallel()
193+
p := New().WithContext(bgctx).WithCancelOnError()
194+
var cancelledTasks atomic.Int64
195+
p.Go(func(ctx context.Context) error {
196+
<-ctx.Done()
197+
cancelledTasks.Add(1)
198+
return ctx.Err()
199+
})
200+
p.Go(func(ctx context.Context) error {
201+
<-ctx.Done()
202+
cancelledTasks.Add(1)
203+
return ctx.Err()
204+
})
205+
p.Go(func(ctx context.Context) error {
206+
panic("abort!")
207+
})
208+
assert.Panics(t, func() { _ = p.Wait() })
209+
assert.EqualValues(t, 2, cancelledTasks.Load())
210+
})
211+
190212
t.Run("limit", func(t *testing.T) {
191213
t.Parallel()
192214
for _, maxConcurrent := range []int{1, 10, 100} {

pool/result_context_pool_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"testing"
1111
"time"
1212

13+
"github.com/stretchr/testify/assert"
1314
"github.com/stretchr/testify/require"
1415
)
1516

@@ -100,6 +101,29 @@ func TestResultContextPool(t *testing.T) {
100101
require.ErrorIs(t, err, err1)
101102
})
102103

104+
t.Run("WithCancelOnError and panic", func(t *testing.T) {
105+
t.Parallel()
106+
p := NewWithResults[int]().
107+
WithContext(context.Background()).
108+
WithCancelOnError()
109+
var cancelledTasks atomic.Int64
110+
p.Go(func(ctx context.Context) (int, error) {
111+
<-ctx.Done()
112+
cancelledTasks.Add(1)
113+
return 0, ctx.Err()
114+
})
115+
p.Go(func(ctx context.Context) (int, error) {
116+
<-ctx.Done()
117+
cancelledTasks.Add(1)
118+
return 0, ctx.Err()
119+
})
120+
p.Go(func(ctx context.Context) (int, error) {
121+
panic("abort!")
122+
})
123+
assert.Panics(t, func() { _, _ = p.Wait() })
124+
assert.EqualValues(t, 2, cancelledTasks.Load())
125+
})
126+
103127
t.Run("no WithCancelOnError", func(t *testing.T) {
104128
t.Parallel()
105129
g := NewWithResults[int]().WithContext(context.Background())

0 commit comments

Comments
 (0)