-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathAttn_Projection.py
106 lines (92 loc) · 3.8 KB
/
Attn_Projection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import abc
import torch
from functools import cached_property
from einops import einsum, rearrange, repeat
from torch import nn
class Projection(nn.Module, abc.ABC):
def __init__(self, proj_width: int, num_heads: int, **kwargs):
super().__init__()
self.proj_width = proj_width
self.num_heads = num_heads
@abc.abstractmethod
def forward(self, x, seq_id): ...
class RotaryProjection(Projection):
def __init__(self, *, proj_width: int, num_heads: int, max_len: int = 512, base: int = 10000):
super().__init__(proj_width, num_heads)
assert (
self.proj_width % 2 == 0
), f"proj_width must be even, got {self.proj_width}"
self.register_buffer(
"theta",
1.0
/ torch.pow(
base,
torch.arange(0, self.proj_width, 2, dtype=torch.float)
/ self.proj_width,
),
persistent=False,
)
self.register_buffer("cos", None, persistent=False)
self.register_buffer("sin", None, persistent=False)
self._init_freq(max_len=max_len)
def _init_freq(self, max_len: int):
if self.cos is None or self.cos.size(-2) < max_len:
position = torch.arange(
max_len, device=self.theta.device, dtype=self.theta.dtype
)
m_theta = einsum(position, self.theta,
"length, width -> length width")
m_theta = repeat(m_theta, "length width -> length (width 2)")
self.register_buffer("cos", torch.cos(m_theta), persistent=False)
self.register_buffer("sin", torch.sin(m_theta), persistent=False)
@staticmethod
def _rotate(x):
x1, x2 = rearrange(x, "... (dim r) -> r ... dim", r=2)
return rearrange([-x2, x1], "r ... dim -> ... (dim r)", r=2) # noqa
def forward(self, x, seq_id):
self._init_freq(max_len=seq_id.max() + 1)
rot_cos = self.cos[seq_id]
rot_sin = self.sin[seq_id]
return rot_cos * x + rot_sin * self._rotate(x)
class QueryKeyProjection(nn.Module):
def __init__(self, dim: int, num_heads: int, proj_layer, kwargs=None, partial_factor=None):
super().__init__()
if partial_factor is not None:
assert (
0.0 <= partial_factor[0] < partial_factor[1] <= 1.0
), f"got {partial_factor[0]}, {partial_factor[1]}"
assert num_heads > 0 and dim % num_heads == 0
self.head_dim = dim // num_heads
self.partial_factor = partial_factor
self.query_proj = proj_layer(
proj_width=self.proj_width,
num_heads=num_heads,
**(kwargs or {}),
)
self.key_proj = self.query_proj
@cached_property
def proj_width(self) -> int:
if self.partial_factor is None:
return self.head_dim
return int(self.head_dim * (self.partial_factor[1] - self.partial_factor[0]))
@cached_property
def split_sizes(self):
if self.partial_factor is None:
return 0, self.head_dim, 0
return (
int(self.partial_factor[0] * self.head_dim),
self.proj_width,
int((1.0 - self.partial_factor[1]) * self.head_dim),
)
def forward(self, query, key, query_id, kv_id):
if self.partial_factor is not None:
queries = list(query.split(self.split_sizes, dim=-1))
keys = list(key.split(self.split_sizes, dim=-1))
queries[1] = self.query_proj(queries[1], seq_id=query_id)
keys[1] = self.key_proj(keys[1], seq_id=kv_id)
query = torch.cat(queries, dim=-1)
key = torch.cat(keys, dim=-1)
else:
query = self.query_proj(query, seq_id=query_id)
key = self.key_proj(key, seq_id=kv_id)
return query, key