diff --git a/pkg/nvml/device.go b/pkg/nvml/device.go
index 26f3d87..777a9da 100644
--- a/pkg/nvml/device.go
+++ b/pkg/nvml/device.go
@@ -207,3 +207,9 @@ func (d nvmlDevice) GetNvLinkRemotePciInfo(link int) (PciInfo, Return) {
 	p, r := nvml.Device(d).GetNvLinkRemotePciInfo(link)
 	return PciInfo(p), Return(r)
 }
+
+// SetComputeMode sets the compute mode for the device.
+func (d nvmlDevice) SetComputeMode(mode ComputeMode) Return {
+	r := nvml.Device(d).SetComputeMode(nvml.ComputeMode(mode))
+	return Return(r)
+}
diff --git a/pkg/nvml/device_mock.go b/pkg/nvml/device_mock.go
index 203676c..84f1472 100644
--- a/pkg/nvml/device_mock.go
+++ b/pkg/nvml/device_mock.go
@@ -99,6 +99,9 @@ var _ Device = &DeviceMock{}
 //			RegisterEventsFunc: func(v uint64, eventSet EventSet) Return {
 //				panic("mock out the RegisterEvents method")
 //			},
+//			SetComputeModeFunc: func(computeMode ComputeMode) Return {
+//				panic("mock out the SetComputeMode method")
+//			},
 //			SetMigModeFunc: func(Mode int) (Return, Return) {
 //				panic("mock out the SetMigMode method")
 //			},
@@ -193,6 +196,9 @@ type DeviceMock struct {
 	// RegisterEventsFunc mocks the RegisterEvents method.
 	RegisterEventsFunc func(v uint64, eventSet EventSet) Return
 
+	// SetComputeModeFunc mocks the SetComputeMode method.
+	SetComputeModeFunc func(computeMode ComputeMode) Return
+
 	// SetMigModeFunc mocks the SetMigMode method.
 	SetMigModeFunc func(Mode int) (Return, Return)
 
@@ -306,6 +312,11 @@ type DeviceMock struct {
 			// EventSet is the eventSet argument value.
 			EventSet EventSet
 		}
+		// SetComputeMode holds details about calls to the SetComputeMode method.
+		SetComputeMode []struct {
+			// ComputeMode is the computeMode argument value.
+			ComputeMode ComputeMode
+		}
 		// SetMigMode holds details about calls to the SetMigMode method.
 		SetMigMode []struct {
 			// Mode is the Mode argument value.
@@ -342,6 +353,7 @@ type DeviceMock struct {
 	lockGetUUID                            sync.RWMutex
 	lockIsMigDeviceHandle                  sync.RWMutex
 	lockRegisterEvents                     sync.RWMutex
+	lockSetComputeMode                     sync.RWMutex
 	lockSetMigMode                         sync.RWMutex
 	locknvmlDeviceHandle                   sync.RWMutex
 }
@@ -1133,6 +1145,38 @@ func (mock *DeviceMock) RegisterEventsCalls() []struct {
 	return calls
 }
 
+// SetComputeMode calls SetComputeModeFunc.
+func (mock *DeviceMock) SetComputeMode(computeMode ComputeMode) Return {
+	if mock.SetComputeModeFunc == nil {
+		panic("DeviceMock.SetComputeModeFunc: method is nil but Device.SetComputeMode was just called")
+	}
+	callInfo := struct {
+		ComputeMode ComputeMode
+	}{
+		ComputeMode: computeMode,
+	}
+	mock.lockSetComputeMode.Lock()
+	mock.calls.SetComputeMode = append(mock.calls.SetComputeMode, callInfo)
+	mock.lockSetComputeMode.Unlock()
+	return mock.SetComputeModeFunc(computeMode)
+}
+
+// SetComputeModeCalls gets all the calls that were made to SetComputeMode.
+// Check the length with:
+//
+//	len(mockedDevice.SetComputeModeCalls())
+func (mock *DeviceMock) SetComputeModeCalls() []struct {
+	ComputeMode ComputeMode
+} {
+	var calls []struct {
+		ComputeMode ComputeMode
+	}
+	mock.lockSetComputeMode.RLock()
+	calls = mock.calls.SetComputeMode
+	mock.lockSetComputeMode.RUnlock()
+	return calls
+}
+
 // SetMigMode calls SetMigModeFunc.
 func (mock *DeviceMock) SetMigMode(Mode int) (Return, Return) {
 	if mock.SetMigModeFunc == nil {
diff --git a/pkg/nvml/types.go b/pkg/nvml/types.go
index 02dbab3..d515097 100644
--- a/pkg/nvml/types.go
+++ b/pkg/nvml/types.go
@@ -67,6 +67,7 @@ type Device interface {
 	GetUUID() (string, Return)
 	IsMigDeviceHandle() (bool, Return)
 	RegisterEvents(uint64, EventSet) Return
+	SetComputeMode(ComputeMode) Return
 	SetMigMode(Mode int) (Return, Return)
 	// nvmlDeviceHandle returns a pointer to the underlying NVML device.
 	nvmlDeviceHandle() *nvml.Device
@@ -156,3 +157,6 @@ type GpuTopologyLevel nvml.GpuTopologyLevel
 
 // EnableState represents a generic enable/disable enum
 type EnableState nvml.EnableState
+
+// ComputeMode represents the compute mode for a device
+type ComputeMode nvml.ComputeMode