-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to persist a pytorch lightning module that depends on external data? #1755
Comments
Hi! thanks for your contribution!, great first issue! |
good points. been thinking about this as well. can you share pseudocode so we can come up with the changes to the API? |
Say that we have a Transformer model which takes a tokenizer as part of its hparams:
Later we wish to load the transformer
This work perfectly fine, given that One workaround is to make the tokenizer a kwarg:
The consequence of which is that initialization logic needs to take place outside of the Transformer, and that it is no longer self contained:
|
I personally went with something similar to the workaround. I didn't think it was particularly bad that it wasn't "self-contained". |
FWIW, we did find a workaround that makes the module self-contained. It is built on the kwarg workaround, adding a separate hparams only classmethod class Transformer(LightningModule):
def __init__(self, hparams, tokenizer=None):
...
# set the tokenizer (tokenizer loading logic outside of Transformer)
self.tokenizer = tokenizer
# Initialize the pytorch model (dependent on tokenizer)
self.transformer = torch.nn.Transformer(
dimension = hparams.dimension,
num_embeddings = self.tokenizer.vocab_size,
padding_index = self.tokenizer.padding_index,
...)
def on_save_checkpoint(self, checkpoint):
checkpoint['tokenizer'] = self.tokenizer
@classmethod
def make(cls, hparams):
# Essentially a wrapper around init responsible for tokenizer loading
tokenizer = get_tokenizer(hparams)
return cls(hparams, tokenizer=tokenizer)
@classmethod
def load(
cls,
checkpoint_path,
map_location = None
):
# Copied from load_from_checkpoint, but we extract tokenizer (saved during on_save) and make it a kwarg.
if map_location is not None:
checkpoint = torch.load(checkpoint_path, map_location=map_location)
else:
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
args = []
tokenizer = checkpoint['tokenizer'] # extract the tokenizer
kwargs = {'tokenizer': tokenizer} # Make it a kwarg to _load_model_state
model = cls._load_model_state(checkpoint, *args, **kwargs)
return model This is obviously particular to our use case, but I think it is possible to polish it a bit by, for example, making it possible use a saved kwargs-dictionary in load_from_checkpoint. |
It's a completely valid requirement. That's why people wrap up their training experiments into Dockerfiles. A different example is that Polyaxon supports this with YAML scripts that help you define running environment. |
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
❓ Questions and Help
What is your question?
Hi! We're using pytorch lightning to train language models and transformers from scratch. This includes training tokenizers and applying them text data, resulting in binarized data.
The way we've structured the process is to train a tokenizer, apply it to text data (coupling the binarized data and the tokenizer), and apply a language model on the binarized data.
Since the language model depends on the tokenizer (number of tokens, special tokens, et.c.) the pytorch lightning model needs a tokenizer/vocabulary as part of its hparams. This does not play very nicely with the way hparams and loading works: If we transfer the model from one computer to another, we would need to move a tokenizer to the exact same path on the other computer.
Generally i guess the problem boils down to this: If the inner pytorch modules of the pytorch lightning module depends on some kind of external data (e.g. a vocabulary, or a random sparsity graph), and you then wish to share the pytorch lightning module, we can't find an easy way of doing this.
on_load/save_checkpoint does not work, since they take effect after the model has been initialized, whereas we would like to persist data to the initialization logic in itself.
Is there an elegant way to do this in pytorch lightning?
The text was updated successfully, but these errors were encountered: