Skip to content

Commit 513cb0c

Browse files
committed
Revert "#1. Fix Forward-LSTM crash bug #2. Improve encoding performance by SIMD instructions"
This reverts commit 1a3070c.
1 parent 4fad1b6 commit 513cb0c

11 files changed

+198
-262
lines changed

RNNSharp/BiRNN.cs

+17-22
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
using System.Threading.Tasks;
44
using AdvUtils;
55
using System.Collections.Generic;
6-
using System.Numerics;
76

87
/// <summary>
98
/// RNNSharp written by Zhongkai Fu ([email protected])
@@ -14,7 +13,6 @@ class BiRNN : RNN
1413
{
1514
private RNN forwardRNN;
1615
private RNN backwardRNN;
17-
private Vector<float> vecConst2 = new Vector<float>(2.0f);
1816

1917
public BiRNN(RNN s_forwardRNN, RNN s_backwardRNN)
2018
{
@@ -131,7 +129,7 @@ public override float LearningRate
131129
}
132130
}
133131

134-
public override float GradientCutoff
132+
public override double GradientCutoff
135133
{
136134
get
137135
{
@@ -211,7 +209,7 @@ public override void InitMem()
211209
backwardRNN.InitMem();
212210

213211
//Create and intialise the weights from hidden to output layer, these are just normal weights
214-
Hidden2OutputWeight = new Matrix<float>(L2, L1);
212+
Hidden2OutputWeight = new Matrix<double>(L2, L1);
215213

216214
for (int i = 0; i < Hidden2OutputWeight.Height; i++)
217215
{
@@ -224,7 +222,7 @@ public override void InitMem()
224222
Hidden2OutputWeightLearningRate = new Matrix<float>(L2, L1);
225223
}
226224

227-
public SimpleLayer[] InnerDecode(Sequence pSequence, out SimpleLayer[] outputHiddenLayer, out Matrix<float> rawOutputLayer)
225+
public SimpleLayer[] InnerDecode(Sequence pSequence, out SimpleLayer[] outputHiddenLayer, out Matrix<double> rawOutputLayer)
228226
{
229227
int numStates = pSequence.States.Length;
230228
SimpleLayer[] mForward = null;
@@ -268,18 +266,14 @@ public SimpleLayer[] InnerDecode(Sequence pSequence, out SimpleLayer[] outputHid
268266
SimpleLayer forwardCells = mForward[curState];
269267
SimpleLayer backwardCells = mBackward[curState];
270268

271-
for (int i = 0; i < forwardRNN.L1; i+=Vector<float>.Count)
269+
for (int i = 0; i < forwardRNN.L1; i++)
272270
{
273-
Vector<float> v1 = new Vector<float>(forwardCells.cellOutput, i);
274-
Vector<float> v2 = new Vector<float>(backwardCells.cellOutput, i);
275-
Vector<float> v = (v1 + v2) / vecConst2;
276-
277-
v.CopyTo(cells.cellOutput, i);
271+
cells.cellOutput[i] = (forwardCells.cellOutput[i] + backwardCells.cellOutput[i]) / 2.0;
278272
}
279273
});
280274

281275
//Calculate output layer
282-
Matrix<float> tmp_rawOutputLayer = new Matrix<float>(numStates, L2);
276+
Matrix<double> tmp_rawOutputLayer = new Matrix<double>(numStates, L2);
283277
SimpleLayer[] seqOutput = new SimpleLayer[numStates];
284278
Parallel.For(0, numStates, parallelOption, curState =>
285279
{
@@ -288,7 +282,7 @@ public SimpleLayer[] InnerDecode(Sequence pSequence, out SimpleLayer[] outputHid
288282

289283
matrixXvectorADD(outputCells, mergedHiddenLayer[curState], Hidden2OutputWeight, L2, L1, 0);
290284

291-
float[] tmp_vector = tmp_rawOutputLayer[curState];
285+
double[] tmp_vector = tmp_rawOutputLayer[curState];
292286
outputCells.cellOutput.CopyTo(tmp_vector, 0);
293287

294288
//Activation on output layer
@@ -307,7 +301,7 @@ public override int[] PredictSentenceCRF(Sequence pSequence, RunningMode running
307301
int numStates = pSequence.States.Length;
308302
//Predict output
309303
SimpleLayer[] mergedHiddenLayer = null;
310-
Matrix<float> rawOutputLayer = null;
304+
Matrix<double> rawOutputLayer = null;
311305
SimpleLayer[] seqOutput = InnerDecode(pSequence, out mergedHiddenLayer, out rawOutputLayer);
312306

313307
ForwardBackward(numStates, rawOutputLayer);
@@ -332,7 +326,7 @@ public override int[] PredictSentenceCRF(Sequence pSequence, RunningMode running
332326
{
333327
int label = pSequence.States[curState].Label;
334328
SimpleLayer layer = seqOutput[curState];
335-
float[] CRFOutputLayer = CRFSeqOutput[curState];
329+
double[] CRFOutputLayer = CRFSeqOutput[curState];
336330

337331
//For standard RNN
338332
for (int c = 0; c < L2; c++)
@@ -348,14 +342,14 @@ public override int[] PredictSentenceCRF(Sequence pSequence, RunningMode running
348342
return predict;
349343
}
350344

351-
public override Matrix<float> PredictSentence(Sequence pSequence, RunningMode runningMode)
345+
public override Matrix<double> PredictSentence(Sequence pSequence, RunningMode runningMode)
352346
{
353347
//Reset the network
354348
int numStates = pSequence.States.Length;
355349

356350
//Predict output
357351
SimpleLayer[] mergedHiddenLayer = null;
358-
Matrix<float> rawOutputLayer = null;
352+
Matrix<double> rawOutputLayer = null;
359353
SimpleLayer[] seqOutput = InnerDecode(pSequence, out mergedHiddenLayer, out rawOutputLayer);
360354

361355
if (runningMode != RunningMode.Test)
@@ -380,7 +374,7 @@ public override Matrix<float> PredictSentence(Sequence pSequence, RunningMode ru
380374
{
381375
layer.er[c] = -layer.cellOutput[c];
382376
}
383-
layer.er[label] = 1.0f - layer.cellOutput[label];
377+
layer.er[label] = 1.0 - layer.cellOutput[label];
384378
}
385379

386380
LearnTwoRNN(pSequence, mergedHiddenLayer, seqOutput);
@@ -413,17 +407,18 @@ private void LearnTwoRNN(Sequence pSequence, SimpleLayer[] mergedHiddenLayer, Si
413407
for (int i = 0; i < Hidden2OutputWeight.Height; i++)
414408
{
415409
//update weights for hidden to output layer
416-
float er = outputCells.er[i];
417-
float[] vector_i = Hidden2OutputWeight[i];
410+
double er = outputCells.er[i];
411+
double[] vector_i = Hidden2OutputWeight[i];
418412
for (int k = 0; k < Hidden2OutputWeight.Width; k++)
419413
{
420414
double delta = NormalizeGradient(mergedHiddenCells.cellOutput[k] * er);
421415
double newLearningRate = UpdateLearningRate(Hidden2OutputWeightLearningRate, i, k, delta);
422416

423-
vector_i[k] += (float)(newLearningRate * delta);
417+
vector_i[k] += newLearningRate * delta;
424418
}
425419
}
426420
}
421+
427422
},
428423
()=>
429424
{
@@ -490,7 +485,7 @@ public override void computeHiddenLayer(State state, bool isTrain = true)
490485
throw new NotImplementedException("computeHiddenLayer is not implemented in BiRNN");
491486
}
492487

493-
public override void computeOutput(float[] doutput)
488+
public override void computeOutput(double[] doutput)
494489
{
495490
throw new NotImplementedException("computeOutput is not implemented in BiRNN");
496491
}

RNNSharp/LSTMRNN.cs

+58-33
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ public class LSTMCell : SimpleCell
3030
public double wCellForget;
3131
public double wCellOut;
3232

33+
public float dCellInLearningRate;
34+
public float dCellForgetLearningRate;
35+
public float dCellOutLearningRate;
36+
3337
//partial derivatives
3438
public double dSWCellIn;
3539
public double dSWCellForget;
@@ -48,6 +52,22 @@ public struct LSTMWeight
4852
public float wInputOutputGate;
4953
}
5054

55+
//public struct LSTMWeightLearningRate
56+
//{
57+
// public float dInputCellLearningRate;
58+
// public float dInputInputGateLearningRate;
59+
// public float dInputForgetGateLearningRate;
60+
// public float dInputOutputGateLearningRate;
61+
//}
62+
63+
//public struct LSTMWeightDerivative
64+
//{
65+
// //partial derivatives. dont need partial derivative for output gate as it uses BP not RTRL
66+
// public double dSInputCell;
67+
// public double dSInputInputGate;
68+
// public double dSInputForgetGate;
69+
//}
70+
5171
public class LSTMRNN : RNN
5272
{
5373
public LSTMCell[] neuHidden; //neurons in hidden layer
@@ -56,15 +76,10 @@ public class LSTMRNN : RNN
5676

5777
protected Vector4[][] Input2HiddenLearningRate;
5878
protected Vector4[][] Feature2HiddenLearningRate;
59-
protected Vector3[] CellLearningRate;
6079

6180
protected Vector3[][] input2hiddenDeri;
6281
protected Vector3[][] feature2hiddenDeri;
6382

64-
private Vector4 vecLearningRate;
65-
private Vector3 vecLearningRate3;
66-
67-
6883
public LSTMRNN()
6984
{
7085
ModelType = MODELTYPE.LSTM;
@@ -353,7 +368,7 @@ public override void SaveModel(string filename)
353368
//weight input->hidden
354369
Logger.WriteLine("Saving input2hidden weights...");
355370
saveLSTMWeight(input2hidden, fo);
356-
371+
357372
if (DenseFeatureSize > 0)
358373
{
359374
//weight fea->hidden
@@ -438,7 +453,7 @@ public override void initWeights()
438453
}
439454

440455
//Create and intialise the weights from hidden to output layer, these are just normal weights
441-
Hidden2OutputWeight = new Matrix<float>(L2, L1);
456+
Hidden2OutputWeight = new Matrix<double>(L2, L1);
442457

443458
for (int i = 0; i < Hidden2OutputWeight.Height; i++)
444459
{
@@ -484,9 +499,12 @@ public override void CleanStatus()
484499
Feature2HiddenLearningRate = new Vector4[L1][];
485500
}
486501

487-
CellLearningRate = new Vector3[L1];
488502
Parallel.For(0, L1, parallelOption, i =>
489503
{
504+
neuHidden[i].dCellForgetLearningRate = 0;
505+
neuHidden[i].dCellInLearningRate = 0;
506+
neuHidden[i].dCellOutLearningRate = 0;
507+
490508
Input2HiddenLearningRate[i] = new Vector4[L0];
491509

492510
if (DenseFeatureSize > 0)
@@ -497,8 +515,6 @@ public override void CleanStatus()
497515
});
498516

499517
Hidden2OutputWeightLearningRate = new Matrix<float>(L2, L1);
500-
vecLearningRate = new Vector4(LearningRate, LearningRate, LearningRate, LearningRate);
501-
vecLearningRate3 = new Vector3(LearningRate, LearningRate, LearningRate);
502518
}
503519

504520
public override void InitMem()
@@ -567,7 +583,7 @@ public override void ComputeHiddenLayerErr()
567583
//find the error by find the product of the output errors and their weight connection.
568584
SimpleCell cell = neuHidden[i];
569585

570-
cell.er = 0.0f;
586+
cell.er = 0.0;
571587

572588
if (cell.mask == false)
573589
{
@@ -584,22 +600,30 @@ public override void LearnOutputWeight()
584600
//update weights for hidden to output layer
585601
Parallel.For(0, L1, parallelOption, i =>
586602
{
587-
float cellOutput = neuHidden[i].cellOutput;
603+
double cellOutput = neuHidden[i].cellOutput;
588604
for (int k = 0; k < L2; k++)
589605
{
590-
float delta = NormalizeGradient(cellOutput * OutputLayer.er[k]);
591-
double newLearningRate = UpdateLearningRate(Hidden2OutputWeightLearningRate, k, i, delta);
606+
double delta = NormalizeGradient(cellOutput * OutputLayer.er[k]);
607+
double newLearningRate = UpdateLearningRate(Hidden2OutputWeightLearningRate, i, k, delta);
592608

593-
Hidden2OutputWeight[k][i] += (float)(newLearningRate * delta);
609+
Hidden2OutputWeight[k][i] += newLearningRate * delta;
594610
}
595611
});
596612
}
597613

614+
public double UpdateLearningRate(ref float mg, double delta)
615+
{
616+
double dg = mg + delta * delta;
617+
mg = (float)dg;
618+
return LearningRate / (1.0 + Math.Sqrt(dg));
619+
}
620+
598621
public override void LearnNet(State state, int numStates, int curState)
599622
{
600623
//Get sparse feature and apply it into hidden layer
601624
var sparse = state.SparseData;
602625
int sparseFeatureSize = sparse.Count;
626+
Vector4 vecLearningRate = new Vector4(LearningRate, LearningRate, LearningRate, LearningRate);
603627

604628
//put variables for derivaties in weight class and cell class
605629
Parallel.For(0, L1, parallelOption, i =>
@@ -626,6 +650,8 @@ public override void LearnNet(State state, int numStates, int curState)
626650
(float)Sigmoid2_ci_netCellState_mul_SigmoidDerivative_ci_netIn,
627651
(float)ci_previousCellState_mul_SigmoidDerivative_ci_netForget);
628652

653+
double delta = 0;
654+
double newLearningRate = 0;
629655
for (int k = 0; k < sparseFeatureSize; k++)
630656
{
631657
var entry = sparse.GetEntry(k);
@@ -647,7 +673,9 @@ public override void LearnNet(State state, int numStates, int curState)
647673
vecAlpha = wlr + vecAlpha;
648674
wlr_i[entry.Key] = vecAlpha;
649675

650-
vecAlpha = vecLearningRate / (Vector4.SquareRoot(vecAlpha) + Vector4.One);
676+
vecAlpha = Vector4.SquareRoot(vecAlpha) + Vector4.One;
677+
vecAlpha = vecLearningRate / vecAlpha;
678+
651679
vecDelta = vecAlpha * vecDelta;
652680

653681
w.wInputCell += vecDelta.X;
@@ -685,7 +713,9 @@ public override void LearnNet(State state, int numStates, int curState)
685713
vecAlpha = wlr + vecAlpha;
686714
wlr_i[j] = vecAlpha;
687715

688-
vecAlpha = vecLearningRate / (Vector4.SquareRoot(vecAlpha) + Vector4.One);
716+
vecAlpha = Vector4.SquareRoot(vecAlpha) + Vector4.One;
717+
vecAlpha = vecLearningRate / vecAlpha;
718+
689719
vecDelta = vecAlpha * vecDelta;
690720

691721
w.wInputCell += vecDelta.X;
@@ -706,22 +736,17 @@ public override void LearnNet(State state, int numStates, int curState)
706736

707737

708738
//update internal weights
709-
Vector3 vecCellDelta = new Vector3((float)c.dSWCellIn, (float)c.dSWCellForget, (float)c.cellState);
710-
Vector3 vecCellErr = new Vector3(cellStateError, cellStateError, gradientOutputGate);
711-
Vector3 vecCellLearningRate = CellLearningRate[i];
712-
713-
vecCellDelta = vecCellErr * vecCellDelta;
714-
vecCellLearningRate += (vecCellDelta * vecCellDelta);
715-
CellLearningRate[i] = vecCellLearningRate;
716-
717-
//LearningRate / (1.0 + Math.Sqrt(dg));
718-
vecCellLearningRate = vecLearningRate3 / (Vector3.One + Vector3.SquareRoot(vecCellLearningRate));
719-
vecCellDelta = vecCellLearningRate * vecCellDelta;
739+
delta = cellStateError * c.dSWCellIn;
740+
newLearningRate = UpdateLearningRate(ref c.dCellInLearningRate, delta);
741+
c.wCellIn += newLearningRate * delta;
720742

721-
c.wCellIn += vecCellDelta.X;
722-
c.wCellForget += vecCellDelta.Y;
723-
c.wCellOut += vecCellDelta.Z;
743+
delta = cellStateError * c.dSWCellForget;
744+
newLearningRate = UpdateLearningRate(ref c.dCellForgetLearningRate, delta);
745+
c.wCellForget += newLearningRate * delta;
724746

747+
delta = gradientOutputGate * c.cellState;
748+
newLearningRate = UpdateLearningRate(ref c.dCellOutLearningRate, delta);
749+
c.wCellOut += newLearningRate * delta;
725750

726751
neuHidden[i] = c;
727752
});
@@ -808,15 +833,15 @@ public override void computeHiddenLayer(State state, bool isTrain = true)
808833
//squash output gate
809834
cell_j.yOut = Sigmoid(cell_j.netOut);
810835

811-
cell_j.cellOutput = (float)(cell_j.cellState * cell_j.yOut);
836+
cell_j.cellOutput = cell_j.cellState * cell_j.yOut;
812837

813838

814839
neuHidden[j] = cell_j;
815840
});
816841
}
817842

818843

819-
public override void computeOutput(float[] doutput)
844+
public override void computeOutput(double[] doutput)
820845
{
821846
matrixXvectorADD(OutputLayer, neuHidden, Hidden2OutputWeight, L2, L1, 0);
822847
if (doutput != null)

RNNSharp/MathUtil.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ namespace RNNSharp
77
{
88
class MathUtil
99
{
10-
public static int GetMaxProbIndex(float [] array)
10+
public static int GetMaxProbIndex(double [] array)
1111
{
1212
int dim = array.Length;
1313
double maxValue = array[0];

0 commit comments

Comments
 (0)