Skip to content

Commit 1d2b3be

Browse files
committed
#1. Bug fix: output layer is not cleaned before calculating new values
#2. Add dropout for LSTM
1 parent fe925d1 commit 1d2b3be

File tree

5 files changed

+62
-114
lines changed

5 files changed

+62
-114
lines changed

RNNSharp/BiRNN.cs

-6
Original file line numberDiff line numberDiff line change
@@ -198,12 +198,6 @@ public neuron[][] InnerDecode(Sequence pSequence, out Matrix<neuron> outputHidde
198198
return seqOutput;
199199
}
200200

201-
public override void netFlush()
202-
{
203-
forwardRNN.netFlush();
204-
backwardRNN.netFlush();
205-
}
206-
207201
public override Matrix<double> learnSentenceForRNNCRF(Sequence pSequence, RunningMode runningMode)
208202
{
209203
//Reset the network

RNNSharp/LSTMRNN.cs

+56-88
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ public class LSTMCell
3939

4040
//cell output
4141
public double cellOutput;
42+
public bool mask;
4243
}
4344

4445
public struct LSTMWeight
@@ -68,10 +69,6 @@ public class LSTMRNN : RNN
6869
protected LSTMWeightDerivative[][] input2hiddenDeri;
6970
protected LSTMWeightDerivative[][] feature2hiddenDeri;
7071

71-
//for LSTM layer
72-
const bool NORMAL = true;
73-
const bool BIAS = false;
74-
7572
public LSTMRNN()
7673
{
7774
m_modeltype = MODELTYPE.LSTM;
@@ -248,29 +245,29 @@ public override void saveNetBin(string filename)
248245
}
249246

250247

251-
double TanH(double x)
248+
double Sigmoid2(double x)
252249
{
253-
return Math.Tanh(x);
250+
//sigmoid function return a bounded output between [-2,2]
251+
return (4.0 / (1.0 + Math.Exp(-x))) - 2.0;
254252
}
255253

256-
double TanHDerivative(double x)
254+
double Sigmoid2Derivative(double x)
257255
{
258-
double tmp = Math.Tanh(x);
259-
return 1 - tmp * tmp;
256+
return 4.0 * Sigmoid(x) * (1.0 - Sigmoid(x));
260257
}
261258

262259
double Sigmoid(double x)
263260
{
264-
return (1 / (1 + Math.Exp(-x)));
261+
return (1.0 / (1.0 + Math.Exp(-x)));
265262
}
266263

267264
double SigmoidDerivative(double x)
268265
{
269-
return Sigmoid(x) * (1 - Sigmoid(x));
266+
return Sigmoid(x) * (1.0 - Sigmoid(x));
270267
}
271268

272269

273-
public LSTMWeight LSTMWeightInit(int iL)
270+
public LSTMWeight LSTMWeightInit()
274271
{
275272
LSTMWeight w;
276273

@@ -292,7 +289,7 @@ public override void initWeights()
292289
input2hidden[i] = new LSTMWeight[L0];
293290
for (int j = 0; j < L0; j++)
294291
{
295-
input2hidden[i][j] = LSTMWeightInit(L0);
292+
input2hidden[i][j] = LSTMWeightInit();
296293
}
297294
}
298295

@@ -304,7 +301,7 @@ public override void initWeights()
304301
feature2hidden[i] = new LSTMWeight[fea_size];
305302
for (int j = 0; j < fea_size; j++)
306303
{
307-
feature2hidden[i][j] = LSTMWeightInit(L0);
304+
feature2hidden[i][j] = LSTMWeightInit();
308305
}
309306
}
310307
}
@@ -418,26 +415,14 @@ public void matrixXvectorADD(neuron[] dest, LSTMCell[] srcvec, Matrix<double> sr
418415
//ac mod
419416
Parallel.For(0, (to - from), parallelOption, i =>
420417
{
418+
dest[i + from].cellOutput = 0;
421419
for (int j = 0; j < to2 - from2; j++)
422420
{
423421
dest[i + from].cellOutput += srcvec[j + from2].cellOutput * srcmatrix[i][j];
424422
}
425423
});
426424
}
427425

428-
public void matrixXvectorADD(LSTMCell[] dest, double[] srcvec, LSTMWeight[][] srcmatrix, int from, int to, int from2, int to2)
429-
{
430-
//ac mod
431-
Parallel.For(0, (to - from), parallelOption, i =>
432-
{
433-
for (int j = 0; j < to2 - from2; j++)
434-
{
435-
dest[i + from].netIn += srcvec[j + from2] * srcmatrix[i][j].wInputInputGate;
436-
}
437-
});
438-
}
439-
440-
441426
public override void LearnBackTime(State state, int numStates, int curState)
442427
{
443428
}
@@ -463,8 +448,8 @@ public override void learnNet(State state, int timeat, bool biRNN = false)
463448
{
464449
var entry = sparse.GetEntry(k);
465450
LSTMWeightDerivative w = w_i[entry.Key];
466-
w_i[entry.Key].dSInputCell = w.dSInputCell * c.yForget + TanHDerivative(c.netCellState) * c.yIn * entry.Value;
467-
w_i[entry.Key].dSInputInputGate = w.dSInputInputGate * c.yForget + TanH(c.netCellState) * SigmoidDerivative(c.netIn) * entry.Value;
451+
w_i[entry.Key].dSInputCell = w.dSInputCell * c.yForget + Sigmoid2Derivative(c.netCellState) * c.yIn * entry.Value;
452+
w_i[entry.Key].dSInputInputGate = w.dSInputInputGate * c.yForget + Sigmoid2(c.netCellState) * SigmoidDerivative(c.netIn) * entry.Value;
468453
w_i[entry.Key].dSInputForgetGate = w.dSInputForgetGate * c.yForget + c.previousCellState * SigmoidDerivative(c.netForget) * entry.Value;
469454

470455
}
@@ -475,15 +460,15 @@ public override void learnNet(State state, int timeat, bool biRNN = false)
475460
for (int j = 0; j < fea_size; j++)
476461
{
477462
LSTMWeightDerivative w = w_i[j];
478-
w_i[j].dSInputCell = w.dSInputCell * c.yForget + TanHDerivative(c.netCellState) * c.yIn * neuFeatures[j];
479-
w_i[j].dSInputInputGate = w.dSInputInputGate * c.yForget + TanH(c.netCellState) * SigmoidDerivative(c.netIn) * neuFeatures[j];
463+
w_i[j].dSInputCell = w.dSInputCell * c.yForget + Sigmoid2Derivative(c.netCellState) * c.yIn * neuFeatures[j];
464+
w_i[j].dSInputInputGate = w.dSInputInputGate * c.yForget + Sigmoid2(c.netCellState) * SigmoidDerivative(c.netIn) * neuFeatures[j];
480465
w_i[j].dSInputForgetGate = w.dSInputForgetGate * c.yForget + c.previousCellState * SigmoidDerivative(c.netForget) * neuFeatures[j];
481466

482467
}
483468
}
484469

485470
//partial derivatives for internal connections
486-
c.dSWCellIn = c.dSWCellIn * c.yForget + TanH(c.netCellState) * SigmoidDerivative(c.netIn) * c.cellState;
471+
c.dSWCellIn = c.dSWCellIn * c.yForget + Sigmoid2(c.netCellState) * SigmoidDerivative(c.netIn) * c.cellState;
487472

488473
//partial derivatives for internal connections, initially zero as dS is zero and previous cell state is zero
489474
c.dSWCellForget = c.dSWCellForget * c.yForget + c.previousCellState * SigmoidDerivative(c.netForget) * c.previousCellState;
@@ -505,18 +490,12 @@ public override void learnNet(State state, int timeat, bool biRNN = false)
505490
weightedSum = NormalizeErr(weightedSum);
506491

507492
//using the error find the gradient of the output gate
508-
double gradientOutputGate = SigmoidDerivative(c.netOut) * TanHDerivative(c.cellState) * weightedSum;
493+
double gradientOutputGate = SigmoidDerivative(c.netOut) * c.cellState * weightedSum;
509494

510495
//internal cell state error
511496
double cellStateError = c.yOut * weightedSum;
512497

513-
514498
//weight updates
515-
516-
//already done the deltas for the hidden-output connections
517-
518-
//output gates. for each connection to the hidden layer
519-
//to the input layer
520499
LSTMWeight[] w_i = input2hidden[i];
521500
LSTMWeightDerivative[] wd_i = input2hiddenDeri[i];
522501
for (int k = 0; k < sparseFeatureSize; k++)
@@ -545,30 +524,22 @@ public override void learnNet(State state, int timeat, bool biRNN = false)
545524
}
546525
}
547526

548-
//for the internal connection
549-
double deltaOutputGateCell = alpha * gradientOutputGate * c.cellState;
550-
551-
//using internal partial derivative
552-
double deltaInputGateCell = alpha * cellStateError * c.dSWCellIn;
553-
554-
double deltaForgetGateCell = alpha * cellStateError * c.dSWCellForget;
555-
556527
//update internal weights
557-
c.wCellIn += deltaInputGateCell;
558-
c.wCellForget += deltaForgetGateCell;
559-
c.wCellOut += deltaOutputGateCell;
528+
c.wCellIn += alpha * cellStateError * c.dSWCellIn;
529+
c.wCellForget += alpha * cellStateError * c.dSWCellForget;
530+
c.wCellOut += alpha * gradientOutputGate * c.cellState;
560531

561532
neuHidden[i] = c;
562533
});
563534

564535
//update weights for hidden to output layer
565-
for (int i = 0; i < L1; i++)
536+
Parallel.For(0, L1, parallelOption, i =>
566537
{
567538
for (int k = 0; k < L2; k++)
568539
{
569540
mat_hidden2output[k][i] += alpha * neuHidden[i].cellOutput * neuOutput[k].er;
570541
}
571-
}
542+
});
572543
}
573544

574545

@@ -580,35 +551,16 @@ public override void computeNet(State state, double[] doutput, bool isTrain = tr
580551
var sparse = state.GetSparseData();
581552
int sparseFeatureSize = sparse.GetNumberOfEntries();
582553

583-
//loop through all input gates in hidden layer
584-
//for each hidden neuron
585-
Parallel.For(0, L1, parallelOption, j =>
586-
{
587-
//rest the value of the net input to zero
588-
neuHidden[j].netIn = 0;
589-
590-
//hidden(t-1) -> hidden(t)
591-
neuHidden[j].previousCellState = neuHidden[j].cellState;
592-
593-
//for each input neuron
594-
for (int i = 0; i < sparseFeatureSize; i++)
595-
{
596-
var entry = sparse.GetEntry(i);
597-
neuHidden[j].netIn += entry.Value * input2hidden[j][entry.Key].wInputInputGate;
598-
}
599-
600-
});
601-
602-
//fea(t) -> hidden(t)
603-
if (fea_size > 0)
604-
{
605-
matrixXvectorADD(neuHidden, neuFeatures, feature2hidden, 0, L1, 0, fea_size);
606-
}
607-
608554
Parallel.For(0, L1, parallelOption, j =>
609555
{
610556
LSTMCell cell_j = neuHidden[j];
611557

558+
//hidden(t-1) -> hidden(t)
559+
cell_j.previousCellState = cell_j.cellState;
560+
561+
//rest the value of the net input to zero
562+
cell_j.netIn = 0;
563+
612564
cell_j.netForget = 0;
613565
//reset each netCell state to zero
614566
cell_j.netCellState = 0;
@@ -619,16 +571,19 @@ public override void computeNet(State state, double[] doutput, bool isTrain = tr
619571
var entry = sparse.GetEntry(i);
620572
LSTMWeight w = input2hidden[j][entry.Key];
621573
//loop through all forget gates in hiddden layer
574+
cell_j.netIn += entry.Value * w.wInputInputGate;
622575
cell_j.netForget += entry.Value * w.wInputForgetGate;
623576
cell_j.netCellState += entry.Value * w.wInputCell;
624577
cell_j.netOut += entry.Value * w.wInputOutputGate;
625578
}
626579

580+
//fea(t) -> hidden(t)
627581
if (fea_size > 0)
628582
{
629583
for (int i = 0; i < fea_size; i++)
630584
{
631585
LSTMWeight w = feature2hidden[j][i];
586+
cell_j.netIn += neuFeatures[i] * w.wInputInputGate;
632587
cell_j.netForget += neuFeatures[i] * w.wInputForgetGate;
633588
cell_j.netCellState += neuFeatures[i] * w.wInputCell;
634589
cell_j.netOut += neuFeatures[i] * w.wInputOutputGate;
@@ -643,18 +598,24 @@ public override void computeNet(State state, double[] doutput, bool isTrain = tr
643598
//include internal connection multiplied by the previous cell state
644599
cell_j.netForget += cell_j.previousCellState * cell_j.wCellForget;
645600
cell_j.yForget = Sigmoid(cell_j.netForget);
646-
647601

648-
//cell state is equal to the previous cell state multipled by the forget gate and the cell inputs multiplied by the input gate
649-
cell_j.cellState = cell_j.yForget * cell_j.previousCellState + cell_j.yIn * TanH(cell_j.netCellState);
602+
if (cell_j.mask == true)
603+
{
604+
cell_j.cellState = 0;
605+
}
606+
else
607+
{
608+
//cell state is equal to the previous cell state multipled by the forget gate and the cell inputs multiplied by the input gate
609+
cell_j.cellState = cell_j.yForget * cell_j.previousCellState + cell_j.yIn * Sigmoid2(cell_j.netCellState);
610+
}
650611

651612
////include the internal connection multiplied by the CURRENT cell state
652613
cell_j.netOut += cell_j.cellState * cell_j.wCellOut;
653614

654615
//squash output gate
655616
cell_j.yOut = Sigmoid(cell_j.netOut);
656617

657-
cell_j.cellOutput = TanH(cell_j.cellState) * cell_j.yOut;
618+
cell_j.cellOutput = cell_j.cellState * cell_j.yOut;
658619

659620

660621
neuHidden[j] = cell_j;
@@ -673,18 +634,25 @@ public override void computeNet(State state, double[] doutput, bool isTrain = tr
673634
SoftmaxLayer(neuOutput);
674635
}
675636

676-
public override void netFlush() //cleans all activations and error vectors
637+
public override void netReset(bool updateNet = false) //cleans hidden layer activation + bptt history
677638
{
678-
neuFeatures = new double[fea_size];
639+
for (int a = 0; a < L1; a++)
640+
{
641+
neuHidden[a].mask = false;
642+
}
679643

680-
for (int i = 0; i < L1; i++)
644+
if (updateNet == true)
681645
{
682-
LSTMCellInit(neuHidden[i]);
646+
//Train mode
647+
for (int a = 0; a < L1; a++)
648+
{
649+
if (rand.NextDouble() < dropout)
650+
{
651+
neuHidden[a].mask = true;
652+
}
653+
}
683654
}
684-
}
685655

686-
public override void netReset(bool updateNet = false) //cleans hidden layer activation + bptt history
687-
{
688656
Parallel.For(0, L1, parallelOption, i =>
689657
{
690658
LSTMCellInit(neuHidden[i]);

RNNSharp/RNN.cs

+1-6
Original file line numberDiff line numberDiff line change
@@ -491,8 +491,6 @@ public virtual double TrainNet(DataSet trainingSet, int iter)
491491
logp = 0;
492492
counterTokenForLM = 0;
493493

494-
netFlush();
495-
496494
//Shffle training corpus
497495
trainingSet.Shuffle();
498496

@@ -620,8 +618,6 @@ public void matrixXvectorADD(neuron[] dest, neuron[] srcvec, Matrix<double> srcm
620618
}
621619
}
622620

623-
public abstract void netFlush();
624-
625621
public int[] DecodeNN(Sequence seq)
626622
{
627623
Matrix<double> ys = PredictSentence(seq, RunningMode.Test);
@@ -841,8 +837,7 @@ public virtual bool ValidateNet(DataSet validationSet)
841837
counter = 0;
842838
logp = 0;
843839
counterTokenForLM = 0;
844-
845-
netFlush();
840+
846841
int numSequence = validationSet.GetSize();
847842
for (int curSequence = 0; curSequence < numSequence; curSequence++)
848843
{

RNNSharp/RNNEncoder.cs

+5-5
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,11 @@ public void Train()
129129

130130
betterValidateNet = true;
131131
}
132-
else
133-
{
134-
Logger.WriteLine(Logger.Level.info, "Loading previous best model from file {0}...", m_modelSetting.GetModelFile());
135-
rnn.loadNetBin(m_modelSetting.GetModelFile());
136-
}
132+
//else
133+
//{
134+
// Logger.WriteLine(Logger.Level.info, "Loading previous best model from file {0}...", m_modelSetting.GetModelFile());
135+
// rnn.loadNetBin(m_modelSetting.GetModelFile());
136+
//}
137137

138138

139139
if (ppl >= lastPPL && lastAlpha != rnn.Alpha)

RNNSharp/SimpleRNN.cs

-9
Original file line numberDiff line numberDiff line change
@@ -390,13 +390,6 @@ public override void LearnBackTime(State state, int numStates, int curState)
390390
}
391391
}
392392

393-
394-
public override void netFlush() //cleans all activations and error vectors
395-
{
396-
neuHidden = new neuron[L1];
397-
neuOutput = new neuron[L2];
398-
}
399-
400393
public override void loadNetBin(string filename)
401394
{
402395
Logger.WriteLine(Logger.Level.info, "Loading SimpleRNN model: {0}", filename);
@@ -460,9 +453,7 @@ public override void saveNetBin(string filename)
460453
StreamWriter sw = new StreamWriter(filename);
461454
BinaryWriter fo = new BinaryWriter(sw.BaseStream);
462455

463-
464456
fo.Write((int)m_modeltype);
465-
466457
fo.Write((int)m_modeldirection);
467458

468459
// Signiture , 0 is for RNN or 1 is for RNN-CRF

0 commit comments

Comments
 (0)