csukuangfj commited on
Commit
d2cc323
·
1 Parent(s): 80e7e4c

minor fixes

Browse files
Files changed (2) hide show
  1. app.py +12 -3
  2. model.py +9 -7
app.py CHANGED
@@ -89,7 +89,7 @@ def process_uploaded_file(
89
  "result_item_error",
90
  )
91
 
92
- if input_num_speakers < 0:
93
  try:
94
  input_threshold = float(input_threshold)
95
  if input_threshold < 0 or input_threshold > 10:
@@ -142,7 +142,7 @@ def process(
142
 
143
  audio, sample_rate = read_wave(filename)
144
 
145
- MyPrint("audio", audio.shape, sample_rate)
146
 
147
  sd = get_speaker_diarization(
148
  segmentation_model=speaker_segmentation_model,
@@ -150,7 +150,7 @@ def process(
150
  num_clusters=input_num_speakers,
151
  threshold=input_threshold,
152
  )
153
- MyPrint(f"{audio.shape / sd.sample_rate}, {sample_rate}")
154
 
155
  segments = sd.process(audio).sort_by_start_time()
156
  s = ""
@@ -194,6 +194,15 @@ See more information by visiting
194
  If you want to try it on Android, please download pre-built Android
195
  APKs for speaker diarzation by visiting
196
  <https://k2-fsa.github.io/sherpa/onnx/speaker-diarization/android.html>
 
 
 
 
 
 
 
 
 
197
  """
198
 
199
  # css style is copied from
 
89
  "result_item_error",
90
  )
91
 
92
+ if input_num_speakers <= 0:
93
  try:
94
  input_threshold = float(input_threshold)
95
  if input_threshold < 0 or input_threshold > 10:
 
142
 
143
  audio, sample_rate = read_wave(filename)
144
 
145
+ MyPrint(f"audio, {audio.shape}, {sample_rate}")
146
 
147
  sd = get_speaker_diarization(
148
  segmentation_model=speaker_segmentation_model,
 
150
  num_clusters=input_num_speakers,
151
  threshold=input_threshold,
152
  )
153
+ MyPrint(f"{audio.shape[0] / sd.sample_rate}, {sample_rate}")
154
 
155
  segments = sd.process(audio).sort_by_start_time()
156
  s = ""
 
194
  If you want to try it on Android, please download pre-built Android
195
  APKs for speaker diarzation by visiting
196
  <https://k2-fsa.github.io/sherpa/onnx/speaker-diarization/android.html>
197
+
198
+ ---
199
+
200
+ Note about the two arguments:
201
+
202
+ - number of speakers: If you know the actual number of speakers in the input file,
203
+ please provide it. Otherwise, please set it to 0
204
+ - clustering threshold: Used only when number of speakers is 0. A larger
205
+ threshold results in fewer clusters, i.e., fewer speakers.
206
  """
207
 
208
  # css style is copied from
model.py CHANGED
@@ -16,7 +16,7 @@
16
 
17
  import wave
18
  from functools import lru_cache
19
- from typing import List, Tuple
20
 
21
  import numpy as np
22
  import sherpa_onnx
@@ -62,7 +62,7 @@ def _get_nn_model_filename(
62
  return nn_model_filename
63
 
64
 
65
- def get_speaker_segmentation_model(repo_id) -> List[str]:
66
  assert repo_id in ("pyannote/segmentation-3.0",)
67
 
68
  if repo_id == "pyannote/segmentation-3.0":
@@ -72,14 +72,14 @@ def get_speaker_segmentation_model(repo_id) -> List[str]:
72
  )
73
 
74
 
75
- def get_speaker_embedding_model(model_name) -> List[str]:
76
- model_name = model_name.split("|")[0]
77
  assert (
78
  model_name
79
  in three_d_speaker_embedding_models
80
  + nemo_speaker_embedding_models
81
  + wespeaker_embedding_models
82
  )
 
83
 
84
  return _get_nn_model_filename(
85
  repo_id="csukuangfj/speaker-embedding-models",
@@ -92,16 +92,18 @@ def get_speaker_diarization(
92
  ):
93
  segmentation = get_speaker_segmentation_model(segmentation_model)
94
  embedding = get_speaker_embedding_model(embedding_model)
95
- print("segmentation", segmentation)
96
- print("embedding", embedding)
97
 
98
  config = sherpa_onnx.OfflineSpeakerDiarizationConfig(
99
  segmentation=sherpa_onnx.OfflineSpeakerSegmentationModelConfig(
100
  pyannote=sherpa_onnx.OfflineSpeakerSegmentationPyannoteModelConfig(
101
  model=segmentation
102
  ),
 
 
 
 
 
103
  ),
104
- embedding=sherpa_onnx.SpeakerEmbeddingExtractorConfig(model=embedding),
105
  clustering=sherpa_onnx.FastClusteringConfig(
106
  num_clusters=num_clusters,
107
  threshold=threshold,
 
16
 
17
  import wave
18
  from functools import lru_cache
19
+ from typing import Tuple
20
 
21
  import numpy as np
22
  import sherpa_onnx
 
62
  return nn_model_filename
63
 
64
 
65
+ def get_speaker_segmentation_model(repo_id) -> str:
66
  assert repo_id in ("pyannote/segmentation-3.0",)
67
 
68
  if repo_id == "pyannote/segmentation-3.0":
 
72
  )
73
 
74
 
75
+ def get_speaker_embedding_model(model_name) -> str:
 
76
  assert (
77
  model_name
78
  in three_d_speaker_embedding_models
79
  + nemo_speaker_embedding_models
80
  + wespeaker_embedding_models
81
  )
82
+ model_name = model_name.split("|")[0]
83
 
84
  return _get_nn_model_filename(
85
  repo_id="csukuangfj/speaker-embedding-models",
 
92
  ):
93
  segmentation = get_speaker_segmentation_model(segmentation_model)
94
  embedding = get_speaker_embedding_model(embedding_model)
 
 
95
 
96
  config = sherpa_onnx.OfflineSpeakerDiarizationConfig(
97
  segmentation=sherpa_onnx.OfflineSpeakerSegmentationModelConfig(
98
  pyannote=sherpa_onnx.OfflineSpeakerSegmentationPyannoteModelConfig(
99
  model=segmentation
100
  ),
101
+ debug=True,
102
+ ),
103
+ embedding=sherpa_onnx.SpeakerEmbeddingExtractorConfig(
104
+ model=embedding,
105
+ debug=True,
106
  ),
 
107
  clustering=sherpa_onnx.FastClusteringConfig(
108
  num_clusters=num_clusters,
109
  threshold=threshold,