Skip to content

Commit 267affb

Browse files
committed
[ROCm][CUDA] StreamExecutor logic for ROCm / CUDA platform
Work in progress. Addressed requests from @timshen91: - only contains changes in stream_executor/.... - does not remove any stream_executor/cuda/*.h, so that things outside of stream_executor don't break. All the types and functions in the namespace cuda now alias to namespace gpu counterparts. For example, namespace cuda { using CUDADriver = gpu::GpuDriver; }. - all stream_executor/gpu/BUILD targets should be only visible to //third_party/tensorflow/stream_executor:__subpackages__. - target stream_executor/gpu:X should be only used by stream_executor/cuda:cuda_X or stream_executor/rocm:rocm_X, not cuda_Y. For example, cuda:cuda_platform should depend on cuda:cuda_driver, not gpu:gpu_driver.
1 parent f6b3d83 commit 267affb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+6783
-2263
lines changed

tensorflow/BUILD

+7
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,13 @@ config_setting(
343343
},
344344
)
345345

346+
config_setting(
347+
name = "using_rocm_hipcc",
348+
define_values = {
349+
"using_rocm_hipcc": "true",
350+
},
351+
)
352+
346353
config_setting(
347354
name = "with_mpi_support",
348355
values = {"define": "with_mpi_support=true"},

tensorflow/core/BUILD

+8
Original file line numberDiff line numberDiff line change
@@ -1963,6 +1963,14 @@ cc_library(
19631963
],
19641964
)
19651965

1966+
cc_library(
1967+
name = "rocm",
1968+
visibility = ["//visibility:public"],
1969+
deps = [
1970+
"//tensorflow/core/platform/default/build_config:rocm",
1971+
],
1972+
)
1973+
19661974
# -----------------------------------------------------------------------------
19671975
# Clif-related proto libraries.
19681976

tensorflow/core/platform/default/build_config.bzl

+6
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ load("//tensorflow:tensorflow.bzl", "if_windows")
66
load("//tensorflow:tensorflow.bzl", "if_not_windows")
77
load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
88
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
9+
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
910
load(
1011
"//third_party/mkl:build_defs.bzl",
1112
"if_mkl_ml",
@@ -735,6 +736,11 @@ def tf_additional_binary_deps():
735736
"//tensorflow/stream_executor:cuda_platform",
736737
"//tensorflow/core/platform/default/build_config:cuda",
737738
],
739+
) + if_rocm(
740+
[
741+
"//tensorflow/stream_executor:rocm_platform",
742+
"//tensorflow/core/platform/default/build_config:rocm",
743+
],
738744
) + [
739745
# TODO(allenl): Split these out into their own shared objects (they are
740746
# here because they are shared between contrib/ op shared objects and

tensorflow/core/platform/default/build_config/BUILD

+27
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ licenses(["notice"]) # Apache 2.0
88
exports_files(["LICENSE"])
99

1010
load("//tensorflow:tensorflow.bzl", "if_cuda")
11+
load("//tensorflow:tensorflow.bzl", "if_rocm")
1112
load("//tensorflow:tensorflow.bzl", "tf_copts")
1213
load("//tensorflow:tensorflow.bzl", "tf_cuda_library")
1314
load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
@@ -42,6 +43,7 @@ tf_cuda_library(
4243
"//tensorflow/stream_executor/cuda:cuda_platform_id",
4344
"//tensorflow/stream_executor/host:host_platform_id",
4445
"//tensorflow/stream_executor/platform:dso_loader",
46+
"//tensorflow/stream_executor/rocm:rocm_platform_id",
4547
] + select({
4648
"@local_config_cuda//cuda:darwin": ["IOKit"],
4749
"//conditions:default": [],
@@ -50,6 +52,7 @@ tf_cuda_library(
5052
"//tensorflow:using_cuda_nvcc": ["//tensorflow/stream_executor/cuda:all_runtime"],
5153
"//tensorflow:using_cuda_clang_with_dynamic_build": [],
5254
"//tensorflow:using_cuda_nvcc_with_dynamic_build": [],
55+
"//tensorflow:using_rocm_hipcc": ["//tensorflow/stream_executor/rocm:all_runtime"],
5356
"//conditions:default": [],
5457
}),
5558
)
@@ -67,6 +70,18 @@ cc_library(
6770
}),
6871
)
6972

73+
cc_library(
74+
name = "stream_executor_rocm",
75+
deps = [
76+
":stream_executor_no_cuda",
77+
":rocm",
78+
] + if_static(
79+
["//tensorflow/stream_executor/rocm:all_runtime"],
80+
) + select({
81+
"//conditions:default": [],
82+
}),
83+
)
84+
7085
cc_library(
7186
name = "stream_executor_no_cuda",
7287
deps = [
@@ -79,6 +94,7 @@ cc_library(
7994
"//tensorflow/stream_executor/host:host_platform",
8095
"//tensorflow/stream_executor/host:host_platform_id",
8196
"//tensorflow/stream_executor/platform:dso_loader",
97+
"//tensorflow/stream_executor/rocm:rocm_platform_id",
8298
],
8399
)
84100

@@ -267,6 +283,17 @@ cc_library(
267283
],
268284
)
269285

286+
cc_library(
287+
name = "rocm",
288+
data = [],
289+
linkopts = select({
290+
"//conditions:default": [
291+
"-Wl,-rpath,../local_config_rocm/rocm/rocm/lib",
292+
],
293+
}),
294+
deps = [],
295+
)
296+
270297
cc_library(
271298
name = "sycl",
272299
data = if_ccpp([

tensorflow/core/platform/stream_executor.h

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ limitations under the License.
2727
#include "tensorflow/stream_executor/multi_platform_manager.h"
2828
#include "tensorflow/stream_executor/platform.h"
2929
#include "tensorflow/stream_executor/platform/dso_loader.h"
30+
#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
3031
#include "tensorflow/stream_executor/scratch_allocator.h"
3132
#include "tensorflow/stream_executor/stream.h"
3233
#include "tensorflow/stream_executor/stream_executor.h"

tensorflow/core/platform/stream_executor_no_cuda.h

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ limitations under the License.
2727
#include "tensorflow/stream_executor/multi_platform_manager.h"
2828
#include "tensorflow/stream_executor/platform.h"
2929
#include "tensorflow/stream_executor/platform/dso_loader.h"
30+
#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
3031
#include "tensorflow/stream_executor/scratch_allocator.h"
3132
#include "tensorflow/stream_executor/stream.h"
3233
#include "tensorflow/stream_executor/stream_executor.h"

tensorflow/stream_executor/BUILD

+7
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
licenses(["notice"]) # Apache 2.0
88

99
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
10+
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
11+
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")
1012
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
1113
load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
1214
load("//tensorflow/stream_executor:build_defs.bzl", "stream_executor_friends")
@@ -654,3 +656,8 @@ alias(
654656
name = "cuda_platform",
655657
actual = "//tensorflow/stream_executor/cuda:all_runtime",
656658
)
659+
660+
alias(
661+
name = "rocm_platform",
662+
actual = "//tensorflow/stream_executor/rocm:all_runtime",
663+
)

tensorflow/stream_executor/cuda/BUILD

+19-5
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ cc_library(
6666
deps = if_cuda_is_configured([
6767
"@com_google_absl//absl/container:inlined_vector",
6868
"@com_google_absl//absl/strings",
69+
"//tensorflow/stream_executor/gpu:gpu_diagnostics_header",
6970
"//tensorflow/stream_executor/lib",
7071
"//tensorflow/stream_executor/platform",
7172
]),
@@ -85,6 +86,7 @@ cc_library(
8586
"@com_google_absl//absl/strings",
8687
"@local_config_cuda//cuda:cuda_headers",
8788
"//tensorflow/stream_executor:device_options",
89+
"//tensorflow/stream_executor/gpu:gpu_driver_header",
8890
"//tensorflow/stream_executor/lib",
8991
"//tensorflow/stream_executor/platform",
9092
"//tensorflow/stream_executor/platform:dso_loader",
@@ -97,18 +99,22 @@ cc_library(
9799
name = "cuda_activation_header",
98100
hdrs = if_cuda_is_configured(["cuda_activation.h"]),
99101
visibility = ["//visibility:public"],
100-
deps = if_cuda_is_configured(["//tensorflow/stream_executor/platform"]),
102+
deps = if_cuda_is_configured([
103+
"//tensorflow/stream_executor/gpu:gpu_activation_header",
104+
"//tensorflow/stream_executor/platform",
105+
]),
101106
)
102107

103108
cc_library(
104109
name = "cuda_activation",
105-
srcs = if_cuda_is_configured(["cuda_activation.cc"]),
110+
srcs = [],
106111
hdrs = if_cuda_is_configured(["cuda_activation.h"]),
107112
deps = if_cuda_is_configured([
108113
":cuda_driver",
109114
"@local_config_cuda//cuda:cuda_headers",
110115
"//tensorflow/stream_executor",
111116
"//tensorflow/stream_executor:stream_executor_internal",
117+
"//tensorflow/stream_executor/gpu:gpu_activation",
112118
"//tensorflow/stream_executor/platform",
113119
]),
114120
)
@@ -120,6 +126,7 @@ cc_library(
120126
deps = if_cuda_is_configured([
121127
":cuda_kernel",
122128
"//tensorflow/stream_executor:event",
129+
"//tensorflow/stream_executor/gpu:gpu_executor_header",
123130
"//tensorflow/stream_executor/lib",
124131
"//tensorflow/stream_executor/platform",
125132
]),
@@ -230,6 +237,7 @@ cc_library(
230237
"//tensorflow/stream_executor:event",
231238
"//tensorflow/stream_executor:plugin_registry",
232239
"//tensorflow/stream_executor:rng",
240+
"//tensorflow/stream_executor/gpu:gpu_rng_header",
233241
"//tensorflow/stream_executor/lib",
234242
"//tensorflow/stream_executor/platform",
235243
"//tensorflow/stream_executor/platform:dso_loader",
@@ -239,12 +247,14 @@ cc_library(
239247

240248
cc_library(
241249
name = "cuda_kernel",
250+
srcs = if_cuda_is_configured(["cuda_kernel.cc"]),
242251
hdrs = if_cuda_is_configured(["cuda_kernel.h"]),
243252
deps = if_cuda_is_configured([
244253
":cuda_driver",
245254
"@local_config_cuda//cuda:cuda_headers",
246255
"//tensorflow/stream_executor:event",
247256
"//tensorflow/stream_executor:stream_executor_pimpl_header",
257+
"//tensorflow/stream_executor/gpu:gpu_kernel_header",
248258
"//tensorflow/stream_executor/lib",
249259
"//tensorflow/stream_executor/platform",
250260
]),
@@ -265,38 +275,41 @@ cc_library(
265275
":cuda_gpu_executor_header",
266276
":cuda_stream",
267277
"//tensorflow/stream_executor:stream_executor_headers",
278+
"//tensorflow/stream_executor/gpu:gpu_event",
279+
"//tensorflow/stream_executor/gpu:gpu_stream_header",
268280
"//tensorflow/stream_executor/lib",
269281
]),
270282
)
271283

272284
cc_library(
273285
name = "cuda_stream",
274-
srcs = if_cuda_is_configured(["cuda_stream.cc"]),
286+
srcs = [],
275287
hdrs = if_cuda_is_configured(["cuda_stream.h"]),
276288
deps = if_cuda_is_configured([
277289
":cuda_driver",
278290
":cuda_gpu_executor_header",
279291
"//tensorflow/stream_executor:stream_executor_headers",
280292
"//tensorflow/stream_executor:stream_header",
293+
"//tensorflow/stream_executor/gpu:gpu_stream",
281294
"//tensorflow/stream_executor/lib",
282295
"//tensorflow/stream_executor/platform",
283296
]),
284297
)
285298

286299
cc_library(
287300
name = "cuda_timer",
288-
srcs = if_cuda_is_configured(["cuda_timer.cc"]),
301+
srcs = [],
289302
hdrs = if_cuda_is_configured(["cuda_timer.h"]),
290303
deps = if_cuda_is_configured([
291304
":cuda_driver",
292305
":cuda_gpu_executor_header",
293306
":cuda_stream",
294307
"//tensorflow/stream_executor:stream_executor_headers",
308+
"//tensorflow/stream_executor/gpu:gpu_timer",
295309
"//tensorflow/stream_executor/lib",
296310
]),
297311
)
298312

299-
# It implements :cuda_gpu_executor_header
300313
cc_library(
301314
name = "cuda_gpu_executor",
302315
srcs = if_cuda_is_configured(["cuda_gpu_executor.cc"]),
@@ -316,6 +329,7 @@ cc_library(
316329
"//tensorflow/stream_executor:stream_executor_internal",
317330
"//tensorflow/stream_executor:stream_executor_pimpl_header",
318331
"//tensorflow/stream_executor:timer",
332+
"//tensorflow/stream_executor/gpu:gpu_executor_header",
319333
"//tensorflow/stream_executor/lib",
320334
"//tensorflow/stream_executor/platform",
321335
"//tensorflow/stream_executor/platform:dso_loader",

tensorflow/stream_executor/cuda/cuda_activation.h

+3-25
Original file line numberDiff line numberDiff line change
@@ -17,43 +17,21 @@ limitations under the License.
1717
// It reaches into the CUDA implementation to activate an underlying CUDA
1818
// context.
1919
//
20-
// Having this file separate from cuda_gpu_executor.h means that dependent
20+
// Having this file separate from cuda/cuda_gpu_executor.h means that dependent
2121
// code does not also have to depend on cuda.h.
2222

2323
#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_ACTIVATION_H_
2424
#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_ACTIVATION_H_
2525

26-
#include "tensorflow/stream_executor/platform/port.h"
26+
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
2727

2828
namespace stream_executor {
2929

3030
class StreamExecutor;
3131

3232
namespace cuda {
3333

34-
class CUDAExecutor;
35-
class ScopedActivateContext;
36-
37-
// Activates a CUDA context within an enclosing scope.
38-
class ScopedActivateExecutorContext {
39-
public:
40-
// Form that takes a CUDA executor implementation.
41-
explicit ScopedActivateExecutorContext(CUDAExecutor* cuda_exec);
42-
43-
// Form that takes a pImpl executor and extracts a CUDA implementation --
44-
// fatal failure if it is not CUDA inside.
45-
explicit ScopedActivateExecutorContext(StreamExecutor* stream_exec);
46-
47-
ScopedActivateExecutorContext(ScopedActivateExecutorContext&& other);
48-
49-
~ScopedActivateExecutorContext();
50-
51-
private:
52-
// The cuda.h-using datatype that we wrap.
53-
ScopedActivateContext* driver_scoped_activate_context_;
54-
55-
SE_DISALLOW_COPY_AND_ASSIGN(ScopedActivateExecutorContext);
56-
};
34+
using ScopedActivateExecutorContext = gpu::ScopedActivateExecutorContext;
5735

5836
} // namespace cuda
5937
} // namespace stream_executor

0 commit comments

Comments
 (0)