Skip to content

Commit

Permalink
fix: parallel containers clean race (#2790)
Browse files Browse the repository at this point in the history
Simply the logic in parallel containers, eliminating a clean up race
condition where multiple clean ups on the same container could occur at
the same time.
  • Loading branch information
stevenh authored Sep 23, 2024
1 parent 1d01e21 commit 309dec1
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 60 deletions.
70 changes: 29 additions & 41 deletions parallel.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package testcontainers

import (
"context"
"errors"
"fmt"
"sync"
)
Expand Down Expand Up @@ -32,24 +31,27 @@ func (gpe ParallelContainersError) Error() string {
return fmt.Sprintf("%v", gpe.Errors)
}

// parallelContainersResult represents result.
type parallelContainersResult struct {
ParallelContainersRequestError
Container Container
}

func parallelContainersRunner(
ctx context.Context,
requests <-chan GenericContainerRequest,
errorsCh chan<- ParallelContainersRequestError,
containers chan<- Container,
results chan<- parallelContainersResult,
wg *sync.WaitGroup,
) {
defer wg.Done()
for req := range requests {
c, err := GenericContainer(ctx, req)
res := parallelContainersResult{Container: c}
if err != nil {
errorsCh <- ParallelContainersRequestError{
Request: req,
Error: errors.Join(err, TerminateContainer(c)),
}
continue
res.Request = req
res.Error = err
}
containers <- c
results <- res
}
}

Expand All @@ -65,41 +67,26 @@ func ParallelContainers(ctx context.Context, reqs ParallelContainerRequest, opt
}

tasksChan := make(chan GenericContainerRequest, tasksChanSize)
errsChan := make(chan ParallelContainersRequestError)
resChan := make(chan Container)
waitRes := make(chan struct{})

containers := make([]Container, 0)
errors := make([]ParallelContainersRequestError, 0)
resultsChan := make(chan parallelContainersResult, tasksChanSize)
done := make(chan struct{})

wg := sync.WaitGroup{}
var wg sync.WaitGroup
wg.Add(tasksChanSize)

// run workers
for i := 0; i < tasksChanSize; i++ {
go parallelContainersRunner(ctx, tasksChan, errsChan, resChan, &wg)
go parallelContainersRunner(ctx, tasksChan, resultsChan, &wg)
}

var errs []ParallelContainersRequestError
containers := make([]Container, 0, len(reqs))
go func() {
for {
select {
case c, ok := <-resChan:
if !ok {
resChan = nil
} else {
containers = append(containers, c)
}
case e, ok := <-errsChan:
if !ok {
errsChan = nil
} else {
errors = append(errors, e)
}
}

if resChan == nil && errsChan == nil {
waitRes <- struct{}{}
break
defer close(done)
for res := range resultsChan {
if res.Error != nil {
errs = append(errs, res.ParallelContainersRequestError)
} else {
containers = append(containers, res.Container)
}
}
}()
Expand All @@ -108,14 +95,15 @@ func ParallelContainers(ctx context.Context, reqs ParallelContainerRequest, opt
tasksChan <- req
}
close(tasksChan)

wg.Wait()
close(resChan)
close(errsChan)

<-waitRes
close(resultsChan)

<-done

if len(errors) != 0 {
return containers, ParallelContainersError{Errors: errors}
if len(errs) != 0 {
return containers, ParallelContainersError{Errors: errs}
}

return containers, nil
Expand Down
29 changes: 10 additions & 19 deletions parallel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package testcontainers

import (
"context"
"errors"
"fmt"
"testing"
"time"
Expand Down Expand Up @@ -99,23 +98,18 @@ func TestParallelContainers(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
res, err := ParallelContainers(context.Background(), tc.reqs, ParallelContainersOptions{})
if err != nil {
require.NotZero(t, tc.expErrors)
var e ParallelContainersError
errors.As(err, &e)
if len(e.Errors) != tc.expErrors {
t.Fatalf("expected errors: %d, got: %d\n", tc.expErrors, len(e.Errors))
}
}

for _, c := range res {
c := c
CleanupContainer(t, c)
}

if len(res) != tc.resLen {
t.Fatalf("expected containers: %d, got: %d\n", tc.resLen, len(res))
if tc.expErrors != 0 {
require.Error(t, err)
var errs ParallelContainersError
require.ErrorAs(t, err, &errs)
require.Len(t, errs.Errors, tc.expErrors)
}

require.Len(t, res, tc.resLen)
})
}
}
Expand Down Expand Up @@ -157,11 +151,8 @@ func TestParallelContainersWithReuse(t *testing.T) {
ctx := context.Background()

res, err := ParallelContainers(ctx, parallelRequest, ParallelContainersOptions{})
if err != nil {
var e ParallelContainersError
errors.As(err, &e)
t.Fatalf("expected errors: %d, got: %d\n", 0, len(e.Errors))
for _, c := range res {
CleanupContainer(t, c)
}
// Container is reused, only terminate first container
CleanupContainer(t, res[0])
require.NoError(t, err)
}

0 comments on commit 309dec1

Please sign in to comment.