|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import io |
|
import base64 |
|
import numpy as np |
|
import scipy.io.wavfile |
|
from typing import Text |
|
import streamlit as st |
|
from pyannote.audio import Pipeline |
|
from pyannote.audio import Audio |
|
from pyannote.audio.pipelines.utils.hook import TimingHook |
|
from pyannote.core import Segment |
|
|
|
|
|
import streamlit.components.v1 as components |
|
|
|
|
|
def to_base64(waveform: np.ndarray, sample_rate: int = 16000) -> Text: |
|
"""Convert waveform to base64 data""" |
|
waveform /= np.max(np.abs(waveform)) + 1e-8 |
|
with io.BytesIO() as content: |
|
scipy.io.wavfile.write(content, sample_rate, waveform) |
|
content.seek(0) |
|
b64 = base64.b64encode(content.read()).decode() |
|
b64 = f"data:audio/x-wav;base64,{b64}" |
|
return b64 |
|
|
|
|
|
PYANNOTE_LOGO = "https://avatars.githubusercontent.com/u/7559051?s=400&v=4" |
|
EXCERPT = 120 |
|
|
|
st.set_page_config(page_title="pyannote pretrained pipelines", page_icon=PYANNOTE_LOGO) |
|
|
|
col1, col2 = st.columns([0.2, 0.8], gap="small") |
|
|
|
with col1: |
|
st.image(PYANNOTE_LOGO) |
|
|
|
with col2: |
|
st.markdown( |
|
""" |
|
# pretrained pipelines |
|
Make the most of [pyannote](https://github.com/pyannote) thanks to our [consulting services](https://herve.niderb.fr/consulting.html) |
|
""" |
|
) |
|
|
|
|
|
PIPELINES = [ |
|
"pyannote/speaker-diarization-3.1", |
|
"pyannote/speaker-diarization-3.0", |
|
] |
|
|
|
audio = Audio(sample_rate=16000, mono=True) |
|
|
|
selected_pipeline = st.selectbox("Select a pretrained pipeline", PIPELINES, index=0) |
|
|
|
|
|
with st.spinner("Loading pipeline..."): |
|
try: |
|
use_auth_token = st.secrets["PYANNOTE_TOKEN"] |
|
except FileNotFoundError: |
|
use_auth_token = None |
|
except KeyError: |
|
use_auth_token = None |
|
|
|
pipeline = Pipeline.from_pretrained( |
|
selected_pipeline, use_auth_token=use_auth_token |
|
) |
|
if torch.cuda.is_available(): |
|
pipeline.to(torch.device("cuda")) |
|
|
|
uploaded_file = st.file_uploader("Upload an audio file") |
|
if uploaded_file is not None: |
|
try: |
|
duration = audio.get_duration(uploaded_file) |
|
except RuntimeError as e: |
|
st.error(e) |
|
st.stop() |
|
|
|
spinner_message = ( |
|
f"Processing {duration:.0f}s file... " |
|
if duration < EXCERPT |
|
else f"Processing first {EXCERPT:.0f}s of file..." |
|
) |
|
|
|
duration = min(duration, EXCERPT) |
|
waveform, sample_rate = audio.crop(uploaded_file, Segment(0, duration)) |
|
uri = "".join(uploaded_file.name.split()) |
|
file = {"waveform": waveform, "sample_rate": sample_rate, "uri": uri} |
|
|
|
with st.spinner(spinner_message): |
|
with TimingHook() as hook: |
|
output = pipeline(file, hook=hook) |
|
|
|
processing_time = file["timing"]["total"] |
|
faster_than_real_time = duration / processing_time |
|
st.success( |
|
f"Processed {duration:.0f}s of audio in {processing_time:.1f}s ({faster_than_real_time:.1f}x faster than real-time)", |
|
icon="✅", |
|
) |
|
|
|
with open("assets/template.html") as html, open("assets/style.css") as css: |
|
html_template = html.read() |
|
st.markdown("<style>{}</style>".format(css.read()), unsafe_allow_html=True) |
|
|
|
colors = [ |
|
"#ffd70033", |
|
"#00ffff33", |
|
"#ff00ff33", |
|
"#00ff0033", |
|
"#9932cc33", |
|
"#00bfff33", |
|
"#ff7f5033", |
|
"#66cdaa33", |
|
] |
|
num_colors = len(colors) |
|
|
|
label2color = { |
|
label: colors[k % num_colors] for k, label in enumerate(sorted(output.labels())) |
|
} |
|
|
|
BASE64 = to_base64(waveform.numpy().T) |
|
|
|
REGIONS = "" |
|
for segment, _, label in output.itertracks(yield_label=True): |
|
REGIONS += f"regions.addRegion({{start: {segment.start:g}, end: {segment.end:g}, color: '{label2color[label]}', resize : false, drag : false}});" |
|
|
|
html = html_template.replace("BASE64", BASE64).replace("REGIONS", REGIONS) |
|
components.html(html, height=250, scrolling=True) |
|
|
|
with io.StringIO() as fp: |
|
output.write_rttm(fp) |
|
content = fp.getvalue() |
|
b64 = base64.b64encode(content.encode()).decode() |
|
href = f'<a download="{output.uri}.rttm" href="data:file/text;base64,{b64}">Download</a> result in RTTM file format or run it locally:' |
|
st.markdown(href, unsafe_allow_html=True) |
|
|
|
code = f""" |
|
# load pretrained pipeline |
|
from pyannote.audio import Pipeline |
|
pipeline = Pipeline.from_pretrained("{selected_pipeline}", |
|
use_auth_token=HUGGINGFACE_TOKEN) |
|
|
|
# (optional) send pipeline to GPU |
|
import torch |
|
pipeline.to(torch.device("cuda")) |
|
|
|
# process audio file |
|
output = pipeline("audio.wav")""" |
|
st.code(code, language="python") |
|
|