@@ -129,12 +129,14 @@ public class LSTMLayer : SimpleLayer
129
129
private Vector4 vecNormalLearningRate ;
130
130
private Vector3 vecNormalLearningRate3 ;
131
131
private Vector < float > vecNormalLearningRateFloat ;
132
+ // protected float[] previousCellOutputs;
132
133
133
134
LSTMLayerConfig config ;
134
135
135
136
public LSTMLayer ( LSTMLayerConfig config ) : base ( config )
136
137
{
137
138
this . config = config ;
139
+ // previousCellOutputs = new float[LayerSize];
138
140
LSTMCells = new LSTMCell [ LayerSize ] ;
139
141
for ( var i = 0 ; i < LayerSize ; i ++ )
140
142
{
@@ -155,7 +157,7 @@ public override Neuron CopyNeuronTo(Neuron neuron)
155
157
LSTMNeuron lstmNeuron = neuron as LSTMNeuron ;
156
158
157
159
Cells . CopyTo ( lstmNeuron . Cells , 0 ) ;
158
- previousCellOutputs . CopyTo ( lstmNeuron . PrevCellOutputs , 0 ) ;
160
+ // previousCellOutputs.CopyTo(lstmNeuron.PrevCellOutputs, 0);
159
161
for ( int i = 0 ; i < LayerSize ; i ++ )
160
162
{
161
163
lstmNeuron . LSTMCells [ i ] . Set ( LSTMCells [ i ] ) ;
@@ -168,8 +170,8 @@ public override Neuron CopyNeuronTo(Neuron neuron)
168
170
public override void PreUpdateWeights ( Neuron neuron , float [ ] errs )
169
171
{
170
172
LSTMNeuron lstmNeuron = neuron as LSTMNeuron ;
171
- lstmNeuron . Cells . CopyTo ( Cells , 0 ) ;
172
- lstmNeuron . PrevCellOutputs . CopyTo ( previousCellOutputs , 0 ) ;
173
+ // lstmNeuron.Cells.CopyTo(Cells, 0);
174
+ // lstmNeuron.PrevCellOutputs.CopyTo(previousCellOutputs, 0);
173
175
for ( int i = 0 ; i < LayerSize ; i ++ )
174
176
{
175
177
LSTMCells [ i ] . Set ( lstmNeuron . LSTMCells [ i ] ) ;
@@ -726,7 +728,8 @@ public override void ForwardPass(SparseVector sparseFeature, float[] denseFeatur
726
728
727
729
//hidden(t-1) -> hidden(t)
728
730
cell_j . previousCellState = cell_j . cellState ;
729
- previousCellOutputs [ j ] = Cells [ j ] ;
731
+ cell_j . previousCellOutput = Cells [ j ] ;
732
+ // previousCellOutputs[j] = Cells[j];
730
733
731
734
var vecCell_j = Vector4 . Zero ;
732
735
@@ -787,26 +790,24 @@ public override void ForwardPass(SparseVector sparseFeature, float[] denseFeatur
787
790
//reset each netOut to zero
788
791
cell_j . netOut = vecCell_j . W ;
789
792
790
- var cell_j_previousCellOutput = previousCellOutputs [ j ] ;
791
-
792
793
//include internal connection multiplied by the previous cell state
793
- cell_j . netIn += cell_j . previousCellState * cellWeight_j . wPeepholeIn + cell_j_previousCellOutput * cellWeight_j . wCellIn ;
794
+ cell_j . netIn += cell_j . previousCellState * cellWeight_j . wPeepholeIn + cell_j . previousCellOutput * cellWeight_j . wCellIn ;
794
795
//squash input
795
796
cell_j . yIn = Sigmoid ( cell_j . netIn ) ;
796
797
797
798
//include internal connection multiplied by the previous cell state
798
799
cell_j . netForget += cell_j . previousCellState * cellWeight_j . wPeepholeForget +
799
- cell_j_previousCellOutput * cellWeight_j . wCellForget ;
800
+ cell_j . previousCellOutput * cellWeight_j . wCellForget ;
800
801
cell_j . yForget = Sigmoid ( cell_j . netForget ) ;
801
802
802
- cell_j . netCellState += cell_j_previousCellOutput * cellWeight_j . wCellState ;
803
+ cell_j . netCellState += cell_j . previousCellOutput * cellWeight_j . wCellState ;
803
804
cell_j . yCellState = TanH ( cell_j . netCellState ) ;
804
805
805
806
//cell state is equal to the previous cell state multipled by the forget gate and the cell inputs multiplied by the input gate
806
807
cell_j . cellState = cell_j . yForget * cell_j . previousCellState + cell_j . yIn * cell_j . yCellState ;
807
808
808
809
////include the internal connection multiplied by the CURRENT cell state
809
- cell_j . netOut += cell_j . cellState * cellWeight_j . wPeepholeOut + cell_j_previousCellOutput * cellWeight_j . wCellOut ;
810
+ cell_j . netOut += cell_j . cellState * cellWeight_j . wPeepholeOut + cell_j . previousCellOutput * cellWeight_j . wCellOut ;
810
811
811
812
//squash output gate
812
813
cell_j . yOut = Sigmoid ( cell_j . netOut ) ;
@@ -1001,20 +1002,18 @@ public override void BackwardPass()
1001
1002
cellWeight . wPeepholeOut += vecCellDelta . Z ;
1002
1003
1003
1004
//Update cells weights
1004
- var c_previousCellOutput = previousCellOutputs [ i ] ;
1005
1005
//partial derivatives for internal connections
1006
1006
cellWeightDeri . dSWCellIn = cellWeightDeri . dSWCellIn * c . yForget +
1007
- Sigmoid2_ci_netCellState_mul_SigmoidDerivative_ci_netIn * c_previousCellOutput ;
1007
+ Sigmoid2_ci_netCellState_mul_SigmoidDerivative_ci_netIn * c . previousCellOutput ;
1008
1008
1009
1009
//partial derivatives for internal connections, initially zero as dS is zero and previous cell state is zero
1010
1010
cellWeightDeri . dSWCellForget = cellWeightDeri . dSWCellForget * c . yForget +
1011
- ci_previousCellState_mul_SigmoidDerivative_ci_netForget * c_previousCellOutput ;
1011
+ ci_previousCellState_mul_SigmoidDerivative_ci_netForget * c . previousCellOutput ;
1012
1012
1013
1013
cellWeightDeri . dSWCellState = cellWeightDeri . dSWCellState * c . yForget +
1014
- Sigmoid2Derivative_ci_netCellState_mul_ci_yIn * c_previousCellOutput ;
1014
+ Sigmoid2Derivative_ci_netCellState_mul_ci_yIn * c . previousCellOutput ;
1015
1015
1016
- var vecCellDelta4 = new Vector4 ( ( float ) cellWeightDeri . dSWCellIn , ( float ) cellWeightDeri . dSWCellForget , ( float ) cellWeightDeri . dSWCellState ,
1017
- c_previousCellOutput ) ;
1016
+ var vecCellDelta4 = new Vector4 ( ( float ) cellWeightDeri . dSWCellIn , ( float ) cellWeightDeri . dSWCellForget , ( float ) cellWeightDeri . dSWCellState , ( float ) c . previousCellOutput ) ;
1018
1017
vecCellDelta4 = vecErr * vecCellDelta4 ;
1019
1018
1020
1019
//Normalize err by gradient cut-off
@@ -1069,7 +1068,7 @@ public override void Reset()
1069
1068
1070
1069
private void InitializeLSTMCell ( LSTMCell c , LSTMCellWeight cw , LSTMCellWeightDeri deri )
1071
1070
{
1072
- c . previousCellState = 0 ;
1071
+ // c.previousCellState = 0;
1073
1072
c . cellState = 0 ;
1074
1073
1075
1074
//partial derivatives
@@ -1139,6 +1138,7 @@ public class LSTMCell
1139
1138
{
1140
1139
//The following fields are only for forward
1141
1140
public double previousCellState ;
1141
+ public double previousCellOutput ;
1142
1142
public double cellState ;
1143
1143
1144
1144
public double netCellState ;
@@ -1164,6 +1164,7 @@ public LSTMCell(LSTMCell cell)
1164
1164
public void Set ( LSTMCell cell )
1165
1165
{
1166
1166
previousCellState = cell . previousCellState ;
1167
+ previousCellOutput = cell . previousCellOutput ;
1167
1168
cellState = cell . cellState ;
1168
1169
netCellState = cell . netCellState ;
1169
1170
netForget = cell . netForget ;
0 commit comments