2
2
from mnn .tensor import Tensor
3
3
4
4
5
+ def ensure_vector_shape (func ):
6
+ def wrapper (self , inputs , * args , ** kwargs ):
7
+ reshaped = False
8
+ if inputs .shape [- 1 ] != 1 :
9
+ inputs = inputs .unsqueeze (- 1 )
10
+ reshaped = True
11
+ outputs = func (self , inputs , * args , ** kwargs )
12
+ if reshaped :
13
+ outputs = outputs .squeeze (- 1 )
14
+ return outputs
15
+ return wrapper
16
+
17
+
5
18
class BaseLayer ():
6
19
def __init__ (self ):
7
20
self .name = self .__class__ .__name__
@@ -52,6 +65,7 @@ def __init__(self, *shape, bias=True):
52
65
if self .bias :
53
66
self .params ['b' ] = Tensor .randn (1 , m , 1 )
54
67
68
+ @ensure_vector_shape
55
69
def forward (self , inputs , feedbacks = None ):
56
70
r'''
57
71
An linear layer to compute $y_{m \times 1} = W_{m \times n} x_{n \times 1} + b_{m \times 1}$,
@@ -72,6 +86,7 @@ def forward(self, inputs, feedbacks=None):
72
86
else :
73
87
return self .params ['w' ] @ inputs
74
88
89
+ @ensure_vector_shape
75
90
def backward (self , gradients ):
76
91
r'''
77
92
## Gradients w.r.t. $W$
@@ -227,7 +242,8 @@ def forward(self, inputs, feedbacks=None):
227
242
r'''
228
243
Loss $\ell_i = \sum_{i} (x_i - y_i)^2$
229
244
'''
230
- inputs = inputs .squeeze (- 1 )
245
+ if inputs .shape [- 1 ] == 1 :
246
+ inputs = inputs .squeeze (- 1 )
231
247
self .last_error = inputs - feedbacks
232
248
batch_size = inputs .shape [0 ]
233
249
batch_loss = ((inputs - feedbacks ) ** 2 ).sum (axis = 1 )
@@ -265,10 +281,14 @@ def stable_softmax(inputs, axis):
265
281
sum_exps = stable_exps .sum (axis = axis , keepdims = True )
266
282
return stable_exps / sum_exps
267
283
284
+ @ensure_vector_shape
268
285
def forward (self , inputs , feedbacks = None ):
286
+ if inputs .shape [self .axis ] == 1 :
287
+ self .axis -= 1
269
288
self .saved_forward = SoftmaxLayer .stable_softmax (inputs , self .axis )
270
289
return self .saved_forward
271
290
291
+ @ensure_vector_shape
272
292
def backward (self , gradients ):
273
293
r'''
274
294
When $i = k$,
@@ -318,30 +338,6 @@ def backward(self, gradients):
318
338
return jacob_x @ gradients
319
339
320
340
321
- class LogLayer (BaseLayer ):
322
- def forward (self , inputs , feedbacks = None ):
323
- r"""
324
- $$
325
- y = \log(x)
326
- $$
327
- """
328
- self .last_inputs = inputs
329
- return Tensor .log (inputs )
330
-
331
- def backward (self , gradients ):
332
- r"""
333
- Because this layer is element-wise operation, i.e.,
334
- $J_x y = \operatorname{diag}(x^{-1})$,
335
- the final gradient can be also simplified into Hadamard product:
336
-
337
- $$
338
- \nabla_x \ell = J^T_x y \cdot \nabla_y \ell = x^{-1} \odot \nabla_y \ell
339
- $$
340
- """
341
- reciprocal = 1 / self .last_inputs
342
- return reciprocal * gradients
343
-
344
-
345
341
class LogSoftmaxLayer (BaseLayer ):
346
342
def __init__ (self , * shape , axis = 1 ):
347
343
super ().__init__ ()
@@ -370,6 +366,7 @@ def stable_log_softmax(inputs, axis):
370
366
371
367
return log_softmax , softmax
372
368
369
+ @ensure_vector_shape
373
370
def forward (self , inputs , feedbacks = None ):
374
371
r"""
375
372
$$
@@ -378,12 +375,15 @@ def forward(self, inputs, feedbacks=None):
378
375
379
376
where $y_i(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}$
380
377
"""
378
+ if inputs .shape [self .axis ] == 1 :
379
+ self .axis -= 1
381
380
log_softmax , softmax = LogSoftmaxLayer .stable_log_softmax (
382
381
inputs , self .axis
383
382
)
384
383
self .saved_forward = softmax
385
384
return log_softmax
386
385
386
+ @ensure_vector_shape
387
387
def backward (self , gradients ):
388
388
r'''
389
389
For softmax function, when $i = k$, we have
@@ -429,7 +429,8 @@ def backward(self, gradients):
429
429
softmax_dim = softmax_T .shape [- 1 ]
430
430
stacked = softmax_T .stacked ()
431
431
# compute Jacobian matrix wrt. inputs x
432
- identity = Tensor .eye (softmax_dim , softmax_dim ).unsqueeze (0 )
432
+ identity = Tensor .eye (softmax_dim , softmax_dim )
433
+ identity = identity .unsqueeze_to_dim_like (stacked )
433
434
jacob_x = identity - stacked
434
435
return jacob_x .T @ gradients
435
436
@@ -439,6 +440,7 @@ class NllLossLayer(BaseLayer):
439
440
This is to simulate PyTorch NLL layer which computes a negative expectation loss.
440
441
Labels are passed in as integer indices.
441
442
'''
443
+ @ensure_vector_shape
442
444
def forward (self , inputs , feedbacks = None ):
443
445
r'''
444
446
$$
@@ -495,7 +497,9 @@ def forward(self, inputs, feedbacks=None):
495
497
inputs , self .axis
496
498
)
497
499
use_log_softmax = log_softmax [Tensor .arange (batch_size ), indices ]
498
- cross_entropy = - use_log_softmax .squeeze (- 1 )
500
+ cross_entropy = - use_log_softmax
501
+ if cross_entropy .shape [- 1 ] == 1 :
502
+ cross_entropy = cross_entropy .squeeze (- 1 )
499
503
500
504
self .saved_context = (batch_size , indices , softmax )
501
505
return self ._batch_reduced (cross_entropy )
@@ -667,34 +671,54 @@ def __init__(self, d=64, heads=12, bias=False):
667
671
self .W_val = LinearLayer (d_full , d_full , bias = bias )
668
672
self .scaled_dotproduct = MatrixProduct ()
669
673
self .softmax = SoftmaxLayer (axis = - 1 )
670
- self .apply_attention = MatrixProduct ()
674
+ self .attn_product = MatrixProduct ()
671
675
672
676
def split_heads (self , X ):
673
677
X = X .squeeze (- 1 )
674
678
new_shape = X .shape [:- 1 ] + (self .heads , self .d )
675
679
X = X .reshape (new_shape )
676
680
X = X .transpose (0 , 2 , 1 , 3 )
677
- return X . unsqueeze ( - 1 )
681
+ return X
678
682
679
683
def merge_heads (self , X ):
680
- X = X .squeeze (- 1 )
681
684
X = X .transpose (0 , 2 , 1 , 3 )
682
685
new_shape = X .shape [:- 2 ] + (self .d * self .heads ,)
683
686
X = X .reshape (new_shape )
684
687
return X .unsqueeze (- 1 )
685
688
689
+ @ensure_vector_shape
686
690
def forward (self , inputs , feedbacks = None ):
687
691
X = inputs
688
-
689
- Q = self .split_heads (self .W_qry .forward (X ))
690
- K = self .split_heads (self .W_key .forward (X ))
691
- V = self .split_heads (self .W_val .forward (X ))
692
+ Y_qry = self .W_qry .forward (X )
693
+ Y_key = self .W_key .forward (X )
694
+ Y_val = self .W_val .forward (X )
695
+ Q = self .split_heads (Y_qry )
696
+ K = self .split_heads (Y_key )
697
+ V = self .split_heads (Y_val )
692
698
693
699
K = K / math .sqrt (self .d )
694
700
product = self .scaled_dotproduct .forward ((Q , K .T ))
695
701
A = self .softmax .forward (product )
696
- V_attn = self .apply_attention .forward ((A , V ))
697
- return self .merge_heads (V_attn )
702
+ V_attn = self .attn_product .forward ((A , V ))
703
+ V_attn = self .merge_heads (V_attn )
704
+ return V_attn
705
+
706
+ @ensure_vector_shape
707
+ def backward (self , gradients ):
708
+ gradients = self .split_heads (gradients )
709
+ grads_A , grads_V = self .attn_product .backward (gradients )
710
+ gradients = self .softmax .backward (grads_A )
711
+ grads_Q , grads_KT = self .scaled_dotproduct .backward (gradients )
712
+ grads_K = grads_KT .T / math .sqrt (self .d )
713
+
714
+ grads_Q = self .merge_heads (grads_Q )
715
+ grads_K = self .merge_heads (grads_K )
716
+ grads_V = self .merge_heads (grads_V )
717
+
718
+ grads_X = self .W_qry .backward (grads_Q )
719
+ grads_X = grads_X + self .W_key .backward (grads_K )
720
+ grads_X = grads_X + self .W_val .backward (grads_V )
721
+ return grads_X
698
722
699
723
700
724
if __name__ == '__main__' :
@@ -721,15 +745,9 @@ def forward(self, inputs, feedbacks=None):
721
745
print (gradients .shape )
722
746
723
747
softmax_layer = SoftmaxLayer ()
724
- outputs = softmax_layer .forward (inputs )
748
+ outputs = softmax_layer .forward (inputs . squeeze ( - 1 ) )
725
749
print (outputs .shape )
726
- gradients = softmax_layer .backward (Tensor .randn (B , D , 1 ))
727
- print (gradients .shape )
728
-
729
- log_layer = LogLayer ()
730
- outputs = log_layer .forward (inputs )
731
- print (outputs .shape )
732
- gradients = log_layer .backward (Tensor .randn (B , D , 1 ))
750
+ gradients = softmax_layer .backward (Tensor .randn (B , D ))
733
751
print (gradients .shape )
734
752
735
753
log_softmax_layer = LogSoftmaxLayer ()
@@ -753,6 +771,8 @@ def forward(self, inputs, feedbacks=None):
753
771
print (gradients .shape )
754
772
755
773
multihead_attn = MultiHeadAttention (d = 32 , heads = 3 )
756
- inputs = Tensor .randn (2 , 128 , 32 * 3 , 1 )
774
+ inputs = Tensor .randn (2 , 128 , 32 * 3 )
757
775
outputs = multihead_attn .forward (inputs )
758
776
print (multihead_attn .name , outputs .shape )
777
+ gradients = multihead_attn .backward (Tensor .randn (2 , 128 , 96 ))
778
+ print (gradients .shape )
0 commit comments