Skip to content

Commit 8448f50

Browse files
start on #16 using https://github.com/karpathy/recurrentjs as a basis to start from
1 parent d2cd6a4 commit 8448f50

15 files changed

+2697
-0
lines changed

index.html

+2,026
Large diffs are not rendered by default.

lib/recurrent/index.js

+172
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
//http://colah.github.io/posts/2015-08-Understanding-LSTMs/
2+
var Matrix = require('./matrix'),
3+
RNN = require('./rnn'),
4+
LSTM = require('./lstm');
5+
6+
// Transformer definitions
7+
function Graph(needsBackprop) {
8+
if(typeof needsBackprop === 'undefined') { needsBackprop = true; }
9+
this.needsBackprop = needsBackprop;
10+
11+
// this will store a list of functions that perform backprop,
12+
// in their forward pass order. So in backprop we will go
13+
// backwards and evoke each one
14+
this.backprop = [];
15+
}
16+
Graph.prototype = {
17+
backward: function() {
18+
while(this.backprop.length > 1) {
19+
this.backprop.pop()(); // tick!
20+
}
21+
},
22+
/**
23+
*
24+
* @param {Matrix} m
25+
* @param ix
26+
*/
27+
rowPluck: function(m, ix) {
28+
// pluck a row of m with index ix and return it as col vector
29+
if (ix < 0 && ix >= m.n) throw new Error('row cannot pluck');
30+
var d = m.d;
31+
var out = new Matrix(d, 1);
32+
for(var i=0,n=d;i<n;i++){ out.weights[i] = m.weights[d * ix + i]; } // copy over the data
33+
34+
if(this.needsBackprop) {
35+
this.backprop.push(function backward() {
36+
for(var i=0,n=d;i<n;i++){ m.dw[d * ix + i] += out.dw[i]; }
37+
});
38+
}
39+
return out;
40+
},
41+
42+
/**
43+
*
44+
* @param {Matrix} m
45+
*/
46+
tanh: function(m) {
47+
// tanh nonlinearity
48+
var out = new Matrix(m.n, m.d);
49+
var n = m.weights.length;
50+
for(var i=0;i<n;i++) {
51+
out.weights[i] = Math.tanh(m.weights[i]);
52+
}
53+
54+
if(this.needsBackprop) {
55+
this.backprop.push(function backward() {
56+
for(var i=0;i<n;i++) {
57+
// grad for z = tanh(x) is (1 - z^2)
58+
var mwi = out.weights[i];
59+
m.dw[i] += (1.0 - mwi * mwi) * out.dw[i];
60+
}
61+
});
62+
}
63+
return out;
64+
},
65+
66+
/**
67+
*
68+
* @param {Matrix} m
69+
*/
70+
sigmoid: function(m) {
71+
// sigmoid nonlinearity
72+
var out = new Matrix(m.n, m.d);
73+
var n = m.weights.length;
74+
for(var i=0;i<n;i++) {
75+
out.weights[i] = sig(m.weights[i]);
76+
}
77+
78+
if(this.needsBackprop) {
79+
this.backprop.push(function backward() {
80+
for(var i=0;i<n;i++) {
81+
// grad for z = tanh(x) is (1 - z^2)
82+
var mwi = out.weights[i];
83+
m.dw[i] += mwi * (1.0 - mwi) * out.dw[i];
84+
}
85+
});
86+
}
87+
return out;
88+
},
89+
90+
/**
91+
*
92+
* @param {Matrix} m
93+
*/
94+
relu: function(m) {
95+
var out = new Matrix(m.n, m.d);
96+
var n = m.weights.length;
97+
for(var i=0;i<n;i++) {
98+
out.weights[i] = Math.max(0, m.weights[i]); // relu
99+
}
100+
if(this.needsBackprop) {
101+
this.backprop.push(function backward() {
102+
for(var i=0;i<n;i++) {
103+
m.dw[i] += m.weights[i] > 0 ? out.dw[i] : 0.0;
104+
}
105+
});
106+
}
107+
return out;
108+
}
109+
};
110+
111+
function Solver() {
112+
this.decayRate = 0.999;
113+
this.smoothEps = 1e-8;
114+
this.stepCache = {};
115+
this.ratioClipped = null;
116+
}
117+
Solver.prototype = {
118+
step: function(stepSize, regc, clipval) {
119+
// perform parameter update
120+
var model = this.model;
121+
var solverStats = {};
122+
var numClipped = 0;
123+
var numTot = 0;
124+
for(var k in model) {
125+
if(model.hasOwnProperty(k)) {
126+
var m = model[k]; // mat ref
127+
if(!(k in this.stepCache)) { this.stepCache[k] = new Matrix(m.n, m.d); }
128+
var s = this.stepCache[k];
129+
for(var i=0,n=m.weights.length;i<n;i++) {
130+
131+
// rmsprop adaptive learning rate
132+
var mdwi = m.dw[i];
133+
s.weights[i] = s.weights[i] * this.decayRate + (1.0 - this.decayRate) * mdwi * mdwi;
134+
135+
// gradient clip
136+
if(mdwi > clipval) {
137+
mdwi = clipval;
138+
numClipped++;
139+
}
140+
if(mdwi < -clipval) {
141+
mdwi = -clipval;
142+
numClipped++;
143+
}
144+
numTot++;
145+
146+
// update (and regularize)
147+
m.weights[i] += - stepSize * mdwi / Math.sqrt(s.weights[i] + this.smoothEps) - regc * m.weights[i];
148+
m.dw[i] = 0; // reset gradients for next iteration
149+
}
150+
}
151+
}
152+
this.ratioClipped = numClipped*1.0/numTot;
153+
154+
return this;
155+
}
156+
};
157+
158+
function sig(x) {
159+
// helper function for computing sigmoid
160+
return 1.0 / (1 + Math.exp(-x));
161+
}
162+
163+
// various utils
164+
module.exports = {
165+
// classes
166+
LSTM: LSTM,
167+
RNN: RNN,
168+
169+
// optimization
170+
Solver: Solver,
171+
Graph: Graph
172+
};

lib/recurrent/lstm.js

+126
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
var Matrix = require('./matrix'),
2+
RandomMatrix = require('./matrix/random'),
3+
add = require('./matrix/add'),
4+
multiply = require('./matrix/multiply'),
5+
multiplyElement = require('./matrix/multiply-element');
6+
7+
function LSTM(inputSize, hiddenSizes, outputSize) {
8+
// hidden size should be a list
9+
10+
this.model = [];
11+
this.inputSize = inputSize;
12+
this.hiddenSizes = hiddenSizes;
13+
this.outputSize = outputSize;
14+
15+
for(var d=0;d<hiddenSizes.length;d++) { // loop over depths
16+
var prevSize = d === 0 ? inputSize : hiddenSizes[d - 1];
17+
var hiddenSize = hiddenSizes[d];
18+
this.model.push({
19+
// gates parameters
20+
wix: new RandomMatrix(hiddenSize, prevSize , 0, 0.08),
21+
wih: new RandomMatrix(hiddenSize, hiddenSize , 0, 0.08),
22+
bi: new Matrix(hiddenSize, 1),
23+
24+
wfx: new RandomMatrix(hiddenSize, prevSize , 0, 0.08),
25+
wfh: new RandomMatrix(hiddenSize, hiddenSize , 0, 0.08),
26+
bf: new Matrix(hiddenSize, 1),
27+
28+
wox: new RandomMatrix(hiddenSize, prevSize , 0, 0.08),
29+
woh: new RandomMatrix(hiddenSize, hiddenSize , 0, 0.08),
30+
bo: new Matrix(hiddenSize, 1),
31+
32+
// cell write params
33+
wcx: new RandomMatrix(hiddenSize, prevSize , 0, 0.08),
34+
wch: new RandomMatrix(hiddenSize, hiddenSize , 0, 0.08),
35+
bc: new Matrix(hiddenSize, 1)
36+
});
37+
}
38+
// decoder params
39+
this.model.whd = new RandomMatrix(outputSize, hiddenSize, 0, 0.08);
40+
this.model.bd = new Matrix(outputSize, 1);
41+
}
42+
43+
LSTM.prototype = {
44+
/**
45+
*
46+
* @param {Graph} graph
47+
* @param prev
48+
* @returns {{hidden: Array, cell: Array, output}}
49+
*/
50+
forward: function (graph, prev) {
51+
// forward prop for a single tick of LSTM
52+
// G is graph to append ops to
53+
// model contains LSTM parameters
54+
// x is 1D column vector with observation
55+
// prev is a struct containing hidden and cell
56+
// from previous iteration
57+
58+
var model = this.model,
59+
hiddenSizes = this.hiddenSizes,
60+
hiddenPrevs,
61+
cellPrevs,
62+
d;
63+
64+
if(typeof prev.hidden === 'undefined') {
65+
hiddenPrevs = [];
66+
cellPrevs = [];
67+
for(d=0;d<hiddenSizes.length;d++) {
68+
hiddenPrevs.push(new Matrix(hiddenSizes[d],1));
69+
cellPrevs.push(new Matrix(hiddenSizes[d],1));
70+
}
71+
} else {
72+
hiddenPrevs = prev.hidden;
73+
cellPrevs = prev.cell;
74+
}
75+
76+
var hidden = [],
77+
cell = [];
78+
for(d=0;d<hiddenSizes.length;d++) {
79+
80+
var inputVector = d === 0 ? x : hidden[d-1];
81+
var hiddenPrev = hiddenPrevs[d];
82+
var cellPrev = cellPrevs[d];
83+
84+
// input gate
85+
var h0 = multiply(model[d].wix, inputVector);
86+
var h1 = multiply(model[d].wih, hiddenPrev);
87+
var inputGate = graph.sigmoid(add(add(h0,h1),model[d].bi));
88+
89+
// forget gate
90+
var h2 = multiply(model[d].wfx, inputVector);
91+
var h3 = multiply(model[d].wfh, hiddenPrev);
92+
var forgetGate = graph.sigmoid(add(add(h2, h3),model[d].bf));
93+
94+
// output gate
95+
var h4 = multiply(model[d].wox, inputVector);
96+
var h5 = multiply(model[d].woh, hiddenPrev);
97+
var outputGate = graph.sigmoid(add(add(h4, h5),model[d].bo));
98+
99+
// write operation on cells
100+
var h6 = multiply(model[d].wcx, inputVector);
101+
var h7 = multiply(model[d].wch, hiddenPrev);
102+
var cellWrite = graph.tanh(add(add(h6, h7),model[d].bc));
103+
104+
// compute new cell activation
105+
var retainCell = multiplyElement(forgetGate, cellPrev); // what do we keep from cell
106+
var writeCell = multiplyElement(inputGate, cellWrite); // what do we write to cell
107+
var cellD = add(retainCell, writeCell); // new cell contents
108+
109+
// compute hidden state as gated, saturated cell activations
110+
var hiddenD = multiplyElement(outputGate, graph.tanh(cellD));
111+
112+
hidden.push(hiddenD);
113+
cell.push(cellD);
114+
}
115+
116+
// one decoder to outputs at end
117+
var output = add(multiply(model.whd, hidden[hidden.length - 1]), model.bd);
118+
119+
// return cell memory, hidden representation and output
120+
return {
121+
hidden: hidden,
122+
cell: cell,
123+
output: output
124+
};
125+
}
126+
};

lib/recurrent/matrix/add.js

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
var Matrix = require('./index');
2+
/**
3+
*
4+
* @param {Matrix} m1
5+
* @param {Matrix} m2
6+
* @param backPropagateArray
7+
* @returns {Matrix}
8+
*/
9+
module.export = function add(m1, m2, backPropagateArray) {
10+
if (m1.weights.length !== m2.weights.length) throw new Error('matrix addition dimensions misaligned');
11+
12+
var out = new Matrix(m1.n, m1.d);
13+
for(var i=0,n=m1.weights.length;i<n;i++) {
14+
out.weights[i] = m1.weights[i] + m2.weights[i];
15+
}
16+
if(typeof backPropagateArray !== 'undefined') {
17+
backPropagateArray.push(function backward() {
18+
for(var i=0,n=m1.weights.length;i<n;i++) {
19+
m1.dw[i] += out.dw[i];
20+
m2.dw[i] += out.dw[i];
21+
}
22+
});
23+
}
24+
return out;
25+
};

lib/recurrent/matrix/index.js

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
var zeros = require('./zeros'),
2+
random = require('./random'),
3+
randf = random.f,
4+
randn = random.n;
5+
6+
/**
7+
* A matrix
8+
* @param {Number} n
9+
* @param {Number} d
10+
* @constructor
11+
*/
12+
function Matrix(n, d) {
13+
// n is number of rows d is number of columns
14+
this.n = n;
15+
this.d = d;
16+
this.weights = zeros(n * d);
17+
this.dw = zeros(n * d);
18+
}
19+
20+
Matrix.prototype = {
21+
getWeights: function(row, col) {
22+
// slow but careful accessor function
23+
// we want row-major order
24+
var ix = (this.d * row) + col;
25+
if (ix < 0 && ix >= this.weights.length) throw new Error('get accessor is skewed');
26+
return this.weights[ix];
27+
},
28+
setWeights: function(row, col, v) {
29+
// slow but careful accessor function
30+
var ix = (this.d * row) + col;
31+
if (ix < 0 && ix >= this.weights.length) throw new Error('set accessor is skewed');
32+
this.weights[ix] = v;
33+
},
34+
toJSON: function() {
35+
var weights = [];
36+
for (var i = 0; i < this.weights.length; i++) {
37+
weights.push(this.weights[i]);
38+
}
39+
return {
40+
n: this.n,
41+
d: this.d,
42+
weights: weights
43+
};
44+
},
45+
fromJSON: function(json) {
46+
this.n = json.n;
47+
this.d = json.d;
48+
this.weights = zeros(this.n * this.d);
49+
this.dw = zeros(this.n * this.d);
50+
for(var i=0,n=this.n * this.d;i<n;i++) {
51+
this.weights[i] = json.weights[i]; // copy over weights
52+
}
53+
},
54+
55+
// fill matrix with random gaussian numbers
56+
fillRandN: function(mu, std) {
57+
for(var i=0,n=this.weights.length;i<n;i++) {
58+
this.weights[i] = randn(mu, std);
59+
}
60+
},
61+
62+
// fill matrix with random gaussian numbers
63+
fillRand: function(lo, hi) {
64+
for(var i=0,n=this.weights.length;i<n;i++) {
65+
this.weights[i] = randf(lo, hi);
66+
}
67+
}
68+
};
69+
70+
module.export = Matrix;

0 commit comments

Comments
 (0)