Prechádzať zdrojové kódy

integration: Extract stream demultiplexing from container.Exec

The original code in container.Exec was potentially leaking the copy
goroutine when the context was cancelled or timed out. The new
`demultiplexStreams()` function won't return until the goroutine has
finished its work, and to ensure that it takes care of closing the
hijacked connection.

Signed-off-by: Albin Kerouanton <albinker@gmail.com>
Albin Kerouanton 2 rokov pred
rodič
commit
d7fb4dd170

+ 36 - 0
integration/internal/container/container.go

@@ -1,14 +1,17 @@
 package container
 
 import (
+	"bytes"
 	"context"
 	"runtime"
+	"sync"
 	"testing"
 
 	"github.com/docker/docker/api/types"
 	"github.com/docker/docker/api/types/container"
 	"github.com/docker/docker/api/types/network"
 	"github.com/docker/docker/client"
+	"github.com/docker/docker/pkg/stdcopy"
 	ocispec "github.com/opencontainers/image-spec/specs-go/v1"
 	"gotest.tools/v3/assert"
 )
@@ -71,3 +74,36 @@ func Run(ctx context.Context, t *testing.T, client client.APIClient, ops ...func
 
 	return id
 }
+
+type streams struct {
+	stdout, stderr bytes.Buffer
+}
+
+// demultiplexStreams starts a goroutine to demultiplex stdout and stderr from the types.HijackedResponse resp and
+// waits until either multiplexed stream reaches EOF or the context expires. It unconditionally closes resp and waits
+// until the demultiplexing goroutine has finished its work before returning.
+func demultiplexStreams(ctx context.Context, resp types.HijackedResponse) (streams, error) {
+	var s streams
+	outputDone := make(chan error, 1)
+
+	var wg sync.WaitGroup
+	wg.Add(1)
+	go func() {
+		_, err := stdcopy.StdCopy(&s.stdout, &s.stderr, resp.Reader)
+		outputDone <- err
+		wg.Done()
+	}()
+
+	var err error
+	select {
+	case copyErr := <-outputDone:
+		err = copyErr
+		break
+	case <-ctx.Done():
+		err = ctx.Err()
+	}
+
+	resp.Close()
+	wg.Wait()
+	return s, err
+}

+ 4 - 21
integration/internal/container/exec.go

@@ -6,7 +6,6 @@ import (
 
 	"github.com/docker/docker/api/types"
 	"github.com/docker/docker/client"
-	"github.com/docker/docker/pkg/stdcopy"
 )
 
 // ExecResult represents a result returned from Exec()
@@ -58,27 +57,11 @@ func Exec(ctx context.Context, cli client.APIClient, id string, cmd []string, op
 	if err != nil {
 		return ExecResult{}, err
 	}
-	defer aresp.Close()
 
 	// read the output
-	var outBuf, errBuf bytes.Buffer
-	outputDone := make(chan error, 1)
-
-	go func() {
-		// StdCopy demultiplexes the stream into two buffers
-		_, err = stdcopy.StdCopy(&outBuf, &errBuf, aresp.Reader)
-		outputDone <- err
-	}()
-
-	select {
-	case err := <-outputDone:
-		if err != nil {
-			return ExecResult{}, err
-		}
-		break
-
-	case <-ctx.Done():
-		return ExecResult{}, ctx.Err()
+	s, err := demultiplexStreams(ctx, aresp)
+	if err != nil {
+		return ExecResult{}, err
 	}
 
 	// get the exit code
@@ -87,5 +70,5 @@ func Exec(ctx context.Context, cli client.APIClient, id string, cmd []string, op
 		return ExecResult{}, err
 	}
 
-	return ExecResult{ExitCode: iresp.ExitCode, outBuffer: &outBuf, errBuffer: &errBuf}, nil
+	return ExecResult{ExitCode: iresp.ExitCode, outBuffer: &s.stdout, errBuffer: &s.stderr}, nil
 }