import torch from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig modelname="gpt2" config = AutoConfig.from_pretrained(modelname) tokenizer = AutoTokenizer.from_pretrained(modelname) model = AutoModelForCausalLM.from_pretrained(modelname,config=config) def botsay(user_input): prompt = "This is a conversation between Human and AI bot. AI's name is ThatGPT." new_token_id=None gen_tokens="" new_token="" j =6 length=0 limit = 128 thatid=5562 cont = True last_apppended = False cnt=0 disable_repeat_length= 5 disable_repeat_count = 2 tokens=[] while(cont): cnt+=1 prob = 1.0 input_ids=tokenizer(prompt+user_input+"\nAI:"+gen_tokens,return_tensors="pt").input_ids length=len(input_ids) if length >limit: gen_tokens="⚠️sorry length limit. please reload the browser." return gen_tokens outs=model(input_ids=input_ids) topk = torch.topk(outs.logits.squeeze()[-1,:],k=j+1).indices if new_token =="that": that_id = 326 elif new_token ==" that": that_id = -1 elif new_token[-1:] ==" ": that_id = 5562 else: that_id = 326 if ("thatGPT" in gen_tokens[-12:]): that_id = -1 if last_apppended: that_id = -1 if that_id in topk: new_token_id = that_id else: new_token_id = torch.argmax(outs.logits.squeeze()[-1,:]) new_token=tokenizer.decode(new_token_id) new_token=tokenizer.decode(new_token_id) prev_tokens=gen_tokens gen_tokens+=new_token if (cnt>10) and (disable_repeat_count": if ("that" not in gen_tokens): gen_tokens = gen_tokens.replace("\n","").replace(".","") gen_tokens += " that" else: cont = False return gen_tokens.replace("
","").replace("AI:","").replace("\xa0","") import gradio as gr def add_text(history, text): history = history + [(text, None)] return history, "" def bot(history): serial_history="" for h in history: serial_history+="\nHuman:"+h[0] if h[1]==None: break serial_history+="\nAI:"+h[1].replace("
","") response = botsay(serial_history) history[-1][1] = response serial_history+="\nAI:"+response return history with gr.Blocks() as demo: gr.Markdown("# ThatGPT") chatbot = gr.Chatbot([], elem_id="chatbot").style(height=750) with gr.Row(): with gr.Column(scale=0.85): txt = gr.Textbox( show_label=False, placeholder="AI always replies with \"that\".", ).style(container=False) txt.submit(add_text, [chatbot, txt], [chatbot, txt]).then( bot, chatbot, chatbot ) demo.launch()