Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for llama adapters #528

Closed
redthing1 opened this issue Mar 26, 2023 · 5 comments
Closed

add support for llama adapters #528

redthing1 opened this issue Mar 26, 2023 · 5 comments
Labels
enhancement New feature or request model Model specific stale

Comments

@redthing1
Copy link

implement support for running models that use Llama adapter
https://github.com/ZrrSkywalker/LLaMA-Adapter

described here how to get the model

https://github.com/ZrrSkywalker/LLaMA-Adapter#inference

@Green-Sky
Copy link
Collaborator

so this is not lora?

@Green-Sky Green-Sky added enhancement New feature or request model Model specific labels Mar 26, 2023
@Piezoid
Copy link
Contributor

Piezoid commented Mar 26, 2023

As I understand it, they are inserting a learnt layer in the frozen LLaMA network:
https://github.com/ZrrSkywalker/LLaMA-Adapter/blob/main/llama/model.py#L141-L158

I'm a bit confused about why they mentioned Alpaca-LoRA since I can't seem to find any low-rank auto-encoder layers. Maybe they used it for the training data?

@maaaax
Copy link

maaaax commented Apr 1, 2023

As I understand it, they are inserting a learnt layer in the frozen LLaMA network: https://github.com/ZrrSkywalker/LLaMA-Adapter/blob/main/llama/model.py#L141-L158

I'm a bit confused about why they mentioned Alpaca-LoRA since I can't seem to find any low-rank auto-encoder layers. Maybe they used it for the training data?

According to the paper LLaMA-Adapter is not LoRA

@tmm1
Copy link

tmm1 commented Apr 11, 2023

Looks like LLaMA-Adapter has a modified copy of the llama package, where ModelArgs has been taught to load the adapter weights:

-- llama/model.py      2023-04-10 18:18:14                      
+++ ../LLaMA-Adapter/llama/model.py     2023-04-10 18:11:49      
@@ -29,7 +28,10 @@              
     max_batch_size: int = 32   
     max_seq_len: int = 2048    
 
+    adapter_len: int=10        
+    adapter_layer: int=8       
 
+                               
 class RMSNorm(torch.nn.Module):                                 
     def __init__(self, dim: int, eps: float = 1e-6):            
         super().__init__()     
@@ -115,8 +117,9 @@             
         self.cache_v = torch.zeros(                             
             (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)           
         ).cuda()               
+        self.gate = torch.nn.Parameter(torch.zeros(1))          
 
-    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):                    
+    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None):      
         bsz, seqlen, _ = x.shape                                
         xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)         
 
@@ -135,6 +138,12 @@            
         keys = self.cache_k[:bsz, : start_pos + seqlen]         
         values = self.cache_v[:bsz, : start_pos + seqlen]       
 
+        if adapter is not None:                                 
+           adapter_len = adapter.shape[1]                       
+           adapter_k = self.wk(adapter).view(1, adapter_len, self.n_local_heads, self.head_dim).repeat(bsz, 1, 1, 1)              
+           adapter_v = self.wv(adapter).view(1, adapter_len, self.n_local_heads, self.head_dim).repeat(bsz, 1, 1, 1)              
+           adapter_k = adapter_k.transpose(1, 2)                
+           adapter_v = adapter_v.transpose(1, 2)                
         xq = xq.transpose(1, 2)                                 
         keys = keys.transpose(1, 2)                             
         values = values.transpose(1, 2)                         
@@ -143,6 +152,10 @@            
             scores = scores + mask  # (bs, n_local_heads, slen, cache_len + slen)                
         scores = F.softmax(scores.float(), dim=-1).type_as(xq)  
         output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)             
+        if adapter is not None:                                 
+            adapter_scores = torch.matmul(xq, adapter_k.transpose(2, 3)) / math.sqrt(self.head_dim)                               
+            adapter_scores = self.gate * F.softmax(adapter_scores.float(), dim=-1).type_as(xq)   
+            output = output + torch.matmul(adapter_scores, adapter_v)                            
         output = output.transpose(                              
             1, 2               
         ).contiguous().view(bsz, seqlen, -1)                    
@@ -189,8 +202,8 @@             
         self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)                               
         self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)    
 
-    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):                    
-        h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask)       
+    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None):      
+        h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask, adapter)                               
         out = h + self.feed_forward.forward(self.ffn_norm(h))   
         return out             
 
@@ -218,6 +231,9 @@             
         self.freqs_cis = precompute_freqs_cis(                  
             self.params.dim // self.params.n_heads, self.params.max_seq_len * 2                  
         )                      
+        self.adapter_query = nn.Embedding(params.adapter_len * params.adapter_layer, params.dim) 
+        self.adapter_len = params.adapter_len                   
+        self.adapter_layer = params.adapter_layer               
 
     @torch.inference_mode()    
     def forward(self, tokens: torch.Tensor, start_pos: int):    
@@ -225,14 +241,18 @@           
         h = self.tok_embeddings(tokens)                         
         self.freqs_cis = self.freqs_cis.to(h.device)            
         freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]                               
-                               
+        prompt = self.adapter_query.weight.reshape(self.params.adapter_layer, self.params.adapter_len, self.params.dim).unsqueeze(1)                  
         mask = None            
         if seqlen > 1:         
             mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)       
             mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)                           
 
-        for layer in self.layers:                               
+        for layer in self.layers[: -1 * self.params.adapter_layer]:                              
             h = layer(h, start_pos, freqs_cis, mask)            
+        layer_index = 0        
+        for layer in self.layers[-1 * self.params.adapter_layer:]:                               
+            h = layer(h, start_pos, freqs_cis, mask, prompt[layer_index])                        
+            layer_index = layer_index + 1                       
         h = self.norm(h)       
         output = self.output(h[:, -1, :])  # only compute last logits                            
         return output.float()  

@github-actions github-actions bot added the stale label Mar 25, 2024
Copy link
Contributor

This issue was closed because it has been inactive for 14 days since being marked as stale.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request model Model specific stale
Projects
None yet
Development

No branches or pull requests

5 participants