Skip to content

Commit 3cdebdc

Browse files
committed
Add beam search generation w/ Flashlight Text
1 parent 183b80c commit 3cdebdc

File tree

3 files changed

+232
-53
lines changed

3 files changed

+232
-53
lines changed

notebooks/hf_with_torchtext_gen.ipynb

+121-8
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"name": "stderr",
1717
"output_type": "stream",
1818
"text": [
19-
"/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/tqdm-4.64.0-py3.9.egg/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
19+
"/data/home/jrcummings/miniconda/envs/torchtext/lib/python3.9/site-packages/tqdm-4.64.1-py3.9.egg/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
2020
" from .autonotebook import tqdm as notebook_tqdm\n"
2121
]
2222
}
@@ -39,14 +39,14 @@
3939
},
4040
{
4141
"cell_type": "code",
42-
"execution_count": 3,
42+
"execution_count": 5,
4343
"metadata": {},
4444
"outputs": [
4545
{
4646
"name": "stderr",
4747
"output_type": "stream",
4848
"text": [
49-
"/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/transformers/models/t5/tokenization_t5.py:164: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n",
49+
"/data/home/jrcummings/miniconda/envs/torchtext/lib/python3.9/site-packages/transformers/models/t5/tokenization_t5.py:163: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n",
5050
"For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n",
5151
"- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.\n",
5252
"- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n",
@@ -74,7 +74,55 @@
7474
},
7575
{
7676
"cell_type": "code",
77-
"execution_count": 4,
77+
"execution_count": 6,
78+
"metadata": {},
79+
"outputs": [
80+
{
81+
"name": "stdout",
82+
"output_type": "stream",
83+
"text": [
84+
"['a dog is good for you. studies have shown that dog ownership is good for your overall health and well-being.']\n"
85+
]
86+
}
87+
],
88+
"source": [
89+
"# Testing HuggingFace's T5 w/ Beam Search\n",
90+
"tokens = generative_hf_t5.generate(test_sequence_tk, max_len=100, pad_idx=t5.config.pad_token_id, num_beams=5, beam_size_token=t5.config.vocab_size)\n",
91+
"print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True))"
92+
]
93+
},
94+
{
95+
"cell_type": "code",
96+
"execution_count": 7,
97+
"metadata": {},
98+
"outputs": [
99+
{
100+
"name": "stdout",
101+
"output_type": "stream",
102+
"text": [
103+
"['a dog is good for you. studies have shown that dog ownership is good for your overall health and well-being.'] 9.786320924758911\n",
104+
"['studies have shown that owning a dog is good for you. studies have shown that owning a dog is good for you.'] 1.3000121116638184\n"
105+
]
106+
}
107+
],
108+
"source": [
109+
"# Testing Decoding Speed HuggingFace's T5 w/ TorchText Beam Search vs. HuggingFace Beam Search\n",
110+
"import time\n",
111+
"\n",
112+
"start = time.time()\n",
113+
"tokens = generative_hf_t5.generate(test_sequence_tk, max_len=100, pad_idx=t5.config.pad_token_id, num_beams=5, beam_size_token=t5.config.vocab_size)\n",
114+
"end = time.time()\n",
115+
"print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)\n",
116+
"\n",
117+
"start = time.time()\n",
118+
"tokens = t5.generate(test_sequence_tk, max_length=100, num_beams=5, do_sample=False)\n",
119+
"end = time.time()\n",
120+
"print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)"
121+
]
122+
},
123+
{
124+
"cell_type": "code",
125+
"execution_count": 8,
78126
"metadata": {},
79127
"outputs": [
80128
{
@@ -99,7 +147,54 @@
99147
},
100148
{
101149
"cell_type": "code",
102-
"execution_count": 5,
150+
"execution_count": 9,
151+
"metadata": {},
152+
"outputs": [
153+
{
154+
"name": "stdout",
155+
"output_type": "stream",
156+
"text": [
157+
"['Nearly. PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions.']\n"
158+
]
159+
}
160+
],
161+
"source": [
162+
"tokens = generative_hf_bart.generate(test_sequence_tk, max_len=20, pad_idx=bart.config.pad_token_id, num_beams=5, beam_size_token=bart.config.vocab_size)\n",
163+
"print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True))\n"
164+
]
165+
},
166+
{
167+
"cell_type": "code",
168+
"execution_count": 10,
169+
"metadata": {},
170+
"outputs": [
171+
{
172+
"name": "stdout",
173+
"output_type": "stream",
174+
"text": [
175+
"['PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs. The blackouts are expected to last through at least midday tomorrow. to be affected by the shutoffs which were expected to last through at least midday tomorrow. to be affected by the shutoffs which were expected to last through at least midday tomorrow. to be affected by the'] 58.09997892379761\n",
176+
"['PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs. The blackouts were expected to last through at least midday tomorrow.'] 2.456479787826538\n"
177+
]
178+
}
179+
],
180+
"source": [
181+
"# Testing Decoding Speed HuggingFace's BART w/ TorchText Beam Search vs. HuggingFace Beam Search\n",
182+
"import time\n",
183+
"\n",
184+
"start = time.time()\n",
185+
"tokens = generative_hf_bart.generate(test_sequence_tk, max_len=100, pad_idx=t5.config.pad_token_id, num_beams=5, eos_score=1.0, beam_size_token=t5.config.vocab_size)\n",
186+
"end = time.time()\n",
187+
"print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)\n",
188+
"\n",
189+
"start = time.time()\n",
190+
"tokens = bart.generate(test_sequence_tk, max_length=100, num_beams=5, do_sample=False)\n",
191+
"end = time.time()\n",
192+
"print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)"
193+
]
194+
},
195+
{
196+
"cell_type": "code",
197+
"execution_count": 3,
103198
"metadata": {},
104199
"outputs": [
105200
{
@@ -119,11 +214,29 @@
119214
"tokens = generative_hf_gpt2.generate(test_sequence_tk, max_len=20, pad_idx=gpt2.config.pad_token_id)\n",
120215
"print(gpt2_tokenizer.batch_decode(tokens, skip_special_tokens=True))"
121216
]
217+
},
218+
{
219+
"cell_type": "code",
220+
"execution_count": 4,
221+
"metadata": {},
222+
"outputs": [
223+
{
224+
"name": "stdout",
225+
"output_type": "stream",
226+
"text": [
227+
"['I enjoy walking with my cute dog,\" says Kelli Williams-Petersen. The dog loves it so much, that when she']\n"
228+
]
229+
}
230+
],
231+
"source": [
232+
"tokens = generative_hf_gpt2.generate(test_sequence_tk, max_len=20, pad_idx=gpt2.config.pad_token_id, num_beams=5, beam_size_token=gpt2.config.vocab_size)\n",
233+
"print(gpt2_tokenizer.batch_decode(tokens, skip_special_tokens=True))"
234+
]
122235
}
123236
],
124237
"metadata": {
125238
"kernelspec": {
126-
"display_name": "Python 3.9.13 ('torchtext39')",
239+
"display_name": "torchtext",
127240
"language": "python",
128241
"name": "python3"
129242
},
@@ -137,12 +250,12 @@
137250
"name": "python",
138251
"nbconvert_exporter": "python",
139252
"pygments_lexer": "ipython3",
140-
"version": "3.9.13"
253+
"version": "3.9.15"
141254
},
142255
"orig_nbformat": 4,
143256
"vscode": {
144257
"interpreter": {
145-
"hash": "63c8862cb56f124e3ee7674b73de745eeb216416a9b24f78d1fcb7c775bff1b7"
258+
"hash": "1851d106532ddfc6fbd983b9ae95397243fcc3930d811046c990ea169e960650"
146259
}
147260
}
148261
},

test/torchtext_unittest/prototype/test_generate.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,18 @@ def test_warns_when_no_max_len_provided(self, mock) -> None:
5555
def test_beam_search(self) -> None:
5656
generation_model = GenerationUtil(self.model)
5757

58-
tokens = generation_model.generate(self.inputs, num_beams=3, max_len=30)
58+
tokens = generation_model.generate(self.inputs, num_beams=3, max_len=30, beam_size_token=self.model.config.vocab_size)
5959

6060
generated_text = self.transform.decode(tokens.tolist())
6161

62-
import pdb
63-
pdb.set_trace()
62+
expected_generated_text = [
63+
'kate mccartney: a dog is good for you . she says studies have shown that dog ownership is good for',
64+
'Das ist gut.',
65+
'acceptable',
66+
'4.0',
67+
'a tornado ripped through a swath of a lake in st. louis . a s'
68+
]
69+
70+
self.assertEqual(generated_text, expected_generated_text)
6471

6572

0 commit comments

Comments
 (0)