Browse Source

Clean up ProgressStatus

- Rename to Broadcaster

- Document exported types

- Change Wait function to just wait. Writing a message to the writer and
  adding the writer to the observers list are now handled by separate
  function calls.

- Avoid importing logrus (the condition where it was used should never
  happen, anyway).

- Make writes non-blocking

Signed-off-by: Aaron Lehmann <aaron.lehmann@docker.com>
Aaron Lehmann 10 years ago
parent
commit
26c9b58504
7 changed files with 204 additions and 126 deletions
  1. 1 1
      graph/load.go
  2. 2 2
      graph/pools_test.go
  3. 26 24
      graph/pull_v1.go
  4. 15 13
      graph/pull_v2.go
  5. 14 14
      graph/tags.go
  6. 146 0
      pkg/progressreader/broadcaster.go
  7. 0 72
      pkg/progressreader/progressstatus.go

+ 1 - 1
graph/load.go

@@ -108,7 +108,7 @@ func (s *TagStore) recursiveLoad(address, tmpImageDir string) error {
 		// ensure no two downloads of the same layer happen at the same time
 		// ensure no two downloads of the same layer happen at the same time
 		if ps, found := s.poolAdd("pull", "layer:"+img.ID); found {
 		if ps, found := s.poolAdd("pull", "layer:"+img.ID); found {
 			logrus.Debugf("Image (id: %s) load is already running, waiting", img.ID)
 			logrus.Debugf("Image (id: %s) load is already running, waiting", img.ID)
-			ps.Wait(nil, nil)
+			ps.Wait()
 			return nil
 			return nil
 		}
 		}
 
 

+ 2 - 2
graph/pools_test.go

@@ -13,8 +13,8 @@ func init() {
 
 
 func TestPools(t *testing.T) {
 func TestPools(t *testing.T) {
 	s := &TagStore{
 	s := &TagStore{
-		pullingPool: make(map[string]*progressreader.ProgressStatus),
-		pushingPool: make(map[string]*progressreader.ProgressStatus),
+		pullingPool: make(map[string]*progressreader.Broadcaster),
+		pushingPool: make(map[string]*progressreader.Broadcaster),
 	}
 	}
 
 
 	if _, found := s.poolAdd("pull", "test1"); found {
 	if _, found := s.poolAdd("pull", "test1"); found {

+ 26 - 24
graph/pull_v1.go

@@ -138,29 +138,30 @@ func (p *v1Puller) pullRepository(askedTag string) error {
 			}
 			}
 
 
 			// ensure no two downloads of the same image happen at the same time
 			// ensure no two downloads of the same image happen at the same time
-			ps, found := p.poolAdd("pull", "img:"+img.ID)
+			broadcaster, found := p.poolAdd("pull", "img:"+img.ID)
 			if found {
 			if found {
-				msg := p.sf.FormatProgress(stringid.TruncateID(img.ID), "Layer already being pulled by another client. Waiting.", nil)
-				ps.Wait(out, msg)
+				out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Layer already being pulled by another client. Waiting.", nil))
+				broadcaster.Add(out)
+				broadcaster.Wait()
 				out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil))
 				out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil))
 				errors <- nil
 				errors <- nil
 				return
 				return
 			}
 			}
-			ps.AddObserver(out)
+			broadcaster.Add(out)
 			defer p.poolRemove("pull", "img:"+img.ID)
 			defer p.poolRemove("pull", "img:"+img.ID)
 
 
 			// we need to retain it until tagging
 			// we need to retain it until tagging
 			p.graph.Retain(sessionID, img.ID)
 			p.graph.Retain(sessionID, img.ID)
 			imgIDs = append(imgIDs, img.ID)
 			imgIDs = append(imgIDs, img.ID)
 
 
-			ps.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s", img.Tag, p.repoInfo.CanonicalName), nil))
+			broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s", img.Tag, p.repoInfo.CanonicalName), nil))
 			success := false
 			success := false
 			var lastErr, err error
 			var lastErr, err error
 			var isDownloaded bool
 			var isDownloaded bool
 			for _, ep := range p.repoInfo.Index.Mirrors {
 			for _, ep := range p.repoInfo.Index.Mirrors {
 				ep += "v1/"
 				ep += "v1/"
-				ps.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, mirror: %s", img.Tag, p.repoInfo.CanonicalName, ep), nil))
-				if isDownloaded, err = p.pullImage(ps, img.ID, ep, repoData.Tokens); err != nil {
+				broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, mirror: %s", img.Tag, p.repoInfo.CanonicalName, ep), nil))
+				if isDownloaded, err = p.pullImage(broadcaster, img.ID, ep, repoData.Tokens); err != nil {
 					// Don't report errors when pulling from mirrors.
 					// Don't report errors when pulling from mirrors.
 					logrus.Debugf("Error pulling image (%s) from %s, mirror: %s, %s", img.Tag, p.repoInfo.CanonicalName, ep, err)
 					logrus.Debugf("Error pulling image (%s) from %s, mirror: %s, %s", img.Tag, p.repoInfo.CanonicalName, ep, err)
 					continue
 					continue
@@ -171,12 +172,12 @@ func (p *v1Puller) pullRepository(askedTag string) error {
 			}
 			}
 			if !success {
 			if !success {
 				for _, ep := range repoData.Endpoints {
 				for _, ep := range repoData.Endpoints {
-					ps.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, endpoint: %s", img.Tag, p.repoInfo.CanonicalName, ep), nil))
-					if isDownloaded, err = p.pullImage(ps, img.ID, ep, repoData.Tokens); err != nil {
+					broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, endpoint: %s", img.Tag, p.repoInfo.CanonicalName, ep), nil))
+					if isDownloaded, err = p.pullImage(broadcaster, img.ID, ep, repoData.Tokens); err != nil {
 						// It's not ideal that only the last error is returned, it would be better to concatenate the errors.
 						// It's not ideal that only the last error is returned, it would be better to concatenate the errors.
 						// As the error is also given to the output stream the user will see the error.
 						// As the error is also given to the output stream the user will see the error.
 						lastErr = err
 						lastErr = err
-						ps.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Error pulling image (%s) from %s, endpoint: %s, %s", img.Tag, p.repoInfo.CanonicalName, ep, err), nil))
+						broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Error pulling image (%s) from %s, endpoint: %s, %s", img.Tag, p.repoInfo.CanonicalName, ep, err), nil))
 						continue
 						continue
 					}
 					}
 					layersDownloaded = layersDownloaded || isDownloaded
 					layersDownloaded = layersDownloaded || isDownloaded
@@ -186,11 +187,11 @@ func (p *v1Puller) pullRepository(askedTag string) error {
 			}
 			}
 			if !success {
 			if !success {
 				err := fmt.Errorf("Error pulling image (%s) from %s, %v", img.Tag, p.repoInfo.CanonicalName, lastErr)
 				err := fmt.Errorf("Error pulling image (%s) from %s, %v", img.Tag, p.repoInfo.CanonicalName, lastErr)
-				ps.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), err.Error(), nil))
+				broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), err.Error(), nil))
 				errors <- err
 				errors <- err
 				return
 				return
 			}
 			}
-			ps.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil))
+			broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil))
 
 
 			errors <- nil
 			errors <- nil
 		}
 		}
@@ -244,18 +245,19 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri
 		id := history[i]
 		id := history[i]
 
 
 		// ensure no two downloads of the same layer happen at the same time
 		// ensure no two downloads of the same layer happen at the same time
-		ps, found := p.poolAdd("pull", "layer:"+id)
+		broadcaster, found := p.poolAdd("pull", "layer:"+id)
 		if found {
 		if found {
 			logrus.Debugf("Image (id: %s) pull is already running, skipping", id)
 			logrus.Debugf("Image (id: %s) pull is already running, skipping", id)
-			msg := p.sf.FormatProgress(stringid.TruncateID(imgID), "Layer already being pulled by another client. Waiting.", nil)
-			ps.Wait(out, msg)
+			out.Write(p.sf.FormatProgress(stringid.TruncateID(imgID), "Layer already being pulled by another client. Waiting.", nil))
+			broadcaster.Add(out)
+			broadcaster.Wait()
 		} else {
 		} else {
-			ps.AddObserver(out)
+			broadcaster.Add(out)
 		}
 		}
 		defer p.poolRemove("pull", "layer:"+id)
 		defer p.poolRemove("pull", "layer:"+id)
 
 
 		if !p.graph.Exists(id) {
 		if !p.graph.Exists(id) {
-			ps.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Pulling metadata", nil))
+			broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Pulling metadata", nil))
 			var (
 			var (
 				imgJSON []byte
 				imgJSON []byte
 				imgSize int64
 				imgSize int64
@@ -266,7 +268,7 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri
 			for j := 1; j <= retries; j++ {
 			for j := 1; j <= retries; j++ {
 				imgJSON, imgSize, err = p.session.GetRemoteImageJSON(id, endpoint)
 				imgJSON, imgSize, err = p.session.GetRemoteImageJSON(id, endpoint)
 				if err != nil && j == retries {
 				if err != nil && j == retries {
-					ps.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error pulling dependent layers", nil))
+					broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error pulling dependent layers", nil))
 					return layersDownloaded, err
 					return layersDownloaded, err
 				} else if err != nil {
 				} else if err != nil {
 					time.Sleep(time.Duration(j) * 500 * time.Millisecond)
 					time.Sleep(time.Duration(j) * 500 * time.Millisecond)
@@ -275,7 +277,7 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri
 				img, err = image.NewImgJSON(imgJSON)
 				img, err = image.NewImgJSON(imgJSON)
 				layersDownloaded = true
 				layersDownloaded = true
 				if err != nil && j == retries {
 				if err != nil && j == retries {
-					ps.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error pulling dependent layers", nil))
+					broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error pulling dependent layers", nil))
 					return layersDownloaded, fmt.Errorf("Failed to parse json: %s", err)
 					return layersDownloaded, fmt.Errorf("Failed to parse json: %s", err)
 				} else if err != nil {
 				} else if err != nil {
 					time.Sleep(time.Duration(j) * 500 * time.Millisecond)
 					time.Sleep(time.Duration(j) * 500 * time.Millisecond)
@@ -291,7 +293,7 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri
 				if j > 1 {
 				if j > 1 {
 					status = fmt.Sprintf("Pulling fs layer [retries: %d]", j)
 					status = fmt.Sprintf("Pulling fs layer [retries: %d]", j)
 				}
 				}
-				ps.Write(p.sf.FormatProgress(stringid.TruncateID(id), status, nil))
+				broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), status, nil))
 				layer, err := p.session.GetRemoteImageLayer(img.ID, endpoint, imgSize)
 				layer, err := p.session.GetRemoteImageLayer(img.ID, endpoint, imgSize)
 				if uerr, ok := err.(*url.Error); ok {
 				if uerr, ok := err.(*url.Error); ok {
 					err = uerr.Err
 					err = uerr.Err
@@ -300,7 +302,7 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri
 					time.Sleep(time.Duration(j) * 500 * time.Millisecond)
 					time.Sleep(time.Duration(j) * 500 * time.Millisecond)
 					continue
 					continue
 				} else if err != nil {
 				} else if err != nil {
-					ps.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error pulling dependent layers", nil))
+					broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error pulling dependent layers", nil))
 					return layersDownloaded, err
 					return layersDownloaded, err
 				}
 				}
 				layersDownloaded = true
 				layersDownloaded = true
@@ -309,7 +311,7 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri
 				err = p.graph.Register(img,
 				err = p.graph.Register(img,
 					progressreader.New(progressreader.Config{
 					progressreader.New(progressreader.Config{
 						In:        layer,
 						In:        layer,
-						Out:       ps,
+						Out:       broadcaster,
 						Formatter: p.sf,
 						Formatter: p.sf,
 						Size:      imgSize,
 						Size:      imgSize,
 						NewLines:  false,
 						NewLines:  false,
@@ -320,14 +322,14 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri
 					time.Sleep(time.Duration(j) * 500 * time.Millisecond)
 					time.Sleep(time.Duration(j) * 500 * time.Millisecond)
 					continue
 					continue
 				} else if err != nil {
 				} else if err != nil {
-					ps.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error downloading dependent layers", nil))
+					broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error downloading dependent layers", nil))
 					return layersDownloaded, err
 					return layersDownloaded, err
 				} else {
 				} else {
 					break
 					break
 				}
 				}
 			}
 			}
 		}
 		}
-		ps.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Download complete", nil))
+		broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Download complete", nil))
 	}
 	}
 	return layersDownloaded, nil
 	return layersDownloaded, nil
 }
 }

+ 15 - 13
graph/pull_v2.go

@@ -73,28 +73,29 @@ func (p *v2Puller) pullV2Repository(tag string) (err error) {
 
 
 	}
 	}
 
 
-	ps, found := p.poolAdd("pull", taggedName)
+	broadcaster, found := p.poolAdd("pull", taggedName)
 	if found {
 	if found {
 		// Another pull of the same repository is already taking place; just wait for it to finish
 		// Another pull of the same repository is already taking place; just wait for it to finish
-		msg := p.sf.FormatStatus("", "Repository %s already being pulled by another client. Waiting.", p.repoInfo.CanonicalName)
-		ps.Wait(p.config.OutStream, msg)
+		p.config.OutStream.Write(p.sf.FormatStatus("", "Repository %s already being pulled by another client. Waiting.", p.repoInfo.CanonicalName))
+		broadcaster.Add(p.config.OutStream)
+		broadcaster.Wait()
 		return nil
 		return nil
 	}
 	}
 	defer p.poolRemove("pull", taggedName)
 	defer p.poolRemove("pull", taggedName)
-	ps.AddObserver(p.config.OutStream)
+	broadcaster.Add(p.config.OutStream)
 
 
 	var layersDownloaded bool
 	var layersDownloaded bool
 	for _, tag := range tags {
 	for _, tag := range tags {
 		// pulledNew is true if either new layers were downloaded OR if existing images were newly tagged
 		// pulledNew is true if either new layers were downloaded OR if existing images were newly tagged
 		// TODO(tiborvass): should we change the name of `layersDownload`? What about message in WriteStatus?
 		// TODO(tiborvass): should we change the name of `layersDownload`? What about message in WriteStatus?
-		pulledNew, err := p.pullV2Tag(ps, tag, taggedName)
+		pulledNew, err := p.pullV2Tag(broadcaster, tag, taggedName)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
 		layersDownloaded = layersDownloaded || pulledNew
 		layersDownloaded = layersDownloaded || pulledNew
 	}
 	}
 
 
-	writeStatus(taggedName, ps, p.sf, layersDownloaded)
+	writeStatus(taggedName, broadcaster, p.sf, layersDownloaded)
 
 
 	return nil
 	return nil
 }
 }
@@ -119,16 +120,17 @@ func (p *v2Puller) download(di *downloadInfo) {
 
 
 	out := di.out
 	out := di.out
 
 
-	ps, found := p.poolAdd("pull", "img:"+di.img.ID)
+	broadcaster, found := p.poolAdd("pull", "img:"+di.img.ID)
 	if found {
 	if found {
-		msg := p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Layer already being pulled by another client. Waiting.", nil)
-		ps.Wait(out, msg)
+		out.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Layer already being pulled by another client. Waiting.", nil))
+		broadcaster.Add(out)
+		broadcaster.Wait()
 		out.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Download complete", nil))
 		out.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Download complete", nil))
 		di.err <- nil
 		di.err <- nil
 		return
 		return
 	}
 	}
 
 
-	ps.AddObserver(out)
+	broadcaster.Add(out)
 	defer p.poolRemove("pull", "img:"+di.img.ID)
 	defer p.poolRemove("pull", "img:"+di.img.ID)
 	tmpFile, err := ioutil.TempFile("", "GetImageBlob")
 	tmpFile, err := ioutil.TempFile("", "GetImageBlob")
 	if err != nil {
 	if err != nil {
@@ -163,7 +165,7 @@ func (p *v2Puller) download(di *downloadInfo) {
 
 
 	reader := progressreader.New(progressreader.Config{
 	reader := progressreader.New(progressreader.Config{
 		In:        ioutil.NopCloser(io.TeeReader(layerDownload, verifier)),
 		In:        ioutil.NopCloser(io.TeeReader(layerDownload, verifier)),
-		Out:       ps,
+		Out:       broadcaster,
 		Formatter: p.sf,
 		Formatter: p.sf,
 		Size:      di.size,
 		Size:      di.size,
 		NewLines:  false,
 		NewLines:  false,
@@ -172,7 +174,7 @@ func (p *v2Puller) download(di *downloadInfo) {
 	})
 	})
 	io.Copy(tmpFile, reader)
 	io.Copy(tmpFile, reader)
 
 
-	ps.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Verifying Checksum", nil))
+	broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Verifying Checksum", nil))
 
 
 	if !verifier.Verified() {
 	if !verifier.Verified() {
 		err = fmt.Errorf("filesystem layer verification failed for digest %s", di.digest)
 		err = fmt.Errorf("filesystem layer verification failed for digest %s", di.digest)
@@ -181,7 +183,7 @@ func (p *v2Puller) download(di *downloadInfo) {
 		return
 		return
 	}
 	}
 
 
-	ps.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Download complete", nil))
+	broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Download complete", nil))
 
 
 	logrus.Debugf("Downloaded %s to tempfile %s", di.img.ID, tmpFile.Name())
 	logrus.Debugf("Downloaded %s to tempfile %s", di.img.ID, tmpFile.Name())
 	di.layer = layerDownload
 	di.layer = layerDownload

+ 14 - 14
graph/tags.go

@@ -37,8 +37,8 @@ type TagStore struct {
 	sync.Mutex
 	sync.Mutex
 	// FIXME: move push/pull-related fields
 	// FIXME: move push/pull-related fields
 	// to a helper type
 	// to a helper type
-	pullingPool     map[string]*progressreader.ProgressStatus
-	pushingPool     map[string]*progressreader.ProgressStatus
+	pullingPool     map[string]*progressreader.Broadcaster
+	pushingPool     map[string]*progressreader.Broadcaster
 	registryService *registry.Service
 	registryService *registry.Service
 	eventsService   *events.Events
 	eventsService   *events.Events
 	trustService    *trust.TrustStore
 	trustService    *trust.TrustStore
@@ -94,8 +94,8 @@ func NewTagStore(path string, cfg *TagStoreConfig) (*TagStore, error) {
 		graph:           cfg.Graph,
 		graph:           cfg.Graph,
 		trustKey:        cfg.Key,
 		trustKey:        cfg.Key,
 		Repositories:    make(map[string]Repository),
 		Repositories:    make(map[string]Repository),
-		pullingPool:     make(map[string]*progressreader.ProgressStatus),
-		pushingPool:     make(map[string]*progressreader.ProgressStatus),
+		pullingPool:     make(map[string]*progressreader.Broadcaster),
+		pushingPool:     make(map[string]*progressreader.Broadcaster),
 		registryService: cfg.Registry,
 		registryService: cfg.Registry,
 		eventsService:   cfg.Events,
 		eventsService:   cfg.Events,
 		trustService:    cfg.Trust,
 		trustService:    cfg.Trust,
@@ -428,10 +428,10 @@ func validateDigest(dgst string) error {
 	return nil
 	return nil
 }
 }
 
 
-// poolAdd checks if a push or pull is already running, and returns (ps, true)
-// if a running operation is found. Otherwise, it creates a new one and returns
-// (ps, false).
-func (store *TagStore) poolAdd(kind, key string) (*progressreader.ProgressStatus, bool) {
+// poolAdd checks if a push or pull is already running, and returns
+// (broadcaster, true) if a running operation is found. Otherwise, it creates a
+// new one and returns (broadcaster, false).
+func (store *TagStore) poolAdd(kind, key string) (*progressreader.Broadcaster, bool) {
 	store.Lock()
 	store.Lock()
 	defer store.Unlock()
 	defer store.Unlock()
 
 
@@ -442,18 +442,18 @@ func (store *TagStore) poolAdd(kind, key string) (*progressreader.ProgressStatus
 		return p, true
 		return p, true
 	}
 	}
 
 
-	ps := progressreader.NewProgressStatus()
+	broadcaster := progressreader.NewBroadcaster()
 
 
 	switch kind {
 	switch kind {
 	case "pull":
 	case "pull":
-		store.pullingPool[key] = ps
+		store.pullingPool[key] = broadcaster
 	case "push":
 	case "push":
-		store.pushingPool[key] = ps
+		store.pushingPool[key] = broadcaster
 	default:
 	default:
 		panic("Unknown pool type")
 		panic("Unknown pool type")
 	}
 	}
 
 
-	return ps, false
+	return broadcaster, false
 }
 }
 
 
 func (store *TagStore) poolRemove(kind, key string) error {
 func (store *TagStore) poolRemove(kind, key string) error {
@@ -462,12 +462,12 @@ func (store *TagStore) poolRemove(kind, key string) error {
 	switch kind {
 	switch kind {
 	case "pull":
 	case "pull":
 		if ps, exists := store.pullingPool[key]; exists {
 		if ps, exists := store.pullingPool[key]; exists {
-			ps.Done()
+			ps.Close()
 			delete(store.pullingPool, key)
 			delete(store.pullingPool, key)
 		}
 		}
 	case "push":
 	case "push":
 		if ps, exists := store.pushingPool[key]; exists {
 		if ps, exists := store.pushingPool[key]; exists {
-			ps.Done()
+			ps.Close()
 			delete(store.pushingPool, key)
 			delete(store.pushingPool, key)
 		}
 		}
 	default:
 	default:

+ 146 - 0
pkg/progressreader/broadcaster.go

@@ -0,0 +1,146 @@
+package progressreader
+
+import (
+	"bytes"
+	"errors"
+	"io"
+	"sync"
+)
+
+// Broadcaster keeps track of one or more observers watching the progress
+// of an operation. For example, if multiple clients are trying to pull an
+// image, they share a Broadcaster for the download operation.
+type Broadcaster struct {
+	sync.Mutex
+	// c is a channel that observers block on, waiting for the operation
+	// to finish.
+	c chan struct{}
+	// cond is a condition variable used to wake up observers when there's
+	// new data available.
+	cond *sync.Cond
+	// history is a buffer of the progress output so far, so a new observer
+	// can catch up.
+	history bytes.Buffer
+	// wg is a WaitGroup used to wait for all writes to finish on Close
+	wg sync.WaitGroup
+	// isClosed is set to true when Close is called to avoid closing c
+	// multiple times.
+	isClosed bool
+}
+
+// NewBroadcaster returns a Broadcaster structure
+func NewBroadcaster() *Broadcaster {
+	b := &Broadcaster{
+		c: make(chan struct{}),
+	}
+	b.cond = sync.NewCond(b)
+	return b
+}
+
+// closed returns true if and only if the broadcaster has been closed
+func (broadcaster *Broadcaster) closed() bool {
+	select {
+	case <-broadcaster.c:
+		return true
+	default:
+		return false
+	}
+}
+
+// receiveWrites runs as a goroutine so that writes don't block the Write
+// function. It writes the new data in broadcaster.history each time there's
+// activity on the broadcaster.cond condition variable.
+func (broadcaster *Broadcaster) receiveWrites(observer io.Writer) {
+	n := 0
+
+	broadcaster.Lock()
+
+	// The condition variable wait is at the end of this loop, so that the
+	// first iteration will write the history so far.
+	for {
+		newData := broadcaster.history.Bytes()[n:]
+		// Make a copy of newData so we can release the lock
+		sendData := make([]byte, len(newData), len(newData))
+		copy(sendData, newData)
+		broadcaster.Unlock()
+
+		if len(sendData) > 0 {
+			written, err := observer.Write(sendData)
+			if err != nil {
+				broadcaster.wg.Done()
+				return
+			}
+			n += written
+		}
+
+		broadcaster.Lock()
+
+		// detect closure of the broadcast writer
+		if broadcaster.closed() {
+			broadcaster.Unlock()
+			broadcaster.wg.Done()
+			return
+		}
+
+		if broadcaster.history.Len() == n {
+			broadcaster.cond.Wait()
+		}
+
+		// Mutex is still locked as the loop continues
+	}
+}
+
+// Write adds data to the history buffer, and also writes it to all current
+// observers.
+func (broadcaster *Broadcaster) Write(p []byte) (n int, err error) {
+	broadcaster.Lock()
+	defer broadcaster.Unlock()
+
+	// Is the broadcaster closed? If so, the write should fail.
+	if broadcaster.closed() {
+		return 0, errors.New("attempted write to closed progressreader Broadcaster")
+	}
+
+	broadcaster.history.Write(p)
+	broadcaster.cond.Broadcast()
+
+	return len(p), nil
+}
+
+// Add adds an observer to the Broadcaster. The new observer receives the
+// data from the history buffer, and also all subsequent data.
+func (broadcaster *Broadcaster) Add(w io.Writer) error {
+	// The lock is acquired here so that Add can't race with Close
+	broadcaster.Lock()
+	defer broadcaster.Unlock()
+
+	if broadcaster.closed() {
+		return errors.New("attempted to add observer to closed progressreader Broadcaster")
+	}
+
+	broadcaster.wg.Add(1)
+	go broadcaster.receiveWrites(w)
+
+	return nil
+}
+
+// Close signals to all observers that the operation has finished.
+func (broadcaster *Broadcaster) Close() {
+	broadcaster.Lock()
+	if broadcaster.isClosed {
+		broadcaster.Unlock()
+		return
+	}
+	broadcaster.isClosed = true
+	close(broadcaster.c)
+	broadcaster.cond.Broadcast()
+	broadcaster.Unlock()
+
+	// Don't return from Close until all writers have caught up.
+	broadcaster.wg.Wait()
+}
+
+// Wait blocks until the operation is marked as completed by the Done method.
+func (broadcaster *Broadcaster) Wait() {
+	<-broadcaster.c
+}

+ 0 - 72
pkg/progressreader/progressstatus.go

@@ -1,72 +0,0 @@
-package progressreader
-
-import (
-	"bytes"
-	"io"
-	"sync"
-
-	"github.com/docker/docker/vendor/src/github.com/Sirupsen/logrus"
-)
-
-type ProgressStatus struct {
-	sync.Mutex
-	c         chan struct{}
-	observers []io.Writer
-	history   bytes.Buffer
-}
-
-func NewProgressStatus() *ProgressStatus {
-	return &ProgressStatus{
-		c:         make(chan struct{}),
-		observers: []io.Writer{},
-	}
-}
-
-func (ps *ProgressStatus) Write(p []byte) (n int, err error) {
-	ps.Lock()
-	defer ps.Unlock()
-	ps.history.Write(p)
-	for _, w := range ps.observers {
-		// copy paste from MultiWriter, replaced return with continue
-		n, err = w.Write(p)
-		if err != nil {
-			continue
-		}
-		if n != len(p) {
-			err = io.ErrShortWrite
-			continue
-		}
-	}
-	return len(p), nil
-}
-
-func (ps *ProgressStatus) AddObserver(w io.Writer) {
-	ps.Lock()
-	defer ps.Unlock()
-	w.Write(ps.history.Bytes())
-	ps.observers = append(ps.observers, w)
-}
-
-func (ps *ProgressStatus) Done() {
-	ps.Lock()
-	close(ps.c)
-	ps.history.Reset()
-	ps.Unlock()
-}
-
-func (ps *ProgressStatus) Wait(w io.Writer, msg []byte) error {
-	ps.Lock()
-	channel := ps.c
-	ps.Unlock()
-
-	if channel == nil {
-		// defensive
-		logrus.Debugf("Channel is nil ")
-	}
-	if w != nil {
-		w.Write(msg)
-		ps.AddObserver(w)
-	}
-	<-channel
-	return nil
-}