Skip to content

Commit e9efd28

Browse files
committed
2 parents 377f542 + b37ca92 commit e9efd28

File tree

5 files changed

+479
-368
lines changed

5 files changed

+479
-368
lines changed

keras/engine/topology.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -365,15 +365,15 @@ def non_trainable_weights(self, weights):
365365

366366
@property
367367
def regularizers(self):
368-
warnings.warn('The `regularizers` property of '
368+
warnings.warn('Layer ' + self.name + ': The `regularizers` property of '
369369
'layers/models is deprecated. '
370370
'Regularization losses are now managed via the `losses` '
371371
'layer/model property.')
372372
return []
373373

374374
@regularizers.setter
375375
def regularizers(self, _):
376-
warnings.warn('The `regularizers` property of layers/models '
376+
warnings.warn('Layer ' + self.name + ': The `regularizers` property of layers/models '
377377
'is deprecated. '
378378
'Regularization losses are now managed via the `losses` '
379379
'layer/model property.')

keras/initializations.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def norm_weight(shape, scale=0.01, ortho=True, name=None):
102102
"""
103103
Random weights drawn from a Gaussian
104104
"""
105-
assert len(shape)==2, 'shape must have length 2'
105+
assert len(shape)>0, 'shape must have length > 0. Currently, it has length == ' + str(len(shape))
106106
if shape[0] == shape[1] and ortho:
107107
W = ortho_weight(shape)
108108
else:

keras/layers/core.py

+64-2
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,46 @@ def get_config(self):
410410
return dict(list(base_config.items()) + list(config.items()))
411411

412412

413+
414+
class PermuteGeneral(Layer):
415+
'''Permutes the dimensions of the input according to a given pattern.
416+
This is just like the layer Permute, but DOES INCLUDE the batch dimension.
417+
418+
# Arguments
419+
dims: Tuple of integers. Permutation pattern, INCLUDING the
420+
samples dimension. Indexing starts at 0.
421+
For instance, `(1, 0, 2)` permutes the batch and first dimension of the input.
422+
423+
# Input shape
424+
Arbitrary. Use the keyword argument `input_shape`
425+
(tuple of integers, does not include the samples axis)
426+
when using this layer as the first layer in a model.
427+
428+
# Output shape
429+
Same as the input shape, but with the dimensions re-ordered according
430+
to the specified pattern.
431+
'''
432+
def __init__(self, dims, **kwargs):
433+
self.dims = tuple(dims)
434+
self.supports_masking = True
435+
super(PermuteGeneral, self).__init__(**kwargs)
436+
437+
def get_output_shape_for(self, input_shape):
438+
input_shape = list(input_shape)
439+
output_shape = copy.copy(input_shape)
440+
for i, dim in enumerate(self.dims):
441+
output_shape[i] = input_shape[dim]
442+
return tuple(output_shape)
443+
444+
def call(self, x, mask=None):
445+
return K.permute_dimensions(x, self.dims)
446+
447+
def get_config(self):
448+
config = {'dims': self.dims}
449+
base_config = super(PermuteGeneral, self).get_config()
450+
return dict(list(base_config.items()) + list(config.items()))
451+
452+
413453
class Flatten(Layer):
414454
'''Flattens the input. Does not affect the batch size.
415455
@@ -581,7 +621,7 @@ def antirectifier_output_shape(input_shape):
581621
def __init__(self, function, output_shape=None, arguments=None, **kwargs):
582622
self.function = function
583623
self.arguments = arguments if arguments else {}
584-
self.supports_masking = False
624+
self.supports_masking = True
585625

586626
if output_shape is None:
587627
self._output_shape = None
@@ -1303,6 +1343,28 @@ def get_config(self):
13031343
base_config = super(MaskedMean, self).get_config()
13041344
return dict(list(base_config.items()))
13051345

1346+
1347+
class MaskLayer(Layer):
1348+
"""
1349+
Applies to the input layer its mask
1350+
"""
1351+
def __init__(self, **kwargs):
1352+
self.support_mask = True
1353+
super(MaskLayer, self).__init__(**kwargs)
1354+
1355+
def call(self, x, mask=None):
1356+
return mask[:, :, None] * x
1357+
1358+
def compute_mask(self, input_shape, input_mask=None):
1359+
return input_mask
1360+
1361+
def get_output_shape_for(self, input_shape):
1362+
return input_shape
1363+
1364+
def get_config(self):
1365+
base_config = super(MaskLayer, self).get_config()
1366+
return dict(list(base_config.items()))
1367+
13061368
class WeightedSum(Layer):
13071369
''' Applies a weighted sum over a set of vectors input[0] and their respective weights input[1].
13081370
First, the weights are tiled for matching the length of the input vectors on dim=1.
@@ -1530,4 +1592,4 @@ def compute_mask(self, input, input_mask=None):
15301592
#def get_config(self):
15311593
#base_config = super(LambdaRemoveMask, self).get_config()
15321594
#return dict(list(base_config.items()))
1533-
"""
1595+
"""

0 commit comments

Comments
 (0)