#!/usr/bin/env python # -*- coding: utf-8 -*- import os import time from tempfile import NamedTemporaryFile, _TemporaryFileWrapper from typing import Any, Optional, Union import streamlit as st import torchaudio from conette import CoNeTTEModel, conette, __version__ from conette.utils.collections import dict_list_to_list_dict from st_audiorec import st_audiorec from streamlit.runtime.uploaded_file_manager import UploadedFile from torch import Tensor ALLOW_REP_MODES = ("stopwords", "all", "none") DEFAULT_TASK = "audiocaps" MAX_BEAM_SIZE = 20 MAX_PRED_SIZE = 30 MAX_BATCH_SIZE = 16 RECORD_AUDIO_FNAME = "microphone_conette_record.wav" DEFAULT_THRESHOLD = 0.3 THRESHOLD_PRECISION = 100 MIN_AUDIO_DURATION_SEC = 0.3 MAX_AUDIO_DURATION_SEC = 60 HASH_PREFIX = "hash_" TMP_FILE_PREFIX = "audio_tmp_file_" SECOND_BEFORE_CLEAR_CACHE = 10 * 60 @st.cache_resource def load_conette(*args, **kwargs) -> CoNeTTEModel: return conette(*args, **kwargs) def format_candidate(candidate: str) -> str: if len(candidate) == 0: return "" else: return f"{candidate[0].title()}{candidate[1:]}." def format_tags(tags: Optional[list[str]]) -> str: if tags is None or len(tags) == 0: return "None." else: return ", ".join(tags) def get_result_hash(audio_fname: str, generate_kwds: dict[str, Any]) -> str: return f"{HASH_PREFIX}{audio_fname}-{generate_kwds}" def get_results( model: CoNeTTEModel, audio_files: dict[str, bytes], generate_kwds: dict[str, Any], ) -> dict[str, Union[dict[str, Any], str]]: # Get audio to be processed audio_to_predict: dict[str, tuple[str, bytes]] = {} for audio_fname, audio in audio_files.items(): result_hash = get_result_hash(audio_fname, generate_kwds) if result_hash not in st.session_state or audio_fname == RECORD_AUDIO_FNAME: audio_to_predict[result_hash] = (audio_fname, audio) # Save audio to be processed tmp_files: dict[str, _TemporaryFileWrapper] = {} for result_hash, (audio_fname, audio) in audio_to_predict.items(): tmp_file = NamedTemporaryFile(delete=False, prefix=TMP_FILE_PREFIX) tmp_file.write(audio) tmp_file.close() metadata = torchaudio.info(tmp_file.name) # type: ignore duration = metadata.num_frames / metadata.sample_rate if MIN_AUDIO_DURATION_SEC > duration: error_msg = f""" ##### Result for "{audio_fname}" Audio file is too short. (found {duration:.2f}s but the model expect audio in range [{MIN_AUDIO_DURATION_SEC}, {MAX_AUDIO_DURATION_SEC}]) """ st.session_state[result_hash] = error_msg elif duration > MAX_AUDIO_DURATION_SEC: error_msg = f""" ##### Result for "{audio_fname}" Audio file is too long. (found {duration:.2f}s but the model expect audio in range [{MIN_AUDIO_DURATION_SEC}, {MAX_AUDIO_DURATION_SEC}]) """ st.session_state[result_hash] = error_msg else: tmp_files[result_hash] = tmp_file # Generate predictions and store them in session state for start in range(0, len(tmp_files), MAX_BATCH_SIZE): end = min(start + MAX_BATCH_SIZE, len(tmp_files)) result_hashes_j = list(tmp_files.keys())[start:end] tmp_files_j = list(tmp_files.values())[start:end] tmp_paths_j = [tmp_file.name for tmp_file in tmp_files_j] outputs_j = model( tmp_paths_j, **generate_kwds, ) outputs_lst = dict_list_to_list_dict(outputs_j) # type: ignore for result_hash, output_i in zip(result_hashes_j, outputs_lst): st.session_state[result_hash] = output_i # Get outputs outputs = {} for audio_fname in audio_files.keys(): result_hash = get_result_hash(audio_fname, generate_kwds) output_i = st.session_state[result_hash] outputs[audio_fname] = output_i for tmp_file in tmp_files.values(): os.remove(tmp_file.name) return outputs def show_results(outputs: dict[str, Union[dict[str, Any], str]]) -> None: keys = list(outputs.keys())[::-1] outputs = {key: outputs[key] for key in keys} st.divider() for audio_fname, output in outputs.items(): if isinstance(output, str): st.error(output) st.divider() continue cand: str = output["cands"] lprobs: Tensor = output["lprobs"] tags_lst = output.get("tags") mult_cands: list[str] = output["mult_cands"] mult_lprobs: Tensor = output["mult_lprobs"] cand = format_candidate(cand) prob = lprobs.exp().tolist() tags = format_tags(tags_lst) mult_cands = [format_candidate(cand_i) for cand_i in mult_cands] mult_probs = mult_lprobs.exp() indexes = mult_probs.argsort(descending=True)[1:] mult_probs = mult_probs[indexes].tolist() mult_cands = [mult_cands[idx] for idx in indexes] if audio_fname == RECORD_AUDIO_FNAME: header = "##### Result for microphone input:" else: header = f'##### Result for "{audio_fname}"' lines = [ header, f'
"{cand}"