Browse Source

Add native Go downloading functionality with progress bar. Ref #49, #50

Charles Hooper 12 years ago
parent
commit
1bfd827701
2 changed files with 41 additions and 18 deletions
  1. 26 4
      future/future.go
  2. 15 14
      server/server.go

+ 26 - 4
future/future.go

@@ -107,9 +107,8 @@ func Curl(url string, stderr io.Writer) (io.Reader, error) {
 }
 
 // Request a given URL and return an io.Reader
-func Download(url string, stderr io.Writer) (io.Reader, error) {
+func Download(url string, stderr io.Writer) (*http.Response, error) {
 	var resp *http.Response
-	var archive io.ReadCloser = nil
 	var err error = nil
 
 	fmt.Fprintf(stderr, "Download start\n") // FIXME: Replace with progress bar
@@ -119,7 +118,30 @@ func Download(url string, stderr io.Writer) (io.Reader, error) {
 	if resp.StatusCode >= 400 {
 		return nil, errors.New("Got HTTP status code >= 400: " + resp.Status)
 	}
-	archive = resp.Body
 	fmt.Fprintf(stderr, "Download end\n") // FIXME: Replace with progress bar
-	return archive, nil
+	return resp, nil
+}
+
+// Reader with progress bar
+type progressReader struct {
+	reader io.ReadCloser   // Stream to read from
+	output io.Writer   // Where to send progress bar to
+	read_total   int    // Expected stream length (bytes)
+	read_progress int  // How much has been read so far (bytes)
+}
+func (r *progressReader) Read(p []byte) (n int, err error) {
+	read, err := io.ReadCloser(r.reader).Read(p)
+	// FIXME: Don't print progress bar on every read
+	r.read_progress += read
+	fmt.Fprintf(r.output, "%d/%d (%.2f%%)\n", 
+		r.read_progress,
+		r.read_total,
+		float64(r.read_progress) / float64(r.read_total) * 100)
+	return read, err
+}
+func (r *progressReader) Close() error {
+	return io.ReadCloser(r.reader).Close()
+}
+func ProgressReader(r io.ReadCloser, size int, output io.Writer) *progressReader {
+	return &progressReader{r, output, size, 0}
 }

+ 15 - 14
server/server.go

@@ -12,7 +12,6 @@ import (
 	"github.com/dotcloud/docker/rcli"
 	"io"
 	"io/ioutil"
-	"net/http"
 	"net/url"
 	"os"
 	"path"
@@ -427,19 +426,11 @@ func (srv *Server) CmdPull(stdin io.ReadCloser, stdout io.Writer, args ...string
 		u.Path = path.Join("/docker.io/images", u.Path)
 	}
 	fmt.Fprintf(stdout, "Downloading from %s\n", u.String())
-	// Download with curl (pretty progress bar)
-	// If curl is not available or receives a HTTP error, fallback
-	// to http.Get()
-	archive, err := future.Curl(u.String(), stdout)
+	resp, err := future.Download(u.String(), stdout)
+	// FIXME: Validate ContentLength
+	archive := future.ProgressReader(resp.Body, int(resp.ContentLength), stdout)
 	if err != nil {
-		if resp, err := http.Get(u.String()); err != nil {
-			return err
-		} else {
-			if resp.StatusCode >= 400 {
-				return errors.New("Got HTTP status code >= 400: " + resp.Status)
-			}
-			archive = resp.Body
-		}
+		return err
 	}
 	fmt.Fprintf(stdout, "Unpacking to %s\n", name)
 	img, err := srv.images.Import(name, archive, nil)
@@ -815,7 +806,10 @@ func (srv *Server) CmdRun(stdin io.ReadCloser, stdout io.Writer, args ...string)
 		return nil
 	}
 	name := cmd.Arg(0)
+	var img_name string
+	//var img_version string  // Only here for reference
 	var cmdline []string
+
 	if len(cmd.Args()) >= 2 {
 		cmdline = cmd.Args()[1:]
 	}
@@ -823,6 +817,13 @@ func (srv *Server) CmdRun(stdin io.ReadCloser, stdout io.Writer, args ...string)
 	if name == "" {
 		name = "base"
 	}
+
+	// Separate the name:version tag
+    if strings.Contains(name, ":") {
+        parts := strings.SplitN(name, ":", 2)
+        img_name = parts[0]
+		//img_version = parts[1]   // Only here for reference
+    }
 	// Choose a default command if needed
 	if len(cmdline) == 0 {
 		*fl_stdin = true
@@ -835,7 +836,7 @@ func (srv *Server) CmdRun(stdin io.ReadCloser, stdout io.Writer, args ...string)
 	img = srv.images.Find(name)
 	if img == nil {
 		stdin_noclose := ioutil.NopCloser(stdin)
-		if err := srv.CmdPull(stdin_noclose, stdout, name); err != nil {
+		if err := srv.CmdPull(stdin_noclose, stdout, img_name); err != nil {
 			return err
 		}
 		img = srv.images.Find(name)