Browse Source

Update nvidia_devices to call into nvidia-container-runtime-hook

Signed-off-by: Renaud Gaubert <rgaubert@nvidia.com>
Renaud Gaubert 6 years ago
parent
commit
bd3d46a9e5
1 changed files with 33 additions and 35 deletions
  1. 33 35
      daemon/nvidia_linux.go

+ 33 - 35
daemon/nvidia_linux.go

@@ -1,8 +1,10 @@
 package daemon
 package daemon
 
 
 import (
 import (
+	"os"
 	"os/exec"
 	"os/exec"
 	"strconv"
 	"strconv"
+	"strings"
 
 
 	"github.com/containerd/containerd/contrib/nvidia"
 	"github.com/containerd/containerd/contrib/nvidia"
 	"github.com/docker/docker/pkg/capabilities"
 	"github.com/docker/docker/pkg/capabilities"
@@ -15,8 +17,7 @@ import (
 
 
 var errConflictCountDeviceIDs = errors.New("cannot set both Count and DeviceIDs on device request")
 var errConflictCountDeviceIDs = errors.New("cannot set both Count and DeviceIDs on device request")
 
 
-// stolen from github.com/containerd/containerd/contrib/nvidia
-const nvidiaCLI = "nvidia-container-cli"
+const nvidiaHook = "nvidia-container-runtime-hook"
 
 
 // These are NVIDIA-specific capabilities stolen from github.com/containerd/containerd/contrib/nvidia.allCaps
 // These are NVIDIA-specific capabilities stolen from github.com/containerd/containerd/contrib/nvidia.allCaps
 var allNvidiaCaps = map[nvidia.Capability]struct{}{
 var allNvidiaCaps = map[nvidia.Capability]struct{}{
@@ -29,7 +30,7 @@ var allNvidiaCaps = map[nvidia.Capability]struct{}{
 }
 }
 
 
 func init() {
 func init() {
-	if _, err := exec.LookPath(nvidiaCLI); err != nil {
+	if _, err := exec.LookPath(nvidiaHook); err != nil {
 		// do not register Nvidia driver if helper binary is not present.
 		// do not register Nvidia driver if helper binary is not present.
 		return
 		return
 	}
 	}
@@ -45,45 +46,25 @@ func init() {
 }
 }
 
 
 func setNvidiaGPUs(s *specs.Spec, dev *deviceInstance) error {
 func setNvidiaGPUs(s *specs.Spec, dev *deviceInstance) error {
-	var opts []nvidia.Opts
-
 	req := dev.req
 	req := dev.req
 	if req.Count != 0 && len(req.DeviceIDs) > 0 {
 	if req.Count != 0 && len(req.DeviceIDs) > 0 {
 		return errConflictCountDeviceIDs
 		return errConflictCountDeviceIDs
 	}
 	}
 
 
 	if len(req.DeviceIDs) > 0 {
 	if len(req.DeviceIDs) > 0 {
-		var ids []int
-		var uuids []string
-		for _, devID := range req.DeviceIDs {
-			id, err := strconv.Atoi(devID)
-			if err == nil {
-				ids = append(ids, id)
-				continue
-			}
-			// if not an integer, then assume UUID.
-			uuids = append(uuids, devID)
-		}
-		if len(ids) > 0 {
-			opts = append(opts, nvidia.WithDevices(ids...))
-		}
-		if len(uuids) > 0 {
-			opts = append(opts, nvidia.WithDeviceUUIDs(uuids...))
-		}
-	}
-
-	if req.Count < 0 {
-		opts = append(opts, nvidia.WithAllDevices)
+		s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES="+strings.Join(req.DeviceIDs, ","))
 	} else if req.Count > 0 {
 	} else if req.Count > 0 {
-		opts = append(opts, nvidia.WithDevices(countToDevices(req.Count)...))
+		s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES="+countToDevices(req.Count))
+	} else if req.Count < 0 {
+		s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES=all")
 	}
 	}
 
 
-	var nvidiaCaps []nvidia.Capability
+	var nvidiaCaps []string
 	// req.Capabilities contains device capabilities, some but not all are NVIDIA driver capabilities.
 	// req.Capabilities contains device capabilities, some but not all are NVIDIA driver capabilities.
 	for _, c := range dev.selectedCaps {
 	for _, c := range dev.selectedCaps {
 		nvcap := nvidia.Capability(c)
 		nvcap := nvidia.Capability(c)
 		if _, isNvidiaCap := allNvidiaCaps[nvcap]; isNvidiaCap {
 		if _, isNvidiaCap := allNvidiaCaps[nvcap]; isNvidiaCap {
-			nvidiaCaps = append(nvidiaCaps, nvcap)
+			nvidiaCaps = append(nvidiaCaps, c)
 			continue
 			continue
 		}
 		}
 		// TODO: nvidia.WithRequiredCUDAVersion
 		// TODO: nvidia.WithRequiredCUDAVersion
@@ -91,17 +72,34 @@ func setNvidiaGPUs(s *specs.Spec, dev *deviceInstance) error {
 	}
 	}
 
 
 	if nvidiaCaps != nil {
 	if nvidiaCaps != nil {
-		opts = append(opts, nvidia.WithCapabilities(nvidiaCaps...))
+		s.Process.Env = append(s.Process.Env, "NVIDIA_DRIVER_CAPABILITIES="+strings.Join(nvidiaCaps, ","))
+	}
+
+	path, err := exec.LookPath(nvidiaHook)
+	if err != nil {
+		return err
+	}
+
+	if s.Hooks == nil {
+		s.Hooks = &specs.Hooks{}
 	}
 	}
+	s.Hooks.Prestart = append(s.Hooks.Prestart, specs.Hook{
+		Path: path,
+		Args: []string{
+			nvidiaHook,
+			"prestart",
+		},
+		Env: os.Environ(),
+	})
 
 
-	return nvidia.WithGPUs(opts...)(nil, nil, nil, s)
+	return nil
 }
 }
 
 
 // countToDevices returns the list 0, 1, ... count-1 of deviceIDs.
 // countToDevices returns the list 0, 1, ... count-1 of deviceIDs.
-func countToDevices(count int) []int {
-	devices := make([]int, count)
+func countToDevices(count int) string {
+	devices := make([]string, count)
 	for i := range devices {
 	for i := range devices {
-		devices[i] = i
+		devices[i] = strconv.Itoa(i)
 	}
 	}
-	return devices
+	return strings.Join(devices, ",")
 }
 }