Browse code

Update nvidia_devices to call into nvidia-container-runtime-hook

Signed-off-by: Renaud Gaubert <rgaubert@nvidia.com>

Renaud Gaubert authored on 2019/03/26 16:56:17
Showing 1 changed files
... ...
@@ -1,8 +1,10 @@
1 1
 package daemon
2 2
 
3 3
 import (
4
+	"os"
4 5
 	"os/exec"
5 6
 	"strconv"
7
+	"strings"
6 8
 
7 9
 	"github.com/containerd/containerd/contrib/nvidia"
8 10
 	"github.com/docker/docker/pkg/capabilities"
... ...
@@ -15,8 +17,7 @@ import (
15 15
 
16 16
 var errConflictCountDeviceIDs = errors.New("cannot set both Count and DeviceIDs on device request")
17 17
 
18
-// stolen from github.com/containerd/containerd/contrib/nvidia
19
-const nvidiaCLI = "nvidia-container-cli"
18
+const nvidiaHook = "nvidia-container-runtime-hook"
20 19
 
21 20
 // These are NVIDIA-specific capabilities stolen from github.com/containerd/containerd/contrib/nvidia.allCaps
22 21
 var allNvidiaCaps = map[nvidia.Capability]struct{}{
... ...
@@ -29,7 +30,7 @@ var allNvidiaCaps = map[nvidia.Capability]struct{}{
29 29
 }
30 30
 
31 31
 func init() {
32
-	if _, err := exec.LookPath(nvidiaCLI); err != nil {
32
+	if _, err := exec.LookPath(nvidiaHook); err != nil {
33 33
 		// do not register Nvidia driver if helper binary is not present.
34 34
 		return
35 35
 	}
... ...
@@ -45,45 +46,25 @@ func init() {
45 45
 }
46 46
 
47 47
 func setNvidiaGPUs(s *specs.Spec, dev *deviceInstance) error {
48
-	var opts []nvidia.Opts
49
-
50 48
 	req := dev.req
51 49
 	if req.Count != 0 && len(req.DeviceIDs) > 0 {
52 50
 		return errConflictCountDeviceIDs
53 51
 	}
54 52
 
55 53
 	if len(req.DeviceIDs) > 0 {
56
-		var ids []int
57
-		var uuids []string
58
-		for _, devID := range req.DeviceIDs {
59
-			id, err := strconv.Atoi(devID)
60
-			if err == nil {
61
-				ids = append(ids, id)
62
-				continue
63
-			}
64
-			// if not an integer, then assume UUID.
65
-			uuids = append(uuids, devID)
66
-		}
67
-		if len(ids) > 0 {
68
-			opts = append(opts, nvidia.WithDevices(ids...))
69
-		}
70
-		if len(uuids) > 0 {
71
-			opts = append(opts, nvidia.WithDeviceUUIDs(uuids...))
72
-		}
73
-	}
74
-
75
-	if req.Count < 0 {
76
-		opts = append(opts, nvidia.WithAllDevices)
54
+		s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES="+strings.Join(req.DeviceIDs, ","))
77 55
 	} else if req.Count > 0 {
78
-		opts = append(opts, nvidia.WithDevices(countToDevices(req.Count)...))
56
+		s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES="+countToDevices(req.Count))
57
+	} else if req.Count < 0 {
58
+		s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES=all")
79 59
 	}
80 60
 
81
-	var nvidiaCaps []nvidia.Capability
61
+	var nvidiaCaps []string
82 62
 	// req.Capabilities contains device capabilities, some but not all are NVIDIA driver capabilities.
83 63
 	for _, c := range dev.selectedCaps {
84 64
 		nvcap := nvidia.Capability(c)
85 65
 		if _, isNvidiaCap := allNvidiaCaps[nvcap]; isNvidiaCap {
86
-			nvidiaCaps = append(nvidiaCaps, nvcap)
66
+			nvidiaCaps = append(nvidiaCaps, c)
87 67
 			continue
88 68
 		}
89 69
 		// TODO: nvidia.WithRequiredCUDAVersion
... ...
@@ -91,17 +72,34 @@ func setNvidiaGPUs(s *specs.Spec, dev *deviceInstance) error {
91 91
 	}
92 92
 
93 93
 	if nvidiaCaps != nil {
94
-		opts = append(opts, nvidia.WithCapabilities(nvidiaCaps...))
94
+		s.Process.Env = append(s.Process.Env, "NVIDIA_DRIVER_CAPABILITIES="+strings.Join(nvidiaCaps, ","))
95
+	}
96
+
97
+	path, err := exec.LookPath(nvidiaHook)
98
+	if err != nil {
99
+		return err
100
+	}
101
+
102
+	if s.Hooks == nil {
103
+		s.Hooks = &specs.Hooks{}
95 104
 	}
105
+	s.Hooks.Prestart = append(s.Hooks.Prestart, specs.Hook{
106
+		Path: path,
107
+		Args: []string{
108
+			nvidiaHook,
109
+			"prestart",
110
+		},
111
+		Env: os.Environ(),
112
+	})
96 113
 
97
-	return nvidia.WithGPUs(opts...)(nil, nil, nil, s)
114
+	return nil
98 115
 }
99 116
 
100 117
 // countToDevices returns the list 0, 1, ... count-1 of deviceIDs.
101
-func countToDevices(count int) []int {
102
-	devices := make([]int, count)
118
+func countToDevices(count int) string {
119
+	devices := make([]string, count)
103 120
 	for i := range devices {
104
-		devices[i] = i
121
+		devices[i] = strconv.Itoa(i)
105 122
 	}
106
-	return devices
123
+	return strings.Join(devices, ",")
107 124
 }