|
| 1 | +import datetime |
| 2 | +import torch |
| 3 | +from torch.nn import CrossEntropyLoss |
| 4 | +from torch.utils.data import DataLoader |
| 5 | +from torch.utils.data import random_split |
| 6 | + |
| 7 | +import yews.datasets as dsets |
| 8 | +import yews.transforms as transforms |
| 9 | +from yews.train import Trainer |
| 10 | + |
| 11 | +#from yews.models import cpic |
| 12 | +#from yews.models import cpic_v1 |
| 13 | +#from yews.models import cpic_v2 |
| 14 | +#cpic = cpic_v1 |
| 15 | + |
| 16 | +from yews.models import polarity_v1 |
| 17 | +from yews.models import polarity_v2 |
| 18 | +from yews.models import polarity_lstm |
| 19 | +polarity=polarity_lstm |
| 20 | + |
| 21 | + |
| 22 | +if __name__ == '__main__': |
| 23 | + |
| 24 | + print("Now: start : " + str(datetime.datetime.now())) |
| 25 | + |
| 26 | + # Preprocessing |
| 27 | + waveform_transform = transforms.Compose([ |
| 28 | + transforms.ZeroMean(), |
| 29 | + #transforms.SoftClip(1e-4), |
| 30 | + transforms.ToTensor(), |
| 31 | + ]) |
| 32 | + |
| 33 | + # Prepare dataset |
| 34 | + dsets.set_memory_limit(10 * 1024 ** 3) # first number is GB |
| 35 | + # dset = dsets.Wenchuan(path='/home/qszhai/temp_project/deep_learning_course_project/cpic', download=False,sample_transform=waveform_transform) |
| 36 | + dset = dsets.SCSN_polarity(path='/home/qszhai/temp_project/deep_learning_course_project/first_motion_polarity/scsn_data/train_npy', download=False, sample_transform=waveform_transform) |
| 37 | + |
| 38 | + # Split datasets into training and validation |
| 39 | + train_length = int(len(dset) * 0.8) |
| 40 | + val_length = len(dset) - train_length |
| 41 | + train_set, val_set = random_split(dset, [train_length, val_length]) |
| 42 | + |
| 43 | + # Prepare dataloaders |
| 44 | + train_loader = DataLoader(train_set, batch_size=5000, shuffle=True, num_workers=4) |
| 45 | + val_loader = DataLoader(val_set, batch_size=10000, shuffle=False, num_workers=4) |
| 46 | + |
| 47 | + # Prepare trainer |
| 48 | + # trainer = Trainer(cpic(), CrossEntropyLoss(), lr=0.1) |
| 49 | + # note: please use only 1 gpu to run LSTM, https://github.com/pytorch/pytorch/issues/21108 |
| 50 | + model_conf = {"hidden_size": 64} |
| 51 | + plt = polarity(**model_conf) |
| 52 | + trainer = Trainer(plt, CrossEntropyLoss(), lr=0.001) |
| 53 | + |
| 54 | + # Train model over training dataset |
| 55 | + trainer.train(train_loader, val_loader, epochs=50, print_freq=100) |
| 56 | + #resume='checkpoint_best.pth.tar') |
| 57 | + |
| 58 | + # Save training results to disk |
| 59 | + trainer.results(path='scsn_polarity_results.pth.tar') |
| 60 | + |
| 61 | + # Validate saved model |
| 62 | + results = torch.load('scsn_polarity_results.pth.tar') |
| 63 | + #model = cpic() |
| 64 | + model = plt |
| 65 | + model.load_state_dict(results['model']) |
| 66 | + trainer = Trainer(model, CrossEntropyLoss(), lr=0.001) |
| 67 | + trainer.validate(val_loader, print_freq=100) |
| 68 | + |
| 69 | + print("Now: end : " + str(datetime.datetime.now())) |
| 70 | + |
| 71 | + import matplotlib.pyplot as plt |
| 72 | + import numpy as np |
| 73 | + |
| 74 | + myfontsize1=14 |
| 75 | + myfontsize2=18 |
| 76 | + myfontsize3=24 |
| 77 | + |
| 78 | + results = torch.load('scsn_polarity_results.pth.tar') |
| 79 | + |
| 80 | + fig, axes = plt.subplots(2, 1, num=0, figsize=(6, 4), sharex=True) |
| 81 | + axes[0].plot(results['val_acc'], label='Validation') |
| 82 | + axes[0].plot(results['train_acc'], label='Training') |
| 83 | + |
| 84 | + #axes[1].set_xlabel("Epochs",fontsize=myfontsize2) |
| 85 | + axes[0].set_xscale('log') |
| 86 | + axes[0].set_xlim([1, 100]) |
| 87 | + axes[0].xaxis.set_tick_params(labelsize=myfontsize1) |
| 88 | + |
| 89 | + axes[0].set_ylabel("Accuracies (%)",fontsize=myfontsize2) |
| 90 | + axes[0].set_ylim([0, 100]) |
| 91 | + axes[0].set_yticks(np.arange(0, 101, 10)) |
| 92 | + axes[0].yaxis.set_tick_params(labelsize=myfontsize1) |
| 93 | + |
| 94 | + axes[0].grid(True, 'both') |
| 95 | + axes[0].legend(loc=4) |
| 96 | + |
| 97 | + #axes[1].semilogx(results['val_loss'], label='Validation') |
| 98 | + #axes[1].semilogx(results['train_loss'], label='Training') |
| 99 | + axes[1].plot(results['val_loss'], label='Validation') |
| 100 | + axes[1].plot(results['train_loss'], label='Training') |
| 101 | + |
| 102 | + axes[1].set_xlabel("Epochs",fontsize=myfontsize2) |
| 103 | + axes[1].set_xscale('log') |
| 104 | + axes[1].set_xlim([1, 100]) |
| 105 | + axes[1].xaxis.set_tick_params(labelsize=myfontsize1) |
| 106 | + |
| 107 | + axes[1].set_ylabel("Losses",fontsize=myfontsize2) |
| 108 | + axes[1].set_ylim([0.0, 1.0]) |
| 109 | + axes[1].set_yticks(np.arange(0.0,1.01,0.2)) |
| 110 | + axes[1].yaxis.set_tick_params(labelsize=myfontsize1) |
| 111 | + |
| 112 | + axes[1].grid(True, 'both') |
| 113 | + axes[1].legend(loc=1) |
| 114 | + |
| 115 | + fig.tight_layout() |
| 116 | + plt.savefig('Accuracies_train_val.pdf') |
0 commit comments