package daemon
import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"strconv"
"strings"
"github.com/containerd/log"
"github.com/moby/moby/api/types/container"
"github.com/moby/moby/v2/daemon/internal/capabilities"
"github.com/opencontainers/runtime-spec/specs-go"
)
// TODO: nvidia should not be hard-coded, and should be a device plugin instead on the daemon object.
// TODO: add list of device capabilities in daemon/node info
var errConflictCountDeviceIDs = errors.New("cannot set both Count and DeviceIDs on device request")
const (
nvidiaContainerRuntimeHookExecutableName = "nvidia-container-runtime-hook"
nvidiaCDIHookExecutableName = "nvidia-cdi-hook"
)
// These are NVIDIA-specific capabilities stolen from github.com/containerd/containerd/contrib/nvidia.allCaps
var allNvidiaCaps = map[string]struct{}{
"compute": {},
"compat32": {},
"graphics": {},
"utility": {},
"video": {},
"display": {},
}
func getNVIDIADeviceDrivers() map[string]*deviceDriver {
var composite firstSuccessfulUpdater
nvidiaDrivers := make(map[string]*deviceDriver)
if _, err := exec.LookPath(nvidiaCDIHookExecutableName); err == nil {
// Register a driver specific to CDI if present.
// This has no capabilities associated to not inadvertently match requests.
cdiDeviceDriver := &deviceDriver{
updateSpec: (&cdiDeviceInjector{
defaultCDIDeviceKind: "nvidia.com/gpu",
}).injectDevices,
}
nvidiaDrivers["nvidia.cdi"] = cdiDeviceDriver
composite = append(composite, cdiDeviceDriver.updateSpec)
}
if _, err := exec.LookPath(nvidiaContainerRuntimeHookExecutableName); err == nil {
// Register a driver specific to the nvidia-container-runtime-hook if present.
// This has no capabilities associated to not inadvertently match requests.
runtimeHookDeviceDriver := &deviceDriver{
updateSpec: injectNVIDIARuntimeHook,
}
nvidiaDrivers["nvidia.runtime-hook"] = runtimeHookDeviceDriver
composite = append(composite, runtimeHookDeviceDriver.updateSpec)
}
if len(nvidiaDrivers) == 0 {
return nil
}
// We associate all NVIDIA capabilities with the composite updater
capset := capabilities.Set{"gpu": struct{}{}, "nvidia": struct{}{}}
for c := range allNvidiaCaps {
capset[c] = struct{}{}
}
nvidiaDrivers["nvidia"] = &deviceDriver{
capset: capset,
updateSpec: composite.updateSpec,
}
return nvidiaDrivers
}
// specUpdaters refer to a list of functions used updated an OCI spec for a
// given device instance.
type firstSuccessfulUpdater []func(*specs.Spec, *deviceInstance) error
// updateSpec returns on the first successful spec update.
func (us firstSuccessfulUpdater) updateSpec(s *specs.Spec, dev *deviceInstance) error {
var errs []error
for _, u := range us {
if u == nil {
continue
}
if err := u(s, dev); err != nil {
errs = append(errs, err)
continue
}
if len(errs) > 0 {
log.G(context.TODO()).WithError(errors.Join(errs...)).Warning("Ignoring previous errors updating spec")
}
return nil
}
return errors.Join(errs...)
}
// injectNVIDIARuntimeHook handles requests for NVIDIA GPUs.
// This is done by updating the OCI runtime spec to include the correct value
// for the NVIDIA_VISIBLE_DEVICES environment variable and injecting the
// NVIDIA Container Runtime Hook as a container prestart hook.
func injectNVIDIARuntimeHook(s *specs.Spec, dev *deviceInstance) error {
deviceIDs, err := getRequestedDevicesIDs(dev.req)
if err != nil {
return err
}
if len(deviceIDs) == 0 {
return nil
}
s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES="+strings.Join(deviceIDs, ","))
var nvidiaCaps []string
// req.Capabilities contains device capabilities, some but not all are NVIDIA driver capabilities.
for _, c := range dev.selectedCaps {
if _, isNvidiaCap := allNvidiaCaps[c]; isNvidiaCap {
nvidiaCaps = append(nvidiaCaps, c)
continue
}
// TODO: nvidia.WithRequiredCUDAVersion
// for now we let the prestart hook verify cuda versions but errors are not pretty.
}
if nvidiaCaps != nil {
s.Process.Env = append(s.Process.Env, "NVIDIA_DRIVER_CAPABILITIES="+strings.Join(nvidiaCaps, ","))
}
path, err := exec.LookPath(nvidiaContainerRuntimeHookExecutableName)
if err != nil {
return err
}
if s.Hooks == nil {
s.Hooks = &specs.Hooks{}
}
// This implementation uses prestart hooks, which are deprecated.
// CreateRuntime is the closest equivalent, and executed in the same
// locations as prestart-hooks, but depending on what these hooks do,
// possibly one of the other hooks could be used instead (such as
// CreateContainer or StartContainer).
s.Hooks.Prestart = append(s.Hooks.Prestart, specs.Hook{ //nolint:staticcheck // FIXME(thaJeztah); replace prestart hook with a non-deprecated one.
Path: path,
Args: []string{
nvidiaContainerRuntimeHookExecutableName,
"prestart",
},
Env: os.Environ(),
})
return nil
}
// getRequestedDeviceIDs returns the list of requested devices by ID based on
// the device request.
func getRequestedDevicesIDs(req container.DeviceRequest) ([]string, error) {
if req.Count != 0 && len(req.DeviceIDs) > 0 {
return nil, errConflictCountDeviceIDs
}
switch {
case len(req.DeviceIDs) > 0:
return req.DeviceIDs, nil
case req.Count > 0:
return countToDevices(req.Count), nil
case req.Count < 0:
return []string{"all"}, nil
case req.Count == 0:
return nil, nil
}
return nil, nil
}
// countToDevices returns the list 0, 1, ... count-1 of deviceIDs.
func countToDevices(count int) []string {
devices := make([]string, count)
for i := range devices {
devices[i] = strconv.Itoa(i)
}
return devices
}
// A cdiDeviceInjector is used to map regular device requests to CDI device
// requests.
type cdiDeviceInjector struct {
defaultCDIDeviceKind string
}
// injectDevices converts an incoming device request to a request for devices
// using CDI.
// The requested device IDs are converted to CDI device names if required using
// the specified default kind.
func (i *cdiDeviceInjector) injectDevices(s *specs.Spec, dev *deviceInstance) error {
deviceIDs, err := getRequestedDevicesIDs(dev.req)
if err != nil {
return err
}
if len(deviceIDs) == 0 {
return nil
}
// If the cdi device driver is not available then we return an error.
cdiDeviceDriver := deviceDrivers["cdi"]
if cdiDeviceDriver == nil {
return fmt.Errorf("no CDI device driver registered: %w", incompatibleDeviceRequest{dev.req.Driver, dev.req.Capabilities})
}
var cdiDeviceIDs []string
for _, deviceID := range deviceIDs {
cdiDeviceIDs = append(cdiDeviceIDs, i.normalizeDeviceID(deviceID))
}
// We construct a device instance using the CDI device IDs and forward this
// to the cdiDeviceDriver.
return cdiDeviceDriver.updateSpec(s, &deviceInstance{
req: container.DeviceRequest{
Driver: dev.req.Driver,
DeviceIDs: cdiDeviceIDs,
Capabilities: dev.req.Capabilities,
},
selectedCaps: nil,
})
}
// normalizeDeviceID ensures that the specified deviceID is a fully-qualified
// CDI device name.
// If the deviceID is already a fully-qualified CDI device name it is returned
// as-is, otherwise, the default CDI device kind (vendor/class) is used to
// construct a fully qualified CDI device name.
func (i *cdiDeviceInjector) normalizeDeviceID(deviceID string) string {
// if deviceID is of the form vendor.com/class=name, we return it as-is.
// TODO: We should ideally use the parser from the tags.cncf.io/cdi packages.
if _, _, ok := strings.Cut(deviceID, "="); ok {
return deviceID
}
return i.defaultCDIDeviceKind + "=" + deviceID
}