Skip to content

Commit 8aecbb9

Browse files
authored
Updated vocab and vectors with forward method (pytorch#953)
* Updated vocab and vectors with forward method * Added tests
1 parent 87f0d44 commit 8aecbb9

File tree

4 files changed

+34
-2
lines changed

4 files changed

+34
-2
lines changed

test/experimental/test_vectors.py

+18
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,24 @@ def test_vectors_jit(self):
7171
self.assertEqual(vectors_obj['b'], jit_vectors_obj['b'])
7272
self.assertEqual(vectors_obj['not_in_it'], jit_vectors_obj['not_in_it'])
7373

74+
def test_vectors_forward(self):
75+
tensorA = torch.tensor([1, 0], dtype=torch.float)
76+
tensorB = torch.tensor([0, 1], dtype=torch.float)
77+
78+
unk_tensor = torch.tensor([0, 0], dtype=torch.float)
79+
tokens = ['a', 'b']
80+
vecs = torch.stack((tensorA, tensorB), 0)
81+
vectors_obj = vectors(tokens, vecs, unk_tensor=unk_tensor)
82+
jit_vectors_obj = torch.jit.script(vectors_obj.to_ivalue())
83+
84+
tokens_to_lookup = ['a', 'b', 'c']
85+
expected_vectors = torch.stack((tensorA, tensorB, unk_tensor), 0)
86+
vectors_by_tokens = vectors_obj(tokens_to_lookup)
87+
jit_vectors_by_tokens = jit_vectors_obj(tokens_to_lookup)
88+
89+
self.assertEqual(expected_vectors, vectors_by_tokens)
90+
self.assertEqual(expected_vectors, jit_vectors_by_tokens)
91+
7492
def test_vectors_lookup_vectors(self):
7593
tensorA = torch.tensor([1, 0], dtype=torch.float)
7694
tensorB = torch.tensor([0, 1], dtype=torch.float)

test/experimental/test_vocab.py

+14
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,20 @@ def test_vocab_jit(self):
119119
self.assertEqual(jit_v.get_itos(), expected_itos)
120120
self.assertEqual(dict(jit_v.get_stoi()), expected_stoi)
121121

122+
def test_vocab_forward(self):
123+
token_to_freq = {'a': 2, 'b': 2, 'c': 2}
124+
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
125+
126+
c = OrderedDict(sorted_by_freq_tuples)
127+
v = vocab(c)
128+
jit_v = torch.jit.script(v.to_ivalue())
129+
130+
tokens = ['b', 'a', 'c']
131+
expected_indices = [2, 1, 3]
132+
133+
self.assertEqual(v(tokens), expected_indices)
134+
self.assertEqual(jit_v(tokens), expected_indices)
135+
122136
def test_vocab_lookup_token(self):
123137
token_to_freq = {'a': 2, 'b': 2, 'c': 2}
124138
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)

torchtext/experimental/vectors.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def is_jitable(self):
203203
return not isinstance(self.vectors, VectorsPybind)
204204

205205
@torch.jit.export
206-
def __call__(self, tokens: List[str]) -> Tensor:
206+
def forward(self, tokens: List[str]) -> Tensor:
207207
r"""Calls the `lookup_vectors` method
208208
Args:
209209
tokens: a list of tokens

torchtext/experimental/vocab.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def is_jitable(self):
128128
return not isinstance(self.vocab, VocabPybind)
129129

130130
@torch.jit.export
131-
def __call__(self, tokens: List[str]) -> List[int]:
131+
def forward(self, tokens: List[str]) -> List[int]:
132132
r"""Calls the `lookup_indices` method
133133
Args:
134134
tokens (List[str]): the tokens used to lookup their corresponding `indices`.

0 commit comments

Comments
 (0)