-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy path04_1_train_pimms_models.py
351 lines (281 loc) · 9.96 KB
/
04_1_train_pimms_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
# %% [markdown]
# # PIMMS Tutorial: Scikit-learn style transformers
#
# 1. Load data into pandas dataframe
# 2. Fit model on training data, potentially specify validation data
# 3. Impute only missing values with predictions from model
#
# Autoencoders need wide training data, i.e. a sample with all its features' intensities, whereas
# Collaborative Filtering needs long training data, i.e. sample identifier a feature identifier and the intensity.
# Both data formats can be transformed into each other, but models using long data format do not need to
# take care of missing values.
# %%
import os
from importlib import metadata
IN_COLAB = 'COLAB_GPU' in os.environ
if IN_COLAB:
try:
_v = metadata.version('pimms-learn')
print(f"Running in colab and pimms-learn ({_v}) is installed.")
except metadata.PackageNotFoundError:
print("Install PIMMS...")
# # !pip install git+https://github.com/RasmussenLab/pimms.git
# !pip install pimms-learn
# %% [markdown]
# If on colab, please restart the environment and run everything from here on.
#
# Specify example data:
# %%
import os
IN_COLAB = 'COLAB_GPU' in os.environ
fn_intensities = 'data/dev_datasets/HeLa_6070/protein_groups_wide_N50.csv'
if IN_COLAB:
fn_intensities = ('https://raw.githubusercontent.com/RasmussenLab/pimms/main/'
'project/data/dev_datasets/HeLa_6070/protein_groups_wide_N50.csv')
# %% [markdown]
# Load package.
# %%
import logging
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.display import display
import pimmslearn.filter
import pimmslearn.plotting.data
import pimmslearn.sampling
from pimmslearn.plotting.defaults import color_model_mapping
from pimmslearn.sklearn.ae_transformer import AETransformer
from pimmslearn.sklearn.cf_transformer import CollaborativeFilteringTransformer
pimmslearn.plotting.make_large_descriptors(8)
logger = logger = pimmslearn.logging.setup_nb_logger()
logging.getLogger('fontTools').setLevel(logging.ERROR)
# %% [markdown]
# ## Parameters
# Can be set by papermill on the command line or manually in the (colab) notebook.
# %%
fn_intensities: str = fn_intensities # path or url to the data file in csv format
index_name: str = 'Sample ID' # name of the index column
column_name: str = 'protein group' # name of the column index
select_features: bool = True # Whether to select features based on prevalence
feat_prevalence: float = 0.2 # minimum prevalence of a feature to be included
sample_completeness: float = 0.3 # minimum completeness of a sample to be included
sample_splits: bool = True # Whether to sample validation and test data
frac_non_train: float = 0.1 # fraction of non training data (validation and test split)
frac_mnar: float = 0.0 # fraction of missing not at random data, rest: missing completely at random
random_state: int = 42 # random state for reproducibility
# %% [markdown]
# ## Data
# %%
df = pd.read_csv(fn_intensities, index_col=0)
df.head()
# %% [markdown]
# We will need the data in long format for Collaborative Filtering.
# Naming both the row and column index assures
# that the data can be transformed very easily into long format:
# %%
df.index.name = index_name # already set
df.columns.name = column_name # not set due to csv disk file format
df.head()
# %% [markdown]
# ### Data transformation: log2 transformation
# Transform the data using the logarithm, here using base 2:
# %%
df = np.log2(df + 1)
df.head()
# %% [markdown]
# ### two plots inspecting data availability
#
# 1. proportion of missing values per feature median (N = protein groups)
# 2. CDF of available intensities per protein group
# %%
ax = pimmslearn.plotting.data.plot_feat_median_over_prop_missing(
data=df, type='boxplot')
# %%
df.notna().sum().sort_values().plot()
# %% [markdown]
# ### Data selection
# define a minimum feature and sample frequency for a feature to be included
# %%
if select_features:
# potentially this can take a few iterations to stabilize.
df = pimmslearn.filter.select_features(df, feat_prevalence=feat_prevalence)
df = pimmslearn.filter.select_features(df=df, feat_prevalence=sample_completeness, axis=1)
df.shape
# %% [markdown]
# Transform to long-data format:
# %%
df = df.stack().to_frame('intensity')
df
# %% [markdown]
# ## Optionally: Sample data
# - models can be trained without subsetting the data
# - allows evaluation of the models
# %%
if sample_splits:
splits, thresholds, fake_na_mcar, fake_na_mnar = pimmslearn.sampling.sample_mnar_mcar(
df_long=df,
frac_non_train=frac_non_train,
frac_mnar=frac_mnar,
random_state=random_state,
)
splits = pimmslearn.sampling.check_split_integrity(splits)
else:
splits = pimmslearn.sampling.DataSplits(is_wide_format=False)
splits.train_X = df
# %% [markdown]
# The resulting DataFrame with one column has an `MulitIndex` with the sample and feature identifier.
# %% [markdown]
# ## Collaborative Filtering
#
# Inspect annotations of the scikit-learn like Transformer:
# %%
# # # CollaborativeFilteringTransformer?
# %% [markdown]
# Let's set up collaborative filtering without a validation or test set, using
# all the data there is.
# %%
cf_model = CollaborativeFilteringTransformer(
target_column='intensity',
sample_column='Sample ID',
item_column='protein group',
out_folder='runs/scikit_interface')
# %% [markdown]
# We use `fit` and `transform` to train the model and impute the missing values.
# > Scikit learns interface requires a `X` and `y`. `y` is the validation data in our context.
# > We might have to change the interface to allow usage within pipelines (-> `y` is not needed).
# > This will probably mean setting up a validation set within the model.
# %%
cf_model.fit(splits.train_X,
splits.val_y,
cuda=False,
epochs_max=20,
)
# %% [markdown]
# Impute missing values usin `transform` method:
# %%
df_imputed = cf_model.transform(df).unstack()
assert df_imputed.isna().sum().sum() == 0
df_imputed.head()
# %% [markdown]
# Let's plot the distribution of the imputed values vs the ones used for training:
# %%
df_imputed = df_imputed.stack() # long-format
observed = df_imputed.loc[df.index]
imputed = df_imputed.loc[df_imputed.index.difference(df.index)]
df_imputed = df_imputed.unstack() # back to wide-format
# some checks
assert len(df) == len(observed)
assert df_imputed.shape[0] * df_imputed.shape[1] == len(imputed) + len(observed)
fig, axes = plt.subplots(2, figsize=(8, 4))
min_max = pimmslearn.plotting.data.get_min_max_iterable(
[observed, imputed])
label_template = '{method} (N={n:,d})'
ax, _ = pimmslearn.plotting.data.plot_histogram_intensities(
observed,
ax=axes[0],
min_max=min_max,
label=label_template.format(method='measured',
n=len(observed),
),
color='grey',
alpha=1)
_ = ax.legend()
ax, _ = pimmslearn.plotting.data.plot_histogram_intensities(
imputed,
ax=axes[1],
min_max=min_max,
label=label_template.format(method='CF imputed',
n=len(imputed),
),
color=color_model_mapping['CF'],
alpha=1)
_ = ax.legend()
# %% [markdown]
# ## AutoEncoder architectures
# %%
# Use wide format of data
splits.to_wide_format()
splits.train_X
# %% [markdown]
# Validation data for early stopping (if specified)
# %%
splits.val_y
# %% [markdown]
# Training and validation need the same shape:
# %%
if splits.val_y is not None:
splits.val_y = pd.DataFrame(pd.NA, index=splits.train_X.index,
columns=splits.train_X.columns).fillna(splits.val_y)
print(splits.train_X.shape, splits.val_y.shape)
# %% [markdown]
# Select either `DAE` or `VAE` model by chance:
# %%
model_selected = random.choice(['DAE', 'VAE'])
print("Selected model by chance:", model_selected)
model = AETransformer(
model=model_selected,
hidden_layers=[512,],
latent_dim=50,
out_folder='runs/scikit_interface',
batch_size=10,
)
# %%
model.fit(splits.train_X, splits.val_y,
epochs_max=50,
cuda=False)
# %% [markdown]
# Impute missing values using `transform` method:
# %%
df_imputed = model.transform(splits.train_X).stack()
df_imputed
# %% [markdown]
# Evaluate the model using the validation data:
# %%
if splits.val_y is not None:
pred_val = splits.val_y.stack().to_frame('observed')
pred_val[model_selected] = df_imputed
val_metrics = pimmslearn.models.calculte_metrics(pred_val, 'observed')
display(val_metrics)
fig, ax = plt.subplots(figsize=(8, 2))
ax, errors_binned = pimmslearn.plotting.errors.plot_errors_by_median(
pred=pred_val,
target_col='observed',
feat_medians=splits.train_X.median(),
ax=ax,
metric_name='MAE',
palette=color_model_mapping)
# %% [markdown]
# replace predicted values with validation data values
# %%
splits.to_long_format()
df_imputed = df_imputed.replace(splits.val_y).replace(splits.test_y)
df_imputed
# %% [markdown]
# Plot the distribution of the imputed values vs the observed data:
# %%
observed = df_imputed.loc[df.index].squeeze()
imputed = df_imputed.loc[df_imputed.index.difference(df.index)].squeeze()
fig, axes = plt.subplots(2, figsize=(8, 4))
min_max = pimmslearn.plotting.data.get_min_max_iterable([observed, imputed])
label_template = '{method} (N={n:,d})'
ax, _ = pimmslearn.plotting.data.plot_histogram_intensities(
observed,
ax=axes[0],
min_max=min_max,
label=label_template.format(method='measured',
n=len(observed),
),
color='grey',
alpha=1)
_ = ax.legend()
ax, _ = pimmslearn.plotting.data.plot_histogram_intensities(
imputed,
ax=axes[1],
min_max=min_max,
label=label_template.format(method=f'{model_selected} imputed',
n=len(imputed),
),
color=color_model_mapping[model_selected],
alpha=1)
_ = ax.legend()