Skip to content

Add pad transform, string to int transform #1683

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 21, 2022
70 changes: 70 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,76 @@ def test_add_token(self):
def test_add_token_jit(self):
self._add_token(test_scripting=True)

def _pad_transform(self, test_scripting):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a brief description of what this test does?
In the form of "(under some condition), an API produces a result that satisfies this property"

input_1d_tensor = torch.ones(5)
input_2d_tensor = torch.ones((8, 5))
pad_long = transforms.PadTransform(max_length=7, pad_value=0)
if test_scripting:
pad_long = torch.jit.script(pad_long)
padded_1d_tensor_actual = pad_long(input_1d_tensor)
padded_1d_tensor_expected = torch.cat([torch.ones(5), torch.zeros(2)])
torch.testing.assert_close(
padded_1d_tensor_actual,
padded_1d_tensor_expected,
msg=f"actual: {padded_1d_tensor_actual}, expected: {padded_1d_tensor_expected}",
)

padded_2d_tensor_actual = pad_long(input_2d_tensor)
padded_2d_tensor_expected = torch.cat([torch.ones(8, 5), torch.zeros(8, 2)], axis=-1)
torch.testing.assert_close(
padded_2d_tensor_actual,
padded_2d_tensor_expected,
msg=f"actual: {padded_2d_tensor_actual}, expected: {padded_2d_tensor_expected}",
)

pad_short = transforms.PadTransform(max_length=3, pad_value=0)
if test_scripting:
pad_short = torch.jit.script(pad_short)
padded_1d_tensor_actual = pad_short(input_1d_tensor)
padded_1d_tensor_expected = input_1d_tensor
torch.testing.assert_close(
padded_1d_tensor_actual,
padded_1d_tensor_expected,
msg=f"actual: {padded_1d_tensor_actual}, expected: {padded_1d_tensor_expected}",
)

padded_2d_tensor_actual = pad_short(input_2d_tensor)
padded_2d_tensor_expected = input_2d_tensor
torch.testing.assert_close(
padded_2d_tensor_actual,
padded_2d_tensor_expected,
msg=f"actual: {padded_2d_tensor_actual}, expected: {padded_2d_tensor_expected}",
)

def test_pad_transform(self):
self._pad_transform(test_scripting=False)

def test_pad_transform_jit(self):
self._pad_transform(test_scripting=True)

def _str_to_int_transform(self, test_scripting):
input_1d_string_list = ["1", "2", "3", "4", "5"]
input_2d_string_list = [["1", "2", "3"], ["4", "5", "6"]]

str_to_int = transforms.StrToIntTransform()
if test_scripting:
str_to_int = torch.jit.script(str_to_int)

expected_1d_int_list = [1, 2, 3, 4, 5]
actual_1d_int_list = str_to_int(input_1d_string_list)
self.assertListEqual(expected_1d_int_list, actual_1d_int_list)

expected_2d_int_list = [[1, 2, 3], [4, 5, 6]]
actual_2d_int_list = str_to_int(input_2d_string_list)
for i in range(len(expected_2d_int_list)):
self.assertListEqual(expected_2d_int_list[i], actual_2d_int_list[i])

def test_str_to_int_transform(self):
self._pad_transform(test_scripting=False)

def test_str_to_int_transform_jit(self):
self._pad_transform(test_scripting=True)


class TestSequential(TorchtextTestCase):
def _sequential(self, test_scripting):
Expand Down
29 changes: 29 additions & 0 deletions torchtext/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
"LabelToIndex",
"Truncate",
"AddToken",
"PadTransform",
"StrToIntTransform",
"GPT2BPETokenizer",
"Sequential",
]
Expand Down Expand Up @@ -221,6 +223,33 @@ def forward(self, input: Any) -> Any:
return F.add_token(input, self.token, self.begin)


class PadTransform(Module):
def __init__(self, max_length: int, pad_value: int):
super().__init__()
self.max_length = max_length
self.pad_value = pad_value

def forward(self, x: torch.Tensor) -> torch.Tensor:
max_encoded_length = x.size(-1)
if max_encoded_length < self.max_length:
pad_amount = self.max_length - max_encoded_length
x = torch.nn.functional.pad(x, (0, pad_amount), value=self.pad_value)
return x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please add doc-strings for __init__ and forward methods? Also update the the rst file to update the documentation.



class StrToIntTransform(Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment for doc-strings :)

def __init__(self):
super().__init__()

def forward(self, input: Union[List[str], List[List[str]]]) -> Union[List[int], List[List[int]]]:
if isinstance(input[0], str):
return [int(x) for x in input] # type: ignore
if isinstance(input[0], List) and isinstance(input[0][0], str):
return [[int(x) for x in ll] for ll in input]
else:
raise TypeError("Input type not supported")


class GPT2BPETokenizer(Module):
__jit_unused_properties__ = ["is_jitable"]
"""
Expand Down