Sfoglia il codice sorgente

Merge pull request #18943 from cpuguy83/fix_write_flusher

Remove Exists from backend
David Calavera 9 anni fa
parent
commit
b3cb0d196d

+ 2 - 4
api/server/router/container/backend.go

@@ -44,14 +44,13 @@ type stateBackend interface {
 	ContainerUnpause(name string) error
 	ContainerUpdate(name string, hostConfig *container.HostConfig) ([]string, error)
 	ContainerWait(name string, timeout time.Duration) (int, error)
-	Exists(id string) bool
 }
 
 // monitorBackend includes functions to implement to provide containers monitoring functionality.
 type monitorBackend interface {
 	ContainerChanges(name string) ([]archive.Change, error)
 	ContainerInspect(name string, size bool, version version.Version) (interface{}, error)
-	ContainerLogs(name string, config *backend.ContainerLogsConfig) error
+	ContainerLogs(name string, config *backend.ContainerLogsConfig, started chan struct{}) error
 	ContainerStats(name string, config *backend.ContainerStatsConfig) error
 	ContainerTop(name string, psArgs string) (*types.ContainerProcessList, error)
 
@@ -60,8 +59,7 @@ type monitorBackend interface {
 
 // attachBackend includes function to implement to provide container attaching functionality.
 type attachBackend interface {
-	ContainerAttachWithLogs(name string, c *backend.ContainerAttachWithLogsConfig) error
-	ContainerWsAttachWithLogs(name string, c *backend.ContainerWsAttachWithLogsConfig) error
+	ContainerAttach(name string, c *backend.ContainerAttachConfig) error
 }
 
 // Backend is all the methods that need to be implemented to provide container specific functionality.

+ 80 - 53
api/server/router/container/container_routes.go

@@ -66,14 +66,8 @@ func (s *containerRouter) getContainersStats(ctx context.Context, w http.Respons
 	}
 
 	stream := httputils.BoolValueOrDefault(r, "stream", true)
-	var out io.Writer
 	if !stream {
 		w.Header().Set("Content-Type", "application/json")
-		out = w
-	} else {
-		wf := ioutils.NewWriteFlusher(w)
-		out = wf
-		defer wf.Close()
 	}
 
 	var closeNotifier <-chan bool
@@ -83,7 +77,7 @@ func (s *containerRouter) getContainersStats(ctx context.Context, w http.Respons
 
 	config := &backend.ContainerStatsConfig{
 		Stream:    stream,
-		OutStream: out,
+		OutStream: w,
 		Stop:      closeNotifier,
 		Version:   string(httputils.VersionFromContext(ctx)),
 	}
@@ -112,22 +106,6 @@ func (s *containerRouter) getContainersLogs(ctx context.Context, w http.Response
 	}
 
 	containerName := vars["name"]
-
-	if !s.backend.Exists(containerName) {
-		return derr.ErrorCodeNoSuchContainer.WithArgs(containerName)
-	}
-
-	// write an empty chunk of data (this is to ensure that the
-	// HTTP Response is sent immediately, even if the container has
-	// not yet produced any data)
-	w.WriteHeader(http.StatusOK)
-	if flusher, ok := w.(http.Flusher); ok {
-		flusher.Flush()
-	}
-
-	output := ioutils.NewWriteFlusher(w)
-	defer output.Close()
-
 	logsConfig := &backend.ContainerLogsConfig{
 		ContainerLogsOptions: types.ContainerLogsOptions{
 			Follow:     httputils.BoolValue(r, "follow"),
@@ -137,15 +115,21 @@ func (s *containerRouter) getContainersLogs(ctx context.Context, w http.Response
 			ShowStdout: stdout,
 			ShowStderr: stderr,
 		},
-		OutStream: output,
+		OutStream: w,
 		Stop:      closeNotifier,
 	}
 
-	if err := s.backend.ContainerLogs(containerName, logsConfig); err != nil {
-		// The client may be expecting all of the data we're sending to
-		// be multiplexed, so send it through OutStream, which will
-		// have been set up to handle that if needed.
-		fmt.Fprintf(logsConfig.OutStream, "Error running logs job: %s\n", utils.GetErrorMessage(err))
+	chStarted := make(chan struct{})
+	if err := s.backend.ContainerLogs(containerName, logsConfig, chStarted); err != nil {
+		select {
+		case <-chStarted:
+			// The client may be expecting all of the data we're sending to
+			// be multiplexed, so send it through OutStream, which will
+			// have been set up to handle that if needed.
+			fmt.Fprintf(logsConfig.OutStream, "Error running logs job: %s\n", utils.GetErrorMessage(err))
+		default:
+			return err
+		}
 	}
 
 	return nil
@@ -443,18 +427,45 @@ func (s *containerRouter) postContainersAttach(ctx context.Context, w http.Respo
 		}
 	}
 
-	attachWithLogsConfig := &backend.ContainerAttachWithLogsConfig{
-		Hijacker:   w.(http.Hijacker),
-		Upgrade:    upgrade,
+	hijacker, ok := w.(http.Hijacker)
+	if !ok {
+		return derr.ErrorCodeNoHijackConnection.WithArgs(containerName)
+	}
+
+	setupStreams := func() (io.ReadCloser, io.Writer, io.Writer, error) {
+		conn, _, err := hijacker.Hijack()
+		if err != nil {
+			return nil, nil, nil, err
+		}
+
+		// set raw mode
+		conn.Write([]byte{})
+
+		if upgrade {
+			fmt.Fprintf(conn, "HTTP/1.1 101 UPGRADED\r\nContent-Type: application/vnd.docker.raw-stream\r\nConnection: Upgrade\r\nUpgrade: tcp\r\n\r\n")
+		} else {
+			fmt.Fprintf(conn, "HTTP/1.1 200 OK\r\nContent-Type: application/vnd.docker.raw-stream\r\n\r\n")
+		}
+
+		closer := func() error {
+			httputils.CloseStreams(conn)
+			return nil
+		}
+		return ioutils.NewReadCloserWrapper(conn, closer), conn, conn, nil
+	}
+
+	attachConfig := &backend.ContainerAttachConfig{
+		GetStreams: setupStreams,
 		UseStdin:   httputils.BoolValue(r, "stdin"),
 		UseStdout:  httputils.BoolValue(r, "stdout"),
 		UseStderr:  httputils.BoolValue(r, "stderr"),
 		Logs:       httputils.BoolValue(r, "logs"),
 		Stream:     httputils.BoolValue(r, "stream"),
 		DetachKeys: keys,
+		MuxStreams: true,
 	}
 
-	return s.backend.ContainerAttachWithLogs(containerName, attachWithLogsConfig)
+	return s.backend.ContainerAttach(containerName, attachConfig)
 }
 
 func (s *containerRouter) wsContainersAttach(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
@@ -463,10 +474,6 @@ func (s *containerRouter) wsContainersAttach(ctx context.Context, w http.Respons
 	}
 	containerName := vars["name"]
 
-	if !s.backend.Exists(containerName) {
-		return derr.ErrorCodeNoSuchContainer.WithArgs(containerName)
-	}
-
 	var keys []byte
 	var err error
 	detachKeys := r.FormValue("detachKeys")
@@ -477,24 +484,44 @@ func (s *containerRouter) wsContainersAttach(ctx context.Context, w http.Respons
 		}
 	}
 
-	h := websocket.Handler(func(ws *websocket.Conn) {
-		defer ws.Close()
+	done := make(chan struct{})
+	started := make(chan struct{})
 
-		wsAttachWithLogsConfig := &backend.ContainerWsAttachWithLogsConfig{
-			InStream:   ws,
-			OutStream:  ws,
-			ErrStream:  ws,
-			Logs:       httputils.BoolValue(r, "logs"),
-			Stream:     httputils.BoolValue(r, "stream"),
-			DetachKeys: keys,
+	setupStreams := func() (io.ReadCloser, io.Writer, io.Writer, error) {
+		wsChan := make(chan *websocket.Conn)
+		h := func(conn *websocket.Conn) {
+			wsChan <- conn
+			<-done
 		}
 
-		if err := s.backend.ContainerWsAttachWithLogs(containerName, wsAttachWithLogsConfig); err != nil {
-			logrus.Errorf("Error attaching websocket: %s", utils.GetErrorMessage(err))
-		}
-	})
-	ws := websocket.Server{Handler: h, Handshake: nil}
-	ws.ServeHTTP(w, r)
+		srv := websocket.Server{Handler: h, Handshake: nil}
+		go func() {
+			close(started)
+			srv.ServeHTTP(w, r)
+		}()
 
-	return nil
+		conn := <-wsChan
+		return conn, conn, conn, nil
+	}
+
+	attachConfig := &backend.ContainerAttachConfig{
+		GetStreams: setupStreams,
+		Logs:       httputils.BoolValue(r, "logs"),
+		Stream:     httputils.BoolValue(r, "stream"),
+		DetachKeys: keys,
+		UseStdin:   true,
+		UseStdout:  true,
+		UseStderr:  true,
+		MuxStreams: false, // TODO: this should be true since it's a single stream for both stdout and stderr
+	}
+
+	err = s.backend.ContainerAttach(containerName, attachConfig)
+	close(done)
+	select {
+	case <-started:
+		logrus.Errorf("Error attaching websocket: %s", err)
+		return nil
+	default:
+	}
+	return err
 }

+ 1 - 8
api/server/router/system/system_routes.go

@@ -68,16 +68,9 @@ func (s *systemRouter) getEvents(ctx context.Context, w http.ResponseWriter, r *
 	}
 
 	w.Header().Set("Content-Type", "application/json")
-
-	// This is to ensure that the HTTP status code is sent immediately,
-	// so that it will not block the receiver.
-	w.WriteHeader(http.StatusOK)
-	if flusher, ok := w.(http.Flusher); ok {
-		flusher.Flush()
-	}
-
 	output := ioutils.NewWriteFlusher(w)
 	defer output.Close()
+	output.Flush()
 
 	enc := json.NewEncoder(output)
 

+ 8 - 15
api/types/backend/backend.go

@@ -5,32 +5,25 @@ package backend
 
 import (
 	"io"
-	"net/http"
 
 	"github.com/docker/engine-api/types"
 )
 
-// ContainerAttachWithLogsConfig holds the streams to use when connecting to a container to view logs.
-type ContainerAttachWithLogsConfig struct {
-	Hijacker   http.Hijacker
-	Upgrade    bool
+// ContainerAttachConfig holds the streams to use when connecting to a container to view logs.
+type ContainerAttachConfig struct {
+	GetStreams func() (io.ReadCloser, io.Writer, io.Writer, error)
 	UseStdin   bool
 	UseStdout  bool
 	UseStderr  bool
 	Logs       bool
 	Stream     bool
 	DetachKeys []byte
-}
 
-// ContainerWsAttachWithLogsConfig attach with websockets, since all
-// stream data is delegated to the websocket to handle there.
-type ContainerWsAttachWithLogsConfig struct {
-	InStream   io.ReadCloser // Reader to attach to stdin of container
-	OutStream  io.Writer     // Writer to attach to stdout of container
-	ErrStream  io.Writer     // Writer to attach to stderr of container
-	Logs       bool          // If true return log output
-	Stream     bool          // If true return stream output
-	DetachKeys []byte
+	// Used to signify that streams are multiplexed and therefore need a StdWriter to encode stdout/sderr messages accordingly.
+	// TODO @cpuguy83: This shouldn't be needed. It was only added so that http and websocket endpoints can use the same function, and the websocket function was not using a stdwriter prior to this change...
+	// HOWEVER, the websocket endpoint is using a single stream and SHOULD be encoded with stdout/stderr as is done for HTTP since it is still just a single stream.
+	// Since such a change is an API change unrelated to the current changeset we'll keep it as is here and change separately.
+	MuxStreams bool
 }
 
 // ContainerLogsConfig holds configs for logging operations. Exists

+ 1 - 1
builder/builder.go

@@ -106,7 +106,7 @@ type Backend interface {
 	// Pull tells Docker to pull image referenced by `name`.
 	PullOnBuild(name string, authConfigs map[string]types.AuthConfig, output io.Writer) (Image, error)
 	// ContainerAttach attaches to container.
-	ContainerAttachOnBuild(cID string, stdin io.ReadCloser, stdout, stderr io.Writer, stream bool) error
+	ContainerAttachRaw(cID string, stdin io.ReadCloser, stdout, stderr io.Writer, stream bool) error
 	// ContainerCreate creates a new Docker container and returns potential warnings
 	ContainerCreate(types.ContainerCreateConfig) (types.ContainerCreateResponse, error)
 	// ContainerRm removes a container specified by `id`.

+ 1 - 1
builder/dockerfile/internals.go

@@ -541,7 +541,7 @@ func (b *Builder) create() (string, error) {
 func (b *Builder) run(cID string) (err error) {
 	errCh := make(chan error)
 	go func() {
-		errCh <- b.docker.ContainerAttachOnBuild(cID, nil, b.Stdout, b.Stderr, true)
+		errCh <- b.docker.ContainerAttachRaw(cID, nil, b.Stdout, b.Stderr, true)
 	}()
 
 	finished := make(chan struct{})

+ 11 - 38
daemon/attach.go

@@ -13,11 +13,8 @@ import (
 	"github.com/docker/docker/pkg/stdcopy"
 )
 
-// ContainerAttachWithLogs attaches to logs according to the config passed in. See ContainerAttachWithLogsConfig.
-func (daemon *Daemon) ContainerAttachWithLogs(prefixOrName string, c *backend.ContainerAttachWithLogsConfig) error {
-	if c.Hijacker == nil {
-		return derr.ErrorCodeNoHijackConnection.WithArgs(prefixOrName)
-	}
+// ContainerAttach attaches to logs according to the config passed in. See ContainerAttachConfig.
+func (daemon *Daemon) ContainerAttach(prefixOrName string, c *backend.ContainerAttachConfig) error {
 	container, err := daemon.GetContainer(prefixOrName)
 	if err != nil {
 		return derr.ErrorCodeNoSuchContainer.WithArgs(prefixOrName)
@@ -26,29 +23,15 @@ func (daemon *Daemon) ContainerAttachWithLogs(prefixOrName string, c *backend.Co
 		return derr.ErrorCodePausedContainer.WithArgs(prefixOrName)
 	}
 
-	conn, _, err := c.Hijacker.Hijack()
+	inStream, outStream, errStream, err := c.GetStreams()
 	if err != nil {
 		return err
 	}
-	defer conn.Close()
-	// Flush the options to make sure the client sets the raw mode
-	conn.Write([]byte{})
-	inStream := conn.(io.ReadCloser)
-	outStream := conn.(io.Writer)
-
-	if c.Upgrade {
-		fmt.Fprintf(outStream, "HTTP/1.1 101 UPGRADED\r\nContent-Type: application/vnd.docker.raw-stream\r\nConnection: Upgrade\r\nUpgrade: tcp\r\n\r\n")
-	} else {
-		fmt.Fprintf(outStream, "HTTP/1.1 200 OK\r\nContent-Type: application/vnd.docker.raw-stream\r\n\r\n")
-	}
+	defer inStream.Close()
 
-	var errStream io.Writer
-
-	if !container.Config.Tty {
-		errStream = stdcopy.NewStdWriter(outStream, stdcopy.Stderr)
+	if !container.Config.Tty && c.MuxStreams {
+		errStream = stdcopy.NewStdWriter(errStream, stdcopy.Stderr)
 		outStream = stdcopy.NewStdWriter(outStream, stdcopy.Stdout)
-	} else {
-		errStream = outStream
 	}
 
 	var stdin io.ReadCloser
@@ -64,32 +47,22 @@ func (daemon *Daemon) ContainerAttachWithLogs(prefixOrName string, c *backend.Co
 		stderr = errStream
 	}
 
-	if err := daemon.attachWithLogs(container, stdin, stdout, stderr, c.Logs, c.Stream, c.DetachKeys); err != nil {
+	if err := daemon.containerAttach(container, stdin, stdout, stderr, c.Logs, c.Stream, c.DetachKeys); err != nil {
 		fmt.Fprintf(outStream, "Error attaching: %s\n", err)
 	}
 	return nil
 }
 
-// ContainerWsAttachWithLogs websocket connection
-func (daemon *Daemon) ContainerWsAttachWithLogs(prefixOrName string, c *backend.ContainerWsAttachWithLogsConfig) error {
+// ContainerAttachRaw attaches the provided streams to the container's stdio
+func (daemon *Daemon) ContainerAttachRaw(prefixOrName string, stdin io.ReadCloser, stdout, stderr io.Writer, stream bool) error {
 	container, err := daemon.GetContainer(prefixOrName)
 	if err != nil {
 		return err
 	}
-	return daemon.attachWithLogs(container, c.InStream, c.OutStream, c.ErrStream, c.Logs, c.Stream, c.DetachKeys)
-}
-
-// ContainerAttachOnBuild attaches streams to the container cID. If stream is true, it streams the output.
-func (daemon *Daemon) ContainerAttachOnBuild(cID string, stdin io.ReadCloser, stdout, stderr io.Writer, stream bool) error {
-	return daemon.ContainerWsAttachWithLogs(cID, &backend.ContainerWsAttachWithLogsConfig{
-		InStream:  stdin,
-		OutStream: stdout,
-		ErrStream: stderr,
-		Stream:    stream,
-	})
+	return daemon.containerAttach(container, stdin, stdout, stderr, false, stream, nil)
 }
 
-func (daemon *Daemon) attachWithLogs(container *container.Container, stdin io.ReadCloser, stdout, stderr io.Writer, logs, stream bool, keys []byte) error {
+func (daemon *Daemon) containerAttach(container *container.Container, stdin io.ReadCloser, stdout, stderr io.Writer, logs, stream bool, keys []byte) error {
 	if logs {
 		logDriver, err := daemon.getLogger(container)
 		if err != nil {

+ 14 - 9
daemon/logs.go

@@ -11,13 +11,14 @@ import (
 	"github.com/docker/docker/daemon/logger"
 	"github.com/docker/docker/daemon/logger/jsonfilelog"
 	derr "github.com/docker/docker/errors"
+	"github.com/docker/docker/pkg/ioutils"
 	"github.com/docker/docker/pkg/stdcopy"
 	timetypes "github.com/docker/engine-api/types/time"
 )
 
 // ContainerLogs hooks up a container's stdout and stderr streams
 // configured with the given struct.
-func (daemon *Daemon) ContainerLogs(containerName string, config *backend.ContainerLogsConfig) error {
+func (daemon *Daemon) ContainerLogs(containerName string, config *backend.ContainerLogsConfig, started chan struct{}) error {
 	container, err := daemon.GetContainer(containerName)
 	if err != nil {
 		return derr.ErrorCodeNoSuchContainer.WithArgs(containerName)
@@ -27,14 +28,6 @@ func (daemon *Daemon) ContainerLogs(containerName string, config *backend.Contai
 		return derr.ErrorCodeNeedStream
 	}
 
-	outStream := config.OutStream
-	errStream := outStream
-	if !container.Config.Tty {
-		errStream = stdcopy.NewStdWriter(outStream, stdcopy.Stderr)
-		outStream = stdcopy.NewStdWriter(outStream, stdcopy.Stdout)
-	}
-	config.OutStream = outStream
-
 	cLog, err := daemon.getLogger(container)
 	if err != nil {
 		return err
@@ -67,6 +60,18 @@ func (daemon *Daemon) ContainerLogs(containerName string, config *backend.Contai
 	}
 	logs := logReader.ReadLogs(readConfig)
 
+	wf := ioutils.NewWriteFlusher(config.OutStream)
+	defer wf.Close()
+	close(started)
+	wf.Flush()
+
+	var outStream io.Writer = wf
+	errStream := outStream
+	if !container.Config.Tty {
+		errStream = stdcopy.NewStdWriter(outStream, stdcopy.Stderr)
+		outStream = stdcopy.NewStdWriter(outStream, stdcopy.Stdout)
+	}
+
 	for {
 		select {
 		case err := <-logs.Err:

+ 7 - 5
daemon/stats.go

@@ -7,6 +7,7 @@ import (
 
 	"github.com/docker/docker/api/types/backend"
 	"github.com/docker/docker/daemon/execdriver"
+	"github.com/docker/docker/pkg/ioutils"
 	"github.com/docker/docker/pkg/version"
 	"github.com/docker/engine-api/types"
 	"github.com/docker/engine-api/types/versions/v1p20"
@@ -31,11 +32,12 @@ func (daemon *Daemon) ContainerStats(prefixOrName string, config *backend.Contai
 		return json.NewEncoder(config.OutStream).Encode(&types.Stats{})
 	}
 
+	outStream := config.OutStream
 	if config.Stream {
-		// Write an empty chunk of data.
-		// This is to ensure that the HTTP status code is sent immediately,
-		// even if the container has not yet produced any data.
-		config.OutStream.Write(nil)
+		wf := ioutils.NewWriteFlusher(outStream)
+		defer wf.Close()
+		wf.Flush()
+		outStream = wf
 	}
 
 	var preCPUStats types.CPUStats
@@ -50,7 +52,7 @@ func (daemon *Daemon) ContainerStats(prefixOrName string, config *backend.Contai
 		return ss
 	}
 
-	enc := json.NewEncoder(config.OutStream)
+	enc := json.NewEncoder(outStream)
 
 	updates := daemon.subscribeToContainerStats(container)
 	defer daemon.unsubscribeToContainerStats(container, updates)

+ 19 - 11
integration-cli/docker_api_containers_test.go

@@ -416,22 +416,30 @@ func (s *DockerSuite) TestGetContainerStatsNoStream(c *check.C) {
 func (s *DockerSuite) TestGetStoppedContainerStats(c *check.C) {
 	// Problematic on Windows as Windows does not support stats
 	testRequires(c, DaemonIsLinux)
-	// TODO: this test does nothing because we are c.Assert'ing in goroutine
-	var (
-		name = "statscontainer"
-	)
+	name := "statscontainer"
 	dockerCmd(c, "create", "--name", name, "busybox", "top")
 
+	type stats struct {
+		status int
+		err    error
+	}
+	chResp := make(chan stats)
+
+	// We expect an immediate response, but if it's not immediate, the test would hang, so put it in a goroutine
+	// below we'll check this on a timeout.
 	go func() {
-		// We'll never get return for GET stats from sockRequest as of now,
-		// just send request and see if panic or error would happen on daemon side.
-		status, _, err := sockRequest("GET", "/containers/"+name+"/stats", nil)
-		c.Assert(err, checker.IsNil)
-		c.Assert(status, checker.Equals, http.StatusOK)
+		resp, body, err := sockRequestRaw("GET", "/containers/"+name+"/stats", nil, "")
+		body.Close()
+		chResp <- stats{resp.StatusCode, err}
 	}()
 
-	// allow some time to send request and let daemon deal with it
-	time.Sleep(1 * time.Second)
+	select {
+	case r := <-chResp:
+		c.Assert(r.err, checker.IsNil)
+		c.Assert(r.status, checker.Equals, http.StatusOK)
+	case <-time.After(10 * time.Second):
+		c.Fatal("timeout waiting for stats reponse for stopped container")
+	}
 }
 
 // #9981 - Allow a docker created volume (ie, one in /var/lib/docker/volumes) to be used to overwrite (via passing in Binds on api start) an existing volume

+ 41 - 41
pkg/ioutils/writeflusher.go

@@ -1,9 +1,7 @@
 package ioutils
 
 import (
-	"errors"
 	"io"
-	"net/http"
 	"sync"
 )
 
@@ -11,45 +9,43 @@ import (
 // is a flush. In addition, the Close method can be called to intercept
 // Read/Write calls if the targets lifecycle has already ended.
 type WriteFlusher struct {
-	mu      sync.Mutex
-	w       io.Writer
-	flusher http.Flusher
-	flushed bool
-	closed  error
+	w           io.Writer
+	flusher     flusher
+	flushed     chan struct{}
+	flushedOnce sync.Once
+	closed      chan struct{}
+	closeLock   sync.Mutex
+}
 
-	// TODO(stevvooe): Use channel for closed instead, remove mutex. Using a
-	// channel will allow one to properly order the operations.
+type flusher interface {
+	Flush()
 }
 
-var errWriteFlusherClosed = errors.New("writeflusher: closed")
+var errWriteFlusherClosed = io.EOF
 
 func (wf *WriteFlusher) Write(b []byte) (n int, err error) {
-	wf.mu.Lock()
-	defer wf.mu.Unlock()
-	if wf.closed != nil {
-		return 0, wf.closed
+	select {
+	case <-wf.closed:
+		return 0, errWriteFlusherClosed
+	default:
 	}
 
 	n, err = wf.w.Write(b)
-	wf.flush() // every write is a flush.
+	wf.Flush() // every write is a flush.
 	return n, err
 }
 
 // Flush the stream immediately.
 func (wf *WriteFlusher) Flush() {
-	wf.mu.Lock()
-	defer wf.mu.Unlock()
-
-	wf.flush()
-}
-
-// flush the stream immediately without taking a lock. Used internally.
-func (wf *WriteFlusher) flush() {
-	if wf.closed != nil {
+	select {
+	case <-wf.closed:
 		return
+	default:
 	}
 
-	wf.flushed = true
+	wf.flushedOnce.Do(func() {
+		close(wf.flushed)
+	})
 	wf.flusher.Flush()
 }
 
@@ -59,34 +55,38 @@ func (wf *WriteFlusher) Flushed() bool {
 	// BUG(stevvooe): Remove this method. Its use is inherently racy. Seems to
 	// be used to detect whether or a response code has been issued or not.
 	// Another hook should be used instead.
-	wf.mu.Lock()
-	defer wf.mu.Unlock()
-
-	return wf.flushed
+	var flushed bool
+	select {
+	case <-wf.flushed:
+		flushed = true
+	default:
+	}
+	return flushed
 }
 
 // Close closes the write flusher, disallowing any further writes to the
 // target. After the flusher is closed, all calls to write or flush will
 // result in an error.
 func (wf *WriteFlusher) Close() error {
-	wf.mu.Lock()
-	defer wf.mu.Unlock()
-
-	if wf.closed != nil {
-		return wf.closed
+	wf.closeLock.Lock()
+	defer wf.closeLock.Unlock()
+
+	select {
+	case <-wf.closed:
+		return errWriteFlusherClosed
+	default:
+		close(wf.closed)
 	}
-
-	wf.closed = errWriteFlusherClosed
 	return nil
 }
 
 // NewWriteFlusher returns a new WriteFlusher.
 func NewWriteFlusher(w io.Writer) *WriteFlusher {
-	var flusher http.Flusher
-	if f, ok := w.(http.Flusher); ok {
-		flusher = f
+	var fl flusher
+	if f, ok := w.(flusher); ok {
+		fl = f
 	} else {
-		flusher = &NopFlusher{}
+		fl = &NopFlusher{}
 	}
-	return &WriteFlusher{w: w, flusher: flusher}
+	return &WriteFlusher{w: w, flusher: fl, closed: make(chan struct{}), flushed: make(chan struct{})}
 }