Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Scriptable BERT tokenizer #1707

Merged
merged 39 commits into from
May 25, 2022
Merged

Conversation

parmeet
Copy link
Contributor

@parmeet parmeet commented May 10, 2022

This PR adds support for scriptable BERT tokenizer.

Initial Implementation: Our implementation is derived from the work of https://github.com/LieluoboAi/radish. We have made following major amendments in their implementation:

  • Replaced usage of utfcpp with utfproc itself for converting string to and from unicode. This reduces additional dependency on utfcpp.
  • Replaced usage of std::unordered_map with torchtext's Vocab implementation to perform look-up.
  • Fixed wordpiece (max_seg_) algorithm to match HuggingFace (HF) implementation
  • Fixed stripping issue by sending stripped text directly from python (\u2048 at end of string cannot be removed trivially in C++)
  • Replaced their implementation of splitting strings based on whitespace with torchtext's split_
  • Perform to_lower directly on unicode strings. Also corrected the logic of combining flags to perform lowering and stripping accents to match HF implementation
  • Fixed _is_control implementation to match HF implementation
  • Remove comparison with kChinesePunts to match HF's implementation of _is_punctuation
  • Changed UString type from uint16_t to uint32_t. On rare occasions when a unicode code point cannot fit in a uint16_t type it causes errors.

Testing

Verified that the results matches with HF BERT Tokenizer on EnWik9 dataset (13147026 rows).

usage

bert_base_uncased_vocab_file = "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt"
import torchtext.transforms as T
from torchtext.utils import get_asset_local_path
# Instantiate tokenizer with lower case, and return tokens=True (we also support return token IDs instead)
bert_tokenizer = T.BERTTokenizer(get_asset_local_path(bert_base_uncased_vocab_file),
                                    do_lower_case=True, strip_accents=None, return_tokens=True)

# non-batch API
tokens = bert_tokenizer("Hello world")
# out: ['hello', 'world']

# batch API 
tokens = bert_tokenizer(["Hello world","How are you!"])
# out: [['hello', 'world'], ['how', 'are', 'you', '!']]

Follow-up:

  • Perform batch processing directly in C++ instead of iterating over input sentences in python

@parmeet parmeet marked this pull request as ready for review May 17, 2022 19:25
@parmeet parmeet changed the title [WIP] Add support for Scriptable BERT tokenizer Add support for Scriptable BERT tokenizer May 17, 2022
Copy link
Contributor

@Nayef211 Nayef211 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just left a couple of noob questions and suggestions. Overall LGTM. Thanks for adding the BERT tokenizer to torchtext @parmeet! This looks like it was a very complex class to implement 🚀

Comment on lines +597 to +601
for text in input:
if self._return_tokens:
tokens.append(self._tokenize(text))
else:
tokens.append(self._encode(text))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way we could pass in the batch input directly to the C++ kernel and do the for-loop in the kernel itself? As we've seen in previous benchmarking efforts, a lot of time is spent on passing data back and forth between Python and C++ and we may be able to get significant perf gains just by passing the entire list in one go.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's a great idea. Let's do it in follow-up PR.


namespace torchtext {

typedef std::basic_string<uint32_t> UString;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we using std::basic_string here because the text being passed in from Python contains unicode which isn't compatible with std::string?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The string passed from python is UTF-8 encoded bytes. UString is the container to store the unicode code points when converting string to unicode and vice-versa.


namespace torchtext {

std::string BERTEncoder::kUnkToken = "[UNK]";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noob question: why do we make kUnkToken a static property rather than a constant?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh, just following original implementation :).

@parmeet
Copy link
Contributor Author

parmeet commented May 25, 2022

Thanks @Nayef211 for the thorough review and feedback. I have address the comments :).

@parmeet
Copy link
Contributor Author

parmeet commented May 25, 2022

Torchtext CI is failing for windows on python 3.7 Seems like following is the culprit:

error: can't copy 'build\lib.win-amd64-3.7\pyd': doesn't exist or not a regular file

Any suggestions @mthrok, @atalman what's going on in here?

cc: @Nayef211

@Nayef211
Copy link
Contributor

@parmeet I'm also seeing failures for test_bert_tokenizer on several platforms. Let's try to fix these before landing

@parmeet
Copy link
Contributor Author

parmeet commented May 25, 2022

@parmeet I'm also seeing failures for test_bert_tokenizer on several platforms. Let's try to fix these before landing

yaa, looking at them. Not sure what went wrong. Now locally the tests are passing. Will look into CI results.

@parmeet parmeet merged commit da509e1 into pytorch:main May 25, 2022
@parmeet parmeet deleted the bert_tokenizer branch May 25, 2022 15:49
@philschmid
Copy link

Hello 🙋🏻‍♂️

I was trying to script/trace the new BERTTokenizer but without any success. What does scriptable for you mean in this context?

Here is my example on how tried to trace it

bert_base_uncased_vocab_file = "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt"
import torchtext.transforms as T
from torchtext.utils import get_asset_local_path
# Instantiate tokenizer with lower case, and return tokens=True (we also support return token IDs instead)
bert_tokenizer = T.BERTTokenizer(get_asset_local_path(bert_base_uncased_vocab_file),
                                    do_lower_case=True, strip_accents=None, return_tokens=True)

traced_tokenizer = torch.jit.trace(bert_tokenizer, "test")

here is the error

RuntimeError: 
Module 'BERTTokenizer' has no attribute 'bert_model' (This attribute exists on the Python module, but we failed to convert Python type: 'torchtext._torchtext.BERTEncoder' to a TorchScript type. Only tensors and (possibly nested) tuples of tensors, lists, or dictsare supported as inputs or outputs of traced functions, but instead got value of type BERTEncoder.. Its type was inferred; try adding a type annotation for the attribute.):
  File "/home/ubuntu/miniconda3/envs/optimum/lib/python3.8/site-packages/torchtext/transforms.py", line 603
    def _batch_encode(self, text: List[str]) -> List[List[str]]:
        """Batch version of _encode i.e operate on list of str"""
        token_ids: List[List[int]] = self.bert_model.batch_encode([t.strip() for t in text])
                                     ~~~~~~~~~~~~~~~ <--- HERE
        tokens_ids_str: List[List[str]] = [[str(t) for t in token_id] for token_id in token_ids]
        return tokens_ids_str

@parmeet
Copy link
Contributor Author

parmeet commented Jul 6, 2022

@philschmid could you try with torch.jit.script for scripting?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants