Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Minimal Reproducible Usage Example For TPU support on examples/seq2seq #5960

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions examples/seq2seq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,20 @@ The following command should work on a 16GB GPU:
--model_name_or_path facebook/bart-large
```

The following command should work on for TPUs:
```bash
./finetune_tpu.sh \
--data_dir $XSUM_DIR \
--train_batch_size=8 \
--eval_batch_size=4 \
--output_dir=xsum_results \
--num_train_epochs 1 \
--model_name_or_path facebook/bart-large \
--n_tpu_cores 1 \
```

NB If you are using multiple TPU cores, then one needs to adjust the `batch_size` (for training and eval both), `learning rate`, `n_tpu_cores` etc. accordingly in order to find the best possible combination that leverages TPus to their maximal capacity.

### Translation Finetuning

First, follow the wmt_en_ro download instructions.
Expand Down
15 changes: 15 additions & 0 deletions examples/seq2seq/finetune_tpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}"

# the proper usage is documented in the README, you need to specify data_dir, output_dir and model_name_or_path
# NB You need to adjust the learning_rate, batch_size (for train, eval etc) and pass in n_tpu_cores as well.
# TPUs are very sensitive to these params.

python finetune.py \
--learning_rate=3e-5 \
--gpus 0 \
--do_train \
--do_predict \
--n_val 1000 \
--val_check_interval 0.1 \
$@
3 changes: 2 additions & 1 deletion src/transformers/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,7 @@ def __init__(self, config: BartConfig):
super().__init__(config)
base_model = BartModel(config)
self.model = base_model
self.lm_head = _make_linear_from_emb(self.model.shared)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to call this again after line 952 to make the tests pass.

self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))

def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
Expand Down Expand Up @@ -1110,7 +1111,7 @@ def get_encoder(self):
return self.model.encoder

def get_output_embeddings(self):
return _make_linear_from_emb(self.model.shared) # make it on the fly
return self.lm_head # don't make it on the fly as it's not compatible with TPU's


@add_start_docstrings(
Expand Down