nvidia_linux.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. package daemon
  2. import (
  3. "os"
  4. "os/exec"
  5. "strconv"
  6. "strings"
  7. "github.com/containerd/containerd/contrib/nvidia"
  8. "github.com/docker/docker/pkg/capabilities"
  9. specs "github.com/opencontainers/runtime-spec/specs-go"
  10. "github.com/pkg/errors"
  11. )
  12. // TODO: nvidia should not be hard-coded, and should be a device plugin instead on the daemon object.
  13. // TODO: add list of device capabilities in daemon/node info
  14. var errConflictCountDeviceIDs = errors.New("cannot set both Count and DeviceIDs on device request")
  15. const nvidiaHook = "nvidia-container-runtime-hook"
  16. // These are NVIDIA-specific capabilities stolen from github.com/containerd/containerd/contrib/nvidia.allCaps
  17. var allNvidiaCaps = map[nvidia.Capability]struct{}{
  18. nvidia.Compute: {},
  19. nvidia.Compat32: {},
  20. nvidia.Graphics: {},
  21. nvidia.Utility: {},
  22. nvidia.Video: {},
  23. nvidia.Display: {},
  24. }
  25. func init() {
  26. if _, err := exec.LookPath(nvidiaHook); err != nil {
  27. // do not register Nvidia driver if helper binary is not present.
  28. return
  29. }
  30. capset := capabilities.Set{"gpu": struct{}{}, "nvidia": struct{}{}}
  31. nvidiaDriver := &deviceDriver{
  32. capset: capset,
  33. updateSpec: setNvidiaGPUs,
  34. }
  35. for c := range allNvidiaCaps {
  36. nvidiaDriver.capset[string(c)] = struct{}{}
  37. }
  38. registerDeviceDriver("nvidia", nvidiaDriver)
  39. }
  40. func setNvidiaGPUs(s *specs.Spec, dev *deviceInstance) error {
  41. req := dev.req
  42. if req.Count != 0 && len(req.DeviceIDs) > 0 {
  43. return errConflictCountDeviceIDs
  44. }
  45. if len(req.DeviceIDs) > 0 {
  46. s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES="+strings.Join(req.DeviceIDs, ","))
  47. } else if req.Count > 0 {
  48. s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES="+countToDevices(req.Count))
  49. } else if req.Count < 0 {
  50. s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES=all")
  51. }
  52. var nvidiaCaps []string
  53. // req.Capabilities contains device capabilities, some but not all are NVIDIA driver capabilities.
  54. for _, c := range dev.selectedCaps {
  55. nvcap := nvidia.Capability(c)
  56. if _, isNvidiaCap := allNvidiaCaps[nvcap]; isNvidiaCap {
  57. nvidiaCaps = append(nvidiaCaps, c)
  58. continue
  59. }
  60. // TODO: nvidia.WithRequiredCUDAVersion
  61. // for now we let the prestart hook verify cuda versions but errors are not pretty.
  62. }
  63. if nvidiaCaps != nil {
  64. s.Process.Env = append(s.Process.Env, "NVIDIA_DRIVER_CAPABILITIES="+strings.Join(nvidiaCaps, ","))
  65. }
  66. path, err := exec.LookPath(nvidiaHook)
  67. if err != nil {
  68. return err
  69. }
  70. if s.Hooks == nil {
  71. s.Hooks = &specs.Hooks{}
  72. }
  73. s.Hooks.Prestart = append(s.Hooks.Prestart, specs.Hook{
  74. Path: path,
  75. Args: []string{
  76. nvidiaHook,
  77. "prestart",
  78. },
  79. Env: os.Environ(),
  80. })
  81. return nil
  82. }
  83. // countToDevices returns the list 0, 1, ... count-1 of deviceIDs.
  84. func countToDevices(count int) string {
  85. devices := make([]string, count)
  86. for i := range devices {
  87. devices[i] = strconv.Itoa(i)
  88. }
  89. return strings.Join(devices, ",")
  90. }