Explorar o código

Cleanup api server creation

Current implementation is hard to reason about because of trying to mix
unix/tcp server implementations, even though they are quite different.
This cleans that up.

Also makes it possible to create and manage a new API server easily,
e.g. for adding an introspection socket to a container.

Built in such a way as to allow a non-HTTP server to work as well, such
as libchan.

Signed-off-by: Brian Goff <cpuguy83@gmail.com>
Brian Goff %!s(int64=10) %!d(string=hai) anos
pai
achega
dacae746b7
Modificáronse 1 ficheiros con 127 adicións e 72 borrados
  1. 127 72
      api/server/server.go

+ 127 - 72
api/server/server.go

@@ -3,8 +3,7 @@ package server
 import (
 	"bufio"
 	"bytes"
-	"crypto/tls"
-	"crypto/x509"
+
 	"encoding/base64"
 	"encoding/json"
 	"expvar"
@@ -19,6 +18,9 @@ import (
 	"strings"
 	"syscall"
 
+	"crypto/tls"
+	"crypto/x509"
+
 	"code.google.com/p/go.net/websocket"
 	"github.com/docker/libcontainer/user"
 	"github.com/gorilla/mux"
@@ -39,6 +41,18 @@ var (
 	activationLock chan struct{}
 )
 
+type HttpServer struct {
+	srv *http.Server
+	l   net.Listener
+}
+
+func (s *HttpServer) Serve() error {
+	return s.srv.Serve(s.l)
+}
+func (s *HttpServer) Close() error {
+	return s.l.Close()
+}
+
 type HttpApiFunc func(eng *engine.Engine, version version.Version, w http.ResponseWriter, r *http.Request, vars map[string]string) error
 
 func hijackServer(w http.ResponseWriter) (io.ReadCloser, io.Writer, error) {
@@ -1334,9 +1348,14 @@ func ServeRequest(eng *engine.Engine, apiversion version.Version, w http.Respons
 	return nil
 }
 
-// ServeFD creates an http.Server and sets it up to serve given a socket activated
+// serveFd creates an http.Server and sets it up to serve given a socket activated
 // argument.
-func ServeFd(addr string, handle http.Handler) error {
+func serveFd(addr string, job *engine.Job) error {
+	r, err := createRouter(job.Eng, job.GetenvBool("Logging"), job.GetenvBool("EnableCors"), job.Getenv("Version"))
+	if err != nil {
+		return err
+	}
+
 	ls, e := systemd.ListenFD(addr)
 	if e != nil {
 		return e
@@ -1354,7 +1373,7 @@ func ServeFd(addr string, handle http.Handler) error {
 	for i := range ls {
 		listener := ls[i]
 		go func() {
-			httpSrv := http.Server{Handler: handle}
+			httpSrv := http.Server{Handler: r}
 			chErrors <- httpSrv.Serve(listener)
 		}()
 	}
@@ -1382,6 +1401,41 @@ func lookupGidByName(nameOrGid string) (int, error) {
 	return -1, fmt.Errorf("Group %s not found", nameOrGid)
 }
 
+func setupTls(cert, key, ca string, l net.Listener) (net.Listener, error) {
+	tlsCert, err := tls.LoadX509KeyPair(cert, key)
+	if err != nil {
+		return nil, fmt.Errorf("Couldn't load X509 key pair (%s, %s): %s. Key encrypted?",
+			cert, key, err)
+	}
+	tlsConfig := &tls.Config{
+		NextProtos:   []string{"http/1.1"},
+		Certificates: []tls.Certificate{tlsCert},
+		// Avoid fallback on insecure SSL protocols
+		MinVersion: tls.VersionTLS10,
+	}
+
+	if ca != "" {
+		certPool := x509.NewCertPool()
+		file, err := ioutil.ReadFile(ca)
+		if err != nil {
+			return nil, fmt.Errorf("Couldn't read CA certificate: %s", err)
+		}
+		certPool.AppendCertsFromPEM(file)
+		tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
+		tlsConfig.ClientCAs = certPool
+	}
+
+	return tls.NewListener(l, tlsConfig), nil
+}
+
+func newListener(proto, addr string, bufferRequests bool) (net.Listener, error) {
+	if bufferRequests {
+		return listenbuffer.NewListenBuffer(proto, addr, activationLock)
+	}
+
+	return net.Listen(proto, addr)
+}
+
 func changeGroup(addr string, nameOrGid string) error {
 	gid, err := lookupGidByName(nameOrGid)
 	if err != nil {
@@ -1392,99 +1446,95 @@ func changeGroup(addr string, nameOrGid string) error {
 	return os.Chown(addr, 0, gid)
 }
 
-// ListenAndServe sets up the required http.Server and gets it listening for
-// each addr passed in and does protocol specific checking.
-func ListenAndServe(proto, addr string, job *engine.Job) error {
-	var l net.Listener
+func setSocketGroup(addr, group string) error {
+	if group == "" {
+		return nil
+	}
+
+	if err := changeGroup(addr, group); err != nil {
+		if group != "docker" {
+			return err
+		}
+		log.Debugf("Warning: could not chgrp %s to docker: %v", addr, err)
+	}
+
+	return nil
+}
+
+func setupUnixHttp(addr string, job *engine.Job) (*HttpServer, error) {
 	r, err := createRouter(job.Eng, job.GetenvBool("Logging"), job.GetenvBool("EnableCors"), job.Getenv("Version"))
 	if err != nil {
-		return err
+		return nil, err
 	}
 
-	if proto == "fd" {
-		return ServeFd(addr, r)
+	if err := syscall.Unlink(addr); err != nil && !os.IsNotExist(err) {
+		return nil, err
 	}
+	mask := syscall.Umask(0777)
+	defer syscall.Umask(mask)
 
-	if proto == "unix" {
-		if err := syscall.Unlink(addr); err != nil && !os.IsNotExist(err) {
-			return err
-		}
+	l, err := newListener("unix", addr, job.GetenvBool("BufferRequests"))
+	if err != nil {
+		return nil, err
 	}
 
-	var oldmask int
-	if proto == "unix" {
-		oldmask = syscall.Umask(0777)
+	if err := setSocketGroup(addr, job.Getenv("SocketGroup")); err != nil {
+		return nil, err
 	}
 
-	if job.GetenvBool("BufferRequests") {
-		l, err = listenbuffer.NewListenBuffer(proto, addr, activationLock)
-	} else {
-		l, err = net.Listen(proto, addr)
+	if err := os.Chmod(addr, 0660); err != nil {
+		return nil, err
 	}
 
-	if proto == "unix" {
-		syscall.Umask(oldmask)
+	return &HttpServer{&http.Server{Addr: addr, Handler: r}, l}, nil
+}
+
+func setupTcpHttp(addr string, job *engine.Job) (*HttpServer, error) {
+	if !strings.HasPrefix(addr, "127.0.0.1") && !job.GetenvBool("TlsVerify") {
+		log.Infof("/!\\ DON'T BIND ON ANOTHER IP ADDRESS THAN 127.0.0.1 IF YOU DON'T KNOW WHAT YOU'RE DOING /!\\")
 	}
+
+	r, err := createRouter(job.Eng, job.GetenvBool("Logging"), job.GetenvBool("EnableCors"), job.Getenv("Version"))
 	if err != nil {
-		return err
+		return nil, err
 	}
 
-	if proto != "unix" && (job.GetenvBool("Tls") || job.GetenvBool("TlsVerify")) {
-		tlsCert := job.Getenv("TlsCert")
-		tlsKey := job.Getenv("TlsKey")
-		cert, err := tls.LoadX509KeyPair(tlsCert, tlsKey)
-		if err != nil {
-			return fmt.Errorf("Couldn't load X509 key pair (%s, %s): %s. Key encrypted?",
-				tlsCert, tlsKey, err)
-		}
-		tlsConfig := &tls.Config{
-			NextProtos:   []string{"http/1.1"},
-			Certificates: []tls.Certificate{cert},
-			// Avoid fallback on insecure SSL protocols
-			MinVersion: tls.VersionTLS10,
-		}
-		if job.GetenvBool("TlsVerify") {
-			certPool := x509.NewCertPool()
-			file, err := ioutil.ReadFile(job.Getenv("TlsCa"))
-			if err != nil {
-				return fmt.Errorf("Couldn't read CA certificate: %s", err)
-			}
-			certPool.AppendCertsFromPEM(file)
+	l, err := newListener("tcp", addr, job.GetenvBool("BufferRequests"))
+	if err != nil {
+		return nil, err
+	}
 
-			tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
-			tlsConfig.ClientCAs = certPool
+	if job.GetenvBool("Tls") || job.GetenvBool("TlsVerify") {
+		var tlsCa string
+		if job.GetenvBool("TlsVerify") {
+			tlsCa = job.Getenv("TlsCa")
+		}
+		l, err = setupTls(job.Getenv("TlsCert"), job.Getenv("TlsKey"), tlsCa, l)
+		if err != nil {
+			return nil, err
 		}
-		l = tls.NewListener(l, tlsConfig)
 	}
+	return &HttpServer{&http.Server{Addr: addr, Handler: r}, l}, nil
+}
 
+// NewServer sets up the required Server and does protocol specific checking.
+func NewServer(proto, addr string, job *engine.Job) (Server, error) {
 	// Basic error and sanity checking
 	switch proto {
+	case "fd":
+		return nil, serveFd(addr, job)
 	case "tcp":
-		if !strings.HasPrefix(addr, "127.0.0.1") && !job.GetenvBool("TlsVerify") {
-			log.Infof("/!\\ DON'T BIND ON ANOTHER IP ADDRESS THAN 127.0.0.1 IF YOU DON'T KNOW WHAT YOU'RE DOING /!\\")
-		}
+		return setupTcpHttp(addr, job)
 	case "unix":
-		socketGroup := job.Getenv("SocketGroup")
-		if socketGroup != "" {
-			if err := changeGroup(addr, socketGroup); err != nil {
-				if socketGroup == "docker" {
-					// if the user hasn't explicitly specified the group ownership, don't fail on errors.
-					log.Debugf("Warning: could not chgrp %s to docker: %s", addr, err.Error())
-				} else {
-					return err
-				}
-			}
-
-		}
-		if err := os.Chmod(addr, 0660); err != nil {
-			return err
-		}
+		return setupUnixHttp(addr, job)
 	default:
-		return fmt.Errorf("Invalid protocol format.")
+		return nil, fmt.Errorf("Invalid protocol format.")
 	}
+}
 
-	httpSrv := http.Server{Addr: addr, Handler: r}
-	return httpSrv.Serve(l)
+type Server interface {
+	Serve() error
+	Close() error
 }
 
 // ServeApi loops through all of the protocols sent in to docker and spawns
@@ -1506,7 +1556,12 @@ func ServeApi(job *engine.Job) engine.Status {
 		}
 		go func() {
 			log.Infof("Listening for HTTP on %s (%s)", protoAddrParts[0], protoAddrParts[1])
-			chErrors <- ListenAndServe(protoAddrParts[0], protoAddrParts[1], job)
+			srv, err := NewServer(protoAddrParts[0], protoAddrParts[1], job)
+			if err != nil {
+				chErrors <- err
+				return
+			}
+			chErrors <- srv.Serve()
 		}()
 	}