Skip to content

Add pad transform, string to int transform #1683

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 21, 2022

Conversation

ebsmothers
Copy link
Contributor

Test plan:

Add unit tests

python3 -m pytest -v -k 'pad_transform or str_to_int'
====================================================================================================== test session starts =======================================================================================================
platform linux -- Python 3.9.12, pytest-6.2.5, py-1.11.0, pluggy-1.0.0 -- /data/home/ebs/miniconda/envs/torchtext/bin/python3
cachedir: .pytest_cache
rootdir: /data/home/ebs/torchtext, configfile: pytest.ini, testpaths: test/
plugins: pythonpath-0.7.4, cov-3.0.0
collected 1091 items / 1087 deselected / 4 selected

test/test_transforms.py::TestTransforms::test_pad_transform PASSED                                                                                                                                                         [ 25%]
test/test_transforms.py::TestTransforms::test_pad_transform_jit PASSED                                                                                                                                                     [ 50%]
test/test_transforms.py::TestTransforms::test_str_to_int_transform PASSED                                                                                                                                                  [ 75%]
test/test_transforms.py::TestTransforms::test_str_to_int_transform_jit PASSED                                                                                                                                              [100%]

============================================================================================== 4 passed, 1087 deselected in 10.57s ===============================================================================================

Summary:
Rather than raise an exception whenever head_dim != 64, we can just infer the scaling value and continue to provide a warning.

Also add an assertion in case embed_dim is not a multiple of num_heads (in which case forward will break).

Reviewed By: parmeet

Differential Revision: D32193989

fbshipit-source-id: 30f68c55f3ec37932252c77c355ae55b8bf34ded
Test plan:

Add unit tests

```
python3 -m pytest -v -k 'pad_transform or str_to_int'
====================================================================================================== test session starts =======================================================================================================
platform linux -- Python 3.9.12, pytest-6.2.5, py-1.11.0, pluggy-1.0.0 -- /data/home/ebs/miniconda/envs/torchtext/bin/python3
cachedir: .pytest_cache
rootdir: /data/home/ebs/torchtext, configfile: pytest.ini, testpaths: test/
plugins: pythonpath-0.7.4, cov-3.0.0
collected 1091 items / 1087 deselected / 4 selected

test/test_transforms.py::TestTransforms::test_pad_transform PASSED                                                                                                                                                         [ 25%]
test/test_transforms.py::TestTransforms::test_pad_transform_jit PASSED                                                                                                                                                     [ 50%]
test/test_transforms.py::TestTransforms::test_str_to_int_transform PASSED                                                                                                                                                  [ 75%]
test/test_transforms.py::TestTransforms::test_str_to_int_transform_jit PASSED                                                                                                                                              [100%]

============================================================================================== 4 passed, 1087 deselected in 10.57s ===============================================================================================
```
Comment on lines 226 to 237
class PadTransform(Module):
def __init__(self, max_length: int, pad_value: int):
super().__init__()
self.max_length = max_length
self.pad_value = pad_value

def forward(self, x: torch.Tensor) -> torch.Tensor:
max_encoded_length = x.size(-1)
if max_encoded_length < self.max_length:
pad_amount = self.max_length - max_encoded_length
x = torch.nn.functional.pad(x, (0, pad_amount), value=self.pad_value)
return x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please add doc-strings for __init__ and forward methods? Also update the the rst file to update the documentation.

return x


class StrToIntTransform(Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment for doc-strings :)

@ebsmothers ebsmothers requested a review from parmeet April 19, 2022 17:44
@@ -71,3 +71,17 @@ Sequential
.. autoclass:: Sequential

.. automethod:: forward

PadTransform
----------
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
----------
------------

.. automethod:: forward

StrToIntTransform
----------
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
----------
-----------------

@@ -205,6 +205,76 @@ def test_add_token(self):
def test_add_token_jit(self):
self._add_token(test_scripting=True)

def _pad_transform(self, test_scripting):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a brief description of what this test does?
In the form of "(under some condition), an API produces a result that satisfies this property"

@ebsmothers ebsmothers marked this pull request as draft April 20, 2022 22:37
@ebsmothers ebsmothers removed the request for review from parmeet April 20, 2022 22:37
@ebsmothers ebsmothers marked this pull request as ready for review April 20, 2022 23:40
@ebsmothers ebsmothers requested review from mthrok and parmeet April 20, 2022 23:52
Copy link
Contributor

@parmeet parmeet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for upstreaming the transforms from torchMM. I think once we land this, the next step could be use the transforms directly from torchtext and remove them from torchMM :)

@parmeet parmeet merged commit 38f520c into pytorch:main Apr 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants