@@ -228,6 +228,12 @@ def beam_search(
228
228
encoder_output_key = "last_hidden_state" if self .is_huggingface_model else "encoder_output"
229
229
encoder_output = model_kwargs ["encoder_outputs" ][encoder_output_key ]
230
230
231
+ num_sequences = input_ids .shape [0 ]
232
+
233
+ # Pre-allocate everything
234
+ token_idxs = torch .full ((num_sequences , num_beams , 1 ), eos_idx ).to (dtype = torch .long , device = device )
235
+ beam_idxs = torch .zeros ((num_sequences , num_beams , 1 )).to (dtype = torch .long , device = device )
236
+
231
237
def update_func (emissions , N , T , prev_step_token_idxs , prev_step_hyp_idxs , prev_step_model_states , timestep ):
232
238
# `emissions` and `N` are unused in this current implementation
233
239
@@ -236,16 +242,8 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
236
242
# For first timestep, create previous step token_idxs and model_states
237
243
if timestep == 0 :
238
244
prev_step_token_idxs = [- 1 ]
239
- prev_step_model_states = [
240
- create_emitting_model_state (
241
- Seq2SeqModelState (timestep = 0 , sequence = input_ids [i ].unsqueeze (0 ), lm_scores = None )
242
- )
243
- ]
244
245
245
246
encoder_output_for_curr_seq = encoder_output [i , :, :].unsqueeze (0 ) if self .is_encoder_decoder else None
246
- prev_model_state_sequences = [
247
- get_obj_from_emitting_model_state (state ).sequence for state in prev_step_model_states
248
- ]
249
247
out_probs , model_states = [], []
250
248
251
249
start = 0
@@ -261,66 +259,32 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
261
259
if end > curr_beam_size :
262
260
end = curr_beam_size
263
261
264
- num_samples = end - start
265
-
266
262
if prev_step_token_idxs != [- 1 ]:
267
- state_sequences = torch .cat (prev_model_state_sequences [start :end ], dim = 0 )
268
- token_indices = (
269
- torch .Tensor (prev_step_token_idxs [start :end ])
270
- .to (dtype = torch .long , device = device )
271
- .reshape (num_samples , 1 )
272
- )
273
-
274
- state_and_tokens = torch .cat (
275
- [state_sequences , token_indices ], dim = - 1
276
- ) # [batch_size x (timestep + 1)]
277
- assert state_and_tokens .shape == (
278
- num_samples ,
279
- timestep + 1 ,
280
- ), f"state_and_tokens has shape { state_and_tokens .shape } = expected { (num_samples , timestep + 1 )} "
263
+ token_indices = torch .Tensor (prev_step_token_idxs [start :end ]).to (dtype = torch .long , device = device )
264
+ token_idxs [i , : len (token_indices ), 0 ] = token_indices
265
+ curr_token_idxs = token_idxs [i , :, 0 ].reshape (num_beams , 1 )
281
266
else :
282
- assert len (prev_model_state_sequences ) == 1
283
- state_and_tokens = token_indices = prev_model_state_sequences [0 ].expand (
284
- num_beams , - 1
285
- ) # TODO: Make this more robust
286
-
287
- # Cleanup -- combine this with the above
288
- if self .is_encoder_decoder :
289
- # Expand encoder outputs along the batch dimension so that they match the decoder input state's batch size
290
- # This is a view-only operation and doesn't copy
291
- model_kwargs ["encoder_outputs" ][encoder_output_key ] = encoder_output_for_curr_seq .expand (
292
- num_samples if timestep > 0 else num_beams , - 1 , - 1
293
- )
267
+ if self .is_encoder_decoder :
268
+ # Expand encoder outputs along the batch dimension so that they match the decoder input state's batch size
269
+ # This is a view-only operation and doesn't copy
270
+ model_kwargs ["encoder_outputs" ][encoder_output_key ] = encoder_output_for_curr_seq .expand (
271
+ num_beams , - 1 , - 1
272
+ )
273
+ curr_token_idxs = torch .zeros ((num_beams , 1 )).to (dtype = torch .long , device = device )
274
+
294
275
295
276
# Preprocess inputs for generation
296
277
model_inputs = self .model .prepare_inputs_for_generation (
297
- token_indices , ** model_kwargs
278
+ curr_token_idxs , ** model_kwargs
298
279
) # This should technically work with state_and_tokens, but the prepare function has to splice if past (like HF does)
299
280
if self .is_huggingface_model :
300
281
model_inputs .update (self ._huggingface_model_input_values )
301
282
if len (prev_step_hyp_idxs ) > 1 and model_kwargs ["past" ] is not None :
302
- beam_idxs = torch .Tensor (prev_step_hyp_idxs ).to (dtype = torch .int32 )
303
-
304
- # We could store this in model_kwargs
305
- num_hyps_in_prev_step = model_kwargs ["past" ][0 ][0 ].shape [0 ]
306
-
307
- num_finished_hyps_in_step = num_hyps_in_prev_step - len (prev_step_hyp_idxs )
308
- if num_finished_hyps_in_step > 0 :
309
- beam_idxs = F .pad (beam_idxs , (0 , num_finished_hyps_in_step ), "constant" , 0 )
310
-
311
- beam_idxs = torch .clamp (beam_idxs , max = len (prev_step_hyp_idxs ) - 1 )
312
-
313
- reordered_cached = self .model ._reorder_cache (model_kwargs ["past" ], beam_idxs )
314
-
315
- if num_finished_hyps_in_step > 0 :
316
- sliced_cache = ()
317
- for states in reordered_cached :
318
- sliced_state = ()
319
- for state in states :
320
- sliced_state = sliced_state + (state [: len (prev_step_hyp_idxs )],)
321
- sliced_cache = sliced_cache + (sliced_state ,)
322
- reordered_cached = sliced_cache
283
+ beam_indices = torch .Tensor (prev_step_hyp_idxs ).to (dtype = torch .int32 )
284
+ beam_idxs [i , : len (prev_step_hyp_idxs ), 0 ] = beam_indices
285
+ curr_beam_idxs = beam_idxs [i , :, 0 ]
323
286
287
+ reordered_cached = self .model ._reorder_cache (model_kwargs ["past" ], curr_beam_idxs )
324
288
model_inputs ["past_key_values" ] = reordered_cached
325
289
326
290
# Forward pass
@@ -334,18 +298,21 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
334
298
if self .is_huggingface_model :
335
299
self ._update_model_kwargs_for_generation (outputs , model_kwargs )
336
300
301
+ # Reset
302
+ token_idxs [i , :, 0 ] = eos_idx
303
+ beam_idxs [i , :, 0 ] = 0
304
+
337
305
# Keep track of probabilities over vocab for this pairing
338
- # TODO: fix how we track the number here?
339
- for i in range (lm_scores .shape [0 ]):
306
+ for i in range (num_beams ):
340
307
sample_lm_scores = lm_scores [i , - 1 ]
341
308
out_probs .append (sample_lm_scores .tolist ())
342
309
# Keep track of sequence and decoder hidden states
343
310
model_states .append (
344
311
create_emitting_model_state (
345
312
Seq2SeqModelState (
346
313
timestep = timestep ,
347
- sequence = state_and_tokens [ i ]. unsqueeze ( 0 ) ,
348
- lm_scores = sample_lm_scores ,
314
+ sequence = [] ,
315
+ lm_scores = 0 ,
349
316
)
350
317
)
351
318
)
@@ -391,10 +358,6 @@ def is_not_neg_one(elem: int) -> bool:
391
358
if not self .is_encoder_decoder :
392
359
final_tokens = input_ids [timestep ].tolist () + final_tokens
393
360
394
- # Makeshift padding so that we can stack the tensors
395
- while len (final_tokens ) < max_len :
396
- final_tokens += [0 ]
397
-
398
361
# Convert from list to tensors
399
362
final_tokens_as_tensors = torch .Tensor (final_tokens ).to (torch .long )
400
363
0 commit comments