@@ -39,6 +39,7 @@ public class LSTMCell
39
39
40
40
//cell output
41
41
public double cellOutput ;
42
+ public bool mask ;
42
43
}
43
44
44
45
public struct LSTMWeight
@@ -68,10 +69,6 @@ public class LSTMRNN : RNN
68
69
protected LSTMWeightDerivative [ ] [ ] input2hiddenDeri ;
69
70
protected LSTMWeightDerivative [ ] [ ] feature2hiddenDeri ;
70
71
71
- //for LSTM layer
72
- const bool NORMAL = true ;
73
- const bool BIAS = false ;
74
-
75
72
public LSTMRNN ( )
76
73
{
77
74
m_modeltype = MODELTYPE . LSTM ;
@@ -248,29 +245,29 @@ public override void saveNetBin(string filename)
248
245
}
249
246
250
247
251
- double TanH ( double x )
248
+ double Sigmoid2 ( double x )
252
249
{
253
- return Math . Tanh ( x ) ;
250
+ //sigmoid function return a bounded output between [-2,2]
251
+ return ( 4.0 / ( 1.0 + Math . Exp ( - x ) ) ) - 2.0 ;
254
252
}
255
253
256
- double TanHDerivative ( double x )
254
+ double Sigmoid2Derivative ( double x )
257
255
{
258
- double tmp = Math . Tanh ( x ) ;
259
- return 1 - tmp * tmp ;
256
+ return 4.0 * Sigmoid ( x ) * ( 1.0 - Sigmoid ( x ) ) ;
260
257
}
261
258
262
259
double Sigmoid ( double x )
263
260
{
264
- return ( 1 / ( 1 + Math . Exp ( - x ) ) ) ;
261
+ return ( 1.0 / ( 1.0 + Math . Exp ( - x ) ) ) ;
265
262
}
266
263
267
264
double SigmoidDerivative ( double x )
268
265
{
269
- return Sigmoid ( x ) * ( 1 - Sigmoid ( x ) ) ;
266
+ return Sigmoid ( x ) * ( 1.0 - Sigmoid ( x ) ) ;
270
267
}
271
268
272
269
273
- public LSTMWeight LSTMWeightInit ( int iL )
270
+ public LSTMWeight LSTMWeightInit ( )
274
271
{
275
272
LSTMWeight w ;
276
273
@@ -292,7 +289,7 @@ public override void initWeights()
292
289
input2hidden [ i ] = new LSTMWeight [ L0 ] ;
293
290
for ( int j = 0 ; j < L0 ; j ++ )
294
291
{
295
- input2hidden [ i ] [ j ] = LSTMWeightInit ( L0 ) ;
292
+ input2hidden [ i ] [ j ] = LSTMWeightInit ( ) ;
296
293
}
297
294
}
298
295
@@ -304,7 +301,7 @@ public override void initWeights()
304
301
feature2hidden [ i ] = new LSTMWeight [ fea_size ] ;
305
302
for ( int j = 0 ; j < fea_size ; j ++ )
306
303
{
307
- feature2hidden [ i ] [ j ] = LSTMWeightInit ( L0 ) ;
304
+ feature2hidden [ i ] [ j ] = LSTMWeightInit ( ) ;
308
305
}
309
306
}
310
307
}
@@ -418,26 +415,14 @@ public void matrixXvectorADD(neuron[] dest, LSTMCell[] srcvec, Matrix<double> sr
418
415
//ac mod
419
416
Parallel . For ( 0 , ( to - from ) , parallelOption , i =>
420
417
{
418
+ dest [ i + from ] . cellOutput = 0 ;
421
419
for ( int j = 0 ; j < to2 - from2 ; j ++ )
422
420
{
423
421
dest [ i + from ] . cellOutput += srcvec [ j + from2 ] . cellOutput * srcmatrix [ i ] [ j ] ;
424
422
}
425
423
} ) ;
426
424
}
427
425
428
- public void matrixXvectorADD ( LSTMCell [ ] dest , double [ ] srcvec , LSTMWeight [ ] [ ] srcmatrix , int from , int to , int from2 , int to2 )
429
- {
430
- //ac mod
431
- Parallel . For ( 0 , ( to - from ) , parallelOption , i =>
432
- {
433
- for ( int j = 0 ; j < to2 - from2 ; j ++ )
434
- {
435
- dest [ i + from ] . netIn += srcvec [ j + from2 ] * srcmatrix [ i ] [ j ] . wInputInputGate ;
436
- }
437
- } ) ;
438
- }
439
-
440
-
441
426
public override void LearnBackTime ( State state , int numStates , int curState )
442
427
{
443
428
}
@@ -463,8 +448,8 @@ public override void learnNet(State state, int timeat, bool biRNN = false)
463
448
{
464
449
var entry = sparse . GetEntry ( k ) ;
465
450
LSTMWeightDerivative w = w_i [ entry . Key ] ;
466
- w_i [ entry . Key ] . dSInputCell = w . dSInputCell * c . yForget + TanHDerivative ( c . netCellState ) * c . yIn * entry . Value ;
467
- w_i [ entry . Key ] . dSInputInputGate = w . dSInputInputGate * c . yForget + TanH ( c . netCellState ) * SigmoidDerivative ( c . netIn ) * entry . Value ;
451
+ w_i [ entry . Key ] . dSInputCell = w . dSInputCell * c . yForget + Sigmoid2Derivative ( c . netCellState ) * c . yIn * entry . Value ;
452
+ w_i [ entry . Key ] . dSInputInputGate = w . dSInputInputGate * c . yForget + Sigmoid2 ( c . netCellState ) * SigmoidDerivative ( c . netIn ) * entry . Value ;
468
453
w_i [ entry . Key ] . dSInputForgetGate = w . dSInputForgetGate * c . yForget + c . previousCellState * SigmoidDerivative ( c . netForget ) * entry . Value ;
469
454
470
455
}
@@ -475,15 +460,15 @@ public override void learnNet(State state, int timeat, bool biRNN = false)
475
460
for ( int j = 0 ; j < fea_size ; j ++ )
476
461
{
477
462
LSTMWeightDerivative w = w_i [ j ] ;
478
- w_i [ j ] . dSInputCell = w . dSInputCell * c . yForget + TanHDerivative ( c . netCellState ) * c . yIn * neuFeatures [ j ] ;
479
- w_i [ j ] . dSInputInputGate = w . dSInputInputGate * c . yForget + TanH ( c . netCellState ) * SigmoidDerivative ( c . netIn ) * neuFeatures [ j ] ;
463
+ w_i [ j ] . dSInputCell = w . dSInputCell * c . yForget + Sigmoid2Derivative ( c . netCellState ) * c . yIn * neuFeatures [ j ] ;
464
+ w_i [ j ] . dSInputInputGate = w . dSInputInputGate * c . yForget + Sigmoid2 ( c . netCellState ) * SigmoidDerivative ( c . netIn ) * neuFeatures [ j ] ;
480
465
w_i [ j ] . dSInputForgetGate = w . dSInputForgetGate * c . yForget + c . previousCellState * SigmoidDerivative ( c . netForget ) * neuFeatures [ j ] ;
481
466
482
467
}
483
468
}
484
469
485
470
//partial derivatives for internal connections
486
- c . dSWCellIn = c . dSWCellIn * c . yForget + TanH ( c . netCellState ) * SigmoidDerivative ( c . netIn ) * c . cellState ;
471
+ c . dSWCellIn = c . dSWCellIn * c . yForget + Sigmoid2 ( c . netCellState ) * SigmoidDerivative ( c . netIn ) * c . cellState ;
487
472
488
473
//partial derivatives for internal connections, initially zero as dS is zero and previous cell state is zero
489
474
c . dSWCellForget = c . dSWCellForget * c . yForget + c . previousCellState * SigmoidDerivative ( c . netForget ) * c . previousCellState ;
@@ -505,18 +490,12 @@ public override void learnNet(State state, int timeat, bool biRNN = false)
505
490
weightedSum = NormalizeErr ( weightedSum ) ;
506
491
507
492
//using the error find the gradient of the output gate
508
- double gradientOutputGate = SigmoidDerivative ( c . netOut ) * TanHDerivative ( c . cellState ) * weightedSum ;
493
+ double gradientOutputGate = SigmoidDerivative ( c . netOut ) * c . cellState * weightedSum ;
509
494
510
495
//internal cell state error
511
496
double cellStateError = c . yOut * weightedSum ;
512
497
513
-
514
498
//weight updates
515
-
516
- //already done the deltas for the hidden-output connections
517
-
518
- //output gates. for each connection to the hidden layer
519
- //to the input layer
520
499
LSTMWeight [ ] w_i = input2hidden [ i ] ;
521
500
LSTMWeightDerivative [ ] wd_i = input2hiddenDeri [ i ] ;
522
501
for ( int k = 0 ; k < sparseFeatureSize ; k ++ )
@@ -545,30 +524,22 @@ public override void learnNet(State state, int timeat, bool biRNN = false)
545
524
}
546
525
}
547
526
548
- //for the internal connection
549
- double deltaOutputGateCell = alpha * gradientOutputGate * c . cellState ;
550
-
551
- //using internal partial derivative
552
- double deltaInputGateCell = alpha * cellStateError * c . dSWCellIn ;
553
-
554
- double deltaForgetGateCell = alpha * cellStateError * c . dSWCellForget ;
555
-
556
527
//update internal weights
557
- c . wCellIn += deltaInputGateCell ;
558
- c . wCellForget += deltaForgetGateCell ;
559
- c . wCellOut += deltaOutputGateCell ;
528
+ c . wCellIn += alpha * cellStateError * c . dSWCellIn ;
529
+ c . wCellForget += alpha * cellStateError * c . dSWCellForget ;
530
+ c . wCellOut += alpha * gradientOutputGate * c . cellState ;
560
531
561
532
neuHidden [ i ] = c ;
562
533
} ) ;
563
534
564
535
//update weights for hidden to output layer
565
- for ( int i = 0 ; i < L1 ; i ++ )
536
+ Parallel . For ( 0 , L1 , parallelOption , i =>
566
537
{
567
538
for ( int k = 0 ; k < L2 ; k ++ )
568
539
{
569
540
mat_hidden2output [ k ] [ i ] += alpha * neuHidden [ i ] . cellOutput * neuOutput [ k ] . er ;
570
541
}
571
- }
542
+ } ) ;
572
543
}
573
544
574
545
@@ -580,35 +551,16 @@ public override void computeNet(State state, double[] doutput, bool isTrain = tr
580
551
var sparse = state . GetSparseData ( ) ;
581
552
int sparseFeatureSize = sparse . GetNumberOfEntries ( ) ;
582
553
583
- //loop through all input gates in hidden layer
584
- //for each hidden neuron
585
- Parallel . For ( 0 , L1 , parallelOption , j =>
586
- {
587
- //rest the value of the net input to zero
588
- neuHidden [ j ] . netIn = 0 ;
589
-
590
- //hidden(t-1) -> hidden(t)
591
- neuHidden [ j ] . previousCellState = neuHidden [ j ] . cellState ;
592
-
593
- //for each input neuron
594
- for ( int i = 0 ; i < sparseFeatureSize ; i ++ )
595
- {
596
- var entry = sparse . GetEntry ( i ) ;
597
- neuHidden [ j ] . netIn += entry . Value * input2hidden [ j ] [ entry . Key ] . wInputInputGate ;
598
- }
599
-
600
- } ) ;
601
-
602
- //fea(t) -> hidden(t)
603
- if ( fea_size > 0 )
604
- {
605
- matrixXvectorADD ( neuHidden , neuFeatures , feature2hidden , 0 , L1 , 0 , fea_size ) ;
606
- }
607
-
608
554
Parallel . For ( 0 , L1 , parallelOption , j =>
609
555
{
610
556
LSTMCell cell_j = neuHidden [ j ] ;
611
557
558
+ //hidden(t-1) -> hidden(t)
559
+ cell_j . previousCellState = cell_j . cellState ;
560
+
561
+ //rest the value of the net input to zero
562
+ cell_j . netIn = 0 ;
563
+
612
564
cell_j . netForget = 0 ;
613
565
//reset each netCell state to zero
614
566
cell_j . netCellState = 0 ;
@@ -619,16 +571,19 @@ public override void computeNet(State state, double[] doutput, bool isTrain = tr
619
571
var entry = sparse . GetEntry ( i ) ;
620
572
LSTMWeight w = input2hidden [ j ] [ entry . Key ] ;
621
573
//loop through all forget gates in hiddden layer
574
+ cell_j . netIn += entry . Value * w . wInputInputGate ;
622
575
cell_j . netForget += entry . Value * w . wInputForgetGate ;
623
576
cell_j . netCellState += entry . Value * w . wInputCell ;
624
577
cell_j . netOut += entry . Value * w . wInputOutputGate ;
625
578
}
626
579
580
+ //fea(t) -> hidden(t)
627
581
if ( fea_size > 0 )
628
582
{
629
583
for ( int i = 0 ; i < fea_size ; i ++ )
630
584
{
631
585
LSTMWeight w = feature2hidden [ j ] [ i ] ;
586
+ cell_j . netIn += neuFeatures [ i ] * w . wInputInputGate ;
632
587
cell_j . netForget += neuFeatures [ i ] * w . wInputForgetGate ;
633
588
cell_j . netCellState += neuFeatures [ i ] * w . wInputCell ;
634
589
cell_j . netOut += neuFeatures [ i ] * w . wInputOutputGate ;
@@ -643,18 +598,24 @@ public override void computeNet(State state, double[] doutput, bool isTrain = tr
643
598
//include internal connection multiplied by the previous cell state
644
599
cell_j . netForget += cell_j . previousCellState * cell_j . wCellForget ;
645
600
cell_j . yForget = Sigmoid ( cell_j . netForget ) ;
646
-
647
601
648
- //cell state is equal to the previous cell state multipled by the forget gate and the cell inputs multiplied by the input gate
649
- cell_j . cellState = cell_j . yForget * cell_j . previousCellState + cell_j . yIn * TanH ( cell_j . netCellState ) ;
602
+ if ( cell_j . mask == true )
603
+ {
604
+ cell_j . cellState = 0 ;
605
+ }
606
+ else
607
+ {
608
+ //cell state is equal to the previous cell state multipled by the forget gate and the cell inputs multiplied by the input gate
609
+ cell_j . cellState = cell_j . yForget * cell_j . previousCellState + cell_j . yIn * Sigmoid2 ( cell_j . netCellState ) ;
610
+ }
650
611
651
612
////include the internal connection multiplied by the CURRENT cell state
652
613
cell_j . netOut += cell_j . cellState * cell_j . wCellOut ;
653
614
654
615
//squash output gate
655
616
cell_j . yOut = Sigmoid ( cell_j . netOut ) ;
656
617
657
- cell_j . cellOutput = TanH ( cell_j . cellState ) * cell_j . yOut ;
618
+ cell_j . cellOutput = cell_j . cellState * cell_j . yOut ;
658
619
659
620
660
621
neuHidden [ j ] = cell_j ;
@@ -673,18 +634,25 @@ public override void computeNet(State state, double[] doutput, bool isTrain = tr
673
634
SoftmaxLayer ( neuOutput ) ;
674
635
}
675
636
676
- public override void netFlush ( ) //cleans all activations and error vectors
637
+ public override void netReset ( bool updateNet = false ) //cleans hidden layer activation + bptt history
677
638
{
678
- neuFeatures = new double [ fea_size ] ;
639
+ for ( int a = 0 ; a < L1 ; a ++ )
640
+ {
641
+ neuHidden [ a ] . mask = false ;
642
+ }
679
643
680
- for ( int i = 0 ; i < L1 ; i ++ )
644
+ if ( updateNet == true )
681
645
{
682
- LSTMCellInit ( neuHidden [ i ] ) ;
646
+ //Train mode
647
+ for ( int a = 0 ; a < L1 ; a ++ )
648
+ {
649
+ if ( rand . NextDouble ( ) < dropout )
650
+ {
651
+ neuHidden [ a ] . mask = true ;
652
+ }
653
+ }
683
654
}
684
- }
685
655
686
- public override void netReset ( bool updateNet = false ) //cleans hidden layer activation + bptt history
687
- {
688
656
Parallel . For ( 0 , L1 , parallelOption , i =>
689
657
{
690
658
LSTMCellInit ( neuHidden [ i ] ) ;
0 commit comments