Skip to content

Commit f9dd5ce

Browse files
committed
multi-gpu enhance
1 parent 22feea5 commit f9dd5ce

File tree

4 files changed

+19
-41
lines changed

4 files changed

+19
-41
lines changed

crnn_main.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,10 @@
7070
test_dataset = dataset.lmdbDataset(
7171
root=opt.valroot, transform=dataset.resizeNormalize((100, 32)))
7272

73-
ngpu = int(opt.ngpu)
74-
nh = int(opt.nh)
75-
alphabet = opt.alphabet
76-
nclass = len(alphabet) + 1
73+
nclass = len(opt.alphabet) + 1
7774
nc = 1
7875

79-
converter = utils.strLabelConverter(alphabet)
76+
converter = utils.strLabelConverter(opt.alphabet)
8077
criterion = CTCLoss()
8178

8279

@@ -89,7 +86,8 @@ def weights_init(m):
8986
m.weight.data.normal_(1.0, 0.02)
9087
m.bias.data.fill_(0)
9188

92-
crnn = crnn.CRNN(opt.imgH, nc, nclass, nh, ngpu)
89+
90+
crnn = crnn.CRNN(opt.imgH, nc, nclass, opt.nh)
9391
crnn.apply(weights_init)
9492
if opt.crnn != '':
9593
print('loading pretrained model from %s' % opt.crnn)
@@ -102,6 +100,7 @@ def weights_init(m):
102100

103101
if opt.cuda:
104102
crnn.cuda()
103+
crnn = torch.nn.DataParallel(crnn, device_ids=range(opt.ngpu))
105104
image = image.cuda()
106105
criterion = criterion.cuda()
107106

demo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
img_path = './data/demo.png'
1212
alphabet = '0123456789abcdefghijklmnopqrstuvwxyz'
1313

14-
model = crnn.CRNN(32, 1, 37, 256, 1).cuda()
14+
model = crnn.CRNN(32, 1, 37, 256).cuda()
1515
print('loading pretrained model from %s' % model_path)
1616
model.load_state_dict(torch.load(model_path))
1717

models/crnn.py

+13-21
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,29 @@
11
import torch.nn as nn
2-
import utils
32

43

54
class BidirectionalLSTM(nn.Module):
65

7-
def __init__(self, nIn, nHidden, nOut, ngpu):
6+
def __init__(self, nIn, nHidden, nOut):
87
super(BidirectionalLSTM, self).__init__()
9-
self.ngpu = ngpu
108

119
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
1210
self.embedding = nn.Linear(nHidden * 2, nOut)
1311

1412
def forward(self, input):
15-
recurrent, _ = utils.data_parallel(
16-
self.rnn, input, self.ngpu) # [T, b, h * 2]
17-
13+
recurrent, _ = self.rnn(input)
1814
T, b, h = recurrent.size()
1915
t_rec = recurrent.view(T * b, h)
20-
output = utils.data_parallel(
21-
self.embedding, t_rec, self.ngpu) # [T * b, nOut]
16+
17+
output = self.embedding(t_rec) # [T * b, nOut]
2218
output = output.view(T, b, -1)
2319

2420
return output
2521

2622

2723
class CRNN(nn.Module):
2824

29-
def __init__(self, imgH, nc, nclass, nh, ngpu, n_rnn=2, leakyRelu=False):
25+
def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
3026
super(CRNN, self).__init__()
31-
self.ngpu = ngpu
3227
assert imgH % 16 == 0, 'imgH has to be a multiple of 16'
3328

3429
ks = [3, 3, 3, 3, 3, 3, 2]
@@ -57,31 +52,28 @@ def convRelu(i, batchNormalization=False):
5752
cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
5853
convRelu(2, True)
5954
convRelu(3)
60-
cnn.add_module('pooling{0}'.format(2), nn.MaxPool2d((2, 2),
61-
(2, 1),
62-
(0, 1))) # 256x4x16
55+
cnn.add_module('pooling{0}'.format(2),
56+
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
6357
convRelu(4, True)
6458
convRelu(5)
65-
cnn.add_module('pooling{0}'.format(3), nn.MaxPool2d((2, 2),
66-
(2, 1),
67-
(0, 1))) # 512x2x16
59+
cnn.add_module('pooling{0}'.format(3),
60+
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
6861
convRelu(6, True) # 512x1x16
6962

7063
self.cnn = cnn
7164
self.rnn = nn.Sequential(
72-
BidirectionalLSTM(512, nh, nh, ngpu),
73-
BidirectionalLSTM(nh, nh, nclass, ngpu)
74-
)
65+
BidirectionalLSTM(512, nh, nh),
66+
BidirectionalLSTM(nh, nh, nclass))
7567

7668
def forward(self, input):
7769
# conv features
78-
conv = utils.data_parallel(self.cnn, input, self.ngpu)
70+
conv = self.cnn(input)
7971
b, c, h, w = conv.size()
8072
assert h == 1, "the height of conv must be 1"
8173
conv = conv.squeeze(2)
8274
conv = conv.permute(2, 0, 1) # [w, b, c]
8375

8476
# rnn features
85-
output = utils.data_parallel(self.rnn, conv, self.ngpu)
77+
output = self.rnn(conv)
8678

8779
return output

models/utils.py

-13
This file was deleted.

0 commit comments

Comments
 (0)