@@ -1099,7 +1099,7 @@ class AttGRUCond(Recurrent):
1099
1099
'''
1100
1100
def __init__ (self , output_dim , return_extra_variables = False , return_states = False ,
1101
1101
init = 'glorot_uniform' , inner_init = 'orthogonal' ,
1102
- activation = 'tanh' , inner_activation = 'hard_sigmoid' ,
1102
+ activation = 'tanh' , inner_activation = 'hard_sigmoid' , mask_value = 0. ,
1103
1103
W_regularizer = None , U_regularizer = None , V_regularizer = None , b_regularizer = None ,
1104
1104
wa_regularizer = None , Wa_regularizer = None , Ua_regularizer = None , ba_regularizer = None , ca_regularizer = None ,
1105
1105
dropout_W = 0. , dropout_U = 0. , dropout_V = 0. , dropout_wa = 0. , dropout_Wa = 0. , dropout_Ua = 0. , ** kwargs ):
@@ -1120,6 +1120,7 @@ def __init__(self, output_dim, return_extra_variables=False, return_states=False
1120
1120
self .Ua_regularizer = regularizers .get (Ua_regularizer )
1121
1121
self .ba_regularizer = regularizers .get (ba_regularizer )
1122
1122
self .ca_regularizer = regularizers .get (ca_regularizer )
1123
+ self .mask_value = mask_value
1123
1124
1124
1125
self .dropout_W , self .dropout_U , self .dropout_V = dropout_W , dropout_U , dropout_V
1125
1126
self .dropout_wa , self .dropout_Wa , self .dropout_Ua = dropout_wa , dropout_Wa , dropout_Ua
@@ -1392,6 +1393,7 @@ def step(self, x, states):
1392
1393
pctx_ = states [7 ] # Projected context (i.e. context * Ua + ba)
1393
1394
context = states [8 ] # Original context
1394
1395
mask_input = states [9 ] # Context mask
1396
+
1395
1397
if mask_input .ndim > 1 : # Mask the context (only if necessary)
1396
1398
pctx_ = mask_input [:, :, None ] * pctx_
1397
1399
context = mask_input [:, :, None ] * context # Masked context
@@ -1400,6 +1402,8 @@ def step(self, x, states):
1400
1402
p_state_ = K .dot (h_tm1 * B_Wa [0 ], self .Wa )
1401
1403
pctx_ = K .tanh (pctx_ + p_state_ [:, None , :])
1402
1404
e = K .dot (pctx_ * B_wa [0 ], self .wa ) + self .ca
1405
+ if mask_input .ndim > 1 : # Mask the context (only if necessary)
1406
+ e = mask_input * e
1403
1407
alphas_shape = e .shape
1404
1408
alphas = K .softmax (e .reshape ([alphas_shape [0 ], alphas_shape [1 ]]))
1405
1409
ctx_ = (context * alphas [:, :, None ]).sum (axis = 1 ) # sum over the in_timesteps dimension resulting in [batch_size, input_dim]
@@ -1419,6 +1423,7 @@ def step(self, x, states):
1419
1423
x_h = matrix_x [:, 2 * self .output_dim :]
1420
1424
inner_h = K .dot (r * h_tm1 * B_U [0 ], self .U [:, 2 * self .output_dim :])
1421
1425
hh = self .activation (x_h + inner_h )
1426
+
1422
1427
h = z * h_tm1 + (1 - z ) * hh
1423
1428
1424
1429
return h , [h , ctx_ , alphas ]
@@ -1490,7 +1495,7 @@ def get_constants(self, x, mask_input):
1490
1495
1491
1496
# States[9]
1492
1497
if mask_input is None :
1493
- mask_input = K .variable ([] )
1498
+ mask_input = K .not_equal ( K . sum ( self . context , axis = 2 ), self . mask_value )
1494
1499
constants .append (mask_input )
1495
1500
1496
1501
return constants , B_V
@@ -1523,6 +1528,7 @@ def get_config(self):
1523
1528
'inner_init' : self .inner_init .__name__ ,
1524
1529
'activation' : self .activation .__name__ ,
1525
1530
'inner_activation' : self .inner_activation .__name__ ,
1531
+ 'mask_value' : self .mask_value ,
1526
1532
'W_regularizer' : self .W_regularizer .get_config () if self .W_regularizer else None ,
1527
1533
'U_regularizer' : self .U_regularizer .get_config () if self .U_regularizer else None ,
1528
1534
'V_regularizer' : self .V_regularizer .get_config () if self .U_regularizer else None ,
@@ -3237,7 +3243,7 @@ class AttLSTMCond(Recurrent):
3237
3243
'''
3238
3244
def __init__ (self , output_dim , return_extra_variables = False , return_states = False ,
3239
3245
init = 'glorot_uniform' , inner_init = 'orthogonal' ,
3240
- forget_bias_init = 'one' , activation = 'tanh' , inner_activation = 'sigmoid' ,
3246
+ forget_bias_init = 'one' , activation = 'tanh' , inner_activation = 'sigmoid' , mask_value = 0. ,
3241
3247
W_regularizer = None , U_regularizer = None , V_regularizer = None , b_regularizer = None ,
3242
3248
wa_regularizer = None , Wa_regularizer = None , Ua_regularizer = None , ba_regularizer = None , ca_regularizer = None ,
3243
3249
dropout_W = 0. , dropout_U = 0. , dropout_V = 0. , dropout_wa = 0. , dropout_Wa = 0. , dropout_Ua = 0. ,
@@ -3260,6 +3266,7 @@ def __init__(self, output_dim, return_extra_variables=False, return_states=False
3260
3266
self .Ua_regularizer = regularizers .get (Ua_regularizer )
3261
3267
self .ba_regularizer = regularizers .get (ba_regularizer )
3262
3268
self .ca_regularizer = regularizers .get (ca_regularizer )
3269
+ self .mask_value = mask_value
3263
3270
3264
3271
# Dropouts
3265
3272
self .dropout_W , self .dropout_U , self .dropout_V = dropout_W , dropout_U , dropout_V
@@ -3579,6 +3586,8 @@ def step(self, x, states):
3579
3586
p_state_ = K .dot (h_tm1 * B_Wa [0 ], self .Wa )
3580
3587
pctx_ = K .tanh (pctx_ + p_state_ [:, None , :])
3581
3588
e = K .dot (pctx_ * B_wa [0 ], self .wa ) + self .ca
3589
+ if mask_input .ndim > 1 : # Mask the context (only if necessary)
3590
+ e = mask_input * e
3582
3591
alphas_shape = e .shape
3583
3592
alphas = K .softmax (e .reshape ([alphas_shape [0 ], alphas_shape [1 ]]))
3584
3593
# sum over the in_timesteps dimension resulting in [batch_size, input_dim]
@@ -3669,7 +3678,7 @@ def get_constants(self, x, mask_input):
3669
3678
3670
3679
# States[10]
3671
3680
if mask_input is None :
3672
- mask_input = K .variable ([] )
3681
+ mask_input = K .not_equal ( K . sum ( self . context , axis = 2 ), self . mask_value )
3673
3682
constants .append (mask_input )
3674
3683
3675
3684
return constants , B_V
@@ -3710,6 +3719,7 @@ def get_config(self):
3710
3719
'forget_bias_init' : self .forget_bias_init .__name__ ,
3711
3720
'activation' : self .activation .__name__ ,
3712
3721
'inner_activation' : self .inner_activation .__name__ ,
3722
+ 'mask_value' : self .mask_value ,
3713
3723
'W_regularizer' : self .W_regularizer .get_config () if self .W_regularizer else None ,
3714
3724
'U_regularizer' : self .U_regularizer .get_config () if self .U_regularizer else None ,
3715
3725
'V_regularizer' : self .V_regularizer .get_config () if self .U_regularizer else None ,
0 commit comments