@@ -10,7 +10,7 @@ namespace {
10
10
// We don't use the getNet() from predictor_utils.cc here because that file
11
11
// has additional dependencies that we want to avoid bringing in, to keep the
12
12
// binary size as small as possible.
13
- static const NetDef& getNet (const MetaNetDef& def, const std::string& name) {
13
+ const NetDef& getNet (const MetaNetDef& def, const std::string& name) {
14
14
for (const auto & n : def.nets ()) {
15
15
if (n.key () == name) {
16
16
return n.value ();
@@ -19,7 +19,7 @@ static const NetDef& getNet(const MetaNetDef& def, const std::string& name) {
19
19
CAFFE_THROW (" Net not found: " , name);
20
20
}
21
21
22
- static const ::google::protobuf::RepeatedPtrField<::std::string>& getBlobs (
22
+ const ::google::protobuf::RepeatedPtrField<::std::string>& getBlobs (
23
23
const MetaNetDef& def,
24
24
const std::string& name) {
25
25
for (const auto & b : def.blobs ()) {
@@ -30,60 +30,26 @@ static const ::google::protobuf::RepeatedPtrField<::std::string>& getBlobs(
30
30
CAFFE_THROW (" Blob not found: " , name);
31
31
}
32
32
33
- static std::string combine (const std::string& str, const std::string& name) {
34
- if (name.empty ()) {
35
- return std::string (str);
36
- }
37
- return str + " _" + name;
38
- }
39
-
40
- static std::string getNamedPredictNet (const string& name) {
41
- return combine (PredictorConsts::default_instance ().predict_net_type (), name);
42
- }
43
-
44
- static std::string getNamedInitNet (const string& name) {
45
- return combine (
46
- PredictorConsts::default_instance ().predict_init_net_type (), name);
47
- }
48
-
49
- static std::string getNamedInputs (const string& name) {
50
- return combine (PredictorConsts::default_instance ().inputs_blob_type (), name);
51
- }
52
-
53
- static std::string getNamedOutputs (const string& name) {
54
- return combine (PredictorConsts::default_instance ().outputs_blob_type (), name);
55
- }
56
-
57
- static std::string getNamedParams (const string& name) {
58
- return combine (
59
- PredictorConsts::default_instance ().parameters_blob_type (), name);
60
- }
61
-
62
33
} // namespace
63
34
64
- PredictorConfig makePredictorConfig (
65
- const MetaNetDef& def,
66
- Workspace* parent,
67
- bool run_init,
68
- const std::string& net_name) {
69
- const auto & init_net = getNet (def, getNamedInitNet (net_name));
70
- const auto & run_net = getNet (def, getNamedPredictNet (net_name));
35
+ PredictorConfig
36
+ makePredictorConfig (const MetaNetDef& def, Workspace* parent, bool run_init) {
37
+ const auto & init_net =
38
+ getNet (def, PredictorConsts::default_instance ().global_init_net_type ());
39
+ const auto & run_net =
40
+ getNet (def, PredictorConsts::default_instance ().predict_net_type ());
71
41
auto config = makePredictorConfig (init_net, run_net, parent, run_init);
72
- const auto & inputs = getBlobs (def, getNamedInputs (net_name));
42
+ const auto & inputs =
43
+ getBlobs (def, PredictorConsts::default_instance ().inputs_blob_type ());
73
44
for (const auto & input : inputs) {
74
45
config.input_names .emplace_back (input);
75
46
}
76
47
77
- const auto & outputs = getBlobs (def, getNamedOutputs (net_name));
48
+ const auto & outputs =
49
+ getBlobs (def, PredictorConsts::default_instance ().outputs_blob_type ());
78
50
for (const auto & output : outputs) {
79
51
config.output_names .emplace_back (output);
80
52
}
81
-
82
- const auto & params = getBlobs (def, getNamedParams (net_name));
83
- for (const auto & param : params) {
84
- config.parameter_names .emplace_back (param);
85
- }
86
-
87
53
return config;
88
54
}
89
55
0 commit comments