Browse code

Merge pull request #52048 from shiv-tyagi/vendor-detection

Use CDI for GPU injection for AMD devices for --gpus

Paweł Gronowski authored on 2026/03/05 21:06:54
Showing 7 changed files
... ...
@@ -285,6 +285,8 @@ func (cli *daemonCLI) start(ctx context.Context) (retErr error) {
285 285
 		cdiCache = daemon.RegisterCDIDriver(cli.Config.CDISpecDirs...)
286 286
 	}
287 287
 
288
+	daemon.RegisterGPUDeviceDrivers(cdiCache)
289
+
288 290
 	var apiServer apiserver.Server
289 291
 	authz, err := initMiddlewares(ctx, &apiServer, cli.Config, pluginStore)
290 292
 	if err != nil {
... ...
@@ -2,6 +2,7 @@ package daemon
2 2
 
3 3
 import (
4 4
 	"context"
5
+	"errors"
5 6
 
6 7
 	"github.com/containerd/log"
7 8
 	"github.com/moby/moby/api/types/container"
... ...
@@ -38,6 +39,18 @@ func registerDeviceDriver(name string, d *deviceDriver) {
38 38
 	deviceDrivers[name] = d
39 39
 }
40 40
 
41
+func getFirstAvailableVendor(vendorList []string) (string, error) {
42
+	knownVendors := []string{"nvidia.com", "amd.com"}
43
+	for _, vendor := range knownVendors {
44
+		for _, available := range vendorList {
45
+			if vendor == available {
46
+				return vendor, nil
47
+			}
48
+		}
49
+	}
50
+	return "", errors.New("no known GPU vendor found")
51
+}
52
+
41 53
 func (daemon *Daemon) handleDevice(req container.DeviceRequest, spec *specs.Spec) error {
42 54
 	if req.Driver == "" {
43 55
 		// If no driver is explicitly requested, we iterate over the registered
... ...
@@ -1,9 +1,18 @@
1 1
 package daemon
2 2
 
3 3
 import (
4
+	"errors"
5
+	"fmt"
6
+	"os/exec"
4 7
 	"strings"
5 8
 
9
+	"github.com/moby/moby/v2/daemon/internal/capabilities"
6 10
 	"github.com/opencontainers/runtime-spec/specs-go"
11
+	"tags.cncf.io/container-device-interface/pkg/cdi"
12
+)
13
+
14
+const (
15
+	amdContainerRuntimeExecutableName = "amd-container-runtime"
7 16
 )
8 17
 
9 18
 func setAMDGPUs(s *specs.Spec, dev *deviceInstance) error {
... ...
@@ -25,3 +34,46 @@ func setAMDGPUs(s *specs.Spec, dev *deviceInstance) error {
25 25
 
26 26
 	return nil
27 27
 }
28
+
29
+func createAMDCDIUpdater(cdiCache *cdi.Cache) func(*specs.Spec, *deviceInstance) error {
30
+	return func(s *specs.Spec, dev *deviceInstance) error {
31
+		vendor, err := getFirstAvailableVendor(cdiCache.ListVendors())
32
+		if err != nil {
33
+			return fmt.Errorf("failed to discover GPU vendor from CDI: %w", err)
34
+		}
35
+
36
+		if vendor != "amd.com" {
37
+			return errors.New("AMD CDI spec not found")
38
+		}
39
+
40
+		injector := &cdiDeviceInjector{
41
+			defaultCDIDeviceKind: "amd.com/gpu",
42
+		}
43
+		return injector.injectDevices(s, dev)
44
+	}
45
+}
46
+
47
+func getAMDDeviceDrivers(cdiCache *cdi.Cache) *deviceDriver {
48
+	var composite firstSuccessfulUpdater
49
+
50
+	if cdiCache != nil {
51
+		composite = append(composite, createAMDCDIUpdater(cdiCache))
52
+	}
53
+
54
+	if _, err := exec.LookPath(amdContainerRuntimeExecutableName); err == nil {
55
+		composite = append(composite, setAMDGPUs)
56
+	}
57
+
58
+	if len(composite) == 0 {
59
+		return nil
60
+	}
61
+
62
+	// We do not support specifying driver with device requests for AMD GPUs.
63
+	// Hence only use the composite updater and try cdi and runtime driver in sequence
64
+	// based on availability.
65
+	capset := capabilities.Set{"gpu": struct{}{}, "amd": struct{}{}}
66
+	return &deviceDriver{
67
+		capset:     capset,
68
+		updateSpec: composite.updateSpec,
69
+	}
70
+}
28 71
new file mode 100644
... ...
@@ -0,0 +1,22 @@
0
+package daemon
1
+
2
+import "tags.cncf.io/container-device-interface/pkg/cdi"
3
+
4
+// RegisterGPUDeviceDrivers registers GPU device drivers.
5
+// If the cdiCache is provided, it is used to detect presence of CDI specs for AMD GPUs.
6
+// For NVIDIA GPUs, presence of CDI specs is detected by checking for the nvidia-cdi-hook binary.
7
+func RegisterGPUDeviceDrivers(cdiCache *cdi.Cache) {
8
+	// Register NVIDIA device drivers.
9
+	if nvidiaDrivers := getNVIDIADeviceDrivers(); len(nvidiaDrivers) > 0 {
10
+		for name, driver := range nvidiaDrivers {
11
+			registerDeviceDriver(name, driver)
12
+		}
13
+		return
14
+	}
15
+
16
+	// Register AMD driver if AMD CDI spec or helper binary is present.
17
+	if amdDriver := getAMDDeviceDrivers(cdiCache); amdDriver != nil {
18
+		registerDeviceDriver("amd", amdDriver)
19
+		return
20
+	}
21
+}
0 22
new file mode 100644
... ...
@@ -0,0 +1,8 @@
0
+//go:build !linux
1
+
2
+package daemon
3
+
4
+import "tags.cncf.io/container-device-interface/pkg/cdi"
5
+
6
+// RegisterGPUDeviceDrivers is a no-op on non-Linux platforms.
7
+func RegisterGPUDeviceDrivers(_ *cdi.Cache) {}
... ...
@@ -23,7 +23,6 @@ var errConflictCountDeviceIDs = errors.New("cannot set both Count and DeviceIDs
23 23
 const (
24 24
 	nvidiaContainerRuntimeHookExecutableName = "nvidia-container-runtime-hook"
25 25
 	nvidiaCDIHookExecutableName              = "nvidia-cdi-hook"
26
-	amdContainerRuntimeExecutableName        = "amd-container-runtime"
27 26
 )
28 27
 
29 28
 // These are NVIDIA-specific capabilities stolen from github.com/containerd/containerd/contrib/nvidia.allCaps
... ...
@@ -36,27 +35,6 @@ var allNvidiaCaps = map[string]struct{}{
36 36
 	"display":  {},
37 37
 }
38 38
 
39
-func init() {
40
-	// Register NVIDIA device drivers.
41
-	if nvidiaDrivers := getNVIDIADeviceDrivers(); len(nvidiaDrivers) > 0 {
42
-		for name, driver := range nvidiaDrivers {
43
-			registerDeviceDriver(name, driver)
44
-		}
45
-		return
46
-	}
47
-
48
-	// Register AMD driver if AMD helper binary is present.
49
-	if _, err := exec.LookPath(amdContainerRuntimeExecutableName); err == nil {
50
-		registerDeviceDriver("amd", &deviceDriver{
51
-			capset:     capabilities.Set{"gpu": struct{}{}, "amd": struct{}{}},
52
-			updateSpec: setAMDGPUs,
53
-		})
54
-		return
55
-	}
56
-
57
-	// No "gpu" capability
58
-}
59
-
60 39
 func getNVIDIADeviceDrivers() map[string]*deviceDriver {
61 40
 	var composite firstSuccessfulUpdater
62 41
 	nvidiaDrivers := make(map[string]*deviceDriver)
63 42
new file mode 100644
... ...
@@ -0,0 +1,55 @@
0
+package daemon
1
+
2
+import (
3
+	"testing"
4
+
5
+	"gotest.tools/v3/assert"
6
+)
7
+
8
+func TestGetFirstAvailableVendor(t *testing.T) {
9
+	tests := []struct {
10
+		name         string
11
+		vendors      []string
12
+		expectVendor string
13
+		expectError  string
14
+	}{
15
+		{
16
+			name:         "NVIDIA vendor",
17
+			vendors:      []string{"nvidia.com"},
18
+			expectVendor: "nvidia.com",
19
+		},
20
+		{
21
+			name:         "AMD vendor",
22
+			vendors:      []string{"amd.com"},
23
+			expectVendor: "amd.com",
24
+		},
25
+		{
26
+			name:        "No vendors",
27
+			vendors:     nil,
28
+			expectError: "no known GPU vendor found",
29
+		},
30
+		{
31
+			name:        "Unknown vendor",
32
+			vendors:     []string{"unknown.com"},
33
+			expectError: "no known GPU vendor found",
34
+		},
35
+		{
36
+			name:         "Mixed vendor",
37
+			vendors:      []string{"amd.com", "nvidia.com"},
38
+			expectVendor: "nvidia.com",
39
+		},
40
+	}
41
+
42
+	for _, tt := range tests {
43
+		t.Run(tt.name, func(t *testing.T) {
44
+			vendor, err := getFirstAvailableVendor(tt.vendors)
45
+
46
+			if tt.expectError != "" {
47
+				assert.Error(t, err, tt.expectError)
48
+			} else {
49
+				assert.NilError(t, err)
50
+				assert.Equal(t, tt.expectVendor, vendor)
51
+			}
52
+		})
53
+	}
54
+}