From 79af4ae402f4df1274b639ba4826c4c39756aa32 Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Mon, 19 Dec 2022 10:16:19 -0500 Subject: [PATCH 1/4] Add beam search generation w/ Flashlight Text --- notebooks/hf_with_torchtext_gen.ipynb | 129 ++++++- test/integration_tests/test_generate.py | 88 ++++- .../prototype/test_generate.py | 109 ++++++ torchtext/prototype/generate.py | 356 ++++++++++++++++-- 4 files changed, 641 insertions(+), 41 deletions(-) create mode 100644 test/torchtext_unittest/prototype/test_generate.py diff --git a/notebooks/hf_with_torchtext_gen.ipynb b/notebooks/hf_with_torchtext_gen.ipynb index 0df74a4b39..44f4ccc3b3 100644 --- a/notebooks/hf_with_torchtext_gen.ipynb +++ b/notebooks/hf_with_torchtext_gen.ipynb @@ -16,7 +16,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/tqdm-4.64.0-py3.9.egg/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + "/data/home/jrcummings/miniconda/envs/torchtext/lib/python3.9/site-packages/tqdm-4.64.1-py3.9.egg/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } @@ -39,14 +39,14 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/transformers/models/t5/tokenization_t5.py:164: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n", + "/data/home/jrcummings/miniconda/envs/torchtext/lib/python3.9/site-packages/transformers/models/t5/tokenization_t5.py:163: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n", "For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n", "- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.\n", "- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n", @@ -74,7 +74,55 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['a dog is good for you. studies have shown that dog ownership is good for your overall health and well-being.']\n" + ] + } + ], + "source": [ + "# Testing HuggingFace's T5 w/ Beam Search\n", + "tokens = generative_hf_t5.generate(test_sequence_tk, max_len=100, pad_idx=t5.config.pad_token_id, num_beams=5, beam_size_token=t5.config.vocab_size)\n", + "print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['a dog is good for you. studies have shown that dog ownership is good for your overall health and well-being.'] 9.786320924758911\n", + "['studies have shown that owning a dog is good for you. studies have shown that owning a dog is good for you.'] 1.3000121116638184\n" + ] + } + ], + "source": [ + "# Testing Decoding Speed HuggingFace's T5 w/ TorchText Beam Search vs. HuggingFace Beam Search\n", + "import time\n", + "\n", + "start = time.time()\n", + "tokens = generative_hf_t5.generate(test_sequence_tk, max_len=100, pad_idx=t5.config.pad_token_id, num_beams=5, beam_size_token=t5.config.vocab_size)\n", + "end = time.time()\n", + "print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)\n", + "\n", + "start = time.time()\n", + "tokens = t5.generate(test_sequence_tk, max_length=100, num_beams=5, do_sample=False)\n", + "end = time.time()\n", + "print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -99,7 +147,54 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['Nearly. PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions.']\n" + ] + } + ], + "source": [ + "tokens = generative_hf_bart.generate(test_sequence_tk, max_len=20, pad_idx=bart.config.pad_token_id, num_beams=5, beam_size_token=bart.config.vocab_size)\n", + "print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs. The blackouts are expected to last through at least midday tomorrow. to be affected by the shutoffs which were expected to last through at least midday tomorrow. to be affected by the shutoffs which were expected to last through at least midday tomorrow. to be affected by the'] 58.09997892379761\n", + "['PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs. The blackouts were expected to last through at least midday tomorrow.'] 2.456479787826538\n" + ] + } + ], + "source": [ + "# Testing Decoding Speed HuggingFace's BART w/ TorchText Beam Search vs. HuggingFace Beam Search\n", + "import time\n", + "\n", + "start = time.time()\n", + "tokens = generative_hf_bart.generate(test_sequence_tk, max_len=100, pad_idx=t5.config.pad_token_id, num_beams=5, eos_score=1.0, beam_size_token=t5.config.vocab_size)\n", + "end = time.time()\n", + "print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)\n", + "\n", + "start = time.time()\n", + "tokens = bart.generate(test_sequence_tk, max_length=100, num_beams=5, do_sample=False)\n", + "end = time.time()\n", + "print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -119,11 +214,29 @@ "tokens = generative_hf_gpt2.generate(test_sequence_tk, max_len=20, pad_idx=gpt2.config.pad_token_id)\n", "print(gpt2_tokenizer.batch_decode(tokens, skip_special_tokens=True))" ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['I enjoy walking with my cute dog,\" says Kelli Williams-Petersen. The dog loves it so much, that when she']\n" + ] + } + ], + "source": [ + "tokens = generative_hf_gpt2.generate(test_sequence_tk, max_len=20, pad_idx=gpt2.config.pad_token_id, num_beams=5, beam_size_token=gpt2.config.vocab_size)\n", + "print(gpt2_tokenizer.batch_decode(tokens, skip_special_tokens=True))" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3.9.13 ('torchtext39')", + "display_name": "torchtext", "language": "python", "name": "python3" }, @@ -137,12 +250,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.13" + "version": "3.9.15" }, "orig_nbformat": 4, "vscode": { "interpreter": { - "hash": "63c8862cb56f124e3ee7674b73de745eeb216416a9b24f78d1fcb7c775bff1b7" + "hash": "1851d106532ddfc6fbd983b9ae95397243fcc3930d811046c990ea169e960650" } } }, diff --git a/test/integration_tests/test_generate.py b/test/integration_tests/test_generate.py index 80acd9dc7e..028cc7a4b1 100644 --- a/test/integration_tests/test_generate.py +++ b/test/integration_tests/test_generate.py @@ -28,7 +28,7 @@ def setUp(self) -> None: def test_greedy_generate_with_t5(self) -> None: generation_model = GenerationUtils(self.model) - tokens = generation_model.generate(self.inputs, num_beams=1, max_length=30) + tokens = generation_model.generate(self.inputs, num_beams=1) generated_text = self.transform.decode(tokens.tolist()) expected_generated_text = [ @@ -41,13 +41,97 @@ def test_greedy_generate_with_t5(self) -> None: self.assertEqual(generated_text, expected_generated_text) + def test_beam_search_generate_t5(self) -> None: + generation_model = GenerationUtils(self.model) + + tokens = generation_model.generate( + self.inputs, num_beams=3, vocab_size=self.model.config.vocab_size, max_length=30 + ) + generated_text = self.transform.decode(tokens.tolist()) + + expected_generated_text = [ + "kate mccartney: a dog is good for you . she says studies have shown that dog ownership is good for", + "Das ist gut.", + "acceptable", + "4.0", + "a tornado ripped through a swath of a lake in southeastern michigan . a spokesman", + ] + + self.assertEqual(generated_text, expected_generated_text) + + def test_beam_search_generate_t5_small_batch_size(self) -> None: + generation_model = GenerationUtils(self.model) + + tokens = generation_model.generate( + self.inputs, num_beams=3, vocab_size=self.model.config.vocab_size, max_length=30, max_inference_batch_size=3 + ) + generated_text = self.transform.decode(tokens.tolist()) + + expected_generated_text = [ + "kate mccartney: a dog is good for you . she says studies have shown that dog ownership is good for", + "Das ist gut.", + "acceptable", + "4.0", + "a tornado ripped through a swath of a lake in southeastern michigan . a spokesman", + ] + + self.assertEqual(generated_text, expected_generated_text) + + def test_beam_search_generate_t5_with_small_beam_threshold(self) -> None: + generation_model = GenerationUtils(self.model) + + tokens = generation_model.generate( + self.inputs, num_beams=3, vocab_size=self.model.config.vocab_size, max_length=30, beam_threshold=5 + ) + generated_text = self.transform.decode(tokens.tolist()) + + expected_text = [ + "kate mccartney: a dog is good for you . kate mccartney: dogs", + "Das ist gut.", + "acceptable", + "4.0", + "a tornado ripped through a swath of a lake in southeastern mississippi, causing", + ] + + self.assertEqual(generated_text, expected_text) + + def test_beam_search_generate_t5_large_num_beams(self) -> None: + generation_model = GenerationUtils(self.model) + + tokens = generation_model.generate( + self.inputs, num_beams=25, vocab_size=self.model.config.vocab_size, max_length=30 + ) + generated_text = self.transform.decode(tokens.tolist()) + + expected_text = [ + "aaron carroll, aaron jones, aaron jones and aaron jones", + "Das ist gut.", + "acceptable", + "4.0", + "a blizzard and power outages have prompted a blizzard and power outages, a spokesman says", + ] + + self.assertEqual(generated_text, expected_text) + + def test_beam_search_generate_t5_large_num_beams_eos_score(self) -> None: + generation_model = GenerationUtils(self.model) + + tokens = generation_model.generate( + self.inputs, num_beams=25, vocab_size=self.model.config.vocab_size, max_length=30, eos_score=10.0 + ) + generated_text = self.transform.decode(tokens.tolist()) + + expected_text = ["", "Das ist gut.", "acceptable", "4.0", ""] + + self.assertEqual(generated_text, expected_text) + def test_generate_errors_with_incorrect_beams(self) -> None: generation_model = GenerationUtils(self.model, is_encoder_decoder=True) with self.assertRaises(ValueError): generation_model.generate(self.inputs, num_beams=0) - @patch("logging.Logger.warning") + @patch("warnings.warn") def test_warns_when_no_max_len_provided(self, mock) -> None: generation_model = GenerationUtils(self.model) generation_model.generate(self.inputs) diff --git a/test/torchtext_unittest/prototype/test_generate.py b/test/torchtext_unittest/prototype/test_generate.py new file mode 100644 index 0000000000..c7e69338f4 --- /dev/null +++ b/test/torchtext_unittest/prototype/test_generate.py @@ -0,0 +1,109 @@ +from unittest.mock import patch +from torchtext.prototype.generate import GenerationUtil +from torchtext.prototype.models import T5_BASE_GENERATION +from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase +import torch + + +class TestGenerationUtil(TorchtextTestCase): + def setUp(self) -> None: + super().setUp() + t5_base = T5_BASE_GENERATION + self.transform = t5_base.transform() + self.model = t5_base.get_model() + self.model.eval() + # Examples taken from T5 Paper and Huggingface + self.inputs = self.transform( + [ + "summarize: studies have shown that owning a dog is good for you", + "translate English to German: That is good.", + "cola sentence: The course is jumping well.", + "stsb sentence1: The rhino grazed on the grass. sentence2: A rhino is grazing in a field.", + "summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi...", + ] + ) + torch.manual_seed(0) + + def test_greedy_generate_with_t5(self) -> None: + generation_model = GenerationUtil(self.model) + + tokens = generation_model.generate(self.inputs, num_beams=1, max_len=30) + generated_text = self.transform.decode(tokens.tolist()) + + expected_generated_text = [ + "a dog is good for you, according to studies . owning a dog is good for you, according to studies .", + "Das ist gut.", + "acceptable", + "4.0", + "mississippi authorities dispatch emergency crews to survey damage . severe weather in mississippi has caused extensive damage", + ] + + self.assertEqual(generated_text, expected_generated_text) + + def test_generate_errors_with_incorrect_beams(self) -> None: + generation_model = GenerationUtil(self.model, is_encoder_decoder=True) + + with self.assertRaises(ValueError): + generation_model.generate(self.inputs, num_beams=0) + + @patch("warnings.warn") + def test_warns_when_no_max_len_provided(self, mock) -> None: + generation_model = GenerationUtil(self.model) + generation_model.generate(self.inputs) + mock.assert_called_with("`max_len` was not specified. Defaulting to 256 tokens.") + + def test_warns_when_mp_with_greedy(self, mock) -> None: + pass + + def test_beam_search_with_t5_(self) -> None: + generation_model = GenerationUtil(self.model) + tokens = generation_model.generate( + self.inputs, num_beams=3, max_len=30, beam_size_token=self.model.config.vocab_size + ) + generated_text = self.transform.decode(tokens.tolist()) + + expected_generated_text = [ + "kate mccartney: a dog is good for you . she says studies have shown that dog ownership is good for", + "Das ist gut.", + "acceptable", + "4.0", + "a tornado ripped through a swath of a lake in st. louis . a s", + ] + + self.assertEqual(generated_text, expected_generated_text) + + def test_hf_DELETE(self) -> None: + from transformers import T5ForConditionalGeneration, T5Tokenizer + from torchtext.prototype.generate import GenerationUtil + + t5 = T5ForConditionalGeneration.from_pretrained("t5-base") + test_sequence = [ + "summarize: studies have shown that owning a dog is good for you" + ] # , "Q: what is the capital of Alaska?"] + generative_hf_t5 = GenerationUtil(t5, is_encoder_decoder=True, is_huggingface_model=True) + t5_tokenizer = T5Tokenizer.from_pretrained("t5-base") + test_sequence_tk = t5_tokenizer(test_sequence, padding=True, return_tensors="pt").input_ids + import time + + start = time.time() + tokens = generative_hf_t5.generate( + test_sequence_tk, + max_len=100, + pad_idx=t5.config.pad_token_id, + num_beams=7, + beam_size_token=t5.config.vocab_size, + ) + end = time.time() - start + print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True), end) + exit() + + def test_jit_generate(self) -> None: + # jitted_model = torch.jit.script(self.model) + # encoder = jitted_model.get_encoder() + + + generation_model = GenerationUtil(self.model) + torch.jit.script(generation_model) + + def test_beam_search_speed(self) -> None: + pass diff --git a/torchtext/prototype/generate.py b/torchtext/prototype/generate.py index dd74948c81..158072ae7e 100644 --- a/torchtext/prototype/generate.py +++ b/torchtext/prototype/generate.py @@ -1,16 +1,34 @@ -import logging -from typing import Optional +import warnings +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F +from flashlight.lib.text.decoder import ( + LexiconFreeSeq2SeqDecoder, + LexiconFreeSeq2SeqDecoderOptions, + ZeroLM, + create_emitting_model_state, + get_obj_from_emitting_model_state, +) from torch import nn -logger = logging.getLogger(__name__) + +MODEL_KWARGS_TYPE = Dict[str, Dict[str, Union[torch.Tensor, List[Optional[torch.Tensor]], List[torch.Tensor], None]]] DEFAULT_MAX_SEQ_LEN = 256 -class GenerationUtils: +@dataclass +class Seq2SeqModelState(object): + """Seq2SeqModelState for holding state between beam search rounds.""" + + timestep: int + sequence: torch.Tensor + lm_scores: Optional[torch.Tensor] + + +class GenerationUtils(nn.Module): """Wrapper to provide generation utils for encoder/decoder models and decoder models. Example: @@ -34,21 +52,65 @@ class GenerationUtils: More examples can be found in the `notebooks` directory of this repository. """ + _huggingface_model_input_values = {"return_dict": True, "use_cache": True, "output_hidden_states": True} + def __init__(self, model: nn.Module, **kwargs) -> None: + super().__init__() self.model = model self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", True) self.is_huggingface_model = kwargs.pop("is_huggingface_model", False) + self.incremental_decoding = kwargs.pop("incremental_decoding", False) + if self.is_huggingface_model: + warnings.warn( + "PyTorch does not make any claims about the stability of HuggingFace APIs so using all models from the `transformers` library are experimental." + ) + + def _prepare_encoder_decoder_kwargs_for_generation(self, inputs: torch.Tensor) -> MODEL_KWARGS_TYPE: + """Runs encoder and adds to model_kwargs for decoding. Modified from https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/generation/utils.py#L592. + + Args: + inputs: (Tensor): Tokenized startings sequence(s). + model_kwargs (Dict[str, Any]): Model keyword arguments to be modified for decoding. + + Returns: + Modified model_kwargs with addition of encoded input sequence(s). + """ + # Get encoder + encoder = self.model.get_encoder() + + # Create copy of encoder kwargs + encoder_kwargs: Dict[str, bool] = {} + + if self.is_huggingface_model: + encoder_kwargs["return_dict"] = True + + # Forward pass + # Explicitly call forward method to assert to assert this is a ScriptModule if JITted + model_kwargs = {"encoder_outputs": encoder.forward(inputs, **encoder_kwargs)} + return model_kwargs def _prepare_decoder_ids_for_generation( - self, batch_size: int, pad_idx: int = 0, device: Optional[torch.device] = None, **model_kwargs - ): + self, + batch_size: int, + pad_idx: int = 0, + device: Optional[torch.device] = None, + model_kwargs: Optional[MODEL_KWARGS_TYPE] = None, + ) -> torch.Tensor: + """Prepare decoder IDs for generation.""" if model_kwargs is not None and "decoder_input_ids" in model_kwargs: - return model_kwargs.pop("decoder_input_ids") + decoder_input_ids = model_kwargs.pop("decoder_input_ids") + assert torch.jit.isinstance(decoder_input_ids, torch.Tensor) + return decoder_input_ids else: return torch.ones((batch_size, 1), dtype=torch.long, device=device) * pad_idx def greedy_search( - self, input_ids: torch.Tensor, max_length: int, eos_idx: int, pad_idx: Optional[int] = None, **model_kwargs + self, + input_ids: torch.Tensor, + max_length: int, + eos_idx: int, + pad_idx: Optional[int] = None, + model_kwargs: Optional[MODEL_KWARGS_TYPE] = {}, ) -> torch.Tensor: """Greedy search decoding for text generation. Takes the most likely next token every time. @@ -57,7 +119,7 @@ def greedy_search( max_length (int): Max length to generate responses. eos_idx (int): End of sequence index. pad_idx (int): Padding index. - **model_kwargs + model_kwargs Returns: Batch of sequences decoded by greedy search. @@ -67,8 +129,7 @@ def greedy_search( while True: model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs) if self.is_huggingface_model: - model_inputs["return_dict"] = True - model_inputs["output_hidden_states"] = True + model_inputs.update(self._huggingface_model_input_values) # Get model output outputs = self.model(**model_inputs) @@ -80,9 +141,8 @@ def greedy_search( _, next_tokens = torch.topk(probs, 1) # For any finished sequences, padding idx should be the last token - if eos_idx is not None: - if pad_idx is not None: - next_tokens = next_tokens * unfinished_sequences + pad_idx * (1 - unfinished_sequences) + if eos_idx is not None and pad_idx is not None: + next_tokens = next_tokens * unfinished_sequences + pad_idx * (1 - unfinished_sequences) # Append the next tokens to the previous tokens input_ids = torch.cat([input_ids, next_tokens], dim=-1) @@ -96,8 +156,215 @@ def greedy_search( return input_ids - def beam_search(self, input_ids: torch.Tensor, num_beams: int, max_length: Optional[int]) -> torch.Tensor: - raise NotImplementedError() + def beam_search( + self, + input_ids: torch.Tensor, + num_beams: int, + max_len: int, + beam_size_token: int, + beam_threshold: int, + eos_score: float, + eos_idx: int, + num_python_workers: int, + max_inference_batch_size: int, + model_kwargs: Dict[str, Any], + ) -> torch.Tensor: + """Beam search implemented using Flashlight Text (https://github.com/flashlight/text). + + Args: + input_ids (Tensor): Tokenized startings sequence(s). + num_beams (int): Number of beams to use in the beam search. + max_len (int): Maximum number of tokens to generate. + beam_size_token (int): Vocab size for the LM being used. + beam_threshold (int): Threshold before pruning. + eos_score (float): Score to input when `eos_idx` is generated. + eos_idx (int): End-of-sequence index. + num_python_workers (int): Number of python workers to use for multiprocessing. + model_kwargs + + Returns: + Tensor of the generated sequences. + """ + device = input_ids.device + + if self.is_encoder_decoder: + encoder_output_key = "last_hidden_state" if self.is_huggingface_model else "encoder_output" + encoder_output = model_kwargs["encoder_outputs"][encoder_output_key] + + def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_step_model_states, timestep): + # `emissions` and `N` are unused in this current implementation + + i = T # Hacky access to the current seq in inputs + + # For first timestep, create previous step token_idxs and model_states + if timestep == 0: + prev_step_token_idxs = [-1] + prev_step_model_states = [ + create_emitting_model_state( + Seq2SeqModelState(timestep=0, sequence=input_ids[i].unsqueeze(0), lm_scores=None) + ) + ] + + encoder_output_for_curr_seq = encoder_output[i, :, :].unsqueeze(0) if self.is_encoder_decoder else None + prev_model_state_sequences = [ + get_obj_from_emitting_model_state(state).sequence for state in prev_step_model_states + ] + out_probs, model_states = [], [] + + start = 0 + # This is the parallelism level at which elements in the beam will be batched + step = min( + max_inference_batch_size, 1000 / (timestep + 1) + ) # many hypotheses will EOS, so increase the batch size gradually + curr_beam_size = len(prev_step_token_idxs) + + # 2. Batched inference to get next tokens + while start < curr_beam_size: # catch the remainder + end = start + step + if end > curr_beam_size: + end = curr_beam_size + + num_samples = end - start + + if prev_step_token_idxs != [-1]: + state_sequences = torch.cat(prev_model_state_sequences[start:end], dim=0) + token_indices = ( + torch.Tensor(prev_step_token_idxs[start:end]) + .to(dtype=torch.long, device=device) + .reshape(num_samples, 1) + ) + + state_and_tokens = torch.cat( + [state_sequences, token_indices], dim=-1 + ) # [batch_size x (timestep + 1)] + assert state_and_tokens.shape == ( + num_samples, + timestep + 1, + ), f"state_and_tokens has shape {state_and_tokens.shape} = expected {(num_samples, timestep + 1)}" + else: + assert len(prev_model_state_sequences) == 1 + state_and_tokens = prev_model_state_sequences[0] + + # Cleanup -- combine this with the above + if self.is_encoder_decoder: + # Expand encoder outputs along the batch dimension so that they match the decoder input state's batch size + # This is a view-only operation and doesn't copy + model_kwargs["encoder_outputs"][encoder_output_key] = encoder_output_for_curr_seq.expand( + num_samples if timestep > 0 else 1, -1, -1 + ) + + # Preprocess inputs for generation + model_inputs = self.model.prepare_inputs_for_generation(state_and_tokens, **model_kwargs) + if self.is_huggingface_model: + model_inputs.update(self._huggingface_model_input_values) + + # Forward pass + outputs = self.model(**model_inputs) + + # Collect outputs + output_key = "logits" if self.is_huggingface_model else "decoder_output" + lm_scores = outputs[output_key] + + # Keep track of probabilities over vocab for this pairing + for i in range(lm_scores.shape[0]): + sample_lm_scores = lm_scores[i, -1] + out_probs.append(sample_lm_scores.tolist()) + # Keep track of sequence and decoder hidden states + model_states.append( + create_emitting_model_state( + Seq2SeqModelState( + timestep=timestep, + sequence=state_and_tokens[i].unsqueeze(0), + lm_scores=sample_lm_scores, + ) + ) + ) + + start += step + + return out_probs, model_states + + # 3. Initialize options and decoder from Flashlight Text + options = LexiconFreeSeq2SeqDecoderOptions( + beam_size=num_beams, + beam_size_token=beam_size_token, + beam_threshold=beam_threshold, + lm_weight=0.0, # We have no custom LM so score is zero + eos_score=eos_score, + log_add=True, + ) + + decoder = LexiconFreeSeq2SeqDecoder( + options=options, lm=ZeroLM(), eos_idx=eos_idx, update_func=update_func, max_output_length=max_len + ) + + # 4. Process outputs from beam decoder + # TODO: This can definitely be optimized + def beam_decode_step(timestep: int) -> torch.Tensor: + # Create these as function b/c unnamed functions (lambdas) cause problems w/ MP + def select_second_elem_in_tuple(tup: Tuple[List[int], float]) -> float: + return tup[1] + + def is_not_neg_one(elem: int) -> bool: + return elem != -1 + + # Decode step takes ptr to encoder emissions, i, and beam size token + # but actually these aren't currently being used. + decoder.decode_step(0, timestep, 0) + hyps = decoder.get_all_final_hypothesis() + + # Find the best beam + token_scores = [(hyp.tokens, hyp.score) for hyp in hyps] + final_tokens = list(filter(is_not_neg_one, max(token_scores, key=select_second_elem_in_tuple)[0])) + + # Have to prepend the input tokens if decoder-only model + if not self.is_encoder_decoder: + final_tokens = input_ids[timestep].tolist() + final_tokens + + # Makeshift padding so that we can stack the tensors + while len(final_tokens) < max_len: + final_tokens += [0] + + # Convert from list to tensors + final_tokens_as_tensors = torch.Tensor(final_tokens).to(torch.long) + + return final_tokens_as_tensors + + if num_python_workers > 1: + warnings.warn("Multiprocessing has not yet been implemented.") + + all_final_tokens = [beam_decode_step(i) for i in range(len(input_ids))] + + # 5. Return top hypotheses for all input sequences + return torch.stack(all_final_tokens, dim=0) + + def forward( + self, + inputs: Optional[torch.Tensor] = None, + num_beams: Optional[int] = None, + max_len: Optional[int] = None, + pad_idx: int = 0, + eos_idx: int = 1, + beam_threshold: int = 100, + beam_size_token: Optional[int] = None, + eos_score: float = -1.0, + num_python_workers: int = 1, + max_inference_batch_size: int = 16, + ): + """Calls self.generate() method.""" + warnings.warn("Forward method simply calls `GenerationUtils.generate()`. Please use generate method directly.") + return self.generate( + inputs=inputs, + num_beams=num_beams, + max_len=max_len, + pad_idx=pad_idx, + eos_idx=eos_idx, + beam_threshold=beam_threshold, + beam_size_token=beam_size_token, + eos_score=eos_score, + num_python_workers=num_python_workers, + max_inference_batch_size=max_inference_batch_size, + ) def generate( self, @@ -106,39 +373,66 @@ def generate( max_length: Optional[int] = None, pad_idx: int = 0, eos_idx: int = 1, + num_python_workers: int = 1, + beam_threshold: int = 100, + vocab_size: Optional[int] = None, + eos_score: float = -1.0, + max_inference_batch_size: int = 16, ) -> torch.Tensor: - """Generation method. - - `num_beams` == 1 or `num_beams` is None -> greedy search - `num_beams` > 1 -> beam search + """Entrypoint generation method. Args: input_ids (Tensor): Ids of tokenized input tokens. The 'seed' text for generation. num_beams (int): If provided, specifies the number of beams to use in beam search generation. max_length (int): Max length to generate responses. pad_idx (int): Padding index. Defaults to 0. - eos_idx (int): End of sequence index. Defaults to 1. + eos_idx (int): End-of-sequence index. Defaults to 1. + num_python_workers (int): If > 1, using multiprocessing on CPU. + vocab_size (int): Vocab size for the beam search algo to evaluate, can typically default to vocab size of the model. + beam_threshold (int): Threshold before pruning; specific to beam search. + eos_score (float): Score to input when `eos_idx` is generated; specific to beam search. + max_inference_batch_size (int): In beam search, to avoid OOMs, can choose to batch smaller amounts of hypothesis; defaults to 16. Returns: Tensor of Tensors containing output sequences as ids. - `Note`: If one beam is provided or no beams are specified, the generation method will default to greedy search. + Conditions for generation: + 1. `num_beams` == 1 or `num_beams` is None -> greedy search + 2. `num_beams` > 1 -> beam search """ - model_kwargs = {} + model_kwargs: MODEL_KWARGS_TYPE = {} if self.is_encoder_decoder: - encoder = self.model.get_encoder() - model_kwargs["encoder_outputs"] = encoder(inputs) - inputs = self._prepare_decoder_ids_for_generation(len(inputs), device=inputs.device, **model_kwargs) + assert torch.jit.isinstance(inputs, torch.Tensor) + model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(inputs) + inputs = self._prepare_decoder_ids_for_generation( + len(inputs), device=inputs.device, model_kwargs=model_kwargs + ) if max_length is None: # Too hard to try to figure out the exact max_seq_length for each model - logger.warning(f"`max_length` was not specified. Defaulting to {DEFAULT_MAX_SEQ_LEN} tokens.") - max_length = DEFAULT_MAX_SEQ_LEN + warnings.warn("`max_length` was not specified. Defaulting to 256 tokens.") + max_length = 256 - if num_beams == 1 or num_beams is None: - return self.greedy_search(inputs, max_length, eos_idx, pad_idx=pad_idx, **model_kwargs) + if num_beams is None or num_beams == 1: + if num_python_workers > 1: + warnings.warn(f"Multiprocessing is not implemented for greedy search.") + return self.greedy_search(inputs, max_length, eos_idx, pad_idx=pad_idx, model_kwargs=model_kwargs) elif num_beams > 1: - return self.beam_search(inputs, num_beams, max_length) + assert ( + vocab_size is not None + ), "`vocab_size` must be specified for beam search. If confused about what to put, you can default to the vocab size of the model you are using." + return self.beam_search( + inputs, + num_beams, + beam_size_token=vocab_size, + max_len=max_length, + beam_threshold=beam_threshold, + eos_score=eos_score, + num_python_workers=num_python_workers, + eos_idx=eos_idx, + max_inference_batch_size=max_inference_batch_size, + model_kwargs=model_kwargs, + ) else: raise ValueError("`num_beams` must be >= 1.") From 3fd765ed86485824236625d7bdfe0655267f60f3 Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Tue, 28 Feb 2023 17:09:27 -0500 Subject: [PATCH 2/4] Add benchmarking script for generation utils --- benchmark/benchmark_generation_utils.py | 47 +++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 benchmark/benchmark_generation_utils.py diff --git a/benchmark/benchmark_generation_utils.py b/benchmark/benchmark_generation_utils.py new file mode 100644 index 0000000000..dacc1a367c --- /dev/null +++ b/benchmark/benchmark_generation_utils.py @@ -0,0 +1,47 @@ +import time +from functools import partial + +from torch.utils.data import DataLoader +from torcheval.metrics.functional import word_error_rate +from torchtext.data.metrics import bleu_score +from torchtext.datasets import CNNDM +from torchtext.datasets import Multi30k +from torchtext.models import T5_BASE_GENERATION +from torchtext.prototype.generate import GenerationUtils + +multi_batch_size = 5 +language_pair = ("en", "de") +multi_datapipe = Multi30k(split="test", language_pair=language_pair) +task = "translate English to German" + + +def apply_prefix(task, x): + return f"{task}: " + x[0], x[1] + + +multi_datapipe = multi_datapipe.map(partial(apply_prefix, task)) +multi_datapipe = multi_datapipe.batch(multi_batch_size) +multi_datapipe = multi_datapipe.rows2columnar(["english", "german"]) +multi_dataloader = DataLoader(multi_datapipe, batch_size=None) + + +def benchmark_beam_search_wer(): + model = T5_BASE_GENERATION.get_model() + transform = T5_BASE_GENERATION.transform() + + seq_generator = GenerationUtils(model) + + batch = next(iter(multi_dataloader)) + input_text = batch["english"] + target = batch["german"] + beam_size = 4 + + model_input = transform(input_text) + model_output = seq_generator.generate(model_input, num_beams=beam_size, vocab_size=model.config.vocab_size) + output_text = transform.decode(model_output.tolist()) + + print(word_error_rate(output_text, target)) + + +if __name__ == "__main__": + benchmark_beam_search_wer() From f506ec1ba46512aafa82a6e96e13420bc5ad93dd Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Tue, 28 Feb 2023 17:32:46 -0500 Subject: [PATCH 3/4] Update sample idx --- torchtext/prototype/generate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtext/prototype/generate.py b/torchtext/prototype/generate.py index 158072ae7e..ad0d326589 100644 --- a/torchtext/prototype/generate.py +++ b/torchtext/prototype/generate.py @@ -266,15 +266,15 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_ lm_scores = outputs[output_key] # Keep track of probabilities over vocab for this pairing - for i in range(lm_scores.shape[0]): - sample_lm_scores = lm_scores[i, -1] + for sample_idx in range(num_samples): + sample_lm_scores = lm_scores[sample_idx, -1] out_probs.append(sample_lm_scores.tolist()) # Keep track of sequence and decoder hidden states model_states.append( create_emitting_model_state( Seq2SeqModelState( timestep=timestep, - sequence=state_and_tokens[i].unsqueeze(0), + sequence=state_and_tokens[sample_idx].unsqueeze(0), lm_scores=sample_lm_scores, ) ) From 2f2f1aa524a4cbe4e312547c4b1bc710dc34d70a Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Thu, 2 Mar 2023 13:45:02 -0500 Subject: [PATCH 4/4] WIP --- benchmark/benchmark_generation_utils.py | 18 +++++++++----- torchtext/prototype/generate.py | 31 ++++++++++++++++--------- 2 files changed, 32 insertions(+), 17 deletions(-) diff --git a/benchmark/benchmark_generation_utils.py b/benchmark/benchmark_generation_utils.py index dacc1a367c..0e10722521 100644 --- a/benchmark/benchmark_generation_utils.py +++ b/benchmark/benchmark_generation_utils.py @@ -3,13 +3,11 @@ from torch.utils.data import DataLoader from torcheval.metrics.functional import word_error_rate -from torchtext.data.metrics import bleu_score -from torchtext.datasets import CNNDM from torchtext.datasets import Multi30k -from torchtext.models import T5_BASE_GENERATION +from torchtext.models import T5_BASE_GENERATION, T5_3B_GENERATION from torchtext.prototype.generate import GenerationUtils -multi_batch_size = 5 +multi_batch_size = 16 language_pair = ("en", "de") multi_datapipe = Multi30k(split="test", language_pair=language_pair) task = "translate English to German" @@ -34,10 +32,18 @@ def benchmark_beam_search_wer(): batch = next(iter(multi_dataloader)) input_text = batch["english"] target = batch["german"] - beam_size = 4 + beam_size = 8 model_input = transform(input_text) - model_output = seq_generator.generate(model_input, num_beams=beam_size, vocab_size=model.config.vocab_size) + model_output = seq_generator.generate( + model_input, + num_beams=beam_size, + beam_threshold=1000, + vocab_size=model.config.vocab_size, + eos_score=-1.0, + eos_idx=1, + pad_idx=0, + ) output_text = transform.decode(model_output.tolist()) print(word_error_rate(output_text, target)) diff --git a/torchtext/prototype/generate.py b/torchtext/prototype/generate.py index ad0d326589..4bc3dd763d 100644 --- a/torchtext/prototype/generate.py +++ b/torchtext/prototype/generate.py @@ -16,7 +16,7 @@ MODEL_KWARGS_TYPE = Dict[str, Dict[str, Union[torch.Tensor, List[Optional[torch.Tensor]], List[torch.Tensor], None]]] -DEFAULT_MAX_SEQ_LEN = 256 +DEFAULT_MAX_SEQ_LEN = 128 @dataclass @@ -213,8 +213,8 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_ start = 0 # This is the parallelism level at which elements in the beam will be batched - step = min( - max_inference_batch_size, 1000 / (timestep + 1) + step = int( + min(max_inference_batch_size, 1000 / (timestep + 1)) ) # many hypotheses will EOS, so increase the batch size gradually curr_beam_size = len(prev_step_token_idxs) @@ -310,9 +310,12 @@ def is_not_neg_one(elem: int) -> bool: # Decode step takes ptr to encoder emissions, i, and beam size token # but actually these aren't currently being used. - decoder.decode_step(0, timestep, 0) + decoder.decode_step(encoder_output.data_ptr(), timestep, beam_size_token) hyps = decoder.get_all_final_hypothesis() + # import pdb + # pdb.set_trace() + # Find the best beam token_scores = [(hyp.tokens, hyp.score) for hyp in hyps] final_tokens = list(filter(is_not_neg_one, max(token_scores, key=select_second_elem_in_tuple)[0])) @@ -328,12 +331,18 @@ def is_not_neg_one(elem: int) -> bool: # Convert from list to tensors final_tokens_as_tensors = torch.Tensor(final_tokens).to(torch.long) + import pdb + + pdb.set_trace() + return final_tokens_as_tensors if num_python_workers > 1: + # with multiprocessing.Pool(num_python_workers) as pool: + # all_final_tokens = pool.map(beam_decode_step, range(len(input_ids))) warnings.warn("Multiprocessing has not yet been implemented.") - - all_final_tokens = [beam_decode_step(i) for i in range(len(input_ids))] + else: + all_final_tokens = [beam_decode_step(i) for i in range(len(input_ids))] # 5. Return top hypotheses for all input sequences return torch.stack(all_final_tokens, dim=0) @@ -376,8 +385,8 @@ def generate( num_python_workers: int = 1, beam_threshold: int = 100, vocab_size: Optional[int] = None, - eos_score: float = -1.0, - max_inference_batch_size: int = 16, + eos_score: float = 0.0, + max_inference_batch_size: int = 32, ) -> torch.Tensor: """Entrypoint generation method. @@ -391,7 +400,7 @@ def generate( vocab_size (int): Vocab size for the beam search algo to evaluate, can typically default to vocab size of the model. beam_threshold (int): Threshold before pruning; specific to beam search. eos_score (float): Score to input when `eos_idx` is generated; specific to beam search. - max_inference_batch_size (int): In beam search, to avoid OOMs, can choose to batch smaller amounts of hypothesis; defaults to 16. + max_inference_batch_size (int): In beam search, to avoid OOMs, can choose to batch smaller amounts of hypothesis; defaults to 32. Returns: Tensor of Tensors containing output sequences as ids. @@ -411,8 +420,8 @@ def generate( if max_length is None: # Too hard to try to figure out the exact max_seq_length for each model - warnings.warn("`max_length` was not specified. Defaulting to 256 tokens.") - max_length = 256 + warnings.warn(f"`max_length` was not specified. Defaulting to {DEFAULT_MAX_SEQ_LEN} tokens.") + max_length = DEFAULT_MAX_SEQ_LEN if num_beams is None or num_beams == 1: if num_python_workers > 1: