{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5c5c3f1c",
   "metadata": {},
   "source": [
    "# Imputation using random draws from shifted normal distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e38874f3",
   "metadata": {
    "lines_to_next_cell": 2,
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import logging\n",
    "\n",
    "import pandas as pd\n",
    "from IPython.display import display\n",
    "\n",
    "import pimmslearn\n",
    "import pimmslearn.imputation\n",
    "import pimmslearn.model\n",
    "import pimmslearn.models as models\n",
    "import pimmslearn.nb\n",
    "from pimmslearn.io import datasplits\n",
    "\n",
    "logger = pimmslearn.logging.setup_logger(logging.getLogger('pimmslearn'))\n",
    "logger.info(\"Median Imputation\")\n",
    "\n",
    "figures = {}  # collection of ax or figures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca2c3fb3",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# catch passed parameters\n",
    "args = None\n",
    "args = dict(globals()).keys()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23955adb",
   "metadata": {},
   "source": [
    "Papermill script parameters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33375547",
   "metadata": {
    "lines_to_next_cell": 2,
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "# files and folders\n",
    "# Datasplit folder with data for experiment\n",
    "folder_experiment: str = 'runs/example'\n",
    "file_format: str = 'csv'  # file format of create splits, default pickle (pkl)\n",
    "# Machine parsed metadata from rawfile workflow\n",
    "fn_rawfile_metadata: str = 'data/dev_datasets/HeLa_6070/files_selected_metadata_N50.csv'\n",
    "# model\n",
    "sample_idx_position: int = 0  # position of index which is sample ID\n",
    "# model key (lower cased version will be used for file names)\n",
    "axis: int = 1  # impute per row/sample (1) or per column/feat (0).\n",
    "completeness = 0.6  # fractio of non missing values for row/sample (axis=0) or column/feat (axis=1)\n",
    "model_key: str = 'RSN'\n",
    "model: str = 'RSN'  # model name\n",
    "save_pred_real_na: bool = True  # Save all predictions for real na\n",
    "# metadata -> defaults for metadata extracted from machine data\n",
    "meta_date_col: str = None  # date column in meta data\n",
    "meta_cat_col: str = None  # category column in meta data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "924a4f0d",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "source": [
    "Some argument transformations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36f708a2",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "args = pimmslearn.nb.get_params(args, globals=globals())\n",
    "args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4bb6bf2",
   "metadata": {
    "lines_to_next_cell": 2,
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "args = pimmslearn.nb.args_from_dict(args)\n",
    "args"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c467efb4",
   "metadata": {},
   "source": [
    "Some naming conventions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3ded735",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "TEMPLATE_MODEL_PARAMS = 'model_params_{}.json'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dea94507",
   "metadata": {},
   "source": [
    "## Load data in long format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92d787e1",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "data = datasplits.DataSplits.from_folder(\n",
    "    args.data, file_format=args.file_format)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "191bcf6f",
   "metadata": {},
   "source": [
    "data is loaded in long format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbd2df5b",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "data.train_X.sample(5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f597ca76",
   "metadata": {},
   "source": [
    "Infer index names from long format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8e2b780",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "index_columns = list(data.train_X.index.names)\n",
    "sample_id = index_columns.pop(args.sample_idx_position)\n",
    "if len(index_columns) == 1:\n",
    "    index_column = index_columns.pop()\n",
    "    index_columns = None\n",
    "    logger.info(f\"{sample_id = }, single feature: {index_column = }\")\n",
    "else:\n",
    "    logger.info(f\"{sample_id = }, multiple features: {index_columns = }\")\n",
    "\n",
    "if not index_columns:\n",
    "    index_columns = [sample_id, index_column]\n",
    "else:\n",
    "    raise NotImplementedError(\n",
    "        \"More than one feature: Needs to be implemented. see above logging output.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29247b80",
   "metadata": {},
   "source": [
    "load meta data for splits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cd8cc67",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "if args.fn_rawfile_metadata:\n",
    "    df_meta = pd.read_csv(args.fn_rawfile_metadata, index_col=0)\n",
    "    display(df_meta.loc[data.train_X.index.levels[0]])\n",
    "else:\n",
    "    df_meta = None"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23f75036",
   "metadata": {},
   "source": [
    "## Initialize Comparison\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63a9a8c8",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "freq_feat = pimmslearn.io.datasplits.load_freq(args.data)\n",
    "freq_feat.head()  # training data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "922d0a8d",
   "metadata": {},
   "source": [
    "### Produce some addional fake samples"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ac69d00",
   "metadata": {},
   "source": [
    "The validation simulated NA is used to by all models to evaluate training performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5855a725",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "val_pred_fake_na = data.val_y.to_frame(name='observed')\n",
    "val_pred_fake_na"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e0ae839",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "test_pred_fake_na = data.test_y.to_frame(name='observed')\n",
    "test_pred_fake_na.describe()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4208681",
   "metadata": {},
   "source": [
    "## Data in wide format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8b41aae",
   "metadata": {
    "lines_to_next_cell": 2,
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "data.to_wide_format()\n",
    "args.M = data.train_X.shape[-1]\n",
    "data.train_X.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7bf83f1",
   "metadata": {},
   "source": [
    "### Impute using shifted normal distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f5349d6",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "imputed_shifted_normal = pimmslearn.imputation.impute_shifted_normal(\n",
    "    data.train_X,\n",
    "    mean_shift=1.8,\n",
    "    std_shrinkage=0.3,\n",
    "    completeness=args.completeness,\n",
    "    axis=args.axis)\n",
    "imputed_shifted_normal = imputed_shifted_normal.to_frame('intensity')\n",
    "imputed_shifted_normal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d32d445e",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "val_pred_fake_na[args.model] = imputed_shifted_normal\n",
    "test_pred_fake_na[args.model] = imputed_shifted_normal\n",
    "val_pred_fake_na"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72a5da63",
   "metadata": {},
   "source": [
    "Save predictions for NA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3198a37c",
   "metadata": {
    "lines_to_next_cell": 0,
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "if args.save_pred_real_na:\n",
    "    mask = data.train_X.isna().stack()\n",
    "    idx_real_na = mask.index[mask]\n",
    "    idx_real_na = (idx_real_na\n",
    "                   .drop(val_pred_fake_na.index)\n",
    "                   .drop(test_pred_fake_na.index))\n",
    "    # hacky, but works:\n",
    "    pred_real_na = (pd.Series(0, index=idx_real_na, name='placeholder')\n",
    "                    .to_frame()\n",
    "                    .join(imputed_shifted_normal)\n",
    "                    .drop('placeholder', axis=1))\n",
    "    # pred_real_na.name = 'intensity'\n",
    "    display(pred_real_na)\n",
    "    pred_real_na.to_csv(args.out_preds / f\"pred_real_na_{args.model_key}.csv\")\n",
    "\n",
    "\n",
    "# # %% [markdown]\n",
    "# ### Plots\n",
    "#"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df99da67",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "ax, _ = pimmslearn.plotting.errors.plot_errors_binned(val_pred_fake_na)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16637d79",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "ax, _ = pimmslearn.plotting.errors.plot_errors_binned(test_pred_fake_na)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c0f7424",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "source": [
    "## Comparisons"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17027ab0",
   "metadata": {},
   "source": [
    "### Validation data\n",
    "\n",
    "- all measured (identified, observed) peptides in validation data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43d42650",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# papermill_description=metrics\n",
    "d_metrics = models.Metrics()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f5d7459",
   "metadata": {},
   "source": [
    "The fake NA for the validation step are real test data (not used for training nor early stopping)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed0498d0",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "added_metrics = d_metrics.add_metrics(val_pred_fake_na, 'valid_fake_na')\n",
    "added_metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8fa6a5f7",
   "metadata": {},
   "source": [
    "### Test Datasplit\n",
    "\n",
    "Fake NAs : Artificially created NAs. Some data was sampled and set\n",
    "explicitly to misssing before it was fed to the model for\n",
    "reconstruction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ee61d53",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "added_metrics = d_metrics.add_metrics(test_pred_fake_na, 'test_fake_na')\n",
    "added_metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "832e3f7c",
   "metadata": {},
   "source": [
    "The fake NA for the validation step are real test data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5a7a319",
   "metadata": {},
   "source": [
    "### Save all metrics as json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9973b3ee",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "pimmslearn.io.dump_json(d_metrics.metrics, args.out_metrics /\n",
    "                        f'metrics_{args.model_key}.json')\n",
    "d_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b2421c3",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "metrics_df = models.get_df_from_nested_dict(\n",
    "    d_metrics.metrics, column_levels=['model', 'metric_name']).T\n",
    "metrics_df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b296a6a1",
   "metadata": {},
   "source": [
    "## Save predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39c41bcd",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# val\n",
    "fname = args.out_preds / f\"pred_val_{args.model_key}.csv\"\n",
    "setattr(args, fname.stem, fname.as_posix())  # add [] assignment?\n",
    "val_pred_fake_na.to_csv(fname)\n",
    "# test\n",
    "fname = args.out_preds / f\"pred_test_{args.model_key}.csv\"\n",
    "setattr(args, fname.stem, fname.as_posix())\n",
    "test_pred_fake_na.to_csv(fname)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6a95d00",
   "metadata": {},
   "source": [
    "## Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f2f7404",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "figures  # switch to fnames?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ad37d15",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "args.dump(fname=args.out_models / f\"model_config_{args.model_key}.yaml\")\n",
    "args"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "formats": "ipynb,py:percent"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}