Pendrokar's picture
dump cached sample paths to JSON file and load it on Space reset
146bfe1
import time
from .models import *
from .utils import *
from .config import *
from .init import *
from .sample_caching import *
import gradio as gr
from pydub import AudioSegment
import random, os, threading, tempfile
from langdetect import detect
from .vote import log_text
# top five models in order to always have one of them picked and scrutinized
top_five = []
hf_token=os.getenv('HF_TOKEN')
# prioritize low vote models
sql = 'SELECT name FROM model WHERE (upvote + downvote) < 750 ORDER BY (upvote + downvote) ASC'
conn = get_db()
cursor = conn.cursor()
cursor.execute(sql)
data = cursor.fetchall()
for model in data:
if (
len(top_five) >= 5
):
break
if model[0] in AVAILABLE_MODELS.keys():
top_five.append(model[0])
print(f"low vote top_five: {top_five}")
def random_m():
return random.sample(list(set(AVAILABLE_MODELS.keys())), 2)
def check_toxicity(text):
if not TOXICITY_CHECK:
return False
return toxicity.predict(text)['toxicity'] > 0.8
def synthandreturn(text, autoplay, request: gr.Request):
text = text.strip()
if len(text) > MAX_SAMPLE_TXT_LENGTH:
raise gr.Error(f'You exceeded the limit of {MAX_SAMPLE_TXT_LENGTH} characters')
if len(text) < MIN_SAMPLE_TXT_LENGTH:
raise gr.Error(f'Please input a text longer than {MIN_SAMPLE_TXT_LENGTH} characters')
if (
# test toxicity if not prepared text
text not in sents
and check_toxicity(text)
):
print(f'Detected toxic content! "{text}"')
raise gr.Error('Your text failed the toxicity test')
if not text:
raise gr.Error(f'You did not enter any text')
# Check language
try:
if (
text not in sents
and not detect(text) == "en"
):
gr.Warning('Warning: The input text may not be in English')
except:
pass
# Get two random models
# forced model: your TTS model versus The World!!!
# mdl1 = 'Pendrokar/xVASynth'
# scrutinize the top five by always picking one of them
if (len(top_five) >= 5):
mdl1 = random.sample(top_five, 1)[0]
vsModels = dict(AVAILABLE_MODELS)
del vsModels[mdl1]
# randomize position of the forced model
mdl2 = random.sample(list(vsModels.keys()), 1)
# forced random
mdl1, mdl2 = random.sample(list([mdl1, mdl2[0]]), 2)
else:
# actual random
mdl1, mdl2 = random.sample(list(AVAILABLE_MODELS.keys()), 2)
print("[debug] Using", mdl1, mdl2)
def predict_and_update_result(text, model, result_storage, request:gr.Request):
hf_headers = {}
try:
if HF_SPACES[model]['is_zero_gpu_space']:
hf_headers = {"X-IP-Token": request.headers['x-ip-token']}
except:
pass
# re-attempt if necessary
attempt_count = 0
max_attempts = 1 # 3 =May cause 429 Too Many Request
while attempt_count < max_attempts:
try:
if model in AVAILABLE_MODELS:
if '/' in model:
# Use public HF Space
# if (model not in hf_clients):
# hf_clients[model] = Client(model, hf_token=hf_token, headers=hf_headers)
mdl_space = Client(model, hf_token=hf_token, headers=hf_headers)
# print(f"{model}: Fetching endpoints of HF Space")
# assume the index is one of the first 9 return params
return_audio_index = int(HF_SPACES[model]['return_audio_index'])
endpoints = mdl_space.view_api(all_endpoints=True, print_info=False, return_format='dict')
api_name = None
fn_index = None
end_parameters = None
# has named endpoint
if '/' == HF_SPACES[model]['function'][0]:
# audio sync function name
api_name = HF_SPACES[model]['function']
end_parameters = _get_param_examples(
endpoints['named_endpoints'][api_name]['parameters']
)
# has unnamed endpoint
else:
# endpoint index is the first character
fn_index = int(HF_SPACES[model]['function'])
end_parameters = _get_param_examples(
endpoints['unnamed_endpoints'][str(fn_index)]['parameters']
)
# override some or all default parameters
space_inputs = _override_params(end_parameters, model)
# force text
space_inputs[HF_SPACES[model]['text_param_index']] = text
print(f"{model}: Sending request to HF Space")
# results = mdl_space.predict(*space_inputs, api_name=api_name, fn_index=fn_index)
if(type(space_inputs) == dict):
results = mdl_space.predict(
**space_inputs,
api_name=api_name,
fn_index=fn_index
)
else:
results = mdl_space.predict(
*space_inputs,
api_name=api_name,
fn_index=fn_index
)
# return path to audio
result = results
if (not isinstance(results, str)):
# return_audio_index may be a filepath string
result = results[return_audio_index]
if (isinstance(result, dict)):
# return_audio_index is a dictionary
result = results[return_audio_index]['value']
else:
# Use the private HF Space
result = router.predict(text, AVAILABLE_MODELS[model].lower(), api_name="/synthesize")
else:
result = router.predict(text, model.lower(), api_name="/synthesize")
break
except Exception as e:
attempt_count += 1
raise gr.Error(f"{model}:"+ repr(e))
# print(f"{model}: Unable to call API (attempt: {attempt_count})")
# sleep for three seconds to avoid spamming the server with requests
# time.sleep(3)
# Fetch and store client again
# hf_clients[model] = Client(model, hf_token=hf_token, headers=hf_headers)
if attempt_count >= max_attempts:
raise gr.Error(f"{model}: Failed to call model")
else:
print('Done with', model)
try:
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
audio = AudioSegment.from_file(result)
current_sr = audio.frame_rate
if current_sr > 24000:
print(f"{model}: Resampling")
audio = audio.set_frame_rate(24000)
try:
print(f"{model}: Trying to normalize audio")
audio = match_target_amplitude(audio, -20)
except:
print(f"{model}: [WARN] Unable to normalize audio")
audio.export(f.name, format="wav")
os.unlink(result)
result = f.name
gr.Info('Audio from a TTS model received')
except:
print(f"{model}: [WARN] Unable to resample audio")
pass
if model in AVAILABLE_MODELS.keys(): model = AVAILABLE_MODELS[model]
result_storage[model] = result
def _get_param_examples(parameters):
# named or unnamed parameters
try:
param_name = parameters[0]['parameter_name']
# success => named params, use dict
example_inputs = {}
except:
# unnamed params, use list
example_inputs = []
pass
for param_info in parameters:
param_name = ''
param_default_value = param_info['example_input']
try:
# named params
param_name = param_info['parameter_name']
param_default_value = param_info['parameter_default']
except:
# unnamed params
pass
param_value = None
if (
param_info['component'] == 'Radio'
or param_info['component'] == 'Dropdown'
or param_info['component'] == 'Audio'
or param_info['python_type']['type'] == 'str'
):
param_value = str(param_default_value)
elif param_info['python_type']['type'] == 'int':
param_value = int(param_default_value)
elif param_info['python_type']['type'] == 'float':
param_value = float(param_default_value)
elif param_info['python_type']['type'] == 'bool':
param_value = bool(param_default_value)
if (param_name != ''):
# named param
example_inputs[param_info['parameter_name']] = param_value
else:
# just append unnamed param and hope
example_inputs.append(param_value)
return example_inputs
def _override_params(inputs, modelname):
try:
for key,value in OVERRIDE_INPUTS[modelname].items():
# if override keys are integers, make the dict into a list
if (
(type(inputs) is dict)
and (type(key) is int)
):
print(f"{modelname}: Converting unnamed override params to List")
inputs = list(inputs.values())
inputs[key] = value
print(f"{modelname}: Default inputs overridden by Arena")
except:
pass
return inputs
def _cache_sample(text, model):
# skip caching if not hardcoded sentence
if (text not in sents):
return False
already_cached = False
# check if already cached
for cached_sample in cached_samples:
# TODO:replace cached with newer version
if (cached_sample.transcript == text and cached_sample.modelName == model):
already_cached = True
return True
if (already_cached):
return False
try:
cached_samples.append(Sample(results[model], text, model))
except:
print('Error when trying to cache sample')
return False
# save list to JSON file
cached_sample_dict = [cached_sample.to_dict() for cached_sample in cached_samples]
try:
with open("_cached_samples.json", "w") as write:
json.dump( cached_sample_dict , write )
except:
pass
mdl1k = mdl1
mdl2k = mdl2
print(mdl1k, mdl2k)
if mdl1 in AVAILABLE_MODELS.keys(): mdl1k=AVAILABLE_MODELS[mdl1]
if mdl2 in AVAILABLE_MODELS.keys(): mdl2k=AVAILABLE_MODELS[mdl2]
results = {}
print(f"Sending models {mdl1k} and {mdl2k} to API")
# do not use multithreading when both spaces are ZeroGPU type
if (
# exists
'is_zero_gpu_space' in HF_SPACES[mdl1]
# is True
and HF_SPACES[mdl1]['is_zero_gpu_space']
and 'is_zero_gpu_space' in HF_SPACES[mdl2]
and HF_SPACES[mdl2]['is_zero_gpu_space']
):
# run Zero-GPU spaces one at a time
predict_and_update_result(text, mdl1k, results, request)
_cache_sample(text, mdl1k)
predict_and_update_result(text, mdl2k, results, request)
_cache_sample(text, mdl2k)
else:
# use multithreading
thread1 = threading.Thread(target=predict_and_update_result, args=(text, mdl1k, results, request))
thread2 = threading.Thread(target=predict_and_update_result, args=(text, mdl2k, results, request))
thread1.start()
# wait 3 seconds to calm hf.space domain
time.sleep(3)
thread2.start()
# timeout in 2 minutes
thread1.join(120)
thread2.join(120)
# cache the result
for model in [mdl1k, mdl2k]:
_cache_sample(text, model)
print(f"Retrieving models {mdl1k} and {mdl2k} from API")
return (
text,
"Synthesize 🐢",
gr.update(visible=True), # r2
mdl1, # model1
mdl2, # model2
gr.update(visible=True, value=results[mdl1k], autoplay=autoplay), # aud1
gr.update(visible=True, value=results[mdl2k], autoplay=False), # aud2
gr.update(visible=True, interactive=False), #abetter
gr.update(visible=True, interactive=False), #bbetter
gr.update(visible=False), #prevmodel1
gr.update(visible=False), #prevmodel2
gr.update(visible=False), #nxt round btn
# reset gr.State aplayed & bplayed
False, #aplayed
False, #bplayed
)
# Battle Mode
def synthandreturn_battle(text, mdl1, mdl2, autoplay):
if mdl1 == mdl2:
raise gr.Error('You can\'t pick two of the same models.')
text = text.strip()
if len(text) > MAX_SAMPLE_TXT_LENGTH:
raise gr.Error(f'You exceeded the limit of {MAX_SAMPLE_TXT_LENGTH} characters')
if len(text) < MIN_SAMPLE_TXT_LENGTH:
raise gr.Error(f'Please input a text longer than {MIN_SAMPLE_TXT_LENGTH} characters')
if (
# test toxicity if not prepared text
text not in sents
and check_toxicity(text)
):
print(f'Detected toxic content! "{text}"')
raise gr.Error('Your text failed the toxicity test')
if not text:
raise gr.Error(f'You did not enter any text')
# Check language
try:
if not detect(text) == "en":
gr.Warning('Warning: The input text may not be in English')
except:
pass
# Get two random models
log_text(text)
print("[debug] Using", mdl1, mdl2)
def predict_and_update_result(text, model, result_storage):
try:
if model in AVAILABLE_MODELS:
result = router.predict(text, AVAILABLE_MODELS[model].lower(), api_name="/synthesize")
else:
result = router.predict(text, model.lower(), api_name="/synthesize")
except:
raise gr.Error('Unable to call API, please try again :)')
print('Done with', model)
# try:
# doresample(result)
# except:
# pass
try:
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
audio = AudioSegment.from_file(result)
current_sr = audio.frame_rate
if current_sr > 24000:
audio = audio.set_frame_rate(24000)
try:
print('Trying to normalize audio')
audio = match_target_amplitude(audio, -20)
except:
print('[WARN] Unable to normalize audio')
audio.export(f.name, format="wav")
os.unlink(result)
result = f.name
except:
pass
if model in AVAILABLE_MODELS.keys(): model = AVAILABLE_MODELS[model]
print(model)
print(f"Running model {model}")
result_storage[model] = result
# try:
# doloudnorm(result)
# except:
# pass
mdl1k = mdl1
mdl2k = mdl2
print(mdl1k, mdl2k)
if mdl1 in AVAILABLE_MODELS.keys(): mdl1k=AVAILABLE_MODELS[mdl1]
if mdl2 in AVAILABLE_MODELS.keys(): mdl2k=AVAILABLE_MODELS[mdl2]
results = {}
print(f"Sending models {mdl1k} and {mdl2k} to API")
thread1 = threading.Thread(target=predict_and_update_result, args=(text, mdl1k, results))
thread2 = threading.Thread(target=predict_and_update_result, args=(text, mdl2k, results))
thread1.start()
thread2.start()
thread1.join()
thread2.join()
print(f"Retrieving models {mdl1k} and {mdl2k} from API")
return (
text,
"Synthesize 🐢",
gr.update(visible=True), # r2
mdl1, # model1
mdl2, # model2
gr.update(visible=True, value=results[mdl1k], autoplay=autoplay), # aud1
gr.update(visible=True, value=results[mdl2k], autoplay=False), # aud2
gr.update(visible=True, interactive=False), #abetter
gr.update(visible=True, interactive=False), #bbetter
gr.update(visible=False), #prevmodel1
gr.update(visible=False), #prevmodel2
gr.update(visible=False), #nxt round btn
)
def randomsent():
return '⚡', random.choice(sents), '🎲'
def randomsent_battle():
return tuple(randomsent()) + tuple(random_m())
def clear_stuff():
return [
gr.update(visible=True, value="", elem_classes=[]),
"Synthesize 🐢",
gr.update(visible=False), # r2
'', # model1
'', # model2
gr.update(visible=False, interactive=False, autoplay=False), # aud1
gr.update(visible=False, interactive=False, autoplay=False), # aud2
gr.update(visible=False, interactive=False), #abetter
gr.update(visible=False, interactive=False), #bbetter
gr.update(visible=False), #prevmodel1
gr.update(visible=False), #prevmodel2
gr.update(visible=False), #nxt round btn
False, #aplayed
False, #bplayed
]