-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathtrain_lina.py
132 lines (108 loc) · 4.3 KB
/
train_lina.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from typing import List, Optional, Tuple
import pytorch_lightning as ptl
import torch
from pytorch_lightning.cli import LightningCLI
from model.attentive_rnn import AttentiveRNN
from model.modeling_lina import LinaModel
from torch import nn
from transformers import get_cosine_schedule_with_warmup
from model.accuracy import MulticlassAccuracy
class TrainLina(ptl.LightningModule):
def __init__(
self,
attentive_rnn: AttentiveRNN,
d_model: int,
quant_layer: List[int],
n_codebook: int,
n_special_token_in: int,
n_special_token_out: int,
n_txt_vocab: int,
tie_embed: bool = False,
txt_encoder: Optional[nn.Module] = None,
spk_encoder: Optional[nn.Module] = None,
learning_rate: float = 5e-4,
weight_decay: float = 0.1,
betas: Tuple[float, float] = (0.9, 0.999),
n_warmup_steps: int = 500,
n_training_steps: int = 300000,
mask_text_p: float = 0.,
load_weights: Optional[str] = None,
):
super(TrainLina, self).__init__()
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.betas = betas
self.n_warmup_steps = n_warmup_steps
self.n_training_steps = n_training_steps
self.model = LinaModel(
attentive_rnn,
d_model,
len(quant_layer),
n_codebook,
n_special_token_in,
n_special_token_out,
n_txt_vocab,
tie_embed=tie_embed,
txt_encoder=txt_encoder,
spk_encoder=spk_encoder,
mask_text_p=mask_text_p,
)
self.save_hyperparameters()
self.accuracy_metric = MulticlassAccuracy(
n_codebook + n_special_token_out,
top_k=10,
ignore_index=[0, 1],
)
if load_weights is not None:
model = torch.load(load_weights)
self.load_state_dict(model["state_dict"])
def on_train_epoch_start(self):
if hasattr(self.trainer.train_dataloader.batch_sampler, "set_epoch"):
self.trainer.train_dataloader.batch_sampler.set_epoch(self.current_epoch)
def step(self, batch):
text_token = batch["text_token"]
audio_token = batch["audio_token"]
crossatt_mask = batch["crossatt_mask"]
crossatt_pos = batch["crossatt_pos"]
encoder_mask = batch["encoder_mask"]
y_mask = batch["y_mask"]
logits, loss, att, masked_logits, masked_target = self.model(text_token, audio_token, encoder_mask, crossatt_mask, logits_mask=y_mask, crossatt_pos=crossatt_pos)
n_quant = masked_logits.shape[1]
accs = []
return logits, loss, att, accs
def training_step(self, batch, idx):
logits, loss, att, accs = self.step(batch)
self.log("train_loss", loss, prog_bar=True, sync_dist=True)
for i, acc in enumerate(accs):
self.log("train_acc_" + str(i), acc, prog_bar=True, sync_dist=True)
return loss
def validation_step(self, batch, idx):
logits, loss, att, accs = self.step(batch)
self.log("val_loss", loss, prog_bar=True, sync_dist=True)
for i, acc in enumerate(accs):
self.log("val_acc_" + str(i), acc, prog_bar=True, sync_dist=True)
return loss
def configure_optimizers(self):
params = [
{
"params": self.model.parameters(),
"weight_decay": self.weight_decay,
}
]
opt = torch.optim.AdamW(
params,
lr=self.learning_rate,
betas=self.betas,
)
scheduler = get_cosine_schedule_with_warmup(opt, num_warmup_steps=self.n_warmup_steps,
num_training_steps=self.n_training_steps,)
return [opt], [{'scheduler': scheduler, "interval": "step"}]
def cli(run=True):
class LinaCli(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments(
"data.init_args.quant_layer", "model.init_args.quant_layer"
)
return LinaCli(parser_kwargs={"parser_mode": "omegaconf"}, save_config_kwargs={"overwrite": True}, run=run)
if __name__ == "__main__":
cli()