{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5e8ae3d7-3e2e-4755-a0b6-709ef4180719",
   "metadata": {},
   "source": [
    "Copyright (c) MONAI Consortium  \n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");  \n",
    "you may not use this file except in compliance with the License.  \n",
    "You may obtain a copy of the License at  \n",
    "    http://www.apache.org/licenses/LICENSE-2.0  \n",
    "Unless required by applicable law or agreed to in writing, software  \n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,  \n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  \n",
    "See the License for the specific language governing permissions and  \n",
    "limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "191c5d77-8ae5-49ab-be22-45f5ba41641f",
   "metadata": {},
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "886952c4-0be4-459d-9c53-b81b29199c76",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-16T13:48:44.235392252Z",
     "start_time": "2023-10-16T13:48:28.253469477Z"
    }
   },
   "outputs": [],
   "source": [
    "!python -c \"import monai\" || pip install -q \"monai-weekly[ignite,pyyaml]\"\n",
    "!pip install -q pytorch-lightning~=2.0.0"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a20e1274-0a27-4e37-95d7-fb813243c34c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-06T15:07:19.730871161Z",
     "start_time": "2023-10-06T15:07:11.317018521Z"
    }
   },
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1144d87-ec2f-4b9b-907a-16ea2da279c4",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-06T15:11:48.797015283Z",
     "start_time": "2023-10-06T15:11:42.300276550Z"
    }
   },
   "outputs": [],
   "source": [
    "from monai.apps import download_and_extract\n",
    "from monai.config import print_config\n",
    "import os\n",
    "import shutil\n",
    "import tempfile\n",
    "\n",
    "print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c572d8b6-3dca-4487-80ad-928090b3e8ab",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-06T15:07:34.380130283Z",
     "start_time": "2023-10-06T15:07:34.330086596Z"
    }
   },
   "source": [
    "# Spleen Segmentation Lightning Bundle\n",
    "\n",
    "In this tutorial we'll describe how to create a bundle for a segmentation network. This will include how to train and apply the network on the command line. Medical  will be used as the dataset with the bundle based off the [Spleen 3D segmentation with MONAI](https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/spleen_segmentation_3d_lightning.ipynb) from Spleen segmentation using Task_09 subset from the Medical Segmentation Decathlon.\n",
    "\n",
    "This work is licensed under the Creative Commons Attribution-ShareAlike 4.0 International License. To view a copy of this license, visit http://creativecommons.org/licenses/by-sa/4.0/.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a18d5cd-6338-4b41-87fd-4e119723bfee",
   "metadata": {},
   "source": [
    "Let's start by initialising a bundle directory structure and create a python module `scripts`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e00b5416-dfab-4043-9293-ec2acdf5842d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-16T14:44:10.513242253Z",
     "start_time": "2023-10-16T14:44:02.546210817Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/usr/bin/tree\n",
      "SpleenSegLightning\n",
      "├── configs\n",
      "│   └── metadata.json\n",
      "├── docs\n",
      "│   └── README.md\n",
      "├── LICENSE\n",
      "├── models\n",
      "└── scripts\n",
      "    └── __init__.py\n",
      "\n",
      "4 directories, 4 files\n"
     ]
    }
   ],
   "source": [
    "%%bash\n",
    "\n",
    "python -m monai.bundle init_bundle SpleenSegLightning\n",
    "rm SpleenSegLightning/configs/inference.json\n",
    "mkdir SpleenSegLightning/scripts\n",
    "touch SpleenSegLightning/scripts/__init__.py\n",
    "which tree && tree SpleenSegLightning || true"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6228ccdd-d2d5-4344-b838-aa58e63f04f9",
   "metadata": {},
   "source": [
    "## Download dataset and put into a directory"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fac6ab59-ca77-4d0e-ab0d-e5df91ead522",
   "metadata": {},
   "source": [
    "First, we set up a temporary data directory, download the data, and move to a temporary directory `data_dir`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "672f0168-950a-440e-b77c-f785a92c0f74",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-16T14:44:10.513242253Z",
     "start_time": "2023-10-16T14:44:02.546210817Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "None\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Task09_Spleen.tar: 1.50GB [02:12, 12.2MB/s]                                "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2024-09-18 08:32:51,645 - INFO - Downloaded: /tmp/tmphbjnt0zm/Task09_Spleen.tar\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2024-09-18 08:32:53,738 - INFO - Verified 'Task09_Spleen.tar', md5: 410d4a301da4e5b2f6f86ec3ddba524e.\n",
      "2024-09-18 08:32:53,739 - INFO - Writing into directory: /tmp/tmphbjnt0zm.\n"
     ]
    }
   ],
   "source": [
    "resource = \"https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar\"\n",
    "md5 = \"410d4a301da4e5b2f6f86ec3ddba524e\"\n",
    "\n",
    "directory = os.environ.get(\"DATA_DIR\")\n",
    "print(directory)\n",
    "if directory is not None:\n",
    "    os.makedirs(directory, exist_ok=True)\n",
    "root_dir = tempfile.mkdtemp() if directory is None else directory\n",
    "compressed_file = os.path.join(root_dir, \"Task09_Spleen.tar\")\n",
    "data_dir = os.path.join(root_dir, \"Task09_Spleen\")\n",
    "os.environ[\"DATA_DIR\"] = data_dir\n",
    "\n",
    "if not os.path.exists(data_dir):\n",
    "    download_and_extract(resource, compressed_file, root_dir, md5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5888c9bd-5022-40b5-9dec-84d9f737f868",
   "metadata": {},
   "source": [
    "## Metadata\n",
    "\n",
    "We'll first replace the `metadata.json` file with our description of what the network will do:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b29f053b-cf16-4ffc-bbe7-d9433fdfa872",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-16T14:45:11.617630093Z",
     "start_time": "2023-10-16T14:45:11.573340254Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting SpleenSegLightning/configs/metadata.json\n"
     ]
    }
   ],
   "source": [
    "%%writefile SpleenSegLightning/configs/metadata.json\n",
    "\n",
    "{\n",
    "    \"version\": \"0.0.1\",\n",
    "    \"changelog\": {\n",
    "        \"0.0.1\": \"Initial version\"\n",
    "    },\n",
    "    \"monai_version\": \"1.2.0\",\n",
    "    \"pytorch_version\": \"2.0.0\",\n",
    "    \"numpy_version\": \"1.23.5\",\n",
    "    \"required_packages_version\": {},\n",
    "    \"name\": \"SpleenSegLightning\",\n",
    "    \"task\": \"3D Spleen segmentation network using MONAI and Pytorch Lightning\",\n",
    "    \"description\": \"This is a demo network for segmentation of the spleen from 3D MRI images.\",\n",
    "    \"authors\": \"Your Name Here\",\n",
    "    \"copyright\": \"Copyright (c) Your Name Here\",\n",
    "    \"data_source\": \"Task_09 subset from the Medical Segmentation Decathlon\",\n",
    "    \"data_type\": \"Nifti\",\n",
    "    \"intended_use\": \"This is suitable for demonstration only\",\n",
    "    \"network_data_format\": {\n",
    "        \"inputs\": {\n",
    "            \"image\": {\n",
    "                \"type\": \"image\",\n",
    "                \"format\": \"magnitude\",\n",
    "                \"modality\": \"MR\",\n",
    "                \"num_channels\": 1,\n",
    "                \"spatial_shape\": [160, 160, 160],\n",
    "                \"dtype\": \"float32\",\n",
    "                \"value_range\": [0, 1],\n",
    "                \"is_patch_data\": false,\n",
    "                \"channel_def\": {\"0\": \"image\"}\n",
    "            }\n",
    "        },\n",
    "        \"outputs\": {\n",
    "            \"pred\": {\n",
    "                \"type\": \"image\",\n",
    "                \"format\": \"labels\",\n",
    "                \"num_channels\": 2,\n",
    "                \"spatial_shape\": [160, 160, 160],\n",
    "                \"dtype\": \"float32\",\n",
    "                \"value_range\": [],\n",
    "                \"is_patch_data\": false,\n",
    "                \"channel_def\": {\"0\": \"background\", \"1\": \"spleen\"}\n",
    "            }\n",
    "        }\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f208bf8-0c3a-4def-ab0f-6091cebcd532",
   "metadata": {},
   "source": [
    "\n",
    "## Common Definitions\n",
    "\n",
    "What we'll now do is construct the bundle configuration scripts to implement training, testing, and inference based off the original script file given above. Common definitions should be placed in a common file used with other scripts to reduce duplication. In our original script, the network definition and transform sequence will be used in multiple places so should go in this common file:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d11681af-3210-4b2b-b7bd-8ad8dedfe230",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-16T14:56:36.558682685Z",
     "start_time": "2023-10-16T14:56:36.528064430Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing SpleenSegLightning/configs/common.yaml\n"
     ]
    }
   ],
   "source": [
    "%%writefile SpleenSegLightning/configs/common.yaml\n",
    "\n",
    "# common imports\n",
    "imports: \n",
    "- $import glob\n",
    "- $import os\n",
    "\n",
    "# define a default root directory value, this can \n",
    "# overridden on the command line\n",
    "bundle_dir: .\n",
    "data_dir: .\n",
    "\n",
    "# use constants from MONAI instead of hard-coding names\n",
    "image: $monai.utils.CommonKeys.IMAGE\n",
    "label: $monai.utils.CommonKeys.LABEL\n",
    "\n",
    "# define a train and validation files from the data directory\n",
    "train_images: '$sorted(glob.glob(os.path.join(@data_dir, ''imagesTr'', ''*.nii.gz'')))'\n",
    "train_labels: '$sorted(glob.glob(os.path.join(@data_dir, ''labelsTr'', ''*.nii.gz'')))'\n",
    "\n",
    "data_dicts: '$[{''image'': img, ''label'': lbl} for img, lbl in zip(@train_images, @train_labels)]'\n",
    "\n",
    "train_files: '$@data_dicts[:-9]'\n",
    "val_files: '$@data_dicts[-9:]'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60ee968cb538d983",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "# Scripts for training and evaluation\n",
    "\n",
    "We'll define the training and evaluation yaml files and scripts contained the Pytorch Lightning-based network. First, in the Python module `scripts`, we'll add `model.py` file containing the network definition:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2c15149785c2192",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-16T14:45:39.050765311Z",
     "start_time": "2023-10-16T14:45:39.031364249Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "lines_to_next_cell": 2
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing SpleenSegLightning/scripts/model.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile SpleenSegLightning/scripts/model.py\n",
    "\n",
    "import pytorch_lightning\n",
    "from monai.utils import set_determinism\n",
    "from monai.transforms import (\n",
    "    AsDiscrete,\n",
    "    Compose,\n",
    "    EnsureType,\n",
    ")\n",
    "from monai.networks.nets import UNet\n",
    "from monai.networks.layers import Norm\n",
    "from monai.metrics import DiceMetric\n",
    "from monai.losses import DiceLoss\n",
    "from monai.inferers import sliding_window_inference\n",
    "from monai.data import decollate_batch\n",
    "import torch\n",
    "\n",
    "\n",
    "class MySegNet(pytorch_lightning.LightningModule):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self._model = UNet(\n",
    "            spatial_dims=3,\n",
    "            in_channels=1,\n",
    "            out_channels=2,\n",
    "            channels=(16, 32, 64, 128, 256),\n",
    "            strides=(2, 2, 2, 2),\n",
    "            num_res_units=2,\n",
    "            norm=Norm.BATCH,\n",
    "        )\n",
    "        self.learning_rate = 1e-4\n",
    "        self.loss_function = DiceLoss(to_onehot_y=True, softmax=True)\n",
    "        self.post_pred = Compose([EnsureType(\"tensor\", device=\"cpu\"),\n",
    "                                  AsDiscrete(argmax=True, to_onehot=2)])\n",
    "        self.post_label = Compose([EnsureType(\"tensor\", device=\"cpu\"),\n",
    "                                   AsDiscrete(to_onehot=2)])\n",
    "        self.dice_metric = DiceMetric(include_background=False, reduction=\"mean\",\n",
    "                                      get_not_nans=False)\n",
    "        self.best_val_dice = 0\n",
    "        self.best_val_epoch = 0\n",
    "        self.validation_step_outputs = []\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self._model(x)\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        print(\"configure_optimizers\", self.learning_rate)\n",
    "        optimizer = torch.optim.Adam(self._model.parameters(), self.learning_rate)\n",
    "        return optimizer\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        images, labels = batch[\"image\"], batch[\"label\"]\n",
    "        output = self.forward(images)\n",
    "        loss = self.loss_function(output, labels)\n",
    "        tensorboard_logs = {\"train_loss\": loss.item()}\n",
    "        return {\"loss\": loss, \"log\": tensorboard_logs}\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        images, labels = batch[\"image\"], batch[\"label\"]\n",
    "        roi_size = (160, 160, 160)\n",
    "        sw_batch_size = 4\n",
    "        outputs = sliding_window_inference(images, roi_size, sw_batch_size, self.forward)\n",
    "        loss = self.loss_function(outputs, labels)\n",
    "        outputs = [self.post_pred(i) for i in decollate_batch(outputs)]\n",
    "        labels = [self.post_label(i) for i in decollate_batch(labels)]\n",
    "        self.dice_metric(y_pred=outputs, y=labels)\n",
    "        d = {\"val_loss\": loss, \"val_number\": len(outputs)}\n",
    "        self.validation_step_outputs.append(d)\n",
    "        return d\n",
    "\n",
    "    def on_validation_epoch_end(self):\n",
    "        val_loss, num_items = 0, 0\n",
    "        for output in self.validation_step_outputs:\n",
    "            val_loss += output[\"val_loss\"].sum().item()\n",
    "            num_items += output[\"val_number\"]\n",
    "        mean_val_dice = self.dice_metric.aggregate().item()\n",
    "        self.dice_metric.reset()\n",
    "        mean_val_loss = torch.tensor(val_loss / num_items)\n",
    "        tensorboard_logs = {\n",
    "            \"val_dice\": mean_val_dice,\n",
    "            \"val_loss\": mean_val_loss,\n",
    "        }\n",
    "        if mean_val_dice > self.best_val_dice:\n",
    "            self.best_val_dice = mean_val_dice\n",
    "            self.best_val_epoch = self.current_epoch\n",
    "        print(\n",
    "            f\"current epoch: {self.current_epoch} \"\n",
    "            f\"current mean dice: {mean_val_dice:.4f}\"\n",
    "            f\"\\nbest mean dice: {self.best_val_dice:.4f} \"\n",
    "            f\"at epoch: {self.best_val_epoch}\"\n",
    "        )\n",
    "        self.validation_step_outputs.clear()  # free memory\n",
    "        return {\"log\": tensorboard_logs}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "92e303feb8d4edca",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "lines_to_next_cell": 0
   },
   "source": [
    "Next, we'll create a `main.py` file to house the training and evaluation scripts. In this example, we use the `lightning_param` dictionary to customize some default arguments in the PyTorch Lightning `Trainer` class. We've set `num_nodes` and `devices` to 1, turned off the sanity checking (`num_sanity_val_steps=0`), and logged the training for every 3 steps (`log_every_n_steps=3`) for demonstration purposes. For more information about the PyTorch Lightning `Trainer` arguments, please refer to the following [link](https://lightning.ai/docs/pytorch/stable/common/trainer.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d49daec7d4ce0b75",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-16T14:45:48.900509600Z",
     "start_time": "2023-10-16T14:45:48.887211086Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing SpleenSegLightning/scripts/main.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile SpleenSegLightning/scripts/main.py\n",
    "\n",
    "from scripts.model import MySegNet\n",
    "import pytorch_lightning\n",
    "\n",
    "def train(lightninig_param, train_dl, val_dl):\n",
    "    net = MySegNet()\n",
    "    trainer = pytorch_lightning.Trainer(max_epochs=lightninig_param['max_epochs'], \n",
    "                                        default_root_dir=lightninig_param['default_root_dir'],\n",
    "                                        check_val_every_n_epoch=lightninig_param['check_val_every_n_epoch'],\n",
    "                                        devices=1, num_nodes=1, log_every_n_steps=3, num_sanity_val_steps=0)\n",
    "    trainer.fit(model=net, train_dataloaders=train_dl, val_dataloaders=val_dl)\n",
    "\n",
    "\n",
    "def evaluate(lightninig_param, ckpt_file, val_dl):\n",
    "    net = MySegNet()\n",
    "    trainer = pytorch_lightning.Trainer(default_root_dir=lightninig_param['default_root_dir'],\n",
    "                                        devices=1, num_nodes=1)\n",
    "    trainer.validate(model=net, dataloaders=val_dl, ckpt_path=ckpt_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eaf81ea7-9ea3-4548-a32e-992f0b9bc0ab",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "## Training\n",
    "Now, we'll define a `train.yaml` file to be used to set the configurations for the training stage:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4dfd052e-abe7-473a-bbf4-25674a3b20ea",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-16T14:55:44.372511589Z",
     "start_time": "2023-10-16T14:55:44.304832953Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing SpleenSegLightning/configs/train.yaml\n"
     ]
    }
   ],
   "source": [
    "%%writefile SpleenSegLightning/configs/train.yaml\n",
    "\n",
    "imports:\n",
    "- $from scripts.main import train\n",
    "- $import glob\n",
    "- $import os\n",
    "\n",
    "# define a default root directory value, this can overridden on the command line\n",
    "bundle_dir: .\n",
    "data_dir: .\n",
    "\n",
    "# define hyperparameters for the lightning trainer\n",
    "max_epochs: 50\n",
    "default_root_dir: $@bundle_dir+\"/logs\"\n",
    "check_val_every_n_epoch: 1\n",
    "\n",
    "lightninig_param:  '${\n",
    "    ''max_epochs'': @max_epochs,\n",
    "    ''default_root_dir'': @default_root_dir,\n",
    "    ''check_val_every_n_epoch'': @check_val_every_n_epoch,\n",
    "}'\n",
    "\n",
    "\n",
    "# define a transform sequence by instantiating a Compose instance with a transform sequence\n",
    "train_transform:\n",
    "  _target_: Compose\n",
    "  transforms:\n",
    "  - _target_: LoadImaged\n",
    "    keys: ['@image','@label']\n",
    "    image_only: true\n",
    "  - _target_: EnsureChannelFirstd\n",
    "    keys:  ['@image','@label']\n",
    "  - _target_: Orientationd\n",
    "    keys:  ['@image','@label']\n",
    "    axcodes: 'RAS'\n",
    "  - _target_: Spacingd\n",
    "    keys:  ['@image','@label']\n",
    "    pixdim: [1.5, 1.5, 2.0]\n",
    "  - _target_: ScaleIntensityRanged\n",
    "    keys: '@image'\n",
    "    a_min: -57\n",
    "    a_max: 164\n",
    "    b_min: 0.0\n",
    "    b_max: 1.0\n",
    "    clip: True\n",
    "  - _target_: CropForegroundd\n",
    "    keys: ['@image','@label']\n",
    "    allow_smaller: False\n",
    "    source_key: '@image'\n",
    "  - _target_: RandCropByPosNegLabeld\n",
    "    keys: ['@image','@label']\n",
    "    label_key: '@label'\n",
    "    spatial_size: [96, 96, 96]\n",
    "    pos: 1\n",
    "    neg: 1\n",
    "    num_samples: 4\n",
    "    image_key: '@image'\n",
    "    image_threshold: 0\n",
    "\n",
    "val_transform:\n",
    "  _target_: Compose\n",
    "  transforms:\n",
    "  - _target_: LoadImaged\n",
    "    keys: ['@image','@label']\n",
    "    image_only: true\n",
    "  - _target_: EnsureChannelFirstd\n",
    "    keys: ['@image','@label']\n",
    "  - _target_: Orientationd\n",
    "    keys: ['@image','@label']\n",
    "    axcodes: 'RAS'\n",
    "  - _target_: Spacingd\n",
    "    keys: ['@image','@label']\n",
    "    pixdim: [1.5, 1.5, 2.0]\n",
    "  - _target_: ScaleIntensityRanged\n",
    "    keys: '@image'\n",
    "    a_min: -57\n",
    "    a_max: 164\n",
    "    b_min: 0.0\n",
    "    b_max: 1.0\n",
    "    clip: True\n",
    "  - _target_: CropForegroundd\n",
    "    keys: ['@image','@label']\n",
    "    source_key: '@image'\n",
    "    allow_smaller: False\n",
    "\n",
    "val_dataset:\n",
    "  _target_: CacheDataset\n",
    "  data: '@val_files'\n",
    "  transform: '@val_transform'\n",
    "  cache_rate: 1.0\n",
    "  num_workers: 4\n",
    "\n",
    "train_dataset:\n",
    "  _target_: CacheDataset\n",
    "  data: '@train_files'\n",
    "  transform: '@train_transform'\n",
    "  cache_rate: 1.0\n",
    "  num_workers: 4\n",
    "  \n",
    "train_dl:\n",
    "  _target_: DataLoader\n",
    "  dataset: '@train_dataset'\n",
    "  batch_size: 1\n",
    "  shuffle: true\n",
    "  num_workers: 4\n",
    "  \n",
    "val_dl:\n",
    "  _target_: DataLoader\n",
    "  dataset: '@val_dataset'\n",
    "  batch_size: 1\n",
    "  shuffle: false\n",
    "  num_workers: 4\n",
    "\n",
    "train:\n",
    "- '$train(@lightninig_param, @train_dl, @val_dl)'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de752181-80b1-4221-9e4a-315e5f7f22a6",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "We can now train as normal to replicate the original code. For demonstration purpose, we set `max_epochs=1`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1d8ac6fd81493874",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-16T14:57:13.955241252Z",
     "start_time": "2023-10-16T14:57:05.329343826Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2024-09-18 08:33:02,488 - INFO - --- input summary of monai.bundle.scripts.run ---\n",
      "2024-09-18 08:33:02,488 - INFO - > config_file: ['./SpleenSegLightning/configs/common.yaml',\n",
      " './SpleenSegLightning/configs/train.yaml']\n",
      "2024-09-18 08:33:02,488 - INFO - > meta_file: './SpleenSegLightning/configs/metadata.json'\n",
      "2024-09-18 08:33:02,488 - INFO - > run_id: 'train'\n",
      "2024-09-18 08:33:02,488 - INFO - > bundle_dir: './SpleenSegLightning'\n",
      "2024-09-18 08:33:02,488 - INFO - > data_dir: '/tmp/tmphbjnt0zm/Task09_Spleen'\n",
      "2024-09-18 08:33:02,488 - INFO - > max_epochs: 1\n",
      "2024-09-18 08:33:02,488 - INFO - ---\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading dataset: 100%|██████████| 32/32 [00:34<00:00,  1.07s/it]\n",
      "Loading dataset: 100%|██████████| 9/9 [00:07<00:00,  1.25it/s]\n",
      "INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True\n",
      "INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n",
      "INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs\n",
      "INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs\n",
      "INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "configure_optimizers 0.0001\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:pytorch_lightning.callbacks.model_summary:\n",
      "  | Name          | Type     | Params\n",
      "-------------------------------------------\n",
      "0 | _model        | UNet     | 4.8 M \n",
      "1 | loss_function | DiceLoss | 0     \n",
      "-------------------------------------------\n",
      "4.8 M     Trainable params\n",
      "0         Non-trainable params\n",
      "4.8 M     Total params\n",
      "19.236    Total estimated model params size (MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0: 100%|██████████| 32/32 [00:03<00:00, 10.30it/s, v_num=0]\n",
      "Validation: 0it [00:00, ?it/s]\u001b[A\n",
      "Validation:   0%|          | 0/9 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|          | 0/9 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:  11%|█         | 1/9 [00:00<00:03,  2.44it/s]\u001b[A\n",
      "Validation DataLoader 0:  22%|██▏       | 2/9 [00:01<00:04,  1.44it/s]\u001b[A\n",
      "Validation DataLoader 0:  33%|███▎      | 3/9 [00:02<00:04,  1.38it/s]\u001b[A\n",
      "Validation DataLoader 0:  44%|████▍     | 4/9 [00:03<00:04,  1.09it/s]\u001b[A\n",
      "Validation DataLoader 0:  56%|█████▌    | 5/9 [00:05<00:04,  1.02s/it]\u001b[A\n",
      "Validation DataLoader 0:  67%|██████▋   | 6/9 [00:05<00:02,  1.03it/s]\u001b[A\n",
      "Validation DataLoader 0:  78%|███████▊  | 7/9 [00:06<00:01,  1.06it/s]\u001b[A\n",
      "Validation DataLoader 0:  89%|████████▉ | 8/9 [00:07<00:00,  1.11it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|██████████| 9/9 [00:07<00:00,  1.17it/s]\u001b[Acurrent epoch: 0 current mean dice: 0.0307\n",
      "best mean dice: 0.0307 at epoch: 0\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0: 100%|██████████| 32/32 [00:11<00:00,  2.78it/s, v_num=0]     \u001b[A\n"
     ]
    }
   ],
   "source": [
    "%%bash\n",
    "\n",
    "BUNDLE=\"./SpleenSegLightning\"\n",
    "export PYTHONPATH=\"$BUNDLE\"\n",
    "\n",
    "# run the bundle with epochs set to 1 for speed during testing, change this to get a better result\n",
    "python -m monai.bundle run train \\\n",
    "    --bundle_dir \"$BUNDLE\" \\\n",
    "    --data_dir \"$DATA_DIR\" \\\n",
    "    --meta_file \"$BUNDLE/configs/metadata.json\" \\\n",
    "    --config_file \"['$BUNDLE/configs/common.yaml','$BUNDLE/configs/train.yaml']\" \\\n",
    "    --max_epochs 1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a84fd22d-4e68-41c2-b133-ee497810ee64",
   "metadata": {},
   "source": [
    "The trained model is inside the subdir `lightning_logs` which the parent folder is defined in the yaml file as `default_root_dir`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c5ba337d-a5b0-47de-9ae2-1554a2cb4f86",
   "metadata": {},
    "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/usr/bin/tree\n",
      "\u001b[01;34mSpleenSegLightning\u001b[00m\n",
      "├── \u001b[01;34mconfigs\u001b[00m\n",
      "│   ├── common.yaml\n",
      "│   ├── metadata.json\n",
      "│   └── train.yaml\n",
      "├── \u001b[01;34mdocs\u001b[00m\n",
      "│   └── README.md\n",
      "├── LICENSE\n",
      "├── \u001b[01;34mlogs\u001b[00m\n",
      "│   └── \u001b[01;34mlightning_logs\u001b[00m\n",
      "│       └── \u001b[01;34mversion_0\u001b[00m\n",
      "│           ├── \u001b[01;34mcheckpoints\u001b[00m\n",
      "│           │   └── epoch=0-step=32.ckpt\n",
      "│           ├── events.out.tfevents.1697764071.sie082-pc.29342.0\n",
      "│           └── hparams.yaml\n",
      "├── \u001b[01;34mmodels\u001b[00m\n",
      "└── \u001b[01;34mscripts\u001b[00m\n",
      "    ├── __init__.py\n",
      "    ├── main.py\n",
      "    ├── model.py\n",
      "    └── \u001b[01;34m__pycache__\u001b[00m\n",
      "        ├── __init__.cpython-39.pyc\n",
      "        ├── main.cpython-39.pyc\n",
      "        └── model.cpython-39.pyc\n",
      "\n",
      "9 directories, 14 files\n"
     ]
    }
   ],
   "source": [
    "!which tree && tree SpleenSegLightning || true"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bbf58fac-b6d5-424d-9e98-1a30937f2116",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "## Evaluation\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abf40c4f-3349-4c40-9eef-811388ffd704",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-06T15:08:30.839184117Z",
     "start_time": "2023-10-06T15:08:30.804583579Z"
    }
   },
   "source": [
    "Here we defined `evaluate` script to reproduce the results from the original code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b4e1f99a-a68b-4aeb-bcf2-842f26609b52",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-16T13:48:20.810381132Z",
     "start_time": "2023-10-16T13:48:20.802106081Z"
    },
    "lines_to_next_cell": 2
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing SpleenSegLightning/configs/evaluate.yaml\n"
     ]
    }
   ],
   "source": [
    "%%writefile SpleenSegLightning/configs/evaluate.yaml\n",
    "\n",
    "# common imports\n",
    "imports:\n",
    "- $from scripts.main import evaluate\n",
    "- $import glob\n",
    "- $import os\n",
    "\n",
    "ckpt_file: \"\"\n",
    "\n",
    "# define hyperparameters for the lightning trainer\n",
    "default_root_dir: $@bundle_dir+\"/logs\"\n",
    "lightninig_param:  '${''default_root_dir'': @default_root_dir,}'\n",
    "\n",
    "\n",
    "val_transform:\n",
    "  _target_: Compose\n",
    "  transforms:\n",
    "  - _target_: LoadImaged\n",
    "    keys: ['@image','@label']\n",
    "    image_only: true\n",
    "  - _target_: EnsureChannelFirstd\n",
    "    keys: ['@image','@label']\n",
    "  - _target_: Orientationd\n",
    "    keys: ['@image','@label']\n",
    "    axcodes: 'RAS'\n",
    "  - _target_: Spacingd\n",
    "    keys: ['@image','@label']\n",
    "    pixdim: [1.5, 1.5, 2.0]\n",
    "  - _target_: ScaleIntensityRanged\n",
    "    keys: '@image'\n",
    "    a_min: -57\n",
    "    a_max: 164\n",
    "    b_min: 0.0\n",
    "    b_max: 1.0\n",
    "    clip: True\n",
    "  - _target_: CropForegroundd\n",
    "    keys: ['@image','@label']\n",
    "    source_key: '@image'\n",
    "    allow_smaller: False\n",
    "\n",
    "val_dataset:\n",
    "  _target_: CacheDataset\n",
    "  data: '@val_files'\n",
    "  transform: '@val_transform'\n",
    "  cache_rate: 1.0\n",
    "  num_workers: 4\n",
    " \n",
    "val_dl:\n",
    "  _target_: DataLoader\n",
    "  dataset: '@val_dataset'\n",
    "  batch_size: 1\n",
    "  shuffle: false\n",
    "  num_workers: 4\n",
    "\n",
    "  \n",
    "# loads the weights from the given file (which needs to be set on the command line) then calls \"evaluate\" script\n",
    "evaluate:\n",
    "- '$evaluate(@lightninig_param,@ckpt_file, @val_dl)'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "64bb2286-3107-49e9-8dbe-66fe1a2ae08c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-06T15:08:57.169262881Z",
     "start_time": "2023-10-06T15:08:35.934836977Z"
    }
   },
   "source": [
    "Evaluation is then run on the command line, using \"evaluate\" as the program to run and providing a path to the model weights with the `ckpt_file` and `data_dir` variables. We'll use the previous model trained for one epoch for demonstration purposes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "fb5a28d3-ebe4-476a-b8f9-7860a0452d99",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-09T15:42:47.723276002Z",
     "start_time": "2023-10-09T15:42:10.080244205Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2024-09-18 08:34:06,379 - INFO - --- input summary of monai.bundle.scripts.run ---\n",
      "2024-09-18 08:34:06,379 - INFO - > config_file: ['./SpleenSegLightning/configs/common.yaml',\n",
      " './SpleenSegLightning/configs/evaluate.yaml']\n",
      "2024-09-18 08:34:06,379 - INFO - > meta_file: './SpleenSegLightning/configs/metadata.json'\n",
      "2024-09-18 08:34:06,379 - INFO - > run_id: 'evaluate'\n",
      "2024-09-18 08:34:06,380 - INFO - > bundle_dir: './SpleenSegLightning'\n",
      "2024-09-18 08:34:06,380 - INFO - > data_dir: '/tmp/tmphbjnt0zm/Task09_Spleen'\n",
      "2024-09-18 08:34:06,380 - INFO - > ckpt_file: './SpleenSegLightning/logs/lightning_logs/version_0/checkpoints/epoch=0-step=32.ckpt'\n",
      "2024-09-18 08:34:06,380 - INFO - ---\n",
      "\n",
      "\n",
      "2024-09-18 08:34:06,381 - WARNING - Default logging file in SpleenSegLightning/configs/logging.conf does not exist, skipping logging.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading dataset: 100%|██████████| 9/9 [00:07<00:00,  1.28it/s]\n",
      "INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True\n",
      "INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n",
      "INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs\n",
      "INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs\n",
      "INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at ./SpleenSegLightning/logs/lightning_logs/version_0/checkpoints/epoch=0-step=32.ckpt\n",
      "INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n",
      "INFO:pytorch_lightning.utilities.rank_zero:Loaded model weights from the checkpoint at ./SpleenSegLightning/logs/lightning_logs/version_0/checkpoints/epoch=0-step=32.ckpt\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation DataLoader 0: 100%|██████████| 9/9 [00:07<00:00,  1.13it/s]current epoch: 0 current mean dice: 0.0307\n",
      "best mean dice: 0.0307 at epoch: 0\n",
      "Validation DataLoader 0: 100%|██████████| 9/9 [00:07<00:00,  1.13it/s]\n"
     ]
    }
   ],
   "source": [
    "%%bash\n",
    "\n",
    "BUNDLE=\"./SpleenSegLightning\"\n",
    "CKPT_FILE=$(find \"$BUNDLE\" -type f -name \"*.ckpt\" | head -n 1)\n",
    "export PYTHONPATH=\"$BUNDLE\"\n",
    "\n",
    "python -m monai.bundle run evaluate \\\n",
    "    --bundle_dir \"$BUNDLE\" \\\n",
    "    --data_dir \"$DATA_DIR\" \\\n",
    "    --meta_file \"$BUNDLE/configs/metadata.json\" \\\n",
    "    --config_file \"['$BUNDLE/configs/common.yaml','$BUNDLE/configs/evaluate.yaml']\" \\\n",
    "    --ckpt_file \"$CKPT_FILE\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e930919d-61fb-4e1b-bb00-8f1bbb69acf1",
   "metadata": {},
   "source": [
    "## Cleanup data directory"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e98026d-998d-4645-83ec-afe3a448b3e5",
   "metadata": {},
   "source": [
    "Remove directory if a temporary was used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "b273214f-7653-468c-b311-fa8d097b5bee",
   "metadata": {},
   "outputs": [],
   "source": [
    "shutil.rmtree(root_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fd62905-4ea8-4f08-bcea-823074fc4ce4",
   "metadata": {},
   "source": [
    "## Summary and Next\n",
    "\n",
    "This tutorial has covered:\n",
    "* Creating full training and evaluation scripts in bundles using MONAI and Pytorch Lightning\n",
    "* Training a network then evaluating its performance with scripts."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}