File size: 6,048 Bytes
8d120bf 05a2178 3fadc6e 8d120bf 3fadc6e 8d120bf 05a2178 8d120bf 05a2178 3fadc6e 7ce6041 05a2178 7ce6041 4514e2e 7ce6041 533d92e 93c4867 05a2178 71950a8 05a2178 8d120bf 533d92e 71950a8 533d92e 71950a8 533d92e 71950a8 533d92e 71950a8 533d92e 3fadc6e 533d92e 3fadc6e 533d92e 3fadc6e 8d120bf 3fadc6e 7ce6041 3fadc6e 05a2178 3fadc6e 8d120bf 05a2178 d5154e9 71950a8 05a2178 71950a8 05a2178 71950a8 93c4867 71950a8 8d120bf 71950a8 3fadc6e 8d120bf 3fadc6e 8d120bf 3fadc6e 7ce6041 d5154e9 05a2178 71950a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
from typing import Iterator
from io import StringIO
import os
import pathlib
import tempfile
# External programs
import whisper
import ffmpeg
# UI
import gradio as gr
from download import downloadUrl
from utils import slugify, write_srt, write_vtt
#import os
#os.system("pip install git+https://github.com/openai/whisper.git")
# Limitations (set to -1 to disable)
DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
# Whether or not to automatically delete all uploaded files, to save disk space
DELETE_UPLOADED_FILES = True
LANGUAGES = [
"English", "Chinese", "German", "Spanish", "Russian", "Korean",
"French", "Japanese", "Portuguese", "Turkish", "Polish", "Catalan",
"Dutch", "Arabic", "Swedish", "Italian", "Indonesian", "Hindi",
"Finnish", "Vietnamese", "Hebrew", "Ukrainian", "Greek", "Malay",
"Czech", "Romanian", "Danish", "Hungarian", "Tamil", "Norwegian",
"Thai", "Urdu", "Croatian", "Bulgarian", "Lithuanian", "Latin",
"Maori", "Malayalam", "Welsh", "Slovak", "Telugu", "Persian",
"Latvian", "Bengali", "Serbian", "Azerbaijani", "Slovenian",
"Kannada", "Estonian", "Macedonian", "Breton", "Basque", "Icelandic",
"Armenian", "Nepali", "Mongolian", "Bosnian", "Kazakh", "Albanian",
"Swahili", "Galician", "Marathi", "Punjabi", "Sinhala", "Khmer",
"Shona", "Yoruba", "Somali", "Afrikaans", "Occitan", "Georgian",
"Belarusian", "Tajik", "Sindhi", "Gujarati", "Amharic", "Yiddish",
"Lao", "Uzbek", "Faroese", "Haitian Creole", "Pashto", "Turkmen",
"Nynorsk", "Maltese", "Sanskrit", "Luxembourgish", "Myanmar", "Tibetan",
"Tagalog", "Malagasy", "Assamese", "Tatar", "Hawaiian", "Lingala",
"Hausa", "Bashkir", "Javanese", "Sundanese"
]
model_cache = dict()
class UI:
def __init__(self, inputAudioMaxDuration):
self.inputAudioMaxDuration = inputAudioMaxDuration
def transcribeFile(self, modelName, languageName, urlData, uploadFile, microphoneData, task):
source, sourceName = getSource(urlData, uploadFile, microphoneData)
try:
selectedLanguage = languageName.lower() if len(languageName) > 0 else None
selectedModel = modelName if modelName is not None else "base"
if self.inputAudioMaxDuration > 0:
# Calculate audio length
audioDuration = ffmpeg.probe(source)["format"]["duration"]
if float(audioDuration) > self.inputAudioMaxDuration:
return ("[ERROR]: Maximum audio file length is " + str(self.inputAudioMaxDuration) + "s, file was " + str(audioDuration) + "s"), "[ERROR]"
model = model_cache.get(selectedModel, None)
if not model:
model = whisper.load_model(selectedModel)
model_cache[selectedModel] = model
# The results
result = model.transcribe(source, language=selectedLanguage, task=task)
text = result["text"]
vtt = getSubs(result["segments"], "vtt")
srt = getSubs(result["segments"], "srt")
# Files that can be downloaded
downloadDirectory = tempfile.mkdtemp()
filePrefix = slugify(sourceName, allow_unicode=True)
download = []
download.append(createFile(srt, downloadDirectory, filePrefix + "-subs.srt"));
download.append(createFile(vtt, downloadDirectory, filePrefix + "-subs.vtt"));
download.append(createFile(text, downloadDirectory, filePrefix + "-transcript.txt"));
return download, text, vtt
finally:
# Cleanup source
if DELETE_UPLOADED_FILES:
print("Deleting source file " + source)
os.remove(source)
def getSource(urlData, uploadFile, microphoneData):
if urlData:
# Download from YouTube
source = downloadUrl(urlData)
else:
# File input
source = uploadFile if uploadFile is not None else microphoneData
file_path = pathlib.Path(source)
sourceName = file_path.stem[:18] + file_path.suffix
return source, sourceName
def createFile(text: str, directory: str, fileName: str) -> str:
# Write the text to a file
with open(os.path.join(directory, fileName), 'w+', encoding="utf-8") as file:
file.write(text)
return file.name
def getSubs(segments: Iterator[dict], format: str) -> str:
segmentStream = StringIO()
if format == 'vtt':
write_vtt(segments, file=segmentStream)
elif format == 'srt':
write_srt(segments, file=segmentStream)
else:
raise Exception("Unknown format " + format)
segmentStream.seek(0)
return segmentStream.read()
def createUi(inputAudioMaxDuration, share=False, server_name: str = None):
ui = UI(inputAudioMaxDuration)
ui_description = "Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
ui_description += " as well as speech translation and language identification. "
if inputAudioMaxDuration > 0:
ui_description += "\n\n" + "Max audio file length: " + str(inputAudioMaxDuration) + " s"
demo = gr.Interface(fn=ui.transcribeFile, description=ui_description, inputs=[
gr.Dropdown(choices=["tiny", "base", "small", "medium", "large"], value="medium", label="Model"),
gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
gr.Text(label="URL (YouTube, etc.)"),
gr.Audio(source="upload", type="filepath", label="Upload Audio"),
gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
], outputs=[
gr.File(label="Download"),
gr.Text(label="Transcription"),
gr.Text(label="Segments")
])
demo.launch(share=share, server_name=server_name)
if __name__ == '__main__':
createUi(DEFAULT_INPUT_AUDIO_MAX_DURATION) |