Skip to content

Commit ebc087f

Browse files
committed
Code refactoring
1 parent 0b69b73 commit ebc087f

File tree

4 files changed

+25
-26
lines changed

4 files changed

+25
-26
lines changed

RNNSharp/Layers/DropoutLayer.cs

-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ public override Neuron CopyNeuronTo(Neuron neuron)
3838
DropoutNeuron dropoutNeuron = neuron as DropoutNeuron;
3939
mask.CopyTo(dropoutNeuron.mask, 0);
4040
Cells.CopyTo(dropoutNeuron.Cells, 0);
41-
previousCellOutputs.CopyTo(dropoutNeuron.PrevCellOutputs, 0);
4241

4342
return dropoutNeuron;
4443
}

RNNSharp/Layers/LSTMLayer.cs

+18-17
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,14 @@ public class LSTMLayer : SimpleLayer
129129
private Vector4 vecNormalLearningRate;
130130
private Vector3 vecNormalLearningRate3;
131131
private Vector<float> vecNormalLearningRateFloat;
132+
// protected float[] previousCellOutputs;
132133

133134
LSTMLayerConfig config;
134135

135136
public LSTMLayer(LSTMLayerConfig config) : base(config)
136137
{
137138
this.config = config;
139+
// previousCellOutputs = new float[LayerSize];
138140
LSTMCells = new LSTMCell[LayerSize];
139141
for (var i = 0; i < LayerSize; i++)
140142
{
@@ -155,7 +157,7 @@ public override Neuron CopyNeuronTo(Neuron neuron)
155157
LSTMNeuron lstmNeuron = neuron as LSTMNeuron;
156158

157159
Cells.CopyTo(lstmNeuron.Cells, 0);
158-
previousCellOutputs.CopyTo(lstmNeuron.PrevCellOutputs, 0);
160+
// previousCellOutputs.CopyTo(lstmNeuron.PrevCellOutputs, 0);
159161
for (int i = 0; i < LayerSize; i++)
160162
{
161163
lstmNeuron.LSTMCells[i].Set(LSTMCells[i]);
@@ -168,8 +170,8 @@ public override Neuron CopyNeuronTo(Neuron neuron)
168170
public override void PreUpdateWeights(Neuron neuron, float[] errs)
169171
{
170172
LSTMNeuron lstmNeuron = neuron as LSTMNeuron;
171-
lstmNeuron.Cells.CopyTo(Cells, 0);
172-
lstmNeuron.PrevCellOutputs.CopyTo(previousCellOutputs, 0);
173+
// lstmNeuron.Cells.CopyTo(Cells, 0);
174+
// lstmNeuron.PrevCellOutputs.CopyTo(previousCellOutputs, 0);
173175
for (int i = 0; i < LayerSize; i++)
174176
{
175177
LSTMCells[i].Set(lstmNeuron.LSTMCells[i]);
@@ -726,7 +728,8 @@ public override void ForwardPass(SparseVector sparseFeature, float[] denseFeatur
726728

727729
//hidden(t-1) -> hidden(t)
728730
cell_j.previousCellState = cell_j.cellState;
729-
previousCellOutputs[j] = Cells[j];
731+
cell_j.previousCellOutput = Cells[j];
732+
// previousCellOutputs[j] = Cells[j];
730733

731734
var vecCell_j = Vector4.Zero;
732735

@@ -787,26 +790,24 @@ public override void ForwardPass(SparseVector sparseFeature, float[] denseFeatur
787790
//reset each netOut to zero
788791
cell_j.netOut = vecCell_j.W;
789792

790-
var cell_j_previousCellOutput = previousCellOutputs[j];
791-
792793
//include internal connection multiplied by the previous cell state
793-
cell_j.netIn += cell_j.previousCellState * cellWeight_j.wPeepholeIn + cell_j_previousCellOutput * cellWeight_j.wCellIn;
794+
cell_j.netIn += cell_j.previousCellState * cellWeight_j.wPeepholeIn + cell_j.previousCellOutput * cellWeight_j.wCellIn;
794795
//squash input
795796
cell_j.yIn = Sigmoid(cell_j.netIn);
796797

797798
//include internal connection multiplied by the previous cell state
798799
cell_j.netForget += cell_j.previousCellState * cellWeight_j.wPeepholeForget +
799-
cell_j_previousCellOutput * cellWeight_j.wCellForget;
800+
cell_j.previousCellOutput * cellWeight_j.wCellForget;
800801
cell_j.yForget = Sigmoid(cell_j.netForget);
801802

802-
cell_j.netCellState += cell_j_previousCellOutput * cellWeight_j.wCellState;
803+
cell_j.netCellState += cell_j.previousCellOutput * cellWeight_j.wCellState;
803804
cell_j.yCellState = TanH(cell_j.netCellState);
804805

805806
//cell state is equal to the previous cell state multipled by the forget gate and the cell inputs multiplied by the input gate
806807
cell_j.cellState = cell_j.yForget * cell_j.previousCellState + cell_j.yIn * cell_j.yCellState;
807808

808809
////include the internal connection multiplied by the CURRENT cell state
809-
cell_j.netOut += cell_j.cellState * cellWeight_j.wPeepholeOut + cell_j_previousCellOutput * cellWeight_j.wCellOut;
810+
cell_j.netOut += cell_j.cellState * cellWeight_j.wPeepholeOut + cell_j.previousCellOutput * cellWeight_j.wCellOut;
810811

811812
//squash output gate
812813
cell_j.yOut = Sigmoid(cell_j.netOut);
@@ -1001,20 +1002,18 @@ public override void BackwardPass()
10011002
cellWeight.wPeepholeOut += vecCellDelta.Z;
10021003

10031004
//Update cells weights
1004-
var c_previousCellOutput = previousCellOutputs[i];
10051005
//partial derivatives for internal connections
10061006
cellWeightDeri.dSWCellIn = cellWeightDeri.dSWCellIn * c.yForget +
1007-
Sigmoid2_ci_netCellState_mul_SigmoidDerivative_ci_netIn * c_previousCellOutput;
1007+
Sigmoid2_ci_netCellState_mul_SigmoidDerivative_ci_netIn * c.previousCellOutput;
10081008

10091009
//partial derivatives for internal connections, initially zero as dS is zero and previous cell state is zero
10101010
cellWeightDeri.dSWCellForget = cellWeightDeri.dSWCellForget * c.yForget +
1011-
ci_previousCellState_mul_SigmoidDerivative_ci_netForget * c_previousCellOutput;
1011+
ci_previousCellState_mul_SigmoidDerivative_ci_netForget * c.previousCellOutput;
10121012

10131013
cellWeightDeri.dSWCellState = cellWeightDeri.dSWCellState * c.yForget +
1014-
Sigmoid2Derivative_ci_netCellState_mul_ci_yIn * c_previousCellOutput;
1014+
Sigmoid2Derivative_ci_netCellState_mul_ci_yIn * c.previousCellOutput;
10151015

1016-
var vecCellDelta4 = new Vector4((float)cellWeightDeri.dSWCellIn, (float)cellWeightDeri.dSWCellForget, (float)cellWeightDeri.dSWCellState,
1017-
c_previousCellOutput);
1016+
var vecCellDelta4 = new Vector4((float)cellWeightDeri.dSWCellIn, (float)cellWeightDeri.dSWCellForget, (float)cellWeightDeri.dSWCellState, (float)c.previousCellOutput);
10181017
vecCellDelta4 = vecErr * vecCellDelta4;
10191018

10201019
//Normalize err by gradient cut-off
@@ -1069,7 +1068,7 @@ public override void Reset()
10691068

10701069
private void InitializeLSTMCell(LSTMCell c, LSTMCellWeight cw, LSTMCellWeightDeri deri)
10711070
{
1072-
c.previousCellState = 0;
1071+
// c.previousCellState = 0;
10731072
c.cellState = 0;
10741073

10751074
//partial derivatives
@@ -1139,6 +1138,7 @@ public class LSTMCell
11391138
{
11401139
//The following fields are only for forward
11411140
public double previousCellState;
1141+
public double previousCellOutput;
11421142
public double cellState;
11431143

11441144
public double netCellState;
@@ -1164,6 +1164,7 @@ public LSTMCell(LSTMCell cell)
11641164
public void Set(LSTMCell cell)
11651165
{
11661166
previousCellState = cell.previousCellState;
1167+
previousCellOutput = cell.previousCellOutput;
11671168
cellState = cell.cellState;
11681169
netCellState = cell.netCellState;
11691170
netForget = cell.netForget;

RNNSharp/Layers/SimpleLayer.cs

-4
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ public class SimpleLayer
2323
public float[] DenseFeature { get; set; }
2424

2525
protected ParallelOptions parallelOption = new ParallelOptions();
26-
protected float[] previousCellOutputs;
2726
protected RunningMode runningMode;
2827

2928

@@ -34,7 +33,6 @@ public SimpleLayer(LayerConfig config)
3433
{
3534
LayerConfig = config;
3635
Cells = new float[LayerSize];
37-
previousCellOutputs = new float[LayerSize];
3836
Errs = new float[LayerSize];
3937
LabelShortList = new List<int>();
4038
}
@@ -68,7 +66,6 @@ public void SetRunningMode(RunningMode mode)
6866
public virtual Neuron CopyNeuronTo(Neuron neuron)
6967
{
7068
Cells.CopyTo(neuron.Cells, 0);
71-
previousCellOutputs.CopyTo(neuron.PrevCellOutputs, 0);
7269

7370
return neuron;
7471
}
@@ -99,7 +96,6 @@ public virtual void ShallowCopyWeightTo(SimpleLayer destLayer)
9996
public virtual void PreUpdateWeights(Neuron neuron, float[] errs)
10097
{
10198
neuron.Cells.CopyTo(Cells, 0);
102-
neuron.PrevCellOutputs.CopyTo(previousCellOutputs, 0);
10399
errs.CopyTo(Errs, 0);
104100
}
105101

RNNSharp/RNNEncoder.cs

+7-4
Original file line numberDiff line numberDiff line change
@@ -217,14 +217,17 @@ public void Train()
217217

218218
}
219219

220+
var start = DateTime.Now;
220221
Logger.WriteLine($"Start to training {iter} iteration. learning rate = {RNNHelper.LearningRate}");
221222
Parallel.For(0, N, i =>
222223
{
223224
rnns[i].CleanStatus();
224225
Process(rnns[i], dataSets[i], RunningMode.Training);
225226
});
226227

227-
Logger.WriteLine($"End {iter} iteration.");
228+
var duration = DateTime.Now.Subtract(start);
229+
230+
Logger.WriteLine($"End {iter} iteration. Time duration = {duration}");
228231
Logger.WriteLine("");
229232

230233
if (tknErrCnt >= bestTrainTknErrCnt && lastAlpha != RNNHelper.LearningRate)
@@ -256,7 +259,6 @@ public void Train()
256259
//We got better result on validated corpus, save this model
257260
Logger.WriteLine($"Saving better model into file {modelFilePath}, since we got a better result on validation set.");
258261
Logger.WriteLine($"Error token percent: {(double)tknErrCnt / (double)processedWordCnt * 100.0}%, Error sequence percent: {(double)sentErrCnt / (double)processedSequence * 100.0}%");
259-
Logger.WriteLine("");
260262

261263
rnn.SaveModel(modelFilePath);
262264
bestValidTknErrCnt = tknErrCnt;
@@ -268,11 +270,12 @@ public void Train()
268270
//We got better result on validated corpus, save this model
269271
Logger.WriteLine($"Saving better model into file {modelFilePath}, although validation set doesn't exist, we have better result on training set.");
270272
Logger.WriteLine($"Error token percent: {(double)trainTknErrCnt / (double)processedWordCnt * 100.0}%, Error sequence percent: {(double)sentErrCnt / (double)processedSequence * 100.0}%");
271-
Logger.WriteLine("");
272273

273274
rnn.SaveModel(modelFilePath);
274275
}
275-
276+
277+
Logger.WriteLine("");
278+
276279
if (trainTknErrCnt >= bestTrainTknErrCnt)
277280
{
278281
//We don't have better result on training set, so reduce learning rate

0 commit comments

Comments
 (0)