Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Scriptable BERT tokenizer #1707

Merged
merged 39 commits into from
May 25, 2022
Merged
Changes from 1 commit
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
873fc43
add classes
parmeet Apr 18, 2022
914cf71
Merge branch 'main' of github.com:pytorch/text into bert_tokenizer
parmeet Apr 24, 2022
82ac4d2
Merge branch 'main' of github.com:pytorch/text into bert_tokenizer
parmeet Apr 27, 2022
f17cc21
add basic functions
parmeet Apr 27, 2022
06ec41b
Merge branch 'main' of github.com:pytorch/text into bert_tokenizer
parmeet May 4, 2022
4276030
minor updates
parmeet May 4, 2022
e39826e
minor update
parmeet May 4, 2022
aaad788
added submodule
parmeet May 5, 2022
4f892e1
initial run
parmeet May 6, 2022
ef1b1f7
added pybinded transform and test structure
parmeet May 6, 2022
9414cda
fixed few bugs
parmeet May 8, 2022
ae6206b
fixed _is_control
parmeet May 9, 2022
aa543af
using python strip and removing it from C++
parmeet May 9, 2022
67c452f
fix UString type
parmeet May 10, 2022
d442d86
partially add code for scripting
parmeet May 10, 2022
434c002
add support for scripting
parmeet May 10, 2022
74f1231
minor edit
parmeet May 10, 2022
e67b513
minor edit
parmeet May 10, 2022
70d0fc8
remove chinese punctuation
parmeet May 11, 2022
3a27236
fix lint
parmeet May 11, 2022
1a77b8b
fix lint
parmeet May 11, 2022
a0caeb1
adding to_lower option
parmeet May 11, 2022
5cf80a1
Revert "adding to_lower option"
parmeet May 11, 2022
d0b4e7d
add to_lower option, need to fix unit test
parmeet May 11, 2022
653515c
update to_lower
parmeet May 12, 2022
806d67e
fix upper case tests
parmeet May 12, 2022
fc0a608
modify test suit
parmeet May 12, 2022
c26fb4b
Merge branch 'main' of github.com:pytorch/text into bert_tokenizer
parmeet May 16, 2022
2dd48ac
minor edits
parmeet May 16, 2022
9e4c098
fixed linter
parmeet May 16, 2022
ab177ea
undo changes in clip test
parmeet May 16, 2022
fa498e5
add fix for 3332 code point
parmeet May 17, 2022
6de32f1
fix lint
parmeet May 17, 2022
9b2038b
Merge branch 'main' of github.com:pytorch/text into bert_tokenizer
parmeet May 23, 2022
0a64f89
fix doc strings and C++ contructor initializer list
parmeet May 25, 2022
2387924
fix lint
parmeet May 25, 2022
5900bdf
Merge branch 'main' of github.com:pytorch/text into bert_tokenizer
parmeet May 25, 2022
abd81fc
Revert "fix doc strings and C++ contructor initializer list"
parmeet May 25, 2022
d940ecb
re-address comments w.r.t revert
parmeet May 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Revert "adding to_lower option"
This reverts commit a0caeb1.
parmeet committed May 11, 2022

Verified

This commit was signed with the committer’s verified signature.
kemuru Marino
commit 5cf80a1ab0c5487137031ab68a827aa1f63f0e50
4 changes: 1 addition & 3 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
@@ -586,10 +586,8 @@ def test_clip_tokenizer_save_load_torchscript(self):
class TestBERTTokenizer(TorchtextTestCase):
def _load_tokenizer(self, test_scripting: bool, return_tokens: bool):
vocab_file = "bert_base_uncased_vocab.txt"
to_lower = True
tokenizer = transforms.BERTTokenizer(
vocab_path=get_asset_path(vocab_file),
to_lower=to_lower,
return_tokens=return_tokens,
)
if test_scripting:
@@ -639,6 +637,6 @@ def test_bert_tokenizer_save_load_torchscript(self):
tokenizer_path = os.path.join(self.test_dir, "bert_tokenizer_torchscript.pt")
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
torch.save(torch.jit.script(tokenizer), tokenizer_path)
torch.save(tokenizer.__prepare_scriptable__(), tokenizer_path)
loaded_tokenizer = torch.load(tokenizer_path)
self._bert_tokenizer((loaded_tokenizer))
47 changes: 27 additions & 20 deletions torchtext/csrc/bert_tokenizer.cpp
Original file line number Diff line number Diff line change
@@ -86,15 +86,10 @@ static void _to_lower(UString& text) {
}
}

BERTEncoder::BERTEncoder(const std::string& vocab_file, bool to_lower)
: vocab_{_load_vocab_from_file(vocab_file, 1, 1)} {
to_lower_ = to_lower;
}
BERTEncoder::BERTEncoder(const std::string& vocab_file)
: vocab_{_load_vocab_from_file(vocab_file, 1, 1)} {}

BERTEncoder::BERTEncoder(std::vector<std::string> tokens, bool to_lower)
: vocab_{Vocab(tokens)} {
to_lower = to_lower_;
}
BERTEncoder::BERTEncoder(Vocab vocab) : vocab_{vocab} {}

UString BERTEncoder::_clean(UString text) {
/* This function combines:
@@ -219,8 +214,7 @@ std::vector<std::string> BERTEncoder::Tokenize(std::string text) {
unicodes = _basic_tokenize(unicodes);

// Convert text to lower-case
if (to_lower_)
_to_lower(unicodes);
_to_lower(unicodes);

// Convert back to string from code-points
std::string newtext = _convert_from_unicode(unicodes);
@@ -250,24 +244,37 @@ std::vector<int64_t> BERTEncoder::Encode(std::string text) {
return indices;
}

BERTEncoderStates _serialize_bert_encoder(
VocabStates _serialize_bert_encoder(
const c10::intrusive_ptr<BERTEncoder>& self) {
auto strings = self->vocab_.itos_;
return std::make_tuple(self->to_lower_, std::move(strings));
return _serialize_vocab(c10::make_intrusive<Vocab>(self->vocab_));
}

c10::intrusive_ptr<BERTEncoder> _deserialize_bert_encoder(
BERTEncoderStates states) {
c10::intrusive_ptr<BERTEncoder> _deserialize_bert_encoder(VocabStates states) {
auto state_size = std::tuple_size<decltype(states)>::value;
TORCH_CHECK(
state_size == 2,
"Expected deserialized BERTEncoder to have 2 states but found " +
state_size == 4,
"Expected deserialized Vocab to have 4 states but found " +
std::to_string(state_size) + " states");

auto& to_lower = std::get<0>(states);
auto& strings = std::get<1>(states);
auto& version_str = std::get<0>(states);
auto& integers = std::get<1>(states);
auto& strings = std::get<2>(states);
auto& tensors = std::get<3>(states);

// check tensors are empty
TORCH_CHECK(tensors.size() == 0, "Expected `tensors` states to be empty");

return c10::make_intrusive<BERTEncoder>(std::move(strings), to_lower);
// throw error if version is not compatible
TORCH_CHECK(
version_str.compare("0.0.2") >= 0,
"Found unexpected version for serialized Vocab: " + version_str);

c10::optional<int64_t> default_index = {};
if (integers.size() > 0) {
default_index = integers[0];
}
return c10::make_intrusive<BERTEncoder>(
Vocab(std::move(strings), default_index));
}

} // namespace torchtext
12 changes: 4 additions & 8 deletions torchtext/csrc/bert_tokenizer.h
Original file line number Diff line number Diff line change
@@ -6,15 +6,12 @@ namespace torchtext {

typedef std::basic_string<uint32_t> UString;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we using std::basic_string here because the text being passed in from Python contains unicode which isn't compatible with std::string?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The string passed from python is UTF-8 encoded bytes. UString is the container to store the unicode code points when converting string to unicode and vice-versa.


typedef std::tuple<bool, std::vector<std::string>> BERTEncoderStates;

struct BERTEncoder : torch::CustomClassHolder {
BERTEncoder(const std::string& vocab_file, bool to_lower);
BERTEncoder(std::vector<std::string> tokens, bool to_lower);
BERTEncoder(const std::string& vocab_file);
BERTEncoder(Vocab vocab);
std::vector<std::string> Tokenize(std::string text);
std::vector<int64_t> Encode(std::string text);
Vocab vocab_;
bool to_lower_;

protected:
UString _clean(UString text);
@@ -27,8 +24,7 @@ struct BERTEncoder : torch::CustomClassHolder {
static std::string kUnkToken;
};

BERTEncoderStates _serialize_bert_encoder(
VocabStates _serialize_bert_encoder(
const c10::intrusive_ptr<BERTEncoder>& self);
c10::intrusive_ptr<BERTEncoder> _deserialize_bert_encoder(
BERTEncoderStates states);
c10::intrusive_ptr<BERTEncoder> _deserialize_bert_encoder(VocabStates states);
} // namespace torchtext
6 changes: 3 additions & 3 deletions torchtext/csrc/register_pybindings.cpp
Original file line number Diff line number Diff line change
@@ -217,16 +217,16 @@ PYBIND11_MODULE(_torchtext, m) {
}));

py::class_<BERTEncoder, c10::intrusive_ptr<BERTEncoder>>(m, "BERTEncoder")
.def(py::init<const std::string, bool>())
.def(py::init<const std::string>())
.def("encode", &BERTEncoder::Encode)
.def("tokenize", &BERTEncoder::Tokenize)
.def(py::pickle(
// __getstate__
[](const c10::intrusive_ptr<BERTEncoder>& self) -> BERTEncoderStates {
[](const c10::intrusive_ptr<BERTEncoder>& self) -> VocabStates {
return _serialize_bert_encoder(self);
},
// __setstate__
[](BERTEncoderStates states) -> c10::intrusive_ptr<BERTEncoder> {
[](VocabStates states) -> c10::intrusive_ptr<BERTEncoder> {
return _deserialize_bert_encoder(states);
}));

6 changes: 3 additions & 3 deletions torchtext/csrc/register_torchbindings.cpp
Original file line number Diff line number Diff line change
@@ -174,16 +174,16 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) {
});

m.class_<BERTEncoder>("BERTEncoder")
.def(torch::init<const std::string, bool>())
.def(torch::init<const std::string>())
.def("encode", &BERTEncoder::Encode)
.def("tokenize", &BERTEncoder::Tokenize)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<BERTEncoder>& self) -> BERTEncoderStates {
[](const c10::intrusive_ptr<BERTEncoder>& self) -> VocabStates {
return _serialize_bert_encoder(self);
},
// __setstate__
[](BERTEncoderStates states) -> c10::intrusive_ptr<BERTEncoder> {
[](VocabStates states) -> c10::intrusive_ptr<BERTEncoder> {
return _deserialize_bert_encoder(states);
});
;
7 changes: 3 additions & 4 deletions torchtext/transforms.py
Original file line number Diff line number Diff line change
@@ -539,12 +539,11 @@ class BERTTokenizer(Module):
Transform for BERT Tokenizer.
"""

def __init__(self, vocab_path: str, to_lower:bool, return_tokens=False) -> None:
def __init__(self, vocab_path: str, return_tokens=False) -> None:
super().__init__()
self.bert_model = BERTEncoderPyBind(vocab_path, to_lower)
self.bert_model = BERTEncoderPyBind(vocab_path)
self._return_tokens = return_tokens
self._vocab_path = vocab_path
self._to_lower = to_lower

@property
def is_jitable(self):
@@ -609,7 +608,7 @@ def __prepare_scriptable__(self):

if not self.is_jitable:
tokenizer_copy = deepcopy(self)
tokenizer_copy.bert_model = torch.classes.torchtext.BERTEncoder(self._vocab_path, self._to_lower)
tokenizer_copy.bert_model = torch.classes.torchtext.BERTEncoder(self._vocab_path)
return tokenizer_copy

return self