Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xla_device with args #2374

Closed
zcain117 opened this issue Jul 24, 2020 · 24 comments
Closed

xla_device with args #2374

zcain117 opened this issue Jul 24, 2020 · 24 comments
Labels
stale Has not had recent activity

Comments

@zcain117
Copy link
Collaborator

PyTorch Lightning had some interest in a use case that we haven't explored much: running separate and independent processes on each TPU core, each with its own slice of the dataset. I think the use case was k-fold training and/or hyperparam tuning but I wanted @williamFalcon and @lezwon and @Borda to correct me if I'm wrong.

@ailzhang and @JackCaoG and @davidel are not convinced that this will work with the current system.

Right now, the API gives the option of choosing a device: https://github.com/pytorch/xla/blob/master/torch_xla/core/xla_model.py#L221

This makes it seem like this use case is possible and PyTorch Lightning was able to get it working on Colab but having some issues when running in GKE.

Should we remove this option from the API if we don't intend for anyone to use it?

@zcain117
Copy link
Collaborator Author

@ultrons FYI in case you were curious about this use case

@davidel
Copy link
Collaborator

davidel commented Jul 25, 2020

What does "independent" mean?
Do they do any collective ops? If they do, they are dependent.
I am not sure what different slices of dataset mean here. Distributed training does that. Might be worst a better explanation.

@williamFalcon
Copy link

there are no collective ops... each one runs on its own process with its own optimizer and dataset

@davidel
Copy link
Collaborator

davidel commented Jul 25, 2020

Then yes, no problem. As long as you do not call xm.optimizer_step().

@davidel
Copy link
Collaborator

davidel commented Jul 25, 2020

There is still no need to pass an ordinal to xm.xla_device().

@williamFalcon
Copy link

ok. i think we still call the xm.optimizer_step()
i guess we can use the regular optimizer when operating in this mode @lezwon

@davidel
Copy link
Collaborator

davidel commented Jul 25, 2020

Wait ... you do need to call xm.mark_step() though (instead of xm.optimizer_step().

@lezwon
Copy link
Contributor

lezwon commented Jul 25, 2020

There is still no need to pass an ordinal to xm.xla_device().

@davidel I'm a little confused. How would a process select a different core without providing an ordinal? Will it automatically pick the next free core? Also, could you explain why xm.optimizer_step() should not be called?

@davidel
Copy link
Collaborator

davidel commented Jul 25, 2020

The pytorch/xla multiprocessing automatically partitions the devices and assign a proper "current device" to each process.

If you call xm.optimizer_step() the different cores will try to reduce the gradients, hence the cores are not really independent.

@tmabraham
Copy link

I was just checking out the issues, and found this one.

I don't know if you guys know, but @abhishekkrthakur invented this exact technique of training multiple folds on TPUs over here. Here's a YouTube video on the same.

@lezwon
Copy link
Contributor

lezwon commented Jul 26, 2020

The pytorch/xla multiprocessing automatically partitions the devices and assign a proper "current device" to each process.

I meant when we run each process on a separate core manually instead of using multiprocessing, we would have to choose a device using xla_device right?

If you call xm.optimizer_step() the different cores will try to reduce the gradients, hence the cores are not really independent.

In the API Guide, it's mentioned to call xm.optimizer_step(optimizer, barrier=True). Does this create a problem when running multiple processes parallel? What if we provide a replica group using the groups parameter? i.e xm.optimizer_step(optimizer, barrier=True, groups=[[xm.get_ordinal()]])

@lezwon
Copy link
Contributor

lezwon commented Jul 26, 2020

@tmabraham This functionality was added based on his kernel :]

@tmabraham
Copy link

tmabraham commented Jul 26, 2020

@lezwon So the implementation is in kernel. So I am curious what is the confusion?

@lezwon
Copy link
Contributor

lezwon commented Jul 26, 2020

@tmabraham So abhishek's kernel is basically demonstrates training K models parallelly. The functionality implemented in PyTorch Lightning supports both: training with multi-processing as well as training on a single core. Given the code differences between them, there are some issues we are trying to resolve during training and checkpointing to ensure a consistent experience to the user similar to that of training on GPUs. You can view this PR for more info.

@davidel
Copy link
Collaborator

davidel commented Jul 26, 2020

I was just checking out the issues, and found this one.

I don't know if you guys know, but @abhishekkrthakur invented this exact technique of training multiple folds on TPUs over here. Here's a YouTube video on the same.

That code "happens to work" 😄
It does the same:

device = xm.xla_device(fold + 1)

It assumes the positions of TPU devices starting from 1.
It should have called xm.get_xla_supported_devices(NUM_CORES) and index such list with fold.

Also, that code uses multi-threading, which is considerably slower 20..30% to multi-processing due to GIL serialization over the model's python code.

@zcain117
Copy link
Collaborator Author

Ok to summarize so far:

  • specifying a TPU core via xm.xla_device(fold) will probably work but it's not a use case we test or promote
  • if specifying a TPU core:
    • remember to use 0-indexed arg, i.e. fold instead of fold+1.
    • maybe try device = xm.get_xla_supported_devices(NUM_CORES)[fold] instead of device = xm.xla_device(fold) if the latter isn't working.
    • use xm.mark_step() instead of xm.optimizer_step since the latter will consolidate gradients between cores whereas you wanted your cores to be independent models.

Davide mentioned There is still no need to pass an ordinal to xm.xla_device().. I think his point was that each core is already independent (as long as you stop calling xm.optimizer_step). Instead of requesting a particular device, you could instead use our recommended flow and run the code without knowing before spawn time which device you'll end up on. In your code that runs on the TPU code, you can find which device you ended up on using code like this (example usage) and then maybe use something like torch.utils.data.Subset inside that core's code to make sure that core is using the right data. @davidel let me know if that seems right.

@williamFalcon @lezwon @Borda let us know if you're still running into issues.

@davidel
Copy link
Collaborator

davidel commented Jul 30, 2020

Yes. You just cannot select a device. It gets assigned to you.
And xm.xla_device() will tell you what it is.
In case of thread based parallelism (which I would not use as it's deprecated, and 20..30% slower), the device get passed to the target function.

If you need to know an ordinal, in order to create data samplers, use xm.get_ordinal() (and xm.xrt_world_size() for the world size).

@abhishekkrthakur
Copy link

When I look at the documentation which mentions training on single-core, it says xm.optimizer_step with barrier=True and not xm.mark_step. Can this be updated to reflect what we have learnt here? Is the documentation deprecated?

import torch_xla.core.xla_model as xm

device = xm.xla_device()
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

for data, target in train_loader:
  optimizer.zero_grad()
  data = data.to(device)
  target = target.to(device)
  output = model(data)
  loss = loss_fn(output, target)
  loss.backward()

  xm.optimizer_step(optimizer, barrier=True)

Or is mark_step used only when we specify the device?

@davidel The code I wrote uses multi-threading because multi-processing needed 8 times the memory. It was not able to fit a multiprocessing model in kaggle kernels which have 16GB of RAM.

So, the conclusion is, we call, xm.xla_device without index, we use mark_step and that's it. Right?

@davidel
Copy link
Collaborator

davidel commented Jul 30, 2020

In cases where you do not call xm.optimizer_step() (like somewhere mentioned in this thread), AND you are not using the ParallelLoader, you need to call xm.mark_step().
Otherwise the ParallelLoader calls it itself:

WRT OOM on 16GB Kaggle, did you try the MpModelWrapper :

class MpModelWrapper(object):

@abhishekkrthakur
Copy link

abhishekkrthakur commented Jul 30, 2020 via email

@davidel
Copy link
Collaborator

davidel commented Jul 30, 2020

You could try the serial executor:

class MpSerialExecutor(object):

Where the function you pass to it is like:

def _make_device_model(device):
  model = MyModel(...)
  return model.to(device)

def _serial_model_create(device):
  model = _make_device_model(device)
  gc.collect()
  return model

@ultrons
Copy link
Contributor

ultrons commented Jul 31, 2020

@davidel , If we do not call xm.optimizer_step, instead do a optimizer.step() followed by mark_step in the training loop, this would mean that no all reduce happens across the cores. Since each core is working on separate shard of dataset already, essentially we will be training 8 independent models in parallel and at the end of the training loop we can write out those models. Is that a right understanding? If so then probably @abhishekkrthakur can consider going that route. @abhishekkrthakur is that what you want to accomplish as k-fold training?

@davidel
Copy link
Collaborator

davidel commented Jul 31, 2020

Actually, it is a bit more complex.
The TPUs (in the way we configure them in replication mode) have a global barrier that all cores have to reach, before execution starts.
This means that the model described above only works if all the cores run the same number of TPU executions.
A totally independent training (where number of TPU execs is uneven across cores) requires changes and addition of a special mode.

@lezwon lezwon mentioned this issue Aug 9, 2020
@stale
Copy link

stale bot commented Aug 30, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the stale Has not had recent activity label Aug 30, 2020
@stale stale bot closed this as completed Sep 6, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Has not had recent activity
Projects
None yet
Development

No branches or pull requests

7 participants