Browse Source

put the websocket route in the map containing all routes

Instead of handling the websocket differently just handle it as a normal
route and upgrade it to a websocket.
benoitc 12 years ago
parent
commit
166eba3e28
1 changed files with 73 additions and 68 deletions
  1. 73 68
      api.go

+ 73 - 68
api.go

@@ -22,6 +22,9 @@ const APIVERSION = 1.3
 const DEFAULTHTTPHOST string = "127.0.0.1"
 const DEFAULTHTTPPORT int = 4243
 
+type HttpApiFunc func(srv *Server, version float64, w http.ResponseWriter, r *http.Request, vars map[string]string) error
+type WsApiFunc func(srv *Server, ws *websocket.Conn, vars map[string]string) error
+
 func hijackServer(w http.ResponseWriter) (io.ReadCloser, io.Writer, error) {
 	conn, _, err := w.(http.Hijacker).Hijack()
 	if err != nil {
@@ -694,9 +697,7 @@ func postContainersAttach(srv *Server, version float64, w http.ResponseWriter, r
 	return nil
 }
 
-func wsContainersAttach(srv *Server, ws *websocket.Conn, vars map[string]string) error {
-
-	r := ws.Request()
+func wsContainersAttach(srv *Server, version float64, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
 
 	if err := parseForm(r); err != nil {
 		return err
@@ -731,11 +732,14 @@ func wsContainersAttach(srv *Server, ws *websocket.Conn, vars map[string]string)
 		return err
 	}
 
-	defer ws.Close()
+	h := websocket.Handler(func(ws *websocket.Conn) {
+		defer ws.Close()
 
-	if err := srv.ContainerAttach(name, logs, stream, stdin, stdout, stderr, ws, ws); err != nil {
-		return err
-	}
+		if err := srv.ContainerAttach(name, logs, stream, stdin, stdout, stderr, ws, ws); err != nil {
+			utils.Debugf("Error: %s", err)
+		}
+	})
+	h.ServeHTTP(w, r)
 
 	return nil
 }
@@ -873,24 +877,68 @@ func writeCorsHeaders(w http.ResponseWriter, r *http.Request) {
 	w.Header().Add("Access-Control-Allow-Methods", "GET, POST, DELETE, PUT, OPTIONS")
 }
 
+func logRequest(logging bool, localMethod string, localRoute string, r *http.Request) {
+	utils.Debugf("Calling %s %s", localMethod, localRoute)
+
+	if logging {
+		log.Println(r.Method, r.RequestURI)
+	}
+}
+
+func makeHttpHandler(srv *Server, logging bool, localMethod string, localRoute string, handlerFunc HttpApiFunc) http.HandlerFunc {
+	return func(w http.ResponseWriter, r *http.Request) {
+		// log the request
+		utils.Debugf("Calling %s %s", localMethod, localRoute)
+
+		if logging {
+			log.Println(r.Method, r.RequestURI)
+		}
+
+		if strings.Contains(r.Header.Get("User-Agent"), "Docker-Client/") {
+			userAgent := strings.Split(r.Header.Get("User-Agent"), "/")
+			if len(userAgent) == 2 && userAgent[1] != VERSION {
+				utils.Debugf("Warning: client and server don't have the same version (client: %s, server: %s)", userAgent[1], VERSION)
+			}
+		}
+		version, err := strconv.ParseFloat(mux.Vars(r)["version"], 64)
+		if err != nil {
+			version = APIVERSION
+		}
+		if srv.enableCors {
+			writeCorsHeaders(w, r)
+		}
+
+		if version == 0 || version > APIVERSION {
+			w.WriteHeader(http.StatusNotFound)
+			return
+		}
+
+		if err := handlerFunc(srv, version, w, r, mux.Vars(r)); err != nil {
+			utils.Debugf("Error: %s", err)
+			httpError(w, err)
+		}
+	}
+}
+
 func createRouter(srv *Server, logging bool) (*mux.Router, error) {
 	r := mux.NewRouter()
 
 	m := map[string]map[string]func(*Server, float64, http.ResponseWriter, *http.Request, map[string]string) error{
 		"GET": {
-			"/auth":                         getAuth,
-			"/version":                      getVersion,
-			"/info":                         getInfo,
-			"/images/json":                  getImagesJSON,
-			"/images/viz":                   getImagesViz,
-			"/images/search":                getImagesSearch,
-			"/images/{name:.*}/history":     getImagesHistory,
-			"/images/{name:.*}/json":        getImagesByName,
-			"/containers/ps":                getContainersJSON,
-			"/containers/json":              getContainersJSON,
-			"/containers/{name:.*}/export":  getContainersExport,
-			"/containers/{name:.*}/changes": getContainersChanges,
-			"/containers/{name:.*}/json":    getContainersByName,
+			"/auth":                           getAuth,
+			"/version":                        getVersion,
+			"/info":                           getInfo,
+			"/images/json":                    getImagesJSON,
+			"/images/viz":                     getImagesViz,
+			"/images/search":                  getImagesSearch,
+			"/images/{name:.*}/history":       getImagesHistory,
+			"/images/{name:.*}/json":          getImagesByName,
+			"/containers/ps":                  getContainersJSON,
+			"/containers/json":                getContainersJSON,
+			"/containers/{name:.*}/export":    getContainersExport,
+			"/containers/{name:.*}/changes":   getContainersChanges,
+			"/containers/{name:.*}/json":      getContainersByName,
+			"/containers/{name:.*}/attach/ws": wsContainersAttach,
 		},
 		"POST": {
 			"/auth":                         postAuth,
@@ -924,64 +972,21 @@ func createRouter(srv *Server, logging bool) (*mux.Router, error) {
 			utils.Debugf("Registering %s, %s", method, route)
 			// NOTE: scope issue, make sure the variables are local and won't be changed
 			localRoute := route
-			localMethod := method
 			localFct := fct
-			f := func(w http.ResponseWriter, r *http.Request) {
-				utils.Debugf("Calling %s %s", localMethod, localRoute)
-
-				if logging {
-					log.Println(r.Method, r.RequestURI)
-				}
-				if strings.Contains(r.Header.Get("User-Agent"), "Docker-Client/") {
-					userAgent := strings.Split(r.Header.Get("User-Agent"), "/")
-					if len(userAgent) == 2 && userAgent[1] != VERSION {
-						utils.Debugf("Warning: client and server don't have the same version (client: %s, server: %s)", userAgent[1], VERSION)
-					}
-				}
-				version, err := strconv.ParseFloat(mux.Vars(r)["version"], 64)
-				if err != nil {
-					version = APIVERSION
-				}
-				if srv.enableCors {
-					writeCorsHeaders(w, r)
-				}
-
-				if version == 0 || version > APIVERSION {
-					w.WriteHeader(http.StatusNotFound)
-					return
-				}
-
-				if err := localFct(srv, version, w, r, mux.Vars(r)); err != nil {
-					utils.Debugf("Error: %s", err)
-					httpError(w, err)
-				}
-			}
+			localMethod := method
 
+			// build the handler function
+			f := makeHttpHandler(srv, logging, localMethod, localRoute, localFct)
+
+			// add the new route
 			if localRoute == "" {
 				r.Methods(localMethod).HandlerFunc(f)
 			} else {
 				r.Path("/v{version:[0-9.]+}" + localRoute).Methods(localMethod).HandlerFunc(f)
 				r.Path(localRoute).Methods(localMethod).HandlerFunc(f)
 			}
-
 		}
 	}
-	attachHandler := websocket.Handler(func(ws *websocket.Conn) {
-		r := ws.Request()
-		utils.Debugf("Calling %s %s", r.Method, r.RequestURI)
-
-		if logging {
-			log.Println(r.Method, r.RequestURI)
-		}
-		if err := wsContainersAttach(srv, ws, mux.Vars(r)); err != nil {
-			utils.Debugf("Error: %s", err)
-			ws.Close()
-		}
-	})
-
-	attachRoute := "/containers/{name:.*}/attach/ws"
-	r.Path("/v{version:[0-9.]+}" + attachRoute).Methods("GET").Handler(attachHandler)
-	r.Path(attachRoute).Methods("GET").Handler(attachHandler)
 
 	return r, nil
 }