Skip to content

Commit 1db03fd

Browse files
initial add of new competition tutorial
1 parent 4a40380 commit 1db03fd

File tree

13 files changed

+3389
-0
lines changed

13 files changed

+3389
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
## Introduction
2+
3+
This tutorial illustrates how to use MONAI for cryo electron tomography. The pipeline and models were partly used to win the [Cryo-ET competition on kaggle](https://www.kaggle.com/competitions/czii-cryo-et-object-identification/overview). The tutorial was tested with nvidia/pytorch:24.08-py3 docker container and a single A100 GPU.
4+
5+
## What is Cryo-ET?
6+
7+
If you ask ChatGPT:
8+
9+
Cryo-ET (Cryo-Electron Tomography) is an advanced imaging technique that allows scientists to visualize biological structures in near-native states at high resolution. It combines cryogenic sample preservation with electron tomography to generate three-dimensional (3D) reconstructions of cellular structures, protein complexes, and organelles.
10+
11+
### How It Works
12+
1. Cryo-Fixation: The sample (e.g., a cell or a purified macromolecular complex) is rapidly frozen using liquid ethane or similar methods to prevent ice crystal formation, preserving its natural state.
13+
2. Electron Microscopy: The frozen sample is placed under a transmission electron microscope (TEM), where images are taken from multiple angles by tilting the sample.
14+
3. Tomographic Reconstruction: Computational algorithms combine these 2D images to create a detailed 3D model of the structure.
15+
16+
### Applications
17+
Studying cellular architecture at nanometer resolution.
18+
Visualizing macromolecular complexes in their native environments.
19+
Understanding interactions between viruses and host cells.
20+
Investigating neurodegenerative diseases, cancer, and infectious diseases.
21+
Cryo-ET is particularly powerful because it enables direct imaging of biological systems without the need for staining or chemical fixation, preserving their native conformation.
22+
23+
24+
## Environment:
25+
26+
To have a common environment its suggested to use the basic pytorch docker container and add a few pip packages on top
27+
28+
1. This tutorial was tested with tag 24.08-py3, i.e. run the following command to pull/ start the container.
29+
30+
```docker run nvcr.io/nvidia/pytorch:24.08-py3```
31+
32+
2. Within the container clone this repository
33+
34+
```
35+
git clone https://github.com/ProjectMONAI/tutorials
36+
cd tutorials/competitions/kaggle/Cryo-ET/1st_place_solution/
37+
```
38+
39+
40+
3. And install necessary additional pip packages via
41+
42+
```pip install -r requirements.txt```
43+
44+
## Required Data
45+
46+
This tutorial is build upon the official Cryo ET competition data. It can be downloaded directly from kaggle: https://www.kaggle.com/competitions/czii-cryo-et-object-identification/data
47+
48+
Alternativly it can be downloaded using the kaggle API (which can be installed via ```pip install kaggle```)
49+
50+
```kaggle competitions download -c czii-cryo-et-object-identification```
51+
52+
and adjust path to it in ```configs/common_config.py``` with ```cfg.data_folder```
53+
54+
55+
56+
## Training models
57+
58+
For the competition we created a cross-validation scheme by simply simply splitting the 7 training tomographs into 7 folds. I.e. we train on 6 tomographs and use the 7th as validation.
59+
For convenience we provide a file ```train_folded_v1.csv``` which contains the original training annotations and was also extended by a column containing fold_ids.
60+
61+
We solve the competition with a 3D-segmentation approach leveraging [MONAI's FlexibleUNet](https://docs.monai.io/en/stable/networks.html#flexibleunet) architecture. Compared to the original implementation we adjusted the network to output more featuremap and enable deep-supervision. The following illustrates the resulting architecture at a high level:
62+
63+
![alt text](partly_Unet.png "Partly UNet")
64+
65+
We provide three different configurations which differ only in the used backbone and output feature maps. The configuration files are .py files and located under ```configs``` and share all other hyper-parameters. Each hyperparameter can be overwriten by adding a flag to the training command. To train a resnet34 version of our segmentation model simply run
66+
67+
```python train.py -C cfg_resnet34 --output_dir WHATEVERISYOUROUTPUTDIR```
68+
69+
This will save checkpoints under the specified WHATEVERISYOUROUTPUTDIR.
70+
By default models are trained using bfloat16 which requires a GPU capable of that. Alternatively you can set ```cfg.bf16=False``` or overwrite as flag ```--bf16 False``` when running ```train.py ```.
71+
72+
### Replicating 1st place solution (segmentation part)
73+
74+
To train checkpoints necessary for replicating the segmentation part of the 1st place solution run training of 2x fullfits for each model. Thereby ```cfg.fold = -1``` results in training on all data, and using ```fold 0``` as validation.
75+
```
76+
python train.py -C cfg_resnet34 --fold -1
77+
python train.py -C cfg_resnet34 --fold -1
78+
python train.py -C cfg_resnet34_ds --fold -1
79+
python train.py -C cfg_resnet34_ds --fold -1
80+
python train.py -C cfg_effnetb3 --fold -1
81+
python train.py -C cfg_effnetb3 --fold -1
82+
```
83+
84+
## Inference
85+
86+
Inference after models are converted with torch jit is shown in our 1st place submission kaggle kernel.
87+
88+
https://www.kaggle.com/code/christofhenkel/cryo-et-1st-place-solution?scriptVersionId=223259615
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from common_config import basic_cfg
2+
import os
3+
import pandas as pd
4+
import numpy as np
5+
import monai.transforms as mt
6+
7+
cfg = basic_cfg
8+
9+
cfg.name = os.path.basename(__file__).split(".")[0]
10+
cfg.output_dir = f"/mount/cryo/models/{os.path.basename(__file__).split('.')[0]}"
11+
12+
#model
13+
cfg.backbone = 'efficientnet-b3'
14+
cfg.backbone_args = dict(spatial_dims=3,
15+
in_channels=cfg.in_channels,
16+
out_channels=cfg.n_classes,
17+
backbone=cfg.backbone,
18+
pretrained=cfg.pretrained)
19+
cfg.class_weights = np.array([64,64,64,64,64,64,1])
20+
cfg.lvl_weights = np.array([0,0,0,1])
21+
22+
23+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from common_config import basic_cfg
2+
import os
3+
import pandas as pd
4+
import numpy as np
5+
6+
cfg = basic_cfg
7+
8+
# paths
9+
cfg.name = os.path.basename(__file__).split(".")[0]
10+
cfg.output_dir = f"/mount/cryo/models/{os.path.basename(__file__).split('.')[0]}"
11+
12+
13+
#model
14+
15+
cfg.backbone = 'resnet34'
16+
cfg.backbone_args = dict(spatial_dims=3,
17+
in_channels=cfg.in_channels,
18+
out_channels=cfg.n_classes,
19+
backbone=cfg.backbone,
20+
pretrained=cfg.pretrained)
21+
cfg.class_weights = np.array([256,256,256,256,256,256,1])
22+
cfg.lvl_weights = np.array([0,0,0,1])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from common_config import basic_cfg
2+
import os
3+
import pandas as pd
4+
import numpy as np
5+
6+
cfg = basic_cfg
7+
8+
# paths
9+
cfg.name = os.path.basename(__file__).split(".")[0]
10+
cfg.output_dir = f"/mount/cryo/models/{os.path.basename(__file__).split('.')[0]}"
11+
12+
cfg.backbone = 'resnet34'
13+
cfg.backbone_args = dict(spatial_dims=3,
14+
in_channels=cfg.in_channels,
15+
out_channels=cfg.n_classes,
16+
backbone=cfg.backbone,
17+
pretrained=cfg.pretrained)
18+
cfg.class_weights = np.array([64,64,64,64,64,64,1])
19+
cfg.lvl_weights = np.array([0,0,1,1])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
from types import SimpleNamespace
2+
from monai import transforms as mt
3+
4+
cfg = SimpleNamespace(**{})
5+
6+
# stages
7+
cfg.train = True
8+
cfg.val = True
9+
cfg.test = True
10+
cfg.train_val = True
11+
12+
# dataset
13+
cfg.batch_size_val = None
14+
cfg.use_custom_batch_sampler = False
15+
cfg.val_df = None
16+
cfg.test_df = None
17+
cfg.val_data_folder = None
18+
cfg.train_aug = None
19+
cfg.val_aug = None
20+
cfg.data_sample = -1
21+
22+
# model
23+
24+
cfg.pretrained = False
25+
cfg.pretrained_weights = None
26+
cfg.pretrained_weights_strict = True
27+
cfg.pop_weights = None
28+
cfg.compile_model = False
29+
30+
# training routine
31+
cfg.fold = 0
32+
cfg.optimizer = "Adam"
33+
cfg.sgd_momentum = 0
34+
cfg.sgd_nesterov = False
35+
cfg.lr = 1e-4
36+
cfg.schedule = "cosine"
37+
cfg.num_cycles = 0.5
38+
cfg.weight_decay = 0
39+
cfg.epochs = 10
40+
cfg.seed = -1
41+
cfg.resume_training = False
42+
cfg.distributed = False
43+
cfg.clip_grad = 0
44+
cfg.save_val_data = True
45+
cfg.gradient_checkpointing = False
46+
cfg.apex_ddp = False
47+
cfg.synchronize_step = True
48+
49+
# eval
50+
cfg.eval_ddp = True
51+
cfg.calc_metric = True
52+
cfg.calc_metric_epochs = 1
53+
cfg.eval_steps = 0
54+
cfg.eval_epochs = 1
55+
cfg.save_pp_csv = True
56+
57+
58+
# ressources
59+
cfg.find_unused_parameters = False
60+
cfg.grad_accumulation = 1
61+
cfg.syncbn = False
62+
cfg.gpu = 0
63+
cfg.dp = False
64+
cfg.num_workers = 8
65+
cfg.drop_last = True
66+
cfg.save_checkpoint = True
67+
cfg.save_only_last_ckpt = False
68+
cfg.save_weights_only = False
69+
70+
# logging,
71+
cfg.neptune_project = None
72+
cfg.neptune_connection_mode = "debug"
73+
cfg.save_first_batch = False
74+
cfg.save_first_batch_preds = False
75+
cfg.clip_mode = "norm"
76+
cfg.data_sample = -1
77+
cfg.track_grad_norm = True
78+
cfg.grad_norm_type = 2.
79+
cfg.track_weight_norm = True
80+
cfg.norm_eps = 1e-4
81+
cfg.disable_tqdm = False
82+
83+
84+
85+
86+
# paths
87+
88+
cfg.data_folder = '/mount/cryo/data/czii-cryo-et-object-identification/train/static/ExperimentRuns/'
89+
cfg.train_df = 'train_folded_v1.csv'
90+
91+
92+
# stages
93+
cfg.test = False
94+
cfg.train = True
95+
cfg.train_val = False
96+
97+
#logging
98+
cfg.neptune_project = None
99+
cfg.neptune_connection_mode = "async"
100+
101+
#model
102+
cfg.model = "mdl_1"
103+
cfg.mixup_p = 1.
104+
cfg.mixup_beta = 1.
105+
cfg.in_channels = 1
106+
cfg.pretrained = False
107+
108+
#data
109+
cfg.dataset = "ds_1"
110+
cfg.classes = ['apo-ferritin','beta-amylase','beta-galactosidase','ribosome','thyroglobulin','virus-like-particle']
111+
cfg.n_classes = len(cfg.classes)
112+
113+
cfg.post_process_pipeline = 'pp_1'
114+
cfg.metric = 'metric_1'
115+
116+
117+
118+
cfg.particle_radi = {'apo-ferritin':60,
119+
'beta-amylase':65,
120+
'beta-galactosidase':90,
121+
'ribosome':150,
122+
'thyroglobulin':130,
123+
'virus-like-particle':135
124+
}
125+
126+
cfg.voxel_spacing = 10.0
127+
128+
129+
# OPTIMIZATION & SCHEDULE
130+
131+
cfg.fold = 0
132+
cfg.epochs = 10
133+
134+
cfg.lr = 1e-3
135+
cfg.optimizer = "Adam"
136+
cfg.weight_decay = 0.
137+
cfg.warmup = 0.
138+
cfg.batch_size = 8
139+
cfg.batch_size_val = 16
140+
cfg.sub_batch_size = 4
141+
cfg.roi_size = [96,96,96]
142+
cfg.train_sub_epochs = 1112
143+
cfg.val_sub_epochs = 1
144+
cfg.mixed_precision = False
145+
cfg.bf16 = True
146+
cfg.force_fp16 = True
147+
cfg.pin_memory = False
148+
cfg.grad_accumulation = 1.
149+
cfg.num_workers = 8
150+
151+
152+
153+
154+
155+
156+
#Saving
157+
cfg.save_weights_only = True
158+
cfg.save_only_last_ckpt = False
159+
cfg.save_val_data = False
160+
cfg.save_checkpoint=True
161+
cfg.save_pp_csv = False
162+
163+
164+
165+
cfg.static_transforms = static_transforms = mt.Compose([mt.EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),mt.NormalizeIntensityd(keys="image"),])
166+
cfg.train_aug = mt.Compose([mt.RandSpatialCropSamplesd(keys=["image", "label"],
167+
roi_size=cfg.roi_size,
168+
num_samples=cfg.sub_batch_size),
169+
mt.RandFlipd(
170+
keys=["image", "label"],
171+
prob=0.5,
172+
spatial_axis=0,
173+
),
174+
mt.RandFlipd(
175+
keys=["image", "label"],
176+
prob=0.5,
177+
spatial_axis=1,
178+
),
179+
mt.RandFlipd(
180+
keys=["image", "label"],
181+
prob=0.5,
182+
spatial_axis=2,
183+
),
184+
mt.RandRotate90d(
185+
keys=["image", "label"],
186+
prob=0.75,
187+
max_k=3,
188+
spatial_axes=(0, 1),
189+
),
190+
mt.RandRotated(keys=["image", "label"], prob=0.5,range_x=0.78,range_y=0.,range_z=0., padding_mode='reflection')
191+
192+
])
193+
194+
cfg.val_aug = mt.Compose([mt.GridPatchd(keys=["image","label"],patch_size=cfg.roi_size, pad_mode='reflect')])
195+
196+
197+
198+
basic_cfg = cfg

0 commit comments

Comments
 (0)