Yuekai Zhang
commited on
Commit
·
4fa997e
1
Parent(s):
21f2415
update files
Browse files- README.md +11 -2
- generate_perf_input.py +134 -0
- mid.wav +0 -0
- model_repo_cuda_decoder/attention_rescoring/1/.gitkeep +0 -0
- model_repo_cuda_decoder/scoring/1/decoder.py +21 -19
- model_repo_cuda_decoder/scoring/1/frame_reducer.py +193 -0
- model_repo_cuda_decoder/scoring/1/model.py +52 -7
- model_repo_cuda_decoder/scoring/1/wfst_decoding_config.yaml +2 -2
- model_repo_cuda_decoder/scoring/config.pbtxt +2 -2
- model_repo_cuda_decoder/scoring/config.pbtxt.template +1 -1
- run.sh +1 -4
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
-
###
|
2 |
-
|
3 |
```
|
4 |
docker pull soar97/triton-wenet:22.12
|
5 |
docker run -it --rm --name "wenet_tlg_test" --gpus all --shm-size 1g --net host soar97/triton-wenet:22.12
|
@@ -8,4 +8,13 @@ git clone https://huggingface.co/yuekai/model_repo_conformer_aishell_wenet_tlg.g
|
|
8 |
cd model_repo_conformer_aishell_wenet_tlg
|
9 |
bash run.sh
|
10 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
|
|
1 |
+
### Tutorial
|
2 |
+
Start Server
|
3 |
```
|
4 |
docker pull soar97/triton-wenet:22.12
|
5 |
docker run -it --rm --name "wenet_tlg_test" --gpus all --shm-size 1g --net host soar97/triton-wenet:22.12
|
|
|
8 |
cd model_repo_conformer_aishell_wenet_tlg
|
9 |
bash run.sh
|
10 |
```
|
11 |
+
Start Client
|
12 |
+
```
|
13 |
+
pip3 install tritonclient[all]==2.29
|
14 |
+
apt-get install -y libsndfile1
|
15 |
+
pip3 install soundfile
|
16 |
+
|
17 |
+
python3 generate_perf_input.py --audio_file ./mid.wav
|
18 |
+
perf_analyzer -m attention_rescoring -b 1 -p 20000 --concurrency-range 100 -i gRPC --input-data=offline_input.json -u localhost:8001
|
19 |
+
```
|
20 |
|
generate_perf_input.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import soundfile as sf
|
4 |
+
import numpy as np
|
5 |
+
import argparse
|
6 |
+
import math
|
7 |
+
|
8 |
+
|
9 |
+
def generate_offline_input(args):
|
10 |
+
wav_file = args.audio_file
|
11 |
+
print("Reading {}".format(wav_file))
|
12 |
+
waveform, sample_rate = sf.read(wav_file)
|
13 |
+
batch_size = 1
|
14 |
+
mat = np.array([waveform] * batch_size, dtype=np.float32)
|
15 |
+
|
16 |
+
out_dict = {
|
17 |
+
"data": [
|
18 |
+
{
|
19 |
+
"WAV_LENS": [len(waveform)],
|
20 |
+
"WAV": {
|
21 |
+
"shape": [len(waveform)],
|
22 |
+
"content": mat.flatten().tolist(),
|
23 |
+
},
|
24 |
+
}
|
25 |
+
]
|
26 |
+
}
|
27 |
+
json.dump(out_dict, open("offline_input.json", "w"))
|
28 |
+
|
29 |
+
|
30 |
+
def generate_online_input(args):
|
31 |
+
wav_file = args.audio_file
|
32 |
+
waveform, sample_rate = sf.read(wav_file)
|
33 |
+
chunk_size, subsampling = args.chunk_size, args.subsampling
|
34 |
+
context = args.context
|
35 |
+
first_chunk_length = (chunk_size - 1) * subsampling + context
|
36 |
+
frame_length_ms, frame_shift_ms = args.frame_length_ms, args.frame_shift_ms
|
37 |
+
# for the first chunk,
|
38 |
+
# we need additional frame to generate the exact first chunk length frames
|
39 |
+
add_frames = math.ceil((frame_length_ms - frame_shift_ms) / frame_shift_ms)
|
40 |
+
first_chunk_ms = (first_chunk_length + add_frames) * frame_shift_ms
|
41 |
+
other_chunk_ms = chunk_size * subsampling * frame_shift_ms
|
42 |
+
first_chunk_s = first_chunk_ms / 1000
|
43 |
+
other_chunk_s = other_chunk_ms / 1000
|
44 |
+
|
45 |
+
wav_segs = []
|
46 |
+
i = 0
|
47 |
+
while i < len(waveform):
|
48 |
+
if i == 0:
|
49 |
+
stride = int(first_chunk_s * sample_rate)
|
50 |
+
wav_segs.append(waveform[i : i + stride])
|
51 |
+
else:
|
52 |
+
stride = int(other_chunk_s * sample_rate)
|
53 |
+
wav_segs.append(waveform[i : i + stride])
|
54 |
+
i += len(wav_segs[-1])
|
55 |
+
|
56 |
+
data = {"data": [[]]}
|
57 |
+
|
58 |
+
for idx, seg in enumerate(wav_segs): # 0, num_frames + 5, 64
|
59 |
+
chunk_len = len(seg)
|
60 |
+
if idx == 0:
|
61 |
+
length = int(first_chunk_s * sample_rate)
|
62 |
+
expect_input = np.zeros((1, length), dtype=np.float32)
|
63 |
+
else:
|
64 |
+
length = int(other_chunk_s * sample_rate)
|
65 |
+
expect_input = np.zeros((1, length), dtype=np.float32)
|
66 |
+
|
67 |
+
expect_input[0][0:chunk_len] = seg
|
68 |
+
|
69 |
+
flat_chunk = expect_input.flatten().astype(np.float32).tolist()
|
70 |
+
seq = {
|
71 |
+
"WAV": {"content": flat_chunk, "shape": expect_input[0].shape},
|
72 |
+
"WAV_LENS": [chunk_len],
|
73 |
+
}
|
74 |
+
data["data"][0].append(seq)
|
75 |
+
|
76 |
+
json.dump(data, open("online_input.json", "w"))
|
77 |
+
|
78 |
+
|
79 |
+
if __name__ == "__main__":
|
80 |
+
parser = argparse.ArgumentParser()
|
81 |
+
parser.add_argument(
|
82 |
+
"--audio_file", type=str, default=None, help="single wav file"
|
83 |
+
)
|
84 |
+
# below is only for streaming input
|
85 |
+
parser.add_argument("--streaming", action="store_true", required=False)
|
86 |
+
parser.add_argument(
|
87 |
+
"--sample_rate",
|
88 |
+
type=int,
|
89 |
+
required=False,
|
90 |
+
default=16000,
|
91 |
+
help="sample rate used in training",
|
92 |
+
)
|
93 |
+
parser.add_argument(
|
94 |
+
"--frame_length_ms",
|
95 |
+
type=int,
|
96 |
+
required=False,
|
97 |
+
default=25,
|
98 |
+
help="frame length used in training",
|
99 |
+
)
|
100 |
+
parser.add_argument(
|
101 |
+
"--frame_shift_ms",
|
102 |
+
type=int,
|
103 |
+
required=False,
|
104 |
+
default=10,
|
105 |
+
help="frame shift length used in training",
|
106 |
+
)
|
107 |
+
parser.add_argument(
|
108 |
+
"--chunk_size",
|
109 |
+
type=int,
|
110 |
+
required=False,
|
111 |
+
default=16,
|
112 |
+
help="chunk size default is 16",
|
113 |
+
)
|
114 |
+
parser.add_argument(
|
115 |
+
"--context",
|
116 |
+
type=int,
|
117 |
+
required=False,
|
118 |
+
default=7,
|
119 |
+
help="conformer context default is 7",
|
120 |
+
)
|
121 |
+
parser.add_argument(
|
122 |
+
"--subsampling",
|
123 |
+
type=int,
|
124 |
+
required=False,
|
125 |
+
default=4,
|
126 |
+
help="subsampling rate default is 4",
|
127 |
+
)
|
128 |
+
|
129 |
+
args = parser.parse_args()
|
130 |
+
|
131 |
+
if args.streaming and os.path.exists(args.audio_file):
|
132 |
+
generate_online_input(args)
|
133 |
+
else:
|
134 |
+
generate_offline_input(args)
|
mid.wav
ADDED
Binary file (160 kB). View file
|
|
model_repo_cuda_decoder/attention_rescoring/1/.gitkeep
ADDED
File without changes
|
model_repo_cuda_decoder/scoring/1/decoder.py
CHANGED
@@ -3,6 +3,7 @@ import torch
|
|
3 |
from typing import List
|
4 |
from riva.asrlib.decoder.python_decoder import (BatchedMappedDecoderCuda,
|
5 |
BatchedMappedDecoderCudaConfig)
|
|
|
6 |
|
7 |
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
8 |
"""Make mask tensor containing indices of padded part.
|
@@ -81,52 +82,53 @@ class RivaWFSTDecoder:
|
|
81 |
|
82 |
config.online_opts.lattice_postprocessor_opts.nbest = beam_size
|
83 |
|
|
|
|
|
84 |
self.decoder = BatchedMappedDecoderCuda(
|
85 |
config, os.path.join(tlg_dir, "TLG.fst"),
|
86 |
os.path.join(tlg_dir, "words.txt"), vocab_size
|
87 |
)
|
88 |
self.word_id_to_word_str = load_word_symbols(os.path.join(tlg_dir, "words.txt"))
|
89 |
self.nbest = beam_size
|
|
|
|
|
90 |
|
91 |
def decode_nbest(self, logits, length):
|
|
|
92 |
logits = logits.to(torch.float32).contiguous()
|
93 |
sequence_lengths_tensor = length.to(torch.long).to('cpu').contiguous()
|
94 |
-
before = logits.shape
|
95 |
-
if logits.shape[0] == 1:
|
96 |
-
logits = logits.repeat(2,1,1)
|
97 |
-
sequence_lengths_tensor = sequence_lengths_tensor.repeat(2)
|
98 |
-
print(before, logits.shape)
|
99 |
results = self.decoder.decode_nbest(logits, sequence_lengths_tensor)
|
100 |
-
if logits.shape[0] == 1:
|
101 |
-
results = results[0:1]
|
102 |
total_hyps, total_hyps_id = [], []
|
|
|
103 |
for nbest_sentences in results:
|
104 |
-
nbest_list, nbest_id_list = []
|
105 |
for sent in nbest_sentences:
|
106 |
-
# subtract 1 to get the label id,
|
|
|
107 |
hyp_ids = [label - 1 for label in sent.ilabels]
|
108 |
-
|
|
|
|
|
109 |
nbest_id_list.append(new_hyp)
|
110 |
|
111 |
-
hyp = "".join(self.word_id_to_word_str[word]
|
|
|
112 |
nbest_list.append(hyp)
|
113 |
-
|
114 |
total_hyps.append(nbest_list)
|
|
|
115 |
total_hyps_id.append(nbest_id_list)
|
116 |
-
return total_hyps, total_hyps_id
|
117 |
|
118 |
def decode_mbr(self, logits, length):
|
|
|
|
|
119 |
logits = logits.to(torch.float32).contiguous()
|
120 |
sequence_lengths_tensor = length.to(torch.long).to('cpu').contiguous()
|
121 |
-
if logits.shape[0] == 1:
|
122 |
-
logits = logits.repeat(2,1,1)
|
123 |
-
sequence_lengths_tensor = sequence_lengths_tensor.repeat(2)
|
124 |
results = self.decoder.decode_mbr(logits, sequence_lengths_tensor)
|
125 |
-
if logits.shape[0] == 1:
|
126 |
-
results = results[0:1]
|
127 |
total_hyps = []
|
128 |
for sent in results:
|
129 |
hyp = [word[0] for word in sent]
|
130 |
hyp_zh = "".join(hyp)
|
131 |
total_hyps.append(hyp_zh)
|
132 |
-
return total_hyps
|
|
|
3 |
from typing import List
|
4 |
from riva.asrlib.decoder.python_decoder import (BatchedMappedDecoderCuda,
|
5 |
BatchedMappedDecoderCudaConfig)
|
6 |
+
from frame_reducer import FrameReducer
|
7 |
|
8 |
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
9 |
"""Make mask tensor containing indices of padded part.
|
|
|
82 |
|
83 |
config.online_opts.lattice_postprocessor_opts.nbest = beam_size
|
84 |
|
85 |
+
# config.online_opts.decoder_opts.blank_penalty = -5.0
|
86 |
+
|
87 |
self.decoder = BatchedMappedDecoderCuda(
|
88 |
config, os.path.join(tlg_dir, "TLG.fst"),
|
89 |
os.path.join(tlg_dir, "words.txt"), vocab_size
|
90 |
)
|
91 |
self.word_id_to_word_str = load_word_symbols(os.path.join(tlg_dir, "words.txt"))
|
92 |
self.nbest = beam_size
|
93 |
+
self.vocab_size = vocab_size
|
94 |
+
self.frame_reducer = FrameReducer(0.98)
|
95 |
|
96 |
def decode_nbest(self, logits, length):
|
97 |
+
logits, length = self.frame_reducer(logits, length.cuda(), logits)
|
98 |
logits = logits.to(torch.float32).contiguous()
|
99 |
sequence_lengths_tensor = length.to(torch.long).to('cpu').contiguous()
|
|
|
|
|
|
|
|
|
|
|
100 |
results = self.decoder.decode_nbest(logits, sequence_lengths_tensor)
|
|
|
|
|
101 |
total_hyps, total_hyps_id = [], []
|
102 |
+
max_hyp_len = 3
|
103 |
for nbest_sentences in results:
|
104 |
+
nbest_list, nbest_id_list = [], []
|
105 |
for sent in nbest_sentences:
|
106 |
+
# subtract 1 to get the label id,
|
107 |
+
# since fst decoder adds 1 to the label id
|
108 |
hyp_ids = [label - 1 for label in sent.ilabels]
|
109 |
+
# padding for hyps_pad_sos_eos
|
110 |
+
new_hyp = [self.vocab_size - 1] + remove_duplicates_and_blank(hyp_ids, eos=self.vocab_size - 1, blank_id=0) + [self.vocab_size - 1] # noqa
|
111 |
+
max_hyp_len = max(max_hyp_len, len(new_hyp))
|
112 |
nbest_id_list.append(new_hyp)
|
113 |
|
114 |
+
hyp = "".join(self.word_id_to_word_str[word]
|
115 |
+
for word in sent.words if word != 0)
|
116 |
nbest_list.append(hyp)
|
117 |
+
nbest_list += [""] * (self.nbest - len(nbest_list))
|
118 |
total_hyps.append(nbest_list)
|
119 |
+
nbest_id_list += [[self.vocab_size - 1, 0, self.vocab_size - 1]] * (self.nbest - len(nbest_id_list)) # noqa
|
120 |
total_hyps_id.append(nbest_id_list)
|
121 |
+
return total_hyps, total_hyps_id, max_hyp_len
|
122 |
|
123 |
def decode_mbr(self, logits, length):
|
124 |
+
logits, length = self.frame_reducer(logits, length.cuda(), logits)
|
125 |
+
# logits[:,:,0] -= 2.0
|
126 |
logits = logits.to(torch.float32).contiguous()
|
127 |
sequence_lengths_tensor = length.to(torch.long).to('cpu').contiguous()
|
|
|
|
|
|
|
128 |
results = self.decoder.decode_mbr(logits, sequence_lengths_tensor)
|
|
|
|
|
129 |
total_hyps = []
|
130 |
for sent in results:
|
131 |
hyp = [word[0] for word in sent]
|
132 |
hyp_zh = "".join(hyp)
|
133 |
total_hyps.append(hyp_zh)
|
134 |
+
return total_hyps
|
model_repo_cuda_decoder/scoring/1/frame_reducer.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
#
|
3 |
+
# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang,
|
4 |
+
# Zengwei Yao,
|
5 |
+
# Wei Kang)
|
6 |
+
#
|
7 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
|
21 |
+
import math
|
22 |
+
from typing import Optional, Tuple
|
23 |
+
|
24 |
+
import torch
|
25 |
+
import torch.nn as nn
|
26 |
+
import torch.nn.functional as F
|
27 |
+
|
28 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
29 |
+
"""
|
30 |
+
Args:
|
31 |
+
lengths:
|
32 |
+
A 1-D tensor containing sentence lengths.
|
33 |
+
max_len:
|
34 |
+
The length of masks.
|
35 |
+
Returns:
|
36 |
+
Return a 2-D bool tensor, where masked positions
|
37 |
+
are filled with `True` and non-masked positions are
|
38 |
+
filled with `False`.
|
39 |
+
>>> lengths = torch.tensor([1, 3, 2, 5])
|
40 |
+
>>> make_pad_mask(lengths)
|
41 |
+
tensor([[False, True, True, True, True],
|
42 |
+
[False, False, False, True, True],
|
43 |
+
[False, False, True, True, True],
|
44 |
+
[False, False, False, False, False]])
|
45 |
+
"""
|
46 |
+
assert lengths.ndim == 1, lengths.ndim
|
47 |
+
max_len = max(max_len, lengths.max())
|
48 |
+
n = lengths.size(0)
|
49 |
+
seq_range = torch.arange(0, max_len, device=lengths.device)
|
50 |
+
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
|
51 |
+
|
52 |
+
return expaned_lengths >= lengths.unsqueeze(-1)
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
class FrameReducer(nn.Module):
|
57 |
+
"""The encoder output is first used to calculate
|
58 |
+
the CTC posterior probability; then for each output frame,
|
59 |
+
if its blank posterior is bigger than some thresholds,
|
60 |
+
it will be simply discarded from the encoder output.
|
61 |
+
"""
|
62 |
+
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
blank_threshlod: float = 0.95,
|
66 |
+
):
|
67 |
+
super().__init__()
|
68 |
+
self.blank_threshlod = blank_threshlod
|
69 |
+
|
70 |
+
def forward(
|
71 |
+
self,
|
72 |
+
x: torch.Tensor,
|
73 |
+
x_lens: torch.Tensor,
|
74 |
+
ctc_output: torch.Tensor,
|
75 |
+
y_lens: Optional[torch.Tensor] = None,
|
76 |
+
blank_id: int = 0,
|
77 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
78 |
+
"""
|
79 |
+
Args:
|
80 |
+
x:
|
81 |
+
The shared encoder output with shape [N, T, C].
|
82 |
+
x_lens:
|
83 |
+
A tensor of shape (batch_size,) containing the number of frames in
|
84 |
+
`x` before padding.
|
85 |
+
ctc_output:
|
86 |
+
The CTC output with shape [N, T, vocab_size].
|
87 |
+
y_lens:
|
88 |
+
A tensor of shape (batch_size,) containing the number of frames in
|
89 |
+
`y` before padding.
|
90 |
+
blank_id:
|
91 |
+
The blank id of ctc_output.
|
92 |
+
Returns:
|
93 |
+
out:
|
94 |
+
The frame reduced encoder output with shape [N, T', C].
|
95 |
+
out_lens:
|
96 |
+
A tensor of shape (batch_size,) containing the number of frames in
|
97 |
+
`out` before padding.
|
98 |
+
"""
|
99 |
+
N, T, C = x.size()
|
100 |
+
|
101 |
+
padding_mask = make_pad_mask(x_lens, x.size(1))
|
102 |
+
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(self.blank_threshlod)) * (~padding_mask) # noqa
|
103 |
+
|
104 |
+
if y_lens is not None:
|
105 |
+
# Limit the maximum number of reduced frames
|
106 |
+
limit_lens = T - y_lens
|
107 |
+
max_limit_len = limit_lens.max().int()
|
108 |
+
fake_limit_indexes = torch.topk(
|
109 |
+
ctc_output[:, :, blank_id], max_limit_len
|
110 |
+
).indices
|
111 |
+
T = (
|
112 |
+
torch.arange(max_limit_len)
|
113 |
+
.expand_as(
|
114 |
+
fake_limit_indexes,
|
115 |
+
)
|
116 |
+
.to(device=x.device)
|
117 |
+
)
|
118 |
+
T = torch.remainder(T, limit_lens.unsqueeze(1))
|
119 |
+
limit_indexes = torch.gather(fake_limit_indexes, 1, T)
|
120 |
+
limit_mask = torch.full_like(
|
121 |
+
non_blank_mask,
|
122 |
+
False,
|
123 |
+
device=x.device,
|
124 |
+
).scatter_(1, limit_indexes, True)
|
125 |
+
|
126 |
+
non_blank_mask = non_blank_mask | ~limit_mask
|
127 |
+
|
128 |
+
out_lens = non_blank_mask.sum(dim=1)
|
129 |
+
max_len = out_lens.max()
|
130 |
+
pad_lens_list = (
|
131 |
+
torch.full_like(
|
132 |
+
out_lens,
|
133 |
+
max_len.item(),
|
134 |
+
device=x.device,
|
135 |
+
)
|
136 |
+
- out_lens
|
137 |
+
)
|
138 |
+
max_pad_len = pad_lens_list.max()
|
139 |
+
|
140 |
+
out = F.pad(x, (0, 0, 0, max_pad_len))
|
141 |
+
|
142 |
+
valid_pad_mask = ~make_pad_mask(pad_lens_list)
|
143 |
+
total_valid_mask = torch.concat([non_blank_mask, valid_pad_mask], dim=1)
|
144 |
+
|
145 |
+
out = out[total_valid_mask].reshape(N, -1, C)
|
146 |
+
|
147 |
+
return out, out_lens
|
148 |
+
|
149 |
+
|
150 |
+
if __name__ == "__main__":
|
151 |
+
import time
|
152 |
+
|
153 |
+
test_times = 10000
|
154 |
+
device = "cuda:0"
|
155 |
+
frame_reducer = FrameReducer()
|
156 |
+
|
157 |
+
# non zero case
|
158 |
+
x = torch.ones(15, 498, 384, dtype=torch.float32, device=device)
|
159 |
+
x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
|
160 |
+
y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
|
161 |
+
ctc_output = torch.log(
|
162 |
+
torch.randn(15, 498, 500, dtype=torch.float32, device=device),
|
163 |
+
)
|
164 |
+
|
165 |
+
avg_time = 0
|
166 |
+
for i in range(test_times):
|
167 |
+
torch.cuda.synchronize(device=x.device)
|
168 |
+
delta_time = time.time()
|
169 |
+
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
|
170 |
+
torch.cuda.synchronize(device=x.device)
|
171 |
+
delta_time = time.time() - delta_time
|
172 |
+
avg_time += delta_time
|
173 |
+
print(x_fr.shape)
|
174 |
+
print(x_lens_fr)
|
175 |
+
print(avg_time / test_times)
|
176 |
+
|
177 |
+
# all zero case
|
178 |
+
x = torch.zeros(15, 498, 384, dtype=torch.float32, device=device)
|
179 |
+
x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
|
180 |
+
y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
|
181 |
+
ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32, device=device)
|
182 |
+
|
183 |
+
avg_time = 0
|
184 |
+
for i in range(test_times):
|
185 |
+
torch.cuda.synchronize(device=x.device)
|
186 |
+
delta_time = time.time()
|
187 |
+
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
|
188 |
+
torch.cuda.synchronize(device=x.device)
|
189 |
+
delta_time = time.time() - delta_time
|
190 |
+
avg_time += delta_time
|
191 |
+
print(x_fr.shape)
|
192 |
+
print(x_lens_fr)
|
193 |
+
print(avg_time / test_times)
|
model_repo_cuda_decoder/scoring/1/model.py
CHANGED
@@ -16,7 +16,7 @@ import triton_python_backend_utils as pb_utils
|
|
16 |
import numpy as np
|
17 |
|
18 |
import torch
|
19 |
-
from torch.utils.dlpack import from_dlpack
|
20 |
import json
|
21 |
import os
|
22 |
import yaml
|
@@ -123,7 +123,7 @@ class TritonPythonModel:
|
|
123 |
self.eos = eos
|
124 |
self.ignore_id = ignore_id
|
125 |
|
126 |
-
if self.decoding_method
|
127 |
self.decoder = RivaWFSTDecoder(len(self.vocabulary),
|
128 |
self.tlg_dir,
|
129 |
self.tlg_decoding_config,
|
@@ -175,12 +175,57 @@ class TritonPythonModel:
|
|
175 |
encoder_out_len = torch.cat(encoder_out_lens_list, dim=0)
|
176 |
return encoder_out, encoder_out_len, logits, batch_count_list
|
177 |
|
178 |
-
def rescore_hyps(self,
|
179 |
"""
|
180 |
Rescore the hypotheses with attention rescoring
|
181 |
"""
|
182 |
-
|
183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
def prepare_response(self, hyps, batch_count_list):
|
186 |
"""
|
@@ -223,17 +268,17 @@ class TritonPythonModel:
|
|
223 |
ctc_log_probs = ctc_log_probs.cuda()
|
224 |
if self.decoding_method == "tlg_mbr":
|
225 |
total_hyps = self.decoder.decode_mbr(ctc_log_probs, encoder_out_len)
|
226 |
-
# list(str), list((float), list(int)) # TODO: add token_ids, time stamps
|
227 |
elif self.decoding_method == "ctc_greedy_search":
|
228 |
total_hyps = ctc_greedy_search(ctc_log_probs, encoder_out_len,
|
229 |
self.vocabulary, self.blank_id, self.eos)
|
230 |
elif self.decoding_method == "tlg":
|
231 |
-
nbest_hyps, nbest_ids = self.decoder.decode_nbest(
|
232 |
total_hyps = [nbest[0] for nbest in nbest_hyps]
|
233 |
|
234 |
if self.decoding_method == "tlg" and self.rescore:
|
235 |
assert self.beam_size > 1, "Beam size must be greater than 1 for rescoring"
|
236 |
selected_ids = self.rescore_hyps(nbest_ids,
|
|
|
237 |
encoder_out,
|
238 |
encoder_out_len)
|
239 |
total_hyps = [nbest[i] for nbest, i in zip(nbest_hyps, selected_ids)]
|
|
|
16 |
import numpy as np
|
17 |
|
18 |
import torch
|
19 |
+
from torch.utils.dlpack import from_dlpack, to_dlpack
|
20 |
import json
|
21 |
import os
|
22 |
import yaml
|
|
|
123 |
self.eos = eos
|
124 |
self.ignore_id = ignore_id
|
125 |
|
126 |
+
if "tlg" in self.decoding_method:
|
127 |
self.decoder = RivaWFSTDecoder(len(self.vocabulary),
|
128 |
self.tlg_dir,
|
129 |
self.tlg_decoding_config,
|
|
|
175 |
encoder_out_len = torch.cat(encoder_out_lens_list, dim=0)
|
176 |
return encoder_out, encoder_out_len, logits, batch_count_list
|
177 |
|
178 |
+
def rescore_hyps(self, total_tokens, max_hyp_len, encoder_out, encoder_out_len):
|
179 |
"""
|
180 |
Rescore the hypotheses with attention rescoring
|
181 |
"""
|
182 |
+
input1 = pb_utils.Tensor.from_dlpack("encoder_out", to_dlpack(encoder_out))
|
183 |
+
input2 = pb_utils.Tensor.from_dlpack("encoder_out_lens",
|
184 |
+
to_dlpack(encoder_out_len.unsqueeze(-1)))
|
185 |
+
hyps_pad_sos_eos = np.zeros([len(total_tokens),
|
186 |
+
self.beam_size, max_hyp_len], dtype=np.int64)
|
187 |
+
hyps_lens_sos = np.zeros([len(total_tokens), self.beam_size], dtype=np.int32)
|
188 |
+
ctc_scores = np.zeros([len(total_tokens),
|
189 |
+
self.beam_size], dtype=np.float16) # TODO: zero here
|
190 |
+
|
191 |
+
for i, hyps in enumerate(total_tokens):
|
192 |
+
for j, hyp in enumerate(hyps):
|
193 |
+
hyps_pad_sos_eos[i][j][:len(hyp)] = hyp
|
194 |
+
hyps_lens_sos[i][j] = len(hyp) - 1
|
195 |
+
input3 = pb_utils.Tensor("hyps_pad_sos_eos", hyps_pad_sos_eos)
|
196 |
+
input4 = pb_utils.Tensor("hyps_lens_sos", hyps_lens_sos)
|
197 |
+
input5 = pb_utils.Tensor("ctc_score", ctc_scores)
|
198 |
+
input_tensors = [input1, input2, input3, input4, input5]
|
199 |
+
|
200 |
+
if self.bidecoder:
|
201 |
+
r_hyps_pad_sos_eos = np.zeros([len(total_tokens),
|
202 |
+
self.beam_size, max_hyp_len], dtype=np.int64)
|
203 |
+
for i, hyps in enumerate(total_tokens):
|
204 |
+
for j, hyp in enumerate(hyps):
|
205 |
+
r_hyps_pad_sos_eos[i][j][:len(hyp)] = hyp[::-1]
|
206 |
+
input6 = pb_utils.Tensor.from_dlpack("r_hyps_pad_sos_eos",
|
207 |
+
r_hyps_pad_sos_eos)
|
208 |
+
input_tensors.insert(-1, input6)
|
209 |
+
|
210 |
+
inference_request = pb_utils.InferenceRequest(
|
211 |
+
model_name='decoder',
|
212 |
+
requested_output_names=['best_index'],
|
213 |
+
inputs=input_tensors)
|
214 |
+
|
215 |
+
inference_response = inference_request.exec()
|
216 |
+
if inference_response.has_error():
|
217 |
+
raise pb_utils.TritonModelException(inference_response.error().message())
|
218 |
+
else:
|
219 |
+
# Extract the output tensors from the inference response.
|
220 |
+
best_index = pb_utils.get_output_tensor_by_name(inference_response,
|
221 |
+
'best_index')
|
222 |
+
if best_index.is_cpu():
|
223 |
+
best_index = best_index.as_numpy()
|
224 |
+
else:
|
225 |
+
best_index = from_dlpack(best_index.to_dlpack())
|
226 |
+
best_index = best_index.cpu().numpy()
|
227 |
+
best_index = np.squeeze(best_index, -1).tolist()
|
228 |
+
return best_index
|
229 |
|
230 |
def prepare_response(self, hyps, batch_count_list):
|
231 |
"""
|
|
|
268 |
ctc_log_probs = ctc_log_probs.cuda()
|
269 |
if self.decoding_method == "tlg_mbr":
|
270 |
total_hyps = self.decoder.decode_mbr(ctc_log_probs, encoder_out_len)
|
|
|
271 |
elif self.decoding_method == "ctc_greedy_search":
|
272 |
total_hyps = ctc_greedy_search(ctc_log_probs, encoder_out_len,
|
273 |
self.vocabulary, self.blank_id, self.eos)
|
274 |
elif self.decoding_method == "tlg":
|
275 |
+
nbest_hyps, nbest_ids, max_hyp_len = self.decoder.decode_nbest(ctc_log_probs, encoder_out_len) # noqa
|
276 |
total_hyps = [nbest[0] for nbest in nbest_hyps]
|
277 |
|
278 |
if self.decoding_method == "tlg" and self.rescore:
|
279 |
assert self.beam_size > 1, "Beam size must be greater than 1 for rescoring"
|
280 |
selected_ids = self.rescore_hyps(nbest_ids,
|
281 |
+
max_hyp_len,
|
282 |
encoder_out,
|
283 |
encoder_out_len)
|
284 |
total_hyps = [nbest[i] for nbest, i in zip(nbest_hyps, selected_ids)]
|
model_repo_cuda_decoder/scoring/1/wfst_decoding_config.yaml
CHANGED
@@ -3,8 +3,8 @@ n_input_per_chunk: 50
|
|
3 |
default_beam: 17.0
|
4 |
max_active: 7000
|
5 |
determinize_lattice: True
|
6 |
-
max_batch_size:
|
7 |
-
num_channels:
|
8 |
frame_shift_seconds: 0.04
|
9 |
lm_scale: 5.0
|
10 |
word_ins_penalty: 0.0
|
|
|
3 |
default_beam: 17.0
|
4 |
max_active: 7000
|
5 |
determinize_lattice: True
|
6 |
+
max_batch_size: 200
|
7 |
+
num_channels: 400
|
8 |
frame_shift_seconds: 0.04
|
9 |
lm_scale: 5.0
|
10 |
word_ins_penalty: 0.0
|
model_repo_cuda_decoder/scoring/config.pbtxt
CHANGED
@@ -35,11 +35,11 @@ parameters [
|
|
35 |
},
|
36 |
{
|
37 |
key: "decoding_method",
|
38 |
-
value: { string_value: "
|
39 |
},
|
40 |
{
|
41 |
key: "attention_rescoring",
|
42 |
-
value: { string_value: "
|
43 |
},
|
44 |
{
|
45 |
key: "bidecoder",
|
|
|
35 |
},
|
36 |
{
|
37 |
key: "decoding_method",
|
38 |
+
value: { string_value: "tlg_mbr"} # tlg, tlg_mbr, ctc_greedy_search, cpu_ctc_beam_search, cuda_ctc_beam_search
|
39 |
},
|
40 |
{
|
41 |
key: "attention_rescoring",
|
42 |
+
value: { string_value: "1"}
|
43 |
},
|
44 |
{
|
45 |
key: "bidecoder",
|
model_repo_cuda_decoder/scoring/config.pbtxt.template
CHANGED
@@ -39,7 +39,7 @@ parameters [
|
|
39 |
},
|
40 |
{
|
41 |
key: "attention_rescoring",
|
42 |
-
value: { string_value: "
|
43 |
},
|
44 |
{
|
45 |
key: "bidecoder",
|
|
|
39 |
},
|
40 |
{
|
41 |
key: "attention_rescoring",
|
42 |
+
value: { string_value: "1"}
|
43 |
},
|
44 |
{
|
45 |
key: "bidecoder",
|
run.sh
CHANGED
@@ -3,7 +3,4 @@ export CUDA_VISIBLE_DEVICES="1"
|
|
3 |
model_repo_path=model_repo_cuda_decoder
|
4 |
tritonserver --model-repository $model_repo_path \
|
5 |
--pinned-memory-pool-byte-size=512000000 \
|
6 |
-
--cuda-memory-pool-byte-size=0:1024000000
|
7 |
-
--http-port=18000 \
|
8 |
-
--metrics-port=18001 \
|
9 |
-
--grpc-port=18002
|
|
|
3 |
model_repo_path=model_repo_cuda_decoder
|
4 |
tritonserver --model-repository $model_repo_path \
|
5 |
--pinned-memory-pool-byte-size=512000000 \
|
6 |
+
--cuda-memory-pool-byte-size=0:1024000000
|
|
|
|
|
|