-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain_word2vec.py
66 lines (53 loc) · 2.1 KB
/
train_word2vec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
""" This code is adapted from https://github.com/ChenRocks/fast_abs_rl """
""" pretrain a word2vec on the corpus"""
import argparse
import json
import logging
import os
from os.path import join, exists
from time import time
from datetime import timedelta
from cytoolz import concatv
import gensim
from utils.io import count_data
class Sentences(object):
""" needed for gensim word2vec training"""
def __init__(self, data_path):
self._path = join(data_path, 'train')
self._n_data = count_data(self._path)
def __iter__(self):
for i in range(self._n_data):
with open(join(self._path, '{}.json'.format(i))) as f:
data = json.loads(f.read())
for s in concatv(data['article'], data['abstract']):
yield ['<s>'] + s.lower().split() + [r'<\s>']
def main(args):
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s',
level=logging.INFO)
start = time()
save_dir = args.path
if not exists(save_dir):
os.makedirs(save_dir)
sentences = Sentences(args.data)
model = gensim.models.Word2Vec(
size=args.dim, min_count=5, workers=16, sg=1)
model.build_vocab(sentences)
print('vocab built in {}'.format(timedelta(seconds=time()-start)))
model.train(sentences,
total_examples=model.corpus_count, epochs=model.iter)
model.save(join(save_dir, 'word2vec.{}d.{}k.bin'.format(
args.dim, len(model.wv.vocab)//1000)))
model.wv.save_word2vec_format(join(
save_dir,
'word2vec.{}d.{}k.w2v'.format(args.dim, len(model.wv.vocab)//1000)
))
print('word2vec trained in {}'.format(timedelta(seconds=time()-start)))
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='train word2vec embedding used for model initialization'
)
parser.add_argument('-data', required=True, help='path of the dataset')
parser.add_argument('-path', required=True, help='root of the model')
parser.add_argument('-dim', action='store', type=int, default=128)
args = parser.parse_args()
main(args)