diff --git a/benchmark/benchmark_generation_utils.py b/benchmark/benchmark_generation_utils.py new file mode 100644 index 0000000000..0e10722521 --- /dev/null +++ b/benchmark/benchmark_generation_utils.py @@ -0,0 +1,53 @@ +import time +from functools import partial + +from torch.utils.data import DataLoader +from torcheval.metrics.functional import word_error_rate +from torchtext.datasets import Multi30k +from torchtext.models import T5_BASE_GENERATION, T5_3B_GENERATION +from torchtext.prototype.generate import GenerationUtils + +multi_batch_size = 16 +language_pair = ("en", "de") +multi_datapipe = Multi30k(split="test", language_pair=language_pair) +task = "translate English to German" + + +def apply_prefix(task, x): + return f"{task}: " + x[0], x[1] + + +multi_datapipe = multi_datapipe.map(partial(apply_prefix, task)) +multi_datapipe = multi_datapipe.batch(multi_batch_size) +multi_datapipe = multi_datapipe.rows2columnar(["english", "german"]) +multi_dataloader = DataLoader(multi_datapipe, batch_size=None) + + +def benchmark_beam_search_wer(): + model = T5_BASE_GENERATION.get_model() + transform = T5_BASE_GENERATION.transform() + + seq_generator = GenerationUtils(model) + + batch = next(iter(multi_dataloader)) + input_text = batch["english"] + target = batch["german"] + beam_size = 8 + + model_input = transform(input_text) + model_output = seq_generator.generate( + model_input, + num_beams=beam_size, + beam_threshold=1000, + vocab_size=model.config.vocab_size, + eos_score=-1.0, + eos_idx=1, + pad_idx=0, + ) + output_text = transform.decode(model_output.tolist()) + + print(word_error_rate(output_text, target)) + + +if __name__ == "__main__": + benchmark_beam_search_wer() diff --git a/notebooks/hf_with_torchtext_gen.ipynb b/notebooks/hf_with_torchtext_gen.ipynb index 0df74a4b39..44f4ccc3b3 100644 --- a/notebooks/hf_with_torchtext_gen.ipynb +++ b/notebooks/hf_with_torchtext_gen.ipynb @@ -16,7 +16,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/tqdm-4.64.0-py3.9.egg/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + "/data/home/jrcummings/miniconda/envs/torchtext/lib/python3.9/site-packages/tqdm-4.64.1-py3.9.egg/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } @@ -39,14 +39,14 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/transformers/models/t5/tokenization_t5.py:164: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n", + "/data/home/jrcummings/miniconda/envs/torchtext/lib/python3.9/site-packages/transformers/models/t5/tokenization_t5.py:163: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n", "For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n", "- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.\n", "- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n", @@ -74,7 +74,55 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['a dog is good for you. studies have shown that dog ownership is good for your overall health and well-being.']\n" + ] + } + ], + "source": [ + "# Testing HuggingFace's T5 w/ Beam Search\n", + "tokens = generative_hf_t5.generate(test_sequence_tk, max_len=100, pad_idx=t5.config.pad_token_id, num_beams=5, beam_size_token=t5.config.vocab_size)\n", + "print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['a dog is good for you. studies have shown that dog ownership is good for your overall health and well-being.'] 9.786320924758911\n", + "['studies have shown that owning a dog is good for you. studies have shown that owning a dog is good for you.'] 1.3000121116638184\n" + ] + } + ], + "source": [ + "# Testing Decoding Speed HuggingFace's T5 w/ TorchText Beam Search vs. HuggingFace Beam Search\n", + "import time\n", + "\n", + "start = time.time()\n", + "tokens = generative_hf_t5.generate(test_sequence_tk, max_len=100, pad_idx=t5.config.pad_token_id, num_beams=5, beam_size_token=t5.config.vocab_size)\n", + "end = time.time()\n", + "print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)\n", + "\n", + "start = time.time()\n", + "tokens = t5.generate(test_sequence_tk, max_length=100, num_beams=5, do_sample=False)\n", + "end = time.time()\n", + "print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -99,7 +147,54 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['Nearly. PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions.']\n" + ] + } + ], + "source": [ + "tokens = generative_hf_bart.generate(test_sequence_tk, max_len=20, pad_idx=bart.config.pad_token_id, num_beams=5, beam_size_token=bart.config.vocab_size)\n", + "print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs. The blackouts are expected to last through at least midday tomorrow. to be affected by the shutoffs which were expected to last through at least midday tomorrow. to be affected by the shutoffs which were expected to last through at least midday tomorrow. to be affected by the'] 58.09997892379761\n", + "['PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs. The blackouts were expected to last through at least midday tomorrow.'] 2.456479787826538\n" + ] + } + ], + "source": [ + "# Testing Decoding Speed HuggingFace's BART w/ TorchText Beam Search vs. HuggingFace Beam Search\n", + "import time\n", + "\n", + "start = time.time()\n", + "tokens = generative_hf_bart.generate(test_sequence_tk, max_len=100, pad_idx=t5.config.pad_token_id, num_beams=5, eos_score=1.0, beam_size_token=t5.config.vocab_size)\n", + "end = time.time()\n", + "print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)\n", + "\n", + "start = time.time()\n", + "tokens = bart.generate(test_sequence_tk, max_length=100, num_beams=5, do_sample=False)\n", + "end = time.time()\n", + "print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -119,11 +214,29 @@ "tokens = generative_hf_gpt2.generate(test_sequence_tk, max_len=20, pad_idx=gpt2.config.pad_token_id)\n", "print(gpt2_tokenizer.batch_decode(tokens, skip_special_tokens=True))" ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['I enjoy walking with my cute dog,\" says Kelli Williams-Petersen. The dog loves it so much, that when she']\n" + ] + } + ], + "source": [ + "tokens = generative_hf_gpt2.generate(test_sequence_tk, max_len=20, pad_idx=gpt2.config.pad_token_id, num_beams=5, beam_size_token=gpt2.config.vocab_size)\n", + "print(gpt2_tokenizer.batch_decode(tokens, skip_special_tokens=True))" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3.9.13 ('torchtext39')", + "display_name": "torchtext", "language": "python", "name": "python3" }, @@ -137,12 +250,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.13" + "version": "3.9.15" }, "orig_nbformat": 4, "vscode": { "interpreter": { - "hash": "63c8862cb56f124e3ee7674b73de745eeb216416a9b24f78d1fcb7c775bff1b7" + "hash": "1851d106532ddfc6fbd983b9ae95397243fcc3930d811046c990ea169e960650" } } }, diff --git a/test/integration_tests/test_generate.py b/test/integration_tests/test_generate.py index 80acd9dc7e..028cc7a4b1 100644 --- a/test/integration_tests/test_generate.py +++ b/test/integration_tests/test_generate.py @@ -28,7 +28,7 @@ def setUp(self) -> None: def test_greedy_generate_with_t5(self) -> None: generation_model = GenerationUtils(self.model) - tokens = generation_model.generate(self.inputs, num_beams=1, max_length=30) + tokens = generation_model.generate(self.inputs, num_beams=1) generated_text = self.transform.decode(tokens.tolist()) expected_generated_text = [ @@ -41,13 +41,97 @@ def test_greedy_generate_with_t5(self) -> None: self.assertEqual(generated_text, expected_generated_text) + def test_beam_search_generate_t5(self) -> None: + generation_model = GenerationUtils(self.model) + + tokens = generation_model.generate( + self.inputs, num_beams=3, vocab_size=self.model.config.vocab_size, max_length=30 + ) + generated_text = self.transform.decode(tokens.tolist()) + + expected_generated_text = [ + "kate mccartney: a dog is good for you . she says studies have shown that dog ownership is good for", + "Das ist gut.", + "acceptable", + "4.0", + "a tornado ripped through a swath of a lake in southeastern michigan . a spokesman", + ] + + self.assertEqual(generated_text, expected_generated_text) + + def test_beam_search_generate_t5_small_batch_size(self) -> None: + generation_model = GenerationUtils(self.model) + + tokens = generation_model.generate( + self.inputs, num_beams=3, vocab_size=self.model.config.vocab_size, max_length=30, max_inference_batch_size=3 + ) + generated_text = self.transform.decode(tokens.tolist()) + + expected_generated_text = [ + "kate mccartney: a dog is good for you . she says studies have shown that dog ownership is good for", + "Das ist gut.", + "acceptable", + "4.0", + "a tornado ripped through a swath of a lake in southeastern michigan . a spokesman", + ] + + self.assertEqual(generated_text, expected_generated_text) + + def test_beam_search_generate_t5_with_small_beam_threshold(self) -> None: + generation_model = GenerationUtils(self.model) + + tokens = generation_model.generate( + self.inputs, num_beams=3, vocab_size=self.model.config.vocab_size, max_length=30, beam_threshold=5 + ) + generated_text = self.transform.decode(tokens.tolist()) + + expected_text = [ + "kate mccartney: a dog is good for you . kate mccartney: dogs", + "Das ist gut.", + "acceptable", + "4.0", + "a tornado ripped through a swath of a lake in southeastern mississippi, causing", + ] + + self.assertEqual(generated_text, expected_text) + + def test_beam_search_generate_t5_large_num_beams(self) -> None: + generation_model = GenerationUtils(self.model) + + tokens = generation_model.generate( + self.inputs, num_beams=25, vocab_size=self.model.config.vocab_size, max_length=30 + ) + generated_text = self.transform.decode(tokens.tolist()) + + expected_text = [ + "aaron carroll, aaron jones, aaron jones and aaron jones", + "Das ist gut.", + "acceptable", + "4.0", + "a blizzard and power outages have prompted a blizzard and power outages, a spokesman says", + ] + + self.assertEqual(generated_text, expected_text) + + def test_beam_search_generate_t5_large_num_beams_eos_score(self) -> None: + generation_model = GenerationUtils(self.model) + + tokens = generation_model.generate( + self.inputs, num_beams=25, vocab_size=self.model.config.vocab_size, max_length=30, eos_score=10.0 + ) + generated_text = self.transform.decode(tokens.tolist()) + + expected_text = ["", "Das ist gut.", "acceptable", "4.0", ""] + + self.assertEqual(generated_text, expected_text) + def test_generate_errors_with_incorrect_beams(self) -> None: generation_model = GenerationUtils(self.model, is_encoder_decoder=True) with self.assertRaises(ValueError): generation_model.generate(self.inputs, num_beams=0) - @patch("logging.Logger.warning") + @patch("warnings.warn") def test_warns_when_no_max_len_provided(self, mock) -> None: generation_model = GenerationUtils(self.model) generation_model.generate(self.inputs) diff --git a/test/torchtext_unittest/prototype/test_generate.py b/test/torchtext_unittest/prototype/test_generate.py new file mode 100644 index 0000000000..c7e69338f4 --- /dev/null +++ b/test/torchtext_unittest/prototype/test_generate.py @@ -0,0 +1,109 @@ +from unittest.mock import patch +from torchtext.prototype.generate import GenerationUtil +from torchtext.prototype.models import T5_BASE_GENERATION +from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase +import torch + + +class TestGenerationUtil(TorchtextTestCase): + def setUp(self) -> None: + super().setUp() + t5_base = T5_BASE_GENERATION + self.transform = t5_base.transform() + self.model = t5_base.get_model() + self.model.eval() + # Examples taken from T5 Paper and Huggingface + self.inputs = self.transform( + [ + "summarize: studies have shown that owning a dog is good for you", + "translate English to German: That is good.", + "cola sentence: The course is jumping well.", + "stsb sentence1: The rhino grazed on the grass. sentence2: A rhino is grazing in a field.", + "summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi...", + ] + ) + torch.manual_seed(0) + + def test_greedy_generate_with_t5(self) -> None: + generation_model = GenerationUtil(self.model) + + tokens = generation_model.generate(self.inputs, num_beams=1, max_len=30) + generated_text = self.transform.decode(tokens.tolist()) + + expected_generated_text = [ + "a dog is good for you, according to studies . owning a dog is good for you, according to studies .", + "Das ist gut.", + "acceptable", + "4.0", + "mississippi authorities dispatch emergency crews to survey damage . severe weather in mississippi has caused extensive damage", + ] + + self.assertEqual(generated_text, expected_generated_text) + + def test_generate_errors_with_incorrect_beams(self) -> None: + generation_model = GenerationUtil(self.model, is_encoder_decoder=True) + + with self.assertRaises(ValueError): + generation_model.generate(self.inputs, num_beams=0) + + @patch("warnings.warn") + def test_warns_when_no_max_len_provided(self, mock) -> None: + generation_model = GenerationUtil(self.model) + generation_model.generate(self.inputs) + mock.assert_called_with("`max_len` was not specified. Defaulting to 256 tokens.") + + def test_warns_when_mp_with_greedy(self, mock) -> None: + pass + + def test_beam_search_with_t5_(self) -> None: + generation_model = GenerationUtil(self.model) + tokens = generation_model.generate( + self.inputs, num_beams=3, max_len=30, beam_size_token=self.model.config.vocab_size + ) + generated_text = self.transform.decode(tokens.tolist()) + + expected_generated_text = [ + "kate mccartney: a dog is good for you . she says studies have shown that dog ownership is good for", + "Das ist gut.", + "acceptable", + "4.0", + "a tornado ripped through a swath of a lake in st. louis . a s", + ] + + self.assertEqual(generated_text, expected_generated_text) + + def test_hf_DELETE(self) -> None: + from transformers import T5ForConditionalGeneration, T5Tokenizer + from torchtext.prototype.generate import GenerationUtil + + t5 = T5ForConditionalGeneration.from_pretrained("t5-base") + test_sequence = [ + "summarize: studies have shown that owning a dog is good for you" + ] # , "Q: what is the capital of Alaska?"] + generative_hf_t5 = GenerationUtil(t5, is_encoder_decoder=True, is_huggingface_model=True) + t5_tokenizer = T5Tokenizer.from_pretrained("t5-base") + test_sequence_tk = t5_tokenizer(test_sequence, padding=True, return_tensors="pt").input_ids + import time + + start = time.time() + tokens = generative_hf_t5.generate( + test_sequence_tk, + max_len=100, + pad_idx=t5.config.pad_token_id, + num_beams=7, + beam_size_token=t5.config.vocab_size, + ) + end = time.time() - start + print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True), end) + exit() + + def test_jit_generate(self) -> None: + # jitted_model = torch.jit.script(self.model) + # encoder = jitted_model.get_encoder() + + + generation_model = GenerationUtil(self.model) + torch.jit.script(generation_model) + + def test_beam_search_speed(self) -> None: + pass diff --git a/torchtext/prototype/generate.py b/torchtext/prototype/generate.py index dd74948c81..4bc3dd763d 100644 --- a/torchtext/prototype/generate.py +++ b/torchtext/prototype/generate.py @@ -1,16 +1,34 @@ -import logging -from typing import Optional +import warnings +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F +from flashlight.lib.text.decoder import ( + LexiconFreeSeq2SeqDecoder, + LexiconFreeSeq2SeqDecoderOptions, + ZeroLM, + create_emitting_model_state, + get_obj_from_emitting_model_state, +) from torch import nn -logger = logging.getLogger(__name__) -DEFAULT_MAX_SEQ_LEN = 256 +MODEL_KWARGS_TYPE = Dict[str, Dict[str, Union[torch.Tensor, List[Optional[torch.Tensor]], List[torch.Tensor], None]]] +DEFAULT_MAX_SEQ_LEN = 128 -class GenerationUtils: + +@dataclass +class Seq2SeqModelState(object): + """Seq2SeqModelState for holding state between beam search rounds.""" + + timestep: int + sequence: torch.Tensor + lm_scores: Optional[torch.Tensor] + + +class GenerationUtils(nn.Module): """Wrapper to provide generation utils for encoder/decoder models and decoder models. Example: @@ -34,21 +52,65 @@ class GenerationUtils: More examples can be found in the `notebooks` directory of this repository. """ + _huggingface_model_input_values = {"return_dict": True, "use_cache": True, "output_hidden_states": True} + def __init__(self, model: nn.Module, **kwargs) -> None: + super().__init__() self.model = model self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", True) self.is_huggingface_model = kwargs.pop("is_huggingface_model", False) + self.incremental_decoding = kwargs.pop("incremental_decoding", False) + if self.is_huggingface_model: + warnings.warn( + "PyTorch does not make any claims about the stability of HuggingFace APIs so using all models from the `transformers` library are experimental." + ) + + def _prepare_encoder_decoder_kwargs_for_generation(self, inputs: torch.Tensor) -> MODEL_KWARGS_TYPE: + """Runs encoder and adds to model_kwargs for decoding. Modified from https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/generation/utils.py#L592. + + Args: + inputs: (Tensor): Tokenized startings sequence(s). + model_kwargs (Dict[str, Any]): Model keyword arguments to be modified for decoding. + + Returns: + Modified model_kwargs with addition of encoded input sequence(s). + """ + # Get encoder + encoder = self.model.get_encoder() + + # Create copy of encoder kwargs + encoder_kwargs: Dict[str, bool] = {} + + if self.is_huggingface_model: + encoder_kwargs["return_dict"] = True + + # Forward pass + # Explicitly call forward method to assert to assert this is a ScriptModule if JITted + model_kwargs = {"encoder_outputs": encoder.forward(inputs, **encoder_kwargs)} + return model_kwargs def _prepare_decoder_ids_for_generation( - self, batch_size: int, pad_idx: int = 0, device: Optional[torch.device] = None, **model_kwargs - ): + self, + batch_size: int, + pad_idx: int = 0, + device: Optional[torch.device] = None, + model_kwargs: Optional[MODEL_KWARGS_TYPE] = None, + ) -> torch.Tensor: + """Prepare decoder IDs for generation.""" if model_kwargs is not None and "decoder_input_ids" in model_kwargs: - return model_kwargs.pop("decoder_input_ids") + decoder_input_ids = model_kwargs.pop("decoder_input_ids") + assert torch.jit.isinstance(decoder_input_ids, torch.Tensor) + return decoder_input_ids else: return torch.ones((batch_size, 1), dtype=torch.long, device=device) * pad_idx def greedy_search( - self, input_ids: torch.Tensor, max_length: int, eos_idx: int, pad_idx: Optional[int] = None, **model_kwargs + self, + input_ids: torch.Tensor, + max_length: int, + eos_idx: int, + pad_idx: Optional[int] = None, + model_kwargs: Optional[MODEL_KWARGS_TYPE] = {}, ) -> torch.Tensor: """Greedy search decoding for text generation. Takes the most likely next token every time. @@ -57,7 +119,7 @@ def greedy_search( max_length (int): Max length to generate responses. eos_idx (int): End of sequence index. pad_idx (int): Padding index. - **model_kwargs + model_kwargs Returns: Batch of sequences decoded by greedy search. @@ -67,8 +129,7 @@ def greedy_search( while True: model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs) if self.is_huggingface_model: - model_inputs["return_dict"] = True - model_inputs["output_hidden_states"] = True + model_inputs.update(self._huggingface_model_input_values) # Get model output outputs = self.model(**model_inputs) @@ -80,9 +141,8 @@ def greedy_search( _, next_tokens = torch.topk(probs, 1) # For any finished sequences, padding idx should be the last token - if eos_idx is not None: - if pad_idx is not None: - next_tokens = next_tokens * unfinished_sequences + pad_idx * (1 - unfinished_sequences) + if eos_idx is not None and pad_idx is not None: + next_tokens = next_tokens * unfinished_sequences + pad_idx * (1 - unfinished_sequences) # Append the next tokens to the previous tokens input_ids = torch.cat([input_ids, next_tokens], dim=-1) @@ -96,8 +156,224 @@ def greedy_search( return input_ids - def beam_search(self, input_ids: torch.Tensor, num_beams: int, max_length: Optional[int]) -> torch.Tensor: - raise NotImplementedError() + def beam_search( + self, + input_ids: torch.Tensor, + num_beams: int, + max_len: int, + beam_size_token: int, + beam_threshold: int, + eos_score: float, + eos_idx: int, + num_python_workers: int, + max_inference_batch_size: int, + model_kwargs: Dict[str, Any], + ) -> torch.Tensor: + """Beam search implemented using Flashlight Text (https://github.com/flashlight/text). + + Args: + input_ids (Tensor): Tokenized startings sequence(s). + num_beams (int): Number of beams to use in the beam search. + max_len (int): Maximum number of tokens to generate. + beam_size_token (int): Vocab size for the LM being used. + beam_threshold (int): Threshold before pruning. + eos_score (float): Score to input when `eos_idx` is generated. + eos_idx (int): End-of-sequence index. + num_python_workers (int): Number of python workers to use for multiprocessing. + model_kwargs + + Returns: + Tensor of the generated sequences. + """ + device = input_ids.device + + if self.is_encoder_decoder: + encoder_output_key = "last_hidden_state" if self.is_huggingface_model else "encoder_output" + encoder_output = model_kwargs["encoder_outputs"][encoder_output_key] + + def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_step_model_states, timestep): + # `emissions` and `N` are unused in this current implementation + + i = T # Hacky access to the current seq in inputs + + # For first timestep, create previous step token_idxs and model_states + if timestep == 0: + prev_step_token_idxs = [-1] + prev_step_model_states = [ + create_emitting_model_state( + Seq2SeqModelState(timestep=0, sequence=input_ids[i].unsqueeze(0), lm_scores=None) + ) + ] + + encoder_output_for_curr_seq = encoder_output[i, :, :].unsqueeze(0) if self.is_encoder_decoder else None + prev_model_state_sequences = [ + get_obj_from_emitting_model_state(state).sequence for state in prev_step_model_states + ] + out_probs, model_states = [], [] + + start = 0 + # This is the parallelism level at which elements in the beam will be batched + step = int( + min(max_inference_batch_size, 1000 / (timestep + 1)) + ) # many hypotheses will EOS, so increase the batch size gradually + curr_beam_size = len(prev_step_token_idxs) + + # 2. Batched inference to get next tokens + while start < curr_beam_size: # catch the remainder + end = start + step + if end > curr_beam_size: + end = curr_beam_size + + num_samples = end - start + + if prev_step_token_idxs != [-1]: + state_sequences = torch.cat(prev_model_state_sequences[start:end], dim=0) + token_indices = ( + torch.Tensor(prev_step_token_idxs[start:end]) + .to(dtype=torch.long, device=device) + .reshape(num_samples, 1) + ) + + state_and_tokens = torch.cat( + [state_sequences, token_indices], dim=-1 + ) # [batch_size x (timestep + 1)] + assert state_and_tokens.shape == ( + num_samples, + timestep + 1, + ), f"state_and_tokens has shape {state_and_tokens.shape} = expected {(num_samples, timestep + 1)}" + else: + assert len(prev_model_state_sequences) == 1 + state_and_tokens = prev_model_state_sequences[0] + + # Cleanup -- combine this with the above + if self.is_encoder_decoder: + # Expand encoder outputs along the batch dimension so that they match the decoder input state's batch size + # This is a view-only operation and doesn't copy + model_kwargs["encoder_outputs"][encoder_output_key] = encoder_output_for_curr_seq.expand( + num_samples if timestep > 0 else 1, -1, -1 + ) + + # Preprocess inputs for generation + model_inputs = self.model.prepare_inputs_for_generation(state_and_tokens, **model_kwargs) + if self.is_huggingface_model: + model_inputs.update(self._huggingface_model_input_values) + + # Forward pass + outputs = self.model(**model_inputs) + + # Collect outputs + output_key = "logits" if self.is_huggingface_model else "decoder_output" + lm_scores = outputs[output_key] + + # Keep track of probabilities over vocab for this pairing + for sample_idx in range(num_samples): + sample_lm_scores = lm_scores[sample_idx, -1] + out_probs.append(sample_lm_scores.tolist()) + # Keep track of sequence and decoder hidden states + model_states.append( + create_emitting_model_state( + Seq2SeqModelState( + timestep=timestep, + sequence=state_and_tokens[sample_idx].unsqueeze(0), + lm_scores=sample_lm_scores, + ) + ) + ) + + start += step + + return out_probs, model_states + + # 3. Initialize options and decoder from Flashlight Text + options = LexiconFreeSeq2SeqDecoderOptions( + beam_size=num_beams, + beam_size_token=beam_size_token, + beam_threshold=beam_threshold, + lm_weight=0.0, # We have no custom LM so score is zero + eos_score=eos_score, + log_add=True, + ) + + decoder = LexiconFreeSeq2SeqDecoder( + options=options, lm=ZeroLM(), eos_idx=eos_idx, update_func=update_func, max_output_length=max_len + ) + + # 4. Process outputs from beam decoder + # TODO: This can definitely be optimized + def beam_decode_step(timestep: int) -> torch.Tensor: + # Create these as function b/c unnamed functions (lambdas) cause problems w/ MP + def select_second_elem_in_tuple(tup: Tuple[List[int], float]) -> float: + return tup[1] + + def is_not_neg_one(elem: int) -> bool: + return elem != -1 + + # Decode step takes ptr to encoder emissions, i, and beam size token + # but actually these aren't currently being used. + decoder.decode_step(encoder_output.data_ptr(), timestep, beam_size_token) + hyps = decoder.get_all_final_hypothesis() + + # import pdb + # pdb.set_trace() + + # Find the best beam + token_scores = [(hyp.tokens, hyp.score) for hyp in hyps] + final_tokens = list(filter(is_not_neg_one, max(token_scores, key=select_second_elem_in_tuple)[0])) + + # Have to prepend the input tokens if decoder-only model + if not self.is_encoder_decoder: + final_tokens = input_ids[timestep].tolist() + final_tokens + + # Makeshift padding so that we can stack the tensors + while len(final_tokens) < max_len: + final_tokens += [0] + + # Convert from list to tensors + final_tokens_as_tensors = torch.Tensor(final_tokens).to(torch.long) + + import pdb + + pdb.set_trace() + + return final_tokens_as_tensors + + if num_python_workers > 1: + # with multiprocessing.Pool(num_python_workers) as pool: + # all_final_tokens = pool.map(beam_decode_step, range(len(input_ids))) + warnings.warn("Multiprocessing has not yet been implemented.") + else: + all_final_tokens = [beam_decode_step(i) for i in range(len(input_ids))] + + # 5. Return top hypotheses for all input sequences + return torch.stack(all_final_tokens, dim=0) + + def forward( + self, + inputs: Optional[torch.Tensor] = None, + num_beams: Optional[int] = None, + max_len: Optional[int] = None, + pad_idx: int = 0, + eos_idx: int = 1, + beam_threshold: int = 100, + beam_size_token: Optional[int] = None, + eos_score: float = -1.0, + num_python_workers: int = 1, + max_inference_batch_size: int = 16, + ): + """Calls self.generate() method.""" + warnings.warn("Forward method simply calls `GenerationUtils.generate()`. Please use generate method directly.") + return self.generate( + inputs=inputs, + num_beams=num_beams, + max_len=max_len, + pad_idx=pad_idx, + eos_idx=eos_idx, + beam_threshold=beam_threshold, + beam_size_token=beam_size_token, + eos_score=eos_score, + num_python_workers=num_python_workers, + max_inference_batch_size=max_inference_batch_size, + ) def generate( self, @@ -106,39 +382,66 @@ def generate( max_length: Optional[int] = None, pad_idx: int = 0, eos_idx: int = 1, + num_python_workers: int = 1, + beam_threshold: int = 100, + vocab_size: Optional[int] = None, + eos_score: float = 0.0, + max_inference_batch_size: int = 32, ) -> torch.Tensor: - """Generation method. - - `num_beams` == 1 or `num_beams` is None -> greedy search - `num_beams` > 1 -> beam search + """Entrypoint generation method. Args: input_ids (Tensor): Ids of tokenized input tokens. The 'seed' text for generation. num_beams (int): If provided, specifies the number of beams to use in beam search generation. max_length (int): Max length to generate responses. pad_idx (int): Padding index. Defaults to 0. - eos_idx (int): End of sequence index. Defaults to 1. + eos_idx (int): End-of-sequence index. Defaults to 1. + num_python_workers (int): If > 1, using multiprocessing on CPU. + vocab_size (int): Vocab size for the beam search algo to evaluate, can typically default to vocab size of the model. + beam_threshold (int): Threshold before pruning; specific to beam search. + eos_score (float): Score to input when `eos_idx` is generated; specific to beam search. + max_inference_batch_size (int): In beam search, to avoid OOMs, can choose to batch smaller amounts of hypothesis; defaults to 32. Returns: Tensor of Tensors containing output sequences as ids. - `Note`: If one beam is provided or no beams are specified, the generation method will default to greedy search. + Conditions for generation: + 1. `num_beams` == 1 or `num_beams` is None -> greedy search + 2. `num_beams` > 1 -> beam search """ - model_kwargs = {} + model_kwargs: MODEL_KWARGS_TYPE = {} if self.is_encoder_decoder: - encoder = self.model.get_encoder() - model_kwargs["encoder_outputs"] = encoder(inputs) - inputs = self._prepare_decoder_ids_for_generation(len(inputs), device=inputs.device, **model_kwargs) + assert torch.jit.isinstance(inputs, torch.Tensor) + model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(inputs) + inputs = self._prepare_decoder_ids_for_generation( + len(inputs), device=inputs.device, model_kwargs=model_kwargs + ) if max_length is None: # Too hard to try to figure out the exact max_seq_length for each model - logger.warning(f"`max_length` was not specified. Defaulting to {DEFAULT_MAX_SEQ_LEN} tokens.") + warnings.warn(f"`max_length` was not specified. Defaulting to {DEFAULT_MAX_SEQ_LEN} tokens.") max_length = DEFAULT_MAX_SEQ_LEN - if num_beams == 1 or num_beams is None: - return self.greedy_search(inputs, max_length, eos_idx, pad_idx=pad_idx, **model_kwargs) + if num_beams is None or num_beams == 1: + if num_python_workers > 1: + warnings.warn(f"Multiprocessing is not implemented for greedy search.") + return self.greedy_search(inputs, max_length, eos_idx, pad_idx=pad_idx, model_kwargs=model_kwargs) elif num_beams > 1: - return self.beam_search(inputs, num_beams, max_length) + assert ( + vocab_size is not None + ), "`vocab_size` must be specified for beam search. If confused about what to put, you can default to the vocab size of the model you are using." + return self.beam_search( + inputs, + num_beams, + beam_size_token=vocab_size, + max_len=max_length, + beam_threshold=beam_threshold, + eos_score=eos_score, + num_python_workers=num_python_workers, + eos_idx=eos_idx, + max_inference_batch_size=max_inference_batch_size, + model_kwargs=model_kwargs, + ) else: raise ValueError("`num_beams` must be >= 1.")