Skip to content

Commit 377f542

Browse files
committed
Bugfix AttLSTMCond config
1 parent 0d0ce3c commit 377f542

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

keras/layers/core.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def get_config(self):
6363
base_config = super(Masking, self).get_config()
6464
return dict(list(base_config.items()) + list(config.items()))
6565

66-
6766
class Dropout(Layer):
6867
'''Applies Dropout to the input. Dropout consists in randomly setting
6968
a fraction `p` of input units to 0 at each update during training time,
@@ -1510,3 +1509,25 @@ def get_config(self):
15101509
config = {'indices': self.indices}
15111510
base_config = super(SetSubtensor, self).get_config()
15121511
return dict(list(base_config.items()) + list(config.items()))
1512+
1513+
1514+
class RemoveMask(Layer):
1515+
def __init__(self, **kwargs):
1516+
super(RemoveMask, self).__init__(**kwargs)
1517+
1518+
def compute_mask(self, input, input_mask=None):
1519+
return None
1520+
1521+
"""
1522+
class LambdaRemoveMask(Lambda):
1523+
def __init__(self, lambda_fn):
1524+
super(LambdaRemoveMask, self).__init__((lambda_fn))
1525+
#self.supports_masking = True
1526+
1527+
def compute_mask(self, input, input_mask=None):
1528+
return None
1529+
1530+
#def get_config(self):
1531+
#base_config = super(LambdaRemoveMask, self).get_config()
1532+
#return dict(list(base_config.items()))
1533+
"""

keras/layers/recurrent.py

+1
Original file line numberDiff line numberDiff line change
@@ -3666,6 +3666,7 @@ def get_initial_states(self, x):
36663666
def get_config(self):
36673667
config = {'output_dim': self.output_dim,
36683668
'return_extra_variables': self.return_extra_variables,
3669+
'return_states': self.return_states,
36693670
'init': self.init.__name__,
36703671
'inner_init': self.inner_init.__name__,
36713672
'forget_bias_init': self.forget_bias_init.__name__,

0 commit comments

Comments
 (0)