import wandb
import numpy as np
import os
from torch import autocast
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import argparse
from sklearn import metrics
from contextlib import nullcontext
import torch.nn as nn
import torch.nn.functional as F
from torch.hub import download_url_to_file
import pickle
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import LearningRateMonitor

from datasets.audioset import get_test_set, get_full_training_set, get_ft_weighted_sampler
from models.mn.model import get_model as get_mobilenet
from models.dymn.model import get_model as get_dymn
from models.ensemble import get_ensemble_model
from models.preprocess import AugmentMelSTFT
from helpers.init import worker_init_fn
from helpers.utils import NAME_TO_WIDTH, exp_warmup_linear_down, mixup

preds_url = \
    "https://github.com/fschmid56/EfficientAT/releases/download/v0.0.1/passt_enemble_logits_mAP_495.npy"

fname_to_index_url = "https://github.com/fschmid56/EfficientAT/releases/download/v0.0.1/fname_to_index.pkl"


class PLModule(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        # model to preprocess waveform to mel spectrograms
        self.mel = AugmentMelSTFT(n_mels=config.n_mels,
                         sr=config.resample_rate,
                         win_length=config.window_size,
                         hopsize=config.hop_size,
                         n_fft=config.n_fft,
                         freqm=config.freqm,
                         timem=config.timem,
                         fmin=config.fmin,
                         fmax=config.fmax,
                         fmin_aug_range=config.fmin_aug_range,
                         fmax_aug_range=config.fmax_aug_range
                         )

        # load prediction model
        model_name = config.model_name
        pretrained_name = model_name if config.pretrained else None
        width = NAME_TO_WIDTH(model_name) if model_name and config.pretrained else config.model_width
        if model_name.startswith("dymn"):
            model = get_dymn(width_mult=width, pretrained_name=pretrained_name,
                             strides=config.strides, pretrain_final_temp=config.pretrain_final_temp)
        else:
            model = get_mobilenet(width_mult=width, pretrained_name=pretrained_name,
                                  strides=config.strides, head_type=config.head_type, se_dims=config.se_dims)
        self.model = model

        # prepare ingredients for knowledge distillation
        assert 0 <= config.kd_lambda <= 1, "Lambda for Knowledge Distillation must be between 0 and 1."
        self.distillation_loss = nn.BCEWithLogitsLoss(reduction="none")

        # load stored teacher predictions
        if not os.path.isfile(config.teacher_preds):
            # download file
            print("Download teacher predictions...")
            download_url_to_file(preds_url, config.teacher_preds)
        print(f"Load teacher predictions from file {config.teacher_preds}")
        teacher_preds = np.load(config.teacher_preds)
        teacher_preds = torch.from_numpy(teacher_preds).float()
        teacher_preds = torch.sigmoid(teacher_preds / config.temperature)
        teacher_preds.requires_grad = False
        self.teacher_preds = teacher_preds

        if not os.path.isfile(config.fname_to_index):
            print("Download filename to teacher prediction index dictionary...")
            download_url_to_file(fname_to_index_url, config.fname_to_index)
        with open(config.fname_to_index, 'rb') as f:
            fname_to_index = pickle.load(f)
        self.fname_to_index = fname_to_index

        self.distributed_mode = config.num_devices > 1
        self.training_step_outputs = []
        self.validation_step_outputs = []

    def mel_forward(self, x):
        old_shape = x.size()
        x = x.reshape(-1, old_shape[2])
        x = self.mel(x)
        x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2])
        return x

    def forward(self, x):
        """
        :param x: batch of raw audio signals (waveforms)
        :return: final model predictions
        """
        x = self.mel_forward(x)
        x = self.model(x)
        return x

    def configure_optimizers(self):
        """
        This is the way pytorch lightening requires optimizers and learning rate schedulers to be defined.
        The specified items are used automatically in the optimization loop (no need to call optimizer.step() yourself).
        :return: dict containing optimizer and learning rate scheduler
        """
        if self.config.adamw:
            optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.config.max_lr,
                                          weight_decay=self.config.weight_decay)
        else:
            optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.max_lr,
                                         weight_decay=self.config.weight_decay)

        # phases of lr schedule: exponential increase, constant lr, linear decrease, fine-tune
        schedule_lambda = \
            exp_warmup_linear_down(self.config.warm_up_len, self.config.ramp_down_len, self.config.ramp_down_start,
                                   self.config.last_lr_value)
        lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, schedule_lambda)

        return {
            'optimizer': optimizer,
            'lr_scheduler': lr_scheduler
        }

    def on_train_epoch_start(self):
        # in case of DyMN: update DyConv temperature
        if hasattr(self.model, "update_params"):
            self.model.update_params(self.current_epoch)

    def training_step(self, train_batch, batch_idx):
        """
        :param train_batch: contains one batch from train dataloader
        :param batch_idx
        :return: a dict containing at least loss that is used to update model parameters, can also contain
                    other items that can be processed in 'training_epoch_end' to log other metrics than loss
        """
        x, f, y, i = train_batch
        bs = x.size(0)
        x = self.mel_forward(x)

        rn_indices, lam = None, None
        if self.config.mixup_alpha:
            rn_indices, lam = mixup(bs, self.config.mixup_alpha)
            lam = lam.to(x.device)
            x = x * lam.reshape(bs, 1, 1, 1) + \
                x[rn_indices] * (1. - lam.reshape(bs, 1, 1, 1))
            y_hat, _ = self.model(x)
            y_mix = y * lam.reshape(bs, 1) + y[rn_indices] * (1. - lam.reshape(bs, 1))
            samples_loss = F.binary_cross_entropy_with_logits(y_hat, y_mix, reduction="none")
        else:
            y_hat, _ = self.model(x)
            samples_loss = F.binary_cross_entropy_with_logits(y_hat, y, reduction="none")

        # hard label loss
        label_loss = samples_loss.mean()

        # distillation loss
        if self.config.kd_lambda > 0:
            # fetch the correct index in 'teacher_preds' for given filename
            # insert -1 for files not in fname_to_index (proportion of files successfully downloaded from
            # YouTube can vary for AudioSet)
            indices = torch.tensor(
                [self.fname_to_index[fname] if fname in self.fname_to_index else -1 for fname in f], dtype=torch.int64
            )
            # get indices of files we could not find the teacher predictions for
            unknown_indices = indices == -1
            y_soft_teacher = self.teacher_preds[indices]
            y_soft_teacher = y_soft_teacher.to(y_hat.device).type_as(y_hat)

            if self.config.mixup_alpha:
                soft_targets_loss = \
                    self.distillation_loss(y_hat, y_soft_teacher).mean(dim=1) * lam.reshape(bs) + \
                    self.distillation_loss(y_hat, y_soft_teacher[rn_indices]).mean(dim=1) \
                    * (1. - lam.reshape(bs))
            else:
                soft_targets_loss = distillation_loss(y_hat, y_soft_teacher)

            # zero out loss for samples we don't have teacher predictions for
            soft_targets_loss[unknown_indices] = soft_targets_loss[unknown_indices] * 0
            soft_targets_loss = soft_targets_loss.mean()

            # weighting losses
            label_loss = self.config.kd_lambda * label_loss
            soft_targets_loss = (1 - self.config.kd_lambda) * soft_targets_loss
        else:
            soft_targets_loss = torch.tensor(0., device=label_loss.device, dtype=label_loss.dtype)

        # total loss is sum of lambda-weighted label and distillation loss
        loss = label_loss + soft_targets_loss

        results = {"loss": loss.detach().cpu(), "label_loss": label_loss.detach().cpu(),
                   "kd_loss": soft_targets_loss.detach().cpu()}
        self.training_step_outputs.append(results)
        return loss

    def on_train_epoch_end(self):
        """
        :return: a dict containing the metrics you want to log to Weights and Biases
        """
        avg_loss = torch.stack([x['loss'] for x in self.training_step_outputs]).mean()
        avg_label_loss = torch.stack([x['label_loss'] for x in self.training_step_outputs]).mean()
        avg_kd_loss = torch.stack([x['kd_loss'] for x in self.training_step_outputs]).mean()
        self.log_dict({'train/loss': torch.as_tensor(avg_loss).cuda(),
                       'train/label_loss': torch.as_tensor(avg_label_loss).cuda(),
                       'train/kd_loss': torch.as_tensor(avg_kd_loss).cuda()
                       }, sync_dist=True)

        self.training_step_outputs.clear()

    def validation_step(self, val_batch, batch_idx):
        x, _, y = val_batch
        x = self.mel_forward(x)
        y_hat, _ = self.model(x)
        loss = F.binary_cross_entropy_with_logits(y_hat, y)
        preds = torch.sigmoid(y_hat)
        results = {'val_loss': loss, "preds": preds, "targets": y}
        results = {k: v.cpu() for k, v in results.items()}
        self.validation_step_outputs.append(results)

    def on_validation_epoch_end(self):
        loss = torch.stack([x['val_loss'] for x in self.validation_step_outputs])
        preds = torch.cat([x['preds'] for x in self.validation_step_outputs], dim=0)
        targets = torch.cat([x['targets'] for x in self.validation_step_outputs], dim=0)

        all_preds = self.all_gather(preds).reshape(-1, preds.shape[-1]).cpu().float().numpy()
        all_targets = self.all_gather(targets).reshape(-1, targets.shape[-1]).cpu().float().numpy()
        all_loss = self.all_gather(loss).reshape(-1,)

        try:
            average_precision = metrics.average_precision_score(
                all_targets, all_preds, average=None)
        except ValueError:
            average_precision = np.array([np.nan] * 527)
        try:
            roc = metrics.roc_auc_score(all_targets, all_preds, average=None)
        except ValueError:
            roc = np.array([np.nan] * 527)
        logs = {'val/loss': torch.as_tensor(all_loss).mean().cuda(),
                'val/ap': torch.as_tensor(average_precision).mean().cuda(),
                'val/roc': torch.as_tensor(roc).mean().cuda()
                }
        self.log_dict(logs, sync_dist=False)
        self.validation_step_outputs.clear()


def train(config):
    # Train Models from scratch or ImageNet pre-trained on AudioSet
    # PaSST ensemble (https://github.com/kkoutini/PaSST) stored in 'resources/passt_enemble_logits_mAP_495.npy'
    # can be used as a teacher.

    # logging is done using wandb
    wandb_logger = WandbLogger(
        project="EfficientAudioTagging",
        notes="Training efficient audio tagging models on AudioSet using Knowledge Distillation.",
        tags=["AudioSet", "Audio Tagging", "Knowledge Disitillation"],
        config=config,
        name=config.experiment_name
    )

    train_dl = DataLoader(dataset=get_full_training_set(resample_rate=config.resample_rate,
                                                        roll=config.roll,
                                                        wavmix=config.wavmix,
                                                        gain_augment=config.gain_augment),
                          sampler=get_ft_weighted_sampler(config.epoch_len),  # sampler important to balance classes
                          worker_init_fn=worker_init_fn,
                          num_workers=config.num_workers,
                          batch_size=config.batch_size)

    # eval dataloader
    eval_dl = DataLoader(dataset=get_test_set(resample_rate=config.resample_rate),
                         worker_init_fn=worker_init_fn,
                         num_workers=config.num_workers,
                         batch_size=config.batch_size)

    # create pytorch lightening module
    pl_module = PLModule(config)

    # create monitor to keep track of learning rate - we want to check the behaviour of our learning rate schedule
    lr_monitor = LearningRateMonitor(logging_interval='epoch')
    # create the pytorch lightening trainer by specifying the number of epochs to train, the logger,
    # on which kind of device(s) to train and possible callbacks
    trainer = pl.Trainer(max_epochs=config.n_epochs,
                         logger=wandb_logger,
                         accelerator='auto',
                         devices=config.num_devices,
                         precision=config.precision,
                         num_sanity_val_steps=0,
                         callbacks=[lr_monitor])

    # start training and validation for the specified number of epochs
    trainer.fit(pl_module, train_dl, eval_dl)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Example of parser. ')

    # general
    parser.add_argument('--experiment_name', type=str, default="AudioSet")
    parser.add_argument('--batch_size', type=int, default=120)
    parser.add_argument('--num_workers', type=int, default=12)
    parser.add_argument('--num_devices', type=int, default=4)

    # evaluation
    # if ensemble is set, 'model_name' is not used
    parser.add_argument('--ensemble', nargs='+', default=[])
    parser.add_argument('--model_name', type=str, default="mn10_as")  # used also for training
    parser.add_argument('--cuda', action='store_true', default=False)

    # training
    parser.add_argument('--precision', type=int, default=16)
    parser.add_argument('--pretrained', action='store_true', default=False)
    parser.add_argument('--pretrain_final_temp', type=float, default=30.0)  # for DyMN
    parser.add_argument('--model_width', type=float, default=1.0)
    parser.add_argument('--strides', nargs=4, default=[2, 2, 2, 2], type=int)
    parser.add_argument('--head_type', type=str, default="mlp")
    parser.add_argument('--se_dims', type=str, default="c")
    parser.add_argument('--n_epochs', type=int, default=200)
    parser.add_argument('--mixup_alpha', type=float, default=0.3)
    parser.add_argument('--epoch_len', type=int, default=100000)
    parser.add_argument('--roll', action='store_true', default=False)
    parser.add_argument('--wavmix', action='store_true', default=False)
    parser.add_argument('--gain_augment', type=int, default=0)

    # optimizer
    parser.add_argument('--adamw', action='store_true', default=False)
    parser.add_argument('--weight_decay', type=float, default=0.0001)
    # lr schedule
    parser.add_argument('--max_lr', type=float, default=0.003)
    parser.add_argument('--warm_up_len', type=int, default=8)
    parser.add_argument('--ramp_down_start', type=int, default=80)
    parser.add_argument('--ramp_down_len', type=int, default=95)
    parser.add_argument('--last_lr_value', type=float, default=0.01)

    # knowledge distillation
    parser.add_argument('--teacher_preds', type=str,
                        default=os.path.join("resources", "passt_enemble_logits_mAP_495.npy"))
    parser.add_argument('--fname_to_index', type=str,
                        default=os.path.join("resources", "fname_to_index.pkl"))
    parser.add_argument('--temperature', type=float, default=1)
    parser.add_argument('--kd_lambda', type=float, default=0.1)

    # preprocessing
    parser.add_argument('--resample_rate', type=int, default=32000)
    parser.add_argument('--window_size', type=int, default=800)
    parser.add_argument('--hop_size', type=int, default=320)
    parser.add_argument('--n_fft', type=int, default=1024)
    parser.add_argument('--n_mels', type=int, default=128)
    parser.add_argument('--freqm', type=int, default=0)
    parser.add_argument('--timem', type=int, default=0)
    parser.add_argument('--fmin', type=int, default=0)
    parser.add_argument('--fmax', type=int, default=None)
    parser.add_argument('--fmin_aug_range', type=int, default=10)
    parser.add_argument('--fmax_aug_range', type=int, default=2000)

    args = parser.parse_args()
    train(args)