import torch
from torch.autograd import Variable
import argparse
import copy
import pickle
from Datahelper2 import *
from Model import *

from hyperboard import Agent


# train 36431  32384
# validate 4048  8095
# test 61191

def train(model_path,lr,train_batch_size,validate_batch_size,validate_batch_num,resize,train_gpu,validate_gpu=-1):

    # train_gpu = 0
    # validate_gpu = 1
    # model_path = '../amazon2/alexnet'
    # train_batch_size = 256
    # validate_batch_size = 128
    # validate_batch_num = 8

    # parameters
    train_num = 32384
    validate_num = 8095
    DS = '/home/jianglibin/PythonProject/Amazon/train_validate_dataset.h5'
    IMG_TRAIN_PATH = '/home/jianglibin/PythonProject/Amazon/data/train-jpg/'
    LABEL_PATH = '/home/jianglibin/PythonProject/amazon2/labels.pkl'
    IMG_EXT = '.jpg'


    k=5
    epochs = 1
    # lr = 1e-4
    weight_decay = 0
    momentum = 0.9


    criteria2metric = {
        'train loss': 'loss',
        'valid loss': 'loss'
    }
    hyperparameters_train = {
        'name':'train',
        'learning rate': lr,
        'batch size': train_batch_size,
        'optimizer': 'Adam',
        'momentum': 0,
        'net':model_path.split('/')[-1],
        'epoch':'No.1',
    }
    hyperparameters_validate = {
        'name':'validate',
        'learning rate': lr,
        'batch size': train_batch_size,
        'optimizer': 'Adam',
        'momentum': 0,
        'net':model_path.split('/')[-1],
        'epoch': 'No.1',
    }


    agent = Agent(username='jlb',password='1993610')
    train_loss_show = agent.register(hyperparameters_train, criteria2metric['train loss'])
    validate_loss_show = agent.register(hyperparameters_validate, criteria2metric['valid loss'])
    global_step = 0

    with open('kfold.pkl', 'rb') as f:
        kfold = pickle.load(f)

    loss_info = []    # 第i个记录了 fold i 的最小(train_loss,validate_loss)
    for fold in range(k):
        train_index = kfold[fold][0]
        validate_index = kfold[fold][1]


        model = PyNet_10((3,256,256),17)
        if model.getname()!=model_path.split('/')[-1]:
            print('Wrong Model!')
            return
        model.cuda(device_id=train_gpu)

        optimizer = torch.optim.Adam(model.parameters(), lr=lr,weight_decay=weight_decay)

        dset_train = AmazonDateset_train(train_index,IMG_TRAIN_PATH,IMG_EXT,LABEL_PATH,resize=resize)
        train_loader = DataLoader(dset_train, batch_size=train_batch_size, shuffle=True, num_workers=6)
        min_loss = [0.9,0.9]
        for epoch in range(epochs):
            print('--------------Epoch %d: train-----------' % epoch)
            model.train()
            for step, (data, target) in enumerate(train_loader):
                data, target = Variable(data), Variable(target)
                data = data.cuda(device_id=train_gpu)
                target = target.cuda(device_id=train_gpu)

                optimizer.zero_grad()
                output = model(data)
                # print(output.size())
                loss = F.binary_cross_entropy(output, target)
                loss.backward()
                optimizer.step()
                agent.append(train_loss_show, global_step, loss.data[0])
                global_step += 1
                if step % 10 == 0:
                    model.eval()
                    if validate_gpu != -1:
                        model.cuda(validate_gpu)
                    dset_validate = AmazonDateset_validate(validate_index, IMG_TRAIN_PATH, IMG_EXT, LABEL_PATH,random_transform=True,resize=resize)
                    validate_loader = DataLoader(dset_validate, batch_size=validate_batch_size, shuffle=True, num_workers=6)
                    total_vloss = 0
                    for vstep, (vdata, vtarget) in enumerate(validate_loader):
                        vdata, vtarget = Variable(vdata), Variable(vtarget)
                        if validate_gpu != -1:
                            vdata = vdata.cuda(validate_gpu)
                            vtarget = vtarget.cuda(validate_gpu)
                        else:
                            vdata = vdata.cuda(train_gpu)
                            vtarget = vtarget.cuda(train_gpu)

                        voutput = model(vdata)
                        vloss = F.binary_cross_entropy(voutput, vtarget)
                        total_vloss += vloss.data[0]
                        if vstep == (validate_batch_num-1):
                            break
                    vloss = total_vloss / validate_batch_num
                    model.train()
                    if validate_gpu != -1:
                        model.cuda(train_gpu)

                    agent.append(validate_loss_show, global_step, vloss)

                    print('{} Fold{} Epoch{} Step{}: [{}/{} ({:.0f}%)]\tTrain Loss: {:.6f}\tValidate Loss: {:.6f}'.format(model_path.split('/')[-1],fold, epoch,global_step, step * train_batch_size,
                                                                                   len(train_loader.dataset),
                                                                                   100. * step / len(train_loader),
                                                                                   loss.data[0],vloss))
                    if vloss<min_loss[1]:
                        min_loss[1] = vloss
                        min_loss[0] = loss.data[0]
                        model_save = copy.deepcopy(model)
                        torch.save(model_save.cpu(), os.path.join(model_path,'fold%d.mod'%(fold)))
        loss_info.append(min_loss)


    print('-----------------------------------------')
    print(model_path.split('/')[-1]+':')
    for i,l in enumerate(loss_info):
        print('Fold%d: Train loss:%f\tValidate loss:%f'%(i,l[0],l[1]))

    with open(os.path.join(model_path,'train_loss_info.pkl'),'wb') as f:
        pickle.dump(loss_info,f)


if __name__ == '__main__':
    train('../amazon2/pynet10',1e-4,24,8,4,256,0,1)