Skip to content

Commit 814dd46

Browse files
committed
Code refactoring for sentence prediction
1 parent c18b865 commit 814dd46

File tree

6 files changed

+137
-183
lines changed

6 files changed

+137
-183
lines changed

RNNSharp.v12.suo

15 KB
Binary file not shown.

RNNSharp/BiRNN.cs

+25-56
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,7 @@ public override void GetHiddenLayer(Matrix<double> m, int curStatus)
141141

142142
public override void initMem()
143143
{
144-
for (int i = 0; i < MAX_RNN_HIST; i++)
145-
{
146-
m_Diff[i] = new double[L2];
147-
}
148-
144+
m_Diff = new Matrix<double>(MAX_RNN_HIST, L2);
149145
m_tagBigramTransition = new Matrix<double>(L2, L2);
150146
m_DeltaBigramLM = new Matrix<double>(L2, L2);
151147

@@ -166,27 +162,6 @@ public override void initMem()
166162
}
167163
}
168164

169-
170-
171-
public override Matrix<double> InnerDecode(Sequence pSequence)
172-
{
173-
Matrix<neuron> mHiddenLayer = null;
174-
Matrix<double> mRawOutputLayer = null;
175-
neuron[][] outputLayer = InnerDecode(pSequence, out mHiddenLayer, out mRawOutputLayer);
176-
int numStates = pSequence.GetSize();
177-
178-
Matrix<double> m = new Matrix<double>(numStates, L2);
179-
for (int currState = 0; currState < numStates; currState++)
180-
{
181-
for (int i = 0; i < L2; i++)
182-
{
183-
m[currState][i] = outputLayer[currState][i].cellOutput;
184-
}
185-
}
186-
187-
return m;
188-
}
189-
190165
int[] predicted_fnn;
191166
int[] predicted_bnn;
192167
public neuron[][] InnerDecode(Sequence pSequence, out Matrix<neuron> outputHiddenLayer, out Matrix<double> rawOutputLayer)
@@ -267,12 +242,10 @@ public override void netFlush()
267242
backwardRNN.netFlush();
268243
}
269244

270-
public override int[] learnSentenceForRNNCRF(Sequence pSequence)
245+
public override Matrix<double> learnSentenceForRNNCRF(Sequence pSequence, RunningMode runningMode)
271246
{
272247
//Reset the network
273248
int numStates = pSequence.GetSize();
274-
int[] predicted = new int[numStates];
275-
276249
//Predict output
277250
Matrix<neuron> mergedHiddenLayer = null;
278251
Matrix<double> rawOutputLayer = null;
@@ -281,12 +254,10 @@ public override int[] learnSentenceForRNNCRF(Sequence pSequence)
281254
ForwardBackward(numStates, rawOutputLayer);
282255

283256
//Get the best result
284-
predicted = new int[numStates];
285257
for (int i = 0; i < numStates; i++)
286258
{
287259
State state = pSequence.Get(i);
288260
logp += Math.Log10(m_Diff[i][state.GetLabel()]);
289-
predicted[i] = GetBestZIndex(i);
290261
}
291262

292263
UpdateBigramTransition(pSequence);
@@ -305,44 +276,48 @@ public override int[] learnSentenceForRNNCRF(Sequence pSequence)
305276

306277
LearnTwoRNN(pSequence, mergedHiddenLayer, seqOutput);
307278

308-
return predicted;
279+
return m_Diff;
309280
}
310281

311-
public override int[] PredictSentence(Sequence pSequence)
282+
public override Matrix<double> PredictSentence(Sequence pSequence, RunningMode runningMode)
312283
{
313284
//Reset the network
314285
int numStates = pSequence.GetSize();
315-
int[] predicted = new int[numStates];
316286

317287
//Predict output
318288
Matrix<neuron> mergedHiddenLayer = null;
319289
Matrix<double> rawOutputLayer = null;
320290
neuron[][] seqOutput = InnerDecode(pSequence, out mergedHiddenLayer, out rawOutputLayer);
321291

322-
//Merge forward and backward
323-
for (int curState = 0; curState < numStates; curState++)
292+
if (runningMode != RunningMode.Test)
324293
{
325-
State state = pSequence.Get(curState);
326-
logp += Math.Log10(seqOutput[curState][state.GetLabel()].cellOutput);
327-
328-
predicted[curState] = GetBestOutputIndex(seqOutput, curState, L2);
294+
//Merge forward and backward
295+
for (int curState = 0; curState < numStates; curState++)
296+
{
297+
State state = pSequence.Get(curState);
298+
logp += Math.Log10(seqOutput[curState][state.GetLabel()].cellOutput);
299+
counter++;
300+
}
329301
}
330302

331-
//Update hidden-output layer weights
332-
for (int curState = 0; curState < numStates; curState++)
303+
if (runningMode == RunningMode.Train)
333304
{
334-
State state = pSequence.Get(curState);
335-
//For standard RNN
336-
for (int c = 0; c < L2; c++)
305+
//Update hidden-output layer weights
306+
for (int curState = 0; curState < numStates; curState++)
337307
{
338-
seqOutput[curState][c].er = -seqOutput[curState][c].cellOutput;
308+
State state = pSequence.Get(curState);
309+
//For standard RNN
310+
for (int c = 0; c < L2; c++)
311+
{
312+
seqOutput[curState][c].er = -seqOutput[curState][c].cellOutput;
313+
}
314+
seqOutput[curState][state.GetLabel()].er = 1 - seqOutput[curState][state.GetLabel()].cellOutput;
339315
}
340-
seqOutput[curState][state.GetLabel()].er = 1 - seqOutput[curState][state.GetLabel()].cellOutput;
341-
}
342316

343-
LearnTwoRNN(pSequence, mergedHiddenLayer, seqOutput);
317+
LearnTwoRNN(pSequence, mergedHiddenLayer, seqOutput);
318+
}
344319

345-
return predicted;
320+
return rawOutputLayer;
346321
}
347322

348323
private void LearnTwoRNN(Sequence pSequence, Matrix<neuron> mergedHiddenLayer, neuron[][] seqOutput)
@@ -353,8 +328,6 @@ private void LearnTwoRNN(Sequence pSequence, Matrix<neuron> mergedHiddenLayer, n
353328
forwardRNN.mat_hidden2output = mat_hidden2output.CopyTo();
354329
backwardRNN.mat_hidden2output = mat_hidden2output.CopyTo();
355330

356-
357-
358331
Parallel.Invoke(() =>
359332
{
360333
for (int curState = 0; curState < numStates; curState++)
@@ -377,8 +350,6 @@ private void LearnTwoRNN(Sequence pSequence, Matrix<neuron> mergedHiddenLayer, n
377350
//Learn forward network
378351
for (int curState = 0; curState < numStates; curState++)
379352
{
380-
System.Threading.Interlocked.Increment(ref counter);
381-
382353
// error propogation
383354
State state = pSequence.Get(curState);
384355
forwardRNN.setInputLayer(state, curState, numStates, predicted_fnn);
@@ -396,8 +367,6 @@ private void LearnTwoRNN(Sequence pSequence, Matrix<neuron> mergedHiddenLayer, n
396367

397368
for (int curState = 0; curState < numStates; curState++)
398369
{
399-
System.Threading.Interlocked.Increment(ref counter);
400-
401370
int curState2 = numStates - 1 - curState;
402371

403372
// error propogation

RNNSharp/LSTMRNN.cs

+2-9
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,7 @@ public override void loadNetBin(string filename)
184184
{
185185
m_tagBigramTransition = loadMatrixBin(br);
186186

187-
for (int i = 0; i < MAX_RNN_HIST; i++)
188-
{
189-
m_Diff[i] = new double[L2];
190-
}
187+
m_Diff = new Matrix<double>(MAX_RNN_HIST, L2);
191188
m_DeltaBigramLM = new Matrix<double>(L2, L2);
192189
}
193190

@@ -361,11 +358,7 @@ public override void initMem()
361358
{
362359
CreateCell(null);
363360

364-
for (int i = 0; i < MAX_RNN_HIST; i++)
365-
{
366-
m_Diff[i] = new double[L2];
367-
}
368-
361+
m_Diff = new Matrix<double>(MAX_RNN_HIST, L2);
369362
m_tagBigramTransition = new Matrix<double>(L2, L2);
370363
m_DeltaBigramLM = new Matrix<double>(L2, L2);
371364

0 commit comments

Comments
 (0)