1
1
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
2
2
# reference: https://github.com/lifeiteng/vall-e
3
3
import math
4
+ import os , sys
5
+ now_dir = os .getcwd ()
6
+ sys .path .append (now_dir )
4
7
from typing import List , Optional
5
8
import torch
6
9
from tqdm import tqdm
12
15
logits_to_probs ,
13
16
multinomial_sample_one_no_sync ,
14
17
dpo_loss ,
15
- make_reject_y ,
18
+ make_reject_y ,
16
19
get_batch_logps
17
20
)
18
21
from AR .modules .embedding import SinePositionalEmbedding
36
39
"EOS" : 1024 ,
37
40
}
38
41
39
- # @torch.jit.script ## 使用的话首次推理会非常慢,而且推理速度不稳定
42
+ @torch .jit .script
40
43
# Efficient implementation equivalent to the following:
41
44
def scaled_dot_product_attention (query :torch .Tensor , key :torch .Tensor , value :torch .Tensor , attn_mask :Optional [torch .Tensor ]= None , scale :Optional [torch .Tensor ]= None ) -> torch .Tensor :
42
45
B , H , L , S = query .size (0 ), query .size (1 ), query .size (- 2 ), key .size (- 2 )
@@ -82,20 +85,20 @@ def forward(self, x):
82
85
@torch .jit .script
83
86
class T2SBlock :
84
87
def __init__ (
85
- self ,
86
- num_heads ,
87
- hidden_dim : int ,
88
- mlp : T2SMLP ,
89
- qkv_w ,
90
- qkv_b ,
91
- out_w ,
92
- out_b ,
93
- norm_w1 ,
94
- norm_b1 ,
95
- norm_eps1 ,
96
- norm_w2 ,
97
- norm_b2 ,
98
- norm_eps2 ,
88
+ self ,
89
+ num_heads ,
90
+ hidden_dim : int ,
91
+ mlp : T2SMLP ,
92
+ qkv_w ,
93
+ qkv_b ,
94
+ out_w ,
95
+ out_b ,
96
+ norm_w1 ,
97
+ norm_b1 ,
98
+ norm_eps1 ,
99
+ norm_w2 ,
100
+ norm_b2 ,
101
+ norm_eps2 ,
99
102
):
100
103
self .num_heads = num_heads
101
104
self .mlp = mlp
@@ -123,7 +126,7 @@ def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]):
123
126
else :
124
127
return x * padding_mask
125
128
126
- def process_prompt (self , x :torch .Tensor , attn_mask : torch .Tensor , padding_mask :Optional [torch .Tensor ]= None , torch_sdpa : bool = True ):
129
+ def process_prompt (self , x :torch .Tensor , attn_mask : torch .Tensor , padding_mask :Optional [torch .Tensor ]= None ):
127
130
128
131
129
132
q , k , v = F .linear (self .to_mask (x , padding_mask ), self .qkv_w , self .qkv_b ).chunk (3 , dim = - 1 )
@@ -140,10 +143,7 @@ def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:
140
143
k = k_cache .view (batch_size , kv_len , self .num_heads , - 1 ).transpose (1 , 2 )
141
144
v = v_cache .view (batch_size , kv_len , self .num_heads , - 1 ).transpose (1 , 2 )
142
145
143
- if torch_sdpa :
144
- attn = F .scaled_dot_product_attention (q , k , v , ~ attn_mask )
145
- else :
146
- attn = scaled_dot_product_attention (q , k , v , attn_mask )
146
+ attn = scaled_dot_product_attention (q , k , v , attn_mask )
147
147
148
148
attn = attn .permute (2 , 0 , 1 , 3 ).reshape (batch_size * q_len , self .hidden_dim )
149
149
attn = attn .view (q_len , batch_size , self .hidden_dim ).transpose (1 , 0 )
@@ -186,7 +186,7 @@ def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:
186
186
)
187
187
return x , k_cache , v_cache
188
188
189
- def decode_next_token (self , x :torch .Tensor , k_cache :torch .Tensor , v_cache :torch .Tensor , attn_mask :Optional [torch .Tensor ]= None , torch_sdpa : bool = True ):
189
+ def decode_next_token (self , x :torch .Tensor , k_cache :torch .Tensor , v_cache :torch .Tensor , attn_mask :Optional [torch .Tensor ]= None ):
190
190
q , k , v = F .linear (x , self .qkv_w , self .qkv_b ).chunk (3 , dim = - 1 )
191
191
192
192
k_cache = torch .cat ([k_cache , k ], dim = 1 )
@@ -201,10 +201,7 @@ def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.
201
201
v = v_cache .view (batch_size , kv_len , self .num_heads , - 1 ).transpose (1 , 2 )
202
202
203
203
204
- if torch_sdpa :
205
- attn = F .scaled_dot_product_attention (q , k , v )
206
- else :
207
- attn = scaled_dot_product_attention (q , k , v , attn_mask )
204
+ attn = scaled_dot_product_attention (q , k , v , attn_mask )
208
205
209
206
attn = attn .permute (2 , 0 , 1 , 3 ).reshape (batch_size * q_len , self .hidden_dim )
210
207
attn = attn .view (q_len , batch_size , self .hidden_dim ).transpose (1 , 0 )
@@ -233,26 +230,21 @@ def __init__(self, num_blocks : int, blocks: List[T2SBlock]):
233
230
234
231
def process_prompt (
235
232
self , x :torch .Tensor , attn_mask : torch .Tensor ,
236
- padding_mask : Optional [torch .Tensor ]= None ,
237
- torch_sdpa :bool = True
233
+ padding_mask : Optional [torch .Tensor ]= None ,
238
234
):
239
235
k_cache : List [torch .Tensor ] = []
240
236
v_cache : List [torch .Tensor ] = []
241
237
for i in range (self .num_blocks ):
242
- x , k_cache_ , v_cache_ = self .blocks [i ].process_prompt (x , attn_mask , padding_mask , torch_sdpa )
238
+ x , k_cache_ , v_cache_ = self .blocks [i ].process_prompt (x , attn_mask , padding_mask )
243
239
k_cache .append (k_cache_ )
244
240
v_cache .append (v_cache_ )
245
241
return x , k_cache , v_cache
246
242
247
243
def decode_next_token (
248
- self , x :torch .Tensor ,
249
- k_cache : List [torch .Tensor ],
250
- v_cache : List [torch .Tensor ],
251
- attn_mask : Optional [torch .Tensor ]= None ,
252
- torch_sdpa :bool = True
244
+ self , x :torch .Tensor , k_cache : List [torch .Tensor ], v_cache : List [torch .Tensor ], attn_mask : Optional [torch .Tensor ]= None ,
253
245
):
254
246
for i in range (self .num_blocks ):
255
- x , k_cache [i ], v_cache [i ] = self .blocks [i ].decode_next_token (x , k_cache [i ], v_cache [i ], attn_mask , torch_sdpa )
247
+ x , k_cache [i ], v_cache [i ] = self .blocks [i ].decode_next_token (x , k_cache [i ], v_cache [i ], attn_mask )
256
248
return x , k_cache , v_cache
257
249
258
250
@@ -464,6 +456,7 @@ def forward_old(self, x, x_lens, y, y_lens, bert_feature):
464
456
(0 , y_len ),
465
457
value = True ,
466
458
)
459
+ # x_attn_mask[:, x_len]=False
467
460
y_attn_mask = F .pad (
468
461
torch .triu (
469
462
torch .ones (y_len , y_len , dtype = torch .bool , device = x .device ),
@@ -498,14 +491,14 @@ def forward_old(self, x, x_lens, y, y_lens, bert_feature):
498
491
499
492
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
500
493
def infer (
501
- self ,
502
- x ,
503
- x_lens ,
504
- prompts ,
505
- bert_feature ,
506
- top_k : int = - 100 ,
507
- early_stop_num : int = - 1 ,
508
- temperature : float = 1.0 ,
494
+ self ,
495
+ x ,
496
+ x_lens ,
497
+ prompts ,
498
+ bert_feature ,
499
+ top_k : int = - 100 ,
500
+ early_stop_num : int = - 1 ,
501
+ temperature : float = 1.0 ,
509
502
):
510
503
x = self .ar_text_embedding (x )
511
504
x = x + self .bert_proj (bert_feature .transpose (1 , 2 ))
@@ -529,7 +522,7 @@ def infer(
529
522
value = True ,
530
523
)
531
524
y_attn_mask = F .pad (
532
- torch .triu (torch .ones (y_len , y_len , dtype = torch .bool ), diagonal = 1 ),
525
+ torch .triu (torch .ones (y_len , y_len , dtype = torch .bool ), diagonal = 0 ),
533
526
(x_len , 0 ),
534
527
value = False ,
535
528
)
@@ -588,8 +581,7 @@ def infer_panel_batch_infer(
588
581
):
589
582
if prompts is None :
590
583
print ("Warning: Prompt free is not supported batch_infer! switch to naive_infer" )
591
- return self .infer_panel_naive_batched (x , x_lens , prompts , bert_feature , top_k = top_k , top_p = top_p , early_stop_num = early_stop_num , temperature = temperature , ** kwargs )
592
-
584
+ return self .infer_panel_0307 (x , x_lens , prompts , bert_feature , top_k = top_k , top_p = top_p , early_stop_num = early_stop_num , temperature = temperature , ** kwargs )
593
585
594
586
max_len = kwargs .get ("max_len" ,x_lens .max ())
595
587
x_list = []
@@ -670,9 +662,10 @@ def infer_panel_batch_infer(
670
662
idx_list = [None ]* y .shape [0 ]
671
663
for idx in tqdm (range (1500 )):
672
664
if idx == 0 :
673
- xy_dec , k_cache , v_cache = self .t2s_transformer .process_prompt (xy_pos , xy_attn_mask , xy_padding_mask , False )
665
+ xy_dec , k_cache , v_cache = self .t2s_transformer .process_prompt (xy_pos , xy_attn_mask , xy_padding_mask )
674
666
else :
675
- xy_dec , k_cache , v_cache = self .t2s_transformer .decode_next_token (xy_pos , k_cache , v_cache , xy_attn_mask , False )
667
+ xy_dec , k_cache , v_cache = self .t2s_transformer .decode_next_token (xy_pos , k_cache , v_cache , xy_attn_mask )
668
+
676
669
logits = self .ar_predict_layer (
677
670
xy_dec [:, - 1 ]
678
671
)
@@ -750,7 +743,7 @@ def infer_panel_batch_infer(
750
743
# print(idx_list)
751
744
return y_list , idx_list
752
745
753
- def infer_panel_naive_batched (self ,
746
+ def infer_panel_0307 (self ,
754
747
x :List [torch .LongTensor ], #####全部文本token
755
748
x_lens :torch .LongTensor ,
756
749
prompts :torch .LongTensor , ####参考音频token
@@ -799,7 +792,7 @@ def infer_panel_naive(
799
792
800
793
# AR Decoder
801
794
y = prompts
802
-
795
+
803
796
x_len = x .shape [1 ]
804
797
x_attn_mask = torch .zeros ((x_len , x_len ), dtype = torch .bool )
805
798
stop = False
@@ -836,12 +829,10 @@ def infer_panel_naive(
836
829
(x_len , 0 ),
837
830
value = False ,
838
831
)
839
- xy_attn_mask = torch .concat ([x_attn_mask_pad , y_attn_mask ], dim = 0 )\
840
- .unsqueeze (0 )\
841
- .expand (bsz * self .num_head , - 1 , - 1 )\
842
- .view (bsz , self .num_head , src_len , src_len )\
843
- .to (device = x .device , dtype = torch .bool )
844
-
832
+ xy_attn_mask = torch .concat ([x_attn_mask_pad , y_attn_mask ], dim = 0 ).unsqueeze (0 ).expand (bsz * self .num_head , - 1 , - 1 ).view (bsz , self .num_head , src_len , src_len ).to (x .device )
833
+ xy_attn_mask = xy_attn_mask .bool ()
834
+ # new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
835
+ # xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf"))
845
836
for idx in tqdm (range (1500 )):
846
837
if xy_attn_mask is not None :
847
838
xy_dec , k_cache , v_cache = self .t2s_transformer .process_prompt (xy_pos , xy_attn_mask , None )
@@ -869,7 +860,7 @@ def infer_panel_naive(
869
860
if torch .argmax (logits , dim = - 1 )[0 ] == self .EOS or samples [0 , 0 ] == self .EOS :
870
861
stop = True
871
862
if stop :
872
- if y .shape [1 ] == 0 :
863
+ if y .shape [1 ]== 0 :
873
864
y = torch .concat ([y , torch .zeros_like (samples )], dim = 1 )
874
865
print ("bad zero prediction" )
875
866
print (f"T2S Decoding EOS [{ prefix_len } -> { y .shape [1 ]} ]" )
@@ -883,7 +874,6 @@ def infer_panel_naive(
883
874
return y [:, :- 1 ], 0
884
875
return y [:, :- 1 ], idx - 1
885
876
886
-
887
877
def infer_panel (
888
878
self ,
889
879
x :torch .LongTensor , #####全部文本token
@@ -897,4 +887,4 @@ def infer_panel(
897
887
repetition_penalty : float = 1.35 ,
898
888
** kwargs
899
889
):
900
- return self .infer_panel_naive (x , x_lens , prompts , bert_feature , top_k , top_p , early_stop_num , temperature , repetition_penalty , ** kwargs )
890
+ return self .infer_panel_naive (x , x_lens , prompts , bert_feature , top_k , top_p , early_stop_num , temperature , repetition_penalty , ** kwargs )
0 commit comments