1
+ import torch
2
+ import torch .nn as nn
3
+ import numpy as np
4
+ from time import time
5
+ from ..librosa_functions import *
6
+ from ..utils import *
7
+
8
+
9
+ class VQT (torch .nn .Module ):
10
+ def __init__ (
11
+ self ,
12
+ sr = 22050 ,
13
+ hop_length = 512 ,
14
+ fmin = 32.70 ,
15
+ fmax = None ,
16
+ n_bins = 84 ,
17
+ filter_scale = 1 ,
18
+ bins_per_octave = 12 ,
19
+ norm = True ,
20
+ basis_norm = 1 ,
21
+ gamma = 0 ,
22
+ window = 'hann' ,
23
+ pad_mode = 'reflect' ,
24
+ earlydownsample = True ,
25
+ trainable = False ,
26
+ output_format = 'Magnitude' ,
27
+ verbose = True
28
+ ):
29
+
30
+ super ().__init__ ()
31
+
32
+ self .norm = norm
33
+ self .hop_length = hop_length
34
+ self .pad_mode = pad_mode
35
+ self .n_bins = n_bins
36
+ self .earlydownsample = earlydownsample
37
+ self .trainable = trainable
38
+ self .output_format = output_format
39
+ self .filter_scale = filter_scale
40
+ self .bins_per_octave = bins_per_octave
41
+ self .sr = sr
42
+ self .gamma = gamma
43
+ self .basis_norm = basis_norm
44
+
45
+ # It will be used to calculate filter_cutoff and creating CQT kernels
46
+ Q = float (filter_scale )/ (2 ** (1 / bins_per_octave )- 1 )
47
+
48
+ # Creating lowpass filter and make it a torch tensor
49
+ if verbose == True :
50
+ print ("Creating low pass filter ..." , end = '\r ' )
51
+ start = time ()
52
+ lowpass_filter = torch .tensor (create_lowpass_filter (
53
+ band_center = 0.50 ,
54
+ kernelLength = 256 ,
55
+ transitionBandwidth = 0.001 )
56
+ )
57
+
58
+ self .register_buffer ('lowpass_filter' , lowpass_filter [None ,None ,:])
59
+ if verbose == True :
60
+ print ("Low pass filter created, time used = {:.4f} seconds" .format (time ()- start ))
61
+
62
+ n_filters = min (bins_per_octave , n_bins )
63
+ self .n_filters = n_filters
64
+ self .n_octaves = int (np .ceil (float (n_bins ) / bins_per_octave ))
65
+ if verbose == True :
66
+ print ("num_octave = " , self .n_octaves )
67
+
68
+ self .fmin_t = fmin * 2 ** (self .n_octaves - 1 )
69
+ remainder = n_bins % bins_per_octave
70
+
71
+ if remainder == 0 :
72
+ # Calculate the top bin frequency
73
+ fmax_t = self .fmin_t * 2 ** ((bins_per_octave - 1 )/ bins_per_octave )
74
+ else :
75
+ # Calculate the top bin frequency
76
+ fmax_t = self .fmin_t * 2 ** ((remainder - 1 )/ bins_per_octave )
77
+
78
+ # Adjusting the top minimum bins
79
+ self .fmin_t = fmax_t / 2 ** (1 - 1 / bins_per_octave )
80
+ if fmax_t > sr / 2 :
81
+ raise ValueError ('The top bin {}Hz has exceeded the Nyquist frequency, \
82
+ please reduce the n_bins' .format (fmax_t ))
83
+
84
+ if self .earlydownsample == True : # Do early downsampling if this argument is True
85
+ if verbose == True :
86
+ print ("Creating early downsampling filter ..." , end = '\r ' )
87
+ start = time ()
88
+ sr , self .hop_length , self .downsample_factor , early_downsample_filter , \
89
+ self .earlydownsample = get_early_downsample_params (sr ,
90
+ hop_length ,
91
+ fmax_t ,
92
+ Q ,
93
+ self .n_octaves ,
94
+ verbose )
95
+ self .register_buffer ('early_downsample_filter' , early_downsample_filter )
96
+
97
+ if verbose == True :
98
+ print ("Early downsampling filter created, \
99
+ time used = {:.4f} seconds" .format (time ()- start ))
100
+ else :
101
+ self .downsample_factor = 1.
102
+
103
+ # For normalization in the end
104
+ # The freqs returned by create_cqt_kernels cannot be used
105
+ # Since that returns only the top octave bins
106
+ # We need the information for all freq bin
107
+ alpha = 2.0 ** (1.0 / bins_per_octave ) - 1.0
108
+ freqs = fmin * 2.0 ** (np .r_ [0 :n_bins ] / np .float (bins_per_octave ))
109
+ self .frequencies = freqs
110
+ lenghts = np .ceil (Q * sr / (freqs + gamma / alpha ))
111
+
112
+ # get max window length depending on gamma value
113
+ max_len = int (max (lenghts ))
114
+ self .n_fft = int (2 ** (np .ceil (np .log2 (max_len ))))
115
+
116
+ lenghts = torch .tensor (lenghts ).float ()
117
+ self .register_buffer ('lenghts' , lenghts )
118
+
119
+
120
+ def forward (self , x , output_format = None , normalization_type = 'librosa' ):
121
+ """
122
+ Convert a batch of waveforms to VQT spectrograms.
123
+
124
+ Parameters
125
+ ----------
126
+ x : torch tensor
127
+ Input signal should be in either of the following shapes.\n
128
+ 1. ``(len_audio)``\n
129
+ 2. ``(num_audio, len_audio)``\n
130
+ 3. ``(num_audio, 1, len_audio)``
131
+ It will be automatically broadcast to the right shape
132
+ """
133
+ output_format = output_format or self .output_format
134
+
135
+ x = broadcast_dim (x )
136
+ if self .earlydownsample == True :
137
+ x = downsampling_by_n (x , self .early_downsample_filter , self .downsample_factor )
138
+ hop = self .hop_length
139
+ vqt = []
140
+
141
+ x_down = x # Preparing a new variable for downsampling
142
+ my_sr = self .sr
143
+
144
+ for i in range (self .n_octaves ):
145
+ if i > 0 :
146
+ x_down = downsampling_by_2 (x_down , self .lowpass_filter )
147
+ my_sr /= 2
148
+ hop //= 2
149
+
150
+ else :
151
+ x_down = x
152
+
153
+ Q = float (self .filter_scale )/ (2 ** (1 / self .bins_per_octave )- 1 )
154
+
155
+ basis , self .n_fft , lengths , _ = create_cqt_kernels (Q ,
156
+ my_sr ,
157
+ self .fmin_t * 2 ** - i ,
158
+ self .n_filters ,
159
+ self .bins_per_octave ,
160
+ norm = self .basis_norm ,
161
+ topbin_check = False ,
162
+ gamma = self .gamma )
163
+
164
+ cqt_kernels_real = torch .tensor (basis .real .astype (np .float32 )).unsqueeze (1 )
165
+ cqt_kernels_imag = torch .tensor (basis .imag .astype (np .float32 )).unsqueeze (1 )
166
+
167
+ if self .pad_mode == 'constant' :
168
+ my_padding = nn .ConstantPad1d (cqt_kernels_real .shape [- 1 ] // 2 , 0 )
169
+ elif self .pad_mode == 'reflect' :
170
+ my_padding = nn .ReflectionPad1d (cqt_kernels_real .shape [- 1 ] // 2 )
171
+
172
+ cur_vqt = get_cqt_complex (x_down , cqt_kernels_real , cqt_kernels_imag , hop , my_padding )
173
+ vqt .insert (0 , cur_vqt )
174
+
175
+ vqt = torch .cat (vqt , dim = 1 )
176
+ vqt = vqt [:,- self .n_bins :,:] # Removing unwanted bottom bins
177
+ vqt = vqt * self .downsample_factor
178
+
179
+ # Normalize again to get same result as librosa
180
+ if normalization_type == 'librosa' :
181
+ vqt = vqt * torch .sqrt (self .lenghts .view (- 1 ,1 ,1 ))
182
+ elif normalization_type == 'convolutional' :
183
+ pass
184
+ elif normalization_type == 'wrap' :
185
+ vqt *= 2
186
+ else :
187
+ raise ValueError ("The normalization_type %r is not part of our current options." % normalization_type )
188
+
189
+ if output_format == 'Magnitude' :
190
+ if self .trainable == False :
191
+ # Getting CQT Amplitude
192
+ return torch .sqrt (vqt .pow (2 ).sum (- 1 ))
193
+ else :
194
+ return torch .sqrt (vqt .pow (2 ).sum (- 1 ) + 1e-8 )
195
+
196
+ elif output_format == 'Complex' :
197
+ return vqt
198
+
199
+ elif output_format == 'Phase' :
200
+ phase_real = torch .cos (torch .atan2 (vqt [:,:,:,1 ], vqt [:,:,:,0 ]))
201
+ phase_imag = torch .sin (torch .atan2 (vqt [:,:,:,1 ], vqt [:,:,:,0 ]))
202
+ return torch .stack ((phase_real ,phase_imag ), - 1 )
0 commit comments