diff --git a/nvidia-server.go b/nvidia-server.go new file mode 100644 index 0000000..baecb4c --- /dev/null +++ b/nvidia-server.go @@ -0,0 +1,292 @@ +/* + * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "flag" + "log" + "net" + "os" + "path" + "time" + + "golang.org/x/net/context" + "google.golang.org/grpc" + pluginapi "k8s.io/kubernetes/pkg/kubelet/apis/deviceplugin/v1beta1" +) + +var passDeviceSpecs = flag.Bool("pass-device-specs", false, "pass the list of DeviceSpecs to the kubelet on Allocate()") + +// NvidiaDevicePlugin implements the Kubernetes device plugin API +type NvidiaDevicePlugin struct { + devs []*pluginapi.Device + resourceName string + allocateEnvvar string + socket string + id string + + server *grpc.Server + stop chan interface{} + health chan *pluginapi.Device + +} + +// NewNvidiaDevicePlugin returns an initialized NvidiaDevicePlugin +func NewNvidiaDevicePlugin(resourceName string, allocateEnvvar string, socket string, id string) *NvidiaDevicePlugin { + return &NvidiaDevicePlugin{ + resourceName: resourceName, + allocateEnvvar: allocateEnvvar, + socket: socket, + id: id, + } +} + +func (m *NvidiaDevicePlugin) initialize() { + m.server = grpc.NewServer([]grpc.ServerOption{}...) +} + +func (m *NvidiaDevicePlugin) cleanup() { +} + +// Start starts the gRPC server, registers the device plugin with the Kubelet, +// and starts the device healthchecks. +func (m *NvidiaDevicePlugin) Start() error { + m.initialize() + + err := m.Serve() + if err != nil { + log.Printf("Could not start device plugin for '%s': %s", m.resourceName, err) + m.cleanup() + return err + } + log.Printf("Starting to serve '%s' on %s", m.resourceName, m.socket) + + err = m.Register() + if err != nil { + log.Printf("Could not register device plugin: %s", err) + m.Stop() + return err + } + log.Printf("Registered device plugin for '%s' with Kubelet", m.resourceName) + + return nil +} + +// Stop stops the gRPC server. +func (m *NvidiaDevicePlugin) Stop() error { + if m == nil || m.server == nil { + return nil + } + log.Printf("Stopping to serve '%s' on %s", m.resourceName, m.socket) + m.server.Stop() + if err := os.Remove(m.socket); err != nil && !os.IsNotExist(err) { + return err + } + m.cleanup() + return nil +} + +// Serve starts the gRPC server of the device plugin. +func (m *NvidiaDevicePlugin) Serve() error { + sock, err := net.Listen("unix", m.socket) + if err != nil { + return err + } + + pluginapi.RegisterDevicePluginServer(m.server, m) + + go func() { + lastCrashTime := time.Now() + restartCount := 0 + for { + log.Printf("Starting GRPC server for '%s'", m.resourceName) + err := m.server.Serve(sock) + if err == nil { + break + } + + log.Printf("GRPC server for '%s' crashed with error: %v", m.resourceName, err) + + // restart if it has not been too often + // i.e. if server has crashed more than 5 times and it didn't last more than one hour each time + if restartCount > 5 { + // quit + log.Fatal("GRPC server for '%s' has repeatedly crashed recently. Quitting", m.resourceName) + } + timeSinceLastCrash := time.Since(lastCrashTime).Seconds() + lastCrashTime = time.Now() + if timeSinceLastCrash > 3600 { + // it has been one hour since the last crash.. reset the count + // to reflect on the frequency + restartCount = 1 + } else { + restartCount += 1 + } + } + }() + + // Wait for server to start by launching a blocking connexion + conn, err := m.dial(m.socket, 5*time.Second) + if err != nil { + return err + } + conn.Close() + + return nil +} + +// Register registers the device plugin for the given resourceName with Kubelet. +func (m *NvidiaDevicePlugin) Register() error { + conn, err := m.dial(pluginapi.KubeletSocket, 5*time.Second) + if err != nil { + return err + } + defer conn.Close() + + client := pluginapi.NewRegistrationClient(conn) + reqt := &pluginapi.RegisterRequest{ + Version: pluginapi.Version, + Endpoint: path.Base(m.socket), + ResourceName: m.resourceName, + } + + _, err = client.Register(context.Background(), reqt) + if err != nil { + return err + } + return nil +} + +func (m *NvidiaDevicePlugin) GetDevicePluginOptions(context.Context, *pluginapi.Empty) (*pluginapi.DevicePluginOptions, error) { + return &pluginapi.DevicePluginOptions{}, nil +} + +// ListAndWatch lists devices and update that list according to the health status +func (m *NvidiaDevicePlugin) ListAndWatch(e *pluginapi.Empty, s pluginapi.DevicePlugin_ListAndWatchServer) error { + s.Send(&pluginapi.ListAndWatchResponse{Devices: m.devs}) + + for { + select { + case <-m.stop: + return nil + case d := <-m.health: + // FIXME: there is no way to recover from the Unhealthy state. + d.Health = pluginapi.Unhealthy + log.Printf("'%s' device marked unhealthy: %s", m.resourceName, d.ID) + s.Send(&pluginapi.ListAndWatchResponse{Devices: m.devs}) + } + } +} + +// Allocate which return list of devices. +func (m *NvidiaDevicePlugin) Allocate(ctx context.Context, reqs *pluginapi.AllocateRequest) (*pluginapi.AllocateResponse, error) { + responses := pluginapi.AllocateResponse{} + for _, req := range reqs.ContainerRequests { + //for _, id := range req.DevicesIDs { + // if !m.deviceExists(id) { + // return nil, fmt.Errorf("invalid allocation request for '%s': unknown device: %s", m.resourceName, id) + // } + // + + response := pluginapi.ContainerAllocateResponse{ + Envs: map[string]string{ + m.allocateEnvvar: m.id, + }, + } + if *passDeviceSpecs { + response.Devices = m.apiDeviceSpecs(req.DevicesIDs) + } + + responses.ContainerResponses = append(responses.ContainerResponses, &response) + } + + return &responses, nil +} + +func (m *NvidiaDevicePlugin) PreStartContainer(context.Context, *pluginapi.PreStartContainerRequest) (*pluginapi.PreStartContainerResponse, error) { + return &pluginapi.PreStartContainerResponse{}, nil +} + +// dial establishes the gRPC communication with the registered device plugin. +func (m *NvidiaDevicePlugin) dial(unixSocketPath string, timeout time.Duration) (*grpc.ClientConn, error) { + c, err := grpc.Dial(unixSocketPath, grpc.WithInsecure(), grpc.WithBlock(), + grpc.WithTimeout(timeout), + grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) { + return net.DialTimeout("unix", addr, timeout) + }), + ) + + if err != nil { + return nil, err + } + + return c, nil +} + +//func (m *NvidiaDevicePlugin) deviceExists(id string) bool { +// for _, d := range m.cachedDevices { +// if d.ID == id { +// return true +// } +// } +// return false +//} + +//func (m *NvidiaDevicePlugin) apiDevices() []*pluginapi.Device { +// var pdevs []*pluginapi.Device +// for _, d := range m.cachedDevices { +// pdevs = append(pdevs, &d.Device) +// } +// return pdevs +//} + +func (m *NvidiaDevicePlugin) apiDeviceSpecs(filter []string) []*pluginapi.DeviceSpec { + var specs []*pluginapi.DeviceSpec + + paths := []string{ + "/dev/nvidiactl", + "/dev/nvidia-uvm", + "/dev/nvidia-uvm-tools", + "/dev/nvidia-modeset", + } + + for _, p := range paths { + if _, err := os.Stat(p); err == nil { + spec := &pluginapi.DeviceSpec{ + ContainerPath: p, + HostPath: p, + Permissions: "rw", + } + specs = append(specs, spec) + } + } + +// for _, d := range m.devs { +// for _, id := range filter { +// if d.ID == id { +// spec := &pluginapi.DeviceSpec{ +// ContainerPath: d.Path, +// HostPath: d.Path, +// Permissions: "rw", +// } +// specs = append(specs, spec) +// } +// } +// } + + return specs +}