Переглянути джерело

Fix race conditions in parallel pull

During parallel pull of a repostiory it can happen that the same layer
is pulled more than once.

To fix this I have extended the locking code to
- avoid multiple pulls of the same image
- avoid multiple pulls of the same layer


If an error occurs the other layers are awaited before returning as leaving
the scope before the go routines leave causes crashes of the server sometimes
if the download status is updated while the http stream is already closed


Beside this I have extended status display.
Marco Hennings 12 роки тому
батько
коміт
3f802f4a13
1 змінених файлів з 55 додано та 9 видалено
  1. 55 9
      server.go

+ 55 - 9
server.go

@@ -419,19 +419,30 @@ func (srv *Server) pullImage(r *registry.Registry, out io.Writer, imgID, endpoin
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-
+	out.Write(sf.FormatProgress(utils.TruncateID(imgID), "Pulling", "dependend layers"))
 	// FIXME: Try to stream the images?
 	// FIXME: Try to stream the images?
 	// FIXME: Launch the getRemoteImage() in goroutines
 	// FIXME: Launch the getRemoteImage() in goroutines
+
 	for _, id := range history {
 	for _, id := range history {
+
+		// ensure no two downloads of the same layer happen at the same time
+		if err := srv.poolAdd("pull", "layer:"+id); err != nil {
+			utils.Debugf("Image (id: %s) pull is already running, skipping: %v", id, err)
+			return nil
+		}
+		defer srv.poolRemove("pull", "layer:"+id)
+
 		if !srv.runtime.graph.Exists(id) {
 		if !srv.runtime.graph.Exists(id) {
 			out.Write(sf.FormatProgress(utils.TruncateID(id), "Pulling", "metadata"))
 			out.Write(sf.FormatProgress(utils.TruncateID(id), "Pulling", "metadata"))
 			imgJSON, imgSize, err := r.GetRemoteImageJSON(id, endpoint, token)
 			imgJSON, imgSize, err := r.GetRemoteImageJSON(id, endpoint, token)
 			if err != nil {
 			if err != nil {
+				out.Write(sf.FormatProgress(utils.TruncateID(id), "Error", "pulling dependend layers"))
 				// FIXME: Keep going in case of error?
 				// FIXME: Keep going in case of error?
 				return err
 				return err
 			}
 			}
 			img, err := NewImgJSON(imgJSON)
 			img, err := NewImgJSON(imgJSON)
 			if err != nil {
 			if err != nil {
+				out.Write(sf.FormatProgress(utils.TruncateID(id), "Error", "pulling dependend layers"))
 				return fmt.Errorf("Failed to parse json: %s", err)
 				return fmt.Errorf("Failed to parse json: %s", err)
 			}
 			}
 
 
@@ -439,13 +450,17 @@ func (srv *Server) pullImage(r *registry.Registry, out io.Writer, imgID, endpoin
 			out.Write(sf.FormatProgress(utils.TruncateID(id), "Pulling", "fs layer"))
 			out.Write(sf.FormatProgress(utils.TruncateID(id), "Pulling", "fs layer"))
 			layer, err := r.GetRemoteImageLayer(img.ID, endpoint, token)
 			layer, err := r.GetRemoteImageLayer(img.ID, endpoint, token)
 			if err != nil {
 			if err != nil {
+				out.Write(sf.FormatProgress(utils.TruncateID(id), "Error", "pulling dependend layers"))
 				return err
 				return err
 			}
 			}
 			defer layer.Close()
 			defer layer.Close()
 			if err := srv.runtime.graph.Register(imgJSON, utils.ProgressReader(layer, imgSize, out, sf.FormatProgress(utils.TruncateID(id), "Downloading", "%8v/%v (%v)"), sf, false), img); err != nil {
 			if err := srv.runtime.graph.Register(imgJSON, utils.ProgressReader(layer, imgSize, out, sf.FormatProgress(utils.TruncateID(id), "Downloading", "%8v/%v (%v)"), sf, false), img); err != nil {
+				out.Write(sf.FormatProgress(utils.TruncateID(id), "Error", "downloading dependend layers"))
 				return err
 				return err
 			}
 			}
 		}
 		}
+		out.Write(sf.FormatProgress(utils.TruncateID(id), "Download", "complete"))
+
 	}
 	}
 	return nil
 	return nil
 }
 }
@@ -493,29 +508,57 @@ func (srv *Server) pullRepository(r *registry.Registry, out io.Writer, localName
 		downloadImage := func(img *registry.ImgData) {
 		downloadImage := func(img *registry.ImgData) {
 			if askedTag != "" && img.Tag != askedTag {
 			if askedTag != "" && img.Tag != askedTag {
 				utils.Debugf("(%s) does not match %s (id: %s), skipping", img.Tag, askedTag, img.ID)
 				utils.Debugf("(%s) does not match %s (id: %s), skipping", img.Tag, askedTag, img.ID)
-				errors <- nil
+				if parallel {
+					errors <- nil
+				}
 				return
 				return
 			}
 			}
 
 
 			if img.Tag == "" {
 			if img.Tag == "" {
 				utils.Debugf("Image (id: %s) present in this repository but untagged, skipping", img.ID)
 				utils.Debugf("Image (id: %s) present in this repository but untagged, skipping", img.ID)
-				errors <- nil
+				if parallel {
+					errors <- nil
+				}
 				return
 				return
 			}
 			}
+
+			// ensure no two downloads of the same image happen at the same time
+			if err := srv.poolAdd("pull", "img:"+img.ID); err != nil {
+				utils.Debugf("Image (id: %s) pull is already running, skipping: %v", img.ID, err)
+				if parallel {
+					errors <- nil
+				}
+				return
+			}
+			defer srv.poolRemove("pull", "img:"+img.ID)
+
 			out.Write(sf.FormatProgress(utils.TruncateID(img.ID), "Pulling", fmt.Sprintf("image (%s) from %s", img.Tag, localName)))
 			out.Write(sf.FormatProgress(utils.TruncateID(img.ID), "Pulling", fmt.Sprintf("image (%s) from %s", img.Tag, localName)))
 			success := false
 			success := false
+			var lastErr error
 			for _, ep := range repoData.Endpoints {
 			for _, ep := range repoData.Endpoints {
+				out.Write(sf.FormatProgress(utils.TruncateID(img.ID), "Pulling", fmt.Sprintf("image (%s) from %s, endpoint: %s", img.Tag, localName, ep)))
 				if err := srv.pullImage(r, out, img.ID, ep, repoData.Tokens, sf); err != nil {
 				if err := srv.pullImage(r, out, img.ID, ep, repoData.Tokens, sf); err != nil {
-					out.Write(sf.FormatStatus(utils.TruncateID(img.ID), "Error while retrieving image for tag: %s (%s); checking next endpoint", askedTag, err))
+					// Its not ideal that only the last error  is returned, it would be better to concatenate the errors.
+					// As the error is also given to the output stream the user will see the error.
+					lastErr = err
+					out.Write(sf.FormatProgress(utils.TruncateID(img.ID), "Error pulling", fmt.Sprintf("image (%s) from %s, endpoint: %s, %s", img.Tag, localName, ep, err)))
 					continue
 					continue
 				}
 				}
 				success = true
 				success = true
 				break
 				break
 			}
 			}
 			if !success {
 			if !success {
-				errors <- fmt.Errorf("Could not find repository on any of the indexed registries.")
+				out.Write(sf.FormatProgress(utils.TruncateID(img.ID), "Error pulling", fmt.Sprintf("image (%s) from %s, %s", img.Tag, localName, lastErr)))
+				if parallel {
+					errors <- fmt.Errorf("Could not find repository on any of the indexed registries.")
+					return
+				}
+			}
+			out.Write(sf.FormatProgress(utils.TruncateID(img.ID), "Download", "complete"))
+
+			if parallel {
+				errors <- nil
 			}
 			}
-			errors <- nil
 		}
 		}
 
 
 		if parallel {
 		if parallel {
@@ -524,15 +567,18 @@ func (srv *Server) pullRepository(r *registry.Registry, out io.Writer, localName
 			downloadImage(image)
 			downloadImage(image)
 		}
 		}
 	}
 	}
-
 	if parallel {
 	if parallel {
+		var lastError error
 		for i := 0; i < len(repoData.ImgList); i++ {
 		for i := 0; i < len(repoData.ImgList); i++ {
 			if err := <-errors; err != nil {
 			if err := <-errors; err != nil {
-				return err
+				lastError = err
 			}
 			}
 		}
 		}
-	}
+		if lastError != nil {
+			return lastError
+		}
 
 
+	}
 	for tag, id := range tagsList {
 	for tag, id := range tagsList {
 		if askedTag != "" && tag != askedTag {
 		if askedTag != "" && tag != askedTag {
 			continue
 			continue