-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathencoding.lua
67 lines (55 loc) · 1.77 KB
/
encoding.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
require 'table'
require 'torch'
local encoding = {}
-- function to create character based vocalulary with unique id for every
-- character and return the vocabulary along with encoded dataset based on the
-- ids.
function char_to_ints(text)
local alphabet = {}
local encoded = torch.Tensor(#text)
for i = 1, #text do
local c = text:sub(i, i)
if alphabet[c] == nil then
alphabet[#alphabet + 1] = c
alphabet[c] = #alphabet
end
encoded[i] = alphabet[c]
end
return alphabet, encoded
end
function invert_alphabet(alphabet)
local inverted = {}
for char, code in pairs(alphabet) do
inverted[code] = char
end
return inverted
end
function ints_to_chars(alphabet, ints)
-- with the current code, there is no need to invert because the alphabet
-- table already contains inverted key value pair.
local decoder = invert_alphabet(alphabet)
local decoded = {}
for i = 1, ints:size(1) do
decoded[i] = decoder[ints[i]]
end
end
-- function for one hot encoding
function ints_to_one_hot(ints, width)
local height = ints:size()[1]
local zeros = torch.zeros(height, width)
local indices = ints:view(-1, 1):long()
local one_hot = zeros:scatter(2, indices, 1)
return one_hot
end
-- function from one hot encoding to ints
function one_hot_to_ints(ont_hot)
-- y,i=torch.max(x,1) returns the largest element in each column (across
-- rows) of x, and a tensor i of their corresponding indices in x.
-- y,i=torch.max(x,2) performs the max operation across rows.
local _, ints = torch.max(one_hot, 2)
return ints
end
-- export the following functions globally
encoding.ints_to_one_hot = ints_to_one_hot
encoding.char_to_ints = char_to_ints
return encoding