Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Text Classification with PyTorch Lightning: 'dict' object has no attribute 'task' #5452

Closed
stefan-it opened this issue Jul 2, 2020 · 15 comments
Labels
Examples Which is related to examples in general wontfix

Comments

@stefan-it
Copy link
Collaborator

Hi,

after manually resolving the n_gpu attribute issue in lightning_base.py (see #5385), I found another strange behaviour in the Text Classification example.

I used PL in version 0.8.1 with the run_pl.sh script. Training works, but after reloading the model for evaluation, the following error message is thrown:

Traceback (most recent call last):
  File "run_pl_glue.py", line 189, in <module>
    model = model.load_from_checkpoint(checkpoints[-1])
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/core/saving.py", line 171, in load_from_checkpoint
    model = cls._load_model_state(checkpoint, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/core/saving.py", line 201, in _load_model_state
    model = cls(*args, **kwargs)
  File "run_pl_glue.py", line 28, in __init__
    hparams.glue_output_mode = glue_output_modes[hparams.task]
AttributeError: 'dict' object has no attribute 'task'

I did some debugging. So the interesting part is in the constructor:

def __init__(self, hparams):
hparams.glue_output_mode = glue_output_modes[hparams.task]
num_labels = glue_tasks_num_labels[hparams.task]
super().__init__(hparams, num_labels, self.mode)

For training (first initialization), the hparams variable outputs:

Namespace(adam_epsilon=1e-08, cache_dir='', config_name='', data_dir='./glue_data/MRPC/', do_predict=True, do_train=True, eval_batch_size=32, fast_dev_run=False, fp16=True, fp16_opt_level='O1', gpus=1, gradient_accumulation_steps=1, learning_rate=2e-05, max_grad_norm=1.0, max_seq_length=128, model_name_or_path='bert-base-cased', n_tpu_cores=0, num_train_epochs=1, num_workers=4, output_dir='/mnt/transformers-pl/examples/text-classification/mrpc-pl-bert', overwrite_cache=False, resume_from_checkpoint=None, seed=2, task='mrpc', tokenizer_name=None, train_batch_size=32, val_check_interval=1.0, warmup_steps=0, weight_decay=0.0)

Notice the type: it is a Namespace. After training... and re-loading the model checkpoint, hparams looks like:

{'output_dir': '/mnt/transformers-pl/examples/text-classification/mrpc-pl-bert', 'fp16': True, 'fp16_opt_level': 'O1', 'fast_dev_run': False, 'gpus': 1, 'n_tpu_cores': 0, 'max_grad_norm': 1.0, 'do_train': True, 'do_predict': True, 'gradient_accumulation_steps': 1, 'seed': 2, 'resume_from_checkpoint': None, 'val_check_interval': 1.0, 'model_name_or_path': 'bert-base-cased', 'config_name': '', 'tokenizer_name': None, 'cache_dir': '', 'learning_rate': 2e-05, 'weight_decay': 0.0, 'adam_epsilon': 1e-08, 'warmup_steps': 0, 'num_workers': 4, 'num_train_epochs': 1, 'train_batch_size': 32, 'eval_batch_size': 32, 'max_seq_length': 128, 'task': 'mrpc', 'data_dir': './glue_data/MRPC/', 'overwrite_cache': False, 'glue_output_mode': 'classification'}

It's strange, because it is now a normal dictionary so hparams.task is not working 😢

@sshleifer could you help with that issue 🤔

@sshleifer
Copy link
Contributor

You could manually cast it to a namespace with

argparse.Namespace(**ckpt["hparams"])

But @williamFalcon may have a cleaner solution

@sshleifer sshleifer added Ex: Named Entity Recognition Examples Which is related to examples in general and removed Ex: Named Entity Recognition labels Jul 2, 2020
@bhashithe
Copy link
Contributor

bhashithe commented Jul 22, 2020

I added it with a very very dirty fix, in GLUETransformer init added this to avoid cast it to Namespace if it was a dict

if type(hparams) is dict: hparams = Namespace(**hparams)

@nagyrajmund
Copy link

nagyrajmund commented Jul 23, 2020

The official way to do this is to call self.save_hyperparameters(hparams) in the constructor of the module - then the hyperparameters will be accessible through self.hparams['some_param'] and self.hparams.some_param as well.

@bhashithe
Copy link
Contributor

@nagyrajmund Hey, but that looks like it does not solve the issue. Even without save_hyperparameters() call, it will save the hparams in the checkpoint and the yaml file.

@nagyrajmund
Copy link

Hey-hey,

I think you misunderstood me, my proposed fix is to replace this line with self.save_hyperparameters(hparams). Then the hparams will be loaded correctly from the checkpoint without changing any other functionality in the module. Let me know if you run into any issues :)

@williamFalcon
Copy link
Contributor

@nateraw @Borda

@Borda
Copy link

Borda commented Jul 23, 2020

the conclusion after sharing min exmple is missing self.save_hyperparameters() in init
https://pytorch-lightning.slack.com/archives/CRBLFHY79/p1595502354412700

@bhashithe
Copy link
Contributor

bhashithe commented Jul 23, 2020

EDIT: Does not work as intended, please check the other comments

@nagyrajmund Hey, but that looks like it does not solve the issue. Even without save_hyperparameters() call, it will save the hparams in the checkpoint and the yaml file.

It does work, i think as @Borda mentioned the example is missing that. Among, gpus parameter and load_datasets() functions were the issues.

@Borda
Copy link

Borda commented Jul 23, 2020

@bhashithe mind share the code or is it this example? transformers/examples/text-classification/run_pl_glue.py

@bhashithe
Copy link
Contributor

bhashithe commented Jul 23, 2020

EDIT: Does not work as intended, please check the other comments
@Borda It is actually the example, but i had to alter both lightning_base.py and run_pl_glue.py to get it to work.

@sshleifer
Copy link
Contributor

would you mind sending a PR with your fix @bhashithe ?

@bhashithe
Copy link
Contributor

No problem, let me send that now.

@bhashithe
Copy link
Contributor

Sorry @Borda that save_hyperparameters() fix does not work @nagyrajmund

Small oversight on my part, anyway i have it working by resetting hparams to be a Namespace().

@bhashithe
Copy link
Contributor

Created #6027 with fixes.

@stale
Copy link

stale bot commented Sep 24, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Sep 24, 2020
@stale stale bot closed this as completed Oct 1, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Examples Which is related to examples in general wontfix
Projects
None yet
Development

No branches or pull requests

6 participants