Yuekai Zhang commited on
Commit
4fa997e
·
1 Parent(s): 21f2415

update files

Browse files
README.md CHANGED
@@ -1,5 +1,5 @@
1
- ### usage
2
-
3
  ```
4
  docker pull soar97/triton-wenet:22.12
5
  docker run -it --rm --name "wenet_tlg_test" --gpus all --shm-size 1g --net host soar97/triton-wenet:22.12
@@ -8,4 +8,13 @@ git clone https://huggingface.co/yuekai/model_repo_conformer_aishell_wenet_tlg.g
8
  cd model_repo_conformer_aishell_wenet_tlg
9
  bash run.sh
10
  ```
 
 
 
 
 
 
 
 
 
11
 
 
1
+ ### Tutorial
2
+ Start Server
3
  ```
4
  docker pull soar97/triton-wenet:22.12
5
  docker run -it --rm --name "wenet_tlg_test" --gpus all --shm-size 1g --net host soar97/triton-wenet:22.12
 
8
  cd model_repo_conformer_aishell_wenet_tlg
9
  bash run.sh
10
  ```
11
+ Start Client
12
+ ```
13
+ pip3 install tritonclient[all]==2.29
14
+ apt-get install -y libsndfile1
15
+ pip3 install soundfile
16
+
17
+ python3 generate_perf_input.py --audio_file ./mid.wav
18
+ perf_analyzer -m attention_rescoring -b 1 -p 20000 --concurrency-range 100 -i gRPC --input-data=offline_input.json -u localhost:8001
19
+ ```
20
 
generate_perf_input.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import soundfile as sf
4
+ import numpy as np
5
+ import argparse
6
+ import math
7
+
8
+
9
+ def generate_offline_input(args):
10
+ wav_file = args.audio_file
11
+ print("Reading {}".format(wav_file))
12
+ waveform, sample_rate = sf.read(wav_file)
13
+ batch_size = 1
14
+ mat = np.array([waveform] * batch_size, dtype=np.float32)
15
+
16
+ out_dict = {
17
+ "data": [
18
+ {
19
+ "WAV_LENS": [len(waveform)],
20
+ "WAV": {
21
+ "shape": [len(waveform)],
22
+ "content": mat.flatten().tolist(),
23
+ },
24
+ }
25
+ ]
26
+ }
27
+ json.dump(out_dict, open("offline_input.json", "w"))
28
+
29
+
30
+ def generate_online_input(args):
31
+ wav_file = args.audio_file
32
+ waveform, sample_rate = sf.read(wav_file)
33
+ chunk_size, subsampling = args.chunk_size, args.subsampling
34
+ context = args.context
35
+ first_chunk_length = (chunk_size - 1) * subsampling + context
36
+ frame_length_ms, frame_shift_ms = args.frame_length_ms, args.frame_shift_ms
37
+ # for the first chunk,
38
+ # we need additional frame to generate the exact first chunk length frames
39
+ add_frames = math.ceil((frame_length_ms - frame_shift_ms) / frame_shift_ms)
40
+ first_chunk_ms = (first_chunk_length + add_frames) * frame_shift_ms
41
+ other_chunk_ms = chunk_size * subsampling * frame_shift_ms
42
+ first_chunk_s = first_chunk_ms / 1000
43
+ other_chunk_s = other_chunk_ms / 1000
44
+
45
+ wav_segs = []
46
+ i = 0
47
+ while i < len(waveform):
48
+ if i == 0:
49
+ stride = int(first_chunk_s * sample_rate)
50
+ wav_segs.append(waveform[i : i + stride])
51
+ else:
52
+ stride = int(other_chunk_s * sample_rate)
53
+ wav_segs.append(waveform[i : i + stride])
54
+ i += len(wav_segs[-1])
55
+
56
+ data = {"data": [[]]}
57
+
58
+ for idx, seg in enumerate(wav_segs): # 0, num_frames + 5, 64
59
+ chunk_len = len(seg)
60
+ if idx == 0:
61
+ length = int(first_chunk_s * sample_rate)
62
+ expect_input = np.zeros((1, length), dtype=np.float32)
63
+ else:
64
+ length = int(other_chunk_s * sample_rate)
65
+ expect_input = np.zeros((1, length), dtype=np.float32)
66
+
67
+ expect_input[0][0:chunk_len] = seg
68
+
69
+ flat_chunk = expect_input.flatten().astype(np.float32).tolist()
70
+ seq = {
71
+ "WAV": {"content": flat_chunk, "shape": expect_input[0].shape},
72
+ "WAV_LENS": [chunk_len],
73
+ }
74
+ data["data"][0].append(seq)
75
+
76
+ json.dump(data, open("online_input.json", "w"))
77
+
78
+
79
+ if __name__ == "__main__":
80
+ parser = argparse.ArgumentParser()
81
+ parser.add_argument(
82
+ "--audio_file", type=str, default=None, help="single wav file"
83
+ )
84
+ # below is only for streaming input
85
+ parser.add_argument("--streaming", action="store_true", required=False)
86
+ parser.add_argument(
87
+ "--sample_rate",
88
+ type=int,
89
+ required=False,
90
+ default=16000,
91
+ help="sample rate used in training",
92
+ )
93
+ parser.add_argument(
94
+ "--frame_length_ms",
95
+ type=int,
96
+ required=False,
97
+ default=25,
98
+ help="frame length used in training",
99
+ )
100
+ parser.add_argument(
101
+ "--frame_shift_ms",
102
+ type=int,
103
+ required=False,
104
+ default=10,
105
+ help="frame shift length used in training",
106
+ )
107
+ parser.add_argument(
108
+ "--chunk_size",
109
+ type=int,
110
+ required=False,
111
+ default=16,
112
+ help="chunk size default is 16",
113
+ )
114
+ parser.add_argument(
115
+ "--context",
116
+ type=int,
117
+ required=False,
118
+ default=7,
119
+ help="conformer context default is 7",
120
+ )
121
+ parser.add_argument(
122
+ "--subsampling",
123
+ type=int,
124
+ required=False,
125
+ default=4,
126
+ help="subsampling rate default is 4",
127
+ )
128
+
129
+ args = parser.parse_args()
130
+
131
+ if args.streaming and os.path.exists(args.audio_file):
132
+ generate_online_input(args)
133
+ else:
134
+ generate_offline_input(args)
mid.wav ADDED
Binary file (160 kB). View file
 
model_repo_cuda_decoder/attention_rescoring/1/.gitkeep ADDED
File without changes
model_repo_cuda_decoder/scoring/1/decoder.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  from typing import List
4
  from riva.asrlib.decoder.python_decoder import (BatchedMappedDecoderCuda,
5
  BatchedMappedDecoderCudaConfig)
 
6
 
7
  def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
8
  """Make mask tensor containing indices of padded part.
@@ -81,52 +82,53 @@ class RivaWFSTDecoder:
81
 
82
  config.online_opts.lattice_postprocessor_opts.nbest = beam_size
83
 
 
 
84
  self.decoder = BatchedMappedDecoderCuda(
85
  config, os.path.join(tlg_dir, "TLG.fst"),
86
  os.path.join(tlg_dir, "words.txt"), vocab_size
87
  )
88
  self.word_id_to_word_str = load_word_symbols(os.path.join(tlg_dir, "words.txt"))
89
  self.nbest = beam_size
 
 
90
 
91
  def decode_nbest(self, logits, length):
 
92
  logits = logits.to(torch.float32).contiguous()
93
  sequence_lengths_tensor = length.to(torch.long).to('cpu').contiguous()
94
- before = logits.shape
95
- if logits.shape[0] == 1:
96
- logits = logits.repeat(2,1,1)
97
- sequence_lengths_tensor = sequence_lengths_tensor.repeat(2)
98
- print(before, logits.shape)
99
  results = self.decoder.decode_nbest(logits, sequence_lengths_tensor)
100
- if logits.shape[0] == 1:
101
- results = results[0:1]
102
  total_hyps, total_hyps_id = [], []
 
103
  for nbest_sentences in results:
104
- nbest_list, nbest_id_list = []
105
  for sent in nbest_sentences:
106
- # subtract 1 to get the label id, since fst decoder adds 1 to the label id
 
107
  hyp_ids = [label - 1 for label in sent.ilabels]
108
- new_hyp = remove_duplicates_and_blank(hyp_ids, eos=self.vocab_size-1, blank_id=0)
 
 
109
  nbest_id_list.append(new_hyp)
110
 
111
- hyp = "".join(self.word_id_to_word_str[word] for word in sent.words if word != 0)
 
112
  nbest_list.append(hyp)
113
-
114
  total_hyps.append(nbest_list)
 
115
  total_hyps_id.append(nbest_id_list)
116
- return total_hyps, total_hyps_id
117
 
118
  def decode_mbr(self, logits, length):
 
 
119
  logits = logits.to(torch.float32).contiguous()
120
  sequence_lengths_tensor = length.to(torch.long).to('cpu').contiguous()
121
- if logits.shape[0] == 1:
122
- logits = logits.repeat(2,1,1)
123
- sequence_lengths_tensor = sequence_lengths_tensor.repeat(2)
124
  results = self.decoder.decode_mbr(logits, sequence_lengths_tensor)
125
- if logits.shape[0] == 1:
126
- results = results[0:1]
127
  total_hyps = []
128
  for sent in results:
129
  hyp = [word[0] for word in sent]
130
  hyp_zh = "".join(hyp)
131
  total_hyps.append(hyp_zh)
132
- return total_hyps
 
3
  from typing import List
4
  from riva.asrlib.decoder.python_decoder import (BatchedMappedDecoderCuda,
5
  BatchedMappedDecoderCudaConfig)
6
+ from frame_reducer import FrameReducer
7
 
8
  def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
9
  """Make mask tensor containing indices of padded part.
 
82
 
83
  config.online_opts.lattice_postprocessor_opts.nbest = beam_size
84
 
85
+ # config.online_opts.decoder_opts.blank_penalty = -5.0
86
+
87
  self.decoder = BatchedMappedDecoderCuda(
88
  config, os.path.join(tlg_dir, "TLG.fst"),
89
  os.path.join(tlg_dir, "words.txt"), vocab_size
90
  )
91
  self.word_id_to_word_str = load_word_symbols(os.path.join(tlg_dir, "words.txt"))
92
  self.nbest = beam_size
93
+ self.vocab_size = vocab_size
94
+ self.frame_reducer = FrameReducer(0.98)
95
 
96
  def decode_nbest(self, logits, length):
97
+ logits, length = self.frame_reducer(logits, length.cuda(), logits)
98
  logits = logits.to(torch.float32).contiguous()
99
  sequence_lengths_tensor = length.to(torch.long).to('cpu').contiguous()
 
 
 
 
 
100
  results = self.decoder.decode_nbest(logits, sequence_lengths_tensor)
 
 
101
  total_hyps, total_hyps_id = [], []
102
+ max_hyp_len = 3
103
  for nbest_sentences in results:
104
+ nbest_list, nbest_id_list = [], []
105
  for sent in nbest_sentences:
106
+ # subtract 1 to get the label id,
107
+ # since fst decoder adds 1 to the label id
108
  hyp_ids = [label - 1 for label in sent.ilabels]
109
+ # padding for hyps_pad_sos_eos
110
+ new_hyp = [self.vocab_size - 1] + remove_duplicates_and_blank(hyp_ids, eos=self.vocab_size - 1, blank_id=0) + [self.vocab_size - 1] # noqa
111
+ max_hyp_len = max(max_hyp_len, len(new_hyp))
112
  nbest_id_list.append(new_hyp)
113
 
114
+ hyp = "".join(self.word_id_to_word_str[word]
115
+ for word in sent.words if word != 0)
116
  nbest_list.append(hyp)
117
+ nbest_list += [""] * (self.nbest - len(nbest_list))
118
  total_hyps.append(nbest_list)
119
+ nbest_id_list += [[self.vocab_size - 1, 0, self.vocab_size - 1]] * (self.nbest - len(nbest_id_list)) # noqa
120
  total_hyps_id.append(nbest_id_list)
121
+ return total_hyps, total_hyps_id, max_hyp_len
122
 
123
  def decode_mbr(self, logits, length):
124
+ logits, length = self.frame_reducer(logits, length.cuda(), logits)
125
+ # logits[:,:,0] -= 2.0
126
  logits = logits.to(torch.float32).contiguous()
127
  sequence_lengths_tensor = length.to(torch.long).to('cpu').contiguous()
 
 
 
128
  results = self.decoder.decode_mbr(logits, sequence_lengths_tensor)
 
 
129
  total_hyps = []
130
  for sent in results:
131
  hyp = [word[0] for word in sent]
132
  hyp_zh = "".join(hyp)
133
  total_hyps.append(hyp_zh)
134
+ return total_hyps
model_repo_cuda_decoder/scoring/1/frame_reducer.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ #
3
+ # Copyright 2022 Xiaomi Corp. (authors: Yifan Yang,
4
+ # Zengwei Yao,
5
+ # Wei Kang)
6
+ #
7
+ # See ../../../../LICENSE for clarification regarding multiple authors
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ import math
22
+ from typing import Optional, Tuple
23
+
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+
28
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
29
+ """
30
+ Args:
31
+ lengths:
32
+ A 1-D tensor containing sentence lengths.
33
+ max_len:
34
+ The length of masks.
35
+ Returns:
36
+ Return a 2-D bool tensor, where masked positions
37
+ are filled with `True` and non-masked positions are
38
+ filled with `False`.
39
+ >>> lengths = torch.tensor([1, 3, 2, 5])
40
+ >>> make_pad_mask(lengths)
41
+ tensor([[False, True, True, True, True],
42
+ [False, False, False, True, True],
43
+ [False, False, True, True, True],
44
+ [False, False, False, False, False]])
45
+ """
46
+ assert lengths.ndim == 1, lengths.ndim
47
+ max_len = max(max_len, lengths.max())
48
+ n = lengths.size(0)
49
+ seq_range = torch.arange(0, max_len, device=lengths.device)
50
+ expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
51
+
52
+ return expaned_lengths >= lengths.unsqueeze(-1)
53
+
54
+
55
+
56
+ class FrameReducer(nn.Module):
57
+ """The encoder output is first used to calculate
58
+ the CTC posterior probability; then for each output frame,
59
+ if its blank posterior is bigger than some thresholds,
60
+ it will be simply discarded from the encoder output.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ blank_threshlod: float = 0.95,
66
+ ):
67
+ super().__init__()
68
+ self.blank_threshlod = blank_threshlod
69
+
70
+ def forward(
71
+ self,
72
+ x: torch.Tensor,
73
+ x_lens: torch.Tensor,
74
+ ctc_output: torch.Tensor,
75
+ y_lens: Optional[torch.Tensor] = None,
76
+ blank_id: int = 0,
77
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
78
+ """
79
+ Args:
80
+ x:
81
+ The shared encoder output with shape [N, T, C].
82
+ x_lens:
83
+ A tensor of shape (batch_size,) containing the number of frames in
84
+ `x` before padding.
85
+ ctc_output:
86
+ The CTC output with shape [N, T, vocab_size].
87
+ y_lens:
88
+ A tensor of shape (batch_size,) containing the number of frames in
89
+ `y` before padding.
90
+ blank_id:
91
+ The blank id of ctc_output.
92
+ Returns:
93
+ out:
94
+ The frame reduced encoder output with shape [N, T', C].
95
+ out_lens:
96
+ A tensor of shape (batch_size,) containing the number of frames in
97
+ `out` before padding.
98
+ """
99
+ N, T, C = x.size()
100
+
101
+ padding_mask = make_pad_mask(x_lens, x.size(1))
102
+ non_blank_mask = (ctc_output[:, :, blank_id] < math.log(self.blank_threshlod)) * (~padding_mask) # noqa
103
+
104
+ if y_lens is not None:
105
+ # Limit the maximum number of reduced frames
106
+ limit_lens = T - y_lens
107
+ max_limit_len = limit_lens.max().int()
108
+ fake_limit_indexes = torch.topk(
109
+ ctc_output[:, :, blank_id], max_limit_len
110
+ ).indices
111
+ T = (
112
+ torch.arange(max_limit_len)
113
+ .expand_as(
114
+ fake_limit_indexes,
115
+ )
116
+ .to(device=x.device)
117
+ )
118
+ T = torch.remainder(T, limit_lens.unsqueeze(1))
119
+ limit_indexes = torch.gather(fake_limit_indexes, 1, T)
120
+ limit_mask = torch.full_like(
121
+ non_blank_mask,
122
+ False,
123
+ device=x.device,
124
+ ).scatter_(1, limit_indexes, True)
125
+
126
+ non_blank_mask = non_blank_mask | ~limit_mask
127
+
128
+ out_lens = non_blank_mask.sum(dim=1)
129
+ max_len = out_lens.max()
130
+ pad_lens_list = (
131
+ torch.full_like(
132
+ out_lens,
133
+ max_len.item(),
134
+ device=x.device,
135
+ )
136
+ - out_lens
137
+ )
138
+ max_pad_len = pad_lens_list.max()
139
+
140
+ out = F.pad(x, (0, 0, 0, max_pad_len))
141
+
142
+ valid_pad_mask = ~make_pad_mask(pad_lens_list)
143
+ total_valid_mask = torch.concat([non_blank_mask, valid_pad_mask], dim=1)
144
+
145
+ out = out[total_valid_mask].reshape(N, -1, C)
146
+
147
+ return out, out_lens
148
+
149
+
150
+ if __name__ == "__main__":
151
+ import time
152
+
153
+ test_times = 10000
154
+ device = "cuda:0"
155
+ frame_reducer = FrameReducer()
156
+
157
+ # non zero case
158
+ x = torch.ones(15, 498, 384, dtype=torch.float32, device=device)
159
+ x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
160
+ y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
161
+ ctc_output = torch.log(
162
+ torch.randn(15, 498, 500, dtype=torch.float32, device=device),
163
+ )
164
+
165
+ avg_time = 0
166
+ for i in range(test_times):
167
+ torch.cuda.synchronize(device=x.device)
168
+ delta_time = time.time()
169
+ x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
170
+ torch.cuda.synchronize(device=x.device)
171
+ delta_time = time.time() - delta_time
172
+ avg_time += delta_time
173
+ print(x_fr.shape)
174
+ print(x_lens_fr)
175
+ print(avg_time / test_times)
176
+
177
+ # all zero case
178
+ x = torch.zeros(15, 498, 384, dtype=torch.float32, device=device)
179
+ x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
180
+ y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
181
+ ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32, device=device)
182
+
183
+ avg_time = 0
184
+ for i in range(test_times):
185
+ torch.cuda.synchronize(device=x.device)
186
+ delta_time = time.time()
187
+ x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
188
+ torch.cuda.synchronize(device=x.device)
189
+ delta_time = time.time() - delta_time
190
+ avg_time += delta_time
191
+ print(x_fr.shape)
192
+ print(x_lens_fr)
193
+ print(avg_time / test_times)
model_repo_cuda_decoder/scoring/1/model.py CHANGED
@@ -16,7 +16,7 @@ import triton_python_backend_utils as pb_utils
16
  import numpy as np
17
 
18
  import torch
19
- from torch.utils.dlpack import from_dlpack
20
  import json
21
  import os
22
  import yaml
@@ -123,7 +123,7 @@ class TritonPythonModel:
123
  self.eos = eos
124
  self.ignore_id = ignore_id
125
 
126
- if self.decoding_method == "tlg":
127
  self.decoder = RivaWFSTDecoder(len(self.vocabulary),
128
  self.tlg_dir,
129
  self.tlg_decoding_config,
@@ -175,12 +175,57 @@ class TritonPythonModel:
175
  encoder_out_len = torch.cat(encoder_out_lens_list, dim=0)
176
  return encoder_out, encoder_out_len, logits, batch_count_list
177
 
178
- def rescore_hyps(self, total_hyps, total_tokens, encoder_out, encoder_out_len):
179
  """
180
  Rescore the hypotheses with attention rescoring
181
  """
182
- # TODO: add attention rescoring
183
- return total_hyps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
  def prepare_response(self, hyps, batch_count_list):
186
  """
@@ -223,17 +268,17 @@ class TritonPythonModel:
223
  ctc_log_probs = ctc_log_probs.cuda()
224
  if self.decoding_method == "tlg_mbr":
225
  total_hyps = self.decoder.decode_mbr(ctc_log_probs, encoder_out_len)
226
- # list(str), list((float), list(int)) # TODO: add token_ids, time stamps
227
  elif self.decoding_method == "ctc_greedy_search":
228
  total_hyps = ctc_greedy_search(ctc_log_probs, encoder_out_len,
229
  self.vocabulary, self.blank_id, self.eos)
230
  elif self.decoding_method == "tlg":
231
- nbest_hyps, nbest_ids = self.decoder.decode_nbest(encoder_out, encoder_out_len)
232
  total_hyps = [nbest[0] for nbest in nbest_hyps]
233
 
234
  if self.decoding_method == "tlg" and self.rescore:
235
  assert self.beam_size > 1, "Beam size must be greater than 1 for rescoring"
236
  selected_ids = self.rescore_hyps(nbest_ids,
 
237
  encoder_out,
238
  encoder_out_len)
239
  total_hyps = [nbest[i] for nbest, i in zip(nbest_hyps, selected_ids)]
 
16
  import numpy as np
17
 
18
  import torch
19
+ from torch.utils.dlpack import from_dlpack, to_dlpack
20
  import json
21
  import os
22
  import yaml
 
123
  self.eos = eos
124
  self.ignore_id = ignore_id
125
 
126
+ if "tlg" in self.decoding_method:
127
  self.decoder = RivaWFSTDecoder(len(self.vocabulary),
128
  self.tlg_dir,
129
  self.tlg_decoding_config,
 
175
  encoder_out_len = torch.cat(encoder_out_lens_list, dim=0)
176
  return encoder_out, encoder_out_len, logits, batch_count_list
177
 
178
+ def rescore_hyps(self, total_tokens, max_hyp_len, encoder_out, encoder_out_len):
179
  """
180
  Rescore the hypotheses with attention rescoring
181
  """
182
+ input1 = pb_utils.Tensor.from_dlpack("encoder_out", to_dlpack(encoder_out))
183
+ input2 = pb_utils.Tensor.from_dlpack("encoder_out_lens",
184
+ to_dlpack(encoder_out_len.unsqueeze(-1)))
185
+ hyps_pad_sos_eos = np.zeros([len(total_tokens),
186
+ self.beam_size, max_hyp_len], dtype=np.int64)
187
+ hyps_lens_sos = np.zeros([len(total_tokens), self.beam_size], dtype=np.int32)
188
+ ctc_scores = np.zeros([len(total_tokens),
189
+ self.beam_size], dtype=np.float16) # TODO: zero here
190
+
191
+ for i, hyps in enumerate(total_tokens):
192
+ for j, hyp in enumerate(hyps):
193
+ hyps_pad_sos_eos[i][j][:len(hyp)] = hyp
194
+ hyps_lens_sos[i][j] = len(hyp) - 1
195
+ input3 = pb_utils.Tensor("hyps_pad_sos_eos", hyps_pad_sos_eos)
196
+ input4 = pb_utils.Tensor("hyps_lens_sos", hyps_lens_sos)
197
+ input5 = pb_utils.Tensor("ctc_score", ctc_scores)
198
+ input_tensors = [input1, input2, input3, input4, input5]
199
+
200
+ if self.bidecoder:
201
+ r_hyps_pad_sos_eos = np.zeros([len(total_tokens),
202
+ self.beam_size, max_hyp_len], dtype=np.int64)
203
+ for i, hyps in enumerate(total_tokens):
204
+ for j, hyp in enumerate(hyps):
205
+ r_hyps_pad_sos_eos[i][j][:len(hyp)] = hyp[::-1]
206
+ input6 = pb_utils.Tensor.from_dlpack("r_hyps_pad_sos_eos",
207
+ r_hyps_pad_sos_eos)
208
+ input_tensors.insert(-1, input6)
209
+
210
+ inference_request = pb_utils.InferenceRequest(
211
+ model_name='decoder',
212
+ requested_output_names=['best_index'],
213
+ inputs=input_tensors)
214
+
215
+ inference_response = inference_request.exec()
216
+ if inference_response.has_error():
217
+ raise pb_utils.TritonModelException(inference_response.error().message())
218
+ else:
219
+ # Extract the output tensors from the inference response.
220
+ best_index = pb_utils.get_output_tensor_by_name(inference_response,
221
+ 'best_index')
222
+ if best_index.is_cpu():
223
+ best_index = best_index.as_numpy()
224
+ else:
225
+ best_index = from_dlpack(best_index.to_dlpack())
226
+ best_index = best_index.cpu().numpy()
227
+ best_index = np.squeeze(best_index, -1).tolist()
228
+ return best_index
229
 
230
  def prepare_response(self, hyps, batch_count_list):
231
  """
 
268
  ctc_log_probs = ctc_log_probs.cuda()
269
  if self.decoding_method == "tlg_mbr":
270
  total_hyps = self.decoder.decode_mbr(ctc_log_probs, encoder_out_len)
 
271
  elif self.decoding_method == "ctc_greedy_search":
272
  total_hyps = ctc_greedy_search(ctc_log_probs, encoder_out_len,
273
  self.vocabulary, self.blank_id, self.eos)
274
  elif self.decoding_method == "tlg":
275
+ nbest_hyps, nbest_ids, max_hyp_len = self.decoder.decode_nbest(ctc_log_probs, encoder_out_len) # noqa
276
  total_hyps = [nbest[0] for nbest in nbest_hyps]
277
 
278
  if self.decoding_method == "tlg" and self.rescore:
279
  assert self.beam_size > 1, "Beam size must be greater than 1 for rescoring"
280
  selected_ids = self.rescore_hyps(nbest_ids,
281
+ max_hyp_len,
282
  encoder_out,
283
  encoder_out_len)
284
  total_hyps = [nbest[i] for nbest, i in zip(nbest_hyps, selected_ids)]
model_repo_cuda_decoder/scoring/1/wfst_decoding_config.yaml CHANGED
@@ -3,8 +3,8 @@ n_input_per_chunk: 50
3
  default_beam: 17.0
4
  max_active: 7000
5
  determinize_lattice: True
6
- max_batch_size: 800
7
- num_channels: 800
8
  frame_shift_seconds: 0.04
9
  lm_scale: 5.0
10
  word_ins_penalty: 0.0
 
3
  default_beam: 17.0
4
  max_active: 7000
5
  determinize_lattice: True
6
+ max_batch_size: 200
7
+ num_channels: 400
8
  frame_shift_seconds: 0.04
9
  lm_scale: 5.0
10
  word_ins_penalty: 0.0
model_repo_cuda_decoder/scoring/config.pbtxt CHANGED
@@ -35,11 +35,11 @@ parameters [
35
  },
36
  {
37
  key: "decoding_method",
38
- value: { string_value: "tlg"} # tlg, ctc_greedy_search, cpu_ctc_beam_search, cuda_ctc_beam_search
39
  },
40
  {
41
  key: "attention_rescoring",
42
- value: { string_value: "0"}
43
  },
44
  {
45
  key: "bidecoder",
 
35
  },
36
  {
37
  key: "decoding_method",
38
+ value: { string_value: "tlg_mbr"} # tlg, tlg_mbr, ctc_greedy_search, cpu_ctc_beam_search, cuda_ctc_beam_search
39
  },
40
  {
41
  key: "attention_rescoring",
42
+ value: { string_value: "1"}
43
  },
44
  {
45
  key: "bidecoder",
model_repo_cuda_decoder/scoring/config.pbtxt.template CHANGED
@@ -39,7 +39,7 @@ parameters [
39
  },
40
  {
41
  key: "attention_rescoring",
42
- value: { string_value: "0"}
43
  },
44
  {
45
  key: "bidecoder",
 
39
  },
40
  {
41
  key: "attention_rescoring",
42
+ value: { string_value: "1"}
43
  },
44
  {
45
  key: "bidecoder",
run.sh CHANGED
@@ -3,7 +3,4 @@ export CUDA_VISIBLE_DEVICES="1"
3
  model_repo_path=model_repo_cuda_decoder
4
  tritonserver --model-repository $model_repo_path \
5
  --pinned-memory-pool-byte-size=512000000 \
6
- --cuda-memory-pool-byte-size=0:1024000000 \
7
- --http-port=18000 \
8
- --metrics-port=18001 \
9
- --grpc-port=18002
 
3
  model_repo_path=model_repo_cuda_decoder
4
  tritonserver --model-repository $model_repo_path \
5
  --pinned-memory-pool-byte-size=512000000 \
6
+ --cuda-memory-pool-byte-size=0:1024000000