-
Notifications
You must be signed in to change notification settings - Fork 11.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add support for llama adapters #528
Comments
so this is not lora? |
As I understand it, they are inserting a learnt layer in the frozen LLaMA network: I'm a bit confused about why they mentioned Alpaca-LoRA since I can't seem to find any low-rank auto-encoder layers. Maybe they used it for the training data? |
According to the paper LLaMA-Adapter is not LoRA |
Looks like LLaMA-Adapter has a modified copy of the llama package, where -- llama/model.py 2023-04-10 18:18:14
+++ ../LLaMA-Adapter/llama/model.py 2023-04-10 18:11:49
@@ -29,7 +28,10 @@
max_batch_size: int = 32
max_seq_len: int = 2048
+ adapter_len: int=10
+ adapter_layer: int=8
+
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
@@ -115,8 +117,9 @@
self.cache_v = torch.zeros(
(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
).cuda()
+ self.gate = torch.nn.Parameter(torch.zeros(1))
- def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
+ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
@@ -135,6 +138,12 @@
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
+ if adapter is not None:
+ adapter_len = adapter.shape[1]
+ adapter_k = self.wk(adapter).view(1, adapter_len, self.n_local_heads, self.head_dim).repeat(bsz, 1, 1, 1)
+ adapter_v = self.wv(adapter).view(1, adapter_len, self.n_local_heads, self.head_dim).repeat(bsz, 1, 1, 1)
+ adapter_k = adapter_k.transpose(1, 2)
+ adapter_v = adapter_v.transpose(1, 2)
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
@@ -143,6 +152,10 @@
scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
+ if adapter is not None:
+ adapter_scores = torch.matmul(xq, adapter_k.transpose(2, 3)) / math.sqrt(self.head_dim)
+ adapter_scores = self.gate * F.softmax(adapter_scores.float(), dim=-1).type_as(xq)
+ output = output + torch.matmul(adapter_scores, adapter_v)
output = output.transpose(
1, 2
).contiguous().view(bsz, seqlen, -1)
@@ -189,8 +202,8 @@
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
- def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
- h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask)
+ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None):
+ h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask, adapter)
out = h + self.feed_forward.forward(self.ffn_norm(h))
return out
@@ -218,6 +231,9 @@
self.freqs_cis = precompute_freqs_cis(
self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
)
+ self.adapter_query = nn.Embedding(params.adapter_len * params.adapter_layer, params.dim)
+ self.adapter_len = params.adapter_len
+ self.adapter_layer = params.adapter_layer
@torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int):
@@ -225,14 +241,18 @@
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
-
+ prompt = self.adapter_query.weight.reshape(self.params.adapter_layer, self.params.adapter_len, self.params.dim).unsqueeze(1)
mask = None
if seqlen > 1:
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
- for layer in self.layers:
+ for layer in self.layers[: -1 * self.params.adapter_layer]:
h = layer(h, start_pos, freqs_cis, mask)
+ layer_index = 0
+ for layer in self.layers[-1 * self.params.adapter_layer:]:
+ h = layer(h, start_pos, freqs_cis, mask, prompt[layer_index])
+ layer_index = layer_index + 1
h = self.norm(h)
output = self.output(h[:, -1, :]) # only compute last logits
return output.float() |
This issue was closed because it has been inactive for 14 days since being marked as stale. |
implement support for running models that use Llama adapter
https://github.com/ZrrSkywalker/LLaMA-Adapter
described here how to get the model
https://github.com/ZrrSkywalker/LLaMA-Adapter#inference
The text was updated successfully, but these errors were encountered: