Skip to content

Commit fec76fc

Browse files
authored
fix parameterlist parameterdict (#10)
1 parent 62c49f6 commit fec76fc

File tree

3 files changed

+73
-72
lines changed

3 files changed

+73
-72
lines changed

examples/basic_tutorials/Parameter_Container.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import os
2-
os.environ['TL_BACKEND'] = 'tensorflow'
2+
# os.environ['TL_BACKEND'] = 'tensorflow'
33
# os.environ['TL_BACKEND'] = 'mindspore'
44
# os.environ['TL_BACKEND'] = 'paddle'
5-
# os.environ['TL_BACKEND'] = 'torch'
5+
os.environ['TL_BACKEND'] = 'torch'
66

77
import tensorlayerx as tlx
88
from tensorlayerx.nn import Module, Parameter, ParameterList, ParameterDict
@@ -28,6 +28,10 @@ def forward(self, x, choice):
2828

2929
input = tlx.nn.Input(shape=(5,5))
3030
net = MyModule()
31-
31+
trainable_weights = net.trainable_weights
32+
print("-----------------------------trainable_weights-------------------------------")
33+
for weight in trainable_weights:
34+
print(weight)
35+
print("-----------------------------------output------------------------------------")
3236
output = net(input, choice = 'right')
3337
print(output)

tensorlayerx/nn/core/core_tensorflow.py

+65-68
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ class Module(object):
5555
def __init__(self, name=None, act=None, *args, **kwargs):
5656
self._params = OrderedDict()
5757
self._layers = OrderedDict()
58-
self._params_list = OrderedDict()
59-
self._params_dict = OrderedDict()
58+
# self._params_list = OrderedDict()
59+
# self._params_dict = OrderedDict()
6060
self._params_status = OrderedDict()
6161
self._parameter_layout_dict = {}
6262
self._create_time = int(time.time() * 1e9)
@@ -148,11 +148,11 @@ def __setattr__(self, name, value):
148148
raise TypeError("Expected type is Module, but got Parameter.")
149149
self.insert_param_to_layer(name, value)
150150

151-
elif isinstance(value, ParameterList):
152-
self.set_attr_for_parameter_tuple(name, value)
153-
154-
elif isinstance(value, ParameterDict):
155-
self.set_attr_for_parameter_dict(name, value)
151+
# elif isinstance(value, ParameterList):
152+
# self.set_attr_for_parameter_tuple(name, value)
153+
#
154+
# elif isinstance(value, ParameterDict):
155+
# self.set_attr_for_parameter_dict(name, value)
156156

157157
elif isinstance(value, Module):
158158
if layers is None:
@@ -255,46 +255,46 @@ def _set_mode_for_layers(self, is_train):
255255
if isinstance(layer, Module):
256256
layer.is_train = is_train
257257

258-
def set_attr_for_parameter_dict(self, name, value):
259-
"""Set attr for parameter in ParameterDict."""
260-
params = self.__dict__.get('_params')
261-
params_dict = self.__dict__.get('_params_dict')
262-
if params is None:
263-
raise AttributeError("For 'Module', can not assign params before Module.__init__() is called.")
264-
exist_names = set("")
265-
for item in value:
266-
self.insert_param_to_layer(item, value[item], check_name=False)
267-
if item in exist_names:
268-
raise ValueError("The value {} , its name '{}' already exists.".
269-
format(value[item], item))
270-
exist_names.add(item)
271-
272-
if name in self.__dict__:
273-
del self.__dict__[name]
274-
if name in params:
275-
del params[name]
276-
params_dict[name] = value
277-
278-
def set_attr_for_parameter_tuple(self, name, value):
279-
"""Set attr for parameter in ParameterTuple."""
280-
params = self.__dict__.get('_params')
281-
params_list = self.__dict__.get('_params_list')
282-
if params is None:
283-
raise AttributeError("For 'Module', can not assign params before Module.__init__() is called.")
284-
exist_names = set("")
285-
286-
for item in value:
287-
self.insert_param_to_layer(item.name, item, check_name=False)
288-
if item.name in exist_names:
289-
raise ValueError("The value {} , its name '{}' already exists.".
290-
format(value, item.name))
291-
exist_names.add(item.name)
292-
293-
if name in self.__dict__:
294-
del self.__dict__[name]
295-
if name in params:
296-
del params[name]
297-
params_list[name] = value
258+
# def set_attr_for_parameter_dict(self, name, value):
259+
# """Set attr for parameter in ParameterDict."""
260+
# params = self.__dict__.get('_params')
261+
# params_dict = self.__dict__.get('_params_dict')
262+
# if params is None:
263+
# raise AttributeError("For 'Module', can not assign params before Module.__init__() is called.")
264+
# exist_names = set("")
265+
# for item in value:
266+
# self.insert_param_to_layer(item, value[item], check_name=False)
267+
# if item in exist_names:
268+
# raise ValueError("The value {} , its name '{}' already exists.".
269+
# format(value[item], item))
270+
# exist_names.add(item)
271+
#
272+
# if name in self.__dict__:
273+
# del self.__dict__[name]
274+
# if name in params:
275+
# del params[name]
276+
# params_dict[name] = value
277+
#
278+
# def set_attr_for_parameter_tuple(self, name, value):
279+
# """Set attr for parameter in ParameterTuple."""
280+
# params = self.__dict__.get('_params')
281+
# params_list = self.__dict__.get('_params_list')
282+
# if params is None:
283+
# raise AttributeError("For 'Module', can not assign params before Module.__init__() is called.")
284+
# exist_names = set("")
285+
#
286+
# for item in value:
287+
# self.insert_param_to_layer(item.name, item, check_name=False)
288+
# if item.name in exist_names:
289+
# raise ValueError("The value {} , its name '{}' already exists.".
290+
# format(value, item.name))
291+
# exist_names.add(item.name)
292+
#
293+
# if name in self.__dict__:
294+
# del self.__dict__[name]
295+
# if name in params:
296+
# del params[name]
297+
# params_list[name] = value
298298

299299
def set_train(self):
300300
"""Set this network in training mode. After calling this method,
@@ -354,7 +354,6 @@ def insert_param_to_layer(self, param_name, param, check_name=True):
354354
Determines whether the name input is compatible. Default: True.
355355
356356
"""
357-
358357
if not param_name:
359358
raise KeyError("The name of parameter should not be null.")
360359
if check_name and '.' in param_name:
@@ -388,15 +387,15 @@ def __getattr__(self, name):
388387
params_status = self.__dict__['_params_status']
389388
if name in params_status:
390389
return params_status[name]
391-
if '_params_list' in self.__dict__:
392-
params_list = self.__dict__['_params_list']
393-
if name in params_list:
394-
para_list = params_list[name]
395-
return para_list
396-
if '_params_dict' in self.__dict__:
397-
params_dict = self.__dict__['_params_dict']
398-
if name in params_dict:
399-
return params_dict[name]
390+
# if '_params_list' in self.__dict__:
391+
# params_list = self.__dict__['_params_list']
392+
# if name in params_list:
393+
# para_list = params_list[name]
394+
# return para_list
395+
# if '_params_dict' in self.__dict__:
396+
# params_dict = self.__dict__['_params_dict']
397+
# if name in params_dict:
398+
# return params_dict[name]
400399
raise AttributeError("'{}' object has no attribute '{}'.".format(type(self).__name__, name))
401400

402401
def __delattr__(self, name):
@@ -1142,10 +1141,10 @@ def __setitem__(self, idx, parameter):
11421141
idx = self._get_abs_string_index(idx)
11431142
self._params[str(idx)] = parameter
11441143

1145-
# def __setattr__(self, key, value):
1146-
# if not hasattr(self, key) and not isinstance(value, tf.Variable):
1147-
# warnings.warn("Setting attributes on ParameterList is not supported.")
1148-
# super(ParameterList, self).__setattr__(key, value)
1144+
def __setattr__(self, key, value):
1145+
# if not hasattr(self, key) and not isinstance(value, tf.Variable):
1146+
# warnings.warn("Setting attributes on ParameterList is not supported.")
1147+
super(ParameterList, self).__setattr__(key, value)
11491148

11501149
def __len__(self):
11511150
return len(self._params)
@@ -1162,7 +1161,7 @@ def __dir__(self):
11621161
return keys
11631162

11641163
def append(self, parameter):
1165-
self._params[str(len(self))] = parameter
1164+
self.insert_param_to_layer(str(len(self)), parameter)
11661165
return self
11671166

11681167
def extend(self, parameters):
@@ -1173,7 +1172,7 @@ def extend(self, parameters):
11731172
)
11741173
offset = len(self)
11751174
for i, para in enumerate(parameters):
1176-
self._params[str(offset + i)] = para
1175+
self.insert_param_to_layer(str(offset + i), para)
11771176
return self
11781177

11791178
def __call__(self, input):
@@ -1248,15 +1247,13 @@ def __getitem__(self, key):
12481247
return self._params[key]
12491248

12501249
def __setitem__(self, key, parameter):
1251-
self._params[key] = parameter
1250+
self.insert_param_to_layer(key, parameter)
12521251

12531252
def __delitem__(self, key):
12541253
del self._params[key]
12551254

1256-
# def __setattr__(self, key, value):
1257-
# if not hasattr(self, key) and not isinstance(value, tf.Variable):
1258-
# warnings.warn("Setting attributes on ParameterDict is not supported.")
1259-
# super(ParameterDict, self).__setattr__(key, value)
1255+
def __setattr__(self, key, value):
1256+
super(ParameterDict, self).__setattr__(key, value)
12601257

12611258
def __len__(self) -> int:
12621259
return len(self._params)

tensorlayerx/nn/core/core_torch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ def __dir__(self):
617617
keys = [key for key in keys if not key.isdigit()]
618618
return keys
619619

620-
def append(self, parameter: 'Parameter') -> 'ParameterList':
620+
def append(self, parameter):
621621

622622
self.register_parameter(str(len(self)), parameter)
623623
return self

0 commit comments

Comments
 (0)