@@ -108,6 +108,10 @@ struct cudaFunctionTable {
108
108
#if OPAL_CUDA_GET_ATTRIBUTES
109
109
int (* cuPointerGetAttributes )(unsigned int , CUpointer_attribute * , void * * , CUdeviceptr );
110
110
#if OPAL_CUDA_VMM_SUPPORT
111
+ int (* cuDevicePrimaryCtxRetain )(CUcontext * , CUdevice );
112
+ int (* cuDevicePrimaryCtxGetState )(CUdevice , unsigned int * , int * );
113
+ int (* cuMemPoolGetAccess )(CUmemAccess_flags * , CUmemoryPool , CUmemLocation * );
114
+ int (* cuDeviceGetAttribute )(int * , CUdevice_attribute , CUdevice );
111
115
int (* cuDeviceGetCount )(int * );
112
116
int (* cuMemRelease )(CUmemGenericAllocationHandle );
113
117
int (* cuMemRetainAllocationHandle )(CUmemGenericAllocationHandle * , void * );
@@ -488,6 +492,10 @@ int mca_common_cuda_stage_one_init(void)
488
492
OPAL_CUDA_DLSYM (libcuda_handle , cuPointerGetAttributes );
489
493
#endif /* OPAL_CUDA_GET_ATTRIBUTES */
490
494
#if OPAL_CUDA_VMM_SUPPORT
495
+ OPAL_CUDA_DLSYM (libcuda_handle , cuDevicePrimaryCtxRetain );
496
+ OPAL_CUDA_DLSYM (libcuda_handle , cuDevicePrimaryCtxGetState );
497
+ OPAL_CUDA_DLSYM (libcuda_handle , cuMemPoolGetAccess );
498
+ OPAL_CUDA_DLSYM (libcuda_handle , cuDeviceGetAttribute );
491
499
OPAL_CUDA_DLSYM (libcuda_handle , cuDeviceGetCount );
492
500
OPAL_CUDA_DLSYM (libcuda_handle , cuMemRelease );
493
501
OPAL_CUDA_DLSYM (libcuda_handle , cuMemRetainAllocationHandle );
@@ -1745,7 +1753,90 @@ static float mydifftime(opal_timer_t ts_start, opal_timer_t ts_end) {
1745
1753
}
1746
1754
#endif /* OPAL_ENABLE_DEBUG */
1747
1755
1748
- static int mca_common_cuda_check_vmm (CUdeviceptr dbuf , CUmemorytype * mem_type )
1756
+ static int mca_common_cuda_check_mpool (CUdeviceptr dbuf , CUmemorytype * mem_type ,
1757
+ int * dev_id )
1758
+ {
1759
+ #if OPAL_CUDA_VMM_SUPPORT
1760
+ static int device_count = -1 ;
1761
+ static int mpool_supported = -1 ;
1762
+ CUresult result ;
1763
+ CUmemoryPool mpool ;
1764
+ CUmemAccess_flags flags ;
1765
+ CUmemLocation location ;
1766
+
1767
+ if (mpool_supported <= 0 ) {
1768
+ if (mpool_supported == -1 ) {
1769
+ if (device_count == -1 ) {
1770
+ result = cuFunc .cuDeviceGetCount (& device_count );
1771
+ if (result != CUDA_SUCCESS || (0 == device_count )) {
1772
+ mpool_supported = 0 ; /* never check again */
1773
+ device_count = 0 ;
1774
+ return 0 ;
1775
+ }
1776
+ }
1777
+
1778
+ /* assume uniformity of devices */
1779
+ result = cuFunc .cuDeviceGetAttribute (& mpool_supported ,
1780
+ CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED , 0 );
1781
+ if (result != CUDA_SUCCESS ) {
1782
+ mpool_supported = 0 ;
1783
+ }
1784
+ }
1785
+ if (0 == mpool_supported ) {
1786
+ return 0 ;
1787
+ }
1788
+ }
1789
+
1790
+ result = cuFunc .cuPointerGetAttribute (& mpool ,
1791
+ CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE ,
1792
+ dbuf );
1793
+ if (CUDA_SUCCESS != result ) {
1794
+ return 0 ;
1795
+ }
1796
+
1797
+ /* check if device has access */
1798
+ for (int i = 0 ; i < device_count ; i ++ ) {
1799
+ location .type = CU_MEM_LOCATION_TYPE_DEVICE ;
1800
+ location .id = i ;
1801
+ result = cuFunc .cuMemPoolGetAccess (& flags , mpool , & location );
1802
+ if ((CUDA_SUCCESS == result ) &&
1803
+ (CU_MEM_ACCESS_FLAGS_PROT_READWRITE == flags )) {
1804
+ * mem_type = CU_MEMORYTYPE_DEVICE ;
1805
+ * dev_id = i ;
1806
+ return 1 ;
1807
+ }
1808
+ }
1809
+
1810
+ /* host must have access as device access possibility is exhausted */
1811
+ * mem_type = CU_MEMORYTYPE_HOST ;
1812
+ * dev_id = -1 ;
1813
+ return 0 ;
1814
+ #endif
1815
+
1816
+ return 0 ;
1817
+ }
1818
+
1819
+ static int mca_common_cuda_get_primary_context (CUdevice dev_id , CUcontext * pctx )
1820
+ {
1821
+ CUresult result ;
1822
+ unsigned int flags ;
1823
+ int active ;
1824
+
1825
+ result = cuFunc .cuDevicePrimaryCtxGetState (dev_id , & flags , & active );
1826
+ if (CUDA_SUCCESS != result ) {
1827
+ return OPAL_ERROR ;
1828
+ }
1829
+
1830
+ if (active ) {
1831
+ result = cuFunc .cuDevicePrimaryCtxRetain (pctx , dev_id );
1832
+ return OPAL_SUCCESS ;
1833
+ }
1834
+
1835
+ return OPAL_ERROR ;
1836
+ }
1837
+
1838
+ static int mca_common_cuda_check_vmm (CUdeviceptr dbuf , CUmemorytype * mem_type ,
1839
+ int * dev_id )
1749
1840
{
1750
1841
#if OPAL_CUDA_VMM_SUPPORT
1751
1842
static int device_count = -1 ;
@@ -1775,6 +1866,7 @@ static int mca_common_cuda_check_vmm(CUdeviceptr dbuf, CUmemorytype *mem_type)
1775
1866
1776
1867
if (prop .location .type == CU_MEM_LOCATION_TYPE_DEVICE ) {
1777
1868
* mem_type = CU_MEMORYTYPE_DEVICE ;
1869
+ * dev_id = prop .location .id ;
1778
1870
cuFunc .cuMemRelease (alloc_handle );
1779
1871
return 1 ;
1780
1872
}
@@ -1788,6 +1880,7 @@ static int mca_common_cuda_check_vmm(CUdeviceptr dbuf, CUmemorytype *mem_type)
1788
1880
if ((CUDA_SUCCESS == result ) &&
1789
1881
(CU_MEM_ACCESS_FLAGS_PROT_READWRITE == flags )) {
1790
1882
* mem_type = CU_MEMORYTYPE_DEVICE ;
1883
+ * dev_id = i ;
1791
1884
cuFunc .cuMemRelease (alloc_handle );
1792
1885
return 1 ;
1793
1886
}
@@ -1796,6 +1889,7 @@ static int mca_common_cuda_check_vmm(CUdeviceptr dbuf, CUmemorytype *mem_type)
1796
1889
1797
1890
/* host must have access as device access possibility is exhausted */
1798
1891
* mem_type = CU_MEMORYTYPE_HOST ;
1892
+ * dev_id = -1 ;
1799
1893
cuFunc .cuMemRelease (alloc_handle );
1800
1894
return 1 ;
1801
1895
@@ -1809,12 +1903,17 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf, opal_convertor_t
1809
1903
{
1810
1904
int res ;
1811
1905
int is_vmm = 0 ;
1906
+ int is_mpool = 0 ;
1812
1907
CUmemorytype vmm_mem_type = 0 ;
1908
+ CUmemorytype mpool_mem_type = 0 ;
1813
1909
CUmemorytype memType = 0 ;
1910
+ int vmm_dev_id = -1 ;
1911
+ int mpool_dev_id = -1 ;
1814
1912
CUdeviceptr dbuf = (CUdeviceptr )pUserBuf ;
1815
1913
CUcontext ctx = NULL , memCtx = NULL ;
1816
1914
1817
- is_vmm = mca_common_cuda_check_vmm (dbuf , & vmm_mem_type );
1915
+ is_vmm = mca_common_cuda_check_vmm (dbuf , & vmm_mem_type , & vmm_dev_id );
1916
+ is_mpool = mca_common_cuda_check_mpool (dbuf , & mpool_mem_type , & mpool_dev_id );
1818
1917
1819
1918
#if OPAL_CUDA_GET_ATTRIBUTES
1820
1919
uint32_t isManaged = 0 ;
@@ -1844,6 +1943,8 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf, opal_convertor_t
1844
1943
} else if (memType == CU_MEMORYTYPE_HOST ) {
1845
1944
if (is_vmm && (vmm_mem_type == CU_MEMORYTYPE_DEVICE )) {
1846
1945
memType = CU_MEMORYTYPE_DEVICE ;
1946
+ } else if (is_mpool && (mpool_mem_type == CU_MEMORYTYPE_DEVICE )) {
1947
+ memType = CU_MEMORYTYPE_DEVICE ;
1847
1948
} else {
1848
1949
/* Host memory, nothing to do here */
1849
1950
return 0 ;
@@ -1864,6 +1965,8 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf, opal_convertor_t
1864
1965
} else if (memType == CU_MEMORYTYPE_HOST ) {
1865
1966
if (is_vmm && (vmm_mem_type == CU_MEMORYTYPE_DEVICE )) {
1866
1967
memType = CU_MEMORYTYPE_DEVICE ;
1968
+ } else if (is_mpool && (mpool_mem_type == CU_MEMORYTYPE_DEVICE )) {
1969
+ memType = CU_MEMORYTYPE_DEVICE ;
1867
1970
} else {
1868
1971
/* Host memory, nothing to do here */
1869
1972
return 0 ;
@@ -1893,14 +1996,18 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf, opal_convertor_t
1893
1996
return OPAL_ERROR ;
1894
1997
}
1895
1998
#endif /* OPAL_CUDA_GET_ATTRIBUTES */
1896
- if (is_vmm ) {
1897
- /* This function is expected to set context if pointer is device
1898
- * accessible but VMM allocations have NULL context associated
1899
- * which cannot be set against the calling thread */
1900
- opal_output (0 ,
1901
- "CUDA: unable to set context with the given pointer"
1902
- "ptr=%p aborting..." , dbuf );
1903
- return OPAL_ERROR ;
1999
+ if (is_vmm || is_mpool ) {
2000
+ if (OPAL_SUCCESS ==
2001
+ mca_common_cuda_get_primary_context (
2002
+ is_vmm ? vmm_dev_id : mpool_dev_id , & memCtx )) {
2003
+ /* As VMM/mempool allocations have no context associated
2004
+ * with them, check if device primary context can be set */
2005
+ } else {
2006
+ opal_output (0 ,
2007
+ "CUDA: unable to set ctx with the given pointer"
2008
+ "ptr=%p aborting..." , pUserBuf );
2009
+ return OPAL_ERROR ;
2010
+ }
1904
2011
}
1905
2012
1906
2013
res = cuFunc .cuCtxSetCurrent (memCtx );
0 commit comments