Skip to content

Commit ac5e6a2

Browse files
committed
Fix single gpu expr scripts & Add environment variables for capturing CUDA graph
1 parent 169d55d commit ac5e6a2

File tree

4 files changed

+41
-14
lines changed

4 files changed

+41
-14
lines changed

expr/single_gpu/correctness_test/scripts/run_densenet_correct_test.sh

+5-1
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,24 @@ do
1212
echo "##########################################################"
1313
echo " Batch size : ${BATCH}"
1414
echo " Growth K : ${GROWTH_K}"
15-
echo " Num iters : ${ITER}"
15+
echo " Num iters : ${ITER}"
1616
echo "##########################################################"
1717
python /workspace/correctness_test/code/dense_base.py ${BATCH} ${GROWTH_K} ${ITER}
1818

1919
export DO_OOO_BACKPROP="true"
2020
export OOO_CAPTURE_OP="cluster_1_1/xla_run"
2121
export OOO_CAPTURE_ITER=2
2222
export OOO_NUM_BLOCK_OVERLAP_FORWARD=88
23+
export OOO_OVERLAP_START="B4"
24+
export OOO_OVERLAP_END="B3"
2325
export OOO_USE_SUB_STREAM="true"
2426
python /workspace/correctness_test/code/dense_ooo.py ${BATCH} ${GROWTH_K} ${ITER}
2527
unset DO_OOO_BACKPROP
2628
unset OOO_CAPTURE_OP
2729
unset OOO_CAPTURE_ITER
2830
unset OOO_NUM_BLOCK_OVERLAP_FORWARD
31+
unset OOO_OVERLAP_START
32+
unset OOO_OVERLAP_END
2933
unset OOO_USE_SUB_STREAM
3034

3135
python /workspace/correctness_test/code/logit_diff.py

expr/single_gpu/scripts/run_densenet_expr_ooo.sh

+4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ export DO_OOO_BACKPROP="true"
66
export OOO_CAPTURE_OP="cluster_1_1/xla_run"
77
export OOO_CAPTURE_ITER=3
88
export OOO_NUM_BLOCK_OVERLAP_FORWARD=88
9+
export OOO_OVERLAP_START="B4"
10+
export OOO_OVERLAP_END="B3"
911
export OOO_USE_SUB_STREAM="true"
1012

1113
BATCH=$1
@@ -25,4 +27,6 @@ unset DO_OOO_BACKPROP
2527
unset OOO_CAPTURE_OP
2628
unset OOO_CAPTURE_ITER
2729
unset OOO_NUM_BLOCK_OVERLAP_FORWARD
30+
unset OOO_OVERLAP_START
31+
unset OOO_OVERLAP_END
2832
unset OOO_USE_SUB_STREAM

tensorflow/tensorflow/compiler/xla/service/gpu/gpu_executable.cc

+8
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,14 @@ GpuExecutable::GpuExecutable(
101101
if (!str_capture_iter.empty()) {
102102
capture_iter_ = std::stoi(str_capture_iter);
103103
}
104+
105+
const char* cstr_overlap_start_name = std::getenv("OOO_OVERLAP_START");
106+
std::string str_overlap_start_name(cstr_overlap_start_name ? cstr_overlap_start_name : "NONE");
107+
overlap_start_name_ = str_overlap_start_name;
108+
109+
const char* cstr_overlap_end_name = std::getenv("OOO_OVERLAP_END");
110+
std::string str_overlap_end_name(cstr_overlap_end_name ? cstr_overlap_end_name : "NONE");
111+
overlap_end_name_ = str_overlap_end_name;
104112
}
105113
}
106114

tensorflow/tensorflow/compiler/xla/service/gpu/gpu_executable.h

+24-13
Original file line numberDiff line numberDiff line change
@@ -208,23 +208,34 @@ class GpuExecutable : public Executable {
208208
std::string FORWARD_OVERLAP_GRAPH = "FORWARD_OVERLAP_WGRADS";
209209
std::string DEFAULT_GRAPH = "LAST_GRAPH";
210210

211+
std::string overlap_start_name_;
212+
std::string overlap_end_name_;
213+
211214
bool is_overlap_w_grad_op( std::string op_name, std::string hlo_name ){
212-
if ( op_name.find("B4") != std::string::npos &&
213-
op_name.find("Conv2DBackpropFilter") != std::string::npos &&
214-
op_name.find("Dummy") == std::string::npos &&
215-
hlo_name.find("custom") != std::string::npos ) {
216-
return true;
217-
} else {
218-
return false;
219-
}
215+
if (overlap_start_name_ == "NONE") {
216+
return false;
217+
}
218+
219+
if ( op_name.find(overlap_start_name_) != std::string::npos &&
220+
op_name.find("Conv2DBackpropFilter") != std::string::npos &&
221+
op_name.find("Dummy") == std::string::npos &&
222+
hlo_name.find("custom") != std::string::npos ) {
223+
return true;
224+
} else {
225+
return false;
226+
}
220227
}
221228

222229
bool is_remain_graph_start( std::string op_name ){
223-
if ( op_name.find("B3") != std::string::npos ) {
224-
return true;
225-
} else {
226-
return false;
227-
}
230+
if (overlap_end_name_ == "NONE") {
231+
return false;
232+
}
233+
234+
if ( op_name.find(overlap_end_name_) != std::string::npos ) {
235+
return true;
236+
} else {
237+
return false;
238+
}
228239
}
229240

230241
bool is_deleted_op( std::string op_name ){

0 commit comments

Comments
 (0)