-
Notifications
You must be signed in to change notification settings - Fork 504
/
Copy pathtest_utils.py
165 lines (133 loc) · 5.05 KB
/
test_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from datetime import datetime
import multiprocessing
import os
import sys
import time
import unittest
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.debug.metrics_compare_utils as mcu
import torch_xla.utils.utils as xu
def mp_test(func):
"""Wraps a `unittest.TestCase` function running it within an isolated process.
Example::
import torch_xla.test.test_utils as xtu
import unittest
class MyTest(unittest.TestCase):
@xtu.mp_test
def test_basic(self):
...
Args:
func (callable): The `unittest.TestCase` function to be wrapped.
"""
def wrapper(*args, **kwargs):
proc = multiprocessing.Process(target=func, args=args, kwargs=kwargs)
proc.start()
proc.join()
if isinstance(args[0], unittest.TestCase):
args[0].assertEqual(proc.exitcode, 0)
return proc.exitcode
return wrapper
def _get_device_spec(device):
ordinal = xm.get_ordinal(defval=-1)
return str(device) if ordinal < 0 else '{}/{}'.format(device, ordinal)
def write_to_summary(summary_writer,
global_step=None,
dict_to_write={},
write_xla_metrics=False):
"""Writes scalars to a Tensorboard SummaryWriter.
Optionally writes XLA perf metrics.
Args:
summary_writer (SummaryWriter): The Tensorboard SummaryWriter to write to.
If None, no summary files will be written.
global_step (int, optional): The global step value for these data points.
If None, global_step will not be set for this datapoint.
dict_to_write (dict, optional): Dict where key is the scalar name and value
is the scalar value to be written to Tensorboard.
write_xla_metrics (bool, optional): If true, this method will retrieve XLA
performance metrics, parse them, and write them as scalars to Tensorboard.
"""
if summary_writer is None:
return
for k, v in dict_to_write.items():
summary_writer.add_scalar(k, v, global_step)
if write_xla_metrics:
metrics = mcu.parse_metrics_report(met.metrics_report())
aten_ops_sum = 0
for metric_name, metric_value in metrics.items():
if metric_name.find('aten::') == 0:
aten_ops_sum += metric_value
summary_writer.add_scalar(metric_name, metric_value, global_step)
summary_writer.add_scalar('aten_ops_sum', aten_ops_sum, global_step)
def close_summary_writer(summary_writer):
"""Flush and close a SummaryWriter.
Args:
summary_writer (SummaryWriter, optional): The Tensorboard SummaryWriter to
close and flush. If None, no action is taken.
"""
if summary_writer is not None:
summary_writer.flush()
summary_writer.close()
def get_summary_writer(logdir):
"""Initialize a Tensorboard SummaryWriter.
Args:
logdir (str): File location where logs will be written or None. If None, no
writer is created.
Returns:
Instance of Tensorboard SummaryWriter.
"""
if logdir:
from tensorboardX import SummaryWriter
writer = SummaryWriter(log_dir=logdir)
write_to_summary(
writer, 0, dict_to_write={'TensorboardStartTimestamp': time.time()})
return writer
def now(format='%H:%M:%S'):
return datetime.now().strftime(format)
def print_training_update(device,
step,
loss,
rate,
global_rate,
epoch=None,
summary_writer=None):
"""Prints the training metrics at a given step.
Args:
device (torch.device): The device where these statistics came from.
step_num (int): Current step number.
loss (float): Current loss.
rate (float): The examples/sec rate for the current batch.
global_rate (float): The average examples/sec rate since training began.
epoch (int, optional): The epoch number.
summary_writer (SummaryWriter, optional): If provided, this method will
write some of the provided statistics to Tensorboard.
"""
update_data = [
'Training', 'Device={}'.format(_get_device_spec(device)),
'Epoch={}'.format(epoch) if epoch is not None else None,
'Step={}'.format(step), 'Loss={:.5f}'.format(loss),
'Rate={:.2f}'.format(rate), 'GlobalRate={:.2f}'.format(global_rate),
'Time={}'.format(now())
]
print('|', ' '.join(item for item in update_data if item), flush=True)
if summary_writer:
write_to_summary(
summary_writer,
dict_to_write={
'examples/sec': rate,
'average_examples/sec': global_rate,
})
def print_test_update(device, accuracy, epoch=None, step=None):
"""Prints single-core test metrics.
Args:
device: Instance of `torch.device`.
accuracy: Float.
"""
update_data = [
'Test', 'Device={}'.format(_get_device_spec(device)),
'Step={}'.format(step) if step is not None else None,
'Epoch={}'.format(epoch) if epoch is not None else None,
'Accuracy={:.2f}'.format(accuracy) if accuracy is not None else None,
'Time={}'.format(now())
]
print('|', ' '.join(item for item in update_data if item), flush=True)