Skip to content

Commit 4292ae0

Browse files
committed
Merge remote-tracking branch 'upstream/v4.1.x' into v4.1.x_hpcx
2 parents 8e281f5 + f3bda96 commit 4292ae0

File tree

1 file changed

+117
-10
lines changed

1 file changed

+117
-10
lines changed

opal/mca/common/cuda/common_cuda.c

+117-10
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ struct cudaFunctionTable {
108108
#if OPAL_CUDA_GET_ATTRIBUTES
109109
int (*cuPointerGetAttributes)(unsigned int, CUpointer_attribute *, void **, CUdeviceptr);
110110
#if OPAL_CUDA_VMM_SUPPORT
111+
int (*cuDevicePrimaryCtxRetain)(CUcontext*, CUdevice);
112+
int (*cuDevicePrimaryCtxGetState)(CUdevice, unsigned int*, int*);
113+
int (*cuMemPoolGetAccess)(CUmemAccess_flags*, CUmemoryPool, CUmemLocation*);
114+
int (*cuDeviceGetAttribute)(int*, CUdevice_attribute, CUdevice);
111115
int (*cuDeviceGetCount)(int*);
112116
int (*cuMemRelease)(CUmemGenericAllocationHandle);
113117
int (*cuMemRetainAllocationHandle)(CUmemGenericAllocationHandle*, void*);
@@ -488,6 +492,10 @@ int mca_common_cuda_stage_one_init(void)
488492
OPAL_CUDA_DLSYM(libcuda_handle, cuPointerGetAttributes);
489493
#endif /* OPAL_CUDA_GET_ATTRIBUTES */
490494
#if OPAL_CUDA_VMM_SUPPORT
495+
OPAL_CUDA_DLSYM(libcuda_handle, cuDevicePrimaryCtxRetain);
496+
OPAL_CUDA_DLSYM(libcuda_handle, cuDevicePrimaryCtxGetState);
497+
OPAL_CUDA_DLSYM(libcuda_handle, cuMemPoolGetAccess);
498+
OPAL_CUDA_DLSYM(libcuda_handle, cuDeviceGetAttribute);
491499
OPAL_CUDA_DLSYM(libcuda_handle, cuDeviceGetCount);
492500
OPAL_CUDA_DLSYM(libcuda_handle, cuMemRelease);
493501
OPAL_CUDA_DLSYM(libcuda_handle, cuMemRetainAllocationHandle);
@@ -1745,7 +1753,90 @@ static float mydifftime(opal_timer_t ts_start, opal_timer_t ts_end) {
17451753
}
17461754
#endif /* OPAL_ENABLE_DEBUG */
17471755

1748-
static int mca_common_cuda_check_vmm(CUdeviceptr dbuf, CUmemorytype *mem_type)
1756+
static int mca_common_cuda_check_mpool(CUdeviceptr dbuf, CUmemorytype *mem_type,
1757+
int *dev_id)
1758+
{
1759+
#if OPAL_CUDA_VMM_SUPPORT
1760+
static int device_count = -1;
1761+
static int mpool_supported = -1;
1762+
CUresult result;
1763+
CUmemoryPool mpool;
1764+
CUmemAccess_flags flags;
1765+
CUmemLocation location;
1766+
1767+
if (mpool_supported <= 0) {
1768+
if (mpool_supported == -1) {
1769+
if (device_count == -1) {
1770+
result = cuFunc.cuDeviceGetCount(&device_count);
1771+
if (result != CUDA_SUCCESS || (0 == device_count)) {
1772+
mpool_supported = 0; /* never check again */
1773+
device_count = 0;
1774+
return 0;
1775+
}
1776+
}
1777+
1778+
/* assume uniformity of devices */
1779+
result = cuFunc.cuDeviceGetAttribute(&mpool_supported,
1780+
CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, 0);
1781+
if (result != CUDA_SUCCESS) {
1782+
mpool_supported = 0;
1783+
}
1784+
}
1785+
if (0 == mpool_supported) {
1786+
return 0;
1787+
}
1788+
}
1789+
1790+
result = cuFunc.cuPointerGetAttribute(&mpool,
1791+
CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE,
1792+
dbuf);
1793+
if (CUDA_SUCCESS != result) {
1794+
return 0;
1795+
}
1796+
1797+
/* check if device has access */
1798+
for (int i = 0; i < device_count; i++) {
1799+
location.type = CU_MEM_LOCATION_TYPE_DEVICE;
1800+
location.id = i;
1801+
result = cuFunc.cuMemPoolGetAccess(&flags, mpool, &location);
1802+
if ((CUDA_SUCCESS == result) &&
1803+
(CU_MEM_ACCESS_FLAGS_PROT_READWRITE == flags)) {
1804+
*mem_type = CU_MEMORYTYPE_DEVICE;
1805+
*dev_id = i;
1806+
return 1;
1807+
}
1808+
}
1809+
1810+
/* host must have access as device access possibility is exhausted */
1811+
*mem_type = CU_MEMORYTYPE_HOST;
1812+
*dev_id = -1;
1813+
return 0;
1814+
#endif
1815+
1816+
return 0;
1817+
}
1818+
1819+
static int mca_common_cuda_get_primary_context(CUdevice dev_id, CUcontext *pctx)
1820+
{
1821+
CUresult result;
1822+
unsigned int flags;
1823+
int active;
1824+
1825+
result = cuFunc.cuDevicePrimaryCtxGetState(dev_id, &flags, &active);
1826+
if (CUDA_SUCCESS != result) {
1827+
return OPAL_ERROR;
1828+
}
1829+
1830+
if (active) {
1831+
result = cuFunc.cuDevicePrimaryCtxRetain(pctx, dev_id);
1832+
return OPAL_SUCCESS;
1833+
}
1834+
1835+
return OPAL_ERROR;
1836+
}
1837+
1838+
static int mca_common_cuda_check_vmm(CUdeviceptr dbuf, CUmemorytype *mem_type,
1839+
int *dev_id)
17491840
{
17501841
#if OPAL_CUDA_VMM_SUPPORT
17511842
static int device_count = -1;
@@ -1775,6 +1866,7 @@ static int mca_common_cuda_check_vmm(CUdeviceptr dbuf, CUmemorytype *mem_type)
17751866

17761867
if (prop.location.type == CU_MEM_LOCATION_TYPE_DEVICE) {
17771868
*mem_type = CU_MEMORYTYPE_DEVICE;
1869+
*dev_id = prop.location.id;
17781870
cuFunc.cuMemRelease(alloc_handle);
17791871
return 1;
17801872
}
@@ -1788,6 +1880,7 @@ static int mca_common_cuda_check_vmm(CUdeviceptr dbuf, CUmemorytype *mem_type)
17881880
if ((CUDA_SUCCESS == result) &&
17891881
(CU_MEM_ACCESS_FLAGS_PROT_READWRITE == flags)) {
17901882
*mem_type = CU_MEMORYTYPE_DEVICE;
1883+
*dev_id = i;
17911884
cuFunc.cuMemRelease(alloc_handle);
17921885
return 1;
17931886
}
@@ -1796,6 +1889,7 @@ static int mca_common_cuda_check_vmm(CUdeviceptr dbuf, CUmemorytype *mem_type)
17961889

17971890
/* host must have access as device access possibility is exhausted */
17981891
*mem_type = CU_MEMORYTYPE_HOST;
1892+
*dev_id = -1;
17991893
cuFunc.cuMemRelease(alloc_handle);
18001894
return 1;
18011895

@@ -1809,12 +1903,17 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf, opal_convertor_t
18091903
{
18101904
int res;
18111905
int is_vmm = 0;
1906+
int is_mpool = 0;
18121907
CUmemorytype vmm_mem_type = 0;
1908+
CUmemorytype mpool_mem_type = 0;
18131909
CUmemorytype memType = 0;
1910+
int vmm_dev_id = -1;
1911+
int mpool_dev_id = -1;
18141912
CUdeviceptr dbuf = (CUdeviceptr)pUserBuf;
18151913
CUcontext ctx = NULL, memCtx = NULL;
18161914

1817-
is_vmm = mca_common_cuda_check_vmm(dbuf, &vmm_mem_type);
1915+
is_vmm = mca_common_cuda_check_vmm(dbuf, &vmm_mem_type, &vmm_dev_id);
1916+
is_mpool = mca_common_cuda_check_mpool(dbuf, &mpool_mem_type, &mpool_dev_id);
18181917

18191918
#if OPAL_CUDA_GET_ATTRIBUTES
18201919
uint32_t isManaged = 0;
@@ -1844,6 +1943,8 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf, opal_convertor_t
18441943
} else if (memType == CU_MEMORYTYPE_HOST) {
18451944
if (is_vmm && (vmm_mem_type == CU_MEMORYTYPE_DEVICE)) {
18461945
memType = CU_MEMORYTYPE_DEVICE;
1946+
} else if (is_mpool && (mpool_mem_type == CU_MEMORYTYPE_DEVICE)) {
1947+
memType = CU_MEMORYTYPE_DEVICE;
18471948
} else {
18481949
/* Host memory, nothing to do here */
18491950
return 0;
@@ -1864,6 +1965,8 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf, opal_convertor_t
18641965
} else if (memType == CU_MEMORYTYPE_HOST) {
18651966
if (is_vmm && (vmm_mem_type == CU_MEMORYTYPE_DEVICE)) {
18661967
memType = CU_MEMORYTYPE_DEVICE;
1968+
} else if (is_mpool && (mpool_mem_type == CU_MEMORYTYPE_DEVICE)) {
1969+
memType = CU_MEMORYTYPE_DEVICE;
18671970
} else {
18681971
/* Host memory, nothing to do here */
18691972
return 0;
@@ -1893,14 +1996,18 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf, opal_convertor_t
18931996
return OPAL_ERROR;
18941997
}
18951998
#endif /* OPAL_CUDA_GET_ATTRIBUTES */
1896-
if (is_vmm) {
1897-
/* This function is expected to set context if pointer is device
1898-
* accessible but VMM allocations have NULL context associated
1899-
* which cannot be set against the calling thread */
1900-
opal_output(0,
1901-
"CUDA: unable to set context with the given pointer"
1902-
"ptr=%p aborting...", dbuf);
1903-
return OPAL_ERROR;
1999+
if (is_vmm || is_mpool) {
2000+
if (OPAL_SUCCESS ==
2001+
mca_common_cuda_get_primary_context(
2002+
is_vmm ? vmm_dev_id : mpool_dev_id, &memCtx)) {
2003+
/* As VMM/mempool allocations have no context associated
2004+
* with them, check if device primary context can be set */
2005+
} else {
2006+
opal_output(0,
2007+
"CUDA: unable to set ctx with the given pointer"
2008+
"ptr=%p aborting...", pUserBuf);
2009+
return OPAL_ERROR;
2010+
}
19042011
}
19052012

19062013
res = cuFunc.cuCtxSetCurrent(memCtx);

0 commit comments

Comments
 (0)