Skip to content

Commit c7513c3

Browse files
authored
Merge pull request #113 from klueska/fix-bug-load-close
Fix bug with wrong instance of lib being called for load/close
2 parents a97d07c + 73ee14c commit c7513c3

File tree

2 files changed

+25
-25
lines changed

2 files changed

+25
-25
lines changed

pkg/nvml/init.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@ import "C"
1818

1919
// nvml.Init()
2020
func (l *library) Init() Return {
21-
if err := libnvml.load(); err != nil {
21+
if err := l.load(); err != nil {
2222
return ERROR_LIBRARY_NOT_FOUND
2323
}
2424
return nvmlInit()
2525
}
2626

2727
// nvml.InitWithFlags()
2828
func (l *library) InitWithFlags(flags uint32) Return {
29-
if err := libnvml.load(); err != nil {
29+
if err := l.load(); err != nil {
3030
return ERROR_LIBRARY_NOT_FOUND
3131
}
3232
return nvmlInitWithFlags(flags)
@@ -39,7 +39,7 @@ func (l *library) Shutdown() Return {
3939
return ret
4040
}
4141

42-
err := libnvml.close()
42+
err := l.close()
4343
if err != nil {
4444
return ERROR_UNKNOWN
4545
}

pkg/nvml/lib.go

+22-22
Original file line numberDiff line numberDiff line change
@@ -198,93 +198,93 @@ func (pis ProcessInfo_v2Slice) ToProcessInfoSlice() []ProcessInfo {
198198
// When new versioned symbols are added, these would have to be initialized above and have
199199
// corresponding checks and subsequent assignments added below.
200200
func (l *library) updateVersionedSymbols() {
201-
err := l.LookupSymbol("nvmlInit_v2")
201+
err := l.dl.Lookup("nvmlInit_v2")
202202
if err == nil {
203203
nvmlInit = nvmlInit_v2
204204
}
205-
err = l.LookupSymbol("nvmlDeviceGetPciInfo_v2")
205+
err = l.dl.Lookup("nvmlDeviceGetPciInfo_v2")
206206
if err == nil {
207207
nvmlDeviceGetPciInfo = nvmlDeviceGetPciInfo_v2
208208
}
209-
err = l.LookupSymbol("nvmlDeviceGetPciInfo_v3")
209+
err = l.dl.Lookup("nvmlDeviceGetPciInfo_v3")
210210
if err == nil {
211211
nvmlDeviceGetPciInfo = nvmlDeviceGetPciInfo_v3
212212
}
213-
err = l.LookupSymbol("nvmlDeviceGetCount_v2")
213+
err = l.dl.Lookup("nvmlDeviceGetCount_v2")
214214
if err == nil {
215215
nvmlDeviceGetCount = nvmlDeviceGetCount_v2
216216
}
217-
err = l.LookupSymbol("nvmlDeviceGetHandleByIndex_v2")
217+
err = l.dl.Lookup("nvmlDeviceGetHandleByIndex_v2")
218218
if err == nil {
219219
nvmlDeviceGetHandleByIndex = nvmlDeviceGetHandleByIndex_v2
220220
}
221-
err = l.LookupSymbol("nvmlDeviceGetHandleByPciBusId_v2")
221+
err = l.dl.Lookup("nvmlDeviceGetHandleByPciBusId_v2")
222222
if err == nil {
223223
nvmlDeviceGetHandleByPciBusId = nvmlDeviceGetHandleByPciBusId_v2
224224
}
225-
err = l.LookupSymbol("nvmlDeviceGetNvLinkRemotePciInfo_v2")
225+
err = l.dl.Lookup("nvmlDeviceGetNvLinkRemotePciInfo_v2")
226226
if err == nil {
227227
nvmlDeviceGetNvLinkRemotePciInfo = nvmlDeviceGetNvLinkRemotePciInfo_v2
228228
}
229229
// Unable to overwrite nvmlDeviceRemoveGpu() because the v2 function takes
230230
// a different set of parameters than the v1 function.
231-
//err = l.LookupSymbol("nvmlDeviceRemoveGpu_v2")
231+
//err = l.dl.Lookup("nvmlDeviceRemoveGpu_v2")
232232
//if err == nil {
233233
// nvmlDeviceRemoveGpu = nvmlDeviceRemoveGpu_v2
234234
//}
235-
err = l.LookupSymbol("nvmlDeviceGetGridLicensableFeatures_v2")
235+
err = l.dl.Lookup("nvmlDeviceGetGridLicensableFeatures_v2")
236236
if err == nil {
237237
nvmlDeviceGetGridLicensableFeatures = nvmlDeviceGetGridLicensableFeatures_v2
238238
}
239-
err = l.LookupSymbol("nvmlDeviceGetGridLicensableFeatures_v3")
239+
err = l.dl.Lookup("nvmlDeviceGetGridLicensableFeatures_v3")
240240
if err == nil {
241241
nvmlDeviceGetGridLicensableFeatures = nvmlDeviceGetGridLicensableFeatures_v3
242242
}
243-
err = l.LookupSymbol("nvmlDeviceGetGridLicensableFeatures_v4")
243+
err = l.dl.Lookup("nvmlDeviceGetGridLicensableFeatures_v4")
244244
if err == nil {
245245
nvmlDeviceGetGridLicensableFeatures = nvmlDeviceGetGridLicensableFeatures_v4
246246
}
247-
err = l.LookupSymbol("nvmlEventSetWait_v2")
247+
err = l.dl.Lookup("nvmlEventSetWait_v2")
248248
if err == nil {
249249
nvmlEventSetWait = nvmlEventSetWait_v2
250250
}
251-
err = l.LookupSymbol("nvmlDeviceGetAttributes_v2")
251+
err = l.dl.Lookup("nvmlDeviceGetAttributes_v2")
252252
if err == nil {
253253
nvmlDeviceGetAttributes = nvmlDeviceGetAttributes_v2
254254
}
255-
err = l.LookupSymbol("nvmlComputeInstanceGetInfo_v2")
255+
err = l.dl.Lookup("nvmlComputeInstanceGetInfo_v2")
256256
if err == nil {
257257
nvmlComputeInstanceGetInfo = nvmlComputeInstanceGetInfo_v2
258258
}
259-
err = l.LookupSymbol("nvmlDeviceGetComputeRunningProcesses_v2")
259+
err = l.dl.Lookup("nvmlDeviceGetComputeRunningProcesses_v2")
260260
if err == nil {
261261
deviceGetComputeRunningProcesses = deviceGetComputeRunningProcesses_v2
262262
}
263-
err = l.LookupSymbol("nvmlDeviceGetComputeRunningProcesses_v3")
263+
err = l.dl.Lookup("nvmlDeviceGetComputeRunningProcesses_v3")
264264
if err == nil {
265265
deviceGetComputeRunningProcesses = deviceGetComputeRunningProcesses_v3
266266
}
267-
err = l.LookupSymbol("nvmlDeviceGetGraphicsRunningProcesses_v2")
267+
err = l.dl.Lookup("nvmlDeviceGetGraphicsRunningProcesses_v2")
268268
if err == nil {
269269
deviceGetGraphicsRunningProcesses = deviceGetGraphicsRunningProcesses_v2
270270
}
271-
err = l.LookupSymbol("nvmlDeviceGetGraphicsRunningProcesses_v3")
271+
err = l.dl.Lookup("nvmlDeviceGetGraphicsRunningProcesses_v3")
272272
if err == nil {
273273
deviceGetGraphicsRunningProcesses = deviceGetGraphicsRunningProcesses_v3
274274
}
275-
err = l.LookupSymbol("nvmlDeviceGetMPSComputeRunningProcesses_v2")
275+
err = l.dl.Lookup("nvmlDeviceGetMPSComputeRunningProcesses_v2")
276276
if err == nil {
277277
deviceGetMPSComputeRunningProcesses = deviceGetMPSComputeRunningProcesses_v2
278278
}
279-
err = l.LookupSymbol("nvmlDeviceGetMPSComputeRunningProcesses_v3")
279+
err = l.dl.Lookup("nvmlDeviceGetMPSComputeRunningProcesses_v3")
280280
if err == nil {
281281
deviceGetMPSComputeRunningProcesses = deviceGetMPSComputeRunningProcesses_v3
282282
}
283-
err = l.LookupSymbol("nvmlDeviceGetGpuInstancePossiblePlacements_v2")
283+
err = l.dl.Lookup("nvmlDeviceGetGpuInstancePossiblePlacements_v2")
284284
if err == nil {
285285
nvmlDeviceGetGpuInstancePossiblePlacements = nvmlDeviceGetGpuInstancePossiblePlacements_v2
286286
}
287-
err = l.LookupSymbol("nvmlVgpuInstanceGetLicenseInfo_v2")
287+
err = l.dl.Lookup("nvmlVgpuInstanceGetLicenseInfo_v2")
288288
if err == nil {
289289
nvmlVgpuInstanceGetLicenseInfo = nvmlVgpuInstanceGetLicenseInfo_v2
290290
}

0 commit comments

Comments
 (0)