-
Notifications
You must be signed in to change notification settings - Fork 816
Add pad transform, string to int transform #1683
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
b66cecc
6589989
54a4eef
3c11873
02b166e
003eae6
0e7630f
243c851
a6d8e54
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,8 @@ | |
"LabelToIndex", | ||
"Truncate", | ||
"AddToken", | ||
"PadTransform", | ||
"StrToIntTransform", | ||
"GPT2BPETokenizer", | ||
"Sequential", | ||
] | ||
|
@@ -221,6 +223,33 @@ def forward(self, input: Any) -> Any: | |
return F.add_token(input, self.token, self.begin) | ||
|
||
|
||
class PadTransform(Module): | ||
def __init__(self, max_length: int, pad_value: int): | ||
super().__init__() | ||
self.max_length = max_length | ||
self.pad_value = pad_value | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
max_encoded_length = x.size(-1) | ||
if max_encoded_length < self.max_length: | ||
pad_amount = self.max_length - max_encoded_length | ||
x = torch.nn.functional.pad(x, (0, pad_amount), value=self.pad_value) | ||
return x | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you please add doc-strings for |
||
|
||
|
||
class StrToIntTransform(Module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment for doc-strings :) |
||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, input: Union[List[str], List[List[str]]]) -> Union[List[int], List[List[int]]]: | ||
if isinstance(input[0], str): | ||
return [int(x) for x in input] # type: ignore | ||
if isinstance(input[0], List) and isinstance(input[0][0], str): | ||
return [[int(x) for x in ll] for ll in input] | ||
else: | ||
raise TypeError("Input type not supported") | ||
|
||
|
||
class GPT2BPETokenizer(Module): | ||
__jit_unused_properties__ = ["is_jitable"] | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a brief description of what this test does?
In the form of "(under some condition), an API produces a result that satisfies this property"