Skip to content

Commit 0927e02

Browse files
Hao Lufacebook-github-bot
Hao Lu
authored andcommitted
[caffe2] Do not run RemoveOpsByType on recurrent networks (pytorch#45986)
Summary: Pull Request resolved: pytorch#45986 Recurrent networks have subnets that are not well supported by `RemoveOpsByType`. Here we exclude recurrent networks by adding the same check as in memonger. Test Plan: ``` buck test //caffe2/caffe2/fb/predictor:black_box_predictor_test ``` AdIndexer canary for sanity check: https://www.internalfb.com/intern/ads/canary/430059485214766620 Differential Revision: D24167284 fbshipit-source-id: fa90d1c1f34af334a599d879af09d4c0bf7c27bd
1 parent c8d76ff commit 0927e02

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

caffe2/predictor/transforms.cc

+10-4
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ void RenameOutputs(
9090
void RenameInputsInChildren(
9191
const string& from,
9292
const string& to,
93-
std::shared_ptr<caffe2::NetDef> net,
93+
caffe2::NetDef* net,
9494
std::unordered_map<std::string, std::unordered_set<int>>& children) {
9595
VLOG(2) << "RenameInputsInChildren (from=" << from << ", to=" << to << ")";
9696
if (children.count(from) == 0) {
@@ -106,7 +106,7 @@ void RenameInputsInChildren(
106106
void RenameOutputInParents(
107107
const std::string& from,
108108
const std::string& to,
109-
std::shared_ptr<caffe2::NetDef> net,
109+
caffe2::NetDef* net,
110110
std::unordered_map<std::string, std::unordered_set<int>>& parents) {
111111
VLOG(2) << "RenameOutputInParents (from=" << from << ", to=" << to << ")";
112112
if (parents.count(from) == 0) {
@@ -225,7 +225,13 @@ bool FoundOpCandidate(
225225
// extra complexity is handled in FoundOpCandidate.
226226
void RemoveOpsByType(InferenceGraph& graph, const std::string& op_type) {
227227
int num_removed = 0;
228-
std::shared_ptr<NetDef> net = graph.predict_net_def;
228+
NetDef* net = graph.predict_net_def.get();
229+
for (auto& op : net->op()) {
230+
if (op.type() == "RecurrentNetwork") {
231+
LOG(INFO) << "RemoveOpsByType does not support RecurrentNetwork yet";
232+
return;
233+
}
234+
}
229235

230236
std::unordered_set<std::string> inputs(
231237
graph.input_names.begin(), graph.input_names.end());
@@ -239,7 +245,7 @@ void RemoveOpsByType(InferenceGraph& graph, const std::string& op_type) {
239245
for (const auto& o : graph.output_names) {
240246
net->add_external_output(o);
241247
}
242-
onnx::SsaRewrite(nullptr, net.get());
248+
onnx::SsaRewrite(nullptr, net);
243249
// clear external_outputs
244250
net->mutable_external_output()->Clear();
245251
graph.predictor_net_ssa_rewritten = true;

0 commit comments

Comments
 (0)