Skip to content

Commit 65ffabc

Browse files
williamFalconalexeykarnachev
authored andcommitted
remove .item which causes sync issues (Lightning-AI#1254)
* remove .item which causes sync issues * fixed gradient acc sched * fixed gradient acc sched
1 parent 135e22f commit 65ffabc

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import torch
2+
3+
4+
class TensorRunningMean(object):
5+
"""
6+
Tracks a running mean without graph references.
7+
Round robbin for the mean
8+
"""
9+
def __init__(self, window_length):
10+
self.window_length = window_length
11+
self.reset()
12+
self.last_idx = 0
13+
14+
def reset(self):
15+
self.memory = torch.Tensor(self.window_length)
16+
self.current_idx = 0
17+
18+
def last(self):
19+
return self.memory[self.last_idx]
20+
21+
def append(self, x):
22+
# map proper type for memory if they don't match
23+
if self.memory.type() != x.type():
24+
self.memory.type_as(x)
25+
26+
# store without grads
27+
with torch.no_grad():
28+
self.memory[self.current_idx] = x
29+
self.last_idx = self.current_idx
30+
31+
# increase index
32+
self.current_idx += 1
33+
34+
# reset index when hit limit of tensor
35+
if self.current_idx >= self.window_length:
36+
self.current_idx = 0
37+
38+
def mean(self):
39+
return self.memory.mean()

0 commit comments

Comments
 (0)