Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OOM during training on v3-8 #1881

Closed
blizda opened this issue Apr 7, 2020 · 7 comments
Closed

OOM during training on v3-8 #1881

blizda opened this issue Apr 7, 2020 · 7 comments
Labels
stale Has not had recent activity

Comments

@blizda
Copy link

blizda commented Apr 7, 2020

Hi, I trying to train new sinkhorn-transformer model on TPU and train falling on optimizer step.
With reformer model training isn't falling with OOM.

For reprodusing clone and install package from here

Script to reproduce issue

import os
os.environ['XLA_USE_32BIT_LONG'] = '1'
# imports the torch_xla package
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch.nn.functional as F
import torch.optim as optim
from sinkhorn_transformer import SinkhornTransformerLM
import numpy as np
from tqdm import tqdm

dev = xm.xla_device()
model = SinkhornTransformerLM(
    num_tokens= 30522,
    dim = 512,
    depth = 12,
    max_seq_len = 40960,
    heads = 16,
    buckets = 64,
    causal = False,          # auto-regressive or not
    sinkhorn_iter = 7,       # number of sinkhorn iterations - default is set at reported best in paper
    n_sortcut = 2,           # use sortcut to reduce complexity to linear time
    temperature = 0.75,      # gumbel temperature - default is set at reported best in paper
    non_permutative = False, # allow buckets of keys to be sorted to queries more than once
    ff_chunks = 10,          # feedforward chunking, from Reformer paper
    reversible = True,       # make network reversible, from Reformer paper
).to(dev)
opt = optim.Adam(model.parameters(), lr=3e-5)
for i in range(3):
  opt.zero_grad()
  x = torch.randint(0, 30522, (1, 40960)).to(dev)
  y = F.log_softmax(model(x), dim=-1)
  print("model passed")
  loss = F.nll_loss(y.reshape(y.shape[0] * y.shape[1], -1)[:-1], x.reshape(x.shape[0] * x.shape[1])[1:])
  loss.backward()
  print("loss passed")
  xm.optimizer_step(opt, barrier=True)
  print("optimezer steped")

StackTrace:

2020-04-07 10:20:31.802999: E    3709 tensorflow/compiler/xla/xla_client/xla_util.cc:76] StackTrace:
2020-04-07 10:20:31.803006: E    3709 tensorflow/compiler/xla/xla_client/xla_util.cc:76] *** Begin stack trace ***
2020-04-07 10:20:31.803014: E    3709 tensorflow/compiler/xla/xla_client/xla_util.cc:76]        tensorflow::CurrentStackTrace[abi:cxx11]()
2020-04-07 10:20:31.803022: E    3709 tensorflow/compiler/xla/xla_client/xla_util.cc:76]        xla::util::ReportComputationError(tensorflow::Status const&, absl::Span<xla::XlaComputationconst* const>, absl::Span<xla::Shape const * const>)
2020-04-07 10:20:31.803031: E    3709 tensorflow/compiler/xla/xla_client/xla_util.cc:76]        xla::util::ShapeHash(xla::Shape const&)
2020-04-07 10:20:31.803043: E    3709 tensorflow/compiler/xla/xla_client/xla_util.cc:76]        xla::XrtComputationClient::ExecuteComputation(xla::ComputationClient::Computation const&,absl::Span<std::shared_ptr<xla::ComputationClient::Data> const>, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, xla::ComputationClient::ExecuteComputationOptions const&)
2020-04-07 10:20:31.803054: E    3709 tensorflow/compiler/xla/xla_client/xla_util.cc:76] 
2020-04-07 10:20:31.803061: E    3709 tensorflow/compiler/xla/xla_client/xla_util.cc:76] 
2020-04-07 10:20:31.803068: E    3709 tensorflow/compiler/xla/xla_client/xla_util.cc:76] 
2020-04-07 10:20:31.803075: E    3709 tensorflow/compiler/xla/xla_client/xla_util.cc:76] 
2020-04-07 10:20:31.803086: E    3709 tensorflow/compiler/xla/xla_client/xla_util.cc:76] 
2020-04-07 10:20:31.803095: E    3709 tensorflow/compiler/xla/xla_client/xla_util.cc:76]        clone
2020-04-07 10:20:31.803103: E    3709 tensorflow/compiler/xla/xla_client/xla_util.cc:76] *** End stack trace ***
2020-04-07 10:20:31.803111: E    3709 tensorflow/compiler/xla/xla_client/xla_util.cc:76] 
2020-04-07 10:20:31.803118: E    3709 tensorflow/compiler/xla/xla_client/xla_util.cc:76] Status: Resource exhausted: From /job:tpu_worker/replica:0/task:0:
2020-04-07 10:20:31.803126: E    3709 tensorflow/compiler/xla/xla_client/xla_util.cc:76] Attempting to reserve 12.28G at the bottom of memory. That was not possible. There are 9.79G free, 0B reserved, and 9.79G reservable.
2020-04-07 10:20:31.803134: E    3709 tensorflow/compiler/xla/xla_client/xla_util.cc:76]         [[{{node XRTExecute}}]]
2020-04-07 10:20:31.803141: E    3709 tensorflow/compiler/xla/xla_client/xla_util.cc:76] Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.
2020-04-07 10:20:31.803149: E    3709 tensorflow/compiler/xla/xla_client/xla_util.cc:76] 
Traceback (most recent call last):
  File "test_new_fat_model.py", line 44, in <module>
    y = F.log_softmax(model(x), dim=-1)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 558, in __call__
    result = self.forward(*input, **kwargs)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py", line 320, in forward
    x = self.sinkhorn_transformer(x)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 558, in __call__
    result = self.forward(*input, **kwargs)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py", line 306, in forward
    return self.layers(x)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 558, in __call__
    result = self.forward(*input, **kwargs)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py", line 277, in forward
    x = self.layers(x, **kwargs)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 558, in __call__
    result = self.forward(*input, **kwargs)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py", line 130, in forward
    return _ReversibleFunction.apply(x, blocks, block_kwargs)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py", line 100, in forward
    x = block(x, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 558, in __call__
    result = self.forward(*input, **kwargs)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py", line 52,
 in forward
    y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 558, in __call__
    result = self.forward(*input, **kwargs)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py", line 27,
 in forward
    return self.net(*args, **kwargs)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 558, in __call__
    result = self.forward(*input, **kwargs)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py", line 101, in forward
    return self.fn(x, **kwargs)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 558, in __call__
    result = self.forward(*input, **kwargs)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py", line 265, in forward
    out = self.sinkhorn_attention(q, k, v)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 558, in __call__
    result = self.forward(*input, **kwargs)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py", line 138, in forward
    gumbel_noise = sample_gumbel(R.shape, device)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py", line 50, in sample_gumbel
    u = torch.empty(shape, device = device).uniform_(0, 1)
RuntimeError: Resource exhausted: From /job:tpu_worker/replica:0/task:0:
Attempting to reserve 12.28G at the bottom of memory. That was not possible. There are 9.79G free, 0B reserved, and 9.79G reservable.
         [[{{node XRTExecute}}]]

Why this issue is happened? v3-8 TPU have 128GB memory, size of this model exactly the same that reformer model. On GPU this model allocate 40GB with optimizer and loss(with batch size 1)

@dlibenzi
Copy link
Collaborator

dlibenzi commented Apr 7, 2020

There are 16GB per core, not 128GB unified memory 😉
And there is a hard separation among cores. The only thing you can do across cores are all-reduce/collective operations.

But I can see we currently do not support uniform_() (given your OOM comes from there). We will add it today.

@blizda
Copy link
Author

blizda commented Apr 7, 2020

There are 16GB per core, not 128GB unified memory 😉
And there is a hard separation among cores. The only thing you can do across cores are all-reduce/collective operations.

But I can see we currently do not support uniform_() (given your OOM comes from there). We will add it today.

Thanks for clarifying. I will try again when you add uniform_() support.

@dlibenzi
Copy link
Collaborator

dlibenzi commented Apr 7, 2020

The uniform_() op will be in tomorrow's nightly.
One other attempt you can do is to turn on automatic bfloat16:

export XLA_USE_BF16=1

@blizda
Copy link
Author

blizda commented Apr 8, 2020

The uniform_() op will be in tomorrow's nightly.
One other attempt you can do is to turn on automatic bfloat16:

export XLA_USE_BF16=1

Thanks. I set XLA_USE_BF16 and the training is going fine, but the optimizer takes a step very slowly, much slower than on the GPU

@blizda
Copy link
Author

blizda commented Apr 8, 2020

Ok, on the second try optimizer making one step on the normal speed. It a little bit strange

@dlibenzi
Copy link
Collaborator

dlibenzi commented Apr 8, 2020

This is normal.
The IR graphs generated underneath the pytorch operations are compiled to XLA, and this takes time.
After a few steps, if the graphs do not change, the compilations are cached and steps should proceed at full speed.

@stale
Copy link

stale bot commented May 9, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the stale Has not had recent activity label May 9, 2020
@stale stale bot closed this as completed May 16, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Has not had recent activity
Projects
None yet
Development

No branches or pull requests

2 participants