Skip to content

Commit bde996c

Browse files
authoredSep 14, 2022
0.9.10
1 parent ddbe9e0 commit bde996c

14 files changed

+3428
-3422
lines changed
 

‎torchstudio/datasetanalyze.py

+117-117
Original file line numberDiff line numberDiff line change
@@ -1,117 +1,117 @@
1-
import sys
2-
3-
import torchstudio.tcpcodec as tc
4-
from torchstudio.modules import safe_exec
5-
import os
6-
import io
7-
from collections.abc import Iterable
8-
from tqdm.auto import tqdm
9-
import pickle
10-
11-
original_path=sys.path
12-
13-
app_socket = tc.connect()
14-
print("Analyze script connected\n", file=sys.stderr)
15-
while True:
16-
msg_type, msg_data = tc.recv_msg(app_socket)
17-
18-
if msg_type == 'SetAnalyzerCode':
19-
print("Setting analyzer code...\n", file=sys.stderr)
20-
analyzer = None
21-
analyzer_code = tc.decode_strings(msg_data)[0]
22-
error_msg, analyzer_env = safe_exec(analyzer_code, description='analyzer definition')
23-
if error_msg is not None or 'analyzer' not in analyzer_env:
24-
print("Unknown analyzer definition error" if error_msg is None else error_msg, file=sys.stderr)
25-
26-
if msg_type == 'StartAnalysisServer' and 'analyzer' in analyzer_env:
27-
print("Analyzing...\n", file=sys.stderr)
28-
29-
analysis_server, address = tc.generate_server()
30-
31-
if analyzer_env['analyzer'].train is None:
32-
request_msg='AnalysisServerRequestingAllSamples'
33-
elif analyzer_env['analyzer'].train==True:
34-
request_msg='AnalysisServerRequestingTrainingSamples'
35-
elif analyzer_env['analyzer'].train==False:
36-
request_msg='AnalysisServerRequestingValidationSamples'
37-
tc.send_msg(app_socket, request_msg, tc.encode_strings(address))
38-
dataset_socket=tc.start_server(analysis_server)
39-
40-
while True:
41-
dataset_msg_type, dataset_msg_data = tc.recv_msg(dataset_socket)
42-
43-
if dataset_msg_type == 'NumSamples':
44-
num_samples=tc.decode_ints(dataset_msg_data)[0]
45-
pbar=tqdm(total=num_samples, desc='Analyzing...', bar_format='{l_bar}{bar}| {remaining} left\n\n') #see https://github.com/tqdm/tqdm#parameters
46-
47-
if dataset_msg_type == 'InputTensorsID':
48-
input_tensors_id=tc.decode_ints(dataset_msg_data)
49-
50-
if dataset_msg_type == 'OutputTensorsID':
51-
output_tensors_id=tc.decode_ints(dataset_msg_data)
52-
53-
if dataset_msg_type == 'Labels':
54-
labels=tc.decode_strings(dataset_msg_data)
55-
56-
if dataset_msg_type == 'StartSending':
57-
error_msg, return_value = safe_exec(analyzer_env['analyzer'].start_analysis, (num_samples, input_tensors_id, output_tensors_id, labels), description='analyzer definition')
58-
if error_msg is not None:
59-
pbar.close()
60-
print(error_msg, file=sys.stderr)
61-
dataset_socket.close()
62-
analysis_server.close()
63-
break
64-
65-
if dataset_msg_type == 'TrainingSample':
66-
pbar.update(1)
67-
error_msg, return_value = safe_exec(analyzer_env['analyzer'].analyze_sample, (tc.decode_numpy_tensors(dataset_msg_data), True), description='analyzer definition')
68-
if error_msg is not None:
69-
pbar.close()
70-
print(error_msg, file=sys.stderr)
71-
dataset_socket.close()
72-
analysis_server.close()
73-
break
74-
75-
if dataset_msg_type == 'ValidationSample':
76-
pbar.update(1)
77-
error_msg, return_value = safe_exec(analyzer_env['analyzer'].analyze_sample, (tc.decode_numpy_tensors(dataset_msg_data), False), description='analyzer definition')
78-
if error_msg is not None:
79-
pbar.close()
80-
print(error_msg, file=sys.stderr)
81-
dataset_socket.close()
82-
analysis_server.close()
83-
break
84-
85-
if dataset_msg_type == 'DoneSending':
86-
pbar.close()
87-
error_msg, return_value = safe_exec(analyzer_env['analyzer'].finish_analysis, description='analyzer definition')
88-
tc.send_msg(dataset_socket, 'DoneReceiving')
89-
dataset_socket.close()
90-
analysis_server.close()
91-
if error_msg is not None:
92-
print(error_msg, file=sys.stderr)
93-
else:
94-
buffer=io.BytesIO()
95-
pickle.dump(analyzer_env['analyzer'].state_dict(), buffer)
96-
tc.send_msg(app_socket, 'AnalyzerState',buffer.getvalue())
97-
tc.send_msg(app_socket, 'AnalysisWeights',tc.encode_floats(analyzer_env['analyzer'].weights))
98-
print("Analysis complete")
99-
break
100-
101-
if msg_type == 'LoadAnalyzerState':
102-
if 'analyzer' in analyzer_env:
103-
buffer=io.BytesIO(msg_data)
104-
analyzer_env['analyzer'].load_state_dict(pickle.load(buffer))
105-
print("Analyzer state loaded")
106-
107-
if msg_type == 'RequestAnalysisReport':
108-
resolution = tc.decode_ints(msg_data)
109-
if 'analyzer' in analyzer_env:
110-
error_msg, return_value = safe_exec(analyzer_env['analyzer'].generate_report, (resolution[0:2],resolution[2]), description='analyzer definition')
111-
if error_msg is not None:
112-
print(error_msg, file=sys.stderr)
113-
if return_value is not None:
114-
tc.send_msg(app_socket, 'ReportImage', tc.encode_image(return_value))
115-
116-
if msg_type == 'Exit':
117-
break
1+
import sys
2+
3+
import torchstudio.tcpcodec as tc
4+
from torchstudio.modules import safe_exec
5+
import os
6+
import io
7+
from collections.abc import Iterable
8+
from tqdm.auto import tqdm
9+
import pickle
10+
11+
original_path=sys.path
12+
13+
app_socket = tc.connect()
14+
print("Analyze script connected\n", file=sys.stderr)
15+
while True:
16+
msg_type, msg_data = tc.recv_msg(app_socket)
17+
18+
if msg_type == 'SetAnalyzerCode':
19+
print("Setting analyzer code...\n", file=sys.stderr)
20+
analyzer = None
21+
analyzer_code = tc.decode_strings(msg_data)[0]
22+
error_msg, analyzer_env = safe_exec(analyzer_code, description='analyzer definition')
23+
if error_msg is not None or 'analyzer' not in analyzer_env:
24+
print("Unknown analyzer definition error" if error_msg is None else error_msg, file=sys.stderr)
25+
26+
if msg_type == 'StartAnalysisServer' and 'analyzer' in analyzer_env:
27+
print("Analyzing...\n", file=sys.stderr)
28+
29+
analysis_server, address = tc.generate_server()
30+
31+
if analyzer_env['analyzer'].train is None:
32+
request_msg='AnalysisServerRequestingAllSamples'
33+
elif analyzer_env['analyzer'].train==True:
34+
request_msg='AnalysisServerRequestingTrainingSamples'
35+
elif analyzer_env['analyzer'].train==False:
36+
request_msg='AnalysisServerRequestingValidationSamples'
37+
tc.send_msg(app_socket, request_msg, tc.encode_strings(address))
38+
dataset_socket=tc.start_server(analysis_server)
39+
40+
while True:
41+
dataset_msg_type, dataset_msg_data = tc.recv_msg(dataset_socket)
42+
43+
if dataset_msg_type == 'NumSamples':
44+
num_samples=tc.decode_ints(dataset_msg_data)[0]
45+
pbar=tqdm(total=num_samples, desc='Analyzing...', bar_format='{l_bar}{bar}| {remaining} left\n\n') #see https://github.com/tqdm/tqdm#parameters
46+
47+
if dataset_msg_type == 'InputTensorsID':
48+
input_tensors_id=tc.decode_ints(dataset_msg_data)
49+
50+
if dataset_msg_type == 'OutputTensorsID':
51+
output_tensors_id=tc.decode_ints(dataset_msg_data)
52+
53+
if dataset_msg_type == 'Labels':
54+
labels=tc.decode_strings(dataset_msg_data)
55+
56+
if dataset_msg_type == 'StartSending':
57+
error_msg, return_value = safe_exec(analyzer_env['analyzer'].start_analysis, (num_samples, input_tensors_id, output_tensors_id, labels), description='analyzer definition')
58+
if error_msg is not None:
59+
pbar.close()
60+
print(error_msg, file=sys.stderr)
61+
dataset_socket.close()
62+
analysis_server.close()
63+
break
64+
65+
if dataset_msg_type == 'TrainingSample':
66+
pbar.update(1)
67+
error_msg, return_value = safe_exec(analyzer_env['analyzer'].analyze_sample, (tc.decode_numpy_tensors(dataset_msg_data), True), description='analyzer definition')
68+
if error_msg is not None:
69+
pbar.close()
70+
print(error_msg, file=sys.stderr)
71+
dataset_socket.close()
72+
analysis_server.close()
73+
break
74+
75+
if dataset_msg_type == 'ValidationSample':
76+
pbar.update(1)
77+
error_msg, return_value = safe_exec(analyzer_env['analyzer'].analyze_sample, (tc.decode_numpy_tensors(dataset_msg_data), False), description='analyzer definition')
78+
if error_msg is not None:
79+
pbar.close()
80+
print(error_msg, file=sys.stderr)
81+
dataset_socket.close()
82+
analysis_server.close()
83+
break
84+
85+
if dataset_msg_type == 'DoneSending':
86+
pbar.close()
87+
error_msg, return_value = safe_exec(analyzer_env['analyzer'].finish_analysis, description='analyzer definition')
88+
tc.send_msg(dataset_socket, 'DoneReceiving')
89+
dataset_socket.close()
90+
analysis_server.close()
91+
if error_msg is not None:
92+
print(error_msg, file=sys.stderr)
93+
else:
94+
buffer=io.BytesIO()
95+
pickle.dump(analyzer_env['analyzer'].state_dict(), buffer)
96+
tc.send_msg(app_socket, 'AnalyzerState',buffer.getvalue())
97+
tc.send_msg(app_socket, 'AnalysisWeights',tc.encode_floats(analyzer_env['analyzer'].weights))
98+
print("Analysis complete")
99+
break
100+
101+
if msg_type == 'LoadAnalyzerState':
102+
if 'analyzer' in analyzer_env:
103+
buffer=io.BytesIO(msg_data)
104+
analyzer_env['analyzer'].load_state_dict(pickle.load(buffer))
105+
print("Analyzer state loaded")
106+
107+
if msg_type == 'RequestAnalysisReport':
108+
resolution = tc.decode_ints(msg_data)
109+
if 'analyzer' in analyzer_env:
110+
error_msg, return_value = safe_exec(analyzer_env['analyzer'].generate_report, (resolution[0:2],resolution[2]), description='analyzer definition')
111+
if error_msg is not None:
112+
print(error_msg, file=sys.stderr)
113+
if return_value is not None:
114+
tc.send_msg(app_socket, 'ReportImage', tc.encode_image(return_value))
115+
116+
if msg_type == 'Exit':
117+
break

‎torchstudio/datasetload.py

+286-286
Large diffs are not rendered by default.
+50-50
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,50 @@
1-
import torch
2-
from torch.utils.data import Dataset
3-
import inspect
4-
5-
class RandomGenerator(Dataset):
6-
"""A random generator that returns randomly generated tensors
7-
8-
Args:
9-
size (int):
10-
Size of the dataset (number of samples)
11-
tensors:
12-
A list of tuples defining tensor properties: shape, type, range
13-
All properties are optionals. Defaults are null, float, [0,1]
14-
"""
15-
16-
def __init__(self, size:int=256, tensors=[(3,64,64), (int,[0,9])]):
17-
torch.manual_seed(0)
18-
self.size = size
19-
self.tensors = tensors
20-
21-
def __len__(self):
22-
return self.size
23-
24-
def __getitem__(self, idx):
25-
"""
26-
Returns:
27-
A tuple of tensors.
28-
"""
29-
sample = []
30-
for properties in self.tensors:
31-
shape=[]
32-
dtype=float
33-
drange=[0,1]
34-
for property in properties:
35-
if type(property)==int:
36-
shape.append(property)
37-
elif inspect.isclass(property):
38-
dtype=property
39-
elif type(property) is list:
40-
drange=property
41-
shape=tuple(shape)
42-
43-
if 'int' in str(dtype):
44-
tensor=torch.randint(low=drange[0], high=drange[1]+1, size=shape, dtype=dtype)
45-
else:
46-
tensor=torch.rand(size=shape,dtype=dtype)*(drange[1]-drange[0])+drange[0]
47-
48-
sample.append(tensor)
49-
50-
return tuple(sample)
1+
import torch
2+
from torch.utils.data import Dataset
3+
import inspect
4+
5+
class RandomGenerator(Dataset):
6+
"""A random generator that returns randomly generated tensors
7+
8+
Args:
9+
size (int):
10+
Size of the dataset (number of samples)
11+
tensors:
12+
A list of tuples defining tensor properties: shape, type, range
13+
All properties are optionals. Defaults are null, float, [0,1]
14+
"""
15+
16+
def __init__(self, size:int=256, tensors=[(3,64,64), (int,[0,9])]):
17+
torch.manual_seed(0)
18+
self.size = size
19+
self.tensors = tensors
20+
21+
def __len__(self):
22+
return self.size
23+
24+
def __getitem__(self, idx):
25+
"""
26+
Returns:
27+
A tuple of tensors.
28+
"""
29+
sample = []
30+
for properties in self.tensors:
31+
shape=[]
32+
dtype=float
33+
drange=[0,1]
34+
for property in properties:
35+
if type(property)==int:
36+
shape.append(property)
37+
elif inspect.isclass(property):
38+
dtype=property
39+
elif type(property) is list:
40+
drange=property
41+
shape=tuple(shape)
42+
43+
if 'int' in str(dtype):
44+
tensor=torch.randint(low=drange[0], high=drange[1]+1, size=shape, dtype=dtype)
45+
else:
46+
tensor=torch.rand(size=shape,dtype=dtype)*(drange[1]-drange[0])+drange[0]
47+
48+
sample.append(tensor)
49+
50+
return tuple(sample)

‎torchstudio/graphdraw.py

+638-638
Large diffs are not rendered by default.

‎torchstudio/metricsplot.py

+185-185
Original file line numberDiff line numberDiff line change
@@ -1,185 +1,185 @@
1-
import torchstudio.tcpcodec as tc
2-
import inspect
3-
import sys
4-
import os
5-
6-
import matplotlib as mpl
7-
import matplotlib.pyplot as plt
8-
from matplotlib.ticker import MaxNLocator
9-
import PIL
10-
11-
def plot_metrics(prefix, size, dpi, samples=100, labels=[],
12-
loss=[], loss_colors=[], loss_shift=(0,0), loss_scale=(1,1),
13-
metric=[], metric_colors=[], metric_shift=(0,0), metric_scale=(1,1)):
14-
"""Metrics Plot
15-
16-
Usage:
17-
Drag: pan
18-
Scroll: zoom vertically
19-
"""
20-
#set up matplotlib renderer, style, figure and axis
21-
mpl.use('agg') #https://www.namingcrisis.net/post/2019/03/11/interactive-matplotlib-ipython/
22-
plt.style.use('dark_background')
23-
plt.rcParams.update({'font.size': 7})
24-
25-
fig, [ax1, ax2] = plt.subplots(1 if size[0]>size[1] else 2, 2 if size[0]>size[1] else 1, figsize=(size[0]/dpi, size[1]/dpi), dpi=dpi)
26-
27-
#LOSS
28-
ax1.set_title(prefix+"Loss")
29-
30-
#fit
31-
loss_xmin=0
32-
loss_xmax=samples
33-
loss_ymin=0
34-
loss_ymax=1
35-
for l in loss:
36-
loss_xmax=max(loss_xmax,len(l))
37-
# if(len(l)>0):
38-
# loss_ymax=max(loss_ymax,max(l))
39-
40-
# #shift
41-
# render_size=(loss_xmax-loss_xmin,loss_ymax-loss_ymin)
42-
# loss_xmin-=loss_shift[0]/loss_scale[0]*render_size[0]
43-
# loss_xmax-=loss_shift[0]/loss_scale[0]*render_size[0]
44-
# loss_ymin-=loss_shift[1]/loss_scale[1]*render_size[1]
45-
# loss_ymax-=loss_shift[1]/loss_scale[1]*render_size[1]
46-
47-
# #scale
48-
# render_center=(loss_xmin+render_size[0]/2,loss_ymin+render_size[1]/2)
49-
# loss_xmin=render_center[0]-(render_size[0]/loss_scale[0]/2)
50-
# loss_xmax=render_center[0]+(render_size[0]/loss_scale[0]/2)
51-
# loss_ymin=render_center[1]-(render_size[1]/loss_scale[1]/2)
52-
# loss_ymax=render_center[1]+(render_size[1]/loss_scale[1]/2)
53-
54-
# loss_xmin=max(0,loss_xmin)
55-
# loss_ymin=max(0,loss_ymin)
56-
57-
loss_ymin-=loss_shift[1]/loss_scale[1]
58-
loss_ymax-=loss_shift[1]/loss_scale[1]
59-
loss_ymax=loss_ymax/loss_scale[1]
60-
61-
ax1.axis(xmin=loss_xmin,xmax=loss_xmax,ymin=loss_ymin,ymax=loss_ymax)
62-
ax1.spines['top'].set_visible(False)
63-
ax1.spines['right'].set_visible(False)
64-
ax1.spines['left'].set_color('#707070')
65-
ax1.spines['bottom'].set_color('#707070')
66-
for i in range(len(loss)):
67-
ax1.plot(loss[i],label=str(i) if i>=len(labels) else labels[i],color=loss_colors[i%len(loss_colors)])
68-
if labels and loss and loss[0]:
69-
ax1.legend(bbox_to_anchor=(1, 1), loc='upper right', ncol=1, prop={'size': 8})
70-
ax1.grid(color = '#303030')
71-
ax1.xaxis.set_major_locator(MaxNLocator(integer=True))
72-
ax1.ticklabel_format(axis="y", style="sci", scilimits=(0,0))
73-
74-
#METRIC
75-
ax2.set_title(prefix+"Metric")
76-
77-
#fit
78-
metric_xmin=0
79-
metric_xmax=samples
80-
metric_ymin=0
81-
metric_ymax=1
82-
for m in metric:
83-
metric_xmax=max(metric_xmax,len(m))
84-
85-
# #shift
86-
# render_size=(metric_xmax-metric_xmin,metric_ymax-metric_ymin)
87-
# metric_xmin-=metric_shift[0]/metric_scale[0]*render_size[0]
88-
# metric_xmax-=metric_shift[0]/metric_scale[0]*render_size[0]
89-
# metric_ymin-=metric_shift[1]/metric_scale[1]*render_size[1]
90-
# metric_ymax-=metric_shift[1]/metric_scale[1]*render_size[1]
91-
92-
# #scale
93-
# render_center=(metric_xmin+render_size[0]/2,metric_ymin+render_size[1]/2)
94-
# metric_xmin=render_center[0]-(render_size[0]/metric_scale[0]/2)
95-
# metric_xmax=render_center[0]+(render_size[0]/metric_scale[0]/2)
96-
# metric_ymin=render_center[1]-(render_size[1]/metric_scale[1]/2)
97-
# metric_ymax=render_center[1]+(render_size[1]/metric_scale[1]/2)
98-
99-
# metric_xmin=max(0,metric_xmin)
100-
101-
metric_ymin-=metric_shift[1]/metric_scale[1]
102-
metric_ymax-=metric_shift[1]/metric_scale[1]
103-
metric_ymin=(metric_ymin-metric_ymax)/metric_scale[1]+metric_ymax
104-
105-
ax2.axis(xmin=metric_xmin,xmax=metric_xmax,ymin=metric_ymin,ymax=metric_ymax)
106-
ax2.spines['top'].set_visible(False)
107-
ax2.spines['right'].set_visible(False)
108-
ax2.spines['left'].set_color('#707070')
109-
ax2.spines['bottom'].set_color('#707070')
110-
for i in range(len(metric)):
111-
ax2.plot(metric[i],color=metric_colors[i%len(metric_colors)])
112-
ax2.grid(color = '#303030')
113-
ax2.xaxis.set_major_locator(MaxNLocator(integer=True))
114-
115-
plt.tight_layout(pad=0)
116-
117-
canvas = plt.get_current_fig_manager().canvas
118-
canvas.draw()
119-
img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb())
120-
plt.close()
121-
return img
122-
123-
124-
prefix = ''
125-
resolution = (256,256, 96)
126-
samples=100
127-
labels = []
128-
129-
loss=[]
130-
loss_colors=[]
131-
loss_shift = (0,0)
132-
loss_scale = (1,1)
133-
134-
metric=[]
135-
metric_colors=[]
136-
metric_labels = []
137-
metric_shift = (0,0)
138-
metric_scale = (1,1)
139-
140-
app_socket = tc.connect()
141-
while True:
142-
msg_type, msg_data = tc.recv_msg(app_socket)
143-
144-
if msg_type == 'RequestDocumentation':
145-
tc.send_msg(app_socket, 'Documentation', tc.encode_strings(inspect.cleandoc(plot_metrics.__doc__)))
146-
if msg_type == 'SetPrefix':
147-
prefix=tc.decode_strings(msg_data)[0]
148-
149-
if msg_type == 'SetResolution':
150-
resolution = tc.decode_ints(msg_data)
151-
152-
if msg_type == 'NumSamples':
153-
samples = tc.decode_ints(msg_data)[0]
154-
if msg_type == 'SetLabels':
155-
labels=tc.decode_strings(msg_data)
156-
157-
if msg_type == 'ClearLoss':
158-
loss=[]
159-
if msg_type == 'AppendLoss':
160-
loss.append(tc.decode_floats(msg_data))
161-
if msg_type == 'SetLossColors':
162-
loss_colors=tc.decode_strings(msg_data)
163-
if msg_type == 'SetLossShift':
164-
loss_shift = tc.decode_floats(msg_data)
165-
if msg_type == 'SetLossScale':
166-
loss_scale = tc.decode_floats(msg_data)
167-
168-
if msg_type == 'ClearMetric':
169-
metric=[]
170-
if msg_type == 'AppendMetric':
171-
metric.append(tc.decode_floats(msg_data))
172-
if msg_type == 'SetMetricColors':
173-
metric_colors=tc.decode_strings(msg_data)
174-
if msg_type == 'SetMetricShift':
175-
metric_shift = tc.decode_floats(msg_data)
176-
if msg_type == 'SetMetricScale':
177-
metric_scale = tc.decode_floats(msg_data)
178-
179-
if msg_type == 'Render':
180-
if resolution[0]>0 and resolution[1]>0:
181-
img=plot_metrics(prefix,resolution[0:2],resolution[2],samples,labels,loss,loss_colors,loss_shift,loss_scale,metric,metric_colors,metric_shift,metric_scale)
182-
tc.send_msg(app_socket, 'ImageData', tc.encode_image(img))
183-
184-
if msg_type == 'Exit':
185-
break
1+
import torchstudio.tcpcodec as tc
2+
import inspect
3+
import sys
4+
import os
5+
6+
import matplotlib as mpl
7+
import matplotlib.pyplot as plt
8+
from matplotlib.ticker import MaxNLocator
9+
import PIL
10+
11+
def plot_metrics(prefix, size, dpi, samples=100, labels=[],
12+
loss=[], loss_colors=[], loss_shift=(0,0), loss_scale=(1,1),
13+
metric=[], metric_colors=[], metric_shift=(0,0), metric_scale=(1,1)):
14+
"""Metrics Plot
15+
16+
Usage:
17+
Drag: pan
18+
Scroll: zoom vertically
19+
"""
20+
#set up matplotlib renderer, style, figure and axis
21+
mpl.use('agg') #https://www.namingcrisis.net/post/2019/03/11/interactive-matplotlib-ipython/
22+
plt.style.use('dark_background')
23+
plt.rcParams.update({'font.size': 7})
24+
25+
fig, [ax1, ax2] = plt.subplots(1 if size[0]>size[1] else 2, 2 if size[0]>size[1] else 1, figsize=(size[0]/dpi, size[1]/dpi), dpi=dpi)
26+
27+
#LOSS
28+
ax1.set_title(prefix+"Loss")
29+
30+
#fit
31+
loss_xmin=0
32+
loss_xmax=samples
33+
loss_ymin=0
34+
loss_ymax=1
35+
for l in loss:
36+
loss_xmax=max(loss_xmax,len(l))
37+
# if(len(l)>0):
38+
# loss_ymax=max(loss_ymax,max(l))
39+
40+
# #shift
41+
# render_size=(loss_xmax-loss_xmin,loss_ymax-loss_ymin)
42+
# loss_xmin-=loss_shift[0]/loss_scale[0]*render_size[0]
43+
# loss_xmax-=loss_shift[0]/loss_scale[0]*render_size[0]
44+
# loss_ymin-=loss_shift[1]/loss_scale[1]*render_size[1]
45+
# loss_ymax-=loss_shift[1]/loss_scale[1]*render_size[1]
46+
47+
# #scale
48+
# render_center=(loss_xmin+render_size[0]/2,loss_ymin+render_size[1]/2)
49+
# loss_xmin=render_center[0]-(render_size[0]/loss_scale[0]/2)
50+
# loss_xmax=render_center[0]+(render_size[0]/loss_scale[0]/2)
51+
# loss_ymin=render_center[1]-(render_size[1]/loss_scale[1]/2)
52+
# loss_ymax=render_center[1]+(render_size[1]/loss_scale[1]/2)
53+
54+
# loss_xmin=max(0,loss_xmin)
55+
# loss_ymin=max(0,loss_ymin)
56+
57+
loss_ymin-=loss_shift[1]/loss_scale[1]
58+
loss_ymax-=loss_shift[1]/loss_scale[1]
59+
loss_ymax=loss_ymax/loss_scale[1]
60+
61+
ax1.axis(xmin=loss_xmin,xmax=loss_xmax,ymin=loss_ymin,ymax=loss_ymax)
62+
ax1.spines['top'].set_visible(False)
63+
ax1.spines['right'].set_visible(False)
64+
ax1.spines['left'].set_color('#707070')
65+
ax1.spines['bottom'].set_color('#707070')
66+
for i in range(len(loss)):
67+
ax1.plot(loss[i],label=str(i) if i>=len(labels) else labels[i],color=loss_colors[i%len(loss_colors)])
68+
if labels and loss and loss[0]:
69+
ax1.legend(bbox_to_anchor=(1, 1), loc='upper right', ncol=1, prop={'size': 8})
70+
ax1.grid(color = '#303030')
71+
ax1.xaxis.set_major_locator(MaxNLocator(nbins='auto', integer=True))
72+
ax1.ticklabel_format(axis="y", style="sci", scilimits=(0,0))
73+
74+
#METRIC
75+
ax2.set_title(prefix+"Metric")
76+
77+
#fit
78+
metric_xmin=0
79+
metric_xmax=samples
80+
metric_ymin=0
81+
metric_ymax=1
82+
for m in metric:
83+
metric_xmax=max(metric_xmax,len(m))
84+
85+
# #shift
86+
# render_size=(metric_xmax-metric_xmin,metric_ymax-metric_ymin)
87+
# metric_xmin-=metric_shift[0]/metric_scale[0]*render_size[0]
88+
# metric_xmax-=metric_shift[0]/metric_scale[0]*render_size[0]
89+
# metric_ymin-=metric_shift[1]/metric_scale[1]*render_size[1]
90+
# metric_ymax-=metric_shift[1]/metric_scale[1]*render_size[1]
91+
92+
# #scale
93+
# render_center=(metric_xmin+render_size[0]/2,metric_ymin+render_size[1]/2)
94+
# metric_xmin=render_center[0]-(render_size[0]/metric_scale[0]/2)
95+
# metric_xmax=render_center[0]+(render_size[0]/metric_scale[0]/2)
96+
# metric_ymin=render_center[1]-(render_size[1]/metric_scale[1]/2)
97+
# metric_ymax=render_center[1]+(render_size[1]/metric_scale[1]/2)
98+
99+
# metric_xmin=max(0,metric_xmin)
100+
101+
metric_ymin-=metric_shift[1]/metric_scale[1]
102+
metric_ymax-=metric_shift[1]/metric_scale[1]
103+
metric_ymin=(metric_ymin-metric_ymax)/metric_scale[1]+metric_ymax
104+
105+
ax2.axis(xmin=metric_xmin,xmax=metric_xmax,ymin=metric_ymin,ymax=metric_ymax)
106+
ax2.spines['top'].set_visible(False)
107+
ax2.spines['right'].set_visible(False)
108+
ax2.spines['left'].set_color('#707070')
109+
ax2.spines['bottom'].set_color('#707070')
110+
for i in range(len(metric)):
111+
ax2.plot(metric[i],color=metric_colors[i%len(metric_colors)])
112+
ax2.grid(color = '#303030')
113+
ax2.xaxis.set_major_locator(MaxNLocator(nbins='auto', integer=True))
114+
115+
plt.tight_layout(pad=0)
116+
117+
canvas = plt.get_current_fig_manager().canvas
118+
canvas.draw()
119+
img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb())
120+
plt.close()
121+
return img
122+
123+
124+
prefix = ''
125+
resolution = (256,256, 96)
126+
samples=100
127+
labels = []
128+
129+
loss=[]
130+
loss_colors=[]
131+
loss_shift = (0,0)
132+
loss_scale = (1,1)
133+
134+
metric=[]
135+
metric_colors=[]
136+
metric_labels = []
137+
metric_shift = (0,0)
138+
metric_scale = (1,1)
139+
140+
app_socket = tc.connect()
141+
while True:
142+
msg_type, msg_data = tc.recv_msg(app_socket)
143+
144+
if msg_type == 'RequestDocumentation':
145+
tc.send_msg(app_socket, 'Documentation', tc.encode_strings(inspect.cleandoc(plot_metrics.__doc__)))
146+
if msg_type == 'SetPrefix':
147+
prefix=tc.decode_strings(msg_data)[0]
148+
149+
if msg_type == 'SetResolution':
150+
resolution = tc.decode_ints(msg_data)
151+
152+
if msg_type == 'NumSamples':
153+
samples = tc.decode_ints(msg_data)[0]
154+
if msg_type == 'SetLabels':
155+
labels=tc.decode_strings(msg_data)
156+
157+
if msg_type == 'ClearLoss':
158+
loss=[]
159+
if msg_type == 'AppendLoss':
160+
loss.append(tc.decode_floats(msg_data))
161+
if msg_type == 'SetLossColors':
162+
loss_colors=tc.decode_strings(msg_data)
163+
if msg_type == 'SetLossShift':
164+
loss_shift = tc.decode_floats(msg_data)
165+
if msg_type == 'SetLossScale':
166+
loss_scale = tc.decode_floats(msg_data)
167+
168+
if msg_type == 'ClearMetric':
169+
metric=[]
170+
if msg_type == 'AppendMetric':
171+
metric.append(tc.decode_floats(msg_data))
172+
if msg_type == 'SetMetricColors':
173+
metric_colors=tc.decode_strings(msg_data)
174+
if msg_type == 'SetMetricShift':
175+
metric_shift = tc.decode_floats(msg_data)
176+
if msg_type == 'SetMetricScale':
177+
metric_scale = tc.decode_floats(msg_data)
178+
179+
if msg_type == 'Render':
180+
if resolution[0]>0 and resolution[1]>0:
181+
img=plot_metrics(prefix,resolution[0:2],resolution[2],samples,labels,loss,loss_colors,loss_shift,loss_scale,metric,metric_colors,metric_shift,metric_scale)
182+
tc.send_msg(app_socket, 'ImageData', tc.encode_image(img))
183+
184+
if msg_type == 'Exit':
185+
break

‎torchstudio/modelbuild.py

+246-246
Large diffs are not rendered by default.

‎torchstudio/models/unet1d.py

+143-143
Original file line numberDiff line numberDiff line change
@@ -1,143 +1,143 @@
1-
import torch
2-
import torch.nn as nn
3-
import torch.nn.functional as F
4-
5-
#heavily modified from https://github.com/jaxony/unet-pytorch/blob/master/model.py
6-
def block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm=False):
7-
sequence = []
8-
for i in range(conv_per_block):
9-
sequence.append(nn.Conv1d(in_channels if i==0 else out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2))
10-
sequence.append(nn.ReLU(inplace=True))
11-
if batch_norm:
12-
#BatchNorm best after ReLU:
13-
#https://www.reddit.com/r/MachineLearning/comments/67gonq/d_batch_normalization_before_or_after_relu/
14-
#https://stackoverflow.com/questions/39691902/ordering-of-batch-normalization-and-dropout#comment78277697_40295999
15-
#https://github.com/cvjena/cnn-models/issues/3
16-
sequence.append(nn.BatchNorm1d(out_channels))
17-
return nn.Sequential(*sequence)
18-
19-
class DownConv(nn.Module):
20-
def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm, conv_downscaling, pooling=True):
21-
super().__init__()
22-
23-
self.pooling = pooling
24-
25-
self.block = block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm)
26-
27-
if self.pooling:
28-
if not conv_downscaling:
29-
self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
30-
else:
31-
self.pool = nn.Conv1d(out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2, stride=2)
32-
33-
def forward(self, x):
34-
x = self.block(x)
35-
before_pool = x
36-
if self.pooling:
37-
x = self.pool(x)
38-
return x, before_pool
39-
40-
41-
class UpConv(nn.Module):
42-
def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm,
43-
add_merging, conv_upscaling):
44-
super().__init__()
45-
46-
self.add_merging = add_merging
47-
48-
if not conv_upscaling:
49-
self.upconv = nn.ConvTranspose1d(in_channels,out_channels,kernel_size=2,stride=2)
50-
else:
51-
self.upconv = nn.Sequential(nn.Upsample(mode='nearest', scale_factor=2),
52-
nn.Conv1d(in_channels, out_channels,kernel_size=1,groups=1,stride=1))
53-
54-
55-
self.block = block(out_channels*2 if not add_merging else out_channels, out_channels, conv_per_block, kernel_size, batch_norm)
56-
57-
def forward(self, from_down, from_up):
58-
from_up = self.upconv(from_up)
59-
if not self.add_merging:
60-
x = torch.cat((from_up, from_down), 1)
61-
else:
62-
x = from_up + from_down
63-
x = self.block(x)
64-
return x
65-
66-
67-
class UNet1D(nn.Module):
68-
""" `UNet` class is based on https://arxiv.org/abs/1505.04597
69-
UNet is a convolutional encoder-decoder neural network.
70-
71-
This 1D variant is inspired by 1D Unet are inspired by the
72-
Wave UNet ( https://arxiv.org/pdf/1806.03185.pdf )
73-
Default parameters correspond to the Wave UNet.
74-
Convolutions use padding to preserve the original size.
75-
76-
Args:
77-
in_channels: number of channels in the input tensor.
78-
out_channels: number of channels in the output tensor.
79-
feature_channels: number of channels in the first and last hidden feature layer.
80-
depth: number of levels
81-
conv_per_block: number of convolutions per level block
82-
kernel_size: kernel size for all block convolutions
83-
batch_norm: add a batch norm after ReLU
84-
conv_upscaling: use a nearest upsize+conv instead of transposed convolution
85-
conv_downscaling: use a strided convolution instead of maxpooling
86-
add_merging: merge layers from different levels using a add instead of a concat
87-
"""
88-
89-
def __init__(self, in_channels=1, out_channels=1, feature_channels=24,
90-
depth=12, conv_per_block=1, kernel_size=5, batch_norm=False,
91-
conv_upscaling=False, conv_downscaling=False, add_merging=False):
92-
super().__init__()
93-
94-
self.out_channels = out_channels
95-
self.in_channels = in_channels
96-
self.feature_channels = feature_channels
97-
self.depth = depth
98-
99-
self.down_convs = []
100-
self.up_convs = []
101-
102-
# create the encoder pathway and add to a list
103-
for i in range(depth):
104-
ins = self.in_channels if i == 0 else outs
105-
outs = self.feature_channels*(i+1)
106-
pooling = True if i < depth-1 else False
107-
108-
down_conv = DownConv(ins, outs, conv_per_block, kernel_size, batch_norm,
109-
conv_downscaling, pooling=pooling)
110-
self.down_convs.append(down_conv)
111-
112-
# create the decoder pathway and add to a list
113-
# - careful! decoding only requires depth-1 blocks
114-
for i in range(depth-1):
115-
ins = outs
116-
outs = ins - self.feature_channels
117-
up_conv = UpConv(ins, outs, conv_per_block, kernel_size, batch_norm,
118-
conv_upscaling=conv_upscaling, add_merging=add_merging)
119-
self.up_convs.append(up_conv)
120-
121-
self.conv_final = nn.Conv1d(outs, self.out_channels,kernel_size=1,groups=1,stride=1)
122-
123-
# add the list of modules to current module
124-
self.down_convs = nn.ModuleList(self.down_convs)
125-
self.up_convs = nn.ModuleList(self.up_convs)
126-
127-
def forward(self, x):
128-
encoder_outs = []
129-
130-
# encoder pathway, save outputs for merging
131-
for i, module in enumerate(self.down_convs):
132-
x, before_pool = module(x)
133-
encoder_outs.append(before_pool)
134-
135-
for i, module in enumerate(self.up_convs):
136-
before_pool = encoder_outs[-(i+2)]
137-
x = module(before_pool, x)
138-
139-
# No softmax is used. This means you need to use
140-
# nn.CrossEntropyLoss is your training script,
141-
# as this module includes a softmax already.
142-
x = self.conv_final(x)
143-
return x
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
#heavily modified from https://github.com/jaxony/unet-pytorch/blob/master/model.py
6+
def block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm=False):
7+
sequence = []
8+
for i in range(conv_per_block):
9+
sequence.append(nn.Conv1d(in_channels if i==0 else out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2))
10+
sequence.append(nn.ReLU(inplace=True))
11+
if batch_norm:
12+
#BatchNorm best after ReLU:
13+
#https://www.reddit.com/r/MachineLearning/comments/67gonq/d_batch_normalization_before_or_after_relu/
14+
#https://stackoverflow.com/questions/39691902/ordering-of-batch-normalization-and-dropout#comment78277697_40295999
15+
#https://github.com/cvjena/cnn-models/issues/3
16+
sequence.append(nn.BatchNorm1d(out_channels))
17+
return nn.Sequential(*sequence)
18+
19+
class DownConv(nn.Module):
20+
def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm, conv_downscaling, pooling=True):
21+
super().__init__()
22+
23+
self.pooling = pooling
24+
25+
self.block = block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm)
26+
27+
if self.pooling:
28+
if not conv_downscaling:
29+
self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
30+
else:
31+
self.pool = nn.Conv1d(out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2, stride=2)
32+
33+
def forward(self, x):
34+
x = self.block(x)
35+
before_pool = x
36+
if self.pooling:
37+
x = self.pool(x)
38+
return x, before_pool
39+
40+
41+
class UpConv(nn.Module):
42+
def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm,
43+
add_merging, conv_upscaling):
44+
super().__init__()
45+
46+
self.add_merging = add_merging
47+
48+
if not conv_upscaling:
49+
self.upconv = nn.ConvTranspose1d(in_channels,out_channels,kernel_size=2,stride=2)
50+
else:
51+
self.upconv = nn.Sequential(nn.Upsample(mode='nearest', scale_factor=2),
52+
nn.Conv1d(in_channels, out_channels,kernel_size=1,groups=1,stride=1))
53+
54+
55+
self.block = block(out_channels*2 if not add_merging else out_channels, out_channels, conv_per_block, kernel_size, batch_norm)
56+
57+
def forward(self, from_down, from_up):
58+
from_up = self.upconv(from_up)
59+
if not self.add_merging:
60+
x = torch.cat((from_up, from_down), 1)
61+
else:
62+
x = from_up + from_down
63+
x = self.block(x)
64+
return x
65+
66+
67+
class UNet1D(nn.Module):
68+
""" `UNet` class is based on https://arxiv.org/abs/1505.04597
69+
UNet is a convolutional encoder-decoder neural network.
70+
71+
This 1D variant is inspired by 1D Unet are inspired by the
72+
Wave UNet ( https://arxiv.org/pdf/1806.03185.pdf )
73+
Default parameters correspond to the Wave UNet.
74+
Convolutions use padding to preserve the original size.
75+
76+
Args:
77+
in_channels: number of channels in the input tensor.
78+
out_channels: number of channels in the output tensor.
79+
feature_channels: number of channels in the first and last hidden feature layer.
80+
depth: number of levels
81+
conv_per_block: number of convolutions per level block
82+
kernel_size: kernel size for all block convolutions
83+
batch_norm: add a batch norm after ReLU
84+
conv_upscaling: use a nearest upsize+conv instead of transposed convolution
85+
conv_downscaling: use a strided convolution instead of maxpooling
86+
add_merging: merge layers from different levels using a add instead of a concat
87+
"""
88+
89+
def __init__(self, in_channels=1, out_channels=1, feature_channels=24,
90+
depth=12, conv_per_block=1, kernel_size=5, batch_norm=False,
91+
conv_upscaling=False, conv_downscaling=False, add_merging=False):
92+
super().__init__()
93+
94+
self.out_channels = out_channels
95+
self.in_channels = in_channels
96+
self.feature_channels = feature_channels
97+
self.depth = depth
98+
99+
self.down_convs = []
100+
self.up_convs = []
101+
102+
# create the encoder pathway and add to a list
103+
for i in range(depth):
104+
ins = self.in_channels if i == 0 else outs
105+
outs = self.feature_channels*(i+1)
106+
pooling = True if i < depth-1 else False
107+
108+
down_conv = DownConv(ins, outs, conv_per_block, kernel_size, batch_norm,
109+
conv_downscaling, pooling=pooling)
110+
self.down_convs.append(down_conv)
111+
112+
# create the decoder pathway and add to a list
113+
# - careful! decoding only requires depth-1 blocks
114+
for i in range(depth-1):
115+
ins = outs
116+
outs = ins - self.feature_channels
117+
up_conv = UpConv(ins, outs, conv_per_block, kernel_size, batch_norm,
118+
conv_upscaling=conv_upscaling, add_merging=add_merging)
119+
self.up_convs.append(up_conv)
120+
121+
self.conv_final = nn.Conv1d(outs, self.out_channels,kernel_size=1,groups=1,stride=1)
122+
123+
# add the list of modules to current module
124+
self.down_convs = nn.ModuleList(self.down_convs)
125+
self.up_convs = nn.ModuleList(self.up_convs)
126+
127+
def forward(self, x):
128+
encoder_outs = []
129+
130+
# encoder pathway, save outputs for merging
131+
for i, module in enumerate(self.down_convs):
132+
x, before_pool = module(x)
133+
encoder_outs.append(before_pool)
134+
135+
for i, module in enumerate(self.up_convs):
136+
before_pool = encoder_outs[-(i+2)]
137+
x = module(before_pool, x)
138+
139+
# No softmax is used. This means you need to use
140+
# nn.CrossEntropyLoss is your training script,
141+
# as this module includes a softmax already.
142+
x = self.conv_final(x)
143+
return x

‎torchstudio/models/unet2d.py

+166-166
Original file line numberDiff line numberDiff line change
@@ -1,166 +1,166 @@
1-
import torch
2-
import torch.nn as nn
3-
import torch.nn.functional as F
4-
5-
#heavily modified from https://github.com/jaxony/unet-pytorch/blob/master/model.py
6-
def block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm=False):
7-
sequence = []
8-
for i in range(conv_per_block):
9-
sequence.append(nn.Conv2d(in_channels if i==0 else out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2))
10-
sequence.append(nn.ReLU(inplace=True))
11-
if batch_norm:
12-
#BatchNorm best after ReLU:
13-
#https://www.reddit.com/r/MachineLearning/comments/67gonq/d_batch_normalization_before_or_after_relu/
14-
#https://stackoverflow.com/questions/39691902/ordering-of-batch-normalization-and-dropout#comment78277697_40295999
15-
#https://github.com/cvjena/cnn-models/issues/3
16-
sequence.append(nn.BatchNorm2d(out_channels))
17-
return nn.Sequential(*sequence)
18-
19-
class DownConv(nn.Module):
20-
def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm, conv_downscaling, pooling=True):
21-
super().__init__()
22-
23-
self.in_channels=in_channels
24-
self.out_channels=out_channels
25-
self.conv_per_block=conv_per_block
26-
self.kernel_size=kernel_size
27-
self.batch_norm=batch_norm
28-
self.conv_downscaling=conv_downscaling
29-
self.pooling = pooling
30-
31-
self.block = block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm)
32-
33-
if self.pooling:
34-
if not conv_downscaling:
35-
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
36-
else:
37-
self.pool = nn.Conv2d(out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2, stride=2)
38-
39-
def forward(self, x):
40-
x = self.block(x)
41-
before_pool = x
42-
if self.pooling:
43-
x = self.pool(x)
44-
return x, before_pool
45-
46-
def extra_repr(self):
47-
# (Optional)Set the extra information about this module. You can test
48-
# it by printing an object of this class.
49-
return 'in_channels={}, out_channels={}, conv_per_block={}, kernel_size={}, batch_norm={}, conv_downscaling={}, pooling={}'.format(
50-
self.in_channels, self.out_channels, self.conv_per_block, self.kernel_size, self.batch_norm, self.conv_downscaling, self.pooling
51-
)
52-
53-
54-
class UpConv(nn.Module):
55-
def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm,
56-
add_merging, conv_upscaling):
57-
super().__init__()
58-
59-
self.in_channels=in_channels
60-
self.out_channels=out_channels
61-
self.conv_per_block=conv_per_block
62-
self.kernel_size=kernel_size
63-
self.batch_norm=batch_norm
64-
self.add_merging = add_merging
65-
self.conv_upscaling = conv_upscaling
66-
67-
if not conv_upscaling:
68-
self.upconv = nn.ConvTranspose2d(in_channels,out_channels,kernel_size=2,stride=2)
69-
else:
70-
self.upconv = nn.Sequential(nn.Upsample(mode='nearest', scale_factor=2),
71-
nn.Conv2d(in_channels, out_channels,kernel_size=1,groups=1,stride=1))
72-
73-
74-
self.block = block(out_channels*2 if not add_merging else out_channels, out_channels, conv_per_block, kernel_size, batch_norm)
75-
76-
def forward(self, from_down, from_up):
77-
from_up = self.upconv(from_up)
78-
if not self.add_merging:
79-
x = torch.cat((from_up, from_down), 1)
80-
else:
81-
x = from_up + from_down
82-
x = self.block(x)
83-
return x
84-
85-
def extra_repr(self):
86-
# (Optional)Set the extra information about this module. You can test
87-
# it by printing an object of this class.
88-
return 'in_channels={}, out_channels={}, conv_per_block={}, kernel_size={}, batch_norm={}, add_merging={}, conv_upscaling={}'.format(
89-
self.in_channels, self.out_channels, self.conv_per_block, self.kernel_size, self.batch_norm, self.add_merging, self.conv_upscaling
90-
)
91-
92-
class UNet2D(nn.Module):
93-
""" `UNet` class is based on https://arxiv.org/abs/1505.04597
94-
UNet is a convolutional encoder-decoder neural network.
95-
96-
Default parameters correspond to the original UNet, except
97-
convolutions use padding to preserve the original size.
98-
99-
Args:
100-
in_channels: number of channels in the input tensor.
101-
out_channels: number of channels in the output tensor.
102-
feature_channels: number of channels in the first and last hidden feature layer.
103-
depth: number of levels
104-
conv_per_block: number of convolutions per level block
105-
kernel_size: kernel size for all block convolutions
106-
batch_norm: add a batch norm after ReLU
107-
conv_upscaling: use a nearest upscale+conv instead of transposed convolution
108-
conv_downscaling: use a strided convolution instead of maxpooling
109-
add_merging: merge layers from different levels using a add instead of a concat
110-
"""
111-
112-
def __init__(self, in_channels=1, out_channels=2, feature_channels=64,
113-
depth=5, conv_per_block=2, kernel_size=3, batch_norm=False,
114-
conv_upscaling=False, conv_downscaling=False, add_merging=False):
115-
super().__init__()
116-
117-
self.out_channels = out_channels
118-
self.in_channels = in_channels
119-
self.feature_channels = feature_channels
120-
self.depth = depth
121-
122-
self.down_convs = []
123-
self.up_convs = []
124-
125-
# create the encoder pathway and add to a list
126-
for i in range(depth):
127-
ins = self.in_channels if i == 0 else outs
128-
outs = self.feature_channels*(2**i)
129-
pooling = True if i < depth-1 else False
130-
131-
down_conv = DownConv(ins, outs, conv_per_block, kernel_size, batch_norm,
132-
conv_downscaling, pooling=pooling)
133-
self.down_convs.append(down_conv)
134-
135-
# create the decoder pathway and add to a list
136-
# - careful! decoding only requires depth-1 blocks
137-
for i in range(depth-1):
138-
ins = outs
139-
outs = ins // 2
140-
up_conv = UpConv(ins, outs, conv_per_block, kernel_size, batch_norm,
141-
conv_upscaling=conv_upscaling, add_merging=add_merging)
142-
self.up_convs.append(up_conv)
143-
144-
self.conv_final = nn.Conv2d(outs, self.out_channels,kernel_size=1,groups=1,stride=1)
145-
146-
# add the list of modules to current module
147-
self.down_convs = nn.ModuleList(self.down_convs)
148-
self.up_convs = nn.ModuleList(self.up_convs)
149-
150-
def forward(self, x):
151-
encoder_outs = []
152-
153-
# encoder pathway, save outputs for merging
154-
for i, module in enumerate(self.down_convs):
155-
x, before_pool = module(x)
156-
encoder_outs.append(before_pool)
157-
158-
for i, module in enumerate(self.up_convs):
159-
before_pool = encoder_outs[-(i+2)]
160-
x = module(before_pool, x)
161-
162-
# No softmax is used. This means you need to use
163-
# nn.CrossEntropyLoss is your training script,
164-
# as this module includes a softmax already.
165-
x = self.conv_final(x)
166-
return x
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
#heavily modified from https://github.com/jaxony/unet-pytorch/blob/master/model.py
6+
def block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm=False):
7+
sequence = []
8+
for i in range(conv_per_block):
9+
sequence.append(nn.Conv2d(in_channels if i==0 else out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2))
10+
sequence.append(nn.ReLU(inplace=True))
11+
if batch_norm:
12+
#BatchNorm best after ReLU:
13+
#https://www.reddit.com/r/MachineLearning/comments/67gonq/d_batch_normalization_before_or_after_relu/
14+
#https://stackoverflow.com/questions/39691902/ordering-of-batch-normalization-and-dropout#comment78277697_40295999
15+
#https://github.com/cvjena/cnn-models/issues/3
16+
sequence.append(nn.BatchNorm2d(out_channels))
17+
return nn.Sequential(*sequence)
18+
19+
class DownConv(nn.Module):
20+
def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm, conv_downscaling, pooling=True):
21+
super().__init__()
22+
23+
self.in_channels=in_channels
24+
self.out_channels=out_channels
25+
self.conv_per_block=conv_per_block
26+
self.kernel_size=kernel_size
27+
self.batch_norm=batch_norm
28+
self.conv_downscaling=conv_downscaling
29+
self.pooling = pooling
30+
31+
self.block = block(in_channels, out_channels, conv_per_block, kernel_size, batch_norm)
32+
33+
if self.pooling:
34+
if not conv_downscaling:
35+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
36+
else:
37+
self.pool = nn.Conv2d(out_channels, out_channels, kernel_size, padding=(kernel_size-1)//2, stride=2)
38+
39+
def forward(self, x):
40+
x = self.block(x)
41+
before_pool = x
42+
if self.pooling:
43+
x = self.pool(x)
44+
return x, before_pool
45+
46+
def extra_repr(self):
47+
# (Optional)Set the extra information about this module. You can test
48+
# it by printing an object of this class.
49+
return 'in_channels={}, out_channels={}, conv_per_block={}, kernel_size={}, batch_norm={}, conv_downscaling={}, pooling={}'.format(
50+
self.in_channels, self.out_channels, self.conv_per_block, self.kernel_size, self.batch_norm, self.conv_downscaling, self.pooling
51+
)
52+
53+
54+
class UpConv(nn.Module):
55+
def __init__(self, in_channels, out_channels, conv_per_block, kernel_size, batch_norm,
56+
add_merging, conv_upscaling):
57+
super().__init__()
58+
59+
self.in_channels=in_channels
60+
self.out_channels=out_channels
61+
self.conv_per_block=conv_per_block
62+
self.kernel_size=kernel_size
63+
self.batch_norm=batch_norm
64+
self.add_merging = add_merging
65+
self.conv_upscaling = conv_upscaling
66+
67+
if not conv_upscaling:
68+
self.upconv = nn.ConvTranspose2d(in_channels,out_channels,kernel_size=2,stride=2)
69+
else:
70+
self.upconv = nn.Sequential(nn.Upsample(mode='nearest', scale_factor=2),
71+
nn.Conv2d(in_channels, out_channels,kernel_size=1,groups=1,stride=1))
72+
73+
74+
self.block = block(out_channels*2 if not add_merging else out_channels, out_channels, conv_per_block, kernel_size, batch_norm)
75+
76+
def forward(self, from_down, from_up):
77+
from_up = self.upconv(from_up)
78+
if not self.add_merging:
79+
x = torch.cat((from_up, from_down), 1)
80+
else:
81+
x = from_up + from_down
82+
x = self.block(x)
83+
return x
84+
85+
def extra_repr(self):
86+
# (Optional)Set the extra information about this module. You can test
87+
# it by printing an object of this class.
88+
return 'in_channels={}, out_channels={}, conv_per_block={}, kernel_size={}, batch_norm={}, add_merging={}, conv_upscaling={}'.format(
89+
self.in_channels, self.out_channels, self.conv_per_block, self.kernel_size, self.batch_norm, self.add_merging, self.conv_upscaling
90+
)
91+
92+
class UNet2D(nn.Module):
93+
""" `UNet` class is based on https://arxiv.org/abs/1505.04597
94+
UNet is a convolutional encoder-decoder neural network.
95+
96+
Default parameters correspond to the original UNet, except
97+
convolutions use padding to preserve the original size.
98+
99+
Args:
100+
in_channels: number of channels in the input tensor.
101+
out_channels: number of channels in the output tensor.
102+
feature_channels: number of channels in the first and last hidden feature layer.
103+
depth: number of levels
104+
conv_per_block: number of convolutions per level block
105+
kernel_size: kernel size for all block convolutions
106+
batch_norm: add a batch norm after ReLU
107+
conv_upscaling: use a nearest upscale+conv instead of transposed convolution
108+
conv_downscaling: use a strided convolution instead of maxpooling
109+
add_merging: merge layers from different levels using a add instead of a concat
110+
"""
111+
112+
def __init__(self, in_channels=1, out_channels=2, feature_channels=64,
113+
depth=5, conv_per_block=2, kernel_size=3, batch_norm=False,
114+
conv_upscaling=False, conv_downscaling=False, add_merging=False):
115+
super().__init__()
116+
117+
self.out_channels = out_channels
118+
self.in_channels = in_channels
119+
self.feature_channels = feature_channels
120+
self.depth = depth
121+
122+
self.down_convs = []
123+
self.up_convs = []
124+
125+
# create the encoder pathway and add to a list
126+
for i in range(depth):
127+
ins = self.in_channels if i == 0 else outs
128+
outs = self.feature_channels*(2**i)
129+
pooling = True if i < depth-1 else False
130+
131+
down_conv = DownConv(ins, outs, conv_per_block, kernel_size, batch_norm,
132+
conv_downscaling, pooling=pooling)
133+
self.down_convs.append(down_conv)
134+
135+
# create the decoder pathway and add to a list
136+
# - careful! decoding only requires depth-1 blocks
137+
for i in range(depth-1):
138+
ins = outs
139+
outs = ins // 2
140+
up_conv = UpConv(ins, outs, conv_per_block, kernel_size, batch_norm,
141+
conv_upscaling=conv_upscaling, add_merging=add_merging)
142+
self.up_convs.append(up_conv)
143+
144+
self.conv_final = nn.Conv2d(outs, self.out_channels,kernel_size=1,groups=1,stride=1)
145+
146+
# add the list of modules to current module
147+
self.down_convs = nn.ModuleList(self.down_convs)
148+
self.up_convs = nn.ModuleList(self.up_convs)
149+
150+
def forward(self, x):
151+
encoder_outs = []
152+
153+
# encoder pathway, save outputs for merging
154+
for i, module in enumerate(self.down_convs):
155+
x, before_pool = module(x)
156+
encoder_outs.append(before_pool)
157+
158+
for i, module in enumerate(self.up_convs):
159+
before_pool = encoder_outs[-(i+2)]
160+
x = module(before_pool, x)
161+
162+
# No softmax is used. This means you need to use
163+
# nn.CrossEntropyLoss is your training script,
164+
# as this module includes a softmax already.
165+
x = self.conv_final(x)
166+
return x

‎torchstudio/modeltrain.py

+364-360
Large diffs are not rendered by default.

‎torchstudio/parametersplot.py

+170-170
Original file line numberDiff line numberDiff line change
@@ -1,170 +1,170 @@
1-
import torchstudio.tcpcodec as tc
2-
import inspect
3-
import sys
4-
import os
5-
6-
import matplotlib as mpl
7-
import matplotlib.pyplot as plt
8-
from matplotlib.ticker import MaxNLocator
9-
import PIL
10-
11-
def sorted(l,reverse=False):
12-
floats=True
13-
for x in l:
14-
try:
15-
float(x)
16-
except:
17-
floats=False
18-
break
19-
l.sort(key=float if floats else None,reverse=reverse)
20-
return l
21-
22-
#inspired by https://stackoverflow.com/questions/8230638/parallel-coordinates-plot-in-matplotlib
23-
def plot_parameters(size, dpi,
24-
parameters=[], #parameters is a list of parameters
25-
values=[], #values is a list of list containing string values
26-
order=[]): #sorting order for each parameter(1 or -1)
27-
"""Parameters Plot
28-
29-
Usage:
30-
Click: invert parameter sorting order
31-
"""
32-
#set up matplotlib renderer, style, figure and axis
33-
mpl.use('agg') #https://www.namingcrisis.net/post/2019/03/11/interactive-matplotlib-ipython/
34-
plt.style.use('dark_background')
35-
plt.rcParams.update({'font.size': 7})
36-
37-
if len(parameters)<2:
38-
parameters=['Name', 'Validation\nMetric']
39-
40-
# parameters=['Name', 'feature_channels', 'depth', 'Metric Value']
41-
# values=[['Model 1','32','3','.95'],['Model 2','24','4','.9'],['Model 3','16','3','.98'],['Model 4','16','3']]
42-
43-
if len(order)<len(parameters):
44-
order=[1]*len(parameters)
45-
order[0]=-1
46-
47-
param_values=[[] for i in range(len(parameters))]
48-
for value in values:
49-
for i, v in enumerate(value):
50-
if v not in param_values[i]:
51-
param_values[i].append(v)
52-
for i, v in enumerate(param_values):
53-
param_values[i]=sorted(param_values[i], True if order[i]==-1 else False)
54-
55-
fig, host = plt.subplots(figsize=(size[0]/dpi, size[1]/dpi), dpi=dpi)
56-
57-
axes = [host] + [host.twinx() for i in range(len(parameters)-1)]
58-
59-
for i, ax in enumerate(axes):
60-
ax.set_ylim(0, 1)
61-
ax.spines['top'].set_visible(False)
62-
ax.spines['bottom'].set_visible(False)
63-
ax.spines['right'].set_visible(False)
64-
ax.spines['left'].set_position(("axes", i / (len(parameters) - 1)))
65-
ax.spines['left'].set_color((0.2,0.2,0.2))
66-
ax.yaxis.set_tick_params(width=0)
67-
ax.yaxis.set_ticks_position('left')
68-
ax.xaxis.set_tick_params(width=0)
69-
ax.set_yticks([j/(len(param_values[i])-1) if len(param_values[i])>1 else .5 for j in range(len(param_values[i]))])
70-
ax.set_yticklabels(param_values[i])
71-
#first parameter is the model name, keep the set_ticks
72-
axes[0].yaxis.set_tick_params(width=1)
73-
#last parameter is the metric, let the colorbar do the metric
74-
axes[-1].yaxis.set_ticks_position('none')
75-
axes[-1].set_yticklabels([])
76-
axes[-1].spines['left'].set_visible(False)
77-
78-
#set the colorbar for the metric
79-
if param_values[-1]:
80-
max_metric=min_metric=float(param_values[-1][0])
81-
for metric_value in param_values[-1]:
82-
min_metric=min(min_metric,float(metric_value))
83-
max_metric=max(max_metric,float(metric_value))
84-
else:
85-
max_metric=min_metric=0
86-
87-
cmap = plt.get_cmap('viridis') # 'viridis' or 'rainbow'
88-
sc = host.scatter([0,0], [0,0], s=[0,0], c=[min_metric, max_metric], cmap=cmap)
89-
cbar = fig.colorbar(sc, ax=axes[-1], pad=0)
90-
cbar.outline.set_visible(False)
91-
# cbar.set_ticks([])
92-
93-
#set horizontal axe settings
94-
host.set_xlim(0, len(parameters) - 1)
95-
host.set_xticks(range(len(parameters)))
96-
host.set_xticklabels(parameters)
97-
host.tick_params(axis='x', which='major', pad=7)
98-
host.spines['right'].set_visible(False)
99-
host.xaxis.tick_top()
100-
101-
102-
103-
from matplotlib.path import Path
104-
import matplotlib.patches as patches
105-
import numpy as np
106-
for tokens in values:
107-
values_num=[]
108-
for i, token in enumerate(tokens):
109-
if i<len(parameters)-1:
110-
values_num.append(param_values[i].index(token)/(len(param_values[i])-1) if len(param_values[i])>1 else .5)
111-
else:
112-
values_num.append((float(token)-min_metric)/(max_metric-min_metric) if len(param_values[i])>1 and max_metric>min_metric else .5)
113-
114-
# create bezier curves
115-
# for each axis, there will a control vertex at the point itself, one at 1/3rd towards the previous and one
116-
# at one third towards the next axis; the first and last axis have one less control vertex
117-
# x-coordinate of the control vertices: at each integer (for the axes) and two inbetween
118-
# y-coordinate: repeat every point three times, except the first and last only twice
119-
verts = list(zip([x for x in np.linspace(0, len(values_num) - 1, len(values_num) * 3 - 2, endpoint=True)],
120-
np.repeat(values_num, 3)[1:-1]))
121-
# for x,y in verts: host.plot(x, y, 'go') # to show the control points of the beziers
122-
codes = [Path.MOVETO] + [Path.CURVE4 for _ in range(len(verts) - 1)]
123-
path = Path(verts, codes)
124-
patch = patches.PathPatch(path, facecolor='none', lw=1, edgecolor=cmap(values_num[-1]) if len(values_num)==len(parameters) else (0.33, 0.33, 0.33), zorder=values_num[-1] if len(values_num)==len(parameters) else -1)
125-
host.add_patch(patch)
126-
127-
plt.tight_layout(pad=0)
128-
129-
canvas = plt.get_current_fig_manager().canvas
130-
canvas.draw()
131-
img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb())
132-
plt.close()
133-
return img
134-
135-
136-
resolution = (256,256, 96)
137-
138-
parameters=[]
139-
values=[]
140-
order=[]
141-
142-
143-
app_socket = tc.connect()
144-
while True:
145-
msg_type, msg_data = tc.recv_msg(app_socket)
146-
147-
if msg_type == 'RequestDocumentation':
148-
tc.send_msg(app_socket, 'Documentation', tc.encode_strings(inspect.cleandoc(plot_parameters.__doc__)))
149-
150-
if msg_type == 'SetResolution':
151-
resolution = tc.decode_ints(msg_data)
152-
153-
if msg_type == 'SetParameters':
154-
parameters=tc.decode_strings(msg_data)
155-
156-
if msg_type == 'ClearValues':
157-
values = []
158-
if msg_type == 'AppendValues':
159-
values.append(tc.decode_strings(msg_data))
160-
161-
if msg_type == 'SetOrder':
162-
order=tc.decode_ints(msg_data)
163-
164-
if msg_type == 'Render':
165-
if resolution[0]>0 and resolution[1]>0:
166-
img=plot_parameters(resolution[0:2],resolution[2],parameters,values,order)
167-
tc.send_msg(app_socket, 'ImageData', tc.encode_image(img))
168-
169-
if msg_type == 'Exit':
170-
break
1+
import torchstudio.tcpcodec as tc
2+
import inspect
3+
import sys
4+
import os
5+
6+
import matplotlib as mpl
7+
import matplotlib.pyplot as plt
8+
from matplotlib.ticker import MaxNLocator
9+
import PIL
10+
11+
def sorted(l,reverse=False):
12+
floats=True
13+
for x in l:
14+
try:
15+
float(x)
16+
except:
17+
floats=False
18+
break
19+
l.sort(key=float if floats else None,reverse=reverse)
20+
return l
21+
22+
#inspired by https://stackoverflow.com/questions/8230638/parallel-coordinates-plot-in-matplotlib
23+
def plot_parameters(size, dpi,
24+
parameters=[], #parameters is a list of parameters
25+
values=[], #values is a list of list containing string values
26+
order=[]): #sorting order for each parameter(1 or -1)
27+
"""Parameters Plot
28+
29+
Usage:
30+
Click: invert parameter sorting order
31+
"""
32+
#set up matplotlib renderer, style, figure and axis
33+
mpl.use('agg') #https://www.namingcrisis.net/post/2019/03/11/interactive-matplotlib-ipython/
34+
plt.style.use('dark_background')
35+
plt.rcParams.update({'font.size': 7})
36+
37+
if len(parameters)<2:
38+
parameters=['Name', 'Validation\nMetric']
39+
40+
# parameters=['Name', 'feature_channels', 'depth', 'Metric Value']
41+
# values=[['Model 1','32','3','.95'],['Model 2','24','4','.9'],['Model 3','16','3','.98'],['Model 4','16','3']]
42+
43+
if len(order)<len(parameters):
44+
order=[1]*len(parameters)
45+
order[0]=-1
46+
47+
param_values=[[] for i in range(len(parameters))]
48+
for value in values:
49+
for i, v in enumerate(value):
50+
if v not in param_values[i]:
51+
param_values[i].append(v)
52+
for i, v in enumerate(param_values):
53+
param_values[i]=sorted(param_values[i], True if order[i]==-1 else False)
54+
55+
fig, host = plt.subplots(figsize=(size[0]/dpi, size[1]/dpi), dpi=dpi)
56+
57+
axes = [host] + [host.twinx() for i in range(len(parameters)-1)]
58+
59+
for i, ax in enumerate(axes):
60+
ax.set_ylim(0, 1)
61+
ax.spines['top'].set_visible(False)
62+
ax.spines['bottom'].set_visible(False)
63+
ax.spines['right'].set_visible(False)
64+
ax.spines['left'].set_position(("axes", i / (len(parameters) - 1)))
65+
ax.spines['left'].set_color((0.2,0.2,0.2))
66+
ax.yaxis.set_tick_params(width=0)
67+
ax.yaxis.set_ticks_position('left')
68+
ax.xaxis.set_tick_params(width=0)
69+
ax.set_yticks([j/(len(param_values[i])-1) if len(param_values[i])>1 else .5 for j in range(len(param_values[i]))])
70+
ax.set_yticklabels(param_values[i])
71+
#first parameter is the model name, keep the set_ticks
72+
axes[0].yaxis.set_tick_params(width=1)
73+
#last parameter is the metric, let the colorbar do the metric
74+
axes[-1].yaxis.set_ticks_position('none')
75+
axes[-1].set_yticklabels([])
76+
axes[-1].spines['left'].set_visible(False)
77+
78+
#set the colorbar for the metric
79+
if param_values[-1]:
80+
max_metric=min_metric=float(param_values[-1][0])
81+
for metric_value in param_values[-1]:
82+
min_metric=min(min_metric,float(metric_value))
83+
max_metric=max(max_metric,float(metric_value))
84+
else:
85+
max_metric=min_metric=0
86+
87+
cmap = plt.get_cmap('viridis') # 'viridis' or 'rainbow'
88+
sc = host.scatter([0,0], [0,0], s=[0,0], c=[min_metric, max_metric], cmap=cmap)
89+
cbar = fig.colorbar(sc, ax=axes[-1], pad=0)
90+
cbar.outline.set_visible(False)
91+
# cbar.set_ticks([])
92+
93+
#set horizontal axe settings
94+
host.set_xlim(0, len(parameters) - 1)
95+
host.set_xticks(range(len(parameters)))
96+
host.set_xticklabels(parameters)
97+
host.tick_params(axis='x', which='major', pad=7)
98+
host.spines['right'].set_visible(False)
99+
host.xaxis.tick_top()
100+
101+
102+
103+
from matplotlib.path import Path
104+
import matplotlib.patches as patches
105+
import numpy as np
106+
for tokens in values:
107+
values_num=[]
108+
for i, token in enumerate(tokens):
109+
if i<len(parameters)-1:
110+
values_num.append(param_values[i].index(token)/(len(param_values[i])-1) if len(param_values[i])>1 else .5)
111+
else:
112+
values_num.append((float(token)-min_metric)/(max_metric-min_metric) if len(param_values[i])>1 and max_metric>min_metric else .5)
113+
114+
# create bezier curves
115+
# for each axis, there will a control vertex at the point itself, one at 1/3rd towards the previous and one
116+
# at one third towards the next axis; the first and last axis have one less control vertex
117+
# x-coordinate of the control vertices: at each integer (for the axes) and two inbetween
118+
# y-coordinate: repeat every point three times, except the first and last only twice
119+
verts = list(zip([x for x in np.linspace(0, len(values_num) - 1, len(values_num) * 3 - 2, endpoint=True)],
120+
np.repeat(values_num, 3)[1:-1]))
121+
# for x,y in verts: host.plot(x, y, 'go') # to show the control points of the beziers
122+
codes = [Path.MOVETO] + [Path.CURVE4 for _ in range(len(verts) - 1)]
123+
path = Path(verts, codes)
124+
patch = patches.PathPatch(path, facecolor='none', lw=1, edgecolor=cmap(values_num[-1]) if len(values_num)==len(parameters) else (0.33, 0.33, 0.33), zorder=values_num[-1] if len(values_num)==len(parameters) else -1)
125+
host.add_patch(patch)
126+
127+
plt.tight_layout(pad=0)
128+
129+
canvas = plt.get_current_fig_manager().canvas
130+
canvas.draw()
131+
img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb())
132+
plt.close()
133+
return img
134+
135+
136+
resolution = (256,256, 96)
137+
138+
parameters=[]
139+
values=[]
140+
order=[]
141+
142+
143+
app_socket = tc.connect()
144+
while True:
145+
msg_type, msg_data = tc.recv_msg(app_socket)
146+
147+
if msg_type == 'RequestDocumentation':
148+
tc.send_msg(app_socket, 'Documentation', tc.encode_strings(inspect.cleandoc(plot_parameters.__doc__)))
149+
150+
if msg_type == 'SetResolution':
151+
resolution = tc.decode_ints(msg_data)
152+
153+
if msg_type == 'SetParameters':
154+
parameters=tc.decode_strings(msg_data)
155+
156+
if msg_type == 'ClearValues':
157+
values = []
158+
if msg_type == 'AppendValues':
159+
values.append(tc.decode_strings(msg_data))
160+
161+
if msg_type == 'SetOrder':
162+
order=tc.decode_ints(msg_data)
163+
164+
if msg_type == 'Render':
165+
if resolution[0]>0 and resolution[1]>0:
166+
img=plot_parameters(resolution[0:2],resolution[2],parameters,values,order)
167+
tc.send_msg(app_socket, 'ImageData', tc.encode_image(img))
168+
169+
if msg_type == 'Exit':
170+
break

‎torchstudio/pythoninstall.cmd

+239-239
Large diffs are not rendered by default.

‎torchstudio/pythonparse.py

+414-415
Large diffs are not rendered by default.

‎torchstudio/sshtunnel.py

+340-337
Large diffs are not rendered by default.

‎torchstudio/tensorrender.py

+70-70
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,70 @@
1-
import torchstudio.tcpcodec as tc
2-
from torchstudio.modules import safe_exec
3-
import inspect
4-
import sys
5-
import os
6-
7-
title = ''
8-
tensor = None
9-
resolution = (256,256, 96)
10-
shift = (0,0,0,0)
11-
scale = (1,1,1,1)
12-
input_tensors = []
13-
target_tensor = None
14-
labels = []
15-
16-
app_socket = tc.connect()
17-
while True:
18-
msg_type, msg_data = tc.recv_msg(app_socket)
19-
20-
if msg_type == 'SetRendererCode':
21-
error_msg, renderer_env = safe_exec(tc.decode_strings(msg_data)[0],description='renderer definition')
22-
if error_msg is not None or 'renderer' not in renderer_env:
23-
print("Unknown renderer definition error" if error_msg is None else error_msg, file=sys.stderr)
24-
else:
25-
tc.send_msg(app_socket, 'Documentation', tc.encode_strings(inspect.cleandoc(renderer_env['renderer'].__doc__) if renderer_env['renderer'].__doc__ is not None else ""))
26-
27-
if msg_type == 'Clear':
28-
tensor = None
29-
input_tensors = []
30-
target_tensor = None
31-
32-
if msg_type == 'SetTitle':
33-
title = tc.decode_strings(msg_data)[0]
34-
35-
if msg_type == 'TensorData':
36-
tensor = tc.decode_numpy_tensors(msg_data)[0]
37-
38-
if msg_type == 'SetResolution':
39-
resolution = tc.decode_ints(msg_data)
40-
41-
if msg_type == 'SetShift':
42-
shift = tc.decode_floats(msg_data)
43-
if msg_type == 'SetScale':
44-
scale = tc.decode_floats(msg_data)
45-
46-
if msg_type == 'SetInputTensors':
47-
input_tensors = tc.decode_numpy_tensors(msg_data)
48-
49-
if msg_type == 'SetTargetTensors':
50-
target_tensors = tc.decode_numpy_tensors(msg_data)
51-
if target_tensors:
52-
target_tensor=target_tensors[0]
53-
else:
54-
target_tensor=None
55-
56-
if msg_type == 'SetLabels':
57-
labels = tc.decode_strings(msg_data)
58-
59-
if msg_type == 'Render':
60-
if 'renderer' in renderer_env and tensor is not None and resolution[0]>0 and resolution[1]>0:
61-
error_msg, img = safe_exec(renderer_env['renderer'].render, (title, tensor,resolution[0:2],resolution[2],shift,scale,input_tensors,target_tensor,labels), description='renderer definition')
62-
if error_msg is not None:
63-
print(error_msg, file=sys.stderr)
64-
if img is None:
65-
tc.send_msg(app_socket, 'ImageError')
66-
else:
67-
tc.send_msg(app_socket, 'ImageData', tc.encode_image(img))
68-
69-
if msg_type == 'Exit':
70-
break
1+
import torchstudio.tcpcodec as tc
2+
from torchstudio.modules import safe_exec
3+
import inspect
4+
import sys
5+
import os
6+
7+
title = ''
8+
tensor = None
9+
resolution = (256,256, 96)
10+
shift = (0,0,0,0)
11+
scale = (1,1,1,1)
12+
input_tensors = []
13+
target_tensor = None
14+
labels = []
15+
16+
app_socket = tc.connect()
17+
while True:
18+
msg_type, msg_data = tc.recv_msg(app_socket)
19+
20+
if msg_type == 'SetRendererCode':
21+
error_msg, renderer_env = safe_exec(tc.decode_strings(msg_data)[0],description='renderer definition')
22+
if error_msg is not None or 'renderer' not in renderer_env:
23+
print("Unknown renderer definition error" if error_msg is None else error_msg, file=sys.stderr)
24+
else:
25+
tc.send_msg(app_socket, 'Documentation', tc.encode_strings(inspect.cleandoc(renderer_env['renderer'].__doc__) if renderer_env['renderer'].__doc__ is not None else ""))
26+
27+
if msg_type == 'Clear':
28+
tensor = None
29+
input_tensors = []
30+
target_tensor = None
31+
32+
if msg_type == 'SetTitle':
33+
title = tc.decode_strings(msg_data)[0]
34+
35+
if msg_type == 'TensorData':
36+
tensor = tc.decode_numpy_tensors(msg_data)[0]
37+
38+
if msg_type == 'SetResolution':
39+
resolution = tc.decode_ints(msg_data)
40+
41+
if msg_type == 'SetShift':
42+
shift = tc.decode_floats(msg_data)
43+
if msg_type == 'SetScale':
44+
scale = tc.decode_floats(msg_data)
45+
46+
if msg_type == 'SetInputTensors':
47+
input_tensors = tc.decode_numpy_tensors(msg_data)
48+
49+
if msg_type == 'SetTargetTensors':
50+
target_tensors = tc.decode_numpy_tensors(msg_data)
51+
if target_tensors:
52+
target_tensor=target_tensors[0]
53+
else:
54+
target_tensor=None
55+
56+
if msg_type == 'SetLabels':
57+
labels = tc.decode_strings(msg_data)
58+
59+
if msg_type == 'Render':
60+
if 'renderer' in renderer_env and tensor is not None and resolution[0]>0 and resolution[1]>0:
61+
error_msg, img = safe_exec(renderer_env['renderer'].render, (title, tensor,resolution[0:2],resolution[2],shift,scale,input_tensors,target_tensor,labels), description='renderer definition')
62+
if error_msg is not None:
63+
print(error_msg, file=sys.stderr)
64+
if img is None:
65+
tc.send_msg(app_socket, 'ImageError')
66+
else:
67+
tc.send_msg(app_socket, 'ImageData', tc.encode_image(img))
68+
69+
if msg_type == 'Exit':
70+
break

0 commit comments

Comments
 (0)
Please sign in to comment.