1
1
from tqdm import tqdm
2
2
3
3
import torch
4
- from torch .autograd import Variable as Var
5
4
6
5
from . import utils
7
6
8
7
9
8
class Trainer (object ):
10
- def __init__ (self , args , model , criterion , optimizer ):
9
+ def __init__ (self , args , model , criterion , optimizer , device ):
11
10
super (Trainer , self ).__init__ ()
12
11
self .args = args
13
12
self .model = model
14
13
self .criterion = criterion
15
14
self .optimizer = optimizer
15
+ self .device = device
16
16
self .epoch = 0
17
17
18
18
# helper function for training
19
19
def train (self , dataset ):
20
20
self .model .train ()
21
21
self .optimizer .zero_grad ()
22
22
total_loss = 0.0
23
- indices = torch .randperm (len (dataset ))
23
+ indices = torch .randperm (len (dataset ), dtype = torch . long , device = 'cpu' )
24
24
for idx in tqdm (range (len (dataset )), desc = 'Training epoch ' + str (self .epoch + 1 ) + '' ):
25
- ltree , lsent , rtree , rsent , label = dataset [indices [idx ]]
26
- linput , rinput = Var (lsent ), Var (rsent )
27
- target = Var (utils .map_label_to_target (label , dataset .num_classes ))
28
- if self .args .cuda :
29
- linput , rinput = linput .cuda (), rinput .cuda ()
30
- target = target .cuda ()
25
+ ltree , linput , rtree , rinput , label = dataset [indices [idx ]]
26
+ target = utils .map_label_to_target (label , dataset .num_classes )
27
+ linput , rinput = linput .to (self .device ), rinput .to (self .device )
28
+ target = target .to (self .device )
31
29
output = self .model (ltree , linput , rtree , rinput )
32
30
loss = self .criterion (output , target )
33
- total_loss += loss .data [ 0 ]
31
+ total_loss += loss .item ()
34
32
loss .backward ()
35
33
if idx % self .args .batchsize == 0 and idx > 0 :
36
34
self .optimizer .step ()
@@ -41,19 +39,18 @@ def train(self, dataset):
41
39
# helper function for testing
42
40
def test (self , dataset ):
43
41
self .model .eval ()
44
- total_loss = 0
45
- predictions = torch .zeros (len (dataset ))
46
- indices = torch .arange (1 , dataset .num_classes + 1 )
47
- for idx in tqdm (range (len (dataset )), desc = 'Testing epoch ' + str (self .epoch ) + '' ):
48
- ltree , lsent , rtree , rsent , label = dataset [idx ]
49
- linput , rinput = Var (lsent , volatile = True ), Var (rsent , volatile = True )
50
- target = Var (utils .map_label_to_target (label , dataset .num_classes ), volatile = True )
51
- if self .args .cuda :
52
- linput , rinput = linput .cuda (), rinput .cuda ()
53
- target = target .cuda ()
54
- output = self .model (ltree , linput , rtree , rinput )
55
- loss = self .criterion (output , target )
56
- total_loss += loss .data [0 ]
57
- output = output .data .squeeze ().cpu ()
58
- predictions [idx ] = torch .dot (indices , torch .exp (output ))
42
+ with torch .no_grad ():
43
+ total_loss = 0.0
44
+ predictions = torch .zeros (len (dataset ), dtype = torch .float , device = 'cpu' )
45
+ indices = torch .arange (1 , dataset .num_classes + 1 , dtype = torch .float , device = 'cpu' )
46
+ for idx in tqdm (range (len (dataset )), desc = 'Testing epoch ' + str (self .epoch ) + '' ):
47
+ ltree , linput , rtree , rinput , label = dataset [idx ]
48
+ target = utils .map_label_to_target (label , dataset .num_classes )
49
+ linput , rinput = linput .to (self .device ), rinput .to (self .device )
50
+ target = target .to (self .device )
51
+ output = self .model (ltree , linput , rtree , rinput )
52
+ loss = self .criterion (output , target )
53
+ total_loss += loss .item ()
54
+ output = output .squeeze ().to ('cpu' )
55
+ predictions [idx ] = torch .dot (indices , torch .exp (output ))
59
56
return total_loss / len (dataset ), predictions
0 commit comments