Skip to content
This repository was archived by the owner on Aug 18, 2021. It is now read-only.

Commit 0cc55f5

Browse files
committed
dedupe creating RNN
1 parent 1b9a484 commit 0cc55f5

File tree

1 file changed

+9
-21
lines changed

1 file changed

+9
-21
lines changed

char-rnn-classification/train.py

+9-21
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
import math
77

88
n_hidden = 128
9-
rnn = RNN(n_letters, n_hidden, n_categories)
9+
n_epochs = 100000
10+
print_every = 5000
11+
plot_every = 1000
12+
learning_rate = 0.005 # If you set this too high, it might explode. If too low, it might not learn
1013

1114
def categoryFromOutput(output):
1215
top_n, top_i = output.data.topk(1) # Tensor out of Variable with .data
@@ -16,46 +19,31 @@ def categoryFromOutput(output):
1619
def randomChoice(l):
1720
return l[random.randint(0, len(l) - 1)]
1821

19-
def randomTrainingPair():
22+
def randomTrainingPair():
2023
category = randomChoice(all_categories)
2124
line = randomChoice(category_lines[category])
2225
category_tensor = Variable(torch.LongTensor([all_categories.index(category)]))
2326
line_tensor = Variable(lineToTensor(line))
2427
return category, line, category_tensor, line_tensor
2528

29+
rnn = RNN(n_letters, n_hidden, n_categories)
30+
optimizer = torch.optim.SGD(rnn.parameters(), lr=learning_rate)
2631
criterion = nn.NLLLoss()
2732

28-
learning_rate = 0.005 # If you set this too high, it might explode. If too low, it might not learn
29-
3033
def train(category_tensor, line_tensor):
3134
hidden = rnn.initHidden()
32-
33-
rnn.zero_grad()
35+
optimizer.zero_grad()
3436

3537
for i in range(line_tensor.size()[0]):
3638
output, hidden = rnn(line_tensor[i], hidden)
3739

3840
loss = criterion(output, category_tensor)
3941
loss.backward()
4042

41-
# Add parameters' gradients to their values, multiplied by learning rate
42-
for p in rnn.parameters():
43-
p.data.add_(-learning_rate, p.grad.data)
43+
optimizer.step()
4444

4545
return output, loss.data[0]
4646

47-
n_epochs = 100000
48-
print_every = 5000
49-
plot_every = 1000
50-
51-
rnn = RNN(n_letters, n_hidden, n_categories)
52-
53-
n_epochs = 100000
54-
print_every = 5000
55-
plot_every = 1000
56-
57-
rnn = RNN(n_letters, n_hidden, n_categories)
58-
5947
# Keep track of losses for plotting
6048
current_loss = 0
6149
all_losses = []

0 commit comments

Comments
 (0)