Skip to content

Commit c3b42e9

Browse files
committed
Flake8 / Pep8 fixes
1 parent cfec177 commit c3b42e9

10 files changed

+86
-62
lines changed

.gitignore

-1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,3 @@ job.py
99
sftp-config.json
1010
.ftpconfig
1111
.ftpignore
12-
.vscode

LICENSE

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
1818
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
1919
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
2020
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21-
SOFTWARE.
21+
SOFTWARE.

config.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33

44
def parse_args():
5-
parser = argparse.ArgumentParser(description='PyTorch TreeLSTM for Sentence Similarity on Dependency Trees')
6-
#
5+
parser = argparse.ArgumentParser(
6+
description='PyTorch TreeLSTM for Sentence Similarity on Dependency Trees')
7+
# data arguments
78
parser.add_argument('--data', default='data/sick/',
89
help='path to dataset')
910
parser.add_argument('--glove', default='data/glove/',

fetch_and_preprocess.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ python2.7 scripts/download.py
44

55
CLASSPATH="lib:lib/stanford-parser/stanford-parser.jar:lib/stanford-parser/stanford-parser-3.5.1-models.jar"
66
javac -cp $CLASSPATH lib/*.java
7-
python2.7 scripts/preprocess-sick.py
7+
python2.7 scripts/preprocess-sick.py

main.py

+29-20
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ def main():
7373
build_vocab(token_files, sick_vocab_file)
7474

7575
# get vocab object from vocab file previously written
76-
vocab = Vocab(filename=sick_vocab_file, data=[Constants.PAD_WORD, Constants.UNK_WORD, Constants.BOS_WORD, Constants.EOS_WORD])
76+
vocab = Vocab(filename=sick_vocab_file,
77+
data=[Constants.PAD_WORD, Constants.UNK_WORD,
78+
Constants.BOS_WORD, Constants.EOS_WORD])
7779
logger.debug('==> SICK vocabulary size : %d ' % vocab.size())
7880

7981
# load SICK dataset splits
@@ -101,22 +103,25 @@ def main():
101103

102104
# initialize model, criterion/loss_function, optimizer
103105
model = SimilarityTreeLSTM(
104-
vocab.size(),
105-
args.input_dim,
106-
args.mem_dim,
107-
args.hidden_dim,
108-
args.num_classes,
109-
args.sparse,
110-
args.freeze_embed)
106+
vocab.size(),
107+
args.input_dim,
108+
args.mem_dim,
109+
args.hidden_dim,
110+
args.num_classes,
111+
args.sparse,
112+
args.freeze_embed)
111113
criterion = nn.KLDivLoss()
112114
if args.cuda:
113115
model.cuda(), criterion.cuda()
114116
if args.optim == 'adam':
115-
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wd)
117+
optimizer = optim.Adam(filter(lambda p: p.requires_grad,
118+
model.parameters()), lr=args.lr, weight_decay=args.wd)
116119
elif args.optim == 'adagrad':
117-
optimizer = optim.Adagrad(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wd)
120+
optimizer = optim.Adagrad(filter(lambda p: p.requires_grad,
121+
model.parameters()), lr=args.lr, weight_decay=args.wd)
118122
elif args.optim == 'sgd':
119-
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wd)
123+
optimizer = optim.SGD(filter(lambda p: p.requires_grad,
124+
model.parameters()), lr=args.lr, weight_decay=args.wd)
120125
metrics = Metrics(args.num_classes)
121126

122127
# for words common to dataset vocab and GLOVE, use GLOVE vectors
@@ -130,7 +135,8 @@ def main():
130135
logger.debug('==> GLOVE vocabulary size: %d ' % glove_vocab.size())
131136
emb = torch.Tensor(vocab.size(), glove_emb.size(1)).normal_(-0.05, 0.05)
132137
# zero out the embeddings for padding and other special words if they are absent in vocab
133-
for idx, item in enumerate([Constants.PAD_WORD, Constants.UNK_WORD, Constants.BOS_WORD, Constants.EOS_WORD]):
138+
for idx, item in enumerate([Constants.PAD_WORD, Constants.UNK_WORD,
139+
Constants.BOS_WORD, Constants.EOS_WORD]):
134140
emb[idx].zero_()
135141
for word in vocab.labelToIdx.keys():
136142
if glove_vocab.getIndex(word):
@@ -146,29 +152,32 @@ def main():
146152

147153
best = -float('inf')
148154
for epoch in range(args.epochs):
149-
train_loss = trainer.train(train_dataset)
155+
train_loss = trainer.train(train_dataset)
150156
train_loss, train_pred = trainer.test(train_dataset)
151-
dev_loss, dev_pred = trainer.test(dev_dataset)
152-
test_loss, test_pred = trainer.test(test_dataset)
157+
dev_loss, dev_pred = trainer.test(dev_dataset)
158+
test_loss, test_pred = trainer.test(test_dataset)
153159

154160
train_pearson = metrics.pearson(train_pred, train_dataset.labels)
155161
train_mse = metrics.mse(train_pred, train_dataset.labels)
156-
logger.info('==> Epoch {}, Train \tLoss: {}\tPearson: {}\tMSE: {}'.format(epoch, train_loss, train_pearson, train_mse))
162+
logger.info('==> Epoch {}, Train \tLoss: {}\tPearson: {}\tMSE: {}'.format(
163+
epoch, train_loss, train_pearson, train_mse))
157164
dev_pearson = metrics.pearson(dev_pred, dev_dataset.labels)
158165
dev_mse = metrics.mse(dev_pred, dev_dataset.labels)
159-
logger.info('==> Epoch {}, Dev \tLoss: {}\tPearson: {}\tMSE: {}'.format(epoch, dev_loss, dev_pearson, dev_mse))
166+
logger.info('==> Epoch {}, Dev \tLoss: {}\tPearson: {}\tMSE: {}'.format(
167+
epoch, dev_loss, dev_pearson, dev_mse))
160168
test_pearson = metrics.pearson(test_pred, test_dataset.labels)
161169
test_mse = metrics.mse(test_pred, test_dataset.labels)
162-
logger.info('==> Epoch {}, Test \tLoss: {}\tPearson: {}\tMSE: {}'.format(epoch, test_loss, test_pearson, test_mse))
170+
logger.info('==> Epoch {}, Test \tLoss: {}\tPearson: {}\tMSE: {}'.format(
171+
epoch, test_loss, test_pearson, test_mse))
163172

164173
if best < test_pearson:
165174
best = test_pearson
166175
checkpoint = {
167-
'model': trainer.model.state_dict(),
176+
'model': trainer.model.state_dict(),
168177
'optim': trainer.optimizer,
169178
'pearson': test_pearson, 'mse': test_mse,
170179
'args': args, 'epoch': epoch
171-
}
180+
}
172181
logger.debug('==> New optimum found, checkpointing everything now...')
173182
torch.save(checkpoint, '%s.pt' % os.path.join(args.save, args.expname))
174183

model.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,18 @@ def node_forward(self, inputs, child_c, child_h):
2525
i, o, u = F.sigmoid(i), F.sigmoid(o), F.tanh(u)
2626

2727
f = F.sigmoid(
28-
self.fh(child_h) +
29-
self.fx(inputs).repeat(len(child_h), 1)
30-
)
28+
self.fh(child_h) +
29+
self.fx(inputs).repeat(len(child_h), 1)
30+
)
3131
fc = torch.mul(f, child_c)
3232

3333
c = torch.mul(i, u) + torch.sum(fc, dim=0, keepdim=True)
3434
h = torch.mul(o, F.tanh(c))
3535
return c, h
3636

3737
def forward(self, tree, inputs):
38-
_ = [self.forward(tree.children[idx], inputs) for idx in range(tree.num_children)]
38+
for idx in range(tree.num_children):
39+
self.forward(tree.children[idx], inputs)
3940

4041
if tree.num_children == 0:
4142
child_c = Var(inputs[0].data.new(1, self.mem_dim).fill_(0.))

scripts/download.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -10,27 +10,26 @@
1010
import urllib2
1111
import sys
1212
import os
13-
import shutil
1413
import zipfile
15-
import gzip
14+
1615

1716
def download(url, dirpath):
1817
filename = url.split('/')[-1]
1918
filepath = os.path.join(dirpath, filename)
2019
try:
2120
u = urllib2.urlopen(url)
22-
except:
23-
print("URL %s failed to open" %url)
21+
except Exception as e:
22+
print("URL %s failed to open" % url)
2423
raise Exception
2524
try:
2625
f = open(filepath, 'wb')
27-
except:
28-
print("Cannot write %s" %filepath)
26+
except Exception as e:
27+
print("Cannot write %s" % filepath)
2928
raise Exception
3029
try:
3130
filesize = int(u.info().getheaders("Content-Length")[0])
32-
except:
33-
print("URL %s failed to report length" %url)
31+
except Exception as e:
32+
print("URL %s failed to report length" % url)
3433
raise Exception
3534
print("Downloading: %s Bytes: %s" % (filename, filesize))
3635

@@ -47,19 +46,22 @@ def download(url, dirpath):
4746
downloaded += len(buf)
4847
f.write(buf)
4948
status = (("[%-" + str(status_width + 1) + "s] %3.2f%%") %
50-
('=' * int(float(downloaded) / filesize * status_width) + '>', downloaded * 100. / filesize))
49+
('=' * int(downloaded / filesize * status_width) + '>',
50+
downloaded * 100. / filesize))
5151
print(status, end='')
5252
sys.stdout.flush()
5353
f.close()
5454
return filepath
5555

56+
5657
def unzip(filepath):
5758
print("Extracting: " + filepath)
5859
dirpath = os.path.dirname(filepath)
5960
with zipfile.ZipFile(filepath) as zf:
6061
zf.extractall(dirpath)
6162
os.remove(filepath)
6263

64+
6365
def download_tagger(dirpath):
6466
tagger_dir = 'stanford-tagger'
6567
if os.path.exists(os.path.join(dirpath, tagger_dir)):
@@ -74,6 +76,7 @@ def download_tagger(dirpath):
7476
os.remove(filepath)
7577
os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, tagger_dir))
7678

79+
7780
def download_parser(dirpath):
7881
parser_dir = 'stanford-parser'
7982
if os.path.exists(os.path.join(dirpath, parser_dir)):
@@ -88,6 +91,7 @@ def download_parser(dirpath):
8891
os.remove(filepath)
8992
os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, parser_dir))
9093

94+
9195
def download_wordvecs(dirpath):
9296
if os.path.exists(dirpath):
9397
print('Found Glove vectors - skip')
@@ -97,6 +101,7 @@ def download_wordvecs(dirpath):
97101
url = 'http://www-nlp.stanford.edu/data/glove.840B.300d.zip'
98102
unzip(download(url, dirpath))
99103

104+
100105
def download_sick(dirpath):
101106
if os.path.exists(dirpath):
102107
print('Found SICK dataset - skip')
@@ -110,6 +115,7 @@ def download_sick(dirpath):
110115
unzip(download(trial_url, dirpath))
111116
unzip(download(test_url, dirpath))
112117

118+
113119
if __name__ == '__main__':
114120
base_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
115121

scripts/preprocess-sick.py

+21-14
Original file line numberDiff line numberDiff line change
@@ -6,33 +6,37 @@
66
import os
77
import glob
88

9+
910
def make_dirs(dirs):
1011
for d in dirs:
1112
if not os.path.exists(d):
1213
os.makedirs(d)
1314

15+
1416
def dependency_parse(filepath, cp='', tokenize=True):
1517
print('\nDependency parsing ' + filepath)
1618
dirpath = os.path.dirname(filepath)
1719
filepre = os.path.splitext(os.path.basename(filepath))[0]
1820
tokpath = os.path.join(dirpath, filepre + '.toks')
1921
parentpath = os.path.join(dirpath, filepre + '.parents')
20-
relpath = os.path.join(dirpath, filepre + '.rels')
22+
relpath = os.path.join(dirpath, filepre + '.rels')
2123
tokenize_flag = '-tokenize - ' if tokenize else ''
2224
cmd = ('java -cp %s DependencyParse -tokpath %s -parentpath %s -relpath %s %s < %s'
23-
% (cp, tokpath, parentpath, relpath, tokenize_flag, filepath))
25+
% (cp, tokpath, parentpath, relpath, tokenize_flag, filepath))
2426
os.system(cmd)
2527

28+
2629
def constituency_parse(filepath, cp='', tokenize=True):
2730
dirpath = os.path.dirname(filepath)
2831
filepre = os.path.splitext(os.path.basename(filepath))[0]
2932
tokpath = os.path.join(dirpath, filepre + '.toks')
3033
parentpath = os.path.join(dirpath, filepre + '.cparents')
3134
tokenize_flag = '-tokenize - ' if tokenize else ''
3235
cmd = ('java -cp %s ConstituencyParse -tokpath %s -parentpath %s %s < %s'
33-
% (cp, tokpath, parentpath, tokenize_flag, filepath))
36+
% (cp, tokpath, parentpath, tokenize_flag, filepath))
3437
os.system(cmd)
3538

39+
3640
def build_vocab(filepaths, dst_path, lowercase=True):
3741
vocab = set()
3842
for filepath in filepaths:
@@ -45,26 +49,29 @@ def build_vocab(filepaths, dst_path, lowercase=True):
4549
for w in sorted(vocab):
4650
f.write(w + '\n')
4751

52+
4853
def split(filepath, dst_dir):
4954
with open(filepath) as datafile, \
50-
open(os.path.join(dst_dir, 'a.txt'), 'w') as afile, \
51-
open(os.path.join(dst_dir, 'b.txt'), 'w') as bfile, \
52-
open(os.path.join(dst_dir, 'id.txt'), 'w') as idfile, \
53-
open(os.path.join(dst_dir, 'sim.txt'), 'w') as simfile:
54-
datafile.readline()
55-
for line in datafile:
56-
i, a, b, sim, ent = line.strip().split('\t')
57-
idfile.write(i + '\n')
58-
afile.write(a + '\n')
59-
bfile.write(b + '\n')
60-
simfile.write(sim + '\n')
55+
open(os.path.join(dst_dir, 'a.txt'), 'w') as afile, \
56+
open(os.path.join(dst_dir, 'b.txt'), 'w') as bfile, \
57+
open(os.path.join(dst_dir, 'id.txt'), 'w') as idfile, \
58+
open(os.path.join(dst_dir, 'sim.txt'), 'w') as simfile:
59+
datafile.readline()
60+
for line in datafile:
61+
i, a, b, sim, ent = line.strip().split('\t')
62+
idfile.write(i + '\n')
63+
afile.write(a + '\n')
64+
bfile.write(b + '\n')
65+
simfile.write(sim + '\n')
66+
6167

6268
def parse(dirpath, cp=''):
6369
dependency_parse(os.path.join(dirpath, 'a.txt'), cp=cp, tokenize=True)
6470
dependency_parse(os.path.join(dirpath, 'b.txt'), cp=cp, tokenize=True)
6571
constituency_parse(os.path.join(dirpath, 'a.txt'), cp=cp, tokenize=True)
6672
constituency_parse(os.path.join(dirpath, 'b.txt'), cp=cp, tokenize=True)
6773

74+
6875
if __name__ == '__main__':
6976
print('=' * 80)
7077
print('Preprocessing SICK dataset')

trainer.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,19 @@
99
class Trainer(object):
1010
def __init__(self, args, model, criterion, optimizer):
1111
super(Trainer, self).__init__()
12-
self.args = args
13-
self.model = model
14-
self.criterion = criterion
15-
self.optimizer = optimizer
16-
self.epoch = 0
12+
self.args = args
13+
self.model = model
14+
self.criterion = criterion
15+
self.optimizer = optimizer
16+
self.epoch = 0
1717

1818
# helper function for training
1919
def train(self, dataset):
2020
self.model.train()
2121
self.optimizer.zero_grad()
2222
total_loss = 0.0
2323
indices = torch.randperm(len(dataset))
24-
for idx in tqdm(range(len(dataset)),desc='Training epoch ' + str(self.epoch + 1) + ''):
24+
for idx in tqdm(range(len(dataset)), desc='Training epoch ' + str(self.epoch + 1) + ''):
2525
ltree, lsent, rtree, rsent, label = dataset[indices[idx]]
2626
linput, rinput = Var(lsent), Var(rsent)
2727
target = Var(map_label_to_target(label, dataset.num_classes))
@@ -44,7 +44,7 @@ def test(self, dataset):
4444
total_loss = 0
4545
predictions = torch.zeros(len(dataset))
4646
indices = torch.arange(1, dataset.num_classes + 1)
47-
for idx in tqdm(range(len(dataset)),desc='Testing epoch ' + str(self.epoch) + ''):
47+
for idx in tqdm(range(len(dataset)), desc='Testing epoch ' + str(self.epoch) + ''):
4848
ltree, lsent, rtree, rsent, label = dataset[idx]
4949
linput, rinput = Var(lsent, volatile=True), Var(rsent, volatile=True)
5050
target = Var(map_label_to_target(label, dataset.num_classes), volatile=True)

utils.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,13 @@ def build_vocab(filenames, vocabfile):
5454
for token in sorted(vocab):
5555
f.write(token + '\n')
5656

57+
5758
# mapping from scalar to vector
58-
def map_label_to_target(label,num_classes):
59-
target = torch.zeros(1,num_classes)
59+
def map_label_to_target(label, num_classes):
60+
target = torch.zeros(1, num_classes)
6061
ceil = int(math.ceil(label))
6162
floor = int(math.floor(label))
62-
if ceil==floor:
63+
if ceil == floor:
6364
target[0][floor-1] = 1
6465
else:
6566
target[0][floor-1] = ceil - label

0 commit comments

Comments
 (0)