diff --git a/pkg/nvml/device.go b/pkg/nvml/device.go index 26f3d87..777a9da 100644 --- a/pkg/nvml/device.go +++ b/pkg/nvml/device.go @@ -207,3 +207,9 @@ func (d nvmlDevice) GetNvLinkRemotePciInfo(link int) (PciInfo, Return) { p, r := nvml.Device(d).GetNvLinkRemotePciInfo(link) return PciInfo(p), Return(r) } + +// SetComputeMode sets the compute mode for the device. +func (d nvmlDevice) SetComputeMode(mode ComputeMode) Return { + r := nvml.Device(d).SetComputeMode(nvml.ComputeMode(mode)) + return Return(r) +} diff --git a/pkg/nvml/device_mock.go b/pkg/nvml/device_mock.go index 203676c..84f1472 100644 --- a/pkg/nvml/device_mock.go +++ b/pkg/nvml/device_mock.go @@ -99,6 +99,9 @@ var _ Device = &DeviceMock{} // RegisterEventsFunc: func(v uint64, eventSet EventSet) Return { // panic("mock out the RegisterEvents method") // }, +// SetComputeModeFunc: func(computeMode ComputeMode) Return { +// panic("mock out the SetComputeMode method") +// }, // SetMigModeFunc: func(Mode int) (Return, Return) { // panic("mock out the SetMigMode method") // }, @@ -193,6 +196,9 @@ type DeviceMock struct { // RegisterEventsFunc mocks the RegisterEvents method. RegisterEventsFunc func(v uint64, eventSet EventSet) Return + // SetComputeModeFunc mocks the SetComputeMode method. + SetComputeModeFunc func(computeMode ComputeMode) Return + // SetMigModeFunc mocks the SetMigMode method. SetMigModeFunc func(Mode int) (Return, Return) @@ -306,6 +312,11 @@ type DeviceMock struct { // EventSet is the eventSet argument value. EventSet EventSet } + // SetComputeMode holds details about calls to the SetComputeMode method. + SetComputeMode []struct { + // ComputeMode is the computeMode argument value. + ComputeMode ComputeMode + } // SetMigMode holds details about calls to the SetMigMode method. SetMigMode []struct { // Mode is the Mode argument value. @@ -342,6 +353,7 @@ type DeviceMock struct { lockGetUUID sync.RWMutex lockIsMigDeviceHandle sync.RWMutex lockRegisterEvents sync.RWMutex + lockSetComputeMode sync.RWMutex lockSetMigMode sync.RWMutex locknvmlDeviceHandle sync.RWMutex } @@ -1133,6 +1145,38 @@ func (mock *DeviceMock) RegisterEventsCalls() []struct { return calls } +// SetComputeMode calls SetComputeModeFunc. +func (mock *DeviceMock) SetComputeMode(computeMode ComputeMode) Return { + if mock.SetComputeModeFunc == nil { + panic("DeviceMock.SetComputeModeFunc: method is nil but Device.SetComputeMode was just called") + } + callInfo := struct { + ComputeMode ComputeMode + }{ + ComputeMode: computeMode, + } + mock.lockSetComputeMode.Lock() + mock.calls.SetComputeMode = append(mock.calls.SetComputeMode, callInfo) + mock.lockSetComputeMode.Unlock() + return mock.SetComputeModeFunc(computeMode) +} + +// SetComputeModeCalls gets all the calls that were made to SetComputeMode. +// Check the length with: +// +// len(mockedDevice.SetComputeModeCalls()) +func (mock *DeviceMock) SetComputeModeCalls() []struct { + ComputeMode ComputeMode +} { + var calls []struct { + ComputeMode ComputeMode + } + mock.lockSetComputeMode.RLock() + calls = mock.calls.SetComputeMode + mock.lockSetComputeMode.RUnlock() + return calls +} + // SetMigMode calls SetMigModeFunc. func (mock *DeviceMock) SetMigMode(Mode int) (Return, Return) { if mock.SetMigModeFunc == nil { diff --git a/pkg/nvml/types.go b/pkg/nvml/types.go index 02dbab3..d515097 100644 --- a/pkg/nvml/types.go +++ b/pkg/nvml/types.go @@ -67,6 +67,7 @@ type Device interface { GetUUID() (string, Return) IsMigDeviceHandle() (bool, Return) RegisterEvents(uint64, EventSet) Return + SetComputeMode(ComputeMode) Return SetMigMode(Mode int) (Return, Return) // nvmlDeviceHandle returns a pointer to the underlying NVML device. nvmlDeviceHandle() *nvml.Device @@ -156,3 +157,6 @@ type GpuTopologyLevel nvml.GpuTopologyLevel // EnableState represents a generic enable/disable enum type EnableState nvml.EnableState + +// ComputeMode represents the compute mode for a device +type ComputeMode nvml.ComputeMode