Skip to content

Commit 5ac089e

Browse files
xx205czy97
andauthored
Add diarization recipe v3 (#347)
* Add diarization recipe v3 * resolve pylint issues and add missing modifications * eliminate trailing whitespace * deterministic clustering; update README.md * fix args usage in umap_clusterer.py * local import; remove unused diarization args; self.model.eval() when init * compact embedding clustering procedure into a single source file * link to local and path.sh; update requirements.txt and extract_emb.py * fix lint error: extract_emb.py * Update README.md Update News section in README.md * Update voxconverse/v3/README.md Update clustering method * Update README.md --------- Co-authored-by: Zhengyang Chen <[email protected]>
1 parent 8934efe commit 5ac089e

13 files changed

+515
-64
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ pre-commit install # for clean and tidy code
6060
```
6161

6262
## 🔥 News
63+
* 2024.08.20: Update diarization recipe for VoxConverse dataset by leveraging umap dimensionality reduction and hdbscan clustering, see [#347](https://github.com/wenet-e2e/wespeaker/pull/347).
6364
* 2024.08.18: Support using ssl pre-trained models as the frontend. The [WavLM recipe](https://github.com/wenet-e2e/wespeaker/blob/master/examples/voxceleb/v2/run_wavlm.sh) is also provided, see [#344](https://github.com/wenet-e2e/wespeaker/pull/344).
6465
* 2024.05.15: Add support for [quality-aware score calibration](https://arxiv.org/pdf/2211.00815), see [#320](https://github.com/wenet-e2e/wespeaker/pull/320).
6566
* 2024.04.25: Add support for the gemini-dfresnet model, see [#291](https://github.com/wenet-e2e/wespeaker/pull/291).

examples/voxconverse/v3/README.md

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
## Overview
2+
3+
* We suggest to run this recipe on a gpu-available machine, with onnxruntime-gpu supported.
4+
* Dataset: voxconverse_dev that consists of 216 utterances
5+
* Speaker model: ResNet34 model pretrained by wespeaker
6+
* Refer to [voxceleb sv recipe](https://github.com/wenet-e2e/wespeaker/tree/master/examples/voxceleb/v2)
7+
* [pretrained model path](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34_LM.onnx)
8+
* Speaker activity detection model: oracle SAD (from ground truth annotation) or system SAD (VAD model pretrained by silero, https://github.com/snakers4/silero-vad)
9+
* Clustering method: umap dimensionality reduction + hdbscan clustering
10+
* Metric: DER = MISS + FALSE ALARM + SPEAKER CONFUSION (%)
11+
12+
## Results
13+
14+
* Dev set
15+
16+
| system | MISS | FA | SC | DER |
17+
|:---|:---:|:---:|:---:|:---:|
18+
| This repo (with oracle SAD) | 2.3 | 0.0 | 1.3 | 3.6 |
19+
| This repo (with system SAD) | 3.4 | 0.6 | 1.4 | 5.4 |
20+
| DIHARD 2019 baseline [^1] | 11.1 | 1.4 | 11.3 | 23.8 |
21+
| DIHARD 2019 baseline w/ SE [^1] | 9.3 | 1.3 | 9.7 | 20.2 |
22+
| (SyncNet ASD only) [^1] | 2.2 | 4.1 | 4.0 | 10.4 |
23+
| (AVSE ASD only) [^1] | 2.0 | 5.9 | 4.6 | 12.4 |
24+
| (proposed) [^1] | 2.4 | 2.3 | 3.0 | 7.7 |
25+
26+
* Test set
27+
28+
| system | MISS | FA | SC | DER |
29+
|:---|:---:|:---:|:---:|:---:|
30+
| This repo (with oracle SAD) | 1.6 | 0.0 | 1.9 | 3.5 |
31+
| This repo (with system SAD) | 3.8 | 1.7 | 1.8 | 7.4 |
32+
33+
34+
[^1]: Spot the conversation: speaker diarisation in the wild, https://arxiv.org/pdf/2007.01216.pdf

examples/voxconverse/v3/local

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../v2/local

examples/voxconverse/v3/path.sh

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../v2/path.sh

examples/voxconverse/v3/run.sh

+186
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
#!/bin/bash
2+
# Copyright (c) 2022-2023 Xu Xiang
3+
# 2022 Zhengyang Chen ([email protected])
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
. ./path.sh || exit 1
18+
19+
stage=-1
20+
stop_stage=-1
21+
sad_type="oracle"
22+
partition="dev"
23+
24+
# do cmn on the sub-segment or on the vad segment
25+
subseg_cmn=true
26+
# whether print the evaluation result for each file
27+
get_each_file_res=1
28+
29+
. tools/parse_options.sh
30+
31+
# Prerequisite
32+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
33+
mkdir -p external_tools
34+
35+
# [1] Download evaluation toolkit
36+
wget -c https://github.com/usnistgov/SCTK/archive/refs/tags/v2.4.12.zip -O external_tools/SCTK-v2.4.12.zip
37+
unzip -o external_tools/SCTK-v2.4.12.zip -d external_tools
38+
39+
# [3] Download ResNet34 speaker model pretrained by WeSpeaker Team
40+
mkdir -p pretrained_models
41+
42+
wget -c https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34_LM.onnx -O pretrained_models/voxceleb_resnet34_LM.onnx
43+
fi
44+
45+
46+
# Download VoxConverse dev/test audios and the corresponding annotations
47+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
48+
mkdir -p data
49+
50+
# Download annotations for dev and test sets (version 0.0.3)
51+
wget -c https://github.com/joonson/voxconverse/archive/refs/heads/master.zip -O data/voxconverse_master.zip
52+
unzip -o data/voxconverse_master.zip -d data
53+
54+
# Download annotations from VoxSRC-23 validation toolkit (looks like version 0.0.2)
55+
# cd data && git clone https://github.com/JaesungHuh/VoxSRC2023.git --recursive && cd -
56+
57+
# Download dev audios
58+
mkdir -p data/dev
59+
60+
#wget --no-check-certificate -c https://mm.kaist.ac.kr/datasets/voxconverse/data/voxconverse_dev_wav.zip -O data/voxconverse_dev_wav.zip
61+
# The above url may not be reachable, you can try the link below.
62+
# This url is from https://github.com/joonson/voxconverse/blob/master/README.md
63+
wget --no-check-certificate -c https://www.robots.ox.ac.uk/~vgg/data/voxconverse/data/voxconverse_dev_wav.zip -O data/voxconverse_dev_wav.zip
64+
unzip -o data/voxconverse_dev_wav.zip -d data/dev
65+
66+
# Create wav.scp for dev audios
67+
ls `pwd`/data/dev/audio/*.wav | awk -F/ '{print substr($NF, 1, length($NF)-4), $0}' > data/dev/wav.scp
68+
69+
# Test audios
70+
mkdir -p data/test
71+
72+
#wget --no-check-certificate -c https://mm.kaist.ac.kr/datasets/voxconverse/data/voxconverse_test_wav.zip -O data/voxconverse_test_wav.zip
73+
# The above url may not be reachable, you can try the link below.
74+
# This url is from https://github.com/joonson/voxconverse/blob/master/README.md
75+
wget --no-check-certificate -c https://www.robots.ox.ac.uk/~vgg/data/voxconverse/data/voxconverse_test_wav.zip -O data/voxconverse_test_wav.zip
76+
unzip -o data/voxconverse_test_wav.zip -d data/test
77+
78+
# Create wav.scp for test audios
79+
ls `pwd`/data/test/voxconverse_test_wav/*.wav | awk -F/ '{print substr($NF, 1, length($NF)-4), $0}' > data/test/wav.scp
80+
fi
81+
82+
83+
# Voice activity detection
84+
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
85+
# Set VAD min duration
86+
min_duration=0.255
87+
88+
if [[ "x${sad_type}" == "xoracle" ]]; then
89+
# Oracle SAD: handling overlapping or too short regions in ground truth RTTM
90+
while read -r utt wav_path; do
91+
python3 wespeaker/diar/make_oracle_sad.py \
92+
--rttm data/voxconverse-master/${partition}/${utt}.rttm \
93+
--min-duration $min_duration
94+
done < data/${partition}/wav.scp > data/${partition}/oracle_sad
95+
fi
96+
97+
if [[ "x${sad_type}" == "xsystem" ]]; then
98+
# System SAD: applying 'silero' VAD
99+
python3 wespeaker/diar/make_system_sad.py \
100+
--scp data/${partition}/wav.scp \
101+
--min-duration $min_duration > data/${partition}/system_sad
102+
fi
103+
fi
104+
105+
106+
# Extract fbank features
107+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
108+
109+
[ -d "exp/${sad_type}_sad_fbank" ] && rm -r exp/${sad_type}_sad_fbank
110+
111+
echo "Make Fbank features and store it under exp/${sad_type}_sad_fbank"
112+
echo "..."
113+
bash local/make_fbank.sh \
114+
--scp data/${partition}/wav.scp \
115+
--segments data/${partition}/${sad_type}_sad \
116+
--store_dir exp/${partition}_${sad_type}_sad_fbank \
117+
--subseg_cmn ${subseg_cmn} \
118+
--nj 24
119+
fi
120+
121+
# Extract embeddings
122+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
123+
124+
[ -d "exp/${sad_type}_sad_embedding" ] && rm -r exp/${sad_type}_sad_embedding
125+
126+
echo "Extract embeddings and store it under exp/${sad_type}_sad_embedding"
127+
echo "..."
128+
bash local/extract_emb.sh \
129+
--scp exp/${partition}_${sad_type}_sad_fbank/fbank.scp \
130+
--pretrained_model pretrained_models/voxceleb_resnet34_LM.onnx \
131+
--device cuda \
132+
--store_dir exp/${partition}_${sad_type}_sad_embedding \
133+
--batch_size 96 \
134+
--frame_shift 10 \
135+
--window_secs 1.5 \
136+
--period_secs 0.75 \
137+
--subseg_cmn ${subseg_cmn} \
138+
--nj 1
139+
fi
140+
141+
142+
# Applying umap clustering algorithm
143+
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
144+
145+
[ -f "exp/umap_cluster/${partition}_${sad_type}_sad_labels" ] && rm exp/umap_cluster/${partition}_${sad_type}_sad_labels
146+
147+
echo "Doing umap clustering and store the result in exp/umap_cluster/${partition}_${sad_type}_sad_labels"
148+
echo "..."
149+
python3 wespeaker/diar/umap_clusterer.py \
150+
--scp exp/${partition}_${sad_type}_sad_embedding/emb.scp \
151+
--output exp/umap_cluster/${partition}_${sad_type}_sad_labels
152+
fi
153+
154+
155+
# Convert labels to RTTMs
156+
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
157+
python3 wespeaker/diar/make_rttm.py \
158+
--labels exp/umap_cluster/${partition}_${sad_type}_sad_labels \
159+
--channel 1 > exp/umap_cluster/${partition}_${sad_type}_sad_rttm
160+
fi
161+
162+
163+
# Evaluate the result
164+
if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then
165+
ref_dir=data/voxconverse-master/
166+
#ref_dir=data/VoxSRC2023/voxconverse/
167+
echo -e "Get the DER results\n..."
168+
perl external_tools/SCTK-2.4.12/src/md-eval/md-eval.pl \
169+
-c 0.25 \
170+
-r <(cat ${ref_dir}/${partition}/*.rttm) \
171+
-s exp/umap_cluster/${partition}_${sad_type}_sad_rttm 2>&1 | tee exp/umap_cluster/${partition}_${sad_type}_sad_res
172+
173+
if [ ${get_each_file_res} -eq 1 ];then
174+
single_file_res_dir=exp/umap_cluster/${partition}_${sad_type}_single_file_res
175+
mkdir -p $single_file_res_dir
176+
echo -e "\nGet the DER results for each file and the results will be stored underd ${single_file_res_dir}\n..."
177+
178+
awk '{print $2}' exp/umap_cluster/${partition}_${sad_type}_sad_rttm | sort -u | while read file_name; do
179+
perl external_tools/SCTK-2.4.12/src/md-eval/md-eval.pl \
180+
-c 0.25 \
181+
-r <(cat ${ref_dir}/${partition}/${file_name}.rttm) \
182+
-s <(grep "${file_name}" exp/umap_cluster/${partition}_${sad_type}_sad_rttm) > ${single_file_res_dir}/${partition}_${file_name}_res
183+
done
184+
echo "Done!"
185+
fi
186+
fi

examples/voxconverse/v3/tools

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../../tools

examples/voxconverse/v3/wespeaker

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../../wespeaker

requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,5 @@ pypeln==0.4.9
2424
silero-vad
2525
pre-commit==3.5.0
2626
s3prl
27+
hdbscan==0.8.37
28+
umap-learn==0.5.6

wespeaker/cli/speaker.py

+9-23
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from wespeaker.cli.utils import get_args
3030
from wespeaker.models.speaker_model import get_speaker_model
3131
from wespeaker.utils.checkpoint import load_checkpoint
32-
from wespeaker.diar.spectral_clusterer import cluster
32+
from wespeaker.diar.umap_clusterer import cluster
3333
from wespeaker.diar.extract_emb import subsegment
3434
from wespeaker.diar.make_rttm import merge_segments
3535
from wespeaker.utils.utils import set_seed
@@ -47,6 +47,7 @@ def __init__(self, model_dir: str):
4747
self.model = get_speaker_model(
4848
configs['model'])(**configs['model_args'])
4949
load_checkpoint(self.model, model_path)
50+
self.model.eval()
5051
self.vad = load_silero_vad()
5152
self.table = {}
5253
self.resample_rate = 16000
@@ -55,9 +56,6 @@ def __init__(self, model_dir: str):
5556
self.wavform_norm = False
5657

5758
# diarization parmas
58-
self.diar_num_spks = None
59-
self.diar_min_num_spks = 1
60-
self.diar_max_num_spks = 20
6159
self.diar_min_duration = 0.255
6260
self.diar_window_secs = 1.5
6361
self.diar_period_secs = 0.75
@@ -83,18 +81,12 @@ def set_gpu(self, device_id: int):
8381
self.model = self.model.to(self.device)
8482

8583
def set_diarization_params(self,
86-
num_spks=None,
87-
min_num_spks=1,
88-
max_num_spks=20,
8984
min_duration: float = 0.255,
9085
window_secs: float = 1.5,
9186
period_secs: float = 0.75,
9287
frame_shift: int = 10,
9388
batch_size: int = 32,
9489
subseg_cmn: bool = True):
95-
self.diar_num_spks = num_spks
96-
self.diar_min_num_spks = min_num_spks
97-
self.diar_max_num_spks = max_num_spks
9890
self.diar_min_duration = min_duration
9991
self.diar_window_secs = window_secs
10092
self.diar_period_secs = period_secs
@@ -127,10 +119,10 @@ def extract_embedding_feats(self, fbanks, batch_size, subseg_cmn):
127119
fbanks_array = torch.from_numpy(fbanks_array).to(self.device)
128120
for i in tqdm(range(0, fbanks_array.shape[0], batch_size)):
129121
batch_feats = fbanks_array[i:i + batch_size]
130-
# _, batch_embs = self.model(batch_feats)
131-
batch_embs = self.model(batch_feats)
132-
batch_embs = batch_embs[-1] if isinstance(batch_embs,
133-
tuple) else batch_embs
122+
with torch.no_grad():
123+
batch_embs = self.model(batch_feats)
124+
batch_embs = batch_embs[-1] if isinstance(batch_embs,
125+
tuple) else batch_embs
134126
embeddings.append(batch_embs.detach().cpu().numpy())
135127
embeddings = np.vstack(embeddings)
136128
return embeddings
@@ -162,7 +154,7 @@ def extract_embedding(self, audio_path: str):
162154
cmn=True)
163155
feats = feats.unsqueeze(0)
164156
feats = feats.to(self.device)
165-
self.model.eval()
157+
166158
with torch.no_grad():
167159
outputs = self.model(feats)
168160
outputs = outputs[-1] if isinstance(outputs, tuple) else outputs
@@ -251,10 +243,7 @@ def diarize(self, audio_path: str, utt: str = "unk"):
251243

252244
# 4. cluster
253245
subseg2label = []
254-
labels = cluster(embeddings,
255-
num_spks=self.diar_num_spks,
256-
min_num_spks=self.diar_min_num_spks,
257-
max_num_spks=self.diar_max_num_spks)
246+
labels = cluster(embeddings)
258247
for (_subseg, _label) in zip(subsegs, labels):
259248
# b, e = process_seg_id(_subseg, frame_shift=self.diar_frame_shift)
260249
# subseg2label.append([b, e, _label])
@@ -316,10 +305,7 @@ def main():
316305
model.set_resample_rate(args.resample_rate)
317306
model.set_vad(args.vad)
318307
model.set_gpu(args.gpu)
319-
model.set_diarization_params(num_spks=args.diar_num_spks,
320-
min_num_spks=args.diar_min_num_spks,
321-
max_num_spks=args.diar_max_num_spks,
322-
min_duration=args.diar_min_duration,
308+
model.set_diarization_params(min_duration=args.diar_min_duration,
323309
window_secs=args.diar_window_secs,
324310
period_secs=args.diar_period_secs,
325311
frame_shift=args.diar_frame_shift,

wespeaker/cli/utils.py

-12
Original file line numberDiff line numberDiff line change
@@ -75,18 +75,6 @@ def get_args():
7575
help='output file to save speaker embedding '
7676
'or save diarization result')
7777
# diarization params
78-
parser.add_argument('--diar_num_spks',
79-
type=int,
80-
default=None,
81-
help='number of speakers')
82-
parser.add_argument('--diar_min_num_spks',
83-
type=int,
84-
default=1,
85-
help='minimum number of speakers')
86-
parser.add_argument('--diar_max_num_spks',
87-
type=int,
88-
default=20,
89-
help='maximum number of speakers')
9078
parser.add_argument('--diar_min_duration',
9179
type=float,
9280
default=0.255,

wespeaker/diar/extract_emb.py

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def init_session(source, device):
3737
opts = ort.SessionOptions()
3838
opts.inter_op_num_threads = 1
3939
opts.intra_op_num_threads = 1
40+
opts.log_severity_level = 0
4041
session = ort.InferenceSession(source,
4142
sess_options=opts,
4243
providers=providers)

0 commit comments

Comments
 (0)