Spaces:
Runtime error
Runtime error
File size: 7,036 Bytes
3edfd3e a509e98 3edfd3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import os
import gradio as gr
import clueai
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("ClueAI/ChatYuan-large-v2")
model = T5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v2")
# 使用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
base_info = "用户:你是谁?\n小元:我是元语智能公司研发的AI智能助手, 在不违反原则的情况下,我可以回答你的任何问题。\n"
def preprocess(text):
text = f"{base_info}{text}"
text = text.replace("\n", "\\n").replace("\t", "\\t")
return text
def postprocess(text):
return text.replace("\\n", "\n").replace("\\t", "\t").replace('%20',' ')#.replace(" ", " ")
generate_config = {'do_sample': True, 'top_p': 0.9, 'top_k': 50, 'temperature': 0.7,
'num_beams': 1, 'max_length': 1024, 'min_length': 3, 'no_repeat_ngram_size': 5,
'length_penalty': 0.6, 'return_dict_in_generate': True, 'output_scores': True}
def answer(text, sample=True, top_p=0.9, temperature=0.7):
'''sample:是否抽样。生成任务,可以设置为True;
top_p:0-1之间,生成的内容越多样'''
text = preprocess(text)
encoding = tokenizer(text=[text], truncation=True, padding=True, max_length=1024, return_tensors="pt").to(device)
if not sample:
out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=1024, num_beams=1, length_penalty=0.6)
else:
out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=1024, do_sample=True, top_p=top_p, temperature=temperature, no_repeat_ngram_size=12)
#out=model.generate(**encoding, **generate_config)
out_text = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True)
return postprocess(out_text[0])
def clear_session():
return '', None
def chatyuan_bot(input, history):
history = history or []
if len(history) > 5:
history = history[-5:]
context = "\n".join([f"用户:{input_text}\n小元:{answer_text}" for input_text, answer_text in history])
#print(context)
input_text = context + "\n用户:" + input + "\n小元:"
input_text = input_text.strip()
output_text = answer(input_text)
print("open_model".center(20, "="))
print(f"{input_text}\n{output_text}")
#print("="*20)
history.append((input, output_text))
#print(history)
return history, history
def chatyuan_bot_regenerate(input, history):
history = history or []
if history:
input=history[-1][0]
history=history[:-1]
if len(history) > 5:
history = history[-5:]
context = "\n".join([f"用户:{input_text}\n小元:{answer_text}" for input_text, answer_text in history])
#print(context)
input_text = context + "\n用户:" + input + "\n小元:"
input_text = input_text.strip()
output_text = answer(input_text)
print("open_model".center(20, "="))
print(f"{input_text}\n{output_text}")
history.append((input, output_text))
#print(history)
return history, history
block = gr.Blocks()
with block as demo:
gr.Markdown("""<h1><center>ChatYuan-By QiuLingYan</center></h1>
<font size=4>回答来自ChatYuan, 是模型生成的结果, 请谨慎辨别和参考, 不代表任何人观点 | Answer generated by ChatYuan model</font>
<font size=4>注意:gradio对markdown代码格式展示有限</font>
""")
chatbot = gr.Chatbot(label='ChatYuan')
message = gr.Textbox()
state = gr.State()
message.submit(chatyuan_bot, inputs=[message, state], outputs=[chatbot, state])
with gr.Row():
clear_history = gr.Button("👋 清除历史对话 | Clear History")
clear = gr.Button('🧹 清除发送框 | Clear Input')
send = gr.Button("🚀 发送 | Send")
regenerate = gr.Button("🚀 重新生成本次结果 | regenerate")
regenerate.click(chatyuan_bot_regenerate, inputs=[message, state], outputs=[chatbot, state])
send.click(chatyuan_bot, inputs=[message, state], outputs=[chatbot, state])
clear.click(lambda: None, None, message, queue=False)
clear_history.click(fn=clear_session , inputs=[], outputs=[chatbot, state], queue=False)
def ChatYuan(api_key, text_prompt):
cl = clueai.Client(api_key,
check_api_key=True)
# generate a prediction for a prompt
# 需要返回得分的话,指定return_likelihoods="GENERATION"
prediction = cl.generate(model_name='ChatYuan-large', prompt=text_prompt)
# print the predicted text
#print('prediction: {}'.format(prediction.generations[0].text))
response = prediction.generations[0].text
if response == '':
response = "很抱歉,我无法回答这个问题"
return response
def chatyuan_bot_api(api_key, input, history):
history = history or []
if len(history) > 5:
history = history[-5:]
context = "\n".join([f"用户:{input_text}\n小元:{answer_text}" for input_text, answer_text in history])
#print(context)
input_text = context + "\n用户:" + input + "\n小元:"
input_text = input_text.strip()
output_text = ChatYuan(api_key, input_text)
print("api".center(20, "="))
print(f"api_key:{api_key}\n{input_text}\n{output_text}")
#print("="*20)
history.append((input, output_text))
#print(history)
return history, history
block = gr.Blocks()
with block as demo_1:
gr.Markdown("""<h1><center>元语智能——ChatYuan</center></h1>
<font size=4>回答来自ChatYuan, 以上是模型生成的结果, 请谨慎辨别和参考, 不代表任何人观点 | Answer generated by ChatYuan model</font>
<font size=4>注意:gradio对markdown代码格式展示有限</font>
<font size=4>在使用此功能前,你需要有个API key. API key 可以通过这个<a href='https://www.clueai.cn/' target="_blank">平台</a>获取</font>
""")
api_key = gr.inputs.Textbox(label="请输入你的api-key(必填)", default="", type='password')
chatbot = gr.Chatbot(label='ChatYuan')
message = gr.Textbox()
state = gr.State()
message.submit(chatyuan_bot_api, inputs=[api_key,message, state], outputs=[chatbot, state])
with gr.Row():
clear_history = gr.Button("👋 清除历史对话 | Clear Context")
clear = gr.Button('🧹 清除发送框 | Clear Input')
send = gr.Button("🚀 发送 | Send")
send.click(chatyuan_bot_api, inputs=[api_key,message, state], outputs=[chatbot, state],api_name='send')
clear.click(lambda: None, None, message, queue=False)
clear_history.click(fn=clear_session , inputs=[], outputs=[chatbot, state], queue=False)
block = gr.Blocks()
with block as introduction:
gr.Markdown("""啥也没有
""")
gui = gr.TabbedInterface(interface_list=[demo], tab_names=["开源模型"])
gui.launch(quiet=True,show_api=True, share = False) |