瀏覽代碼

Merge pull request #6048 from crquan/use-timeout-conn

Use Timeout Conn wrapper to set read deadline for downloading layer
Michael Crosby 11 年之前
父節點
當前提交
a86975f29b
共有 3 個文件被更改,包括 70 次插入11 次删除
  1. 11 0
      registry/registry.go
  2. 33 11
      server/server.go
  3. 26 0
      utils/timeoutconn.go

+ 11 - 0
registry/registry.go

@@ -726,7 +726,17 @@ type Registry struct {
 }
 }
 
 
 func NewRegistry(authConfig *AuthConfig, factory *utils.HTTPRequestFactory, indexEndpoint string) (r *Registry, err error) {
 func NewRegistry(authConfig *AuthConfig, factory *utils.HTTPRequestFactory, indexEndpoint string) (r *Registry, err error) {
+	httpDial := func(proto string, addr string) (net.Conn, error) {
+		conn, err := net.Dial(proto, addr)
+		if err != nil {
+			return nil, err
+		}
+		conn = utils.NewTimeoutConn(conn, time.Duration(1)*time.Minute)
+		return conn, nil
+	}
+
 	httpTransport := &http.Transport{
 	httpTransport := &http.Transport{
+		Dial:              httpDial,
 		DisableKeepAlives: true,
 		DisableKeepAlives: true,
 		Proxy:             http.ProxyFromEnvironment,
 		Proxy:             http.ProxyFromEnvironment,
 	}
 	}
@@ -738,6 +748,7 @@ func NewRegistry(authConfig *AuthConfig, factory *utils.HTTPRequestFactory, inde
 		},
 		},
 		indexEndpoint: indexEndpoint,
 		indexEndpoint: indexEndpoint,
 	}
 	}
+
 	r.client.Jar, err = cookiejar.New(nil)
 	r.client.Jar, err = cookiejar.New(nil)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err

+ 33 - 11
server/server.go

@@ -27,6 +27,7 @@ import (
 	"io"
 	"io"
 	"io/ioutil"
 	"io/ioutil"
 	"log"
 	"log"
+	"net"
 	"net/http"
 	"net/http"
 	"net/url"
 	"net/url"
 	"os"
 	"os"
@@ -1134,17 +1135,38 @@ func (srv *Server) pullImage(r *registry.Registry, out io.Writer, imgID, endpoin
 				}
 				}
 			}
 			}
 
 
-			// Get the layer
-			out.Write(sf.FormatProgress(utils.TruncateID(id), "Pulling fs layer", nil))
-			layer, err := r.GetRemoteImageLayer(img.ID, endpoint, token)
-			if err != nil {
-				out.Write(sf.FormatProgress(utils.TruncateID(id), "Error pulling dependent layers", nil))
-				return err
-			}
-			defer layer.Close()
-			if err := srv.daemon.Graph().Register(imgJSON, utils.ProgressReader(layer, imgSize, out, sf, false, utils.TruncateID(id), "Downloading"), img); err != nil {
-				out.Write(sf.FormatProgress(utils.TruncateID(id), "Error downloading dependent layers", nil))
-				return err
+			for j := 1; j <= retries; j++ {
+				// Get the layer
+				status := "Pulling fs layer"
+				if j > 1 {
+					status = fmt.Sprintf("Pulling fs layer [retries: %d]", j)
+				}
+				out.Write(sf.FormatProgress(utils.TruncateID(id), status, nil))
+				layer, err := r.GetRemoteImageLayer(img.ID, endpoint, token)
+				if uerr, ok := err.(*url.Error); ok {
+					err = uerr.Err
+				}
+				if terr, ok := err.(net.Error); ok && terr.Timeout() && j < retries {
+					time.Sleep(time.Duration(j) * 500 * time.Millisecond)
+					continue
+				} else if err != nil {
+					out.Write(sf.FormatProgress(utils.TruncateID(id), "Error pulling dependent layers", nil))
+					return err
+				}
+				defer layer.Close()
+
+				err = srv.daemon.Graph().Register(imgJSON,
+					utils.ProgressReader(layer, imgSize, out, sf, false, utils.TruncateID(id), "Downloading"),
+					img)
+				if terr, ok := err.(net.Error); ok && terr.Timeout() && j < retries {
+					time.Sleep(time.Duration(j) * 500 * time.Millisecond)
+					continue
+				} else if err != nil {
+					out.Write(sf.FormatProgress(utils.TruncateID(id), "Error downloading dependent layers", nil))
+					return err
+				} else {
+					break
+				}
 			}
 			}
 		}
 		}
 		out.Write(sf.FormatProgress(utils.TruncateID(id), "Download complete", nil))
 		out.Write(sf.FormatProgress(utils.TruncateID(id), "Download complete", nil))

+ 26 - 0
utils/timeoutconn.go

@@ -0,0 +1,26 @@
+package utils
+
+import (
+	"net"
+	"time"
+)
+
+func NewTimeoutConn(conn net.Conn, timeout time.Duration) net.Conn {
+	return &TimeoutConn{conn, timeout}
+}
+
+// A net.Conn that sets a deadline for every Read or Write operation
+type TimeoutConn struct {
+	net.Conn
+	timeout time.Duration
+}
+
+func (c *TimeoutConn) Read(b []byte) (int, error) {
+	if c.timeout > 0 {
+		err := c.Conn.SetReadDeadline(time.Now().Add(c.timeout))
+		if err != nil {
+			return 0, err
+		}
+	}
+	return c.Conn.Read(b)
+}