import torch import gradio as gr import random from config import device_type, ckpt_path, GPTConfig, GPT, encode, decode, ctx, num_samples, max_new_tokens, temperature, top_k checkpoint = torch.load(ckpt_path, map_location=device_type) gptconf = GPTConfig(**checkpoint['model_args']) model = GPT(gptconf) state_dict = checkpoint['model'] unwanted_prefix = '_orig_mod.' for k,v in list(state_dict.items()): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) model.load_state_dict(state_dict) model.eval() model.to(device_type) def fn_query_on_load(): return "in the air and" def generate_commentary(start): start_ids = encode(start) x = (torch.tensor(start_ids, dtype=torch.long, device=device_type)[None, ...]) out_text = '' with torch.no_grad(): with ctx: for k in range(num_samples): y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) out_text += decode(y[0].tolist()) out_text += '\n-o-o-o-o-o-o-o-\n\n' return { output: out_text } with gr.Blocks() as app: with gr.Row(): gr.Markdown( """ # NanoGPT - Cricket Commentary Generative AI ### Give a prompt and see how it comes out with cricket commentary :) """) with gr.Row(visible=True): search_text = gr.Textbox(value=fn_query_on_load, placeholder='Enter prompt..', label='Enter Prompt') with gr.Row(): submit_btn = gr.Button("Submit", variant='primary') clear_btn = gr.ClearButton() with gr.Row(): with gr.Row(): output = gr.Textbox(lines=15, interactive=False, label='Commentary Box') def clear_data(): return { output: None, search_text: None } clear_btn.click(clear_data, None, [output, search_text]) submit_btn.click( generate_commentary, search_text, output ) ''' Launch the app ''' app.queue().launch()