/* * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package main import ( "flag" "log" "net" "os" "path" "time" "golang.org/x/net/context" "google.golang.org/grpc" pluginapi "k8s.io/kubernetes/pkg/kubelet/apis/deviceplugin/v1beta1" ) var passDeviceSpecs = flag.Bool("pass-device-specs", false, "pass the list of DeviceSpecs to the kubelet on Allocate()") // NvidiaDevicePlugin implements the Kubernetes device plugin API type NvidiaDevicePlugin struct { devs []*pluginapi.Device resourceName string allocateEnvvar string socket string id string server *grpc.Server stop chan interface{} health chan *pluginapi.Device } // NewNvidiaDevicePlugin returns an initialized NvidiaDevicePlugin func NewNvidiaDevicePlugin(resourceName string, allocateEnvvar string, socket string, id string) *NvidiaDevicePlugin { return &NvidiaDevicePlugin{ resourceName: resourceName, allocateEnvvar: allocateEnvvar, socket: socket, id: id, } } func (m *NvidiaDevicePlugin) initialize() { m.server = grpc.NewServer([]grpc.ServerOption{}...) } func (m *NvidiaDevicePlugin) cleanup() { } // Start starts the gRPC server, registers the device plugin with the Kubelet, // and starts the device healthchecks. func (m *NvidiaDevicePlugin) Start() error { m.initialize() err := m.Serve() if err != nil { log.Printf("Could not start device plugin for '%s': %s", m.resourceName, err) m.cleanup() return err } log.Printf("Starting to serve '%s' on %s", m.resourceName, m.socket) err = m.Register() if err != nil { log.Printf("Could not register device plugin: %s", err) m.Stop() return err } log.Printf("Registered device plugin for '%s' with Kubelet", m.resourceName) return nil } // Stop stops the gRPC server. func (m *NvidiaDevicePlugin) Stop() error { if m == nil || m.server == nil { return nil } log.Printf("Stopping to serve '%s' on %s", m.resourceName, m.socket) m.server.Stop() if err := os.Remove(m.socket); err != nil && !os.IsNotExist(err) { return err } m.cleanup() return nil } // Serve starts the gRPC server of the device plugin. func (m *NvidiaDevicePlugin) Serve() error { sock, err := net.Listen("unix", m.socket) if err != nil { return err } pluginapi.RegisterDevicePluginServer(m.server, m) go func() { lastCrashTime := time.Now() restartCount := 0 for { log.Printf("Starting GRPC server for '%s'", m.resourceName) err := m.server.Serve(sock) if err == nil { break } log.Printf("GRPC server for '%s' crashed with error: %v", m.resourceName, err) // restart if it has not been too often // i.e. if server has crashed more than 5 times and it didn't last more than one hour each time if restartCount > 5 { // quit log.Fatal("GRPC server for '%s' has repeatedly crashed recently. Quitting", m.resourceName) } timeSinceLastCrash := time.Since(lastCrashTime).Seconds() lastCrashTime = time.Now() if timeSinceLastCrash > 3600 { // it has been one hour since the last crash.. reset the count // to reflect on the frequency restartCount = 1 } else { restartCount += 1 } } }() // Wait for server to start by launching a blocking connexion conn, err := m.dial(m.socket, 5*time.Second) if err != nil { return err } conn.Close() return nil } // Register registers the device plugin for the given resourceName with Kubelet. func (m *NvidiaDevicePlugin) Register() error { conn, err := m.dial(pluginapi.KubeletSocket, 5*time.Second) if err != nil { return err } defer conn.Close() client := pluginapi.NewRegistrationClient(conn) reqt := &pluginapi.RegisterRequest{ Version: pluginapi.Version, Endpoint: path.Base(m.socket), ResourceName: m.resourceName, } _, err = client.Register(context.Background(), reqt) if err != nil { return err } return nil } func (m *NvidiaDevicePlugin) GetDevicePluginOptions(context.Context, *pluginapi.Empty) (*pluginapi.DevicePluginOptions, error) { return &pluginapi.DevicePluginOptions{}, nil } // ListAndWatch lists devices and update that list according to the health status func (m *NvidiaDevicePlugin) ListAndWatch(e *pluginapi.Empty, s pluginapi.DevicePlugin_ListAndWatchServer) error { s.Send(&pluginapi.ListAndWatchResponse{Devices: m.devs}) for { select { case <-m.stop: return nil case d := <-m.health: // FIXME: there is no way to recover from the Unhealthy state. d.Health = pluginapi.Unhealthy log.Printf("'%s' device marked unhealthy: %s", m.resourceName, d.ID) s.Send(&pluginapi.ListAndWatchResponse{Devices: m.devs}) } } } // Allocate which return list of devices. func (m *NvidiaDevicePlugin) Allocate(ctx context.Context, reqs *pluginapi.AllocateRequest) (*pluginapi.AllocateResponse, error) { responses := pluginapi.AllocateResponse{} for _, req := range reqs.ContainerRequests { //for _, id := range req.DevicesIDs { // if !m.deviceExists(id) { // return nil, fmt.Errorf("invalid allocation request for '%s': unknown device: %s", m.resourceName, id) // } // response := pluginapi.ContainerAllocateResponse{ Envs: map[string]string{ m.allocateEnvvar: m.id, }, } if *passDeviceSpecs { response.Devices = m.apiDeviceSpecs(req.DevicesIDs) } responses.ContainerResponses = append(responses.ContainerResponses, &response) } return &responses, nil } func (m *NvidiaDevicePlugin) PreStartContainer(context.Context, *pluginapi.PreStartContainerRequest) (*pluginapi.PreStartContainerResponse, error) { return &pluginapi.PreStartContainerResponse{}, nil } // dial establishes the gRPC communication with the registered device plugin. func (m *NvidiaDevicePlugin) dial(unixSocketPath string, timeout time.Duration) (*grpc.ClientConn, error) { c, err := grpc.Dial(unixSocketPath, grpc.WithInsecure(), grpc.WithBlock(), grpc.WithTimeout(timeout), grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) { return net.DialTimeout("unix", addr, timeout) }), ) if err != nil { return nil, err } return c, nil } //func (m *NvidiaDevicePlugin) deviceExists(id string) bool { // for _, d := range m.cachedDevices { // if d.ID == id { // return true // } // } // return false //} //func (m *NvidiaDevicePlugin) apiDevices() []*pluginapi.Device { // var pdevs []*pluginapi.Device // for _, d := range m.cachedDevices { // pdevs = append(pdevs, &d.Device) // } // return pdevs //} func (m *NvidiaDevicePlugin) apiDeviceSpecs(filter []string) []*pluginapi.DeviceSpec { var specs []*pluginapi.DeviceSpec paths := []string{ "/dev/nvidiactl", "/dev/nvidia-uvm", "/dev/nvidia-uvm-tools", "/dev/nvidia-modeset", } for _, p := range paths { if _, err := os.Stat(p); err == nil { spec := &pluginapi.DeviceSpec{ ContainerPath: p, HostPath: p, Permissions: "rw", } specs = append(specs, spec) } } // for _, d := range m.devs { // for _, id := range filter { // if d.ID == id { // spec := &pluginapi.DeviceSpec{ // ContainerPath: d.Path, // HostPath: d.Path, // Permissions: "rw", // } // specs = append(specs, spec) // } // } // } return specs }