-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathConfig.h
133 lines (104 loc) · 2.98 KB
/
Config.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
/* Copyright 2015,2016 Tao Xu
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#pragma once
#include <algorithm>
#include <unordered_map>
#include <string>
#include <vector>
namespace boosting {
enum LossFunction {
L2Regression = 0,
L2Logistic = 1
};
// Specifying the training parameters and data format
struct Config {
// reads configuration file generated by
// a.feed.scripts.boosting.gen_json.gen_json_file
bool readConfig(const std::string& fileName);
int getNumFeatures() const {
return trainIdx_.size();
}
int getNumTrees() const {
return numTrees_;
}
int getNumLeaves() const {
return numLeaves_;
}
double getLearningRate() const {
return learningRate_;
}
double getExampleSamplingRate() const {
return exampleSamplingRate_;
}
double getFeatureSamplingRate() const {
return featureSamplingRate_;
}
int getTargetIdx() const {
return targetIdx_;
}
int getCompareIdx() const {
return cmpIdx_;
}
const std::vector<int>& getTrainIdx() const {
return trainIdx_;
}
bool isWeakFeature(const int fidx) const {
return (std::find(weakIdx_.begin(), weakIdx_.end(), trainIdx_[fidx])
!= weakIdx_.end());
}
const std::string& getFeatureName(const int fidx) const {
return allColumns_[trainIdx_[fidx]];
}
// Returns -1 if feature is not found.
int getFeatureIndex(const std::string& f) const {
auto it = featureToIndexMap_.find(f);
return it != featureToIndexMap_.end() ? it->second : -1;
}
const std::vector<int>& getWeakIdx() const {
return weakIdx_;
}
const std::vector<int>& getEvalIdx() const {
return evalIdx_;
}
const std::vector<std::string>& getColumnNames() const {
return allColumns_;
}
char getDelimiter() const {
return delimiter_;
}
LossFunction getLossFunction() const {
return lossFunction_;
}
private:
int numTrees_;
int numLeaves_;
double exampleSamplingRate_;
double featureSamplingRate_;
double learningRate_;
int targetIdx_;
int cmpIdx_;
LossFunction lossFunction_;
std::vector<int> trainIdx_;
std::vector<int> weakIdx_;
std::vector<int> evalIdx_;
std::vector<std::string> allColumns_;
std::unordered_map<std::string, int> featureToIndexMap_;
char delimiter_;
};
}