Kaldi-based forced alignment
This section describes in detail how to use kaldi-decoder
for FST-based forced alignment
with models trained by CTC loss.
Prepare the environment
Before you continue, make sure you have setup icefall by following Installation.
Get the test data
We use the test wave from CTC FORCED ALIGNMENT API TUTORIAL
import torchaudio
# Download test wave
speech_file = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
print(speech_file)
waveform, sr = torchaudio.load(speech_file)
transcript = "i had that curiosity beside me at this moment".split()
print(waveform.shape, sr)
assert waveform.ndim == 2
assert waveform.shape[0] == 1
assert sr == 16000
The test wave is downloaded to:
$HOME/.cache/torch/hub/torchaudio/tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav
Wave filename | Content | Text |
---|---|---|
Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav | i had that curiosity beside me at this moment |
We use the test model from CTC FORCED ALIGNMENT API TUTORIAL
import torch
bundle = torchaudio.pipelines.MMS_FA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = bundle.get_model(with_star=False).to(device)
The model is downloaded to:
$HOME/.cache/torch/hub/checkpoints/model.pt
Compute log_probs
with torch.inference_mode():
emission, _ = model(waveform.to(device))
print(emission.shape)
It should print:
torch.Size([1, 169, 28])
Create token2id and id2token
token2id = bundle.get_dict(star=None)
id2token = {i:t for t, i in token2id.items()}
token2id["<eps>"] = 0
del token2id["-"]
Create word2id and id2word
words = list(set(transcript))
word2id = dict()
word2id['eps'] = 0
for i, w in enumerate(words):
word2id[w] = i + 1
id2word = {i:w for w, i in word2id.items()}
Note that we only use words from the transcript of the test wave.
Convert transcript to an FST graph
egs/librispeech/ASR/local/prepare_lang_fst.py --lang-dir ./
The above command should generate two files H.fst
and HL.fst
. We will
use HL.fst
below:
-rw-r--r-- 1 root root 13K Jun 12 08:28 H.fst
-rw-r--r-- 1 root root 3.7K Jun 12 08:28 HL.fst
Force aligner
Now, everything is ready. We can use the following code to get forced alignments.
from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions
import kaldifst
def force_align():
HL = kaldifst.StdVectorFst.read("./HL.fst")
decodable = DecodableCtc(emission[0].contiguous().cpu().numpy())
decoder_opts = FasterDecoderOptions(max_active=3000)
decoder = FasterDecoder(HL, decoder_opts)
decoder.decode(decodable)
if not decoder.reached_final():
print(f"failed to decode xxx")
return None
ok, best_path = decoder.get_best_path()
(
ok,
isymbols_out,
osymbols_out,
total_weight,
) = kaldifst.get_linear_symbol_sequence(best_path)
if not ok:
print(f"failed to get linear symbol sequence for xxx")
return None
# We need to use i-1 here since we have incremented tokens during
# HL construction
alignment = [i-1 for i in isymbols_out]
return alignment
alignment = force_align()
for i, a in enumerate(alignment):
print(i, id2token[a])
The output should be identical to https://pytorch.org/audio/main/tutorials/ctc_forced_alignment_api_tutorial.html#frame-level-alignments.
For ease of reference, we list the output below:
0 -
1 -
2 -
3 -
4 -
5 -
6 -
7 -
8 -
9 -
10 -
11 -
12 -
13 -
14 -
15 -
16 -
17 -
18 -
19 -
20 -
21 -
22 -
23 -
24 -
25 -
26 -
27 -
28 -
29 -
30 -
31 -
32 i
33 -
34 -
35 h
36 h
37 a
38 -
39 -
40 -
41 d
42 -
43 -
44 t
45 h
46 -
47 a
48 -
49 -
50 t
51 -
52 -
53 -
54 c
55 -
56 -
57 -
58 u
59 u
60 -
61 -
62 -
63 r
64 -
65 i
66 -
67 -
68 -
69 -
70 -
71 -
72 o
73 -
74 -
75 -
76 -
77 -
78 -
79 s
80 -
81 -
82 -
83 i
84 -
85 t
86 -
87 -
88 y
89 -
90 -
91 -
92 -
93 b
94 -
95 e
96 -
97 -
98 -
99 -
100 -
101 s
102 -
103 -
104 -
105 -
106 -
107 -
108 -
109 -
110 i
111 -
112 -
113 d
114 e
115 -
116 m
117 -
118 -
119 e
120 -
121 -
122 -
123 -
124 a
125 -
126 -
127 t
128 -
129 t
130 h
131 -
132 i
133 -
134 -
135 -
136 s
137 -
138 -
139 -
140 -
141 m
142 -
143 -
144 o
145 -
146 -
147 -
148 m
149 -
150 -
151 e
152 -
153 n
154 -
155 t
156 -
157 -
158 -
159 -
160 -
161 -
162 -
163 -
164 -
165 -
166 -
167 -
168 -
To merge tokens, we use:
from icefall.ctc import merge_tokens
token_spans = merge_tokens(alignment)
for span in token_spans:
print(id2token[span.token], span.start, span.end)
The output is given below:
i 32 33
h 35 37
a 37 38
d 41 42
t 44 45
h 45 46
a 47 48
t 50 51
c 54 55
u 58 60
r 63 64
i 65 66
o 72 73
s 79 80
i 83 84
t 85 86
y 88 89
b 93 94
e 95 96
s 101 102
i 110 111
d 113 114
e 114 115
m 116 117
e 119 120
a 124 125
t 127 128
t 129 130
h 130 131
i 132 133
s 136 137
m 141 142
o 144 145
m 148 149
e 151 152
n 153 154
t 155 156
All of the code below is copied and modified from https://pytorch.org/audio/main/tutorials/ctc_forced_alignment_api_tutorial.html.
Segment each word using the computed alignments
def unflatten(list_, lengths):
assert len(list_) == sum(lengths)
i = 0
ret = []
for l in lengths:
ret.append(list_[i : i + l])
i += l
return ret
word_spans = unflatten(token_spans, [len(word) for word in transcript])
print(word_spans)
The output is:
[[TokenSpan(token=2, start=32, end=33)],
[TokenSpan(token=15, start=35, end=37), TokenSpan(token=1, start=37, end=38), TokenSpan(token=13, start=41, end=42)],
[TokenSpan(token=7, start=44, end=45), TokenSpan(token=15, start=45, end=46), TokenSpan(token=1, start=47, end=48), TokenSpan(token=7, start=50, end=51)],
[TokenSpan(token=20, start=54, end=55), TokenSpan(token=6, start=58, end=60), TokenSpan(token=9, start=63, end=64), TokenSpan(token=2, start=65, end=66), TokenSpan(token=5, start=72, end=73), TokenSpan(token=8, start=79, end=80), TokenSpan(token=2, start=83, end=84), TokenSpan(token=7, start=85, end=86), TokenSpan(token=16, start=88, end=89)],
[TokenSpan(token=17, start=93, end=94), TokenSpan(token=3, start=95, end=96), TokenSpan(token=8, start=101, end=102), TokenSpan(token=2, start=110, end=111), TokenSpan(token=13, start=113, end=114), TokenSpan(token=3, start=114, end=115)],
[TokenSpan(token=10, start=116, end=117), TokenSpan(token=3, start=119, end=120)],
[TokenSpan(token=1, start=124, end=125), TokenSpan(token=7, start=127, end=128)],
[TokenSpan(token=7, start=129, end=130), TokenSpan(token=15, start=130, end=131), TokenSpan(token=2, start=132, end=133), TokenSpan(token=8, start=136, end=137)],
[TokenSpan(token=10, start=141, end=142), TokenSpan(token=5, start=144, end=145), TokenSpan(token=10, start=148, end=149), TokenSpan(token=3, start=151, end=152), TokenSpan(token=4, start=153, end=154), TokenSpan(token=7, start=155, end=156)]
]
def preview_word(waveform, spans, num_frames, transcript, sample_rate=bundle.sample_rate):
ratio = waveform.size(1) / num_frames
x0 = int(ratio * spans[0].start)
x1 = int(ratio * spans[-1].end)
print(f"{transcript} {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec")
segment = waveform[:, x0:x1]
return IPython.display.Audio(segment.numpy(), rate=sample_rate)
num_frames = emission.size(1)
preview_word(waveform, word_spans[0], num_frames, transcript[0])
preview_word(waveform, word_spans[1], num_frames, transcript[1])
preview_word(waveform, word_spans[2], num_frames, transcript[2])
preview_word(waveform, word_spans[3], num_frames, transcript[3])
preview_word(waveform, word_spans[4], num_frames, transcript[4])
preview_word(waveform, word_spans[5], num_frames, transcript[5])
preview_word(waveform, word_spans[6], num_frames, transcript[6])
preview_word(waveform, word_spans[7], num_frames, transcript[7])
preview_word(waveform, word_spans[8], num_frames, transcript[8])
The segmented wave of each word along with its time stamp is given below:
Word | Time | Wave |
---|---|---|
i | 0.644 - 0.664 sec | |
had | 0.704 - 0.845 sec | |
that | 0.885 - 1.026 sec | |
curiosity | 1.086 - 1.790 sec | |
beside | 1.871 - 2.314 sec | |
me | 2.334 - 2.414 sec | |
at | 2.495 - 2.575 sec | |
this | 2.595 - 2.756 sec | |
moment | 2.837 - 3.138 sec |
We repost the whole wave below for ease of reference:
Wave filename | Content | Text |
---|---|---|
Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav | i had that curiosity beside me at this moment |
Summary
Congratulations! You have succeeded in using the FST-based approach to compute alignment of a test wave.