Skip to content

Commit 19b1418

Browse files
authored
Enable cuDNN RNNs when dropout is set and training=True (#20983)
1 parent 465a3d2 commit 19b1418

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

keras/src/layers/rnn/gru.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ class GRU(RNN):
354354
355355
1. `activation` == `tanh`
356356
2. `recurrent_activation` == `sigmoid`
357-
3. `dropout` == 0 and `recurrent_dropout` == 0
357+
3. `recurrent_dropout` == 0
358358
4. `unroll` is `False`
359359
5. `use_bias` is `True`
360360
6. `reset_after` is `True`
@@ -553,7 +553,7 @@ def inner_loop(self, sequences, initial_state, mask, training=False):
553553
if self.use_cudnn in ("auto", True):
554554
if not self.recurrent_dropout:
555555
try:
556-
if self.dropout:
556+
if training and self.dropout:
557557
dp_mask = self.cell.get_dropout_mask(sequences[:, 0, :])
558558
dp_mask = ops.expand_dims(dp_mask, axis=1)
559559
dp_mask = ops.broadcast_to(

keras/src/layers/rnn/lstm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ class LSTM(RNN):
343343
344344
1. `activation` == `tanh`
345345
2. `recurrent_activation` == `sigmoid`
346-
3. `dropout` == 0 and `recurrent_dropout` == 0
346+
3. `recurrent_dropout` == 0
347347
4. `unroll` is `False`
348348
5. `use_bias` is `True`
349349
6. Inputs, if use masking, are strictly right-padded.
@@ -534,7 +534,7 @@ def inner_loop(self, sequences, initial_state, mask, training=False):
534534
if self.use_cudnn in ("auto", True):
535535
if not self.recurrent_dropout:
536536
try:
537-
if self.dropout:
537+
if training and self.dropout:
538538
dp_mask = self.cell.get_dropout_mask(sequences[:, 0, :])
539539
dp_mask = ops.expand_dims(dp_mask, axis=1)
540540
dp_mask = ops.broadcast_to(

0 commit comments

Comments
 (0)