6
6
import math
7
7
8
8
n_hidden = 128
9
- rnn = RNN (n_letters , n_hidden , n_categories )
9
+ n_epochs = 100000
10
+ print_every = 5000
11
+ plot_every = 1000
12
+ learning_rate = 0.005 # If you set this too high, it might explode. If too low, it might not learn
10
13
11
14
def categoryFromOutput (output ):
12
15
top_n , top_i = output .data .topk (1 ) # Tensor out of Variable with .data
@@ -16,46 +19,31 @@ def categoryFromOutput(output):
16
19
def randomChoice (l ):
17
20
return l [random .randint (0 , len (l ) - 1 )]
18
21
19
- def randomTrainingPair ():
22
+ def randomTrainingPair ():
20
23
category = randomChoice (all_categories )
21
24
line = randomChoice (category_lines [category ])
22
25
category_tensor = Variable (torch .LongTensor ([all_categories .index (category )]))
23
26
line_tensor = Variable (lineToTensor (line ))
24
27
return category , line , category_tensor , line_tensor
25
28
29
+ rnn = RNN (n_letters , n_hidden , n_categories )
30
+ optimizer = torch .optim .SGD (rnn .parameters (), lr = learning_rate )
26
31
criterion = nn .NLLLoss ()
27
32
28
- learning_rate = 0.005 # If you set this too high, it might explode. If too low, it might not learn
29
-
30
33
def train (category_tensor , line_tensor ):
31
34
hidden = rnn .initHidden ()
32
-
33
- rnn .zero_grad ()
35
+ optimizer .zero_grad ()
34
36
35
37
for i in range (line_tensor .size ()[0 ]):
36
38
output , hidden = rnn (line_tensor [i ], hidden )
37
39
38
40
loss = criterion (output , category_tensor )
39
41
loss .backward ()
40
42
41
- # Add parameters' gradients to their values, multiplied by learning rate
42
- for p in rnn .parameters ():
43
- p .data .add_ (- learning_rate , p .grad .data )
43
+ optimizer .step ()
44
44
45
45
return output , loss .data [0 ]
46
46
47
- n_epochs = 100000
48
- print_every = 5000
49
- plot_every = 1000
50
-
51
- rnn = RNN (n_letters , n_hidden , n_categories )
52
-
53
- n_epochs = 100000
54
- print_every = 5000
55
- plot_every = 1000
56
-
57
- rnn = RNN (n_letters , n_hidden , n_categories )
58
-
59
47
# Keep track of losses for plotting
60
48
current_loss = 0
61
49
all_losses = []
0 commit comments