Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
import random | |
from config import device_type, ckpt_path, GPTConfig, GPT, encode, decode, ctx, num_samples, max_new_tokens, temperature, top_k | |
checkpoint = torch.load(ckpt_path, map_location=device_type) | |
gptconf = GPTConfig(**checkpoint['model_args']) | |
model = GPT(gptconf) | |
state_dict = checkpoint['model'] | |
unwanted_prefix = '_orig_mod.' | |
for k,v in list(state_dict.items()): | |
if k.startswith(unwanted_prefix): | |
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) | |
model.load_state_dict(state_dict) | |
model.eval() | |
model.to(device_type) | |
def fn_query_on_load(): | |
return "in the air and" | |
def generate_commentary(start): | |
start_ids = encode(start) | |
x = (torch.tensor(start_ids, dtype=torch.long, device=device_type)[None, ...]) | |
out_text = '' | |
with torch.no_grad(): | |
with ctx: | |
for k in range(num_samples): | |
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) | |
out_text += decode(y[0].tolist()) | |
out_text += '\n-o-o-o-o-o-o-o-\n\n' | |
return { | |
output: out_text | |
} | |
with gr.Blocks() as app: | |
with gr.Row(): | |
gr.Markdown( | |
""" | |
# NanoGPT - Cricket Commentary Generative AI | |
### Give a prompt and see how it comes out with cricket commentary :) | |
""") | |
with gr.Row(visible=True): | |
search_text = gr.Textbox(value=fn_query_on_load, placeholder='Enter prompt..', label='Enter Prompt') | |
with gr.Row(): | |
submit_btn = gr.Button("Submit", variant='primary') | |
clear_btn = gr.ClearButton() | |
with gr.Row(): | |
with gr.Row(): | |
output = gr.Textbox(lines=15, interactive=False, label='Commentary Box') | |
def clear_data(): | |
return { | |
output: None, | |
search_text: None | |
} | |
clear_btn.click(clear_data, None, [output, search_text]) | |
submit_btn.click( | |
generate_commentary, | |
search_text, | |
output | |
) | |
''' | |
Launch the app | |
''' | |
app.queue().launch() |