smarter-device-manager/nvidia-server.go
2020-06-08 14:53:33 -05:00

293 lines
7.6 KiB
Go

/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package main
import (
"flag"
"log"
"net"
"os"
"path"
"time"
"golang.org/x/net/context"
"google.golang.org/grpc"
pluginapi "k8s.io/kubernetes/pkg/kubelet/apis/deviceplugin/v1beta1"
)
var passDeviceSpecs = flag.Bool("pass-device-specs", false, "pass the list of DeviceSpecs to the kubelet on Allocate()")
// NvidiaDevicePlugin implements the Kubernetes device plugin API
type NvidiaDevicePlugin struct {
devs []*pluginapi.Device
resourceName string
allocateEnvvar string
socket string
id string
server *grpc.Server
stop chan interface{}
health chan *pluginapi.Device
}
// NewNvidiaDevicePlugin returns an initialized NvidiaDevicePlugin
func NewNvidiaDevicePlugin(resourceName string, allocateEnvvar string, socket string, id string) *NvidiaDevicePlugin {
return &NvidiaDevicePlugin{
resourceName: resourceName,
allocateEnvvar: allocateEnvvar,
socket: socket,
id: id,
}
}
func (m *NvidiaDevicePlugin) initialize() {
m.server = grpc.NewServer([]grpc.ServerOption{}...)
}
func (m *NvidiaDevicePlugin) cleanup() {
}
// Start starts the gRPC server, registers the device plugin with the Kubelet,
// and starts the device healthchecks.
func (m *NvidiaDevicePlugin) Start() error {
m.initialize()
err := m.Serve()
if err != nil {
log.Printf("Could not start device plugin for '%s': %s", m.resourceName, err)
m.cleanup()
return err
}
log.Printf("Starting to serve '%s' on %s", m.resourceName, m.socket)
err = m.Register()
if err != nil {
log.Printf("Could not register device plugin: %s", err)
m.Stop()
return err
}
log.Printf("Registered device plugin for '%s' with Kubelet", m.resourceName)
return nil
}
// Stop stops the gRPC server.
func (m *NvidiaDevicePlugin) Stop() error {
if m == nil || m.server == nil {
return nil
}
log.Printf("Stopping to serve '%s' on %s", m.resourceName, m.socket)
m.server.Stop()
if err := os.Remove(m.socket); err != nil && !os.IsNotExist(err) {
return err
}
m.cleanup()
return nil
}
// Serve starts the gRPC server of the device plugin.
func (m *NvidiaDevicePlugin) Serve() error {
sock, err := net.Listen("unix", m.socket)
if err != nil {
return err
}
pluginapi.RegisterDevicePluginServer(m.server, m)
go func() {
lastCrashTime := time.Now()
restartCount := 0
for {
log.Printf("Starting GRPC server for '%s'", m.resourceName)
err := m.server.Serve(sock)
if err == nil {
break
}
log.Printf("GRPC server for '%s' crashed with error: %v", m.resourceName, err)
// restart if it has not been too often
// i.e. if server has crashed more than 5 times and it didn't last more than one hour each time
if restartCount > 5 {
// quit
log.Fatal("GRPC server for '%s' has repeatedly crashed recently. Quitting", m.resourceName)
}
timeSinceLastCrash := time.Since(lastCrashTime).Seconds()
lastCrashTime = time.Now()
if timeSinceLastCrash > 3600 {
// it has been one hour since the last crash.. reset the count
// to reflect on the frequency
restartCount = 1
} else {
restartCount += 1
}
}
}()
// Wait for server to start by launching a blocking connexion
conn, err := m.dial(m.socket, 5*time.Second)
if err != nil {
return err
}
conn.Close()
return nil
}
// Register registers the device plugin for the given resourceName with Kubelet.
func (m *NvidiaDevicePlugin) Register() error {
conn, err := m.dial(pluginapi.KubeletSocket, 5*time.Second)
if err != nil {
return err
}
defer conn.Close()
client := pluginapi.NewRegistrationClient(conn)
reqt := &pluginapi.RegisterRequest{
Version: pluginapi.Version,
Endpoint: path.Base(m.socket),
ResourceName: m.resourceName,
}
_, err = client.Register(context.Background(), reqt)
if err != nil {
return err
}
return nil
}
func (m *NvidiaDevicePlugin) GetDevicePluginOptions(context.Context, *pluginapi.Empty) (*pluginapi.DevicePluginOptions, error) {
return &pluginapi.DevicePluginOptions{}, nil
}
// ListAndWatch lists devices and update that list according to the health status
func (m *NvidiaDevicePlugin) ListAndWatch(e *pluginapi.Empty, s pluginapi.DevicePlugin_ListAndWatchServer) error {
s.Send(&pluginapi.ListAndWatchResponse{Devices: m.devs})
for {
select {
case <-m.stop:
return nil
case d := <-m.health:
// FIXME: there is no way to recover from the Unhealthy state.
d.Health = pluginapi.Unhealthy
log.Printf("'%s' device marked unhealthy: %s", m.resourceName, d.ID)
s.Send(&pluginapi.ListAndWatchResponse{Devices: m.devs})
}
}
}
// Allocate which return list of devices.
func (m *NvidiaDevicePlugin) Allocate(ctx context.Context, reqs *pluginapi.AllocateRequest) (*pluginapi.AllocateResponse, error) {
responses := pluginapi.AllocateResponse{}
for _, req := range reqs.ContainerRequests {
//for _, id := range req.DevicesIDs {
// if !m.deviceExists(id) {
// return nil, fmt.Errorf("invalid allocation request for '%s': unknown device: %s", m.resourceName, id)
// }
//
response := pluginapi.ContainerAllocateResponse{
Envs: map[string]string{
m.allocateEnvvar: m.id,
},
}
if *passDeviceSpecs {
response.Devices = m.apiDeviceSpecs(req.DevicesIDs)
}
responses.ContainerResponses = append(responses.ContainerResponses, &response)
}
return &responses, nil
}
func (m *NvidiaDevicePlugin) PreStartContainer(context.Context, *pluginapi.PreStartContainerRequest) (*pluginapi.PreStartContainerResponse, error) {
return &pluginapi.PreStartContainerResponse{}, nil
}
// dial establishes the gRPC communication with the registered device plugin.
func (m *NvidiaDevicePlugin) dial(unixSocketPath string, timeout time.Duration) (*grpc.ClientConn, error) {
c, err := grpc.Dial(unixSocketPath, grpc.WithInsecure(), grpc.WithBlock(),
grpc.WithTimeout(timeout),
grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("unix", addr, timeout)
}),
)
if err != nil {
return nil, err
}
return c, nil
}
//func (m *NvidiaDevicePlugin) deviceExists(id string) bool {
// for _, d := range m.cachedDevices {
// if d.ID == id {
// return true
// }
// }
// return false
//}
//func (m *NvidiaDevicePlugin) apiDevices() []*pluginapi.Device {
// var pdevs []*pluginapi.Device
// for _, d := range m.cachedDevices {
// pdevs = append(pdevs, &d.Device)
// }
// return pdevs
//}
func (m *NvidiaDevicePlugin) apiDeviceSpecs(filter []string) []*pluginapi.DeviceSpec {
var specs []*pluginapi.DeviceSpec
paths := []string{
"/dev/nvidiactl",
"/dev/nvidia-uvm",
"/dev/nvidia-uvm-tools",
"/dev/nvidia-modeset",
}
for _, p := range paths {
if _, err := os.Stat(p); err == nil {
spec := &pluginapi.DeviceSpec{
ContainerPath: p,
HostPath: p,
Permissions: "rw",
}
specs = append(specs, spec)
}
}
// for _, d := range m.devs {
// for _, id := range filter {
// if d.ID == id {
// spec := &pluginapi.DeviceSpec{
// ContainerPath: d.Path,
// HostPath: d.Path,
// Permissions: "rw",
// }
// specs = append(specs, spec)
// }
// }
// }
return specs
}