Skip to content

Commit 7e17b9e

Browse files
committed
pool: let WithCancelOnError also cancel on panics
1 parent f9b38e5 commit 7e17b9e

File tree

2 files changed

+45
-4
lines changed

2 files changed

+45
-4
lines changed

pool/context_pool.go

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package pool
22

33
import (
44
"context"
5+
6+
"github.com/sourcegraph/conc/panics"
57
)
68

79
// ContextPool is a pool that runs tasks that take a context.
@@ -23,14 +25,31 @@ type ContextPool struct {
2325
// are busy, a call to Go() will block until the task can be started.
2426
func (p *ContextPool) Go(f func(ctx context.Context) error) {
2527
p.errorPool.Go(func() error {
26-
err := f(p.ctx)
27-
if err != nil && p.cancelOnError {
28+
// If we aren't cancelling on error, just return the result of f.
29+
if !p.cancelOnError {
30+
return f(p.ctx)
31+
}
32+
33+
// If we are cancelling on error, then we also want to cancel on panic.
34+
// To do this, we need to recover from any panic f raises.
35+
var err error
36+
recovered := panics.Try(func() { err = f(p.ctx) })
37+
if err != nil || recovered != nil {
2838
// Leaky abstraction warning: We add the error directly because
2939
// otherwise, canceling could cause another goroutine to exit and
3040
// return an error before this error was added, which breaks the
3141
// expectations of WithFirstError().
32-
p.errorPool.addErr(err)
42+
if err != nil {
43+
p.errorPool.addErr(err)
44+
}
45+
3346
p.cancel()
47+
48+
// Now that context is cancelled, if we caught a panic we can
49+
// propagate it.
50+
if recovered != nil {
51+
panic(recovered)
52+
}
3453
return nil
3554
}
3655
return err
@@ -56,7 +75,7 @@ func (p *ContextPool) WithFirstError() *ContextPool {
5675
}
5776

5877
// WithCancelOnError configures the pool to cancel its context as soon as
59-
// any task returns an error. By default, the pool's context is not
78+
// any task returns an error or panics. By default, the pool's context is not
6079
// canceled until the parent context is canceled.
6180
//
6281
// In this case, all errors returned from the pool after the first will

pool/context_pool_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"testing"
99
"time"
1010

11+
"github.com/stretchr/testify/assert"
1112
"github.com/stretchr/testify/require"
1213

1314
"github.com/sourcegraph/sourcegraph/lib/errors"
@@ -188,6 +189,27 @@ func TestContextPool(t *testing.T) {
188189
require.NotErrorIs(t, err, context.Canceled)
189190
})
190191

192+
t.Run("WithCancelOnError and panic", func(t *testing.T) {
193+
t.Parallel()
194+
p := New().WithContext(bgctx).WithCancelOnError()
195+
var cancelledTasks atomic.Int64
196+
p.Go(func(ctx context.Context) error {
197+
<-ctx.Done()
198+
cancelledTasks.Add(1)
199+
return ctx.Err()
200+
})
201+
p.Go(func(ctx context.Context) error {
202+
<-ctx.Done()
203+
cancelledTasks.Add(1)
204+
return ctx.Err()
205+
})
206+
p.Go(func(ctx context.Context) error {
207+
panic("abort!")
208+
})
209+
assert.Panics(t, func() { _ = p.Wait() })
210+
assert.EqualValues(t, 2, cancelledTasks.Load())
211+
})
212+
191213
t.Run("limit", func(t *testing.T) {
192214
t.Parallel()
193215
for _, maxConcurrent := range []int{1, 10, 100} {

0 commit comments

Comments
 (0)