from model import Transformer from tokenizer import Tokenizer from typing import Optional, Tuple, List import time from pathlib import Path import json from logging import getLogger import torch import torch.nn.functional as F logger = getLogger("__name__") class Llama: """ This class is the main entrypoint for doing inference on llama models. The `build` method allows you to build the model from a set of checkpoint weights and a tokenizer model. The `generate` method """ @staticmethod def build( ckpt_dir: str, tokenizer_path: str, max_seq_len: int = 4096, max_batch_size: int = 1, model_parallel_size: Optional[int] = None, device: str = "cpu", float_type = torch.FloatTensor) -> "Llama": assert float_type in [torch.FloatTensor, torch.cuda.HalfTensor], "Only support cuda HalfTensor and FloatTensor" assert model_parallel_size is None, "This version doesn't support model parallel" # Set the seed to ensure reproducability torch.manual_seed(42) start_time = time.time() checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) assert len(checkpoints) > 0, f"No checkpoints in {ckpt_dir}" assert len(checkpoints) == 1, f"More than one checkpoint found in {ckpt_dir}, this version of llama only supports one checkpoint" ckpt_path = checkpoints[0] # we only get the first... cause there should only be one. checkpoint = torch.load(ckpt_path, map_location="cpu") # we can probably load to gpu here, because there is only one... with open(Path(ckpt_dir) / "params.json", "r") as f: params = json.loads(f.read()) params["max_seq_len"] = max_seq_len params["max_batch_size"] = max_batch_size # This parameter is None for the smaller model we are working with. params["ffn_dim_multiplier"] = None # load the tokenizer, and get the vocab size from it. tokenizer = Tokenizer(model_path=tokenizer_path) params["vocab_size"] = tokenizer.n_words logger.info(f"{params=}") # set the default tensor type. This torch.set_default_tensor_type(float_type) model = Transformer(**params) logger.info(f"state_dict_map: {list(checkpoint.keys())}") missing, unexpected = model.load_state_dict(checkpoint, strict=False) logger.info(f"unexpected_keys: {unexpected}") logger.info(f"missing_keys: {missing}") model.to(device=device) print(f"loaded in {time.time() - start_time:.2f} seconds") return Llama(model, tokenizer, device) def __init__(self, model: Transformer, tokenizer: Tokenizer, device: Optional[str] = "cuda"): self.model = model self.tokenizer = tokenizer self.device = device @torch.inference_mode() def generate( self, prompt_tokens: List[List[int]], max_gen_len: int, temperature: float = 0.6, top_p: float = 0.9, logprobs: bool = False, echo: bool = False ) -> Tuple[List[List[int]], Optional[List[List[float]]]]: bsz = len(prompt_tokens) assert bsz <= self.model.max_batch_size, f"batch size too large: ({bsz},{self.model.max_batch_size})" min_prompt_len = min(len(t) for t in prompt_tokens) max_prompt_len = max(len(t) for t in prompt_tokens) assert max_prompt_len <= self.model.max_seq_len # figure out what the longest sequence we are expecting is. total_len = min(self.model.max_seq_len, max_gen_len + max_prompt_len) pad_id = self.tokenizer.pad_id # we create a tensor filled with the pad_id for our prompt/output tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=self.device) for k, t in enumerate(prompt_tokens): # and pack it with the batch of prompts tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=self.device) if logprobs: # if we are looking for the logprobs, we create an output tensor filled with zeros for them token_logprobs = torch.zeros_like(tokens, dtype=torch.float) prev_pos = 0 eos_reached = torch.tensor([False] * bsz, device=self.device) input_text_mask = tokens != pad_id for cur_pos in range(min_prompt_len, total_len): logger.info(f"tokens: {tokens.shape}, {tokens}") logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) logger.info(f"logits: {logits.shape}, {logits[:, -1 ,:300]}") if temperature > 0: probs = torch.softmax(logits[:, -1] / temperature, dim=-1) logger.info(f"probs: {probs.shape}, {probs}[-1,:300]") next_token = sample_top_p(probs, top_p) else: next_token = torch.argmax(logits[:, -1], dim=-1) next_token = next_token.reshape(-1) logger.info(f"next_token: {next_token.tolist()}") logger.info(f"next_token: {self.tokenizer.decode(next_token.tolist())}") next_token = torch.where( input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token) tokens[:, cur_pos] = next_token if logprobs: token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( input = logits.transpose(1,2), target = tokens[:, prev_pos + 1 : cur_pos + 1], reduction = "none", ignore_index = pad_id) logger.info(f"logprobs: {token_logprobs.shape}, {token_logprobs}") eos_reached |= (~input_text_mask[:, cur_pos]) & (next_token == self.tokenizer.eos_id) prev_pos = cur_pos if all(eos_reached): break if logprobs: token_logprobs = token_logprobs.tolist() out_tokens, out_logprobs = [], [] for i, toks in enumerate(tokens.tolist()): start = 0 if echo else len(prompt_tokens[i]) toks = toks[start : len(prompt_tokens[i]) + max_gen_len] probs = None if logprobs: probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len] # cut to eos if any if self.tokenizer.eos_id in toks: eos_idx = toks.index(self.tokenizer.eos_id) toks = toks[:eos_idx] probs = probs[:eos_idx] if logprobs else None out_tokens.append(toks) out_logprobs.append(probs) return (out_tokens, out_logprobs if logprobs else None) def text_completion( self, prompts: List[str], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None, logprobs: bool = False, echo: bool = False, ): if max_gen_len is None: max_gen_len = self.model.max_seq_len - 1 prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] generation_tokens, generation_logprobs = self.generate( prompt_tokens=prompt_tokens, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, logprobs=logprobs, echo=echo, ) if logprobs: return [ { "generation": self.tokenizer.decode(t), "tokens": [self.tokenizer.decode(x) for x in t], "logprobs": logprobs_i, } for t, logprobs_i in zip(generation_tokens, generation_logprobs) ] return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens] def sample_top_p(probs, p): probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) logger.info(f"top 10 indexes: {probs_idx[:,:10]}") logger.info(f"top 10 temperature probs: {probs_sort[:,:10]}") probs_sum = torch.cumsum(probs_sort, dim=-1) mask = probs_sum - probs_sort > p probs_sort[mask] = 0.0 probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) next_token = torch.multinomial(probs_sort, num_samples=1) next_token = torch.gather(probs_idx, -1, next_token) return next_token