Skip to content

Commit de9605c

Browse files
committed
Updating to PyTorch 0.4.0
1 parent 044c163 commit de9605c

File tree

6 files changed

+51
-54
lines changed

6 files changed

+51
-54
lines changed

main.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def main():
4747
logger.addHandler(ch)
4848
# argument validation
4949
args.cuda = args.cuda and torch.cuda.is_available()
50+
device = torch.device("cuda:0" if args.cuda else "cpu")
5051
if args.sparse and args.wd != 0:
5152
logger.error('Sparsity and weight decay are incompatible, pick one!')
5253
exit()
@@ -111,18 +112,6 @@ def main():
111112
args.sparse,
112113
args.freeze_embed)
113114
criterion = nn.KLDivLoss()
114-
if args.cuda:
115-
model.cuda(), criterion.cuda()
116-
if args.optim == 'adam':
117-
optimizer = optim.Adam(filter(lambda p: p.requires_grad,
118-
model.parameters()), lr=args.lr, weight_decay=args.wd)
119-
elif args.optim == 'adagrad':
120-
optimizer = optim.Adagrad(filter(lambda p: p.requires_grad,
121-
model.parameters()), lr=args.lr, weight_decay=args.wd)
122-
elif args.optim == 'sgd':
123-
optimizer = optim.SGD(filter(lambda p: p.requires_grad,
124-
model.parameters()), lr=args.lr, weight_decay=args.wd)
125-
metrics = Metrics(args.num_classes)
126115

127116
# for words common to dataset vocab and GLOVE, use GLOVE vectors
128117
# for other words in dataset vocab, use random normal vectors
@@ -134,7 +123,8 @@ def main():
134123
glove_vocab, glove_emb = utils.load_word_vectors(
135124
os.path.join(args.glove, 'glove.840B.300d'))
136125
logger.debug('==> GLOVE vocabulary size: %d ' % glove_vocab.size())
137-
emb = torch.Tensor(vocab.size(), glove_emb.size(1)).normal_(-0.05, 0.05)
126+
emb = torch.zeros(vocab.size(), glove_emb.size(1), dtype=torch.float, device=device)
127+
emb.normal_(0, 0.05)
138128
# zero out the embeddings for padding and other special words if they are absent in vocab
139129
for idx, item in enumerate([Constants.PAD_WORD, Constants.UNK_WORD,
140130
Constants.BOS_WORD, Constants.EOS_WORD]):
@@ -144,12 +134,22 @@ def main():
144134
emb[vocab.getIndex(word)] = glove_emb[glove_vocab.getIndex(word)]
145135
torch.save(emb, emb_file)
146136
# plug these into embedding matrix inside model
147-
if args.cuda:
148-
emb = emb.cuda()
149-
model.emb.weight.data.copy_(emb)
137+
model.emb.weight.copy_(emb)
138+
139+
model.to(device), criterion.to(device)
140+
if args.optim == 'adam':
141+
optimizer = optim.Adam(filter(lambda p: p.requires_grad,
142+
model.parameters()), lr=args.lr, weight_decay=args.wd)
143+
elif args.optim == 'adagrad':
144+
optimizer = optim.Adagrad(filter(lambda p: p.requires_grad,
145+
model.parameters()), lr=args.lr, weight_decay=args.wd)
146+
elif args.optim == 'sgd':
147+
optimizer = optim.SGD(filter(lambda p: p.requires_grad,
148+
model.parameters()), lr=args.lr, weight_decay=args.wd)
149+
metrics = Metrics(args.num_classes)
150150

151151
# create trainer object for training and testing
152-
trainer = Trainer(args, model, criterion, optimizer)
152+
trainer = Trainer(args, model, criterion, optimizer, device)
153153

154154
best = -float('inf')
155155
for epoch in range(args.epochs):

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
http://download.pytorch.org/whl/cpu/torch-0.3.1-cp36-cp36m-linux_x86_64.whl
1+
http://download.pytorch.org/whl/cpu/torch-0.4.0-cp36-cp36m-linux_x86_64.whl
22
tqdm

treelstm/dataset.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def read_sentences(self, filename):
4444

4545
def read_sentence(self, line):
4646
indices = self.vocab.convertToIdx(line.split(), Constants.UNK_WORD)
47-
return torch.LongTensor(indices)
47+
return torch.tensor(indices, dtype=torch.long, device='cpu')
4848

4949
def read_trees(self, filename):
5050
with open(filename, 'r') as f:
@@ -82,5 +82,5 @@ def read_tree(self, line):
8282
def read_labels(self, filename):
8383
with open(filename, 'r') as f:
8484
labels = list(map(lambda x: float(x), f.readlines()))
85-
labels = torch.Tensor(labels)
85+
labels = torch.tensor(labels, dtype=torch.float, device='cpu')
8686
return labels

treelstm/model.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
import torch.nn as nn
33
import torch.nn.functional as F
4-
from torch.autograd import Variable as Var
54

65
from . import Constants
76

@@ -39,8 +38,8 @@ def forward(self, tree, inputs):
3938
self.forward(tree.children[idx], inputs)
4039

4140
if tree.num_children == 0:
42-
child_c = Var(inputs[0].data.new(1, self.mem_dim).fill_(0.))
43-
child_h = Var(inputs[0].data.new(1, self.mem_dim).fill_(0.))
41+
child_c = inputs[0].detach().new(1, self.mem_dim).fill_(0.).requires_grad_()
42+
child_h = inputs[0].detach().new(1, self.mem_dim).fill_(0.).requires_grad_()
4443
else:
4544
child_c, child_h = zip(* map(lambda x: x.state, tree.children))
4645
child_c, child_h = torch.cat(child_c, dim=0), torch.cat(child_h, dim=0)

treelstm/trainer.py

+22-25
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,34 @@
11
from tqdm import tqdm
22

33
import torch
4-
from torch.autograd import Variable as Var
54

65
from . import utils
76

87

98
class Trainer(object):
10-
def __init__(self, args, model, criterion, optimizer):
9+
def __init__(self, args, model, criterion, optimizer, device):
1110
super(Trainer, self).__init__()
1211
self.args = args
1312
self.model = model
1413
self.criterion = criterion
1514
self.optimizer = optimizer
15+
self.device = device
1616
self.epoch = 0
1717

1818
# helper function for training
1919
def train(self, dataset):
2020
self.model.train()
2121
self.optimizer.zero_grad()
2222
total_loss = 0.0
23-
indices = torch.randperm(len(dataset))
23+
indices = torch.randperm(len(dataset), dtype=torch.long, device='cpu')
2424
for idx in tqdm(range(len(dataset)), desc='Training epoch ' + str(self.epoch + 1) + ''):
25-
ltree, lsent, rtree, rsent, label = dataset[indices[idx]]
26-
linput, rinput = Var(lsent), Var(rsent)
27-
target = Var(utils.map_label_to_target(label, dataset.num_classes))
28-
if self.args.cuda:
29-
linput, rinput = linput.cuda(), rinput.cuda()
30-
target = target.cuda()
25+
ltree, linput, rtree, rinput, label = dataset[indices[idx]]
26+
target = utils.map_label_to_target(label, dataset.num_classes)
27+
linput, rinput = linput.to(self.device), rinput.to(self.device)
28+
target = target.to(self.device)
3129
output = self.model(ltree, linput, rtree, rinput)
3230
loss = self.criterion(output, target)
33-
total_loss += loss.data[0]
31+
total_loss += loss.item()
3432
loss.backward()
3533
if idx % self.args.batchsize == 0 and idx > 0:
3634
self.optimizer.step()
@@ -41,19 +39,18 @@ def train(self, dataset):
4139
# helper function for testing
4240
def test(self, dataset):
4341
self.model.eval()
44-
total_loss = 0
45-
predictions = torch.zeros(len(dataset))
46-
indices = torch.arange(1, dataset.num_classes + 1)
47-
for idx in tqdm(range(len(dataset)), desc='Testing epoch ' + str(self.epoch) + ''):
48-
ltree, lsent, rtree, rsent, label = dataset[idx]
49-
linput, rinput = Var(lsent, volatile=True), Var(rsent, volatile=True)
50-
target = Var(utils.map_label_to_target(label, dataset.num_classes), volatile=True)
51-
if self.args.cuda:
52-
linput, rinput = linput.cuda(), rinput.cuda()
53-
target = target.cuda()
54-
output = self.model(ltree, linput, rtree, rinput)
55-
loss = self.criterion(output, target)
56-
total_loss += loss.data[0]
57-
output = output.data.squeeze().cpu()
58-
predictions[idx] = torch.dot(indices, torch.exp(output))
42+
with torch.no_grad():
43+
total_loss = 0.0
44+
predictions = torch.zeros(len(dataset), dtype=torch.float, device='cpu')
45+
indices = torch.arange(1, dataset.num_classes + 1, dtype=torch.float, device='cpu')
46+
for idx in tqdm(range(len(dataset)), desc='Testing epoch ' + str(self.epoch) + ''):
47+
ltree, linput, rtree, rinput, label = dataset[idx]
48+
target = utils.map_label_to_target(label, dataset.num_classes)
49+
linput, rinput = linput.to(self.device), rinput.to(self.device)
50+
target = target.to(self.device)
51+
output = self.model(ltree, linput, rtree, rinput)
52+
loss = self.criterion(output, target)
53+
total_loss += loss.item()
54+
output = output.squeeze().to('cpu')
55+
predictions[idx] = torch.dot(indices, torch.exp(output))
5956
return total_loss / len(dataset), predictions

treelstm/utils.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,14 @@ def load_word_vectors(path):
2626
contents = f.readline().rstrip('\n').split(' ')
2727
dim = len(contents[1:])
2828
words = [None] * (count)
29-
vectors = torch.zeros(count, dim)
29+
vectors = torch.zeros(count, dim, dtype=torch.float, device='cpu')
3030
with open(path + '.txt', 'r', encoding='utf8', errors='ignore') as f:
3131
idx = 0
3232
for line in f:
3333
contents = line.rstrip('\n').split(' ')
3434
words[idx] = contents[0]
35-
vectors[idx] = torch.Tensor(list(map(float, contents[1:])))
35+
values = list(map(float, contents[1:]))
36+
vectors[idx] = torch.tensor(values, dtype=torch.float, device='cpu')
3637
idx += 1
3738
with open(path + '.vocab', 'w', encoding='utf8', errors='ignore') as f:
3839
for word in words:
@@ -57,12 +58,12 @@ def build_vocab(filenames, vocabfile):
5758

5859
# mapping from scalar to vector
5960
def map_label_to_target(label, num_classes):
60-
target = torch.zeros(1, num_classes)
61+
target = torch.zeros(1, num_classes, dtype=torch.float, device='cpu')
6162
ceil = int(math.ceil(label))
6263
floor = int(math.floor(label))
6364
if ceil == floor:
64-
target[0][floor-1] = 1
65+
target[0, floor-1] = 1
6566
else:
66-
target[0][floor-1] = ceil - label
67-
target[0][ceil-1] = label - floor
67+
target[0, floor-1] = ceil - label
68+
target[0, ceil-1] = label - floor
6869
return target

0 commit comments

Comments
 (0)