File tree 4 files changed +41
-14
lines changed
tensorflow/tensorflow/compiler/xla/service/gpu
4 files changed +41
-14
lines changed Original file line number Diff line number Diff line change 12
12
echo " ##########################################################"
13
13
echo " Batch size : ${BATCH} "
14
14
echo " Growth K : ${GROWTH_K} "
15
- echo " Num iters : ${ITER} "
15
+ echo " Num iters : ${ITER} "
16
16
echo " ##########################################################"
17
17
python /workspace/correctness_test/code/dense_base.py ${BATCH} ${GROWTH_K} ${ITER}
18
18
19
19
export DO_OOO_BACKPROP=" true"
20
20
export OOO_CAPTURE_OP=" cluster_1_1/xla_run"
21
21
export OOO_CAPTURE_ITER=2
22
22
export OOO_NUM_BLOCK_OVERLAP_FORWARD=88
23
+ export OOO_OVERLAP_START=" B4"
24
+ export OOO_OVERLAP_END=" B3"
23
25
export OOO_USE_SUB_STREAM=" true"
24
26
python /workspace/correctness_test/code/dense_ooo.py ${BATCH} ${GROWTH_K} ${ITER}
25
27
unset DO_OOO_BACKPROP
26
28
unset OOO_CAPTURE_OP
27
29
unset OOO_CAPTURE_ITER
28
30
unset OOO_NUM_BLOCK_OVERLAP_FORWARD
31
+ unset OOO_OVERLAP_START
32
+ unset OOO_OVERLAP_END
29
33
unset OOO_USE_SUB_STREAM
30
34
31
35
python /workspace/correctness_test/code/logit_diff.py
Original file line number Diff line number Diff line change @@ -6,6 +6,8 @@ export DO_OOO_BACKPROP="true"
6
6
export OOO_CAPTURE_OP=" cluster_1_1/xla_run"
7
7
export OOO_CAPTURE_ITER=3
8
8
export OOO_NUM_BLOCK_OVERLAP_FORWARD=88
9
+ export OOO_OVERLAP_START=" B4"
10
+ export OOO_OVERLAP_END=" B3"
9
11
export OOO_USE_SUB_STREAM=" true"
10
12
11
13
BATCH=$1
@@ -25,4 +27,6 @@ unset DO_OOO_BACKPROP
25
27
unset OOO_CAPTURE_OP
26
28
unset OOO_CAPTURE_ITER
27
29
unset OOO_NUM_BLOCK_OVERLAP_FORWARD
30
+ unset OOO_OVERLAP_START
31
+ unset OOO_OVERLAP_END
28
32
unset OOO_USE_SUB_STREAM
Original file line number Diff line number Diff line change @@ -101,6 +101,14 @@ GpuExecutable::GpuExecutable(
101
101
if (!str_capture_iter.empty ()) {
102
102
capture_iter_ = std::stoi (str_capture_iter);
103
103
}
104
+
105
+ const char * cstr_overlap_start_name = std::getenv (" OOO_OVERLAP_START" );
106
+ std::string str_overlap_start_name (cstr_overlap_start_name ? cstr_overlap_start_name : " NONE" );
107
+ overlap_start_name_ = str_overlap_start_name;
108
+
109
+ const char * cstr_overlap_end_name = std::getenv (" OOO_OVERLAP_END" );
110
+ std::string str_overlap_end_name (cstr_overlap_end_name ? cstr_overlap_end_name : " NONE" );
111
+ overlap_end_name_ = str_overlap_end_name;
104
112
}
105
113
}
106
114
Original file line number Diff line number Diff line change @@ -208,23 +208,34 @@ class GpuExecutable : public Executable {
208
208
std::string FORWARD_OVERLAP_GRAPH = " FORWARD_OVERLAP_WGRADS" ;
209
209
std::string DEFAULT_GRAPH = " LAST_GRAPH" ;
210
210
211
+ std::string overlap_start_name_;
212
+ std::string overlap_end_name_;
213
+
211
214
bool is_overlap_w_grad_op ( std::string op_name, std::string hlo_name ){
212
- if ( op_name.find (" B4" ) != std::string::npos &&
213
- op_name.find (" Conv2DBackpropFilter" ) != std::string::npos &&
214
- op_name.find (" Dummy" ) == std::string::npos &&
215
- hlo_name.find (" custom" ) != std::string::npos ) {
216
- return true ;
217
- } else {
218
- return false ;
219
- }
215
+ if (overlap_start_name_ == " NONE" ) {
216
+ return false ;
217
+ }
218
+
219
+ if ( op_name.find (overlap_start_name_) != std::string::npos &&
220
+ op_name.find (" Conv2DBackpropFilter" ) != std::string::npos &&
221
+ op_name.find (" Dummy" ) == std::string::npos &&
222
+ hlo_name.find (" custom" ) != std::string::npos ) {
223
+ return true ;
224
+ } else {
225
+ return false ;
226
+ }
220
227
}
221
228
222
229
bool is_remain_graph_start ( std::string op_name ){
223
- if ( op_name.find (" B3" ) != std::string::npos ) {
224
- return true ;
225
- } else {
226
- return false ;
227
- }
230
+ if (overlap_end_name_ == " NONE" ) {
231
+ return false ;
232
+ }
233
+
234
+ if ( op_name.find (overlap_end_name_) != std::string::npos ) {
235
+ return true ;
236
+ } else {
237
+ return false ;
238
+ }
228
239
}
229
240
230
241
bool is_deleted_op ( std::string op_name ){
You can’t perform that action at this time.
0 commit comments