-
Notifications
You must be signed in to change notification settings - Fork 352
/
Copy pathdistributed.py
114 lines (92 loc) · 4.15 KB
/
distributed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#AUTOGENERATED! DO NOT EDIT! File to edit: dev/20a_distributed.ipynb (unless otherwise specified).
__all__ = ['ParallelTrainer', 'setup_distrib', 'DistributedDL', 'DistributedTrainer']
#Cell
from .basics import *
from .callback.progress import ProgressCallback
from torch.nn.parallel import DistributedDataParallel, DataParallel
from torch.utils.data.distributed import DistributedSampler
#Cell
@patch
def reset(self: DataParallel):
if hasattr(self.module, 'reset'): self.module.reset()
#Cell
class ParallelTrainer(Callback):
run_after,run_before = TrainEvalCallback,Recorder
def __init__(self, device_ids): self.device_ids = device_ids
def begin_fit(self): self.learn.model = DataParallel(self.learn.model, device_ids=self.device_ids)
def after_fit(self): self.learn.model = self.learn.model.module
#Cell
@patch
def to_parallel(self: Learner, device_ids=None):
self.add_cb(ParallelTrainer(device_ids))
return self
#Cell
@patch
def reset(self: DistributedDataParallel):
if hasattr(self.module, 'reset'): self.module.reset()
#Cell
def setup_distrib(gpu=None):
if gpu is None: return gpu
gpu = int(gpu)
torch.cuda.set_device(int(gpu))
if num_distrib() > 1:
torch.distributed.init_process_group(backend='nccl', init_method='env://')
return gpu
#Cell
@delegates()
class DistributedDL(TfmdDL):
def __init__(self, dataset, rank, world_size, **kwargs):
super().__init__(dataset, **kwargs)
if self.n%world_size != 0: self.n += world_size-self.n%world_size
self.total_n,self.n = self.n,self.n//world_size
store_attr(self, 'rank,world_size')
def get_idxs(self):
idxs = Inf.count if self.indexed else Inf.nones
return idxs if self.n is None else list(itertools.islice(idxs, self.total_n))
def shuffle_fn(self, idxs):
"Deterministically shuffle on each training process based on epoch."
g = torch.Generator()
g.manual_seed(self.epoch)
return L(idxs)[torch.randperm(self.total_n, generator=g)]
def sample(self):
idxs = self.get_idxs()
if self.shuffle: idxs = self.shuffle_fn(idxs)
# add extra samples to make it evenly divisible
idxs += idxs[:(self.total_n - len(idxs))]
# subsample
idxs = idxs[self.rank:self.total_n:self.world_size]
return (b for i,b in enumerate(idxs) if i//(self.bs or 1)%self.nw==self.offs)
def create_item(self, s):
if s is not None and s >= len(self.dataset): s = s%len(self.dataset)
return super().create_item(s)
def set_epoch(self, epoch): self.epoch = epoch
@classmethod
def from_dl(cls, dl, rank, world_size, **kwargs):
cur_kwargs = dict(num_workers=dl.fake_l.num_workers, pin_memory=dl.pin_memory, timeout=dl.timeout,
bs=dl.bs, shuffle=dl.shuffle, drop_last=dl.drop_last, indexed=dl.indexed)
cur_kwargs.update({n: getattr(dl, n) for n in cls._methods if n not in "get_idxs sample shuffle_fn create_item".split()})
return cls(dl.dataset, rank, world_size, **merge(cur_kwargs, kwargs))
#Cell
class DistributedTrainer(Callback):
run_after,run_before = TrainEvalCallback,Recorder
def __init__(self, cuda_id=0): self.cuda_id = cuda_id
def begin_fit(self):
self.learn.model = DistributedDataParallel(self.model, device_ids=[self.cuda_id], output_device=self.cuda_id)
self.old_dls = [dl for dl in self.dbunch.dls]
self.learn.dbunch.dls = [self._wrap_dl(dl) for dl in self.dbunch.dls]
if rank_distrib() > 0: self.learn.logger=noop
def _wrap_dl(self, dl):
return dl if isinstance(dl, DistributedDL) else DistributedDL.from_dl(dl, rank_distrib(), num_distrib())
def begin_epoch(self):
for dl in self.dbunch.dls: dl.set_epoch(self.epoch)
def begin_train(self): self.dl = self._wrap_dl(self.dl)
def begin_validate(self): self.dl = self._wrap_dl(self.dl)
def after_fit(self):
self.learn.model = self.learn.model.module
self.learn.dbunch.dls = self.old_dls
#Cell
@patch
def to_distributed(self: Learner, cuda_id):
self.add_cb(DistributedTrainer(cuda_id))
if rank_distrib() > 0: self.remove_cb(self.progress)
return self