diff --git a/graph/load.go b/graph/load.go index 8f7efa6af5..a58c5a3cf9 100644 --- a/graph/load.go +++ b/graph/load.go @@ -106,14 +106,10 @@ func (s *TagStore) recursiveLoad(address, tmpImageDir string) error { } // ensure no two downloads of the same layer happen at the same time - if c, err := s.poolAdd("pull", "layer:"+img.ID); err != nil { - if c != nil { - logrus.Debugf("Image (id: %s) load is already running, waiting: %v", img.ID, err) - <-c - return nil - } - - return err + if ps, found := s.poolAdd("pull", "layer:"+img.ID); found { + logrus.Debugf("Image (id: %s) load is already running, waiting", img.ID) + ps.Wait() + return nil } defer s.poolRemove("pull", "layer:"+img.ID) diff --git a/graph/pools_test.go b/graph/pools_test.go index 129a5e1fec..a7b27271b7 100644 --- a/graph/pools_test.go +++ b/graph/pools_test.go @@ -3,6 +3,7 @@ package graph import ( "testing" + "github.com/docker/docker/pkg/progressreader" "github.com/docker/docker/pkg/reexec" ) @@ -12,24 +13,21 @@ func init() { func TestPools(t *testing.T) { s := &TagStore{ - pullingPool: make(map[string]chan struct{}), - pushingPool: make(map[string]chan struct{}), + pullingPool: make(map[string]*progressreader.Broadcaster), + pushingPool: make(map[string]*progressreader.Broadcaster), } - if _, err := s.poolAdd("pull", "test1"); err != nil { - t.Fatal(err) + if _, found := s.poolAdd("pull", "test1"); found { + t.Fatal("Expected pull test1 not to be in progress") } - if _, err := s.poolAdd("pull", "test2"); err != nil { - t.Fatal(err) + if _, found := s.poolAdd("pull", "test2"); found { + t.Fatal("Expected pull test2 not to be in progress") } - if _, err := s.poolAdd("push", "test1"); err == nil || err.Error() != "pull test1 is already in progress" { - t.Fatalf("Expected `pull test1 is already in progress`") + if _, found := s.poolAdd("push", "test1"); !found { + t.Fatalf("Expected pull test1 to be in progress`") } - if _, err := s.poolAdd("pull", "test1"); err == nil || err.Error() != "pull test1 is already in progress" { - t.Fatalf("Expected `pull test1 is already in progress`") - } - if _, err := s.poolAdd("wait", "test3"); err == nil || err.Error() != "Unknown pool type" { - t.Fatalf("Expected `Unknown pool type`") + if _, found := s.poolAdd("pull", "test1"); !found { + t.Fatalf("Expected pull test1 to be in progress`") } if err := s.poolRemove("pull", "test2"); err != nil { t.Fatal(err) @@ -43,7 +41,4 @@ func TestPools(t *testing.T) { if err := s.poolRemove("push", "test1"); err != nil { t.Fatal(err) } - if err := s.poolRemove("wait", "test3"); err == nil || err.Error() != "Unknown pool type" { - t.Fatalf("Expected `Unknown pool type`") - } } diff --git a/graph/pull_v1.go b/graph/pull_v1.go index 0280f1f950..13741292fd 100644 --- a/graph/pull_v1.go +++ b/graph/pull_v1.go @@ -3,6 +3,7 @@ package graph import ( "errors" "fmt" + "io" "net" "net/url" "strings" @@ -137,31 +138,29 @@ func (p *v1Puller) pullRepository(askedTag string) error { } // ensure no two downloads of the same image happen at the same time - if c, err := p.poolAdd("pull", "img:"+img.ID); err != nil { - if c != nil { - out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Layer already being pulled by another client. Waiting.", nil)) - <-c - out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil)) - } else { - logrus.Debugf("Image (id: %s) pull is already running, skipping: %v", img.ID, err) - } + broadcaster, found := p.poolAdd("pull", "img:"+img.ID) + if found { + broadcaster.Add(out) + broadcaster.Wait() + out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil)) errors <- nil return } + broadcaster.Add(out) defer p.poolRemove("pull", "img:"+img.ID) // we need to retain it until tagging p.graph.Retain(sessionID, img.ID) imgIDs = append(imgIDs, img.ID) - out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s", img.Tag, p.repoInfo.CanonicalName), nil)) + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s", img.Tag, p.repoInfo.CanonicalName), nil)) success := false var lastErr, err error var isDownloaded bool for _, ep := range p.repoInfo.Index.Mirrors { ep += "v1/" - out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, mirror: %s", img.Tag, p.repoInfo.CanonicalName, ep), nil)) - if isDownloaded, err = p.pullImage(img.ID, ep, repoData.Tokens); err != nil { + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, mirror: %s", img.Tag, p.repoInfo.CanonicalName, ep), nil)) + if isDownloaded, err = p.pullImage(broadcaster, img.ID, ep, repoData.Tokens); err != nil { // Don't report errors when pulling from mirrors. logrus.Debugf("Error pulling image (%s) from %s, mirror: %s, %s", img.Tag, p.repoInfo.CanonicalName, ep, err) continue @@ -172,12 +171,12 @@ func (p *v1Puller) pullRepository(askedTag string) error { } if !success { for _, ep := range repoData.Endpoints { - out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, endpoint: %s", img.Tag, p.repoInfo.CanonicalName, ep), nil)) - if isDownloaded, err = p.pullImage(img.ID, ep, repoData.Tokens); err != nil { + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, endpoint: %s", img.Tag, p.repoInfo.CanonicalName, ep), nil)) + if isDownloaded, err = p.pullImage(broadcaster, img.ID, ep, repoData.Tokens); err != nil { // It's not ideal that only the last error is returned, it would be better to concatenate the errors. // As the error is also given to the output stream the user will see the error. lastErr = err - out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Error pulling image (%s) from %s, endpoint: %s, %s", img.Tag, p.repoInfo.CanonicalName, ep, err), nil)) + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Error pulling image (%s) from %s, endpoint: %s, %s", img.Tag, p.repoInfo.CanonicalName, ep, err), nil)) continue } layersDownloaded = layersDownloaded || isDownloaded @@ -187,11 +186,11 @@ func (p *v1Puller) pullRepository(askedTag string) error { } if !success { err := fmt.Errorf("Error pulling image (%s) from %s, %v", img.Tag, p.repoInfo.CanonicalName, lastErr) - out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), err.Error(), nil)) + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), err.Error(), nil)) errors <- err return } - out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil)) + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil)) errors <- nil } @@ -226,12 +225,11 @@ func (p *v1Puller) pullRepository(askedTag string) error { return nil } -func (p *v1Puller) pullImage(imgID, endpoint string, token []string) (bool, error) { +func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []string) (bool, error) { history, err := p.session.GetRemoteHistory(imgID, endpoint) if err != nil { return false, err } - out := p.config.OutStream out.Write(p.sf.FormatProgress(stringid.TruncateID(imgID), "Pulling dependent layers", nil)) // FIXME: Try to stream the images? // FIXME: Launch the getRemoteImage() in goroutines @@ -246,14 +244,18 @@ func (p *v1Puller) pullImage(imgID, endpoint string, token []string) (bool, erro id := history[i] // ensure no two downloads of the same layer happen at the same time - if c, err := p.poolAdd("pull", "layer:"+id); err != nil { - logrus.Debugf("Image (id: %s) pull is already running, skipping: %v", id, err) - <-c + broadcaster, found := p.poolAdd("pull", "layer:"+id) + if found { + logrus.Debugf("Image (id: %s) pull is already running, skipping", id) + broadcaster.Add(out) + broadcaster.Wait() + } else { + broadcaster.Add(out) } defer p.poolRemove("pull", "layer:"+id) if !p.graph.Exists(id) { - out.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Pulling metadata", nil)) + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Pulling metadata", nil)) var ( imgJSON []byte imgSize int64 @@ -264,7 +266,7 @@ func (p *v1Puller) pullImage(imgID, endpoint string, token []string) (bool, erro for j := 1; j <= retries; j++ { imgJSON, imgSize, err = p.session.GetRemoteImageJSON(id, endpoint) if err != nil && j == retries { - out.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error pulling dependent layers", nil)) + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error pulling dependent layers", nil)) return layersDownloaded, err } else if err != nil { time.Sleep(time.Duration(j) * 500 * time.Millisecond) @@ -273,7 +275,7 @@ func (p *v1Puller) pullImage(imgID, endpoint string, token []string) (bool, erro img, err = image.NewImgJSON(imgJSON) layersDownloaded = true if err != nil && j == retries { - out.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error pulling dependent layers", nil)) + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error pulling dependent layers", nil)) return layersDownloaded, fmt.Errorf("Failed to parse json: %s", err) } else if err != nil { time.Sleep(time.Duration(j) * 500 * time.Millisecond) @@ -289,7 +291,7 @@ func (p *v1Puller) pullImage(imgID, endpoint string, token []string) (bool, erro if j > 1 { status = fmt.Sprintf("Pulling fs layer [retries: %d]", j) } - out.Write(p.sf.FormatProgress(stringid.TruncateID(id), status, nil)) + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), status, nil)) layer, err := p.session.GetRemoteImageLayer(img.ID, endpoint, imgSize) if uerr, ok := err.(*url.Error); ok { err = uerr.Err @@ -298,7 +300,7 @@ func (p *v1Puller) pullImage(imgID, endpoint string, token []string) (bool, erro time.Sleep(time.Duration(j) * 500 * time.Millisecond) continue } else if err != nil { - out.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error pulling dependent layers", nil)) + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error pulling dependent layers", nil)) return layersDownloaded, err } layersDownloaded = true @@ -307,7 +309,7 @@ func (p *v1Puller) pullImage(imgID, endpoint string, token []string) (bool, erro err = p.graph.Register(img, progressreader.New(progressreader.Config{ In: layer, - Out: out, + Out: broadcaster, Formatter: p.sf, Size: imgSize, NewLines: false, @@ -318,14 +320,14 @@ func (p *v1Puller) pullImage(imgID, endpoint string, token []string) (bool, erro time.Sleep(time.Duration(j) * 500 * time.Millisecond) continue } else if err != nil { - out.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error downloading dependent layers", nil)) + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error downloading dependent layers", nil)) return layersDownloaded, err } else { break } } } - out.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Download complete", nil)) + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Download complete", nil)) } return layersDownloaded, nil } diff --git a/graph/pull_v2.go b/graph/pull_v2.go index 0b21c1f75a..96f47b873a 100644 --- a/graph/pull_v2.go +++ b/graph/pull_v2.go @@ -73,30 +73,28 @@ func (p *v2Puller) pullV2Repository(tag string) (err error) { } - c, err := p.poolAdd("pull", taggedName) - if err != nil { - if c != nil { - // Another pull of the same repository is already taking place; just wait for it to finish - p.config.OutStream.Write(p.sf.FormatStatus("", "Repository %s already being pulled by another client. Waiting.", p.repoInfo.CanonicalName)) - <-c - return nil - } - return err + broadcaster, found := p.poolAdd("pull", taggedName) + if found { + // Another pull of the same repository is already taking place; just wait for it to finish + broadcaster.Add(p.config.OutStream) + broadcaster.Wait() + return nil } defer p.poolRemove("pull", taggedName) + broadcaster.Add(p.config.OutStream) var layersDownloaded bool for _, tag := range tags { // pulledNew is true if either new layers were downloaded OR if existing images were newly tagged // TODO(tiborvass): should we change the name of `layersDownload`? What about message in WriteStatus? - pulledNew, err := p.pullV2Tag(tag, taggedName) + pulledNew, err := p.pullV2Tag(broadcaster, tag, taggedName) if err != nil { return err } layersDownloaded = layersDownloaded || pulledNew } - writeStatus(taggedName, p.config.OutStream, p.sf, layersDownloaded) + writeStatus(taggedName, broadcaster, p.sf, layersDownloaded) return nil } @@ -121,18 +119,16 @@ func (p *v2Puller) download(di *downloadInfo) { out := di.out - if c, err := p.poolAdd("pull", "img:"+di.img.ID); err != nil { - if c != nil { - out.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Layer already being pulled by another client. Waiting.", nil)) - <-c - out.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Download complete", nil)) - } else { - logrus.Debugf("Image (id: %s) pull is already running, skipping: %v", di.img.ID, err) - } + broadcaster, found := p.poolAdd("pull", "img:"+di.img.ID) + if found { + broadcaster.Add(out) + broadcaster.Wait() + out.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Download complete", nil)) di.err <- nil return } + broadcaster.Add(out) defer p.poolRemove("pull", "img:"+di.img.ID) tmpFile, err := ioutil.TempFile("", "GetImageBlob") if err != nil { @@ -167,7 +163,7 @@ func (p *v2Puller) download(di *downloadInfo) { reader := progressreader.New(progressreader.Config{ In: ioutil.NopCloser(io.TeeReader(layerDownload, verifier)), - Out: out, + Out: broadcaster, Formatter: p.sf, Size: di.size, NewLines: false, @@ -176,7 +172,7 @@ func (p *v2Puller) download(di *downloadInfo) { }) io.Copy(tmpFile, reader) - out.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Verifying Checksum", nil)) + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Verifying Checksum", nil)) if !verifier.Verified() { err = fmt.Errorf("filesystem layer verification failed for digest %s", di.digest) @@ -185,7 +181,7 @@ func (p *v2Puller) download(di *downloadInfo) { return } - out.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Download complete", nil)) + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Download complete", nil)) logrus.Debugf("Downloaded %s to tempfile %s", di.img.ID, tmpFile.Name()) di.layer = layerDownload @@ -193,9 +189,8 @@ func (p *v2Puller) download(di *downloadInfo) { di.err <- nil } -func (p *v2Puller) pullV2Tag(tag, taggedName string) (verified bool, err error) { +func (p *v2Puller) pullV2Tag(out io.Writer, tag, taggedName string) (verified bool, err error) { logrus.Debugf("Pulling tag from V2 registry: %q", tag) - out := p.config.OutStream manSvc, err := p.repo.Manifests(context.Background()) if err != nil { diff --git a/graph/push_v1.go b/graph/push_v1.go index 6de293403b..9769ac711f 100644 --- a/graph/push_v1.go +++ b/graph/push_v1.go @@ -214,7 +214,6 @@ func (p *v1Pusher) pushImageToEndpoint(endpoint string, imageIDs []string, tags // pushRepository pushes layers that do not already exist on the registry. func (p *v1Pusher) pushRepository(tag string) error { - logrus.Debugf("Local repo: %s", p.localRepo) p.out = ioutils.NewWriteFlusher(p.config.OutStream) imgList, tags, err := p.getImageList(tag) @@ -229,8 +228,8 @@ func (p *v1Pusher) pushRepository(tag string) error { logrus.Debugf("Pushing ID: %s with Tag: %s", data.ID, data.Tag) } - if _, err := p.poolAdd("push", p.repoInfo.LocalName); err != nil { - return err + if _, found := p.poolAdd("push", p.repoInfo.LocalName); found { + return fmt.Errorf("push or pull %s is already in progress", p.repoInfo.LocalName) } defer p.poolRemove("push", p.repoInfo.LocalName) diff --git a/graph/push_v2.go b/graph/push_v2.go index 6823ac5928..7d5ca44d96 100644 --- a/graph/push_v2.go +++ b/graph/push_v2.go @@ -57,8 +57,8 @@ func (p *v2Pusher) getImageTags(askedTag string) ([]string, error) { func (p *v2Pusher) pushV2Repository(tag string) error { localName := p.repoInfo.LocalName - if _, err := p.poolAdd("push", localName); err != nil { - return err + if _, found := p.poolAdd("push", localName); found { + return fmt.Errorf("push or pull %s is already in progress", localName) } defer p.poolRemove("push", localName) diff --git a/graph/tags.go b/graph/tags.go index 293ac33186..09a9bc8530 100644 --- a/graph/tags.go +++ b/graph/tags.go @@ -17,6 +17,7 @@ import ( "github.com/docker/docker/graph/tags" "github.com/docker/docker/image" "github.com/docker/docker/pkg/parsers" + "github.com/docker/docker/pkg/progressreader" "github.com/docker/docker/pkg/stringid" "github.com/docker/docker/registry" "github.com/docker/docker/trust" @@ -36,8 +37,8 @@ type TagStore struct { sync.Mutex // FIXME: move push/pull-related fields // to a helper type - pullingPool map[string]chan struct{} - pushingPool map[string]chan struct{} + pullingPool map[string]*progressreader.Broadcaster + pushingPool map[string]*progressreader.Broadcaster registryService *registry.Service eventsService *events.Events trustService *trust.Store @@ -93,8 +94,8 @@ func NewTagStore(path string, cfg *TagStoreConfig) (*TagStore, error) { graph: cfg.Graph, trustKey: cfg.Key, Repositories: make(map[string]Repository), - pullingPool: make(map[string]chan struct{}), - pushingPool: make(map[string]chan struct{}), + pullingPool: make(map[string]*progressreader.Broadcaster), + pushingPool: make(map[string]*progressreader.Broadcaster), registryService: cfg.Registry, eventsService: cfg.Events, trustService: cfg.Trust, @@ -433,27 +434,32 @@ func validateDigest(dgst string) error { return nil } -func (store *TagStore) poolAdd(kind, key string) (chan struct{}, error) { +// poolAdd checks if a push or pull is already running, and returns +// (broadcaster, true) if a running operation is found. Otherwise, it creates a +// new one and returns (broadcaster, false). +func (store *TagStore) poolAdd(kind, key string) (*progressreader.Broadcaster, bool) { store.Lock() defer store.Unlock() - if c, exists := store.pullingPool[key]; exists { - return c, fmt.Errorf("pull %s is already in progress", key) + if p, exists := store.pullingPool[key]; exists { + return p, true } - if c, exists := store.pushingPool[key]; exists { - return c, fmt.Errorf("push %s is already in progress", key) + if p, exists := store.pushingPool[key]; exists { + return p, true } - c := make(chan struct{}) + broadcaster := progressreader.NewBroadcaster() + switch kind { case "pull": - store.pullingPool[key] = c + store.pullingPool[key] = broadcaster case "push": - store.pushingPool[key] = c + store.pushingPool[key] = broadcaster default: - return nil, fmt.Errorf("Unknown pool type") + panic("Unknown pool type") } - return c, nil + + return broadcaster, false } func (store *TagStore) poolRemove(kind, key string) error { @@ -461,13 +467,13 @@ func (store *TagStore) poolRemove(kind, key string) error { defer store.Unlock() switch kind { case "pull": - if c, exists := store.pullingPool[key]; exists { - close(c) + if ps, exists := store.pullingPool[key]; exists { + ps.Close() delete(store.pullingPool, key) } case "push": - if c, exists := store.pushingPool[key]; exists { - close(c) + if ps, exists := store.pushingPool[key]; exists { + ps.Close() delete(store.pushingPool, key) } default: diff --git a/pkg/progressreader/broadcaster.go b/pkg/progressreader/broadcaster.go new file mode 100644 index 0000000000..429b1d0f1b --- /dev/null +++ b/pkg/progressreader/broadcaster.go @@ -0,0 +1,152 @@ +package progressreader + +import ( + "errors" + "io" + "sync" +) + +// Broadcaster keeps track of one or more observers watching the progress +// of an operation. For example, if multiple clients are trying to pull an +// image, they share a Broadcaster for the download operation. +type Broadcaster struct { + sync.Mutex + // c is a channel that observers block on, waiting for the operation + // to finish. + c chan struct{} + // cond is a condition variable used to wake up observers when there's + // new data available. + cond *sync.Cond + // history is a buffer of the progress output so far, so a new observer + // can catch up. The history is stored as a slice of separate byte + // slices, so that if the writer is a WriteFlusher, the flushes will + // happen in the right places. + history [][]byte + // wg is a WaitGroup used to wait for all writes to finish on Close + wg sync.WaitGroup + // isClosed is set to true when Close is called to avoid closing c + // multiple times. + isClosed bool +} + +// NewBroadcaster returns a Broadcaster structure +func NewBroadcaster() *Broadcaster { + b := &Broadcaster{ + c: make(chan struct{}), + } + b.cond = sync.NewCond(b) + return b +} + +// closed returns true if and only if the broadcaster has been closed +func (broadcaster *Broadcaster) closed() bool { + select { + case <-broadcaster.c: + return true + default: + return false + } +} + +// receiveWrites runs as a goroutine so that writes don't block the Write +// function. It writes the new data in broadcaster.history each time there's +// activity on the broadcaster.cond condition variable. +func (broadcaster *Broadcaster) receiveWrites(observer io.Writer) { + n := 0 + + broadcaster.Lock() + + // The condition variable wait is at the end of this loop, so that the + // first iteration will write the history so far. + for { + newData := broadcaster.history[n:] + // Make a copy of newData so we can release the lock + sendData := make([][]byte, len(newData), len(newData)) + copy(sendData, newData) + broadcaster.Unlock() + + for len(sendData) > 0 { + _, err := observer.Write(sendData[0]) + if err != nil { + broadcaster.wg.Done() + return + } + n++ + sendData = sendData[1:] + } + + broadcaster.Lock() + + // detect closure of the broadcast writer + if broadcaster.closed() { + broadcaster.Unlock() + broadcaster.wg.Done() + return + } + + if len(broadcaster.history) == n { + broadcaster.cond.Wait() + } + + // Mutex is still locked as the loop continues + } +} + +// Write adds data to the history buffer, and also writes it to all current +// observers. +func (broadcaster *Broadcaster) Write(p []byte) (n int, err error) { + broadcaster.Lock() + defer broadcaster.Unlock() + + // Is the broadcaster closed? If so, the write should fail. + if broadcaster.closed() { + return 0, errors.New("attempted write to closed progressreader Broadcaster") + } + + // Add message in p to the history slice + newEntry := make([]byte, len(p), len(p)) + copy(newEntry, p) + broadcaster.history = append(broadcaster.history, newEntry) + + broadcaster.cond.Broadcast() + + return len(p), nil +} + +// Add adds an observer to the Broadcaster. The new observer receives the +// data from the history buffer, and also all subsequent data. +func (broadcaster *Broadcaster) Add(w io.Writer) error { + // The lock is acquired here so that Add can't race with Close + broadcaster.Lock() + defer broadcaster.Unlock() + + if broadcaster.closed() { + return errors.New("attempted to add observer to closed progressreader Broadcaster") + } + + broadcaster.wg.Add(1) + go broadcaster.receiveWrites(w) + + return nil +} + +// Close signals to all observers that the operation has finished. +func (broadcaster *Broadcaster) Close() { + broadcaster.Lock() + if broadcaster.isClosed { + broadcaster.Unlock() + return + } + broadcaster.isClosed = true + close(broadcaster.c) + broadcaster.cond.Broadcast() + broadcaster.Unlock() + + // Don't return from Close until all writers have caught up. + broadcaster.wg.Wait() +} + +// Wait blocks until the operation is marked as completed by the Done method. +func (broadcaster *Broadcaster) Wait() { + <-broadcaster.c +}