forked from pytorch/text
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransforms.py
497 lines (409 loc) · 17.6 KB
/
transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
import json
from copy import deepcopy
from functools import lru_cache
from typing import Any, List, Optional, Union
import torch
import torchtext # noqa: F401
from torch import Tensor
from torch.nn import Module
from torchtext._torchtext import CLIPEncoder as CLIPEncoderPyBind, GPT2BPEEncoder as GPT2BPEEncoderPyBind
from torchtext.data.functional import load_sp_model
from torchtext.utils import get_asset_local_path
from torchtext.vocab import Vocab
from . import functional as F
__all__ = [
"SentencePieceTokenizer",
"VocabTransform",
"ToTensor",
"LabelToIndex",
"Truncate",
"AddToken",
"PadTransform",
"StrToIntTransform",
"GPT2BPETokenizer",
"Sequential",
]
class SentencePieceTokenizer(Module):
"""
Transform for Sentence Piece tokenizer from pre-trained sentencepiece model
Additiona details: https://github.com/google/sentencepiece
:param sp_model_path: Path to pre-trained sentencepiece model
:type sp_model_path: str
Example
>>> from torchtext.transforms import SpmTokenizerTransform
>>> transform = SentencePieceTokenizer("spm_model")
>>> transform(["hello world", "attention is all you need!"])
"""
def __init__(self, sp_model_path: str):
super().__init__()
self.sp_model = load_sp_model(get_asset_local_path(sp_model_path))
def forward(self, input: Any) -> Any:
"""
:param input: Input sentence or list of sentences on which to apply tokenizer.
:type input: Union[str, List[str]]
:return: tokenized text
:rtype: Union[List[str], List[List(str)]]
"""
if torch.jit.isinstance(input, List[str]):
tokens: List[List[str]] = []
for text in input:
tokens.append(self.sp_model.EncodeAsPieces(text))
return tokens
elif torch.jit.isinstance(input, str):
return self.sp_model.EncodeAsPieces(input)
else:
raise TypeError("Input type not supported")
class VocabTransform(Module):
r"""Vocab transform to convert input batch of tokens into corresponding token ids
:param vocab: an instance of :class:`torchtext.vocab.Vocab` class.
Example:
>>> import torch
>>> from torchtext.vocab import vocab
>>> from torchtext.transforms import VocabTransform
>>> from collections import OrderedDict
>>> vocab_obj = vocab(OrderedDict([('a', 1), ('b', 1), ('c', 1)]))
>>> vocab_transform = VocabTransform(vocab_obj)
>>> output = vocab_transform([['a','b'],['a','b','c']])
>>> jit_vocab_transform = torch.jit.script(vocab_transform)
"""
def __init__(self, vocab: Vocab):
super().__init__()
assert isinstance(vocab, Vocab)
self.vocab = vocab
def forward(self, input: Any) -> Any:
"""
:param input: Input batch of token to convert to correspnding token ids
:type input: Union[List[str], List[List[str]]]
:return: Converted input into corresponding token ids
:rtype: Union[List[int], List[List[int]]]
"""
if torch.jit.isinstance(input, List[str]):
return self.vocab.lookup_indices(input)
elif torch.jit.isinstance(input, List[List[str]]):
output: List[List[int]] = []
for tokens in input:
output.append(self.vocab.lookup_indices(tokens))
return output
else:
raise TypeError("Input type not supported")
class ToTensor(Module):
r"""Convert input to torch tensor
:param padding_value: Pad value to make each input in the batch of length equal to the longest sequence in the batch.
:type padding_value: Optional[int]
:param dtype: :class:`torch.dtype` of output tensor
:type dtype: :class:`torch.dtype`
"""
def __init__(self, padding_value: Optional[int] = None, dtype: torch.dtype = torch.long) -> None:
super().__init__()
self.padding_value = padding_value
self.dtype = dtype
def forward(self, input: Any) -> Tensor:
"""
:param input: Sequence or batch of token ids
:type input: Union[List[int], List[List[int]]]
:rtype: Tensor
"""
return F.to_tensor(input, padding_value=self.padding_value, dtype=self.dtype)
class LabelToIndex(Module):
r"""
Transform labels from string names to ids.
:param label_names: a list of unique label names
:type label_names: Optional[List[str]]
:param label_path: a path to file containing unique label names containing 1 label per line. Note that either label_names or label_path should be supplied
but not both.
:type label_path: Optional[str]
"""
def __init__(
self,
label_names: Optional[List[str]] = None,
label_path: Optional[str] = None,
sort_names=False,
):
assert label_names or label_path, "label_names or label_path is required"
assert not (label_names and label_path), "label_names and label_path are mutually exclusive"
super().__init__()
if label_path:
with open(label_path, "r") as f:
label_names = [line.strip() for line in f if line.strip()]
else:
label_names = label_names
if sort_names:
label_names = sorted(label_names)
self._label_vocab = Vocab(torch.classes.torchtext.Vocab(label_names, None))
self._label_names = self._label_vocab.get_itos()
def forward(self, input: Any) -> Any:
"""
:param input: Input labels to convert to corresponding ids
:type input: Union[str, List[str]]
:rtype: Union[int, List[int]]
"""
if torch.jit.isinstance(input, List[str]):
return self._label_vocab.lookup_indices(input)
elif torch.jit.isinstance(input, str):
return self._label_vocab.__getitem__(input)
else:
raise TypeError("Input type not supported")
@property
def label_names(self) -> List[str]:
return self._label_names
class Truncate(Module):
r"""Truncate input sequence
:param max_seq_len: The maximum allowable length for input sequence
:type max_seq_len: int
"""
def __init__(self, max_seq_len: int) -> None:
super().__init__()
self.max_seq_len = max_seq_len
def forward(self, input: Any) -> Any:
"""
:param input: Input sequence or batch of sequence to be truncated
:type input: Union[List[Union[str, int]], List[List[Union[str, int]]]]
:return: Truncated sequence
:rtype: Union[List[Union[str, int]], List[List[Union[str, int]]]]
"""
return F.truncate(input, self.max_seq_len)
class AddToken(Module):
"""Add token to beginning or end of sequence
:param token: The token to be added
:type token: Union[int, str]
:param begin: Whether to insert token at start or end or sequence, defaults to True
:type begin: bool, optional
"""
def __init__(self, token: Union[int, str], begin: bool = True) -> None:
super().__init__()
self.token = token
self.begin = begin
def forward(self, input: Any) -> Any:
"""
:param input: Input sequence or batch
:type input: Union[List[Union[str, int]], List[List[Union[str, int]]]]
"""
return F.add_token(input, self.token, self.begin)
class PadTransform(Module):
def __init__(self, max_length: int, pad_value: int):
super().__init__()
self.max_length = max_length
self.pad_value = pad_value
def forward(self, x: torch.Tensor) -> torch.Tensor:
max_encoded_length = x.size(-1)
if max_encoded_length < self.max_length:
pad_amount = self.max_length - max_encoded_length
x = torch.nn.functional.pad(x, (0, pad_amount), value=self.pad_value)
return x
class StrToIntTransform(Module):
def __init__(self):
super().__init__()
def forward(self, input: Union[List[str], List[List[str]]]) -> Union[List[int], List[List[int]]]:
if isinstance(input[0], str):
return [int(x) for x in input] # type: ignore
if isinstance(input[0], List) and isinstance(input[0][0], str):
return [[int(x) for x in ll] for ll in input]
else:
raise TypeError("Input type not supported")
class GPT2BPETokenizer(Module):
__jit_unused_properties__ = ["is_jitable"]
"""
Transform for GPT-2 BPE Tokenizer.
Reimplements openai GPT-2 BPE in TorchScript. Original openai implementation
https://github.com/openai/gpt-2/blob/master/src/encoder.py
:param encoder_json_path: Path to GPT-2 BPE encoder json file.
:type encoder_json_path: str
:param vocab_bpe_path: Path to bpe vocab file.
:type vocab_bpe_path: str
"""
_seperator: torch.jit.Final[str]
def __init__(
self,
encoder_json_path: str,
vocab_bpe_path: str,
):
super().__init__()
self._seperator = "\u0001"
# load bpe encoder and bpe decoder
with open(get_asset_local_path(encoder_json_path), "r", encoding="utf-8") as f:
bpe_encoder = json.load(f)
# load bpe vocab
with open(get_asset_local_path(vocab_bpe_path), "r", encoding="utf-8") as f:
bpe_vocab = f.read()
bpe_merge_ranks = {
self._seperator.join(merge_pair.split()): i for i, merge_pair in enumerate(bpe_vocab.split("\n")[1:-1])
}
# Caching is enabled in Eager mode
self.bpe = GPT2BPEEncoderPyBind(bpe_encoder, bpe_merge_ranks, self._seperator, bytes_to_unicode(), True)
@property
def is_jitable(self):
return isinstance(self.bpe, torch._C.ScriptObject)
@torch.jit.export
def _tokenize(self, text: str) -> List[str]:
"""Encode text into a list of tokens
Args:
text: An input text string.
Returns:
A list of bpe token ids represents each bpe tokens
For example: "awesome,awe"
--> bpe --> bpe tokens: ["aw", "esome"], [","], ["aw", e]
--> bpe encode --> bpe token ids: [707, 5927, 11, 707, 68]
"""
bpe_token_ids: List[int] = self.bpe.encode(text)
bpe_tokens: List[str] = []
for bpe_token_id in bpe_token_ids:
bpe_tokens.append(str(bpe_token_id))
return bpe_tokens
def forward(self, input: Any) -> Any:
"""
:param input: Input sentence or list of sentences on which to apply tokenizer.
:type input: Union[str, List[str]]
:return: tokenized text
:rtype: Union[List[str], List[List(str)]]
"""
if torch.jit.isinstance(input, List[str]):
tokens: List[List[str]] = []
for text in input:
tokens.append(self._tokenize(text))
return tokens
elif torch.jit.isinstance(input, str):
return self._tokenize(input)
else:
raise TypeError("Input type not supported")
def __prepare_scriptable__(self):
r"""Return a JITable tokenizer."""
if not self.is_jitable:
tokenizer_copy = deepcopy(self)
# Disable caching in script mode
tokenizer_copy.bpe = torch.classes.torchtext.GPT2BPEEncoder(
self.bpe.bpe_encoder_, self.bpe.bpe_merge_ranks_, self.bpe.seperator_, self.bpe.byte_encoder_, False
)
return tokenizer_copy
return self
class CLIPTokenizer(Module):
__jit_unused_properties__ = ["is_jitable"]
"""
Transform for CLIP Tokenizer. Based on Byte-Level BPE.
Reimplements CLIP Tokenizer in TorchScript. Original implementation:
https://github.com/mlfoundations/open_clip/blob/main/src/clip/tokenizer.py
This tokenizer has been trained to treat spaces like parts of the tokens
(a bit like sentencepiece) so a word will be encoded differently whether it
is at the beginning of the sentence (without space) or not.
The below code snippet shows how to use the CLIP tokenizer with encoder and merges file
taken from the original paper implementation.
Example
>>> from torchtext.transforms import CLIPTokenizer
>>> MERGES_FILE = "http://download.pytorch.org/models/text/clip_merges.bpe"
>>> ENCODER_FILE = "http://download.pytorch.org/models/text/clip_encoder.json"
>>> tokenizer = CLIPTokenizer(merges_path=MERGES_FILE, encoder_json_path=ENCODER_FILE)
>>> tokenizer("the quick brown fox jumped over the lazy dog")
:param merges_path: Path to bpe merges file.
:type merges_path: str
:param encoder_json_path: Optional, path to BPE encoder json file. When specified, this is used
to infer num_merges.
:type encoder_json_path: str
:param num_merges: Optional, number of merges to read from the bpe merges file.
:type num_merges: int
"""
_seperator: torch.jit.Final[str]
def __init__(self, merges_path: str, encoder_json_path: Optional[str] = None, num_merges: Optional[int] = None):
super().__init__()
self._seperator = "\u0001"
# load bpe merges
with open(get_asset_local_path(merges_path), "r", encoding="utf-8") as f:
bpe_merges = f.read().split("\n")[1:]
if encoder_json_path:
# load bpe encoder
with open(get_asset_local_path(encoder_json_path), "r", encoding="utf-8") as f:
bpe_encoder = json.load(f)
# 256 * 2 for each byte. For each byte we have ['a', 'a</w>']
# Additional 2 tokens for bos and eos
num_merges = len(bpe_encoder) - (256 * 2 + 2)
bpe_merge_ranks = {
self._seperator.join(merge_pair.split()): i for i, merge_pair in enumerate(bpe_merges[:num_merges])
}
else:
num_merges = num_merges or len(bpe_merges)
bpe_merge_ranks = {
self._seperator.join(merge_pair.split()): i for i, merge_pair in enumerate(bpe_merges[:num_merges])
}
bpe_vocab = list(bytes_to_unicode().values())
bpe_vocab = bpe_vocab + [v + "</w>" for v in bpe_vocab]
bpe_vocab.extend(["".join(merge_pair.split()) for merge_pair in bpe_merges[:num_merges]])
bpe_vocab.extend(["<|startoftext|>", "<|endoftext|>"])
bpe_encoder = {v: i for i, v in enumerate(bpe_vocab)}
# Caching is enabled in Eager mode
self.bpe = CLIPEncoderPyBind(bpe_encoder, bpe_merge_ranks, self._seperator, bytes_to_unicode(), True)
@property
def is_jitable(self):
return isinstance(self.bpe, torch._C.ScriptObject)
@torch.jit.export
def _tokenize(self, text: str) -> List[str]:
"""Encode text into a list of tokens
Args:
text: An input text string.
Returns:
A list of bpe token ids represents each bpe tokens
For example: "awesome,awe"
--> bpe --> bpe tokens: ["aw", "esome"], [","], ["aw", "e"]
--> bpe encode --> bpe token ids: [707, 5927, 11, 707, 68]
"""
text = text.lower().strip()
bpe_token_ids: List[int] = self.bpe.encode(text)
bpe_tokens: List[str] = []
for bpe_token_id in bpe_token_ids:
bpe_tokens.append(str(bpe_token_id))
return bpe_tokens
def forward(self, input: Any) -> Any:
"""
:param input: Input sentence or list of sentences on which to apply tokenizer.
:type input: Union[str, List[str]]
:return: tokenized text
:rtype: Union[List[str], List[List(str)]]
"""
if torch.jit.isinstance(input, List[str]):
tokens: List[List[str]] = []
for text in input:
tokens.append(self._tokenize(text))
return tokens
elif torch.jit.isinstance(input, str):
return self._tokenize(input)
else:
raise TypeError("Input type not supported")
def __prepare_scriptable__(self):
r"""Return a JITable tokenizer."""
if not self.is_jitable:
tokenizer_copy = deepcopy(self)
# Disable caching in script mode
tokenizer_copy.bpe = torch.classes.torchtext.CLIPEncoder(
self.bpe.bpe_encoder_, self.bpe.bpe_merge_ranks_, self.bpe.seperator_, self.bpe.byte_encoder_, False
)
return tokenizer_copy
return self
@lru_cache()
def bytes_to_unicode():
"""
Original Source: https://github.com/openai/gpt-2/blob/master/src/encoder.py#L9
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2 ** 8):
if b not in bs:
bs.append(b)
cs.append(2 ** 8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
class Sequential(torch.nn.Sequential):
r"""A container to host a sequence of text transforms."""
def forward(self, input: Any) -> Any:
"""
:param input: Input sequence or batch. The input type must be supported by the first transform in the sequence.
:type input: `Any`
"""
for module in self:
input = module(input)
return input