File tree 1 file changed +39
-0
lines changed
pytorch_lightning/trainer
1 file changed +39
-0
lines changed Original file line number Diff line number Diff line change
1
+ import torch
2
+
3
+
4
+ class TensorRunningMean (object ):
5
+ """
6
+ Tracks a running mean without graph references.
7
+ Round robbin for the mean
8
+ """
9
+ def __init__ (self , window_length ):
10
+ self .window_length = window_length
11
+ self .reset ()
12
+ self .last_idx = 0
13
+
14
+ def reset (self ):
15
+ self .memory = torch .Tensor (self .window_length )
16
+ self .current_idx = 0
17
+
18
+ def last (self ):
19
+ return self .memory [self .last_idx ]
20
+
21
+ def append (self , x ):
22
+ # map proper type for memory if they don't match
23
+ if self .memory .type () != x .type ():
24
+ self .memory .type_as (x )
25
+
26
+ # store without grads
27
+ with torch .no_grad ():
28
+ self .memory [self .current_idx ] = x
29
+ self .last_idx = self .current_idx
30
+
31
+ # increase index
32
+ self .current_idx += 1
33
+
34
+ # reset index when hit limit of tensor
35
+ if self .current_idx >= self .window_length :
36
+ self .current_idx = 0
37
+
38
+ def mean (self ):
39
+ return self .memory .mean ()
You can’t perform that action at this time.
0 commit comments