Ver código fonte

Merge pull request #1623 from mhennings/1592-fix-race-conditions-in-parallel-pull

Fix race conditions in parallel pull
Michael Crosby 12 anos atrás
pai
commit
262d57e387
1 arquivos alterados com 55 adições e 9 exclusões
  1. 55 9
      server.go

+ 55 - 9
server.go

@@ -419,19 +419,30 @@ func (srv *Server) pullImage(r *registry.Registry, out io.Writer, imgID, endpoin
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-
+	out.Write(sf.FormatProgress(utils.TruncateID(imgID), "Pulling", "dependend layers"))
 	// FIXME: Try to stream the images?
 	// FIXME: Try to stream the images?
 	// FIXME: Launch the getRemoteImage() in goroutines
 	// FIXME: Launch the getRemoteImage() in goroutines
+
 	for _, id := range history {
 	for _, id := range history {
+
+		// ensure no two downloads of the same layer happen at the same time
+		if err := srv.poolAdd("pull", "layer:"+id); err != nil {
+			utils.Debugf("Image (id: %s) pull is already running, skipping: %v", id, err)
+			return nil
+		}
+		defer srv.poolRemove("pull", "layer:"+id)
+
 		if !srv.runtime.graph.Exists(id) {
 		if !srv.runtime.graph.Exists(id) {
 			out.Write(sf.FormatProgress(utils.TruncateID(id), "Pulling", "metadata"))
 			out.Write(sf.FormatProgress(utils.TruncateID(id), "Pulling", "metadata"))
 			imgJSON, imgSize, err := r.GetRemoteImageJSON(id, endpoint, token)
 			imgJSON, imgSize, err := r.GetRemoteImageJSON(id, endpoint, token)
 			if err != nil {
 			if err != nil {
+				out.Write(sf.FormatProgress(utils.TruncateID(id), "Error", "pulling dependend layers"))
 				// FIXME: Keep going in case of error?
 				// FIXME: Keep going in case of error?
 				return err
 				return err
 			}
 			}
 			img, err := NewImgJSON(imgJSON)
 			img, err := NewImgJSON(imgJSON)
 			if err != nil {
 			if err != nil {
+				out.Write(sf.FormatProgress(utils.TruncateID(id), "Error", "pulling dependend layers"))
 				return fmt.Errorf("Failed to parse json: %s", err)
 				return fmt.Errorf("Failed to parse json: %s", err)
 			}
 			}
 
 
@@ -439,13 +450,17 @@ func (srv *Server) pullImage(r *registry.Registry, out io.Writer, imgID, endpoin
 			out.Write(sf.FormatProgress(utils.TruncateID(id), "Pulling", "fs layer"))
 			out.Write(sf.FormatProgress(utils.TruncateID(id), "Pulling", "fs layer"))
 			layer, err := r.GetRemoteImageLayer(img.ID, endpoint, token)
 			layer, err := r.GetRemoteImageLayer(img.ID, endpoint, token)
 			if err != nil {
 			if err != nil {
+				out.Write(sf.FormatProgress(utils.TruncateID(id), "Error", "pulling dependend layers"))
 				return err
 				return err
 			}
 			}
 			defer layer.Close()
 			defer layer.Close()
 			if err := srv.runtime.graph.Register(imgJSON, utils.ProgressReader(layer, imgSize, out, sf.FormatProgress(utils.TruncateID(id), "Downloading", "%8v/%v (%v)"), sf, false), img); err != nil {
 			if err := srv.runtime.graph.Register(imgJSON, utils.ProgressReader(layer, imgSize, out, sf.FormatProgress(utils.TruncateID(id), "Downloading", "%8v/%v (%v)"), sf, false), img); err != nil {
+				out.Write(sf.FormatProgress(utils.TruncateID(id), "Error", "downloading dependend layers"))
 				return err
 				return err
 			}
 			}
 		}
 		}
+		out.Write(sf.FormatProgress(utils.TruncateID(id), "Download", "complete"))
+
 	}
 	}
 	return nil
 	return nil
 }
 }
@@ -493,29 +508,57 @@ func (srv *Server) pullRepository(r *registry.Registry, out io.Writer, localName
 		downloadImage := func(img *registry.ImgData) {
 		downloadImage := func(img *registry.ImgData) {
 			if askedTag != "" && img.Tag != askedTag {
 			if askedTag != "" && img.Tag != askedTag {
 				utils.Debugf("(%s) does not match %s (id: %s), skipping", img.Tag, askedTag, img.ID)
 				utils.Debugf("(%s) does not match %s (id: %s), skipping", img.Tag, askedTag, img.ID)
-				errors <- nil
+				if parallel {
+					errors <- nil
+				}
 				return
 				return
 			}
 			}
 
 
 			if img.Tag == "" {
 			if img.Tag == "" {
 				utils.Debugf("Image (id: %s) present in this repository but untagged, skipping", img.ID)
 				utils.Debugf("Image (id: %s) present in this repository but untagged, skipping", img.ID)
-				errors <- nil
+				if parallel {
+					errors <- nil
+				}
 				return
 				return
 			}
 			}
+
+			// ensure no two downloads of the same image happen at the same time
+			if err := srv.poolAdd("pull", "img:"+img.ID); err != nil {
+				utils.Debugf("Image (id: %s) pull is already running, skipping: %v", img.ID, err)
+				if parallel {
+					errors <- nil
+				}
+				return
+			}
+			defer srv.poolRemove("pull", "img:"+img.ID)
+
 			out.Write(sf.FormatProgress(utils.TruncateID(img.ID), "Pulling", fmt.Sprintf("image (%s) from %s", img.Tag, localName)))
 			out.Write(sf.FormatProgress(utils.TruncateID(img.ID), "Pulling", fmt.Sprintf("image (%s) from %s", img.Tag, localName)))
 			success := false
 			success := false
+			var lastErr error
 			for _, ep := range repoData.Endpoints {
 			for _, ep := range repoData.Endpoints {
+				out.Write(sf.FormatProgress(utils.TruncateID(img.ID), "Pulling", fmt.Sprintf("image (%s) from %s, endpoint: %s", img.Tag, localName, ep)))
 				if err := srv.pullImage(r, out, img.ID, ep, repoData.Tokens, sf); err != nil {
 				if err := srv.pullImage(r, out, img.ID, ep, repoData.Tokens, sf); err != nil {
-					out.Write(sf.FormatStatus(utils.TruncateID(img.ID), "Error while retrieving image for tag: %s (%s); checking next endpoint", askedTag, err))
+					// Its not ideal that only the last error  is returned, it would be better to concatenate the errors.
+					// As the error is also given to the output stream the user will see the error.
+					lastErr = err
+					out.Write(sf.FormatProgress(utils.TruncateID(img.ID), "Error pulling", fmt.Sprintf("image (%s) from %s, endpoint: %s, %s", img.Tag, localName, ep, err)))
 					continue
 					continue
 				}
 				}
 				success = true
 				success = true
 				break
 				break
 			}
 			}
 			if !success {
 			if !success {
-				errors <- fmt.Errorf("Could not find repository on any of the indexed registries.")
+				out.Write(sf.FormatProgress(utils.TruncateID(img.ID), "Error pulling", fmt.Sprintf("image (%s) from %s, %s", img.Tag, localName, lastErr)))
+				if parallel {
+					errors <- fmt.Errorf("Could not find repository on any of the indexed registries.")
+					return
+				}
+			}
+			out.Write(sf.FormatProgress(utils.TruncateID(img.ID), "Download", "complete"))
+
+			if parallel {
+				errors <- nil
 			}
 			}
-			errors <- nil
 		}
 		}
 
 
 		if parallel {
 		if parallel {
@@ -524,15 +567,18 @@ func (srv *Server) pullRepository(r *registry.Registry, out io.Writer, localName
 			downloadImage(image)
 			downloadImage(image)
 		}
 		}
 	}
 	}
-
 	if parallel {
 	if parallel {
+		var lastError error
 		for i := 0; i < len(repoData.ImgList); i++ {
 		for i := 0; i < len(repoData.ImgList); i++ {
 			if err := <-errors; err != nil {
 			if err := <-errors; err != nil {
-				return err
+				lastError = err
 			}
 			}
 		}
 		}
-	}
+		if lastError != nil {
+			return lastError
+		}
 
 
+	}
 	for tag, id := range tagsList {
 	for tag, id := range tagsList {
 		if askedTag != "" && tag != askedTag {
 		if askedTag != "" && tag != askedTag {
 			continue
 			continue