@@ -30,6 +30,10 @@ public class LSTMCell : SimpleCell
30
30
public double wCellForget ;
31
31
public double wCellOut ;
32
32
33
+ public float dCellInLearningRate ;
34
+ public float dCellForgetLearningRate ;
35
+ public float dCellOutLearningRate ;
36
+
33
37
//partial derivatives
34
38
public double dSWCellIn ;
35
39
public double dSWCellForget ;
@@ -48,6 +52,22 @@ public struct LSTMWeight
48
52
public float wInputOutputGate ;
49
53
}
50
54
55
+ //public struct LSTMWeightLearningRate
56
+ //{
57
+ // public float dInputCellLearningRate;
58
+ // public float dInputInputGateLearningRate;
59
+ // public float dInputForgetGateLearningRate;
60
+ // public float dInputOutputGateLearningRate;
61
+ //}
62
+
63
+ //public struct LSTMWeightDerivative
64
+ //{
65
+ // //partial derivatives. dont need partial derivative for output gate as it uses BP not RTRL
66
+ // public double dSInputCell;
67
+ // public double dSInputInputGate;
68
+ // public double dSInputForgetGate;
69
+ //}
70
+
51
71
public class LSTMRNN : RNN
52
72
{
53
73
public LSTMCell [ ] neuHidden ; //neurons in hidden layer
@@ -56,15 +76,10 @@ public class LSTMRNN : RNN
56
76
57
77
protected Vector4 [ ] [ ] Input2HiddenLearningRate ;
58
78
protected Vector4 [ ] [ ] Feature2HiddenLearningRate ;
59
- protected Vector3 [ ] CellLearningRate ;
60
79
61
80
protected Vector3 [ ] [ ] input2hiddenDeri ;
62
81
protected Vector3 [ ] [ ] feature2hiddenDeri ;
63
82
64
- private Vector4 vecLearningRate ;
65
- private Vector3 vecLearningRate3 ;
66
-
67
-
68
83
public LSTMRNN ( )
69
84
{
70
85
ModelType = MODELTYPE . LSTM ;
@@ -353,7 +368,7 @@ public override void SaveModel(string filename)
353
368
//weight input->hidden
354
369
Logger . WriteLine ( "Saving input2hidden weights..." ) ;
355
370
saveLSTMWeight ( input2hidden , fo ) ;
356
-
371
+
357
372
if ( DenseFeatureSize > 0 )
358
373
{
359
374
//weight fea->hidden
@@ -438,7 +453,7 @@ public override void initWeights()
438
453
}
439
454
440
455
//Create and intialise the weights from hidden to output layer, these are just normal weights
441
- Hidden2OutputWeight = new Matrix < float > ( L2 , L1 ) ;
456
+ Hidden2OutputWeight = new Matrix < double > ( L2 , L1 ) ;
442
457
443
458
for ( int i = 0 ; i < Hidden2OutputWeight . Height ; i ++ )
444
459
{
@@ -484,9 +499,12 @@ public override void CleanStatus()
484
499
Feature2HiddenLearningRate = new Vector4 [ L1 ] [ ] ;
485
500
}
486
501
487
- CellLearningRate = new Vector3 [ L1 ] ;
488
502
Parallel . For ( 0 , L1 , parallelOption , i =>
489
503
{
504
+ neuHidden [ i ] . dCellForgetLearningRate = 0 ;
505
+ neuHidden [ i ] . dCellInLearningRate = 0 ;
506
+ neuHidden [ i ] . dCellOutLearningRate = 0 ;
507
+
490
508
Input2HiddenLearningRate [ i ] = new Vector4 [ L0 ] ;
491
509
492
510
if ( DenseFeatureSize > 0 )
@@ -497,8 +515,6 @@ public override void CleanStatus()
497
515
} ) ;
498
516
499
517
Hidden2OutputWeightLearningRate = new Matrix < float > ( L2 , L1 ) ;
500
- vecLearningRate = new Vector4 ( LearningRate , LearningRate , LearningRate , LearningRate ) ;
501
- vecLearningRate3 = new Vector3 ( LearningRate , LearningRate , LearningRate ) ;
502
518
}
503
519
504
520
public override void InitMem ( )
@@ -567,7 +583,7 @@ public override void ComputeHiddenLayerErr()
567
583
//find the error by find the product of the output errors and their weight connection.
568
584
SimpleCell cell = neuHidden [ i ] ;
569
585
570
- cell . er = 0.0f ;
586
+ cell . er = 0.0 ;
571
587
572
588
if ( cell . mask == false )
573
589
{
@@ -584,22 +600,30 @@ public override void LearnOutputWeight()
584
600
//update weights for hidden to output layer
585
601
Parallel . For ( 0 , L1 , parallelOption , i =>
586
602
{
587
- float cellOutput = neuHidden [ i ] . cellOutput ;
603
+ double cellOutput = neuHidden [ i ] . cellOutput ;
588
604
for ( int k = 0 ; k < L2 ; k ++ )
589
605
{
590
- float delta = NormalizeGradient ( cellOutput * OutputLayer . er [ k ] ) ;
591
- double newLearningRate = UpdateLearningRate ( Hidden2OutputWeightLearningRate , k , i , delta ) ;
606
+ double delta = NormalizeGradient ( cellOutput * OutputLayer . er [ k ] ) ;
607
+ double newLearningRate = UpdateLearningRate ( Hidden2OutputWeightLearningRate , i , k , delta ) ;
592
608
593
- Hidden2OutputWeight [ k ] [ i ] += ( float ) ( newLearningRate * delta ) ;
609
+ Hidden2OutputWeight [ k ] [ i ] += newLearningRate * delta ;
594
610
}
595
611
} ) ;
596
612
}
597
613
614
+ public double UpdateLearningRate ( ref float mg , double delta )
615
+ {
616
+ double dg = mg + delta * delta ;
617
+ mg = ( float ) dg ;
618
+ return LearningRate / ( 1.0 + Math . Sqrt ( dg ) ) ;
619
+ }
620
+
598
621
public override void LearnNet ( State state , int numStates , int curState )
599
622
{
600
623
//Get sparse feature and apply it into hidden layer
601
624
var sparse = state . SparseData ;
602
625
int sparseFeatureSize = sparse . Count ;
626
+ Vector4 vecLearningRate = new Vector4 ( LearningRate , LearningRate , LearningRate , LearningRate ) ;
603
627
604
628
//put variables for derivaties in weight class and cell class
605
629
Parallel . For ( 0 , L1 , parallelOption , i =>
@@ -626,6 +650,8 @@ public override void LearnNet(State state, int numStates, int curState)
626
650
( float ) Sigmoid2_ci_netCellState_mul_SigmoidDerivative_ci_netIn ,
627
651
( float ) ci_previousCellState_mul_SigmoidDerivative_ci_netForget ) ;
628
652
653
+ double delta = 0 ;
654
+ double newLearningRate = 0 ;
629
655
for ( int k = 0 ; k < sparseFeatureSize ; k ++ )
630
656
{
631
657
var entry = sparse . GetEntry ( k ) ;
@@ -647,7 +673,9 @@ public override void LearnNet(State state, int numStates, int curState)
647
673
vecAlpha = wlr + vecAlpha ;
648
674
wlr_i [ entry . Key ] = vecAlpha ;
649
675
650
- vecAlpha = vecLearningRate / ( Vector4 . SquareRoot ( vecAlpha ) + Vector4 . One ) ;
676
+ vecAlpha = Vector4 . SquareRoot ( vecAlpha ) + Vector4 . One ;
677
+ vecAlpha = vecLearningRate / vecAlpha ;
678
+
651
679
vecDelta = vecAlpha * vecDelta ;
652
680
653
681
w . wInputCell += vecDelta . X ;
@@ -685,7 +713,9 @@ public override void LearnNet(State state, int numStates, int curState)
685
713
vecAlpha = wlr + vecAlpha ;
686
714
wlr_i [ j ] = vecAlpha ;
687
715
688
- vecAlpha = vecLearningRate / ( Vector4 . SquareRoot ( vecAlpha ) + Vector4 . One ) ;
716
+ vecAlpha = Vector4 . SquareRoot ( vecAlpha ) + Vector4 . One ;
717
+ vecAlpha = vecLearningRate / vecAlpha ;
718
+
689
719
vecDelta = vecAlpha * vecDelta ;
690
720
691
721
w . wInputCell += vecDelta . X ;
@@ -706,22 +736,17 @@ public override void LearnNet(State state, int numStates, int curState)
706
736
707
737
708
738
//update internal weights
709
- Vector3 vecCellDelta = new Vector3 ( ( float ) c . dSWCellIn , ( float ) c . dSWCellForget , ( float ) c . cellState ) ;
710
- Vector3 vecCellErr = new Vector3 ( cellStateError , cellStateError , gradientOutputGate ) ;
711
- Vector3 vecCellLearningRate = CellLearningRate [ i ] ;
712
-
713
- vecCellDelta = vecCellErr * vecCellDelta ;
714
- vecCellLearningRate += ( vecCellDelta * vecCellDelta ) ;
715
- CellLearningRate [ i ] = vecCellLearningRate ;
716
-
717
- //LearningRate / (1.0 + Math.Sqrt(dg));
718
- vecCellLearningRate = vecLearningRate3 / ( Vector3 . One + Vector3 . SquareRoot ( vecCellLearningRate ) ) ;
719
- vecCellDelta = vecCellLearningRate * vecCellDelta ;
739
+ delta = cellStateError * c . dSWCellIn ;
740
+ newLearningRate = UpdateLearningRate ( ref c . dCellInLearningRate , delta ) ;
741
+ c . wCellIn += newLearningRate * delta ;
720
742
721
- c . wCellIn += vecCellDelta . X ;
722
- c . wCellForget += vecCellDelta . Y ;
723
- c . wCellOut += vecCellDelta . Z ;
743
+ delta = cellStateError * c . dSWCellForget ;
744
+ newLearningRate = UpdateLearningRate ( ref c . dCellForgetLearningRate , delta ) ;
745
+ c . wCellForget += newLearningRate * delta ;
724
746
747
+ delta = gradientOutputGate * c . cellState ;
748
+ newLearningRate = UpdateLearningRate ( ref c . dCellOutLearningRate , delta ) ;
749
+ c . wCellOut += newLearningRate * delta ;
725
750
726
751
neuHidden [ i ] = c ;
727
752
} ) ;
@@ -808,15 +833,15 @@ public override void computeHiddenLayer(State state, bool isTrain = true)
808
833
//squash output gate
809
834
cell_j . yOut = Sigmoid ( cell_j . netOut ) ;
810
835
811
- cell_j . cellOutput = ( float ) ( cell_j . cellState * cell_j . yOut ) ;
836
+ cell_j . cellOutput = cell_j . cellState * cell_j . yOut ;
812
837
813
838
814
839
neuHidden [ j ] = cell_j ;
815
840
} ) ;
816
841
}
817
842
818
843
819
- public override void computeOutput ( float [ ] doutput )
844
+ public override void computeOutput ( double [ ] doutput )
820
845
{
821
846
matrixXvectorADD ( OutputLayer , neuHidden , Hidden2OutputWeight , L2 , L1 , 0 ) ;
822
847
if ( doutput != null )
0 commit comments