Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

machine translation validation fails with multi-process #1280

Closed
sIncerass opened this issue Oct 31, 2019 · 12 comments
Closed

machine translation validation fails with multi-process #1280

sIncerass opened this issue Oct 31, 2019 · 12 comments
Assignees
Labels

Comments

@sIncerass
Copy link

sIncerass commented Oct 31, 2019

❓ Questions and Help

To Reproduce

Steps to reproduce the behavior:

  1. create an instance using the latest torch-xla
export PROJECT_NAME=xxx
gcloud config set project ${PROJECT_NAME}
gcloud compute --project=${PROJECT_NAME} instances create instance-1 \
--zone=europe-west4-a  \
--machine-type=n1-standard-8  \
--image=debian-9-torch-xla-v20191026 \
--image-project=ml-images  \
--boot-disk-size=200GB
  1. conda activate torch-xla-nightly
  2. run machine translation scirpt following https://cloud.google.com/tpu/docs/tutorials/transformer-pytorch in tpu branch of fairseq-tpu (https://github.com/pytorch-tpu/fairseq/tree/tpu) as
gcloud compute tpus create transformer-pytorch-tutorial \
--zone=europe-west4-a \
--network=default \
--range=10.2.3.0 \
--version=pytorch-nightly \
--accelerator-type=v3-8

export TPU_IP_ADDRESS=ip-address; \
export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470";

python train.py \
  $HOME/pytorch-tutorial-data/wmt18_en_de_bpej32k \
  --save-interval=1 \
  --arch=transformer_vaswani_wmt_en_de_big \
  --max-target-positions=64 \
  --attention-dropout=0.1 \
  --no-progress-bar \
  --criterion=label_smoothed_cross_entropy \
  --source-lang=en \
  --lr-scheduler=inverse_sqrt \
  --min-lr 1e-09 \
  --skip-invalid-size-inputs-valid-test \
  --target-lang=de \
  --label-smoothing=0.1 \
  --update-freq=1 \
  --optimizer adam \
  --adam-betas '(0.9, 0.98)' \
  --warmup-init-lr 1e-07 \
  --lr 0.0005 \
  --warmup-updates 4000 \
  --share-all-embeddings \
  --dropout 0.3 \
  --weight-decay 0.0 \
  --valid-subset=valid \
  --max-epoch=25 \
  --input_shapes 128x64 \
  --num_cores=8 \
  --metrics_debug \
  --log_steps=100

After the first epoch during validation, it reports
/anaconda3/envs/torch-xla-nightly/lib/python3.6/multiprocessing/semaphore_tracker.py:143: UserWarning: semaphore_tracker: There appear to be 1 leaked semaphores to clean up at shutdown len(cache)) and then crushes. There is no checkpoint saved, too.

Expected behavior

It crushes with the SIGKILL from multiprocessing:

Traceback (most recent call last):
  File "train.py", line 632, in <module>
    cli_main()
  File "train.py", line 623, in cli_main
    xmp.spawn(_mp_fn, args=(args,), nprocs=args.num_cores)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 154, in spawn
    _start_fn, args=(fn, args), nprocs=nprocs, join=join, daemon=daemon)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 171, in spawn
    while not spawn_context.join():
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 107, in join
    (error_index, name)
Exception: process 0 terminated with signal SIGKILL

Environment

  • reproducible on XLA backend [CPU/TPU]: TPU
  • torch_xla version: torch-xla-nightly (v1026)
  • Any other relevant information:
@sIncerass sIncerass changed the title machine translation validation fails with multi-preprocess machine translation validation fails with multi-process Oct 31, 2019
@taylanbil
Copy link
Collaborator

Hello,

This is not a fatal error, right? The process should be going on after you see this message in the stderr, can you confirm?

This was discussed here. As far as I can tell, this issue is not really related to TPUs and it is benign.

@sIncerass
Copy link
Author

sIncerass commented Oct 31, 2019

Thanks for the information. It will then follow by and crush

| epoch 001 | valid on xla:0/1 'valid' subset | loss 5.485 | nll_loss 3.768 | ppl 13.62 | num_updates 4167
| epoch 001 | valid on xla:0/7 'valid' subset | loss 5.485 | nll_loss 3.768 | ppl 13.62 | num_updates 4167
| epoch 001 | valid on xla:0/2 'valid' subset | loss 5.485 | nll_loss 3.768 | ppl 13.62 | num_updates 4167
| epoch 001 | valid on xla:0/4 'valid' subset | loss 5.485 | nll_loss 3.768 | ppl 13.62 | num_updates 4167
| epoch 001 | valid on xla:0/3 'valid' subset | loss 5.485 | nll_loss 3.768 | ppl 13.62 | num_updates 4167
Traceback (most recent call last):
  File "train.py", line 632, in <module>
    cli_main()
  File "train.py", line 623, in cli_main
    xmp.spawn(_mp_fn, args=(args,), nprocs=args.num_cores)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 154, in spawn
    _start_fn, args=(fn, args), nprocs=nprocs, join=join, daemon=daemon)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 171, in spawn
    while not spawn_context.join():
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 107, in join
    (error_index, name)
Exception: process 0 terminated with signal SIGKILL

@taylanbil
Copy link
Collaborator

I see. Is there another error message you see? Something legitimately errors in the code, but this is independent of the semaphore_tracker message above. I'm going to be going through the steps now to see if I can reproduce.

@taylanbil
Copy link
Collaborator

Oh I just noticed that the commands above create a VM in Europe, and TPUs are in US. Can you retry w/ same region?

@sIncerass
Copy link
Author

sIncerass commented Oct 31, 2019

Sorry, that's a typo, the TPU and VM instance are all in Europe.
Thanks for helping. That's the only error message I have seen.

@taylanbil taylanbil self-assigned this Oct 31, 2019
@taylanbil
Copy link
Collaborator

I am trying to repro currently. I'll report back if epoch 1 validation errors or completes.

@sIncerass
Copy link
Author

sIncerass commented Oct 31, 2019

Many thanks for helping! I am also restarting a new run to see if it reports the same issue.
confirmed the same issue after the first epoch.

@Eric-Wallace
Copy link

Eric-Wallace commented Oct 31, 2019

You can also reproduce this error by just adding an

if i == 10:
    return tracker

inside train_loop_fn so you don't have to wait for epoch 1 training to finish.

@taylanbil
Copy link
Collaborator

So I created a new VM + tpu, and ran through the tutorial. The process indeed died as described in the issue, around validation step ~300. It received a SIGKILL. Looking at sudo dmesg -T, it became obvious that this is an OOM error.

The reason for this is, I believe, the following:

  • the tutorial is created assuming the environment torch-xla-0.5. Whereas you are using torch-xla-nightly. But there has been big changes since the 0.5 release, including switching to use multiprocessing instead of multithreading.
  • Multiprocessing loads the input data to all the processes, whereas multithreading loads once, so the memory usage is significantly higher in MP.
  • Since the tutorial uses n1-standard-8, the process OOMs.

I have verified that the combo n1-standard-64 and torch-xla-nightly works. I will now verify that it works on torch-xla-0.5 and n1-standard-8.

Does that make sense?

@sIncerass
Copy link
Author

sIncerass commented Nov 1, 2019

Yes, it makes sense.
@Eric-Wallace and we found that it might be better to merge the facebookresearch/fairseq@a1c997b into the pytorch-tpu/fairseq repo, which offers more efficient data loader and maybe resolves this problem easily. ("mmap" makes the script doesn't copy the memory across all the different processes).

@taylanbil
Copy link
Collaborator

Thanks for the suggestion, that seems like a useful commit indeed. It is in our plans to rebase our tpu branch on top of fairseq master, which will include this change too. Feel free to submit a PR if you have cherry picked that commit and resolved conflicts etc already.

I verified that both combinations below work.

  • n1-standard-64 and torch-xla-nightly
  • torch-xla-0.5 and n1-standard-8

So, to use multiprocessing in the meantime, you can switch to a bigger machine.

@yingyukexiansheng
Copy link

can you tell me the fairseq version that you had uesd, i not find the --num_cores command in my version, my version is 0.10.2 ,thank you very match

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants