-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate.py
210 lines (174 loc) · 8.48 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
from model import Transformer
from tokenizer import Tokenizer
from typing import Optional, Tuple, List
import time
from pathlib import Path
import json
from logging import getLogger
import torch
import torch.nn.functional as F
logger = getLogger("__name__")
class Llama:
"""
This class is the main entrypoint for doing inference on llama models.
The `build` method allows you to build the model from a set of checkpoint
weights and a tokenizer model.
The `generate` method
"""
@staticmethod
def build(
ckpt_dir: str,
tokenizer_path: str,
max_seq_len: int = 4096,
max_batch_size: int = 1,
model_parallel_size: Optional[int] = None,
device: str = "cpu",
float_type = torch.FloatTensor) -> "Llama":
assert float_type in [torch.FloatTensor, torch.cuda.HalfTensor], "Only support cuda HalfTensor and FloatTensor"
assert model_parallel_size is None, "This version doesn't support model parallel"
# Set the seed to ensure reproducability
torch.manual_seed(42)
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"No checkpoints in {ckpt_dir}"
assert len(checkpoints) == 1, f"More than one checkpoint found in {ckpt_dir}, this version of llama only supports one checkpoint"
ckpt_path = checkpoints[0] # we only get the first... cause there should only be one.
checkpoint = torch.load(ckpt_path, map_location="cpu") # we can probably load to gpu here, because there is only one...
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
params["max_seq_len"] = max_seq_len
params["max_batch_size"] = max_batch_size
# This parameter is None for the smaller model we are working with.
params["ffn_dim_multiplier"] = None
# load the tokenizer, and get the vocab size from it.
tokenizer = Tokenizer(model_path=tokenizer_path)
params["vocab_size"] = tokenizer.n_words
logger.info(f"{params=}")
# set the default tensor type. This
torch.set_default_tensor_type(float_type)
model = Transformer(**params)
logger.info(f"state_dict_map: {list(checkpoint.keys())}")
missing, unexpected = model.load_state_dict(checkpoint, strict=False)
logger.info(f"unexpected_keys: {unexpected}")
logger.info(f"missing_keys: {missing}")
model.to(device=device)
print(f"loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, device)
def __init__(self, model: Transformer, tokenizer: Tokenizer, device: Optional[str] = "cuda"):
self.model = model
self.tokenizer = tokenizer
self.device = device
@torch.inference_mode()
def generate(
self,
prompt_tokens: List[List[int]],
max_gen_len: int,
temperature: float = 0.6,
top_p: float = 0.9,
logprobs: bool = False,
echo: bool = False
) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
bsz = len(prompt_tokens)
assert bsz <= self.model.max_batch_size, f"batch size too large: ({bsz},{self.model.max_batch_size})"
min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
assert max_prompt_len <= self.model.max_seq_len
# figure out what the longest sequence we are expecting is.
total_len = min(self.model.max_seq_len, max_gen_len + max_prompt_len)
pad_id = self.tokenizer.pad_id
# we create a tensor filled with the pad_id for our prompt/output
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=self.device)
for k, t in enumerate(prompt_tokens):
# and pack it with the batch of prompts
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=self.device)
if logprobs:
# if we are looking for the logprobs, we create an output tensor filled with zeros for them
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device=self.device)
input_text_mask = tokens != pad_id
for cur_pos in range(min_prompt_len, total_len):
logger.info(f"tokens: {tokens.shape}, {tokens}")
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
logger.info(f"logits: {logits.shape}, {logits[:, -1 ,:300]}")
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
logger.info(f"probs: {probs.shape}, {probs}[-1,:300]")
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)
next_token = next_token.reshape(-1)
logger.info(f"next_token: {next_token.tolist()}")
logger.info(f"next_token: {self.tokenizer.decode(next_token.tolist())}")
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input = logits.transpose(1,2),
target = tokens[:, prev_pos + 1 : cur_pos + 1],
reduction = "none",
ignore_index = pad_id)
logger.info(f"logprobs: {token_logprobs.shape}, {token_logprobs}")
eos_reached |= (~input_text_mask[:, cur_pos]) & (next_token == self.tokenizer.eos_id)
prev_pos = cur_pos
if all(eos_reached):
break
if logprobs:
token_logprobs = token_logprobs.tolist()
out_tokens, out_logprobs = [], []
for i, toks in enumerate(tokens.tolist()):
start = 0 if echo else len(prompt_tokens[i])
toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
probs = None
if logprobs:
probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
# cut to eos if any
if self.tokenizer.eos_id in toks:
eos_idx = toks.index(self.tokenizer.eos_id)
toks = toks[:eos_idx]
probs = probs[:eos_idx] if logprobs else None
out_tokens.append(toks)
out_logprobs.append(probs)
return (out_tokens, out_logprobs if logprobs else None)
def text_completion(
self,
prompts: List[str],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
):
if max_gen_len is None:
max_gen_len = self.model.max_seq_len - 1
prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
generation_tokens, generation_logprobs = self.generate(
prompt_tokens=prompt_tokens,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
echo=echo,
)
if logprobs:
return [
{
"generation": self.tokenizer.decode(t),
"tokens": [self.tokenizer.decode(x) for x in t],
"logprobs": logprobs_i,
}
for t, logprobs_i in zip(generation_tokens, generation_logprobs)
]
return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens]
def sample_top_p(probs, p):
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
logger.info(f"top 10 indexes: {probs_idx[:,:10]}")
logger.info(f"top 10 temperature probs: {probs_sort[:,:10]}")
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token