Browse code

Added support for AMD GPUs in "docker run --gpus".

Added backend code to support the exact same interface
used today for Nvidia GPUs, allowing customers to use
the same docker commands for both Nvidia and AMD GPUs.

Signed-off-by: Sudheendra Gopinath <sudheendra.gopinath@amd.com>

Reused common functions from nvidia_linux.go.

Removed duplicate code in amd_linux.go by reusing
the init() and countToDevices() functions in
nvidia_linux.go. AMD driver is registered in init().

Signed-off-by: Sudheendra Gopinath <sudheendra.gopinath@amd.com>

Renamed amd-container-runtime constant

Signed-off-by: Sudheendra Gopinath <sudheendra.gopinath@amd.com>

Removed empty branch to keep linter happy.

Also renamed amd_linux.go to gpu_amd_linux.go.

Signed-off-by: Sudheendra Gopinath <sudheendra.gopinath@amd.com>

Renamed nvidia_linux.go and gpu_amd_linux.go.

Signed-off-by: Sudheendra Gopinath <sudheendra.gopinath@amd.com>

Sudheendra Gopinath authored on 2025/05/10 15:42:48
Showing 3 changed files
1 1
new file mode 100644
... ...
@@ -0,0 +1,27 @@
0
+package daemon
1
+
2
+import (
3
+	"strings"
4
+
5
+	"github.com/opencontainers/runtime-spec/specs-go"
6
+)
7
+
8
+func setAMDGPUs(s *specs.Spec, dev *deviceInstance) error {
9
+	req := dev.req
10
+	if req.Count != 0 && len(req.DeviceIDs) > 0 {
11
+		return errConflictCountDeviceIDs
12
+	}
13
+
14
+	switch {
15
+	case len(req.DeviceIDs) > 0:
16
+		s.Process.Env = append(s.Process.Env, "AMD_VISIBLE_DEVICES="+strings.Join(req.DeviceIDs, ","))
17
+	case req.Count > 0:
18
+		s.Process.Env = append(s.Process.Env, "AMD_VISIBLE_DEVICES="+countToDevices(req.Count))
19
+	case req.Count < 0:
20
+		s.Process.Env = append(s.Process.Env, "AMD_VISIBLE_DEVICES=all")
21
+	case req.Count == 0:
22
+		s.Process.Env = append(s.Process.Env, "AMD_VISIBLE_DEVICES=void")
23
+	}
24
+
25
+	return nil
26
+}
0 27
new file mode 100644
... ...
@@ -0,0 +1,127 @@
0
+package daemon
1
+
2
+import (
3
+	"os"
4
+	"os/exec"
5
+	"strconv"
6
+	"strings"
7
+
8
+	"github.com/containerd/containerd/v2/contrib/nvidia"
9
+	"github.com/docker/docker/daemon/internal/capabilities"
10
+	"github.com/opencontainers/runtime-spec/specs-go"
11
+	"github.com/pkg/errors"
12
+)
13
+
14
+// TODO: nvidia should not be hard-coded, and should be a device plugin instead on the daemon object.
15
+// TODO: add list of device capabilities in daemon/node info
16
+
17
+var errConflictCountDeviceIDs = errors.New("cannot set both Count and DeviceIDs on device request")
18
+
19
+const (
20
+	nvidiaHook                        = "nvidia-container-runtime-hook"
21
+	amdContainerRuntimeExecutableName = "amd-container-runtime"
22
+)
23
+
24
+// These are NVIDIA-specific capabilities stolen from github.com/containerd/containerd/contrib/nvidia.allCaps
25
+var allNvidiaCaps = map[nvidia.Capability]struct{}{
26
+	nvidia.Compute:  {},
27
+	nvidia.Compat32: {},
28
+	nvidia.Graphics: {},
29
+	nvidia.Utility:  {},
30
+	nvidia.Video:    {},
31
+	nvidia.Display:  {},
32
+}
33
+
34
+func init() {
35
+	// Register Nvidia driver if Nvidia helper binary is present.
36
+	if _, err := exec.LookPath(nvidiaHook); err == nil {
37
+		capset := capabilities.Set{"gpu": struct{}{}, "nvidia": struct{}{}}
38
+		for c := range allNvidiaCaps {
39
+			capset[string(c)] = struct{}{}
40
+		}
41
+		registerDeviceDriver("nvidia", &deviceDriver{
42
+			capset:     capset,
43
+			updateSpec: setNvidiaGPUs,
44
+		})
45
+		return
46
+	}
47
+
48
+	// Register AMD driver if AMD helper binary is present.
49
+	if _, err := exec.LookPath(amdContainerRuntimeExecutableName); err == nil {
50
+		registerDeviceDriver("amd", &deviceDriver{
51
+			capset:     capabilities.Set{"gpu": struct{}{}, "amd": struct{}{}},
52
+			updateSpec: setAMDGPUs,
53
+		})
54
+		return
55
+	}
56
+
57
+	// No "gpu" capability
58
+}
59
+
60
+func setNvidiaGPUs(s *specs.Spec, dev *deviceInstance) error {
61
+	req := dev.req
62
+	if req.Count != 0 && len(req.DeviceIDs) > 0 {
63
+		return errConflictCountDeviceIDs
64
+	}
65
+
66
+	switch {
67
+	case len(req.DeviceIDs) > 0:
68
+		s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES="+strings.Join(req.DeviceIDs, ","))
69
+	case req.Count > 0:
70
+		s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES="+countToDevices(req.Count))
71
+	case req.Count < 0:
72
+		s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES=all")
73
+	case req.Count == 0:
74
+		s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES=void")
75
+	}
76
+
77
+	var nvidiaCaps []string
78
+	// req.Capabilities contains device capabilities, some but not all are NVIDIA driver capabilities.
79
+	for _, c := range dev.selectedCaps {
80
+		nvcap := nvidia.Capability(c)
81
+		if _, isNvidiaCap := allNvidiaCaps[nvcap]; isNvidiaCap {
82
+			nvidiaCaps = append(nvidiaCaps, c)
83
+			continue
84
+		}
85
+		// TODO: nvidia.WithRequiredCUDAVersion
86
+		// for now we let the prestart hook verify cuda versions but errors are not pretty.
87
+	}
88
+
89
+	if nvidiaCaps != nil {
90
+		s.Process.Env = append(s.Process.Env, "NVIDIA_DRIVER_CAPABILITIES="+strings.Join(nvidiaCaps, ","))
91
+	}
92
+
93
+	path, err := exec.LookPath(nvidiaHook)
94
+	if err != nil {
95
+		return err
96
+	}
97
+
98
+	if s.Hooks == nil {
99
+		s.Hooks = &specs.Hooks{}
100
+	}
101
+
102
+	// This implementation uses prestart hooks, which are deprecated.
103
+	// CreateRuntime is the closest equivalent, and executed in the same
104
+	// locations as prestart-hooks, but depending on what these hooks do,
105
+	// possibly one of the other hooks could be used instead (such as
106
+	// CreateContainer or StartContainer).
107
+	s.Hooks.Prestart = append(s.Hooks.Prestart, specs.Hook{ //nolint:staticcheck // FIXME(thaJeztah); replace prestart hook with a non-deprecated one.
108
+		Path: path,
109
+		Args: []string{
110
+			nvidiaHook,
111
+			"prestart",
112
+		},
113
+		Env: os.Environ(),
114
+	})
115
+
116
+	return nil
117
+}
118
+
119
+// countToDevices returns the list 0, 1, ... count-1 of deviceIDs.
120
+func countToDevices(count int) string {
121
+	devices := make([]string, count)
122
+	for i := range devices {
123
+		devices[i] = strconv.Itoa(i)
124
+	}
125
+	return strings.Join(devices, ",")
126
+}
0 127
deleted file mode 100644
... ...
@@ -1,114 +0,0 @@
1
-package daemon
2
-
3
-import (
4
-	"os"
5
-	"os/exec"
6
-	"strconv"
7
-	"strings"
8
-
9
-	"github.com/containerd/containerd/v2/contrib/nvidia"
10
-	"github.com/docker/docker/daemon/internal/capabilities"
11
-	"github.com/opencontainers/runtime-spec/specs-go"
12
-	"github.com/pkg/errors"
13
-)
14
-
15
-// TODO: nvidia should not be hard-coded, and should be a device plugin instead on the daemon object.
16
-// TODO: add list of device capabilities in daemon/node info
17
-
18
-var errConflictCountDeviceIDs = errors.New("cannot set both Count and DeviceIDs on device request")
19
-
20
-const nvidiaHook = "nvidia-container-runtime-hook"
21
-
22
-// These are NVIDIA-specific capabilities stolen from github.com/containerd/containerd/contrib/nvidia.allCaps
23
-var allNvidiaCaps = map[nvidia.Capability]struct{}{
24
-	nvidia.Compute:  {},
25
-	nvidia.Compat32: {},
26
-	nvidia.Graphics: {},
27
-	nvidia.Utility:  {},
28
-	nvidia.Video:    {},
29
-	nvidia.Display:  {},
30
-}
31
-
32
-func init() {
33
-	if _, err := exec.LookPath(nvidiaHook); err != nil {
34
-		// do not register Nvidia driver if helper binary is not present.
35
-		return
36
-	}
37
-	capset := capabilities.Set{"gpu": struct{}{}, "nvidia": struct{}{}}
38
-	nvidiaDriver := &deviceDriver{
39
-		capset:     capset,
40
-		updateSpec: setNvidiaGPUs,
41
-	}
42
-	for c := range allNvidiaCaps {
43
-		nvidiaDriver.capset[string(c)] = struct{}{}
44
-	}
45
-	registerDeviceDriver("nvidia", nvidiaDriver)
46
-}
47
-
48
-func setNvidiaGPUs(s *specs.Spec, dev *deviceInstance) error {
49
-	req := dev.req
50
-	if req.Count != 0 && len(req.DeviceIDs) > 0 {
51
-		return errConflictCountDeviceIDs
52
-	}
53
-
54
-	switch {
55
-	case len(req.DeviceIDs) > 0:
56
-		s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES="+strings.Join(req.DeviceIDs, ","))
57
-	case req.Count > 0:
58
-		s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES="+countToDevices(req.Count))
59
-	case req.Count < 0:
60
-		s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES=all")
61
-	case req.Count == 0:
62
-		s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES=void")
63
-	}
64
-
65
-	var nvidiaCaps []string
66
-	// req.Capabilities contains device capabilities, some but not all are NVIDIA driver capabilities.
67
-	for _, c := range dev.selectedCaps {
68
-		nvcap := nvidia.Capability(c)
69
-		if _, isNvidiaCap := allNvidiaCaps[nvcap]; isNvidiaCap {
70
-			nvidiaCaps = append(nvidiaCaps, c)
71
-			continue
72
-		}
73
-		// TODO: nvidia.WithRequiredCUDAVersion
74
-		// for now we let the prestart hook verify cuda versions but errors are not pretty.
75
-	}
76
-
77
-	if nvidiaCaps != nil {
78
-		s.Process.Env = append(s.Process.Env, "NVIDIA_DRIVER_CAPABILITIES="+strings.Join(nvidiaCaps, ","))
79
-	}
80
-
81
-	path, err := exec.LookPath(nvidiaHook)
82
-	if err != nil {
83
-		return err
84
-	}
85
-
86
-	if s.Hooks == nil {
87
-		s.Hooks = &specs.Hooks{}
88
-	}
89
-
90
-	// This implementation uses prestart hooks, which are deprecated.
91
-	// CreateRuntime is the closest equivalent, and executed in the same
92
-	// locations as prestart-hooks, but depending on what these hooks do,
93
-	// possibly one of the other hooks could be used instead (such as
94
-	// CreateContainer or StartContainer).
95
-	s.Hooks.Prestart = append(s.Hooks.Prestart, specs.Hook{ //nolint:staticcheck // FIXME(thaJeztah); replace prestart hook with a non-deprecated one.
96
-		Path: path,
97
-		Args: []string{
98
-			nvidiaHook,
99
-			"prestart",
100
-		},
101
-		Env: os.Environ(),
102
-	})
103
-
104
-	return nil
105
-}
106
-
107
-// countToDevices returns the list 0, 1, ... count-1 of deviceIDs.
108
-func countToDevices(count int) string {
109
-	devices := make([]string, count)
110
-	for i := range devices {
111
-		devices[i] = strconv.Itoa(i)
112
-	}
113
-	return strings.Join(devices, ",")
114
-}