Skip to content

Commit 27dc287

Browse files
committed
Clamping beam_idxs, still fails trying to cat past_states to hidden_states
1 parent fc806fc commit 27dc287

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

torchtext/prototype/generate.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -303,25 +303,26 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
303303

304304
# We could store this in model_kwargs
305305
num_hyps_in_prev_step = model_kwargs["past"][0][0].shape[0]
306-
306+
307307
num_finished_hyps_in_step = num_hyps_in_prev_step - len(prev_step_hyp_idxs)
308308
if num_finished_hyps_in_step > 0:
309309
beam_idxs = F.pad(beam_idxs, (0, num_finished_hyps_in_step), "constant", 0)
310-
310+
311+
beam_idxs = torch.clamp(beam_idxs, max=len(prev_step_hyp_idxs) - 1)
312+
311313
reordered_cached = self.model._reorder_cache(model_kwargs["past"], beam_idxs)
312314

313315
if num_finished_hyps_in_step > 0:
314316
sliced_cache = ()
315317
for states in reordered_cached:
316318
sliced_state = ()
317319
for state in states:
318-
sliced_state = sliced_state + (state[:len(prev_step_hyp_idxs)],)
320+
sliced_state = sliced_state + (state[: len(prev_step_hyp_idxs)],)
319321
sliced_cache = sliced_cache + (sliced_state,)
320322
reordered_cached = sliced_cache
321323

322324
model_inputs["past_key_values"] = reordered_cached
323325

324-
325326
# Forward pass
326327
outputs = self.model(**model_inputs)
327328

0 commit comments

Comments
 (0)