package daemon import ( "context" "errors" "fmt" "os" "os/exec" "strconv" "strings" "github.com/containerd/log" "github.com/moby/moby/api/types/container" "github.com/moby/moby/v2/daemon/internal/capabilities" "github.com/opencontainers/runtime-spec/specs-go" ) // TODO: nvidia should not be hard-coded, and should be a device plugin instead on the daemon object. // TODO: add list of device capabilities in daemon/node info var errConflictCountDeviceIDs = errors.New("cannot set both Count and DeviceIDs on device request") const ( nvidiaContainerRuntimeHookExecutableName = "nvidia-container-runtime-hook" nvidiaCDIHookExecutableName = "nvidia-cdi-hook" ) // These are NVIDIA-specific capabilities stolen from github.com/containerd/containerd/contrib/nvidia.allCaps var allNvidiaCaps = map[string]struct{}{ "compute": {}, "compat32": {}, "graphics": {}, "utility": {}, "video": {}, "display": {}, } func getNVIDIADeviceDrivers() map[string]*deviceDriver { var composite firstSuccessfulUpdater nvidiaDrivers := make(map[string]*deviceDriver) if _, err := exec.LookPath(nvidiaCDIHookExecutableName); err == nil { // Register a driver specific to CDI if present. // This has no capabilities associated to not inadvertently match requests. cdiDeviceDriver := &deviceDriver{ updateSpec: (&cdiDeviceInjector{ defaultCDIDeviceKind: "nvidia.com/gpu", }).injectDevices, } nvidiaDrivers["nvidia.cdi"] = cdiDeviceDriver composite = append(composite, cdiDeviceDriver.updateSpec) } if _, err := exec.LookPath(nvidiaContainerRuntimeHookExecutableName); err == nil { // Register a driver specific to the nvidia-container-runtime-hook if present. // This has no capabilities associated to not inadvertently match requests. runtimeHookDeviceDriver := &deviceDriver{ updateSpec: injectNVIDIARuntimeHook, } nvidiaDrivers["nvidia.runtime-hook"] = runtimeHookDeviceDriver composite = append(composite, runtimeHookDeviceDriver.updateSpec) } if len(nvidiaDrivers) == 0 { return nil } // We associate all NVIDIA capabilities with the composite updater capset := capabilities.Set{"gpu": struct{}{}, "nvidia": struct{}{}} for c := range allNvidiaCaps { capset[c] = struct{}{} } nvidiaDrivers["nvidia"] = &deviceDriver{ capset: capset, updateSpec: composite.updateSpec, } return nvidiaDrivers } // specUpdaters refer to a list of functions used updated an OCI spec for a // given device instance. type firstSuccessfulUpdater []func(*specs.Spec, *deviceInstance) error // updateSpec returns on the first successful spec update. func (us firstSuccessfulUpdater) updateSpec(s *specs.Spec, dev *deviceInstance) error { var errs []error for _, u := range us { if u == nil { continue } if err := u(s, dev); err != nil { errs = append(errs, err) continue } if len(errs) > 0 { log.G(context.TODO()).WithError(errors.Join(errs...)).Warning("Ignoring previous errors updating spec") } return nil } return errors.Join(errs...) } // injectNVIDIARuntimeHook handles requests for NVIDIA GPUs. // This is done by updating the OCI runtime spec to include the correct value // for the NVIDIA_VISIBLE_DEVICES environment variable and injecting the // NVIDIA Container Runtime Hook as a container prestart hook. func injectNVIDIARuntimeHook(s *specs.Spec, dev *deviceInstance) error { deviceIDs, err := getRequestedDevicesIDs(dev.req) if err != nil { return err } if len(deviceIDs) == 0 { return nil } s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES="+strings.Join(deviceIDs, ",")) var nvidiaCaps []string // req.Capabilities contains device capabilities, some but not all are NVIDIA driver capabilities. for _, c := range dev.selectedCaps { if _, isNvidiaCap := allNvidiaCaps[c]; isNvidiaCap { nvidiaCaps = append(nvidiaCaps, c) continue } // TODO: nvidia.WithRequiredCUDAVersion // for now we let the prestart hook verify cuda versions but errors are not pretty. } if nvidiaCaps != nil { s.Process.Env = append(s.Process.Env, "NVIDIA_DRIVER_CAPABILITIES="+strings.Join(nvidiaCaps, ",")) } path, err := exec.LookPath(nvidiaContainerRuntimeHookExecutableName) if err != nil { return err } if s.Hooks == nil { s.Hooks = &specs.Hooks{} } // This implementation uses prestart hooks, which are deprecated. // CreateRuntime is the closest equivalent, and executed in the same // locations as prestart-hooks, but depending on what these hooks do, // possibly one of the other hooks could be used instead (such as // CreateContainer or StartContainer). s.Hooks.Prestart = append(s.Hooks.Prestart, specs.Hook{ //nolint:staticcheck // FIXME(thaJeztah); replace prestart hook with a non-deprecated one. Path: path, Args: []string{ nvidiaContainerRuntimeHookExecutableName, "prestart", }, Env: os.Environ(), }) return nil } // getRequestedDeviceIDs returns the list of requested devices by ID based on // the device request. func getRequestedDevicesIDs(req container.DeviceRequest) ([]string, error) { if req.Count != 0 && len(req.DeviceIDs) > 0 { return nil, errConflictCountDeviceIDs } switch { case len(req.DeviceIDs) > 0: return req.DeviceIDs, nil case req.Count > 0: return countToDevices(req.Count), nil case req.Count < 0: return []string{"all"}, nil case req.Count == 0: return nil, nil } return nil, nil } // countToDevices returns the list 0, 1, ... count-1 of deviceIDs. func countToDevices(count int) []string { devices := make([]string, count) for i := range devices { devices[i] = strconv.Itoa(i) } return devices } // A cdiDeviceInjector is used to map regular device requests to CDI device // requests. type cdiDeviceInjector struct { defaultCDIDeviceKind string } // injectDevices converts an incoming device request to a request for devices // using CDI. // The requested device IDs are converted to CDI device names if required using // the specified default kind. func (i *cdiDeviceInjector) injectDevices(s *specs.Spec, dev *deviceInstance) error { deviceIDs, err := getRequestedDevicesIDs(dev.req) if err != nil { return err } if len(deviceIDs) == 0 { return nil } // If the cdi device driver is not available then we return an error. cdiDeviceDriver := deviceDrivers["cdi"] if cdiDeviceDriver == nil { return fmt.Errorf("no CDI device driver registered: %w", incompatibleDeviceRequest{dev.req.Driver, dev.req.Capabilities}) } var cdiDeviceIDs []string for _, deviceID := range deviceIDs { cdiDeviceIDs = append(cdiDeviceIDs, i.normalizeDeviceID(deviceID)) } // We construct a device instance using the CDI device IDs and forward this // to the cdiDeviceDriver. return cdiDeviceDriver.updateSpec(s, &deviceInstance{ req: container.DeviceRequest{ Driver: dev.req.Driver, DeviceIDs: cdiDeviceIDs, Capabilities: dev.req.Capabilities, }, selectedCaps: nil, }) } // normalizeDeviceID ensures that the specified deviceID is a fully-qualified // CDI device name. // If the deviceID is already a fully-qualified CDI device name it is returned // as-is, otherwise, the default CDI device kind (vendor/class) is used to // construct a fully qualified CDI device name. func (i *cdiDeviceInjector) normalizeDeviceID(deviceID string) string { // if deviceID is of the form vendor.com/class=name, we return it as-is. // TODO: We should ideally use the parser from the tags.cncf.io/cdi packages. if _, _, ok := strings.Cut(deviceID, "="); ok { return deviceID } return i.defaultCDIDeviceKind + "=" + deviceID }