Use CDI for GPU injection for AMD devices for --gpus
| ... | ... |
@@ -285,6 +285,8 @@ func (cli *daemonCLI) start(ctx context.Context) (retErr error) {
|
| 285 | 285 |
cdiCache = daemon.RegisterCDIDriver(cli.Config.CDISpecDirs...) |
| 286 | 286 |
} |
| 287 | 287 |
|
| 288 |
+ daemon.RegisterGPUDeviceDrivers(cdiCache) |
|
| 289 |
+ |
|
| 288 | 290 |
var apiServer apiserver.Server |
| 289 | 291 |
authz, err := initMiddlewares(ctx, &apiServer, cli.Config, pluginStore) |
| 290 | 292 |
if err != nil {
|
| ... | ... |
@@ -2,6 +2,7 @@ package daemon |
| 2 | 2 |
|
| 3 | 3 |
import ( |
| 4 | 4 |
"context" |
| 5 |
+ "errors" |
|
| 5 | 6 |
|
| 6 | 7 |
"github.com/containerd/log" |
| 7 | 8 |
"github.com/moby/moby/api/types/container" |
| ... | ... |
@@ -38,6 +39,18 @@ func registerDeviceDriver(name string, d *deviceDriver) {
|
| 38 | 38 |
deviceDrivers[name] = d |
| 39 | 39 |
} |
| 40 | 40 |
|
| 41 |
+func getFirstAvailableVendor(vendorList []string) (string, error) {
|
|
| 42 |
+ knownVendors := []string{"nvidia.com", "amd.com"}
|
|
| 43 |
+ for _, vendor := range knownVendors {
|
|
| 44 |
+ for _, available := range vendorList {
|
|
| 45 |
+ if vendor == available {
|
|
| 46 |
+ return vendor, nil |
|
| 47 |
+ } |
|
| 48 |
+ } |
|
| 49 |
+ } |
|
| 50 |
+ return "", errors.New("no known GPU vendor found")
|
|
| 51 |
+} |
|
| 52 |
+ |
|
| 41 | 53 |
func (daemon *Daemon) handleDevice(req container.DeviceRequest, spec *specs.Spec) error {
|
| 42 | 54 |
if req.Driver == "" {
|
| 43 | 55 |
// If no driver is explicitly requested, we iterate over the registered |
| ... | ... |
@@ -1,9 +1,18 @@ |
| 1 | 1 |
package daemon |
| 2 | 2 |
|
| 3 | 3 |
import ( |
| 4 |
+ "errors" |
|
| 5 |
+ "fmt" |
|
| 6 |
+ "os/exec" |
|
| 4 | 7 |
"strings" |
| 5 | 8 |
|
| 9 |
+ "github.com/moby/moby/v2/daemon/internal/capabilities" |
|
| 6 | 10 |
"github.com/opencontainers/runtime-spec/specs-go" |
| 11 |
+ "tags.cncf.io/container-device-interface/pkg/cdi" |
|
| 12 |
+) |
|
| 13 |
+ |
|
| 14 |
+const ( |
|
| 15 |
+ amdContainerRuntimeExecutableName = "amd-container-runtime" |
|
| 7 | 16 |
) |
| 8 | 17 |
|
| 9 | 18 |
func setAMDGPUs(s *specs.Spec, dev *deviceInstance) error {
|
| ... | ... |
@@ -25,3 +34,46 @@ func setAMDGPUs(s *specs.Spec, dev *deviceInstance) error {
|
| 25 | 25 |
|
| 26 | 26 |
return nil |
| 27 | 27 |
} |
| 28 |
+ |
|
| 29 |
+func createAMDCDIUpdater(cdiCache *cdi.Cache) func(*specs.Spec, *deviceInstance) error {
|
|
| 30 |
+ return func(s *specs.Spec, dev *deviceInstance) error {
|
|
| 31 |
+ vendor, err := getFirstAvailableVendor(cdiCache.ListVendors()) |
|
| 32 |
+ if err != nil {
|
|
| 33 |
+ return fmt.Errorf("failed to discover GPU vendor from CDI: %w", err)
|
|
| 34 |
+ } |
|
| 35 |
+ |
|
| 36 |
+ if vendor != "amd.com" {
|
|
| 37 |
+ return errors.New("AMD CDI spec not found")
|
|
| 38 |
+ } |
|
| 39 |
+ |
|
| 40 |
+ injector := &cdiDeviceInjector{
|
|
| 41 |
+ defaultCDIDeviceKind: "amd.com/gpu", |
|
| 42 |
+ } |
|
| 43 |
+ return injector.injectDevices(s, dev) |
|
| 44 |
+ } |
|
| 45 |
+} |
|
| 46 |
+ |
|
| 47 |
+func getAMDDeviceDrivers(cdiCache *cdi.Cache) *deviceDriver {
|
|
| 48 |
+ var composite firstSuccessfulUpdater |
|
| 49 |
+ |
|
| 50 |
+ if cdiCache != nil {
|
|
| 51 |
+ composite = append(composite, createAMDCDIUpdater(cdiCache)) |
|
| 52 |
+ } |
|
| 53 |
+ |
|
| 54 |
+ if _, err := exec.LookPath(amdContainerRuntimeExecutableName); err == nil {
|
|
| 55 |
+ composite = append(composite, setAMDGPUs) |
|
| 56 |
+ } |
|
| 57 |
+ |
|
| 58 |
+ if len(composite) == 0 {
|
|
| 59 |
+ return nil |
|
| 60 |
+ } |
|
| 61 |
+ |
|
| 62 |
+ // We do not support specifying driver with device requests for AMD GPUs. |
|
| 63 |
+ // Hence only use the composite updater and try cdi and runtime driver in sequence |
|
| 64 |
+ // based on availability. |
|
| 65 |
+ capset := capabilities.Set{"gpu": struct{}{}, "amd": struct{}{}}
|
|
| 66 |
+ return &deviceDriver{
|
|
| 67 |
+ capset: capset, |
|
| 68 |
+ updateSpec: composite.updateSpec, |
|
| 69 |
+ } |
|
| 70 |
+} |
| 28 | 71 |
new file mode 100644 |
| ... | ... |
@@ -0,0 +1,22 @@ |
| 0 |
+package daemon |
|
| 1 |
+ |
|
| 2 |
+import "tags.cncf.io/container-device-interface/pkg/cdi" |
|
| 3 |
+ |
|
| 4 |
+// RegisterGPUDeviceDrivers registers GPU device drivers. |
|
| 5 |
+// If the cdiCache is provided, it is used to detect presence of CDI specs for AMD GPUs. |
|
| 6 |
+// For NVIDIA GPUs, presence of CDI specs is detected by checking for the nvidia-cdi-hook binary. |
|
| 7 |
+func RegisterGPUDeviceDrivers(cdiCache *cdi.Cache) {
|
|
| 8 |
+ // Register NVIDIA device drivers. |
|
| 9 |
+ if nvidiaDrivers := getNVIDIADeviceDrivers(); len(nvidiaDrivers) > 0 {
|
|
| 10 |
+ for name, driver := range nvidiaDrivers {
|
|
| 11 |
+ registerDeviceDriver(name, driver) |
|
| 12 |
+ } |
|
| 13 |
+ return |
|
| 14 |
+ } |
|
| 15 |
+ |
|
| 16 |
+ // Register AMD driver if AMD CDI spec or helper binary is present. |
|
| 17 |
+ if amdDriver := getAMDDeviceDrivers(cdiCache); amdDriver != nil {
|
|
| 18 |
+ registerDeviceDriver("amd", amdDriver)
|
|
| 19 |
+ return |
|
| 20 |
+ } |
|
| 21 |
+} |
| ... | ... |
@@ -23,7 +23,6 @@ var errConflictCountDeviceIDs = errors.New("cannot set both Count and DeviceIDs
|
| 23 | 23 |
const ( |
| 24 | 24 |
nvidiaContainerRuntimeHookExecutableName = "nvidia-container-runtime-hook" |
| 25 | 25 |
nvidiaCDIHookExecutableName = "nvidia-cdi-hook" |
| 26 |
- amdContainerRuntimeExecutableName = "amd-container-runtime" |
|
| 27 | 26 |
) |
| 28 | 27 |
|
| 29 | 28 |
// These are NVIDIA-specific capabilities stolen from github.com/containerd/containerd/contrib/nvidia.allCaps |
| ... | ... |
@@ -36,27 +35,6 @@ var allNvidiaCaps = map[string]struct{}{
|
| 36 | 36 |
"display": {},
|
| 37 | 37 |
} |
| 38 | 38 |
|
| 39 |
-func init() {
|
|
| 40 |
- // Register NVIDIA device drivers. |
|
| 41 |
- if nvidiaDrivers := getNVIDIADeviceDrivers(); len(nvidiaDrivers) > 0 {
|
|
| 42 |
- for name, driver := range nvidiaDrivers {
|
|
| 43 |
- registerDeviceDriver(name, driver) |
|
| 44 |
- } |
|
| 45 |
- return |
|
| 46 |
- } |
|
| 47 |
- |
|
| 48 |
- // Register AMD driver if AMD helper binary is present. |
|
| 49 |
- if _, err := exec.LookPath(amdContainerRuntimeExecutableName); err == nil {
|
|
| 50 |
- registerDeviceDriver("amd", &deviceDriver{
|
|
| 51 |
- capset: capabilities.Set{"gpu": struct{}{}, "amd": struct{}{}},
|
|
| 52 |
- updateSpec: setAMDGPUs, |
|
| 53 |
- }) |
|
| 54 |
- return |
|
| 55 |
- } |
|
| 56 |
- |
|
| 57 |
- // No "gpu" capability |
|
| 58 |
-} |
|
| 59 |
- |
|
| 60 | 39 |
func getNVIDIADeviceDrivers() map[string]*deviceDriver {
|
| 61 | 40 |
var composite firstSuccessfulUpdater |
| 62 | 41 |
nvidiaDrivers := make(map[string]*deviceDriver) |
| 63 | 42 |
new file mode 100644 |
| ... | ... |
@@ -0,0 +1,55 @@ |
| 0 |
+package daemon |
|
| 1 |
+ |
|
| 2 |
+import ( |
|
| 3 |
+ "testing" |
|
| 4 |
+ |
|
| 5 |
+ "gotest.tools/v3/assert" |
|
| 6 |
+) |
|
| 7 |
+ |
|
| 8 |
+func TestGetFirstAvailableVendor(t *testing.T) {
|
|
| 9 |
+ tests := []struct {
|
|
| 10 |
+ name string |
|
| 11 |
+ vendors []string |
|
| 12 |
+ expectVendor string |
|
| 13 |
+ expectError string |
|
| 14 |
+ }{
|
|
| 15 |
+ {
|
|
| 16 |
+ name: "NVIDIA vendor", |
|
| 17 |
+ vendors: []string{"nvidia.com"},
|
|
| 18 |
+ expectVendor: "nvidia.com", |
|
| 19 |
+ }, |
|
| 20 |
+ {
|
|
| 21 |
+ name: "AMD vendor", |
|
| 22 |
+ vendors: []string{"amd.com"},
|
|
| 23 |
+ expectVendor: "amd.com", |
|
| 24 |
+ }, |
|
| 25 |
+ {
|
|
| 26 |
+ name: "No vendors", |
|
| 27 |
+ vendors: nil, |
|
| 28 |
+ expectError: "no known GPU vendor found", |
|
| 29 |
+ }, |
|
| 30 |
+ {
|
|
| 31 |
+ name: "Unknown vendor", |
|
| 32 |
+ vendors: []string{"unknown.com"},
|
|
| 33 |
+ expectError: "no known GPU vendor found", |
|
| 34 |
+ }, |
|
| 35 |
+ {
|
|
| 36 |
+ name: "Mixed vendor", |
|
| 37 |
+ vendors: []string{"amd.com", "nvidia.com"},
|
|
| 38 |
+ expectVendor: "nvidia.com", |
|
| 39 |
+ }, |
|
| 40 |
+ } |
|
| 41 |
+ |
|
| 42 |
+ for _, tt := range tests {
|
|
| 43 |
+ t.Run(tt.name, func(t *testing.T) {
|
|
| 44 |
+ vendor, err := getFirstAvailableVendor(tt.vendors) |
|
| 45 |
+ |
|
| 46 |
+ if tt.expectError != "" {
|
|
| 47 |
+ assert.Error(t, err, tt.expectError) |
|
| 48 |
+ } else {
|
|
| 49 |
+ assert.NilError(t, err) |
|
| 50 |
+ assert.Equal(t, tt.expectVendor, vendor) |
|
| 51 |
+ } |
|
| 52 |
+ }) |
|
| 53 |
+ } |
|
| 54 |
+} |