New version of the nvidia GPU access

This commit is contained in:
Alexandre Ferreira 2020-06-08 14:45:14 -05:00
parent 94783dfc37
commit 727594c382
3 changed files with 42 additions and 20 deletions

55
main.go
View File

@ -5,6 +5,7 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"strings"
"os" "os"
"regexp" "regexp"
"syscall" "syscall"
@ -19,18 +20,20 @@ import (
var confFileName string var confFileName string
const ( const (
deviceFileType int = 0 deviceFileType uint = 0
nvidiaSysType int = 1 nvidiaSysType uint = 1
) )
type DeviceInstance struct { type DeviceInstance struct {
devicePlugin *SmarterDevicePlugin devicePluginSmarter *SmarterDevicePlugin
devicePluginNvidia *NvidiaDevicePlugin
deviceName string deviceName string
socketName string socketName string
deviceFile string deviceFile string
numDevices uint numDevices uint
deviceType uint deviceType uint
deviceId string
} }
type DesiredDevice struct { type DesiredDevice struct {
@ -113,9 +116,9 @@ func main() {
var listDevicesAvailable []DeviceInstance var listDevicesAvailable []DeviceInstance
for _, deviceToTest := range desiredDevices { for _, deviceToTest := range desiredDevices {
if deviceToTest.DeviceMatch = "nvidia-gpu" { if deviceToTest.DeviceMatch == "nvidia-gpu" {
glog.V(0).Infof("Checking nvidia devices") glog.V(0).Infof("Checking nvidia devices")
foundDevices,err := findDevicesPattern(ExistingDevices, "gpu.[0-9]*") foundDevices,err := findDevicesPattern(ExistingDevicesSys, "gpu.[0-9]*")
if err != nil { if err != nil {
glog.Errorf(err.Error()) glog.Errorf(err.Error())
os.Exit(1) os.Exit(1)
@ -125,9 +128,10 @@ func main() {
if len(foundDevices) > 0 { if len(foundDevices) > 0 {
for _, deviceToCreate := range foundDevices { for _, deviceToCreate := range foundDevices {
var newDevice DeviceInstance var newDevice DeviceInstance
deviceId := TrimPrefix(deviceToCreate,"gpu.") deviceId := strings.TrimPrefix(deviceToCreate,"gpu.")
newDevice.deviceName = "smarter-devices/" + "nvidia-gpu" + deviceId newDevice.deviceName = "smarter-devices/" + "nvidia-gpu" + deviceId
newDevice.socketName = pluginapi.DevicePluginPath + "smarter-" + d"nvidia-gpu" + deviceId + ".sock" newDevice.deviceId = deviceId
newDevice.socketName = pluginapi.DevicePluginPath + "smarter-nvidia-gpu" + deviceId + ".sock"
newDevice.deviceFile = deviceId newDevice.deviceFile = deviceId
newDevice.numDevices = deviceToTest.NumMaxDevices newDevice.numDevices = deviceToTest.NumMaxDevices
newDevice.deviceType = nvidiaSysType newDevice.deviceType = nvidiaSysType
@ -135,8 +139,7 @@ func main() {
glog.V(0).Infof("Creating device %s socket and %s name for %s",newDevice.deviceName,newDevice.deviceFile,deviceToTest.DeviceMatch) glog.V(0).Infof("Creating device %s socket and %s name for %s",newDevice.deviceName,newDevice.deviceFile,deviceToTest.DeviceMatch)
} }
} }
} } else {
else {
glog.V(0).Infof("Checking devices %s on /dev",deviceToTest.DeviceMatch) glog.V(0).Infof("Checking devices %s on /dev",deviceToTest.DeviceMatch)
foundDevices,err := findDevicesPattern(ExistingDevices, deviceToTest.DeviceMatch) foundDevices,err := findDevicesPattern(ExistingDevices, deviceToTest.DeviceMatch)
if err != nil { if err != nil {
@ -177,23 +180,30 @@ L:
for { for {
if restart { if restart {
for _, devicesInUse := range listDevicesAvailable { for _, devicesInUse := range listDevicesAvailable {
if devicesInUse.devicePlugin != nil { switch devicesInUse.deviceType {
devicesInUse.devicePlugin.Stop() case deviceFileType :
} if devicesInUse.devicePluginSmarter != nil {
devicesInUse.devicePluginSmarter.Stop()
}
case nvidiaSysType :
if devicesInUse.devicePluginNvidia != nil {
devicesInUse.devicePluginNvidia.Stop()
}
}
} }
var err error var err error
for _, devicesInUse := range listDevicesAvailable { for _, devicesInUse := range listDevicesAvailable {
switch devicesInUse.deviceType { switch devicesInUse.deviceType {
case deviceFileType : case deviceFileType :
devicesInUse.devicePlugin = NewSmarterDevicePlugin(devicesInUse.numDevices, devicesInUse.deviceFile, devicesInUse.deviceName, devicesInUse.socketName) devicesInUse.devicePluginSmarter = NewSmarterDevicePlugin(devicesInUse.numDevices, devicesInUse.deviceFile, devicesInUse.deviceName, devicesInUse.socketName)
if err = devicesInUse.devicePlugin.Serve(); err != nil { if err = devicesInUse.devicePluginSmarter.Serve(); err != nil {
glog.V(0).Info("Could not contact Kubelet, retrying. Did you enable the device plugin feature gate?") glog.V(0).Info("Could not contact Kubelet, retrying. Did you enable the device plugin feature gate?")
break break
} }
case nvidiaSysType : case nvidiaSysType :
devicesInUse.devicePlugin = NewSmarterDevicePlugin(devicesInUse.numDevices, devicesInUse.deviceFile, devicesInUse.deviceName, devicesInUse.socketName) devicesInUse.devicePluginNvidia = NewNvidiaDevicePlugin(devicesInUse.deviceName,"NVIDIA_VISIBLE_DEVICES", devicesInUse.socketName, devicesInUse.deviceId)
if err = devicesInUse.devicePlugin.Serve(); err != nil { if err = devicesInUse.devicePluginNvidia.Serve(); err != nil {
glog.V(0).Info("Could not contact Kubelet, retrying. Did you enable the device plugin feature gate?") glog.V(0).Info("Could not contact Kubelet, retrying. Did you enable the device plugin feature gate?")
break break
} }
@ -224,9 +234,16 @@ L:
default: default:
glog.V(0).Infof("Received signal \"%v\", shutting down.", s) glog.V(0).Infof("Received signal \"%v\", shutting down.", s)
for _, devicesInUse := range listDevicesAvailable { for _, devicesInUse := range listDevicesAvailable {
if devicesInUse.devicePlugin != nil { switch devicesInUse.deviceType {
devicesInUse.devicePlugin.Stop() case deviceFileType :
} if devicesInUse.devicePluginSmarter != nil {
devicesInUse.devicePluginSmarter.Stop()
}
case nvidiaSysType :
if devicesInUse.devicePluginNvidia != nil {
devicesInUse.devicePluginNvidia.Stop()
}
}
} }
break L break L
} }

View File

@ -37,7 +37,7 @@ type SmarterDevicePlugin struct {
// NewSmarterDevicePlugin returns an initialized SmarterDevicePlugin // NewSmarterDevicePlugin returns an initialized SmarterDevicePlugin
func NewSmarterDevicePlugin(nDevices uint, deviceFilename string, resourceIdentification string, serverSock string) *SmarterDevicePlugin { func NewSmarterDevicePlugin(nDevices uint, deviceFilename string, resourceIdentification string, serverSock string) *SmarterDevicePlugin {
return &SmarterDevicePlugin{ return &SmarterDevicePlugin{
devs: getDevices(uint(10)), devs: getDevices(nDevices),
socket: serverSock, socket: serverSock,
deviceFile: deviceFilename, deviceFile: deviceFilename,
resourceName: resourceIdentification, resourceName: resourceIdentification,

View File

@ -33,6 +33,8 @@ spec:
mountPath: /var/lib/kubelet/device-plugins mountPath: /var/lib/kubelet/device-plugins
- name: dev-dir - name: dev-dir
mountPath: /dev mountPath: /dev
- name: sys-dir
mountPath: /sys
volumes: volumes:
- name: device-plugin - name: device-plugin
hostPath: hostPath:
@ -40,4 +42,7 @@ spec:
- name: dev-dir - name: dev-dir
hostPath: hostPath:
path: /dev path: /dev
- name: sys-dir
hostPath:
path: /sys
terminationGracePeriodSeconds: 30 terminationGracePeriodSeconds: 30