Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kaggle cryo #1947

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions competitions/kaggle/Cryo-ET/1st_place_solution/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
## Introduction

This tutorial illustrates how to use MONAI for cryo electron tomography. The pipeline and models were partly used to win the [Cryo-ET competition on kaggle](https://www.kaggle.com/competitions/czii-cryo-et-object-identification/overview). The tutorial was tested with nvidia/pytorch:24.08-py3 docker container and a single A100 GPU.

## What is Cryo-ET?

If you ask ChatGPT:

Cryo-ET (Cryo-Electron Tomography) is an advanced imaging technique that allows scientists to visualize biological structures in near-native states at high resolution. It combines cryogenic sample preservation with electron tomography to generate three-dimensional (3D) reconstructions of cellular structures, protein complexes, and organelles.

### How It Works
1. Cryo-Fixation: The sample (e.g., a cell or a purified macromolecular complex) is rapidly frozen using liquid ethane or similar methods to prevent ice crystal formation, preserving its natural state.
2. Electron Microscopy: The frozen sample is placed under a transmission electron microscope (TEM), where images are taken from multiple angles by tilting the sample.
3. Tomographic Reconstruction: Computational algorithms combine these 2D images to create a detailed 3D model of the structure.

### Applications
Studying cellular architecture at nanometer resolution.
Visualizing macromolecular complexes in their native environments.
Understanding interactions between viruses and host cells.
Investigating neurodegenerative diseases, cancer, and infectious diseases.
Cryo-ET is particularly powerful because it enables direct imaging of biological systems without the need for staining or chemical fixation, preserving their native conformation.


## Environment:

To have a common environment its suggested to use the basic pytorch docker container and add a few pip packages on top

1. This tutorial was tested with tag 24.08-py3, i.e. run the following command to pull/ start the container.

```docker run nvcr.io/nvidia/pytorch:24.08-py3```

2. Within the container clone this repository

```
git clone https://github.com/ProjectMONAI/tutorials
cd tutorials/competitions/kaggle/Cryo-ET/1st_place_solution/
```


3. And install necessary additional pip packages via

```pip install -r requirements.txt```

## Required Data

This tutorial is build upon the official Cryo ET competition data. It can be downloaded directly from kaggle: https://www.kaggle.com/competitions/czii-cryo-et-object-identification/data

Alternativly it can be downloaded using the kaggle API (which can be installed via ```pip install kaggle```)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you decide to use the Kaggle API you need to create a Kaggle account and configure your token as described here.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will also need to follow the competition url and click "join competition" to accept the terms and conditions and then be allowed to download the data with the following command:

```kaggle competitions download -c czii-cryo-et-object-identification```

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you do this inside of the docker pod, this data will be ephemeral and be deleted when the pod goes down. I suggest adding an option to mount the data in the correct location and instruct people to download the data outside the pod.


and adjust path to it in ```configs/common_config.py``` with ```cfg.data_folder```.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you mount the folder in the correct location you don't need to adjust this. So this line can be removed.




## Training models

For the competition we created a cross-validation scheme by simply simply splitting the 7 training tomographs into 7 folds. I.e. we train on 6 tomographs and use the 7th as validation.
For convenience we provide a file ```train_folded_v1.csv``` which contains the original training annotations and was also extended by a column containing fold_ids.

We solve the competition with a 3D-segmentation approach leveraging [MONAI's FlexibleUNet](https://docs.monai.io/en/stable/networks.html#flexibleunet) architecture. Compared to the original implementation we adjusted the network to output more featuremap and enable deep-supervision. The following illustrates the resulting architecture at a high level:

<p align="center">
<img src="partly_Unet.png" alt="figure of a Partly UNet")
</p>

We provide three different configurations which differ only in the used backbone and output feature maps. The configuration files are .py files and located under ```configs``` and share all other hyper-parameters. Each hyperparameter can be overwriten by adding a flag to the training command. To train a resnet34 version of our segmentation model simply run

```python train.py -C cfg_resnet34 --output_dir WHATEVERISYOUROUTPUTDIR```

This will save checkpoints under the specified WHATEVERISYOUROUTPUTDIR when training is finished.
By default models are trained using bfloat16 which requires a GPU capable of that. Alternatively you can set ```cfg.bf16=False``` or overwrite as flag ```--bf16 False``` when running ```train.py ```.

### Replicating 1st place solution (segmentation part)

To train checkpoints necessary for replicating the segmentation part of the 1st place solution run training of 2x fullfits for each model. Thereby ```cfg.fold = -1``` results in training on all data, and using ```fold 0``` as validation.
```
python train.py -C cfg_resnet34 --fold -1
python train.py -C cfg_resnet34 --fold -1
python train.py -C cfg_resnet34_ds --fold -1
python train.py -C cfg_resnet34_ds --fold -1
python train.py -C cfg_effnetb3 --fold -1
python train.py -C cfg_effnetb3 --fold -1
```

## Inference

Inference after models are converted with torch jit is shown in our 1st place submission kaggle kernel.

https://www.kaggle.com/code/christofhenkel/cryo-et-1st-place-solution?scriptVersionId=223259615
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from common_config import basic_cfg
import os
import pandas as pd
import numpy as np
import monai.transforms as mt

cfg = basic_cfg

cfg.name = os.path.basename(__file__).split(".")[0]
cfg.output_dir = f"/mount/cryo/models/{os.path.basename(__file__).split('.')[0]}"

# model
cfg.backbone = "efficientnet-b3"
cfg.backbone_args = dict(
spatial_dims=3,
in_channels=cfg.in_channels,
out_channels=cfg.n_classes,
backbone=cfg.backbone,
pretrained=cfg.pretrained,
)
cfg.class_weights = np.array([64, 64, 64, 64, 64, 64, 1])
cfg.lvl_weights = np.array([0, 0, 0, 1])
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from common_config import basic_cfg
import os
import pandas as pd
import numpy as np

cfg = basic_cfg

# paths
cfg.name = os.path.basename(__file__).split(".")[0]
cfg.output_dir = f"/mount/cryo/models/{os.path.basename(__file__).split('.')[0]}"


# model

cfg.backbone = "resnet34"
cfg.backbone_args = dict(
spatial_dims=3,
in_channels=cfg.in_channels,
out_channels=cfg.n_classes,
backbone=cfg.backbone,
pretrained=cfg.pretrained,
)
cfg.class_weights = np.array([256, 256, 256, 256, 256, 256, 1])
cfg.lvl_weights = np.array([0, 0, 0, 1])
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from common_config import basic_cfg
import os
import pandas as pd
import numpy as np

cfg = basic_cfg

# paths
cfg.name = os.path.basename(__file__).split(".")[0]
cfg.output_dir = f"/mount/cryo/models/{os.path.basename(__file__).split('.')[0]}"

cfg.backbone = "resnet34"
cfg.backbone_args = dict(
spatial_dims=3,
in_channels=cfg.in_channels,
out_channels=cfg.n_classes,
backbone=cfg.backbone,
pretrained=cfg.pretrained,
)
cfg.class_weights = np.array([64, 64, 64, 64, 64, 64, 1])
cfg.lvl_weights = np.array([0, 0, 1, 1])
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
from types import SimpleNamespace
from monai import transforms as mt

cfg = SimpleNamespace(**{})

# stages
cfg.train = True
cfg.val = True
cfg.test = True
cfg.train_val = True

# dataset
cfg.batch_size_val = None
cfg.use_custom_batch_sampler = False
cfg.val_df = None
cfg.test_df = None
cfg.val_data_folder = None
cfg.train_aug = None
cfg.val_aug = None
cfg.data_sample = -1

# model

cfg.pretrained = False
cfg.pretrained_weights = None
cfg.pretrained_weights_strict = True
cfg.pop_weights = None
cfg.compile_model = False

# training routine
cfg.fold = 0
cfg.optimizer = "Adam"
cfg.sgd_momentum = 0
cfg.sgd_nesterov = False
cfg.lr = 1e-4
cfg.schedule = "cosine"
cfg.num_cycles = 0.5
cfg.weight_decay = 0
cfg.epochs = 10
cfg.seed = -1
cfg.resume_training = False
cfg.distributed = False
cfg.clip_grad = 0
cfg.save_val_data = True
cfg.gradient_checkpointing = False
cfg.apex_ddp = False
cfg.synchronize_step = True

# eval
cfg.eval_ddp = True
cfg.calc_metric = True
cfg.calc_metric_epochs = 1
cfg.eval_steps = 0
cfg.eval_epochs = 1
cfg.save_pp_csv = True


# ressources
cfg.find_unused_parameters = False
cfg.grad_accumulation = 1
cfg.syncbn = False
cfg.gpu = 0
cfg.dp = False
cfg.num_workers = 8
cfg.drop_last = True
cfg.save_checkpoint = True
cfg.save_only_last_ckpt = False
cfg.save_weights_only = False

# logging,
cfg.neptune_project = None
cfg.neptune_connection_mode = "debug"
cfg.save_first_batch = False
cfg.save_first_batch_preds = False
cfg.clip_mode = "norm"
cfg.data_sample = -1
cfg.track_grad_norm = True
cfg.grad_norm_type = 2.0
cfg.track_weight_norm = True
cfg.norm_eps = 1e-4
cfg.disable_tqdm = False


# paths

cfg.data_folder = "/mount/cryo/data/czii-cryo-et-object-identification/train/static/ExperimentRuns/"
cfg.train_df = "train_folded_v1.csv"


# stages
cfg.test = False
cfg.train = True
cfg.train_val = False

# logging
cfg.neptune_project = None
cfg.neptune_connection_mode = "async"

# model
cfg.model = "mdl_1"
cfg.mixup_p = 1.0
cfg.mixup_beta = 1.0
cfg.in_channels = 1
cfg.pretrained = False

# data
cfg.dataset = "ds_1"
cfg.classes = ["apo-ferritin", "beta-amylase", "beta-galactosidase", "ribosome", "thyroglobulin", "virus-like-particle"]
cfg.n_classes = len(cfg.classes)

cfg.post_process_pipeline = "pp_1"
cfg.metric = "metric_1"


cfg.particle_radi = {
"apo-ferritin": 60,
"beta-amylase": 65,
"beta-galactosidase": 90,
"ribosome": 150,
"thyroglobulin": 130,
"virus-like-particle": 135,
}

cfg.voxel_spacing = 10.0


# OPTIMIZATION & SCHEDULE

cfg.fold = 0
cfg.epochs = 10

cfg.lr = 1e-3
cfg.optimizer = "Adam"
cfg.weight_decay = 0.0
cfg.warmup = 0.0
cfg.batch_size = 8
cfg.batch_size_val = 16
cfg.sub_batch_size = 4
cfg.roi_size = [96, 96, 96]
cfg.train_sub_epochs = 1112
cfg.val_sub_epochs = 1
cfg.mixed_precision = False
cfg.bf16 = True
cfg.force_fp16 = True
cfg.pin_memory = False
cfg.grad_accumulation = 1.0
cfg.num_workers = 8


# Saving
cfg.save_weights_only = True
cfg.save_only_last_ckpt = False
cfg.save_val_data = False
cfg.save_checkpoint = True
cfg.save_pp_csv = False


cfg.static_transforms = static_transforms = mt.Compose(
[
mt.EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
mt.NormalizeIntensityd(keys="image"),
]
)
cfg.train_aug = mt.Compose(
[
mt.RandSpatialCropSamplesd(keys=["image", "label"], roi_size=cfg.roi_size, num_samples=cfg.sub_batch_size),
mt.RandFlipd(
keys=["image", "label"],
prob=0.5,
spatial_axis=0,
),
mt.RandFlipd(
keys=["image", "label"],
prob=0.5,
spatial_axis=1,
),
mt.RandFlipd(
keys=["image", "label"],
prob=0.5,
spatial_axis=2,
),
mt.RandRotate90d(
keys=["image", "label"],
prob=0.75,
max_k=3,
spatial_axes=(0, 1),
),
mt.RandRotated(
keys=["image", "label"], prob=0.5, range_x=0.78, range_y=0.0, range_z=0.0, padding_mode="reflection"
),
]
)

cfg.val_aug = mt.Compose([mt.GridPatchd(keys=["image", "label"], patch_size=cfg.roi_size, pad_mode="reflect")])


basic_cfg = cfg
Loading
Loading