@@ -32,6 +32,7 @@ type Device interface {
32
32
GetMigDevices () ([]MigDevice , error )
33
33
GetMigProfiles () ([]MigProfile , error )
34
34
GetPCIBusID () (string , error )
35
+ IsImexEnabled () (bool , error )
35
36
IsMigCapable () (bool , error )
36
37
IsMigEnabled () (bool , error )
37
38
VisitMigDevices (func (j int , m MigDevice ) error ) error
@@ -208,6 +209,33 @@ func (d *device) IsMigEnabled() (bool, error) {
208
209
return (mode == nvml .DEVICE_MIG_ENABLE ), nil
209
210
}
210
211
212
+ // IsImexEnabled checks if a device has IMEX capabilities.
213
+ func (d * device ) IsImexEnabled () (bool , error ) {
214
+ if d .lib .hasSymbol ("nvmlDeviceGetGpuFabricInfo" ) {
215
+ _ , ret := d .GetGpuFabricInfo ()
216
+ if ret == nvml .ERROR_NOT_SUPPORTED {
217
+ return false , nil
218
+ }
219
+ if ret != nvml .SUCCESS {
220
+ return false , fmt .Errorf ("error getting GPU Fabric Info: %v" , ret )
221
+ }
222
+ return true , nil
223
+ }
224
+
225
+ if d .lib .hasSymbol ("nvmlDeviceGetGpuFabricInfoV" ) {
226
+ _ , ret := d .GetGpuFabricInfoV ().V2 ()
227
+ if ret == nvml .ERROR_NOT_SUPPORTED {
228
+ return false , nil
229
+ }
230
+ if ret != nvml .SUCCESS {
231
+ return false , fmt .Errorf ("error getting GPU Fabric Info: %v" , ret )
232
+ }
233
+ return true , nil
234
+ }
235
+
236
+ return false , nil
237
+ }
238
+
211
239
// VisitMigDevices walks a top-level device and invokes a callback function for each MIG device configured on it.
212
240
func (d * device ) VisitMigDevices (visit func (int , MigDevice ) error ) error {
213
241
capable , err := d .IsMigCapable ()
0 commit comments