Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to persist a pytorch lightning module that depends on external data? #1755

Closed
Apsod opened this issue May 7, 2020 · 7 comments
Closed
Labels
question Further information is requested won't fix This will not be worked on

Comments

@Apsod
Copy link

Apsod commented May 7, 2020

❓ Questions and Help

What is your question?

Hi! We're using pytorch lightning to train language models and transformers from scratch. This includes training tokenizers and applying them text data, resulting in binarized data.
The way we've structured the process is to train a tokenizer, apply it to text data (coupling the binarized data and the tokenizer), and apply a language model on the binarized data.

Since the language model depends on the tokenizer (number of tokens, special tokens, et.c.) the pytorch lightning model needs a tokenizer/vocabulary as part of its hparams. This does not play very nicely with the way hparams and loading works: If we transfer the model from one computer to another, we would need to move a tokenizer to the exact same path on the other computer.

Generally i guess the problem boils down to this: If the inner pytorch modules of the pytorch lightning module depends on some kind of external data (e.g. a vocabulary, or a random sparsity graph), and you then wish to share the pytorch lightning module, we can't find an easy way of doing this.

on_load/save_checkpoint does not work, since they take effect after the model has been initialized, whereas we would like to persist data to the initialization logic in itself.

Is there an elegant way to do this in pytorch lightning?

@Apsod Apsod added the question Further information is requested label May 7, 2020
@github-actions
Copy link
Contributor

github-actions bot commented May 7, 2020

Hi! thanks for your contribution!, great first issue!

@williamFalcon
Copy link
Contributor

good points. been thinking about this as well.

can you share pseudocode so we can come up with the changes to the API?

@Apsod
Copy link
Author

Apsod commented May 7, 2020

Say that we have a Transformer model which takes a tokenizer as part of its hparams:

class Transformer(LightningModule):
  def __init__(self, hparams):
    ...
    # load the tokenizer/data
    self.tokenizer = load_tokenizer(hparams.tokenizer_path)

    # Initialize the pytorch model (dependent on tokenizer)
    self.transformer = torch.nn.Transformer(
      dimension = hparams.dimension,
      num_embeddings = self.tokenizer.vocab_size,
      padding_index = self.tokenizer.padding_index,
      ...)

model = Transformer(hparams)
trainer = Trainer(..)
trainer.fit(model)

Later we wish to load the transformer

def do_transformer_stuff(checkpoint_path):
  transformer = Transformer.load_from_checkpoint(checkpoint_path)
  ...

This work perfectly fine, given that tokenizer_path points to the same tokenizer. However, if the original tokenizer_path was relative or the checkpoint was transferred to some other place, it will fail.

One workaround is to make the tokenizer a kwarg:

class Transformer(LightningModule):
  def __init__(self, hparams, tokenizer=None):
    ...
    # set the tokenizer (tokenizer loading logic outside of Transformer)
    self.tokenizer = tokenizer

    # Initialize the pytorch model (dependent on tokenizer)
    self.transformer = torch.nn.Transformer(
      dimension = hparams.dimension,
      num_embeddings = self.tokenizer.vocab_size,
      padding_index = self.tokenizer.padding_index,
      ...)


tokenizer = load_tokenizer(hparams.tokenizer_path)
model = Transformer(hparams, tokenizer=tokenizer)
trainer = Trainer(..)
trainer.fit(model)

The consequence of which is that initialization logic needs to take place outside of the Transformer, and that it is no longer self contained:

def do_transformer_stuff(checkpoint_path, tokenizer_path):
  tokenizer = load_tokenizer(hparams.tokenizer_path)
  transformer = Transformer.load_from_checkpoint(checkpoint_path, tokenizer=tokenizer)

@yukw777
Copy link
Contributor

yukw777 commented May 7, 2020

I personally went with something similar to the workaround. I didn't think it was particularly bad that it wasn't "self-contained".

@Apsod
Copy link
Author

Apsod commented May 9, 2020

FWIW, we did find a workaround that makes the module self-contained. It is built on the kwarg workaround, adding a separate hparams only classmethod make responsible for tokenizer initialization and a separate classmethod load responsible for loading from checkpoint.

class Transformer(LightningModule):
  def __init__(self, hparams, tokenizer=None):
    ...
    # set the tokenizer (tokenizer loading logic outside of Transformer)
    self.tokenizer = tokenizer

    # Initialize the pytorch model (dependent on tokenizer)
    self.transformer = torch.nn.Transformer(
      dimension = hparams.dimension,
      num_embeddings = self.tokenizer.vocab_size,
      padding_index = self.tokenizer.padding_index,
      ...)
    def on_save_checkpoint(self, checkpoint):
        checkpoint['tokenizer'] = self.tokenizer

    @classmethod
    def make(cls, hparams):
        # Essentially a wrapper around init responsible for tokenizer loading
        tokenizer = get_tokenizer(hparams)
        return cls(hparams, tokenizer=tokenizer)

    @classmethod
    def load(
            cls,
            checkpoint_path,
            map_location = None
            ):
        # Copied from load_from_checkpoint, but we extract tokenizer (saved during on_save) and make it a kwarg.

        if map_location is not None:
            checkpoint = torch.load(checkpoint_path, map_location=map_location)
        else:
            checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
        
        args = []
        tokenizer = checkpoint['tokenizer']   # extract the tokenizer
        kwargs = {'tokenizer': tokenizer}      # Make it a kwarg to _load_model_state

        model = cls._load_model_state(checkpoint, *args, **kwargs)
        return model

This is obviously particular to our use case, but I think it is possible to polish it a bit by, for example, making it possible use a saved kwargs-dictionary in load_from_checkpoint.

@elkotito
Copy link
Contributor

If we transfer the model from one computer to another, we would need to move a tokenizer to the exact same path on the other computer.

It's a completely valid requirement. That's why people wrap up their training experiments into Dockerfiles. A different example is that Polyaxon supports this with YAML scripts that help you define running environment.

@stale
Copy link

stale bot commented Jul 13, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the won't fix This will not be worked on label Jul 13, 2020
@stale stale bot closed this as completed Jul 22, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested won't fix This will not be worked on
Projects
None yet
Development

No branches or pull requests

4 participants