-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathSelfAttention_Family.py
142 lines (112 loc) · 5.23 KB
/
SelfAttention_Family.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
import torch.nn as nn
import numpy as np
from math import sqrt
from einops import repeat
from layers.Attn_Bias import BinaryAttentionBias
from layers.Attn_Projection import QueryKeyProjection, RotaryProjection
from utils.masking import TriangularCausalMask, TimerMultivariateMask, TimerCovariateMask
class FullAttention(nn.Module):
def __init__(self, mask_flag=True, scale=None, attention_dropout=0.1, output_attention=False):
super(FullAttention, self).__init__()
self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)
def forward(self, queries, keys, values, attn_mask, n_vars=None, n_tokens=None, tau=None, delta=None):
B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = self.scale or 1. / sqrt(E)
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
if self.mask_flag:
if attn_mask is None:
attn_mask = TriangularCausalMask(B, L, device=queries.device)
scores.masked_fill_(attn_mask.mask, -np.inf)
A = self.dropout(torch.softmax(scale * scores, dim=-1))
V = torch.einsum("bhls,bshd->blhd", A, values)
if self.output_attention:
return V.contiguous(), A
else:
return V.contiguous(), None
class TimeAttention(nn.Module):
def __init__(self, mask_flag=True, scale=None, attention_dropout=0.1, output_attention=False, d_model=512, num_heads=8, max_len=100, covariate=False, flash_attention=False):
super(TimeAttention, self).__init__()
self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)
self.covariate = covariate
self.flash_attention = flash_attention
self.qk_proj = QueryKeyProjection(dim=d_model, num_heads=num_heads, proj_layer=RotaryProjection, kwargs=dict(max_len=max_len),
partial_factor=(0.0, 0.5),)
self.attn_bias = BinaryAttentionBias(dim=d_model, num_heads=num_heads)
def forward(self, queries, keys, values, attn_mask, n_vars, n_tokens, tau=None, delta=None):
B, L, H, E = queries.shape
_, S, _, D = values.shape
# [B, H, L, E]
queries = queries.permute(0, 2, 1, 3)
keys = keys.permute(0, 2, 1, 3)
if self.flash_attention:
values = values.permute(0, 2, 1, 3)
seq_id = torch.arange(n_tokens * n_vars)
seq_id = repeat(seq_id, 'n -> b h n', b=B, h=H)
queries, keys = self.qk_proj(
queries, keys, query_id=seq_id, kv_id=seq_id)
scale = self.scale or 1. / sqrt(E)
var_id = repeat(torch.arange(n_vars),
'C -> (C n_tokens)', n_tokens=n_tokens)
var_id = repeat(var_id, 'L -> b h L', b=B, h=1).to(queries.device)
attn_bias = self.attn_bias(var_id, var_id)
if self.mask_flag:
if attn_mask is None:
if self.covariate:
attn_mask = TimerCovariateMask(
B, n_vars, n_tokens, device=queries.device)
else:
attn_mask = TimerMultivariateMask(
B, n_vars, n_tokens, device=queries.device)
attn_mask = attn_bias.masked_fill(attn_mask.mask, float("-inf"))
else:
attn_mask = attn_bias
if self.flash_attention:
V = torch.nn.functional.scaled_dot_product_attention(
queries, keys, values, attn_mask)
else:
scores = torch.einsum("bhle,bhse->bhls", queries, keys)
scores += attn_mask
A = self.dropout(torch.softmax(scale * scores, dim=-1))
V = torch.einsum("bhls,bshd->blhd", A, values)
if self.output_attention:
return V.contiguous(), None
else:
return V.contiguous(), None
class AttentionLayer(nn.Module):
def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None):
super(AttentionLayer, self).__init__()
d_keys = d_keys or (d_model // n_heads)
d_values = d_values or (d_model // n_heads)
self.inner_attention = attention
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
self.value_projection = nn.Linear(d_model, d_values * n_heads)
self.out_projection = nn.Linear(d_values * n_heads, d_model)
self.n_heads = n_heads
def forward(self, queries, keys, values, attn_mask, n_vars=None, n_tokens=None, tau=None, delta=None):
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_heads
queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)
out, attn = self.inner_attention(
queries,
keys,
values,
attn_mask,
n_vars=n_vars,
n_tokens=n_tokens,
tau=tau,
delta=delta
)
out = out.view(B, L, -1)
return self.out_projection(out), attn