Skip to content

Commit e620136

Browse files
authored
Merge pull request #113 from gudgud96/dev-vqt
feat: Add Variable-Q Transform
2 parents 2db9a49 + 3018729 commit e620136

File tree

4 files changed

+290
-8
lines changed

4 files changed

+290
-8
lines changed

Installation/nnAudio/features/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@
1111
from .griffin_lim import *
1212
from .mel import *
1313
from .stft import *
14+
from .vqt import *

Installation/nnAudio/features/vqt.py

+202
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import torch
2+
import torch.nn as nn
3+
import numpy as np
4+
from time import time
5+
from ..librosa_functions import *
6+
from ..utils import *
7+
8+
9+
class VQT(torch.nn.Module):
10+
def __init__(
11+
self,
12+
sr=22050,
13+
hop_length=512,
14+
fmin=32.70,
15+
fmax=None,
16+
n_bins=84,
17+
filter_scale=1,
18+
bins_per_octave=12,
19+
norm=True,
20+
basis_norm=1,
21+
gamma=0,
22+
window='hann',
23+
pad_mode='reflect',
24+
earlydownsample=True,
25+
trainable=False,
26+
output_format='Magnitude',
27+
verbose=True
28+
):
29+
30+
super().__init__()
31+
32+
self.norm = norm
33+
self.hop_length = hop_length
34+
self.pad_mode = pad_mode
35+
self.n_bins = n_bins
36+
self.earlydownsample = earlydownsample
37+
self.trainable = trainable
38+
self.output_format = output_format
39+
self.filter_scale = filter_scale
40+
self.bins_per_octave = bins_per_octave
41+
self.sr = sr
42+
self.gamma = gamma
43+
self.basis_norm = basis_norm
44+
45+
# It will be used to calculate filter_cutoff and creating CQT kernels
46+
Q = float(filter_scale)/(2**(1/bins_per_octave)-1)
47+
48+
# Creating lowpass filter and make it a torch tensor
49+
if verbose==True:
50+
print("Creating low pass filter ...", end='\r')
51+
start = time()
52+
lowpass_filter = torch.tensor(create_lowpass_filter(
53+
band_center = 0.50,
54+
kernelLength=256,
55+
transitionBandwidth=0.001)
56+
)
57+
58+
self.register_buffer('lowpass_filter', lowpass_filter[None,None,:])
59+
if verbose == True:
60+
print("Low pass filter created, time used = {:.4f} seconds".format(time()-start))
61+
62+
n_filters = min(bins_per_octave, n_bins)
63+
self.n_filters = n_filters
64+
self.n_octaves = int(np.ceil(float(n_bins) / bins_per_octave))
65+
if verbose == True:
66+
print("num_octave = ", self.n_octaves)
67+
68+
self.fmin_t = fmin * 2 ** (self.n_octaves - 1)
69+
remainder = n_bins % bins_per_octave
70+
71+
if remainder==0:
72+
# Calculate the top bin frequency
73+
fmax_t = self.fmin_t*2**((bins_per_octave-1)/bins_per_octave)
74+
else:
75+
# Calculate the top bin frequency
76+
fmax_t = self.fmin_t*2**((remainder-1)/bins_per_octave)
77+
78+
# Adjusting the top minimum bins
79+
self.fmin_t = fmax_t / 2 ** (1 - 1 / bins_per_octave)
80+
if fmax_t > sr/2:
81+
raise ValueError('The top bin {}Hz has exceeded the Nyquist frequency, \
82+
please reduce the n_bins'.format(fmax_t))
83+
84+
if self.earlydownsample == True: # Do early downsampling if this argument is True
85+
if verbose == True:
86+
print("Creating early downsampling filter ...", end='\r')
87+
start = time()
88+
sr, self.hop_length, self.downsample_factor, early_downsample_filter, \
89+
self.earlydownsample = get_early_downsample_params(sr,
90+
hop_length,
91+
fmax_t,
92+
Q,
93+
self.n_octaves,
94+
verbose)
95+
self.register_buffer('early_downsample_filter', early_downsample_filter)
96+
97+
if verbose==True:
98+
print("Early downsampling filter created, \
99+
time used = {:.4f} seconds".format(time()-start))
100+
else:
101+
self.downsample_factor = 1.
102+
103+
# For normalization in the end
104+
# The freqs returned by create_cqt_kernels cannot be used
105+
# Since that returns only the top octave bins
106+
# We need the information for all freq bin
107+
alpha = 2.0 ** (1.0 / bins_per_octave) - 1.0
108+
freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave))
109+
self.frequencies = freqs
110+
lenghts = np.ceil(Q * sr / (freqs + gamma / alpha))
111+
112+
# get max window length depending on gamma value
113+
max_len = int(max(lenghts))
114+
self.n_fft = int(2 ** (np.ceil(np.log2(max_len))))
115+
116+
lenghts = torch.tensor(lenghts).float()
117+
self.register_buffer('lenghts', lenghts)
118+
119+
120+
def forward(self, x, output_format=None, normalization_type='librosa'):
121+
"""
122+
Convert a batch of waveforms to VQT spectrograms.
123+
124+
Parameters
125+
----------
126+
x : torch tensor
127+
Input signal should be in either of the following shapes.\n
128+
1. ``(len_audio)``\n
129+
2. ``(num_audio, len_audio)``\n
130+
3. ``(num_audio, 1, len_audio)``
131+
It will be automatically broadcast to the right shape
132+
"""
133+
output_format = output_format or self.output_format
134+
135+
x = broadcast_dim(x)
136+
if self.earlydownsample==True:
137+
x = downsampling_by_n(x, self.early_downsample_filter, self.downsample_factor)
138+
hop = self.hop_length
139+
vqt = []
140+
141+
x_down = x # Preparing a new variable for downsampling
142+
my_sr = self.sr
143+
144+
for i in range(self.n_octaves):
145+
if i > 0:
146+
x_down = downsampling_by_2(x_down, self.lowpass_filter)
147+
my_sr /= 2
148+
hop //= 2
149+
150+
else:
151+
x_down = x
152+
153+
Q = float(self.filter_scale)/(2**(1/self.bins_per_octave)-1)
154+
155+
basis, self.n_fft, lengths, _ = create_cqt_kernels(Q,
156+
my_sr,
157+
self.fmin_t * 2 ** -i,
158+
self.n_filters,
159+
self.bins_per_octave,
160+
norm=self.basis_norm,
161+
topbin_check=False,
162+
gamma=self.gamma)
163+
164+
cqt_kernels_real = torch.tensor(basis.real.astype(np.float32)).unsqueeze(1)
165+
cqt_kernels_imag = torch.tensor(basis.imag.astype(np.float32)).unsqueeze(1)
166+
167+
if self.pad_mode == 'constant':
168+
my_padding = nn.ConstantPad1d(cqt_kernels_real.shape[-1] // 2, 0)
169+
elif self.pad_mode == 'reflect':
170+
my_padding= nn.ReflectionPad1d(cqt_kernels_real.shape[-1] // 2)
171+
172+
cur_vqt = get_cqt_complex(x_down, cqt_kernels_real, cqt_kernels_imag, hop, my_padding)
173+
vqt.insert(0, cur_vqt)
174+
175+
vqt = torch.cat(vqt, dim=1)
176+
vqt = vqt[:,-self.n_bins:,:] # Removing unwanted bottom bins
177+
vqt = vqt * self.downsample_factor
178+
179+
# Normalize again to get same result as librosa
180+
if normalization_type == 'librosa':
181+
vqt = vqt * torch.sqrt(self.lenghts.view(-1,1,1))
182+
elif normalization_type == 'convolutional':
183+
pass
184+
elif normalization_type == 'wrap':
185+
vqt *= 2
186+
else:
187+
raise ValueError("The normalization_type %r is not part of our current options." % normalization_type)
188+
189+
if output_format=='Magnitude':
190+
if self.trainable==False:
191+
# Getting CQT Amplitude
192+
return torch.sqrt(vqt.pow(2).sum(-1))
193+
else:
194+
return torch.sqrt(vqt.pow(2).sum(-1) + 1e-8)
195+
196+
elif output_format=='Complex':
197+
return vqt
198+
199+
elif output_format=='Phase':
200+
phase_real = torch.cos(torch.atan2(vqt[:,:,:,1], vqt[:,:,:,0]))
201+
phase_imag = torch.sin(torch.atan2(vqt[:,:,:,1], vqt[:,:,:,0]))
202+
return torch.stack((phase_real,phase_imag), -1)

Installation/nnAudio/utils.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def rfft_fn(x, n=None, onesided=False):
3939
else:
4040
return torch.rfft(x, n, onesided=onesided)
4141

42-
4342
## --------------------------- Filter Design ---------------------------##
4443
def torch_window_sumsquare(w, n_frames, stride, n_fft, power=2):
4544
w_stacks = w.unsqueeze(-1).repeat((1, n_frames)).unsqueeze(0)
@@ -407,6 +406,8 @@ def create_cqt_kernels(
407406
window="hann",
408407
fmax=None,
409408
topbin_check=True,
409+
gamma=0,
410+
pad_fft=True
410411
):
411412
"""
412413
Automatically create CQT kernels in time domain
@@ -439,25 +440,28 @@ def create_cqt_kernels(
439440
)
440441
)
441442

443+
alpha = 2.0 ** (1.0 / bins_per_octave) - 1.0
444+
lengths = np.ceil(Q * fs / (freqs + gamma / alpha))
445+
446+
# get max window length depending on gamma value
447+
max_len = int(max(lengths))
448+
fftLen = int(2 ** (np.ceil(np.log2(max_len))))
449+
442450
tempKernel = np.zeros((int(n_bins), int(fftLen)), dtype=np.complex64)
443451
specKernel = np.zeros((int(n_bins), int(fftLen)), dtype=np.complex64)
444452

445-
lengths = np.ceil(Q * fs / freqs)
446453
for k in range(0, int(n_bins)):
447454
freq = freqs[k]
448-
l = np.ceil(Q * fs / freq)
455+
l = lengths[k]
449456

450457
# Centering the kernels
451458
if l % 2 == 1: # pad more zeros on RHS
452459
start = int(np.ceil(fftLen / 2.0 - l / 2.0)) - 1
453460
else:
454461
start = int(np.ceil(fftLen / 2.0 - l / 2.0))
455462

456-
sig = (
457-
get_window_dispatch(window, int(l), fftbins=True)
458-
* np.exp(np.r_[-l // 2 : l // 2] * 1j * 2 * np.pi * freq / fs)
459-
/ l
460-
)
463+
window_dispatch = get_window_dispatch(window, int(l), fftbins=True)
464+
sig = window_dispatch * np.exp(np.r_[-l // 2 : l // 2] * 1j * 2 * np.pi * freq / fs) / l
461465

462466
if norm: # Normalizing the filter # Trying to normalize like librosa
463467
tempKernel[k, start : start + int(l)] = sig / np.linalg.norm(sig, norm)

Installation/tests/test_vqt.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import pytest
2+
import librosa
3+
import torch
4+
import sys
5+
6+
sys.path.insert(0, "./")
7+
8+
import os
9+
10+
dir_path = os.path.dirname(os.path.realpath(__file__))
11+
12+
from nnAudio.features import CQT2010v2, VQT
13+
import numpy as np
14+
from parameters import *
15+
import warnings
16+
17+
gpu_idx = 0 # Choose which GPU to use
18+
19+
# If GPU is avaliable, also test on GPU
20+
if torch.cuda.is_available():
21+
device_args = ["cpu", f"cuda:{gpu_idx}"]
22+
else:
23+
warnings.warn("GPU is not avaliable, testing only on CPU")
24+
device_args = ["cpu"]
25+
26+
# librosa example audio for testing
27+
y, sr = librosa.load(librosa.ex('choice'), duration=5)
28+
29+
@pytest.mark.parametrize("device", [*device_args])
30+
def test_vqt_gamma_zero(device):
31+
# nnAudio cqt
32+
spec = CQT2010v2(sr=sr, verbose=False)
33+
C2 = spec(torch.tensor(y).unsqueeze(0), output_format="Magnitude", normalization_type='librosa')
34+
C2 = C2.numpy().squeeze()
35+
36+
# nnAudio vqt
37+
spec = VQT(sr=sr, gamma=0, verbose=False)
38+
V2 = spec(torch.tensor(y).unsqueeze(0), output_format="Magnitude", normalization_type='librosa')
39+
V2 = V2.numpy().squeeze()
40+
41+
assert (C2 == V2).all() == True
42+
43+
44+
@pytest.mark.parametrize("device", [*device_args])
45+
def test_vqt(device):
46+
for gamma in [0, 1, 2, 5, 10]:
47+
48+
# librosa vqt
49+
V1 = np.abs(librosa.vqt(y, sr=sr, gamma=gamma))
50+
51+
# nnAudio vqt
52+
spec = VQT(sr=sr, gamma=gamma, verbose=False)
53+
V2 = spec(torch.tensor(y).unsqueeze(0), output_format="Magnitude", normalization_type='librosa')
54+
V2 = V2.numpy().squeeze()
55+
56+
# NOTE: there will still be some diff between librosa and nnAudio vqt values (same as cqt)
57+
# mainly due to the lengths of both - librosa uses float but nnAudio uses int
58+
# this test aims to keep the diff range within a baseline threshold
59+
vqt_diff = np.abs(V1 - V2)
60+
61+
if gamma == 0:
62+
assert np.amin(vqt_diff) < 1e-8
63+
assert np.amax(vqt_diff) < 0.6785
64+
elif gamma == 1:
65+
assert np.amin(vqt_diff) < 1e-8
66+
assert np.amax(vqt_diff) < 0.6510
67+
elif gamma == 2:
68+
assert np.amin(vqt_diff) < 1e-8
69+
assert np.amax(vqt_diff) < 0.5962
70+
elif gamma == 5:
71+
assert np.amin(vqt_diff) < 1e-8
72+
assert np.amax(vqt_diff) < 0.3695
73+
else:
74+
assert np.amin(vqt_diff) < 1e-8
75+
assert np.amax(vqt_diff) < 0.1

0 commit comments

Comments
 (0)