This repository was archived by the owner on May 11, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathRPropTrainer.cs
341 lines (325 loc) · 15.1 KB
/
RPropTrainer.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
using RCNet.Extensions;
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
namespace RCNet.Neural.Network.NonRecurrent.FF
{
/// <summary>
/// Implements the Resilient Backpropagation iRPROP+ trainer of the feed forward network.
/// </summary>
[Serializable]
public class RPropTrainer : INonRecurrentNetworkTrainer
{
//Attribute properties
/// <inheritdoc/>
public double MSE { get; private set; }
/// <inheritdoc/>
public int MaxAttempt { get; private set; }
/// <inheritdoc/>
public int Attempt { get; private set; }
/// <inheritdoc/>
public int MaxAttemptEpoch { get; private set; }
/// <inheritdoc/>
public int AttemptEpoch { get; private set; }
/// <inheritdoc/>
public string InfoMessage { get; private set; }
//Attributes
private readonly RPropTrainerSettings _cfg;
private readonly FeedForwardNetwork _net;
private readonly List<double[]> _inputVectorCollection;
private readonly List<double[]> _outputVectorCollection;
private readonly Random _rand;
private readonly double[] _weigthsGradsAcc;
private readonly double[] _weigthsPrevGradsAcc;
private readonly double[] _weigthsPrevDeltas;
private readonly double[] _weigthsPrevChanges;
private double _prevMSE;
private readonly GradientWorkerData[] _gradientWorkerDataCollection;
//Constructor
/// <summary>
/// Instantiates the RPropTrainer
/// </summary>
/// <param name="net">The FF network to be trained.</param>
/// <param name="inputVectorCollection">The input vectors (input).</param>
/// <param name="outputVectorCollection">The output vectors (ideal).</param>
/// <param name="cfg">The configuration of the trainer.</param>
/// <param name="rand">The random object to be used.</param>
public RPropTrainer(FeedForwardNetwork net,
List<double[]> inputVectorCollection,
List<double[]> outputVectorCollection,
RPropTrainerSettings cfg,
Random rand
)
{
if (!net.Finalized)
{
throw new InvalidOperationException($"Can´t create trainer. Network structure was not finalized.");
}
_cfg = cfg;
MaxAttempt = _cfg.NumOfAttempts;
MaxAttemptEpoch = _cfg.NumOfAttemptEpochs;
_net = net;
_rand = rand;
_inputVectorCollection = inputVectorCollection;
_outputVectorCollection = outputVectorCollection;
_weigthsGradsAcc = new double[_net.NumOfWeights];
_weigthsPrevGradsAcc = new double[_net.NumOfWeights];
_weigthsPrevDeltas = new double[_net.NumOfWeights];
_weigthsPrevChanges = new double[_net.NumOfWeights];
//Parallel gradient workers (batch ranges) preparation
int numOfWorkers = Math.Max(1, Math.Min(Environment.ProcessorCount - 1, _inputVectorCollection.Count));
_gradientWorkerDataCollection = new GradientWorkerData[numOfWorkers];
int workerBatchSize = _inputVectorCollection.Count / numOfWorkers;
for (int workerIdx = 0, fromRow = 0; workerIdx < numOfWorkers; workerIdx++, fromRow += workerBatchSize)
{
GradientWorkerData gwd = new GradientWorkerData
(
fromRow: fromRow,
toRow: (workerIdx == numOfWorkers - 1 ? _inputVectorCollection.Count - 1 : (fromRow + workerBatchSize) - 1),
numOfWeights: _net.NumOfWeights
);
_gradientWorkerDataCollection[workerIdx] = gwd;
}
InfoMessage = string.Empty;
//Start training attempt
Attempt = 0;
NextAttempt();
return;
}
//Properties
/// <inheritdoc/>
public INonRecurrentNetwork Net { get { return _net; } }
//Methods
/// <summary>
/// Determines the value sign. A value less than the zero tolerance is considered as zero.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private double Sign(double value)
{
if (Math.Abs(value) <= _cfg.ZeroTolerance)
{
return 0;
}
else if (value > 0)
{
return 1;
}
else
{
return -1;
}
}
/// <summary>
/// iRPROP+ variant of weight update.
/// </summary>
private void AdjustWeight(double[] flatWeights, int weightFlatIndex)
{
double weightChange = 0;
double gradMulSign = Sign(_weigthsPrevGradsAcc[weightFlatIndex] * _weigthsGradsAcc[weightFlatIndex]);
if (gradMulSign > 0)
{
//No sign change, increase delta
_weigthsPrevDeltas[weightFlatIndex] = Math.Min(_weigthsPrevDeltas[weightFlatIndex] * _cfg.PositiveEta, _cfg.MaxDelta);
weightChange = Sign(_weigthsGradsAcc[weightFlatIndex]) * _weigthsPrevDeltas[weightFlatIndex];
}
else if (gradMulSign < 0)
{
//Changed sign, decrease delta
_weigthsPrevDeltas[weightFlatIndex] = Math.Max(_weigthsPrevDeltas[weightFlatIndex] * _cfg.NegativeEta, _cfg.MinDelta);
//Ensure no change to delta in the next iteration
_weigthsGradsAcc[weightFlatIndex] = 0;
weightChange = (MSE > _prevMSE) ? -_weigthsPrevChanges[weightFlatIndex] : 0;
}
else
{
//gradMulSign == 0 -> No change to delta
weightChange = Sign(_weigthsGradsAcc[weightFlatIndex]) * _weigthsPrevDeltas[weightFlatIndex];
}
flatWeights[weightFlatIndex] += weightChange;
_weigthsPrevChanges[weightFlatIndex] = weightChange;
return;
}
//Must not be parallel to ensure the same counting order (and also results)
private void ProcessGradientWorkersData()
{
MSE = 0;
foreach (GradientWorkerData worker in _gradientWorkerDataCollection)
{
for (int i = 0; i < _weigthsGradsAcc.Length; i++)
{
_weigthsGradsAcc[i] += worker._weigthsGradsAcc[i];
}
MSE += worker._sumSquaredErr;
}
//Finish the MSE computation
MSE /= (double)(_inputVectorCollection.Count * _net.NumOfOutputValues);
return;
}
/// <inheritdoc/>
public bool NextAttempt()
{
if (Attempt < MaxAttempt)
{
//Next attempt is allowed
++Attempt;
//Reset
_net.RandomizeWeights(_rand);
_weigthsGradsAcc.Populate(0);
_weigthsPrevGradsAcc.Populate(0);
_weigthsPrevDeltas.Populate(_cfg.IniDelta);
_weigthsPrevChanges.Populate(0);
_prevMSE = 0;
MSE = 0;
AttemptEpoch = 0;
return true;
}
else
{
//Max attempt reached -> do nothhing and return false
return false;
}
}
/// <inheritdoc/>
public bool Iteration()
{
if (AttemptEpoch == MaxAttemptEpoch)
{
//Max epoch reached, try new attempt
if (!NextAttempt())
{
//Next attempt is not available
return false;
}
}
//Next epoch
++AttemptEpoch;
//Store previous iteration error
_prevMSE = MSE;
//Store previously accumulated weight gradients
_weigthsGradsAcc.CopyTo(_weigthsPrevGradsAcc, 0);
//Reset accumulated weight gradients
_weigthsGradsAcc.Populate(0);
//Get copy of the network weights
double[] networkFlatWeights = _net.GetWeightsCopy();
//Network output layer shortcut
FeedForwardNetwork.Layer outputLayer = _net.LayerCollection[_net.LayerCollection.Count - 1];
//Process gradient workers threads
Parallel.ForEach(_gradientWorkerDataCollection, worker =>
{
//----------------------------------------------------------------------------------------------------
//Gradient worker local variables
List<double[]> layerInputCollection = new List<double[]>(_net.LayerCollection.Count);
double[] gradients = new double[_net.NumOfNeurons];
double[] derivatives = new double[_net.NumOfNeurons];
//Reset gradient worker data
worker.Reset();
//Loop over the planned range of samples
for (int row = worker._fromRow; row <= worker._toRow; row++)
{
//----------------------------------------------------------------------------------------------------
//Reset of row dependents
layerInputCollection.Clear();
gradients.Populate(0);
derivatives.Populate(0);
//----------------------------------------------------------------------------------------------------
//Network computation (collect layers inputs and activation derivatives)
double[] computedOutputs = _net.Compute(_inputVectorCollection[row], layerInputCollection, derivatives);
//----------------------------------------------------------------------------------------------------
//Compute output layer gradients and update error
int outputLayerNeuronsFlatIdx = outputLayer.NeuronsStartFlatIdx;
for (int neuronIdx = 0; neuronIdx < outputLayer.NumOfLayerNeurons; neuronIdx++)
{
double error = _outputVectorCollection[row][neuronIdx] - computedOutputs[neuronIdx];
gradients[outputLayerNeuronsFlatIdx] = derivatives[outputLayerNeuronsFlatIdx] * error;
//Accumulate error
worker._sumSquaredErr += error * error;
++outputLayerNeuronsFlatIdx;
}//neuronIdx
//----------------------------------------------------------------------------------------------------
//Hidden layers gradients
for (int layerIdx = _net.LayerCollection.Count - 2; layerIdx >= 0; layerIdx--)
{
FeedForwardNetwork.Layer currLayer = _net.LayerCollection[layerIdx];
FeedForwardNetwork.Layer nextLayer = _net.LayerCollection[layerIdx + 1];
int currLayerNeuronFlatIdx = currLayer.NeuronsStartFlatIdx;
for (int currLayerNeuronIdx = 0; currLayerNeuronIdx < currLayer.NumOfLayerNeurons; currLayerNeuronIdx++)
{
double sum = 0;
for (int nextLayerNeuronIdx = 0; nextLayerNeuronIdx < nextLayer.NumOfLayerNeurons; nextLayerNeuronIdx++)
{
int weightFlatIdx = nextLayer.WeightsStartFlatIdx + nextLayerNeuronIdx * nextLayer.NumOfInputNodes + currLayerNeuronIdx;
sum += gradients[nextLayer.NeuronsStartFlatIdx + nextLayerNeuronIdx] * networkFlatWeights[weightFlatIdx];
}//nextLayerNeuronIdx
gradients[currLayerNeuronFlatIdx] = derivatives[currLayerNeuronFlatIdx] * sum;
++currLayerNeuronFlatIdx;
}//currLayerNeuronIdx
}//layerIdx
//----------------------------------------------------------------------------------------------------
//Compute increments for gradients accumulator
for (int layerIdx = 0; layerIdx < _net.LayerCollection.Count; layerIdx++)
{
FeedForwardNetwork.Layer layer = _net.LayerCollection[layerIdx];
double[] layerInputs = layerInputCollection[layerIdx];
int neuronFlatIdx = layer.NeuronsStartFlatIdx;
int weightFlatIdx = layer.WeightsStartFlatIdx;
int biasFlatIdx = layer.BiasesStartFlatIdx;
for (int neuronIdx = 0; neuronIdx < layer.NumOfLayerNeurons; neuronIdx++)
{
//Weights gradients accumulation
//Layer's inputs
for (int inputIdx = 0; inputIdx < layer.NumOfInputNodes; inputIdx++)
{
worker._weigthsGradsAcc[weightFlatIdx] += layerInputs[inputIdx] * gradients[neuronFlatIdx];
++weightFlatIdx;
}
//Layer's input bias
worker._weigthsGradsAcc[biasFlatIdx] += FeedForwardNetwork.BiasValue * gradients[neuronFlatIdx];
++neuronFlatIdx;
++biasFlatIdx;
}//neuronIdx
}//layerIdx
}//Worker main loop
});//Worker finish
//Update of gradient accumulator and MSE by workers
ProcessGradientWorkersData();
//Update all weights and biases
Parallel.For(0, networkFlatWeights.Length, weightFlatIdx =>
{
AdjustWeight(networkFlatWeights, weightFlatIdx);
});
//Set adjusted weights back into the network under training
_net.SetWeights(networkFlatWeights);
return true;
}
//Inner classes
[Serializable]
internal class GradientWorkerData
{
//Attribute properties
public int _fromRow;
public int _toRow;
public double _sumSquaredErr;
public double[] _weigthsGradsAcc;
//Constructor
internal GradientWorkerData(int fromRow, int toRow, int numOfWeights)
{
_fromRow = fromRow;
_toRow = toRow;
_weigthsGradsAcc = new double[numOfWeights];
Reset();
return;
}
//Methods
/// <summary>
/// Resets the gradient worker data to initial state.
/// </summary>
internal void Reset()
{
_sumSquaredErr = 0;
_weigthsGradsAcc.Populate(0);
return;
}
}//WorkerData
}//RPropTrainer
}//Namespace