@@ -410,6 +410,46 @@ def get_config(self):
410
410
return dict (list (base_config .items ()) + list (config .items ()))
411
411
412
412
413
+
414
+ class PermuteGeneral (Layer ):
415
+ '''Permutes the dimensions of the input according to a given pattern.
416
+ This is just like the layer Permute, but DOES INCLUDE the batch dimension.
417
+
418
+ # Arguments
419
+ dims: Tuple of integers. Permutation pattern, INCLUDING the
420
+ samples dimension. Indexing starts at 0.
421
+ For instance, `(1, 0, 2)` permutes the batch and first dimension of the input.
422
+
423
+ # Input shape
424
+ Arbitrary. Use the keyword argument `input_shape`
425
+ (tuple of integers, does not include the samples axis)
426
+ when using this layer as the first layer in a model.
427
+
428
+ # Output shape
429
+ Same as the input shape, but with the dimensions re-ordered according
430
+ to the specified pattern.
431
+ '''
432
+ def __init__ (self , dims , ** kwargs ):
433
+ self .dims = tuple (dims )
434
+ self .supports_masking = True
435
+ super (PermuteGeneral , self ).__init__ (** kwargs )
436
+
437
+ def get_output_shape_for (self , input_shape ):
438
+ input_shape = list (input_shape )
439
+ output_shape = copy .copy (input_shape )
440
+ for i , dim in enumerate (self .dims ):
441
+ output_shape [i ] = input_shape [dim ]
442
+ return tuple (output_shape )
443
+
444
+ def call (self , x , mask = None ):
445
+ return K .permute_dimensions (x , self .dims )
446
+
447
+ def get_config (self ):
448
+ config = {'dims' : self .dims }
449
+ base_config = super (PermuteGeneral , self ).get_config ()
450
+ return dict (list (base_config .items ()) + list (config .items ()))
451
+
452
+
413
453
class Flatten (Layer ):
414
454
'''Flattens the input. Does not affect the batch size.
415
455
@@ -581,7 +621,7 @@ def antirectifier_output_shape(input_shape):
581
621
def __init__ (self , function , output_shape = None , arguments = None , ** kwargs ):
582
622
self .function = function
583
623
self .arguments = arguments if arguments else {}
584
- self .supports_masking = False
624
+ self .supports_masking = True
585
625
586
626
if output_shape is None :
587
627
self ._output_shape = None
@@ -1303,6 +1343,28 @@ def get_config(self):
1303
1343
base_config = super (MaskedMean , self ).get_config ()
1304
1344
return dict (list (base_config .items ()))
1305
1345
1346
+
1347
+ class MaskLayer (Layer ):
1348
+ """
1349
+ Applies to the input layer its mask
1350
+ """
1351
+ def __init__ (self , ** kwargs ):
1352
+ self .support_mask = True
1353
+ super (MaskLayer , self ).__init__ (** kwargs )
1354
+
1355
+ def call (self , x , mask = None ):
1356
+ return mask [:, :, None ] * x
1357
+
1358
+ def compute_mask (self , input_shape , input_mask = None ):
1359
+ return input_mask
1360
+
1361
+ def get_output_shape_for (self , input_shape ):
1362
+ return input_shape
1363
+
1364
+ def get_config (self ):
1365
+ base_config = super (MaskLayer , self ).get_config ()
1366
+ return dict (list (base_config .items ()))
1367
+
1306
1368
class WeightedSum (Layer ):
1307
1369
''' Applies a weighted sum over a set of vectors input[0] and their respective weights input[1].
1308
1370
First, the weights are tiled for matching the length of the input vectors on dim=1.
@@ -1530,4 +1592,4 @@ def compute_mask(self, input, input_mask=None):
1530
1592
#def get_config(self):
1531
1593
#base_config = super(LambdaRemoveMask, self).get_config()
1532
1594
#return dict(list(base_config.items()))
1533
- """
1595
+ """
0 commit comments