Skip to content

Commit 1cf4240

Browse files
committed
优化:进行了文件读取的编码设定
1 parent 3764626 commit 1cf4240

18 files changed

+608
-143
lines changed

接口.bat 1.运行API接口.bat

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
chcp 65001
2+
13
SET FFMPEG_PATH=%cd%\runtime\ffmpeg\bin
24
SET PATH=%FFMPEG_PATH%;%PATH%
35
runtime\python.exe api_v2.py

GPT_SoVITS/AR/models/t2s_model.py

+49-59
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
22
# reference: https://github.com/lifeiteng/vall-e
33
import math
4+
import os, sys
5+
now_dir = os.getcwd()
6+
sys.path.append(now_dir)
47
from typing import List, Optional
58
import torch
69
from tqdm import tqdm
@@ -12,7 +15,7 @@
1215
logits_to_probs,
1316
multinomial_sample_one_no_sync,
1417
dpo_loss,
15-
make_reject_y,
18+
make_reject_y,
1619
get_batch_logps
1720
)
1821
from AR.modules.embedding import SinePositionalEmbedding
@@ -36,7 +39,7 @@
3639
"EOS": 1024,
3740
}
3841

39-
# @torch.jit.script ## 使用的话首次推理会非常慢,而且推理速度不稳定
42+
@torch.jit.script
4043
# Efficient implementation equivalent to the following:
4144
def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, scale:Optional[torch.Tensor]=None) -> torch.Tensor:
4245
B, H, L, S =query.size(0), query.size(1), query.size(-2), key.size(-2)
@@ -82,20 +85,20 @@ def forward(self, x):
8285
@torch.jit.script
8386
class T2SBlock:
8487
def __init__(
85-
self,
86-
num_heads,
87-
hidden_dim: int,
88-
mlp: T2SMLP,
89-
qkv_w,
90-
qkv_b,
91-
out_w,
92-
out_b,
93-
norm_w1,
94-
norm_b1,
95-
norm_eps1,
96-
norm_w2,
97-
norm_b2,
98-
norm_eps2,
88+
self,
89+
num_heads,
90+
hidden_dim: int,
91+
mlp: T2SMLP,
92+
qkv_w,
93+
qkv_b,
94+
out_w,
95+
out_b,
96+
norm_w1,
97+
norm_b1,
98+
norm_eps1,
99+
norm_w2,
100+
norm_b2,
101+
norm_eps2,
99102
):
100103
self.num_heads = num_heads
101104
self.mlp = mlp
@@ -123,7 +126,7 @@ def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]):
123126
else:
124127
return x * padding_mask
125128

126-
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
129+
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None):
127130

128131

129132
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
@@ -140,10 +143,7 @@ def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:
140143
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
141144
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
142145

143-
if torch_sdpa:
144-
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
145-
else:
146-
attn = scaled_dot_product_attention(q, k, v, attn_mask)
146+
attn = scaled_dot_product_attention(q, k, v, attn_mask)
147147

148148
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
149149
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
@@ -186,7 +186,7 @@ def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:
186186
)
187187
return x, k_cache, v_cache
188188

189-
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
189+
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:Optional[torch.Tensor]=None):
190190
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
191191

192192
k_cache = torch.cat([k_cache, k], dim=1)
@@ -201,10 +201,7 @@ def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.
201201
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
202202

203203

204-
if torch_sdpa:
205-
attn = F.scaled_dot_product_attention(q, k, v)
206-
else:
207-
attn = scaled_dot_product_attention(q, k, v, attn_mask)
204+
attn = scaled_dot_product_attention(q, k, v, attn_mask)
208205

209206
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
210207
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
@@ -233,26 +230,21 @@ def __init__(self, num_blocks : int, blocks: List[T2SBlock]):
233230

234231
def process_prompt(
235232
self, x:torch.Tensor, attn_mask : torch.Tensor,
236-
padding_mask : Optional[torch.Tensor]=None,
237-
torch_sdpa:bool=True
233+
padding_mask : Optional[torch.Tensor]=None,
238234
):
239235
k_cache : List[torch.Tensor] = []
240236
v_cache : List[torch.Tensor] = []
241237
for i in range(self.num_blocks):
242-
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask, torch_sdpa)
238+
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask)
243239
k_cache.append(k_cache_)
244240
v_cache.append(v_cache_)
245241
return x, k_cache, v_cache
246242

247243
def decode_next_token(
248-
self, x:torch.Tensor,
249-
k_cache: List[torch.Tensor],
250-
v_cache: List[torch.Tensor],
251-
attn_mask : Optional[torch.Tensor]=None,
252-
torch_sdpa:bool=True
244+
self, x:torch.Tensor, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor], attn_mask : Optional[torch.Tensor]=None,
253245
):
254246
for i in range(self.num_blocks):
255-
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i], attn_mask, torch_sdpa)
247+
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i], attn_mask)
256248
return x, k_cache, v_cache
257249

258250

@@ -464,6 +456,7 @@ def forward_old(self, x, x_lens, y, y_lens, bert_feature):
464456
(0, y_len),
465457
value=True,
466458
)
459+
# x_attn_mask[:, x_len]=False
467460
y_attn_mask = F.pad(
468461
torch.triu(
469462
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
@@ -498,14 +491,14 @@ def forward_old(self, x, x_lens, y, y_lens, bert_feature):
498491

499492
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
500493
def infer(
501-
self,
502-
x,
503-
x_lens,
504-
prompts,
505-
bert_feature,
506-
top_k: int = -100,
507-
early_stop_num: int = -1,
508-
temperature: float = 1.0,
494+
self,
495+
x,
496+
x_lens,
497+
prompts,
498+
bert_feature,
499+
top_k: int = -100,
500+
early_stop_num: int = -1,
501+
temperature: float = 1.0,
509502
):
510503
x = self.ar_text_embedding(x)
511504
x = x + self.bert_proj(bert_feature.transpose(1, 2))
@@ -529,7 +522,7 @@ def infer(
529522
value=True,
530523
)
531524
y_attn_mask = F.pad(
532-
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
525+
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=0),
533526
(x_len, 0),
534527
value=False,
535528
)
@@ -588,8 +581,7 @@ def infer_panel_batch_infer(
588581
):
589582
if prompts is None:
590583
print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
591-
return self.infer_panel_naive_batched(x, x_lens, prompts, bert_feature, top_k=top_k, top_p=top_p, early_stop_num=early_stop_num, temperature=temperature, **kwargs)
592-
584+
return self.infer_panel_0307(x, x_lens, prompts, bert_feature, top_k=top_k, top_p=top_p, early_stop_num=early_stop_num, temperature=temperature, **kwargs)
593585

594586
max_len = kwargs.get("max_len",x_lens.max())
595587
x_list = []
@@ -670,9 +662,10 @@ def infer_panel_batch_infer(
670662
idx_list = [None]*y.shape[0]
671663
for idx in tqdm(range(1500)):
672664
if idx == 0:
673-
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, xy_padding_mask, False)
665+
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, xy_padding_mask)
674666
else:
675-
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, xy_attn_mask, False)
667+
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, xy_attn_mask)
668+
676669
logits = self.ar_predict_layer(
677670
xy_dec[:, -1]
678671
)
@@ -750,7 +743,7 @@ def infer_panel_batch_infer(
750743
# print(idx_list)
751744
return y_list, idx_list
752745

753-
def infer_panel_naive_batched(self,
746+
def infer_panel_0307(self,
754747
x:List[torch.LongTensor], #####全部文本token
755748
x_lens:torch.LongTensor,
756749
prompts:torch.LongTensor, ####参考音频token
@@ -799,7 +792,7 @@ def infer_panel_naive(
799792

800793
# AR Decoder
801794
y = prompts
802-
795+
803796
x_len = x.shape[1]
804797
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
805798
stop = False
@@ -836,12 +829,10 @@ def infer_panel_naive(
836829
(x_len, 0),
837830
value=False,
838831
)
839-
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)\
840-
.unsqueeze(0)\
841-
.expand(bsz*self.num_head, -1, -1)\
842-
.view(bsz, self.num_head, src_len, src_len)\
843-
.to(device=x.device, dtype=torch.bool)
844-
832+
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).unsqueeze(0).expand(bsz*self.num_head, -1, -1).view(bsz, self.num_head, src_len, src_len).to(x.device)
833+
xy_attn_mask = xy_attn_mask.bool()
834+
# new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
835+
# xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf"))
845836
for idx in tqdm(range(1500)):
846837
if xy_attn_mask is not None:
847838
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
@@ -869,7 +860,7 @@ def infer_panel_naive(
869860
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
870861
stop = True
871862
if stop:
872-
if y.shape[1] == 0:
863+
if y.shape[1]==0:
873864
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
874865
print("bad zero prediction")
875866
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
@@ -883,7 +874,6 @@ def infer_panel_naive(
883874
return y[:, :-1], 0
884875
return y[:, :-1], idx - 1
885876

886-
887877
def infer_panel(
888878
self,
889879
x:torch.LongTensor, #####全部文本token
@@ -897,4 +887,4 @@ def infer_panel(
897887
repetition_penalty: float = 1.35,
898888
**kwargs
899889
):
900-
return self.infer_panel_naive(x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs)
890+
return self.infer_panel_naive(x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs)

0 commit comments

Comments
 (0)