Skip to content

Commit 1a3070c

Browse files
committed
#1. Fix Forward-LSTM crash bug
#2. Improve encoding performance by SIMD instructions
1 parent e390b85 commit 1a3070c

11 files changed

+262
-198
lines changed

RNNSharp/BiRNN.cs

+22-17
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.Threading.Tasks;
44
using AdvUtils;
55
using System.Collections.Generic;
6+
using System.Numerics;
67

78
/// <summary>
89
/// RNNSharp written by Zhongkai Fu ([email protected])
@@ -13,6 +14,7 @@ class BiRNN : RNN
1314
{
1415
private RNN forwardRNN;
1516
private RNN backwardRNN;
17+
private Vector<float> vecConst2 = new Vector<float>(2.0f);
1618

1719
public BiRNN(RNN s_forwardRNN, RNN s_backwardRNN)
1820
{
@@ -129,7 +131,7 @@ public override float LearningRate
129131
}
130132
}
131133

132-
public override double GradientCutoff
134+
public override float GradientCutoff
133135
{
134136
get
135137
{
@@ -209,7 +211,7 @@ public override void InitMem()
209211
backwardRNN.InitMem();
210212

211213
//Create and intialise the weights from hidden to output layer, these are just normal weights
212-
Hidden2OutputWeight = new Matrix<double>(L2, L1);
214+
Hidden2OutputWeight = new Matrix<float>(L2, L1);
213215

214216
for (int i = 0; i < Hidden2OutputWeight.Height; i++)
215217
{
@@ -222,7 +224,7 @@ public override void InitMem()
222224
Hidden2OutputWeightLearningRate = new Matrix<float>(L2, L1);
223225
}
224226

225-
public SimpleLayer[] InnerDecode(Sequence pSequence, out SimpleLayer[] outputHiddenLayer, out Matrix<double> rawOutputLayer)
227+
public SimpleLayer[] InnerDecode(Sequence pSequence, out SimpleLayer[] outputHiddenLayer, out Matrix<float> rawOutputLayer)
226228
{
227229
int numStates = pSequence.States.Length;
228230
SimpleLayer[] mForward = null;
@@ -266,14 +268,18 @@ public SimpleLayer[] InnerDecode(Sequence pSequence, out SimpleLayer[] outputHid
266268
SimpleLayer forwardCells = mForward[curState];
267269
SimpleLayer backwardCells = mBackward[curState];
268270

269-
for (int i = 0; i < forwardRNN.L1; i++)
271+
for (int i = 0; i < forwardRNN.L1; i+=Vector<float>.Count)
270272
{
271-
cells.cellOutput[i] = (forwardCells.cellOutput[i] + backwardCells.cellOutput[i]) / 2.0;
273+
Vector<float> v1 = new Vector<float>(forwardCells.cellOutput, i);
274+
Vector<float> v2 = new Vector<float>(backwardCells.cellOutput, i);
275+
Vector<float> v = (v1 + v2) / vecConst2;
276+
277+
v.CopyTo(cells.cellOutput, i);
272278
}
273279
});
274280

275281
//Calculate output layer
276-
Matrix<double> tmp_rawOutputLayer = new Matrix<double>(numStates, L2);
282+
Matrix<float> tmp_rawOutputLayer = new Matrix<float>(numStates, L2);
277283
SimpleLayer[] seqOutput = new SimpleLayer[numStates];
278284
Parallel.For(0, numStates, parallelOption, curState =>
279285
{
@@ -282,7 +288,7 @@ public SimpleLayer[] InnerDecode(Sequence pSequence, out SimpleLayer[] outputHid
282288

283289
matrixXvectorADD(outputCells, mergedHiddenLayer[curState], Hidden2OutputWeight, L2, L1, 0);
284290

285-
double[] tmp_vector = tmp_rawOutputLayer[curState];
291+
float[] tmp_vector = tmp_rawOutputLayer[curState];
286292
outputCells.cellOutput.CopyTo(tmp_vector, 0);
287293

288294
//Activation on output layer
@@ -301,7 +307,7 @@ public override int[] PredictSentenceCRF(Sequence pSequence, RunningMode running
301307
int numStates = pSequence.States.Length;
302308
//Predict output
303309
SimpleLayer[] mergedHiddenLayer = null;
304-
Matrix<double> rawOutputLayer = null;
310+
Matrix<float> rawOutputLayer = null;
305311
SimpleLayer[] seqOutput = InnerDecode(pSequence, out mergedHiddenLayer, out rawOutputLayer);
306312

307313
ForwardBackward(numStates, rawOutputLayer);
@@ -326,7 +332,7 @@ public override int[] PredictSentenceCRF(Sequence pSequence, RunningMode running
326332
{
327333
int label = pSequence.States[curState].Label;
328334
SimpleLayer layer = seqOutput[curState];
329-
double[] CRFOutputLayer = CRFSeqOutput[curState];
335+
float[] CRFOutputLayer = CRFSeqOutput[curState];
330336

331337
//For standard RNN
332338
for (int c = 0; c < L2; c++)
@@ -342,14 +348,14 @@ public override int[] PredictSentenceCRF(Sequence pSequence, RunningMode running
342348
return predict;
343349
}
344350

345-
public override Matrix<double> PredictSentence(Sequence pSequence, RunningMode runningMode)
351+
public override Matrix<float> PredictSentence(Sequence pSequence, RunningMode runningMode)
346352
{
347353
//Reset the network
348354
int numStates = pSequence.States.Length;
349355

350356
//Predict output
351357
SimpleLayer[] mergedHiddenLayer = null;
352-
Matrix<double> rawOutputLayer = null;
358+
Matrix<float> rawOutputLayer = null;
353359
SimpleLayer[] seqOutput = InnerDecode(pSequence, out mergedHiddenLayer, out rawOutputLayer);
354360

355361
if (runningMode != RunningMode.Test)
@@ -374,7 +380,7 @@ public override Matrix<double> PredictSentence(Sequence pSequence, RunningMode r
374380
{
375381
layer.er[c] = -layer.cellOutput[c];
376382
}
377-
layer.er[label] = 1.0 - layer.cellOutput[label];
383+
layer.er[label] = 1.0f - layer.cellOutput[label];
378384
}
379385

380386
LearnTwoRNN(pSequence, mergedHiddenLayer, seqOutput);
@@ -407,18 +413,17 @@ private void LearnTwoRNN(Sequence pSequence, SimpleLayer[] mergedHiddenLayer, Si
407413
for (int i = 0; i < Hidden2OutputWeight.Height; i++)
408414
{
409415
//update weights for hidden to output layer
410-
double er = outputCells.er[i];
411-
double[] vector_i = Hidden2OutputWeight[i];
416+
float er = outputCells.er[i];
417+
float[] vector_i = Hidden2OutputWeight[i];
412418
for (int k = 0; k < Hidden2OutputWeight.Width; k++)
413419
{
414420
double delta = NormalizeGradient(mergedHiddenCells.cellOutput[k] * er);
415421
double newLearningRate = UpdateLearningRate(Hidden2OutputWeightLearningRate, i, k, delta);
416422

417-
vector_i[k] += newLearningRate * delta;
423+
vector_i[k] += (float)(newLearningRate * delta);
418424
}
419425
}
420426
}
421-
422427
},
423428
()=>
424429
{
@@ -485,7 +490,7 @@ public override void computeHiddenLayer(State state, bool isTrain = true)
485490
throw new NotImplementedException("computeHiddenLayer is not implemented in BiRNN");
486491
}
487492

488-
public override void computeOutput(double[] doutput)
493+
public override void computeOutput(float[] doutput)
489494
{
490495
throw new NotImplementedException("computeOutput is not implemented in BiRNN");
491496
}

RNNSharp/LSTMRNN.cs

+33-58
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,6 @@ public class LSTMCell : SimpleCell
3030
public double wCellForget;
3131
public double wCellOut;
3232

33-
public float dCellInLearningRate;
34-
public float dCellForgetLearningRate;
35-
public float dCellOutLearningRate;
36-
3733
//partial derivatives
3834
public double dSWCellIn;
3935
public double dSWCellForget;
@@ -52,22 +48,6 @@ public struct LSTMWeight
5248
public float wInputOutputGate;
5349
}
5450

55-
//public struct LSTMWeightLearningRate
56-
//{
57-
// public float dInputCellLearningRate;
58-
// public float dInputInputGateLearningRate;
59-
// public float dInputForgetGateLearningRate;
60-
// public float dInputOutputGateLearningRate;
61-
//}
62-
63-
//public struct LSTMWeightDerivative
64-
//{
65-
// //partial derivatives. dont need partial derivative for output gate as it uses BP not RTRL
66-
// public double dSInputCell;
67-
// public double dSInputInputGate;
68-
// public double dSInputForgetGate;
69-
//}
70-
7151
public class LSTMRNN : RNN
7252
{
7353
public LSTMCell[] neuHidden; //neurons in hidden layer
@@ -76,10 +56,15 @@ public class LSTMRNN : RNN
7656

7757
protected Vector4[][] Input2HiddenLearningRate;
7858
protected Vector4[][] Feature2HiddenLearningRate;
59+
protected Vector3[] CellLearningRate;
7960

8061
protected Vector3[][] input2hiddenDeri;
8162
protected Vector3[][] feature2hiddenDeri;
8263

64+
private Vector4 vecLearningRate;
65+
private Vector3 vecLearningRate3;
66+
67+
8368
public LSTMRNN()
8469
{
8570
ModelType = MODELTYPE.LSTM;
@@ -368,7 +353,7 @@ public override void SaveModel(string filename)
368353
//weight input->hidden
369354
Logger.WriteLine("Saving input2hidden weights...");
370355
saveLSTMWeight(input2hidden, fo);
371-
356+
372357
if (DenseFeatureSize > 0)
373358
{
374359
//weight fea->hidden
@@ -453,7 +438,7 @@ public override void initWeights()
453438
}
454439

455440
//Create and intialise the weights from hidden to output layer, these are just normal weights
456-
Hidden2OutputWeight = new Matrix<double>(L2, L1);
441+
Hidden2OutputWeight = new Matrix<float>(L2, L1);
457442

458443
for (int i = 0; i < Hidden2OutputWeight.Height; i++)
459444
{
@@ -499,12 +484,9 @@ public override void CleanStatus()
499484
Feature2HiddenLearningRate = new Vector4[L1][];
500485
}
501486

487+
CellLearningRate = new Vector3[L1];
502488
Parallel.For(0, L1, parallelOption, i =>
503489
{
504-
neuHidden[i].dCellForgetLearningRate = 0;
505-
neuHidden[i].dCellInLearningRate = 0;
506-
neuHidden[i].dCellOutLearningRate = 0;
507-
508490
Input2HiddenLearningRate[i] = new Vector4[L0];
509491

510492
if (DenseFeatureSize > 0)
@@ -515,6 +497,8 @@ public override void CleanStatus()
515497
});
516498

517499
Hidden2OutputWeightLearningRate = new Matrix<float>(L2, L1);
500+
vecLearningRate = new Vector4(LearningRate, LearningRate, LearningRate, LearningRate);
501+
vecLearningRate3 = new Vector3(LearningRate, LearningRate, LearningRate);
518502
}
519503

520504
public override void InitMem()
@@ -583,7 +567,7 @@ public override void ComputeHiddenLayerErr()
583567
//find the error by find the product of the output errors and their weight connection.
584568
SimpleCell cell = neuHidden[i];
585569

586-
cell.er = 0.0;
570+
cell.er = 0.0f;
587571

588572
if (cell.mask == false)
589573
{
@@ -600,30 +584,22 @@ public override void LearnOutputWeight()
600584
//update weights for hidden to output layer
601585
Parallel.For(0, L1, parallelOption, i =>
602586
{
603-
double cellOutput = neuHidden[i].cellOutput;
587+
float cellOutput = neuHidden[i].cellOutput;
604588
for (int k = 0; k < L2; k++)
605589
{
606-
double delta = NormalizeGradient(cellOutput * OutputLayer.er[k]);
607-
double newLearningRate = UpdateLearningRate(Hidden2OutputWeightLearningRate, i, k, delta);
590+
float delta = NormalizeGradient(cellOutput * OutputLayer.er[k]);
591+
double newLearningRate = UpdateLearningRate(Hidden2OutputWeightLearningRate, k, i, delta);
608592

609-
Hidden2OutputWeight[k][i] += newLearningRate * delta;
593+
Hidden2OutputWeight[k][i] += (float)(newLearningRate * delta);
610594
}
611595
});
612596
}
613597

614-
public double UpdateLearningRate(ref float mg, double delta)
615-
{
616-
double dg = mg + delta * delta;
617-
mg = (float)dg;
618-
return LearningRate / (1.0 + Math.Sqrt(dg));
619-
}
620-
621598
public override void LearnNet(State state, int numStates, int curState)
622599
{
623600
//Get sparse feature and apply it into hidden layer
624601
var sparse = state.SparseData;
625602
int sparseFeatureSize = sparse.Count;
626-
Vector4 vecLearningRate = new Vector4(LearningRate, LearningRate, LearningRate, LearningRate);
627603

628604
//put variables for derivaties in weight class and cell class
629605
Parallel.For(0, L1, parallelOption, i =>
@@ -650,8 +626,6 @@ public override void LearnNet(State state, int numStates, int curState)
650626
(float)Sigmoid2_ci_netCellState_mul_SigmoidDerivative_ci_netIn,
651627
(float)ci_previousCellState_mul_SigmoidDerivative_ci_netForget);
652628

653-
double delta = 0;
654-
double newLearningRate = 0;
655629
for (int k = 0; k < sparseFeatureSize; k++)
656630
{
657631
var entry = sparse.GetEntry(k);
@@ -673,9 +647,7 @@ public override void LearnNet(State state, int numStates, int curState)
673647
vecAlpha = wlr + vecAlpha;
674648
wlr_i[entry.Key] = vecAlpha;
675649

676-
vecAlpha = Vector4.SquareRoot(vecAlpha) + Vector4.One;
677-
vecAlpha = vecLearningRate / vecAlpha;
678-
650+
vecAlpha = vecLearningRate / (Vector4.SquareRoot(vecAlpha) + Vector4.One);
679651
vecDelta = vecAlpha * vecDelta;
680652

681653
w.wInputCell += vecDelta.X;
@@ -713,9 +685,7 @@ public override void LearnNet(State state, int numStates, int curState)
713685
vecAlpha = wlr + vecAlpha;
714686
wlr_i[j] = vecAlpha;
715687

716-
vecAlpha = Vector4.SquareRoot(vecAlpha) + Vector4.One;
717-
vecAlpha = vecLearningRate / vecAlpha;
718-
688+
vecAlpha = vecLearningRate / (Vector4.SquareRoot(vecAlpha) + Vector4.One);
719689
vecDelta = vecAlpha * vecDelta;
720690

721691
w.wInputCell += vecDelta.X;
@@ -736,17 +706,22 @@ public override void LearnNet(State state, int numStates, int curState)
736706

737707

738708
//update internal weights
739-
delta = cellStateError * c.dSWCellIn;
740-
newLearningRate = UpdateLearningRate(ref c.dCellInLearningRate, delta);
741-
c.wCellIn += newLearningRate * delta;
709+
Vector3 vecCellDelta = new Vector3((float)c.dSWCellIn, (float)c.dSWCellForget, (float)c.cellState);
710+
Vector3 vecCellErr = new Vector3(cellStateError, cellStateError, gradientOutputGate);
711+
Vector3 vecCellLearningRate = CellLearningRate[i];
712+
713+
vecCellDelta = vecCellErr * vecCellDelta;
714+
vecCellLearningRate += (vecCellDelta * vecCellDelta);
715+
CellLearningRate[i] = vecCellLearningRate;
716+
717+
//LearningRate / (1.0 + Math.Sqrt(dg));
718+
vecCellLearningRate = vecLearningRate3 / (Vector3.One + Vector3.SquareRoot(vecCellLearningRate));
719+
vecCellDelta = vecCellLearningRate * vecCellDelta;
742720

743-
delta = cellStateError * c.dSWCellForget;
744-
newLearningRate = UpdateLearningRate(ref c.dCellForgetLearningRate, delta);
745-
c.wCellForget += newLearningRate * delta;
721+
c.wCellIn += vecCellDelta.X;
722+
c.wCellForget += vecCellDelta.Y;
723+
c.wCellOut += vecCellDelta.Z;
746724

747-
delta = gradientOutputGate * c.cellState;
748-
newLearningRate = UpdateLearningRate(ref c.dCellOutLearningRate, delta);
749-
c.wCellOut += newLearningRate * delta;
750725

751726
neuHidden[i] = c;
752727
});
@@ -833,15 +808,15 @@ public override void computeHiddenLayer(State state, bool isTrain = true)
833808
//squash output gate
834809
cell_j.yOut = Sigmoid(cell_j.netOut);
835810

836-
cell_j.cellOutput = cell_j.cellState * cell_j.yOut;
811+
cell_j.cellOutput = (float)(cell_j.cellState * cell_j.yOut);
837812

838813

839814
neuHidden[j] = cell_j;
840815
});
841816
}
842817

843818

844-
public override void computeOutput(double[] doutput)
819+
public override void computeOutput(float[] doutput)
845820
{
846821
matrixXvectorADD(OutputLayer, neuHidden, Hidden2OutputWeight, L2, L1, 0);
847822
if (doutput != null)

RNNSharp/MathUtil.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ namespace RNNSharp
77
{
88
class MathUtil
99
{
10-
public static int GetMaxProbIndex(double [] array)
10+
public static int GetMaxProbIndex(float [] array)
1111
{
1212
int dim = array.Length;
1313
double maxValue = array[0];

0 commit comments

Comments
 (0)