Skip to content

Commit 3bba9bd

Browse files
authored
add xi_vector (#404)
* add xi_vector * fix flake8 error * fix flake8 error * fix flake8 error * update Readme * fix lint errors
1 parent 91a3d6e commit 3bba9bd

File tree

7 files changed

+226
-2
lines changed

7 files changed

+226
-2
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ pre-commit install # for clean and tidy code
6060
```
6161

6262
## 🔥 News
63-
63+
* 2025.02.23: Add support for the Xi-vector, see [#404](https://github.com/wenet-e2e/wespeaker/pull/404).
6464
* 2024.09.03: Support the SimAM_ResNet and the model pretrained on VoxBlink2, check [Pretrained Models](docs/pretrained.md) for the pretrained model, [VoxCeleb Recipe](https://github.com/wenet-e2e/wespeaker/tree/master/examples/voxceleb/v2) for the super performance, and [python usage](docs/python_package.md) for the command line usage!
6565
* 2024.08.30: We support whisper_encoder based frontend and propose the [Whisper-PMFA](https://arxiv.org/pdf/2408.15585) framework, check [#356](https://github.com/wenet-e2e/wespeaker/pull/356).
6666
* 2024.08.20: Update diarization recipe for VoxConverse dataset by leveraging umap dimensionality reduction and hdbscan clustering, see [#347](https://github.com/wenet-e2e/wespeaker/pull/347) and [#352](https://github.com/wenet-e2e/wespeaker/pull/352).

examples/voxceleb/v2/README.md

+2
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@
6262
| SimAM_ResNet100 (VoxBlink2 Pretrain) | 50.2M | || x | × | 0.229 | 0.458 | 0.868 |
6363
| | | ||| × | 0.207 | 0.424 | 0.804 |
6464
| | | |||| 0.202 | 0.421 | 0.795 |
65+
| XI_VEC_ECAPA_TDNN_c512 | 5.9M | 0.68G | x | x | × | 0.995 | 1.130 | 2.169 |
66+
| | | | × || × | 0.883 | 1.056 | 1.976 |
6567

6668

6769
## PLDA results
+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
### train configuraton
2+
3+
exp_dir: exp/XI_VEC_ECAPA_TDNN_c512-emb192-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150
4+
gpus: "[0]"
5+
num_avg: 10
6+
enable_amp: False # whether enable automatic mixed precision training
7+
8+
seed: 42
9+
num_epochs: 150
10+
save_epoch_interval: 5 # save model every 5 epochs
11+
log_batch_interval: 100 # log every 100 batchs
12+
13+
dataloader_args:
14+
batch_size: 512
15+
num_workers: 16
16+
pin_memory: False
17+
prefetch_factor: 8
18+
drop_last: True
19+
20+
dataset_args:
21+
# the sample number which will be traversed within one epoch, if the value equals to 0,
22+
# the utterance number in the dataset will be used as the sample_num_per_epoch.
23+
sample_num_per_epoch: 0
24+
shuffle: True
25+
shuffle_args:
26+
shuffle_size: 2500
27+
filter: True
28+
filter_args:
29+
min_num_frames: 100
30+
max_num_frames: 800
31+
resample_rate: 16000
32+
speed_perturb: True
33+
num_frms: 200
34+
aug_prob: 0.6 # prob to add reverb & noise aug per sample
35+
frontend: "fbank" # fbank, s3prl
36+
fbank_args:
37+
num_mel_bins: 80
38+
frame_shift: 10
39+
frame_length: 25
40+
dither: 1.0
41+
spec_aug: False
42+
spec_aug_args:
43+
num_t_mask: 1
44+
num_f_mask: 1
45+
max_t: 10
46+
max_f: 8
47+
prob: 0.6
48+
49+
model: XI_VEC_ECAPA_TDNN_c512 # XI_VEC_ECAPA_TDNN_c512, XI_VEC_ECAPA_TDNN_c1024, XI_VEC_XVEC
50+
model_init: null
51+
model_args:
52+
feat_dim: 80
53+
embed_dim: 192
54+
pooling_func: "XI" # the default pooling_func in ECAPA_TDNN is ASTP
55+
projection_args:
56+
project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax
57+
scale: 32.0
58+
easy_margin: False
59+
60+
margin_scheduler: MarginScheduler
61+
margin_update:
62+
initial_margin: 0.0
63+
final_margin: 0.2
64+
increase_start_epoch: 20
65+
fix_start_epoch: 40
66+
update_margin: True
67+
increase_type: "exp" # exp, linear
68+
69+
loss: CrossEntropyLoss
70+
loss_args: {}
71+
72+
optimizer: SGD
73+
optimizer_args:
74+
momentum: 0.9
75+
nesterov: True
76+
weight_decay: 0.0001
77+
78+
scheduler: ExponentialDecrease
79+
scheduler_args:
80+
initial_lr: 0.1
81+
final_lr: 0.00005
82+
warm_up_epoch: 6
83+
warm_from_zero: True

wespeaker/models/ecapa_tdnn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def ECAPA_TDNN_GLOB_c512(feat_dim,
281281
pooling_func='ASTP')
282282
model.eval()
283283
out = model(x)
284-
print(out.shape)
284+
print(out[-1].shape)
285285

286286
num_params = sum(param.numel() for param in model.parameters())
287287
print("{} M".format(num_params / 1e6))

wespeaker/models/pooling_layers.py

+71
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,77 @@ def get_out_dim(self):
310310
return self.out_dim
311311

312312

313+
class XI(torch.nn.Module):
314+
def __init__(self, in_dim, hidden_size=256, stddev=False,
315+
train_mean=True, train_prec=True, **kwargs):
316+
super(XI, self).__init__()
317+
self.input_dim = in_dim
318+
self.stddev = stddev
319+
if self.stddev:
320+
self.output_dim = 2 * self.input_dim
321+
else:
322+
self.output_dim = self.input_dim
323+
self.prior_mean = torch.nn.Parameter(torch.zeros(1, self.input_dim),
324+
requires_grad=train_mean)
325+
self.prior_logprec = torch.nn.Parameter(torch.zeros(1, self.input_dim),
326+
requires_grad=train_prec)
327+
self.softmax = torch.nn.Softmax(dim=2)
328+
329+
# Log-precision estimator
330+
self.lin1_relu_bn = nn.Sequential(
331+
nn.Conv1d(self.input_dim, hidden_size,
332+
kernel_size=1, stride=1, bias=True),
333+
nn.ReLU(inplace=True),
334+
nn.BatchNorm1d(hidden_size))
335+
self.lin2 = nn.Conv1d(hidden_size, self.input_dim, kernel_size=1,
336+
stride=1, bias=True)
337+
self.softplus2 = torch.nn.Softplus(beta=1, threshold=20)
338+
339+
def forward(self, inputs):
340+
"""
341+
@inputs: a 3-dimensional tensor (a batch),
342+
including [samples-index, frames-dim-index, frames-index]
343+
"""
344+
assert len(inputs.shape) == 3
345+
assert inputs.shape[1] == self.input_dim
346+
feat = inputs
347+
# Log-precision estimator
348+
# frame precision estimate
349+
logprec = self.softplus2(self.lin2(self.lin1_relu_bn(feat)))
350+
351+
# Square and take log before softmax
352+
logprec = 2.0 * torch.log(logprec)
353+
# Gaussian Posterior Inference
354+
# Option 1: a_o (prior_mean-phi) included in variance
355+
weight_attn = self.softmax(
356+
torch.cat(
357+
(logprec,
358+
self.prior_logprec.repeat(
359+
logprec.shape[0], 1).unsqueeze(dim=2)), 2))
360+
# Posterior precision
361+
Ls = torch.sum(torch.exp(torch.cat(
362+
(logprec, self.prior_logprec.repeat(
363+
logprec.shape[0], 1).unsqueeze(dim=2)), 2)), dim=2)
364+
# Posterior mean
365+
phi = torch.sum(torch.cat(
366+
(feat, self.prior_mean.repeat(
367+
feat.shape[0], 1).unsqueeze(dim=2)), 2) * weight_attn, dim=2)
368+
369+
if self.stddev:
370+
sigma2 = torch.sum(torch.cat((
371+
feat, self.prior_mean.repeat(
372+
feat.shape[0], 1).unsqueeze(dim=2)), 2).pow(2) * weight_attn, dim=2)
373+
sigma = torch.sqrt(torch.clamp(sigma2 - phi ** 2, min=1.0e-12))
374+
return torch.cat((phi, sigma), dim=1).unsqueeze(dim=2)
375+
else:
376+
return phi
377+
378+
def get_out_dim(self):
379+
return self.output_dim
380+
381+
def get_prior(self):
382+
return self.prior_mean, self.prior_logprec
383+
313384
if __name__ == '__main__':
314385
data = torch.randn(16, 512, 10, 35)
315386
# model = StatisticsPooling()

wespeaker/models/speaker_model.py

+3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import wespeaker.models.whisper_PMFA as whisper_PMFA
2424
import wespeaker.models.redimnet as redimnet
2525
import wespeaker.models.samresnet as samresnet
26+
import wespeaker.models.xi_vector as xi_vector
2627

2728

2829

@@ -49,6 +50,8 @@ def get_speaker_model(model_name: str):
4950
return getattr(redimnet, model_name)
5051
elif model_name.startswith("SimAM_ResNet"):
5152
return getattr(samresnet, model_name)
53+
elif model_name.startswith("XI_VEC"):
54+
return getattr(xi_vector, model_name)
5255
else: # model_name error !!!
5356
print(model_name + " not found !!!")
5457
exit(1)

wespeaker/models/xi_vector.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) 2025 Shuai Wang ([email protected])
2+
# 2025 Junjie LI ([email protected])
3+
# 2025 Tianchi Liu ([email protected])
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
'''The implementation of Xi_vector.
17+
18+
Reference:
19+
[1] Lee, K. A., Wang, Q., & Koshinaka, T. (2021). Xi-vector embedding
20+
for speaker recognition. IEEE Signal Processing Letters, 28, 1385-1389.
21+
'''
22+
23+
24+
import torch
25+
import wespeaker.models.ecapa_tdnn as ecapa_tdnn
26+
import wespeaker.models.tdnn as tdnn
27+
28+
29+
30+
31+
def XI_VEC_ECAPA_TDNN_c1024(feat_dim, embed_dim, pooling_func='XI', emb_bn=False):
32+
return ecapa_tdnn.ECAPA_TDNN(channels=1024,
33+
feat_dim=feat_dim,
34+
embed_dim=embed_dim,
35+
pooling_func=pooling_func,
36+
emb_bn=emb_bn)
37+
38+
39+
def XI_VEC_ECAPA_TDNN_c512(feat_dim, embed_dim, pooling_func='XI', emb_bn=False):
40+
return ecapa_tdnn.ECAPA_TDNN(channels=512,
41+
feat_dim=feat_dim,
42+
embed_dim=embed_dim,
43+
pooling_func=pooling_func,
44+
emb_bn=emb_bn)
45+
46+
47+
48+
def XI_VEC_XVEC(feat_dim, embed_dim, pooling_func='XI'):
49+
return tdnn.XVEC(feat_dim=feat_dim, embed_dim=embed_dim, pooling_func=pooling_func)
50+
51+
52+
if __name__ == '__main__':
53+
x = torch.rand(1, 200, 80)
54+
model = XI_VEC_XVEC(feat_dim=80, embed_dim=512, pooling_func='XI')
55+
model.eval()
56+
y = model(x)
57+
print(y[-1].size())
58+
59+
num_params = sum(p.numel() for p in model.parameters())
60+
print("{} M".format(num_params / 1e6))
61+
62+
from thop import profile
63+
x_np = torch.randn(1, 200, 80)
64+
flops, params = profile(model, inputs=(x_np, ))
65+
print("FLOPs: {} G, Params: {} M".format(flops / 1e9, params / 1e6))

0 commit comments

Comments
 (0)