Skip to content

Commit e1d0a8e

Browse files
committed
update optimizer and dataflow
1 parent 4f1f958 commit e1d0a8e

File tree

9 files changed

+210
-171
lines changed

9 files changed

+210
-171
lines changed

examples/basic_tutorials/tutorial_tensorlayer_model_load.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
# -*- coding: utf-8 -*-
33

44
import os
5-
os.environ['TL_BACKEND'] = 'tensorflow'
6-
# os.environ['TL_BACKEND'] = 'paddle'
5+
# os.environ['TL_BACKEND'] = 'tensorflow'
6+
os.environ['TL_BACKEND'] = 'paddle'
7+
# os.environ['TL_BACKEND'] = 'mindspore'
78
# os.environ['TL_BACKEND'] = 'torch'
89

910
import tensorlayerx as tlx
@@ -105,7 +106,7 @@ def forward(self, x):
105106
cnn = CNN()
106107
# cnn.save_standard_weights('./model.npz')
107108
# TODO Tensorflow trained parameters are imported to the TensorFlow backend.
108-
cnn.load_standard_weights('./model.npz', skip=False)
109+
cnn.load_standard_weights('./model.npz', skip=False, reshape=True)
109110

110111
# TODO Tensorflow backend trained parameters imported to PaddlePaddle/PyTorch/MindSpore to
111112
# set reshape to True parameter to convert convolution shape.

tensorlayerx/backend/ops/mindspore_backend.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -1170,6 +1170,8 @@ def floor(x):
11701170

11711171
def gather(params, indices, axis=None):
11721172
op = P.Gather()
1173+
if axis is None:
1174+
axis = 0
11731175
return op(params, indices, axis)
11741176

11751177

@@ -1590,10 +1592,7 @@ def reduce_std(x, axis=None, keepdims=False):
15901592

15911593

15921594
def reduce_sum(x, axis=None, keepdims=False):
1593-
op = P.ReduceSum(keep_dims=keepdims)
1594-
if axis is None:
1595-
return op(x)
1596-
return op(x, axis=axis)
1595+
return msnp.sum(x, axis=axis, keepdims=keepdims)
15971596

15981597

15991598
def reduce_variance(x, axis=None, keepdims=False):
@@ -1729,11 +1728,15 @@ def tanh(x):
17291728

17301729
def any(x, axis=None, keepdims=False):
17311730
op = P.ReduceAny(keep_dims=keepdims)
1731+
if axis is None:
1732+
return op(x)
17321733
return op(x, axis)
17331734

17341735

17351736
def all(x, axis=None, keepdims=False):
17361737
op = P.ReduceAll(keep_dims=keepdims)
1738+
if axis is None:
1739+
return op(x)
17371740
return op(x, axis)
17381741

17391742

@@ -1779,8 +1782,7 @@ def zeros_like(x, dtype=None):
17791782

17801783

17811784
def squeeze(x, axis=None):
1782-
op = P.Squeeze(axis)
1783-
return op(x)
1785+
return msnp.squeeze(x, axis)
17841786

17851787

17861788
def unsorted_segment_sum(x, segment_ids, num_segments):
@@ -1792,7 +1794,7 @@ def unsorted_segment_sum(x, segment_ids, num_segments):
17921794
def unsorted_segment_mean(x, segment_ids, num_segments):
17931795
segment_ids = ms.Tensor(segment_ids)
17941796
op = P.UnsortedSegmentSum()
1795-
x_one = msnp.ones_like(x, dtype=x.dtype)
1797+
x_one = msnp.ones_like(x, dtype=x.dtype)
17961798
sum = op(x, segment_ids, num_segments)
17971799
one = op(x_one, segment_ids, num_segments)
17981800
return sum/one

tensorlayerx/backend/ops/tensorflow_backend.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1249,7 +1249,7 @@ def gather(params, indices, axis=None):
12491249
indices : indices
12501250
The index Tensor. Must be one of the following types: int32, int64. The values must be in range [0, params.shape[axis]).
12511251
axis : tensor.
1252-
Must be one of the following types: int32, int64. The axis in params to gather indices from.
1252+
Must be one of the following types: int32, int64. The axis in params to gather indices from. The default value is None, if None, the axis is 0.
12531253
12541254
Returns
12551255
-------

tensorlayerx/backend/ops/torch_backend.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -1012,7 +1012,9 @@ def floor(x):
10121012
return torch.floor(x)
10131013

10141014

1015-
def gather(params, indices, axis = 0):
1015+
def gather(params, indices, axis = None):
1016+
if axis is None:
1017+
axis = 0
10161018
if axis < 0:
10171019
axis = len(params.shape) + axis
10181020
if axis == 0:
@@ -1522,11 +1524,16 @@ def tanh(x):
15221524

15231525

15241526
def any(x, axis=None, keepdims=False):
1525-
return torch.any(x, dim=axis, keepdim=keepdims)
1526-
1527+
if axis is not None:
1528+
return torch.any(x, dim=axis, keepdim=keepdims)
1529+
else:
1530+
return torch.any(x)
15271531

15281532
def all(x, axis=None, keepdims=False):
1529-
return torch.all(x, dim=axis, keepdim=keepdims)
1533+
if axis is not None:
1534+
return torch.all(x, dim=axis, keepdim=keepdims)
1535+
else:
1536+
return torch.all(x)
15301537

15311538

15321539
def logical_and(x, y):

tensorlayerx/dataflow/utils.py

+7
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ def default_convert(data):
2222
data = tf.convert_to_tensor(data)
2323
elif BACKEND == 'torch':
2424
import torch
25+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
2526
data = torch.as_tensor(data)
27+
data = data.to(device)
2628
elif BACKEND == 'paddle':
2729
import paddle
2830
data = paddle.to_tensor(data)
@@ -76,17 +78,22 @@ def default_collate_torch(batch):
7678
data = batch[0]
7779
data_type = type(data)
7880
import torch
81+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
7982
if isinstance(data, torch.Tensor):
8083
batch = torch.stack(batch, 0)
84+
batch = batch.to(device)
8185
return batch
8286
elif isinstance(data, np.ndarray):
8387
batch = np.stack(batch, axis=0)
8488
batch = torch.as_tensor(batch)
89+
batch = batch.to(device)
8590
return batch
8691
elif isinstance(data, numbers.Number):
8792
batch = torch.as_tensor(batch)
93+
batch = batch.to(device)
8894
return batch
8995
elif isinstance(data, (str, bytes)):
96+
batch = batch.to(device)
9097
return batch
9198
elif isinstance(data, collections.abc.Mapping):
9299
return {key: default_collate_torch([d[key] for d in batch]) for key in data}

tensorlayerx/model/core.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -429,19 +429,16 @@ def th_train(
429429
self, n_epoch, train_dataset, network, loss_fn, train_weights, optimizer, metrics, print_train_batch,
430430
print_freq, test_dataset
431431
):
432+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
433+
network.to(device)
432434
for epoch in range(n_epoch):
433435
start_time = time.time()
434436

435437
train_loss, train_acc, n_iter = 0, 0, 0
436438
for X_batch, y_batch in train_dataset:
437-
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
438439
network.set_train()
439-
X_batch = X_batch.to(device)
440-
y_batch = y_batch.to(device)
441-
network.to(device)
442440
output = network(X_batch)
443441
loss = loss_fn(output, y_batch)
444-
445442
grads = optimizer.gradient(loss, train_weights)
446443
optimizer.apply_gradients(zip(grads, train_weights))
447444

tensorlayerx/model/utils.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -177,15 +177,13 @@ def __call__(self, data, label):
177177
class TrainOneStepWithTH(object):
178178

179179
def __init__(self, net_with_loss, optimizer, train_weights):
180+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
180181
self.net_with_loss = net_with_loss
182+
self.net_with_loss.to(device)
181183
self.optimizer = optimizer
182184
self.train_weights = train_weights
183185

184186
def __call__(self, data, label):
185-
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
186-
data = data.to(device)
187-
label = label.to(device)
188-
self.net_with_loss.to(device)
189187
loss = self.net_with_loss(data, label)
190188
grads = self.optimizer.gradient(loss, self.train_weights)
191189
self.optimizer.apply_gradients(zip(grads, self.train_weights))

0 commit comments

Comments
 (0)