-
Notifications
You must be signed in to change notification settings - Fork 816
Add pad transform, string to int transform #1683
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary: Rather than raise an exception whenever head_dim != 64, we can just infer the scaling value and continue to provide a warning. Also add an assertion in case embed_dim is not a multiple of num_heads (in which case forward will break). Reviewed By: parmeet Differential Revision: D32193989 fbshipit-source-id: 30f68c55f3ec37932252c77c355ae55b8bf34ded
Test plan: Add unit tests ``` python3 -m pytest -v -k 'pad_transform or str_to_int' ====================================================================================================== test session starts ======================================================================================================= platform linux -- Python 3.9.12, pytest-6.2.5, py-1.11.0, pluggy-1.0.0 -- /data/home/ebs/miniconda/envs/torchtext/bin/python3 cachedir: .pytest_cache rootdir: /data/home/ebs/torchtext, configfile: pytest.ini, testpaths: test/ plugins: pythonpath-0.7.4, cov-3.0.0 collected 1091 items / 1087 deselected / 4 selected test/test_transforms.py::TestTransforms::test_pad_transform PASSED [ 25%] test/test_transforms.py::TestTransforms::test_pad_transform_jit PASSED [ 50%] test/test_transforms.py::TestTransforms::test_str_to_int_transform PASSED [ 75%] test/test_transforms.py::TestTransforms::test_str_to_int_transform_jit PASSED [100%] ============================================================================================== 4 passed, 1087 deselected in 10.57s =============================================================================================== ```
torchtext/transforms.py
Outdated
class PadTransform(Module): | ||
def __init__(self, max_length: int, pad_value: int): | ||
super().__init__() | ||
self.max_length = max_length | ||
self.pad_value = pad_value | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
max_encoded_length = x.size(-1) | ||
if max_encoded_length < self.max_length: | ||
pad_amount = self.max_length - max_encoded_length | ||
x = torch.nn.functional.pad(x, (0, pad_amount), value=self.pad_value) | ||
return x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you please add doc-strings for __init__
and forward
methods? Also update the the rst file to update the documentation.
return x | ||
|
||
|
||
class StrToIntTransform(Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment for doc-strings :)
docs/source/transforms.rst
Outdated
@@ -71,3 +71,17 @@ Sequential | |||
.. autoclass:: Sequential | |||
|
|||
.. automethod:: forward | |||
|
|||
PadTransform | |||
---------- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
---------- | |
------------ |
docs/source/transforms.rst
Outdated
.. automethod:: forward | ||
|
||
StrToIntTransform | ||
---------- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
---------- | |
----------------- |
@@ -205,6 +205,76 @@ def test_add_token(self): | |||
def test_add_token_jit(self): | |||
self._add_token(test_scripting=True) | |||
|
|||
def _pad_transform(self, test_scripting): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a brief description of what this test does?
In the form of "(under some condition), an API produces a result that satisfies this property"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for upstreaming the transforms from torchMM. I think once we land this, the next step could be use the transforms directly from torchtext and remove them from torchMM :)
Test plan:
Add unit tests