{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "number of parameters: 29.94M\n", "Running on local URL: http://127.0.0.1:7860\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import gradio as gr\n", "import random\n", "from config import device_type, ckpt_path, GPTConfig, GPT, encode, decode, ctx, num_samples, max_new_tokens, temperature, top_k\n", "\n", "checkpoint = torch.load(ckpt_path, map_location=device_type)\n", "gptconf = GPTConfig(**checkpoint['model_args'])\n", "model = GPT(gptconf)\n", "state_dict = checkpoint['model']\n", "unwanted_prefix = '_orig_mod.'\n", "for k,v in list(state_dict.items()):\n", " if k.startswith(unwanted_prefix):\n", " state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)\n", "model.load_state_dict(state_dict)\n", "model.eval()\n", "model.to(device_type)\n", "\n", "button_click = False\n", "\n", "def fn_query_on_load():\n", " return \"in the air and\"\n", "\n", "num_samples = 1\n", "def generate_commentary(start):\n", " start_ids = encode(start)\n", " x = (torch.tensor(start_ids, dtype=torch.long, device=device_type)[None, ...])\n", "\n", " out_text = ''\n", " with torch.no_grad():\n", " with ctx:\n", " for k in range(num_samples):\n", " y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)\n", " out_text += decode(y[0].tolist())\n", " out_text += '\\n-o-o-o-o-o-o-o-\\n\\n'\n", "\n", " return out_text\n", " \n", " \n", "def fn_gen_comm(prompt, st, o1, o2, o3):\n", " '''global button_click\n", " if not button_click:\n", " button_click = True\n", " elif stat == -1:\n", " button_click = False\n", " return {\n", " output1: output1,\n", " output2: output2,\n", " output3: output3,\n", " stat: stat\n", " }\n", " \n", " \n", " out = generate_commentary(prompt)\n", " if stat == -1:\n", " return {\n", " output1: out,\n", " output2: None,\n", " output3: None,\n", " stat: 0\n", " }\n", " \n", " elif stat == 0:\n", " return {\n", " output1: output1,\n", " output2: out,\n", " output3: None,\n", " stat: 1\n", " }\n", " \n", " elif stat == 2:\n", " return {\n", " output1: output1,\n", " output2: output2,\n", " output3: out,\n", " stat: -1\n", " }'''\n", " \n", " global button_click\n", " if not button_click:\n", " if st == -1:\n", " button_click = True\n", " elif st == -1:\n", " button_click = False\n", " return {\n", " output1: o1,\n", " output2: o2,\n", " output3: o3,\n", " stat: -1\n", " }\n", " elif st == 2:\n", " button_click = False\n", " return {\n", " output1: o1,\n", " output2: o2,\n", " output3: o3,\n", " stat: -1\n", " }\n", " \n", " out = generate_commentary(prompt)\n", " if st == -1:\n", " return {\n", " output1: out,\n", " output2: None,\n", " output3: None,\n", " stat: 0\n", " }\n", " elif st == 0:\n", " return {\n", " output1: o1,\n", " output2: out,\n", " output3: None,\n", " stat: 1\n", " }\n", " elif st == 1:\n", " return {\n", " output1: o1,\n", " output2: o2,\n", " output3: out,\n", " stat: 2\n", " }\n", "\n", "\n", "with gr.Blocks() as app:\n", " with gr.Row():\n", " gr.Markdown(\n", " \"\"\"\n", " # NanoGPT - Cricket Commentary Generative AI\n", " ### Give a prompt and see how it comes out with cricket commentary :)\n", " \"\"\")\n", "\n", " with gr.Row(visible=True):\n", " search_text = gr.Textbox(value=fn_query_on_load, placeholder='Enter prompt..', label='Enter Prompt')\n", "\n", " with gr.Row():\n", " submit_btn = gr.Button(\"Submit\", variant='primary')\n", " clear_btn = gr.ClearButton()\n", " with gr.Row():\n", " with gr.Column():\n", " output1 = gr.Textbox(lines=10, interactive=False, label='Commentary Box')\n", " output2 = gr.Textbox(lines=10, interactive=False, label='Commentary Box')\n", " output3 = gr.Textbox(lines=10, interactive=False, label='Commentary Box')\n", " stat = gr.State(value=-1)\n", " \n", "\n", " def clear_data():\n", " return {\n", " output1: None,\n", " output2: None,\n", " output3: None,\n", " search_text: None\n", " }\n", "\n", " clear_btn.click(clear_data, None, [output1, output2, output3, search_text])\n", "\n", "\n", " submit_btn.click(\n", " fn_gen_comm,\n", " [search_text, stat, output1, output2, output3],\n", " [output1, output2, output3, stat]\n", " )\n", " \n", " '''output1.change(\n", " fn_gen_comm,\n", " search_text,\n", " [output1, output2, output3, stat]\n", " )\n", " \n", " output2.change(\n", " fn_gen_comm,\n", " search_text,\n", " [output1, output2, output3, stat]\n", " )\n", "\n", " output3.change(\n", " fn_gen_comm,\n", " search_text,\n", " [output1, output2, output3, stat]\n", " )'''\n", "\n", "'''\n", "Launch the app\n", "'''\n", "app.queue().launch()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 1 }