Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] text-classification PL example #6027

Merged
merged 9 commits into from
Aug 6, 2020
10 changes: 7 additions & 3 deletions examples/lightning_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
# self.save_hyperparameters()
# can also expand arguments into trainer signature for easier reading

self.hparams = hparams
self.save_hyperparameters(hparams)
self.step_count = 0
self.tfmr_ckpts = {}
self.output_dir = Path(self.hparams.output_dir)
Expand Down Expand Up @@ -194,7 +194,7 @@ def add_model_specific_args(parser, root_dir):

class LoggingCallback(pl.Callback):
def on_batch_end(self, trainer, pl_module):
lrs = {f"lr_group_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_lr())}
lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)}
pl_module.logger.log_metrics(lrs)

def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
Expand Down Expand Up @@ -227,6 +227,10 @@ def add_generic_args(parser, root_dir) -> None:
help="The output directory where the model predictions and checkpoints will be written.",
)

parser.add_argument(
"--gpus", default=0, type=int, help="The number of GPUs allocated for this, it is by default 0 meaning none",
)

parser.add_argument(
"--fp16",
action="store_true",
Expand All @@ -240,7 +244,7 @@ def add_generic_args(parser, root_dir) -> None:
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html",
)
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int, default=0)
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.")
Expand Down
2 changes: 1 addition & 1 deletion examples/text-classification/run_pl.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ mkdir -p $OUTPUT_DIR
# Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}"

python3 run_pl_glue.py --data_dir $DATA_DIR \
python3 run_pl_glue.py --gpus 1 --data_dir $DATA_DIR \
--task $TASK \
--model_name_or_path $BERT_MODEL \
--output_dir $OUTPUT_DIR \
Expand Down
8 changes: 6 additions & 2 deletions examples/text-classification/run_pl_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import os
import time
from argparse import Namespace

import numpy as np
import torch
Expand All @@ -24,6 +25,8 @@ class GLUETransformer(BaseTransformer):
mode = "sequence-classification"

def __init__(self, hparams):
if type(hparams) == dict:
hparams = Namespace(**hparams)
hparams.glue_output_mode = glue_output_modes[hparams.task]
num_labels = glue_tasks_num_labels[hparams.task]

Expand All @@ -41,7 +44,8 @@ def training_step(self, batch, batch_idx):
outputs = self(**inputs)
loss = outputs[0]

tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]}
# tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]}
tensorboard_logs = {"loss": loss}
return {"loss": loss, "log": tensorboard_logs}

def prepare_data(self):
Expand Down Expand Up @@ -72,7 +76,7 @@ def prepare_data(self):
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)

def load_dataset(self, mode, batch_size):
def get_dataloader(self, mode: int, batch_size: int, shuffle: bool) -> DataLoader:
"Load datasets. Called after prepare data."

# We test on dev set to compare to benchmarks without having to submit to GLUE server
Expand Down