Skip to content

Commit 43e94b3

Browse files
authored
[ROCm] Create torchvision as a HIP Extension (#1928)
* Added code to support creating extension on ROCm * max -> fmaxf conversion for hipification * added WITH_HIP flag for hipExtension * added appropriate headers for HIP build * use USE_ROCM in condition to build * change fmaxf and fminf calls * fminf -> min * fix the check for ROCM_HOME * more robust checking for rocm pytorch * add check for pytorch version before using HIP extensions * conditional reading of ROCM_HOME
1 parent cca0c77 commit 43e94b3

File tree

9 files changed

+74
-18
lines changed

9 files changed

+74
-18
lines changed

setup.py

+32-7
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import torch
1515
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
16+
from torch.utils.hipify import hipify_python
1617

1718

1819
def read(*names, **kwargs):
@@ -83,7 +84,27 @@ def get_extensions():
8384

8485
main_file = glob.glob(os.path.join(extensions_dir, '*.cpp'))
8586
source_cpu = glob.glob(os.path.join(extensions_dir, 'cpu', '*.cpp'))
86-
source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu'))
87+
88+
is_rocm_pytorch = False
89+
if torch.__version__ >= '1.5':
90+
from torch.utils.cpp_extension import ROCM_HOME
91+
is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False
92+
93+
if is_rocm_pytorch:
94+
hipify_python.hipify(
95+
project_directory=this_dir,
96+
output_directory=this_dir,
97+
includes="torchvision/csrc/cuda/*",
98+
show_detailed=True,
99+
is_pytorch_extension=True,
100+
)
101+
source_cuda = glob.glob(os.path.join(extensions_dir, 'hip', '*.hip'))
102+
## Copy over additional files
103+
shutil.copy("torchvision/csrc/cuda/cuda_helpers.h", "torchvision/csrc/hip/cuda_helpers.h")
104+
shutil.copy("torchvision/csrc/cuda/vision_cuda.h", "torchvision/csrc/hip/vision_cuda.h")
105+
106+
else:
107+
source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu'))
87108

88109
sources = main_file + source_cpu
89110
extension = CppExtension
@@ -103,15 +124,19 @@ def get_extensions():
103124
define_macros = []
104125

105126
extra_compile_args = {}
106-
if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv('FORCE_CUDA', '0') == '1':
127+
if (torch.cuda.is_available() and ((CUDA_HOME is not None) or is_rocm_pytorch)) or os.getenv('FORCE_CUDA', '0') == '1':
107128
extension = CUDAExtension
108129
sources += source_cuda
109-
define_macros += [('WITH_CUDA', None)]
110-
nvcc_flags = os.getenv('NVCC_FLAGS', '')
111-
if nvcc_flags == '':
112-
nvcc_flags = []
130+
if not is_rocm_pytorch:
131+
define_macros += [('WITH_CUDA', None)]
132+
nvcc_flags = os.getenv('NVCC_FLAGS', '')
133+
if nvcc_flags == '':
134+
nvcc_flags = []
135+
else:
136+
nvcc_flags = nvcc_flags.split(' ')
113137
else:
114-
nvcc_flags = nvcc_flags.split(' ')
138+
define_macros += [('WITH_HIP', None)]
139+
nvcc_flags = []
115140
extra_compile_args = {
116141
'cxx': [],
117142
'nvcc': nvcc_flags,

torchvision/csrc/DeformConv.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
#ifdef WITH_CUDA
66
#include "cuda/vision_cuda.h"
77
#endif
8+
#ifdef WITH_HIP
9+
#include "hip/vision_cuda.h"
10+
#endif
811

912
at::Tensor DeformConv2d_forward(
1013
const at::Tensor& input,
@@ -17,7 +20,7 @@ at::Tensor DeformConv2d_forward(
1720
const int groups,
1821
const int offset_groups) {
1922
if (input.type().is_cuda()) {
20-
#ifdef WITH_CUDA
23+
#if defined(WITH_CUDA) || defined(WITH_HIP)
2124
return DeformConv2d_forward_cuda(
2225
input.contiguous(),
2326
weight.contiguous(),
@@ -56,7 +59,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> DeformConv2d_backward
5659
const int groups,
5760
const int offset_groups) {
5861
if (grad.type().is_cuda()) {
59-
#ifdef WITH_CUDA
62+
#if defined(WITH_CUDA) || defined(WITH_HIP)
6063
return DeformConv2d_backward_cuda(
6164
grad.contiguous(),
6265
input.contiguous(),

torchvision/csrc/PSROIAlign.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
#ifdef WITH_CUDA
66
#include "cuda/vision_cuda.h"
77
#endif
8+
#ifdef WITH_HIP
9+
#include "hip/vision_cuda.h"
10+
#endif
811

912
#include <iostream>
1013

@@ -16,7 +19,7 @@ std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward(
1619
const int pooled_width,
1720
const int sampling_ratio) {
1821
if (input.type().is_cuda()) {
19-
#ifdef WITH_CUDA
22+
#if defined(WITH_CUDA) || defined(WITH_HIP)
2023
return PSROIAlign_forward_cuda(
2124
input,
2225
rois,
@@ -45,7 +48,7 @@ at::Tensor PSROIAlign_backward(
4548
const int height,
4649
const int width) {
4750
if (grad.type().is_cuda()) {
48-
#ifdef WITH_CUDA
51+
#if defined(WITH_CUDA) || defined(WITH_HIP)
4952
return PSROIAlign_backward_cuda(
5053
grad,
5154
rois,

torchvision/csrc/PSROIPool.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
#ifdef WITH_CUDA
66
#include "cuda/vision_cuda.h"
77
#endif
8+
#ifdef WITH_HIP
9+
#include "hip/vision_cuda.h"
10+
#endif
811

912
std::tuple<at::Tensor, at::Tensor> PSROIPool_forward(
1013
const at::Tensor& input,
@@ -13,7 +16,7 @@ std::tuple<at::Tensor, at::Tensor> PSROIPool_forward(
1316
const int pooled_height,
1417
const int pooled_width) {
1518
if (input.type().is_cuda()) {
16-
#ifdef WITH_CUDA
19+
#if defined(WITH_CUDA) || defined(WITH_HIP)
1720
return PSROIPool_forward_cuda(
1821
input, rois, spatial_scale, pooled_height, pooled_width);
1922
#else
@@ -36,7 +39,7 @@ at::Tensor PSROIPool_backward(
3639
const int height,
3740
const int width) {
3841
if (grad.type().is_cuda()) {
39-
#ifdef WITH_CUDA
42+
#if defined(WITH_CUDA) || defined(WITH_HIP)
4043
return PSROIPool_backward_cuda(
4144
grad,
4245
rois,

torchvision/csrc/ROIAlign.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
#ifdef WITH_CUDA
66
#include "cuda/vision_cuda.h"
77
#endif
8+
#ifdef WITH_HIP
9+
#include "hip/vision_cuda.h"
10+
#endif
811

912
// Interface for Python
1013
at::Tensor ROIAlign_forward(
@@ -19,7 +22,7 @@ at::Tensor ROIAlign_forward(
1922
// along each axis.
2023
{
2124
if (input.type().is_cuda()) {
22-
#ifdef WITH_CUDA
25+
#if defined(WITH_CUDA) || defined(WITH_HIP)
2326
return ROIAlign_forward_cuda(
2427
input,
2528
rois,
@@ -49,7 +52,7 @@ at::Tensor ROIAlign_backward(
4952
const int sampling_ratio,
5053
const bool aligned) {
5154
if (grad.type().is_cuda()) {
52-
#ifdef WITH_CUDA
55+
#if defined(WITH_CUDA) || defined(WITH_HIP)
5356
return ROIAlign_backward_cuda(
5457
grad,
5558
rois,

torchvision/csrc/ROIPool.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
#ifdef WITH_CUDA
66
#include "cuda/vision_cuda.h"
77
#endif
8+
#ifdef WITH_HIP
9+
#include "hip/vision_cuda.h"
10+
#endif
811

912
std::tuple<at::Tensor, at::Tensor> ROIPool_forward(
1013
const at::Tensor& input,
@@ -13,7 +16,7 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward(
1316
const int64_t pooled_height,
1417
const int64_t pooled_width) {
1518
if (input.type().is_cuda()) {
16-
#ifdef WITH_CUDA
19+
#if defined(WITH_CUDA) || defined(WITH_HIP)
1720
return ROIPool_forward_cuda(
1821
input, rois, spatial_scale, pooled_height, pooled_width);
1922
#else
@@ -36,7 +39,7 @@ at::Tensor ROIPool_backward(
3639
const int height,
3740
const int width) {
3841
if (grad.type().is_cuda()) {
39-
#ifdef WITH_CUDA
42+
#if defined(WITH_CUDA) || defined(WITH_HIP)
4043
return ROIPool_backward_cuda(
4144
grad,
4245
rois,

torchvision/csrc/cuda/vision_cuda.h

+4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
#pragma once
2+
#if defined(WITH_CUDA)
23
#include <c10/cuda/CUDAGuard.h>
4+
#elif defined(WITH_HIP)
5+
#include <c10/hip/HIPGuard.h>
6+
#endif
37
#include <torch/extension.h>
48

59
at::Tensor ROIAlign_forward_cuda(

torchvision/csrc/nms.h

+10-1
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,27 @@
44
#ifdef WITH_CUDA
55
#include "cuda/vision_cuda.h"
66
#endif
7+
#ifdef WITH_HIP
8+
#include "hip/vision_cuda.h"
9+
#endif
710

811
at::Tensor nms(
912
const at::Tensor& dets,
1013
const at::Tensor& scores,
1114
const double iou_threshold) {
1215
if (dets.device().is_cuda()) {
13-
#ifdef WITH_CUDA
16+
#if defined(WITH_CUDA)
1417
if (dets.numel() == 0) {
1518
at::cuda::CUDAGuard device_guard(dets.device());
1619
return at::empty({0}, dets.options().dtype(at::kLong));
1720
}
1821
return nms_cuda(dets, scores, iou_threshold);
22+
#elif defined(WITH_HIP)
23+
if (dets.numel() == 0) {
24+
at::cuda::HIPGuard device_guard(dets.device());
25+
return at::empty({0}, dets.options().dtype(at::kLong));
26+
}
27+
return nms_cuda(dets, scores, iou_threshold);
1928
#else
2029
AT_ERROR("Not compiled with GPU support");
2130
#endif

torchvision/csrc/vision.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
#ifdef WITH_CUDA
55
#include <cuda.h>
66
#endif
7+
#ifdef WITH_HIP
8+
#include <hip/hip_runtime.h>
9+
#endif
710

811
#include "DeformConv.h"
912
#include "PSROIAlign.h"

0 commit comments

Comments
 (0)