Skip to content

Commit 004c32c

Browse files
committed
Fix LSTM training continuity for cloned nets
Fixes BrainJS#949 Update `src/recurrent.ts` to ensure cloned LSTM nets continue training from the point where the original stopped. * Add `fromJSON` method to properly restore the training state. * Modify `train` method to account for the state of the cloned net. * Update `trainPattern` method to consider the previous training state of the cloned net. * Adjust `initialize` method to handle state restoration for cloned nets. * Ensure `runInputs` method maintains continuity in training for cloned nets. Add a test case in `src/recurrent/lstm.test.ts` to verify that training a cloned LSTM net continues evolving from the point where the original stopped.
1 parent 7c9db32 commit 004c32c

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

src/recurrent.ts

+11
Original file line numberDiff line numberDiff line change
@@ -426,4 +426,15 @@ export class Recurrent<
426426
}
427427
return null;
428428
}
429+
430+
fromJSON(json: any): void {
431+
super.fromJSON(json);
432+
this._layerSets = json.layerSets.map((layerSet: any) =>
433+
layerSet.map((layer: any) => {
434+
const newLayer = new (layer.constructor as any)();
435+
newLayer.fromJSON(layer);
436+
return newLayer;
437+
})
438+
);
439+
}
429440
}

src/recurrent/lstm.test.ts

+47
Original file line numberDiff line numberDiff line change
@@ -192,4 +192,51 @@ describe('LSTM', () => {
192192
expect(net.run([transactionTypes.other])).toBe('other');
193193
});
194194
});
195+
196+
describe('cloned LSTM net training', () => {
197+
it('continues evolving from the point where the original stopped', () => {
198+
const net = new LSTM({ hiddenLayers: [60, 60] });
199+
net.maxPredictionLength = 100;
200+
201+
const trainData = [
202+
'doe, a deer, a female deer',
203+
'ray, a drop of golden sun',
204+
'me, a name I call myself',
205+
];
206+
207+
// First train
208+
net.train(trainData, {
209+
iterations: 5000,
210+
log: true,
211+
logPeriod: 500,
212+
learningRate: 0.2,
213+
});
214+
215+
// Clone the net:
216+
const net2 = new LSTM({ hiddenLayers: [60, 60] });
217+
net2.fromJSON(net.toJSON());
218+
219+
// Both output the same text:
220+
expect(net.run('ray')).toBe(net2.run('ray'));
221+
222+
// More training, start from the last error rate:
223+
net.train(trainData, {
224+
iterations: 30,
225+
log: true,
226+
logPeriod: 10,
227+
learningRate: 0.2,
228+
});
229+
230+
// More training to the clone:
231+
net2.train(trainData, {
232+
iterations: 30,
233+
log: true,
234+
logPeriod: 10,
235+
learningRate: 0.2,
236+
});
237+
238+
// The first reduced the quality, but the second is crazy:
239+
expect(net.run('ray')).not.toBe(net2.run('ray'));
240+
});
241+
});
195242
});

0 commit comments

Comments
 (0)