Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metrics reduction on distributed TPU setting #965

Closed
vfdev-5 opened this issue Apr 22, 2020 · 2 comments · Fixed by #1045
Closed

Metrics reduction on distributed TPU setting #965

vfdev-5 opened this issue Apr 22, 2020 · 2 comments · Fixed by #1045

Comments

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Apr 22, 2020

🚀 Feature

Ignite will support distributed training on TPU (e.g. #960). Currently, metric's computation is impacted in the same way as for DDP on GPUs.

Idea is to improve metric's computation and reduce internal values as it is done for DDP:

def _sync_all_reduce(self, tensor: Union[torch.Tensor, numbers.Number]) -> Union[torch.Tensor, numbers.Number]:

To check if we are running in distributed TPU, we can opt to

# global definition
try:
     import torch_xla.core.xla_model as xm
     on_xla_device = True
except ImportError:
     on_xla_device = False

and if we need to reduce:

xm.xrt_world_size() > 1

This issues depends on #963

@vfdev-5
Copy link
Collaborator Author

vfdev-5 commented Apr 23, 2020

The following monkeypatching works to reduce metrics

# We need to monkeypatch base Ignite metric class to work in distributed TPU
# Until full support of TPU on PyTorch-Ignite side : 
# https://github.com/pytorch/ignite/issues/965
import numbers
import torch
import torch_xla.core.xla_model as xm


def _tpu_sync_all_reduce(self, tensor):
    tensor_to_number = False
    if isinstance(tensor, numbers.Number):
        tensor = torch.tensor(tensor, device=self._device, dtype=torch.float)
        tensor_to_number = True

    if isinstance(tensor, torch.Tensor):
        # check if the tensor is at specified device
        if tensor.device != self._device:
            tensor = tensor.to(self._device)
    else:
        raise TypeError("Unhandled input type {}".format(type(tensor)))

    # synchronize and reduce
    xm.all_reduce("sum", [tensor, ])

    if tensor_to_number:
        return tensor.item()
    return tensor


from ignite.metrics import Metric


Metric._sync_all_reduce = _tpu_sync_all_reduce

https://colab.research.google.com/drive/1Gy8bblDyXYBqMI7PuLAcIDKzRZGSoeYl

@vfdev-5
Copy link
Collaborator Author

vfdev-5 commented Apr 25, 2020

Follow up about dtype support : pytorch/xla#1952

This was referenced May 16, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant