Skip to content

Commit 9192fd9

Browse files
committed
finished MultiHeadAttention backward. Dynamically ensure legal inputs shape.
1 parent 8434424 commit 9192fd9

File tree

3 files changed

+77
-59
lines changed

3 files changed

+77
-59
lines changed

mnn/layer.py

+65-45
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,19 @@
22
from mnn.tensor import Tensor
33

44

5+
def ensure_vector_shape(func):
6+
def wrapper(self, inputs, *args, **kwargs):
7+
reshaped = False
8+
if inputs.shape[-1] != 1:
9+
inputs = inputs.unsqueeze(-1)
10+
reshaped = True
11+
outputs = func(self, inputs, *args, **kwargs)
12+
if reshaped:
13+
outputs = outputs.squeeze(-1)
14+
return outputs
15+
return wrapper
16+
17+
518
class BaseLayer():
619
def __init__(self):
720
self.name = self.__class__.__name__
@@ -52,6 +65,7 @@ def __init__(self, *shape, bias=True):
5265
if self.bias:
5366
self.params['b'] = Tensor.randn(1, m, 1)
5467

68+
@ensure_vector_shape
5569
def forward(self, inputs, feedbacks=None):
5670
r'''
5771
An linear layer to compute $y_{m \times 1} = W_{m \times n} x_{n \times 1} + b_{m \times 1}$,
@@ -72,6 +86,7 @@ def forward(self, inputs, feedbacks=None):
7286
else:
7387
return self.params['w'] @ inputs
7488

89+
@ensure_vector_shape
7590
def backward(self, gradients):
7691
r'''
7792
## Gradients w.r.t. $W$
@@ -227,7 +242,8 @@ def forward(self, inputs, feedbacks=None):
227242
r'''
228243
Loss $\ell_i = \sum_{i} (x_i - y_i)^2$
229244
'''
230-
inputs = inputs.squeeze(-1)
245+
if inputs.shape[-1] == 1:
246+
inputs = inputs.squeeze(-1)
231247
self.last_error = inputs - feedbacks
232248
batch_size = inputs.shape[0]
233249
batch_loss = ((inputs - feedbacks) ** 2).sum(axis=1)
@@ -265,10 +281,14 @@ def stable_softmax(inputs, axis):
265281
sum_exps = stable_exps.sum(axis=axis, keepdims=True)
266282
return stable_exps / sum_exps
267283

284+
@ensure_vector_shape
268285
def forward(self, inputs, feedbacks=None):
286+
if inputs.shape[self.axis] == 1:
287+
self.axis -= 1
269288
self.saved_forward = SoftmaxLayer.stable_softmax(inputs, self.axis)
270289
return self.saved_forward
271290

291+
@ensure_vector_shape
272292
def backward(self, gradients):
273293
r'''
274294
When $i = k$,
@@ -318,30 +338,6 @@ def backward(self, gradients):
318338
return jacob_x @ gradients
319339

320340

321-
class LogLayer(BaseLayer):
322-
def forward(self, inputs, feedbacks=None):
323-
r"""
324-
$$
325-
y = \log(x)
326-
$$
327-
"""
328-
self.last_inputs = inputs
329-
return Tensor.log(inputs)
330-
331-
def backward(self, gradients):
332-
r"""
333-
Because this layer is element-wise operation, i.e.,
334-
$J_x y = \operatorname{diag}(x^{-1})$,
335-
the final gradient can be also simplified into Hadamard product:
336-
337-
$$
338-
\nabla_x \ell = J^T_x y \cdot \nabla_y \ell = x^{-1} \odot \nabla_y \ell
339-
$$
340-
"""
341-
reciprocal = 1 / self.last_inputs
342-
return reciprocal * gradients
343-
344-
345341
class LogSoftmaxLayer(BaseLayer):
346342
def __init__(self, *shape, axis=1):
347343
super().__init__()
@@ -370,6 +366,7 @@ def stable_log_softmax(inputs, axis):
370366

371367
return log_softmax, softmax
372368

369+
@ensure_vector_shape
373370
def forward(self, inputs, feedbacks=None):
374371
r"""
375372
$$
@@ -378,12 +375,15 @@ def forward(self, inputs, feedbacks=None):
378375
379376
where $y_i(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}$
380377
"""
378+
if inputs.shape[self.axis] == 1:
379+
self.axis -= 1
381380
log_softmax, softmax = LogSoftmaxLayer.stable_log_softmax(
382381
inputs, self.axis
383382
)
384383
self.saved_forward = softmax
385384
return log_softmax
386385

386+
@ensure_vector_shape
387387
def backward(self, gradients):
388388
r'''
389389
For softmax function, when $i = k$, we have
@@ -429,7 +429,8 @@ def backward(self, gradients):
429429
softmax_dim = softmax_T.shape[-1]
430430
stacked = softmax_T.stacked()
431431
# compute Jacobian matrix wrt. inputs x
432-
identity = Tensor.eye(softmax_dim, softmax_dim).unsqueeze(0)
432+
identity = Tensor.eye(softmax_dim, softmax_dim)
433+
identity = identity.unsqueeze_to_dim_like(stacked)
433434
jacob_x = identity - stacked
434435
return jacob_x.T @ gradients
435436

@@ -439,6 +440,7 @@ class NllLossLayer(BaseLayer):
439440
This is to simulate PyTorch NLL layer which computes a negative expectation loss.
440441
Labels are passed in as integer indices.
441442
'''
443+
@ensure_vector_shape
442444
def forward(self, inputs, feedbacks=None):
443445
r'''
444446
$$
@@ -495,7 +497,9 @@ def forward(self, inputs, feedbacks=None):
495497
inputs, self.axis
496498
)
497499
use_log_softmax = log_softmax[Tensor.arange(batch_size), indices]
498-
cross_entropy = - use_log_softmax.squeeze(-1)
500+
cross_entropy = - use_log_softmax
501+
if cross_entropy.shape[-1] == 1:
502+
cross_entropy = cross_entropy.squeeze(-1)
499503

500504
self.saved_context = (batch_size, indices, softmax)
501505
return self._batch_reduced(cross_entropy)
@@ -667,34 +671,54 @@ def __init__(self, d=64, heads=12, bias=False):
667671
self.W_val = LinearLayer(d_full, d_full, bias=bias)
668672
self.scaled_dotproduct = MatrixProduct()
669673
self.softmax = SoftmaxLayer(axis=-1)
670-
self.apply_attention = MatrixProduct()
674+
self.attn_product = MatrixProduct()
671675

672676
def split_heads(self, X):
673677
X = X.squeeze(-1)
674678
new_shape = X.shape[:-1] + (self.heads, self.d)
675679
X = X.reshape(new_shape)
676680
X = X.transpose(0, 2, 1, 3)
677-
return X.unsqueeze(-1)
681+
return X
678682

679683
def merge_heads(self, X):
680-
X = X.squeeze(-1)
681684
X = X.transpose(0, 2, 1, 3)
682685
new_shape = X.shape[:-2] + (self.d * self.heads,)
683686
X = X.reshape(new_shape)
684687
return X.unsqueeze(-1)
685688

689+
@ensure_vector_shape
686690
def forward(self, inputs, feedbacks=None):
687691
X = inputs
688-
689-
Q = self.split_heads(self.W_qry.forward(X))
690-
K = self.split_heads(self.W_key.forward(X))
691-
V = self.split_heads(self.W_val.forward(X))
692+
Y_qry = self.W_qry.forward(X)
693+
Y_key = self.W_key.forward(X)
694+
Y_val = self.W_val.forward(X)
695+
Q = self.split_heads(Y_qry)
696+
K = self.split_heads(Y_key)
697+
V = self.split_heads(Y_val)
692698

693699
K = K / math.sqrt(self.d)
694700
product = self.scaled_dotproduct.forward((Q, K.T))
695701
A = self.softmax.forward(product)
696-
V_attn = self.apply_attention.forward((A, V))
697-
return self.merge_heads(V_attn)
702+
V_attn = self.attn_product.forward((A, V))
703+
V_attn = self.merge_heads(V_attn)
704+
return V_attn
705+
706+
@ensure_vector_shape
707+
def backward(self, gradients):
708+
gradients = self.split_heads(gradients)
709+
grads_A, grads_V = self.attn_product.backward(gradients)
710+
gradients = self.softmax.backward(grads_A)
711+
grads_Q, grads_KT = self.scaled_dotproduct.backward(gradients)
712+
grads_K = grads_KT.T / math.sqrt(self.d)
713+
714+
grads_Q = self.merge_heads(grads_Q)
715+
grads_K = self.merge_heads(grads_K)
716+
grads_V = self.merge_heads(grads_V)
717+
718+
grads_X = self.W_qry.backward(grads_Q)
719+
grads_X = grads_X + self.W_key.backward(grads_K)
720+
grads_X = grads_X + self.W_val.backward(grads_V)
721+
return grads_X
698722

699723

700724
if __name__ == '__main__':
@@ -721,15 +745,9 @@ def forward(self, inputs, feedbacks=None):
721745
print(gradients.shape)
722746

723747
softmax_layer = SoftmaxLayer()
724-
outputs = softmax_layer.forward(inputs)
748+
outputs = softmax_layer.forward(inputs.squeeze(-1))
725749
print(outputs.shape)
726-
gradients = softmax_layer.backward(Tensor.randn(B, D, 1))
727-
print(gradients.shape)
728-
729-
log_layer = LogLayer()
730-
outputs = log_layer.forward(inputs)
731-
print(outputs.shape)
732-
gradients = log_layer.backward(Tensor.randn(B, D, 1))
750+
gradients = softmax_layer.backward(Tensor.randn(B, D))
733751
print(gradients.shape)
734752

735753
log_softmax_layer = LogSoftmaxLayer()
@@ -753,6 +771,8 @@ def forward(self, inputs, feedbacks=None):
753771
print(gradients.shape)
754772

755773
multihead_attn = MultiHeadAttention(d=32, heads=3)
756-
inputs = Tensor.randn(2, 128, 32 * 3, 1)
774+
inputs = Tensor.randn(2, 128, 32 * 3)
757775
outputs = multihead_attn.forward(inputs)
758776
print(multihead_attn.name, outputs.shape)
777+
gradients = multihead_attn.backward(Tensor.randn(2, 128, 96))
778+
print(gradients.shape)

mnn/seq_layers.py

-13
Original file line numberDiff line numberDiff line change
@@ -176,19 +176,6 @@ def testcase3():
176176
])
177177
return inputs, targets, net
178178

179-
def testcase4():
180-
inputs = Tensor.randn(B, 32, 1)
181-
targets = Tensor.randint(shape=(B, 1), high=C)
182-
net = SequentialLayers([
183-
LinearLayer(32, 40),
184-
ReluLayer(),
185-
LinearLayer(40, C, bias=False),
186-
SoftmaxLayer(),
187-
LogLayer(),
188-
NllLossLayer()
189-
])
190-
return inputs, targets, net
191-
192179
inputs, targets, net = globals()['testcase'+ str(testcase)]()
193180

194181
for ep in range(1 if debug else 20):

mnn/tensor.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,13 @@ def diag_embed(self):
177177
def unsqueeze(self, *args, **kwargs):
178178
return Tensor(cp.expand_dims(self._data, *args, **kwargs))
179179

180+
def unsqueeze_to_dim_like(self, x):
181+
offdim = len(x.shape) - len(self.shape)
182+
assert offdim >= 0
183+
paddim = (1 for _ in range(offdim))
184+
to_dim = (*paddim, *self.shape)
185+
return self.reshape(to_dim)
186+
180187
def pr(self, quit_=False):
181188
import inspect
182189
l = inspect.stack()[1].frame.f_locals
@@ -201,4 +208,8 @@ def pr(self, quit_=False):
201208
s = d.sum(axis=1)
202209
print(s)
203210
print(s + Tensor([[1.0, 2.0, 3.0]]))
204-
s.pr()
211+
212+
a = Tensor.randn(3, 2)
213+
b = Tensor.randn(8, 9, 10, 3, 2)
214+
a = a.unsqueeze_to_dim_like(b)
215+
a.pr()

0 commit comments

Comments
 (0)