Pendrokar's picture
autoplay checkbox & unlock vote
f968ac4
raw
history blame
14.8 kB
import time
from .models import *
from .utils import *
from .config import *
from .init import *
import gradio as gr
from pydub import AudioSegment
import random, os, threading, tempfile
from langdetect import detect
from .vote import log_text
top_five = []
hf_token=os.getenv('HF_TOKEN')
def random_m():
return random.sample(list(set(AVAILABLE_MODELS.keys())), 2)
def check_toxicity(text):
if not TOXICITY_CHECK:
return False
return toxicity.predict(text)['toxicity'] > 0.8
def synthandreturn(text, autoplay, request: gr.Request):
text = text.strip()
if len(text) > MAX_SAMPLE_TXT_LENGTH:
raise gr.Error(f'You exceeded the limit of {MAX_SAMPLE_TXT_LENGTH} characters')
if len(text) < MIN_SAMPLE_TXT_LENGTH:
raise gr.Error(f'Please input a text longer than {MIN_SAMPLE_TXT_LENGTH} characters')
if (
# test toxicity if not prepared text
text not in sents
and check_toxicity(text)
):
print(f'Detected toxic content! "{text}"')
raise gr.Error('Your text failed the toxicity test')
if not text:
raise gr.Error(f'You did not enter any text')
# Check language
try:
if (
text not in sents
and not detect(text) == "en"
):
gr.Warning('Warning: The input text may not be in English')
except:
pass
# Get two random models
# forced model: your TTS model versus The World!!!
# mdl1 = 'Pendrokar/xVASynth'
# scrutinize the top five by always picking one of them
if (len(top_five) >= 5):
mdl1 = random.sample(top_five, 1)[0]
vsModels = dict(AVAILABLE_MODELS)
del vsModels[mdl1]
# randomize position of the forced model
mdl2 = random.sample(list(vsModels.keys()), 1)
# forced random
mdl1, mdl2 = random.sample(list([mdl1, mdl2[0]]), 2)
else:
# actual random
mdl1, mdl2 = random.sample(list(AVAILABLE_MODELS.keys()), 2)
print("[debug] Using", mdl1, mdl2)
def predict_and_update_result(text, model, result_storage, request:gr.Request):
hf_headers = {}
try:
if HF_SPACES[model]['is_zero_gpu_space']:
hf_headers = {"X-IP-Token": request.headers['x-ip-token']}
except:
pass
# re-attempt if necessary
attempt_count = 0
max_attempts = 1 # 3 =May cause 429 Too Many Request
while attempt_count < max_attempts:
try:
if model in AVAILABLE_MODELS:
if '/' in model:
# Use public HF Space
# if (model not in hf_clients):
# hf_clients[model] = Client(model, hf_token=hf_token, headers=hf_headers)
mdl_space = Client(model, hf_token=hf_token, headers=hf_headers)
# print(f"{model}: Fetching endpoints of HF Space")
# assume the index is one of the first 9 return params
return_audio_index = int(HF_SPACES[model]['return_audio_index'])
endpoints = mdl_space.view_api(all_endpoints=True, print_info=False, return_format='dict')
api_name = None
fn_index = None
end_parameters = None
# has named endpoint
if '/' == HF_SPACES[model]['function'][0]:
# audio sync function name
api_name = HF_SPACES[model]['function']
end_parameters = _get_param_examples(
endpoints['named_endpoints'][api_name]['parameters']
)
# has unnamed endpoint
else:
# endpoint index is the first character
fn_index = int(HF_SPACES[model]['function'])
end_parameters = _get_param_examples(
endpoints['unnamed_endpoints'][str(fn_index)]['parameters']
)
# override some or all default parameters
space_inputs = _override_params(end_parameters, model)
# force text
space_inputs[HF_SPACES[model]['text_param_index']] = text
print(f"{model}: Sending request to HF Space")
results = mdl_space.predict(*space_inputs, api_name=api_name, fn_index=fn_index)
# return path to audio
result = results
if (not isinstance(results, str)):
# return_audio_index may be a filepath string
result = results[return_audio_index]
if (isinstance(result, dict)):
# return_audio_index is a dictionary
result = results[return_audio_index]['value']
else:
# Use the private HF Space
result = router.predict(text, AVAILABLE_MODELS[model].lower(), api_name="/synthesize")
else:
result = router.predict(text, model.lower(), api_name="/synthesize")
break
except Exception as e:
attempt_count += 1
raise gr.Error(f"{model}:"+ repr(e))
# print(f"{model}: Unable to call API (attempt: {attempt_count})")
# sleep for three seconds to avoid spamming the server with requests
# time.sleep(3)
# Fetch and store client again
# hf_clients[model] = Client(model, hf_token=hf_token, headers=hf_headers)
if attempt_count >= max_attempts:
raise gr.Error(f"{model}: Failed to call model")
else:
print('Done with', model)
try:
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
audio = AudioSegment.from_file(result)
current_sr = audio.frame_rate
if current_sr > 24000:
print(f"{model}: Resampling")
audio = audio.set_frame_rate(24000)
try:
print(f"{model}: Trying to normalize audio")
audio = match_target_amplitude(audio, -20)
except:
print(f"{model}: [WARN] Unable to normalize audio")
audio.export(f.name, format="wav")
os.unlink(result)
result = f.name
gr.Info('Audio from a TTS model received')
except:
print(f"{model}: [WARN] Unable to resample audio")
pass
if model in AVAILABLE_MODELS.keys(): model = AVAILABLE_MODELS[model]
result_storage[model] = result
def _get_param_examples(parameters):
example_inputs = []
for param_info in parameters:
if (
param_info['component'] == 'Radio'
or param_info['component'] == 'Dropdown'
or param_info['component'] == 'Audio'
or param_info['python_type']['type'] == 'str'
):
example_inputs.append(str(param_info['example_input']))
continue
if param_info['python_type']['type'] == 'int':
example_inputs.append(int(param_info['example_input']))
continue
if param_info['python_type']['type'] == 'float':
example_inputs.append(float(param_info['example_input']))
continue
if param_info['python_type']['type'] == 'bool':
example_inputs.append(bool(param_info['example_input']))
continue
return example_inputs
def _override_params(inputs, modelname):
try:
for key,value in OVERRIDE_INPUTS[modelname].items():
inputs[key] = value
print(f"{modelname}: Default inputs overridden by Arena")
except:
pass
return inputs
mdl1k = mdl1
mdl2k = mdl2
print(mdl1k, mdl2k)
if mdl1 in AVAILABLE_MODELS.keys(): mdl1k=AVAILABLE_MODELS[mdl1]
if mdl2 in AVAILABLE_MODELS.keys(): mdl2k=AVAILABLE_MODELS[mdl2]
results = {}
print(f"Sending models {mdl1k} and {mdl2k} to API")
# do not use multithreading when both spaces are ZeroGPU type
if (
# exists
'is_zero_gpu_space' in HF_SPACES[mdl1]
# is True
and HF_SPACES[mdl1]['is_zero_gpu_space']
and 'is_zero_gpu_space' in HF_SPACES[mdl2]
and HF_SPACES[mdl2]['is_zero_gpu_space']
):
# run Zero-GPU spaces one at a time
predict_and_update_result(text, mdl1k, results, request)
# _cache_sample(text, mdl1k)
predict_and_update_result(text, mdl2k, results, request)
# _cache_sample(text, mdl2k)
else:
# use multithreading
thread1 = threading.Thread(target=predict_and_update_result, args=(text, mdl1k, results, request))
thread2 = threading.Thread(target=predict_and_update_result, args=(text, mdl2k, results, request))
thread1.start()
# wait 3 seconds to calm hf.space domain
time.sleep(3)
thread2.start()
# timeout in 2 minutes
thread1.join(120)
thread2.join(120)
# cache the result
# for model in [mdl1k, mdl2k]:
# _cache_sample(text, model)
print(f"Retrieving models {mdl1k} and {mdl2k} from API")
return (
text,
"Synthesize",
gr.update(visible=True), # r2
mdl1, # model1
mdl2, # model2
gr.update(visible=True, value=results[mdl1k], autoplay=autoplay), # aud1
gr.update(visible=True, value=results[mdl2k], autoplay=False), # aud2
gr.update(visible=True, interactive=False), #abetter
gr.update(visible=True, interactive=False), #bbetter
gr.update(visible=False), #prevmodel1
gr.update(visible=False), #prevmodel2
gr.update(visible=False), #nxt round btn
)
# Battle Mode
def synthandreturn_battle(text, mdl1, mdl2, autoplay):
if mdl1 == mdl2:
raise gr.Error('You can\'t pick two of the same models.')
text = text.strip()
if len(text) > MAX_SAMPLE_TXT_LENGTH:
raise gr.Error(f'You exceeded the limit of {MAX_SAMPLE_TXT_LENGTH} characters')
if len(text) < MIN_SAMPLE_TXT_LENGTH:
raise gr.Error(f'Please input a text longer than {MIN_SAMPLE_TXT_LENGTH} characters')
if (
# test toxicity if not prepared text
text not in sents
and check_toxicity(text)
):
print(f'Detected toxic content! "{text}"')
raise gr.Error('Your text failed the toxicity test')
if not text:
raise gr.Error(f'You did not enter any text')
# Check language
try:
if not detect(text) == "en":
gr.Warning('Warning: The input text may not be in English')
except:
pass
# Get two random models
log_text(text)
print("[debug] Using", mdl1, mdl2)
def predict_and_update_result(text, model, result_storage):
try:
if model in AVAILABLE_MODELS:
result = router.predict(text, AVAILABLE_MODELS[model].lower(), api_name="/synthesize")
else:
result = router.predict(text, model.lower(), api_name="/synthesize")
except:
raise gr.Error('Unable to call API, please try again :)')
print('Done with', model)
# try:
# doresample(result)
# except:
# pass
try:
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
audio = AudioSegment.from_file(result)
current_sr = audio.frame_rate
if current_sr > 24000:
audio = audio.set_frame_rate(24000)
try:
print('Trying to normalize audio')
audio = match_target_amplitude(audio, -20)
except:
print('[WARN] Unable to normalize audio')
audio.export(f.name, format="wav")
os.unlink(result)
result = f.name
except:
pass
if model in AVAILABLE_MODELS.keys(): model = AVAILABLE_MODELS[model]
print(model)
print(f"Running model {model}")
result_storage[model] = result
# try:
# doloudnorm(result)
# except:
# pass
mdl1k = mdl1
mdl2k = mdl2
print(mdl1k, mdl2k)
if mdl1 in AVAILABLE_MODELS.keys(): mdl1k=AVAILABLE_MODELS[mdl1]
if mdl2 in AVAILABLE_MODELS.keys(): mdl2k=AVAILABLE_MODELS[mdl2]
results = {}
print(f"Sending models {mdl1k} and {mdl2k} to API")
thread1 = threading.Thread(target=predict_and_update_result, args=(text, mdl1k, results))
thread2 = threading.Thread(target=predict_and_update_result, args=(text, mdl2k, results))
thread1.start()
thread2.start()
thread1.join()
thread2.join()
print(f"Retrieving models {mdl1k} and {mdl2k} from API")
return (
text,
"Synthesize",
gr.update(visible=True), # r2
mdl1, # model1
mdl2, # model2
gr.update(visible=True, value=results[mdl1k], autoplay=autoplay), # aud1
gr.update(visible=True, value=results[mdl2k], autoplay=False), # aud2
gr.update(visible=True, interactive=False), #abetter
gr.update(visible=True, interactive=False), #bbetter
gr.update(visible=False), #prevmodel1
gr.update(visible=False), #prevmodel2
gr.update(visible=False), #nxt round btn
)
# Unlock vote
def unlock_vote(btn_index, aplayed, bplayed):
# sample played
if btn_index == 0:
aplayed = gr.State(value=True)
if btn_index == 1:
bplayed = gr.State(value=True)
# both audio samples played
if bool(aplayed) and bool(bplayed):
print('Both audio samples played, voting unlocked')
return [gr.update(interactive=True), gr.update(interactive=True), gr.update(), gr.update()]
return [gr.update(), gr.update(), aplayed, bplayed]
def randomsent():
return random.choice(sents), '🎲'
def randomsent_battle():
return tuple(randomsent()) + tuple(random_m())
def clear_stuff():
return "", "Synthesize", gr.update(visible=False), '', '', gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)