Skip to content

Commit b37ca92

Browse files
committed
Create mask_input if None
1 parent 774abd3 commit b37ca92

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

keras/layers/recurrent.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -1099,7 +1099,7 @@ class AttGRUCond(Recurrent):
10991099
'''
11001100
def __init__(self, output_dim, return_extra_variables=False, return_states=False,
11011101
init='glorot_uniform', inner_init='orthogonal',
1102-
activation='tanh', inner_activation='hard_sigmoid',
1102+
activation='tanh', inner_activation='hard_sigmoid', mask_value=0.,
11031103
W_regularizer=None, U_regularizer=None, V_regularizer=None, b_regularizer=None,
11041104
wa_regularizer=None, Wa_regularizer=None, Ua_regularizer=None, ba_regularizer=None, ca_regularizer=None,
11051105
dropout_W=0., dropout_U=0., dropout_V=0., dropout_wa=0., dropout_Wa=0., dropout_Ua=0., **kwargs):
@@ -1120,6 +1120,7 @@ def __init__(self, output_dim, return_extra_variables=False, return_states=False
11201120
self.Ua_regularizer = regularizers.get(Ua_regularizer)
11211121
self.ba_regularizer = regularizers.get(ba_regularizer)
11221122
self.ca_regularizer = regularizers.get(ca_regularizer)
1123+
self.mask_value = mask_value
11231124

11241125
self.dropout_W, self.dropout_U, self.dropout_V = dropout_W, dropout_U, dropout_V
11251126
self.dropout_wa, self.dropout_Wa, self.dropout_Ua = dropout_wa, dropout_Wa, dropout_Ua
@@ -1392,6 +1393,7 @@ def step(self, x, states):
13921393
pctx_ = states[7] # Projected context (i.e. context * Ua + ba)
13931394
context = states[8] # Original context
13941395
mask_input = states[9] # Context mask
1396+
13951397
if mask_input.ndim > 1: # Mask the context (only if necessary)
13961398
pctx_ = mask_input[:, :, None] * pctx_
13971399
context = mask_input[:, :, None] * context # Masked context
@@ -1400,6 +1402,8 @@ def step(self, x, states):
14001402
p_state_ = K.dot(h_tm1 * B_Wa[0], self.Wa)
14011403
pctx_ = K.tanh(pctx_ + p_state_[:, None, :])
14021404
e = K.dot(pctx_ * B_wa[0], self.wa) + self.ca
1405+
if mask_input.ndim > 1: # Mask the context (only if necessary)
1406+
e = mask_input * e
14031407
alphas_shape = e.shape
14041408
alphas = K.softmax(e.reshape([alphas_shape[0], alphas_shape[1]]))
14051409
ctx_ = (context * alphas[:, :, None]).sum(axis=1) # sum over the in_timesteps dimension resulting in [batch_size, input_dim]
@@ -1419,6 +1423,7 @@ def step(self, x, states):
14191423
x_h = matrix_x[:, 2 * self.output_dim:]
14201424
inner_h = K.dot(r * h_tm1 * B_U[0], self.U[:, 2 * self.output_dim:])
14211425
hh = self.activation(x_h + inner_h)
1426+
14221427
h = z * h_tm1 + (1 - z) * hh
14231428

14241429
return h, [h, ctx_, alphas]
@@ -1490,7 +1495,7 @@ def get_constants(self, x, mask_input):
14901495

14911496
# States[9]
14921497
if mask_input is None:
1493-
mask_input = K.variable([])
1498+
mask_input = K.not_equal(K.sum(self.context, axis=2), self.mask_value)
14941499
constants.append(mask_input)
14951500

14961501
return constants, B_V
@@ -1523,6 +1528,7 @@ def get_config(self):
15231528
'inner_init': self.inner_init.__name__,
15241529
'activation': self.activation.__name__,
15251530
'inner_activation': self.inner_activation.__name__,
1531+
'mask_value': self.mask_value,
15261532
'W_regularizer': self.W_regularizer.get_config() if self.W_regularizer else None,
15271533
'U_regularizer': self.U_regularizer.get_config() if self.U_regularizer else None,
15281534
'V_regularizer': self.V_regularizer.get_config() if self.U_regularizer else None,
@@ -3237,7 +3243,7 @@ class AttLSTMCond(Recurrent):
32373243
'''
32383244
def __init__(self, output_dim, return_extra_variables=False, return_states=False,
32393245
init='glorot_uniform', inner_init='orthogonal',
3240-
forget_bias_init='one', activation='tanh', inner_activation='sigmoid',
3246+
forget_bias_init='one', activation='tanh', inner_activation='sigmoid', mask_value=0.,
32413247
W_regularizer=None, U_regularizer=None, V_regularizer=None, b_regularizer=None,
32423248
wa_regularizer=None, Wa_regularizer=None, Ua_regularizer=None, ba_regularizer=None, ca_regularizer=None,
32433249
dropout_W=0., dropout_U=0., dropout_V=0., dropout_wa=0., dropout_Wa=0., dropout_Ua=0.,
@@ -3260,6 +3266,7 @@ def __init__(self, output_dim, return_extra_variables=False, return_states=False
32603266
self.Ua_regularizer = regularizers.get(Ua_regularizer)
32613267
self.ba_regularizer = regularizers.get(ba_regularizer)
32623268
self.ca_regularizer = regularizers.get(ca_regularizer)
3269+
self.mask_value = mask_value
32633270

32643271
# Dropouts
32653272
self.dropout_W, self.dropout_U, self.dropout_V = dropout_W, dropout_U, dropout_V
@@ -3579,6 +3586,8 @@ def step(self, x, states):
35793586
p_state_ = K.dot(h_tm1 * B_Wa[0], self.Wa)
35803587
pctx_ = K.tanh(pctx_ + p_state_[:, None, :])
35813588
e = K.dot(pctx_ * B_wa[0], self.wa) + self.ca
3589+
if mask_input.ndim > 1: # Mask the context (only if necessary)
3590+
e = mask_input * e
35823591
alphas_shape = e.shape
35833592
alphas = K.softmax(e.reshape([alphas_shape[0], alphas_shape[1]]))
35843593
# sum over the in_timesteps dimension resulting in [batch_size, input_dim]
@@ -3669,7 +3678,7 @@ def get_constants(self, x, mask_input):
36693678

36703679
# States[10]
36713680
if mask_input is None:
3672-
mask_input = K.variable([])
3681+
mask_input = K.not_equal(K.sum(self.context, axis=2), self.mask_value)
36733682
constants.append(mask_input)
36743683

36753684
return constants, B_V
@@ -3710,6 +3719,7 @@ def get_config(self):
37103719
'forget_bias_init': self.forget_bias_init.__name__,
37113720
'activation': self.activation.__name__,
37123721
'inner_activation': self.inner_activation.__name__,
3722+
'mask_value': self.mask_value,
37133723
'W_regularizer': self.W_regularizer.get_config() if self.W_regularizer else None,
37143724
'U_regularizer': self.U_regularizer.get_config() if self.U_regularizer else None,
37153725
'V_regularizer': self.V_regularizer.get_config() if self.U_regularizer else None,

0 commit comments

Comments
 (0)