Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from transformers import pipeline | |
import random | |
import re | |
import torch | |
import spaces | |
title = "WoW Quest Text Generator" | |
description = "Tap on the \"Submit\" button to generate a random quest text." | |
article = "<p>Fine tuned <a href=\"https://huggingface.co/EleutherAI/gpt-neo-125M\">EleutherAI/gpt-neo-125M</a> upon a formatted <a href=\"https://github.com/TrinityCore/TrinityCore\"> TrinityCore – TDB_full_world_927.22082_2022_08_21 Dataset</a></p><p>This generator is fan made and is not affiliated in any way with Blizzard and/or any other company</p>" | |
CUDA_AVAILABLE = torch.cuda.is_available() | |
device = torch.device("cuda" if CUDA_AVAILABLE else "cpu") | |
model_id = "./model" | |
text_generator = pipeline("text-generation", model=model_id, tokenizer=model_id, device=device) | |
max_length = 256 | |
top_k = 40 | |
top_p = 0.92 | |
temperature = 1.0 | |
random.seed(None) | |
wow_class_list = ["Death Knight", "Demon Hunter", "Druid", "Hunter", "Mage", "Monk", "Paladin", "Priest", "Rogue", "Shaman", "Warrior", "Warlock"] | |
wow_race_list = ["Blood Elf", "Human", "Tauren", "Orc", "Kul Tiran", "Void Elf", "Troll", "Vulpera", "Night Elf", "Zandalari Troll", "Worgen", "Undead", "Goblin", "Highmountain Tauren", "Nightborne", "Dwarf", "Draenei", "Gnome", "Lightforged Draenei", "Pandaren", "Maghar Orc", "Mechagnome", "Dark Iron Dwarf"] | |
wow_silly_name_list = ["Glitterstorm", "Sunderwear", "Arrowdynamic", "Sapntap", "Crossblesser", "Praystation", "Healium", "Shocknorris", "Alestrom", "Harryportal", "Merlìn", "Wreckquiem", "Owlcapone"] | |
suggested_text_list = ["Greetings $r", "$c I need your help", "Good to see you $n", "Hey $gBoy:Girl; "] | |
def parseGenderTokens(text): | |
regex = r"\$[gG]([^:]+):([^;]+);" | |
matches = re.finditer(regex, text, re.MULTILINE) | |
parsed_string = "" | |
prev_index = 0 | |
group_num = 0 | |
random_group = -1 | |
for matchNum, match in enumerate(matches, start=1): | |
parsed_string += text[prev_index:match.start()] | |
if random_group == -1: | |
group_num = len(match.groups()) | |
random_group = random.randint(1, group_num) | |
parsed_string += match.group(random_group) | |
prev_index = match.end(group_num) + 1 | |
parsed_string += text[prev_index:] | |
return parsed_string | |
def parseSpecialCharacters(text, wow_class_item, wow_race_item, wow_silly_name_item): | |
parsedText = text.replace("$a", "\n").replace("$B", "\n").replace("$b", "\n").replace("$c", wow_class_item).replace("$C", wow_class_item).replace("$r", wow_race_item).replace("$R", wow_race_item).replace("$n", wow_silly_name_item).replace("$N", wow_silly_name_item) | |
return parseGenderTokens(parsedText) | |
def text_generation(input_text = None): | |
if input_text == None or len(input_text) == 0: | |
input_text = "<|startoftext|>" | |
else: | |
if input_text.startswith("<|startoftext|>") == False: | |
input_text ="<|startoftext|>" + input_text | |
generated_text = text_generator(input_text, | |
max_length=max_length, | |
top_k=top_k, | |
top_p=top_p, | |
temperature=temperature, | |
do_sample=True, | |
repetition_penalty=2.0, | |
num_return_sequences=1) | |
parsed_text = generated_text[0]["generated_text"].replace("<|startoftext|>", "").replace("\r","").replace("\n\n", "\n").replace("\t", " ").replace("<|pad|>", " * ").replace("\"\"", "\"") | |
wow_class_item = random.choice(wow_class_list) | |
wow_race_item = random.choice(wow_race_list) | |
wow_silly_name_item = random.choice(wow_silly_name_list) | |
parsed_text = parseSpecialCharacters(parsed_text, wow_class_item, wow_race_item, wow_silly_name_item) | |
parsed_text = parsed_text.replace("\\n", "\n") | |
return parsed_text | |
demo = gr.Interface( | |
text_generation, | |
inputs=gr.Textbox(lines=1, label="Enter strating text or leave blank"), | |
outputs=gr.Textbox(type="text", label="Generated quest text"), | |
title=title, | |
description=description, | |
article=article, | |
examples=suggested_text_list, | |
allow_flagging="never", | |
) | |
demo.queue() | |
demo.launch() |