Spaces:
Sleeping
Sleeping
File size: 5,179 Bytes
f3c6b77 c1e6869 f3c6b77 c1e6869 f3c6b77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import os
import random
import re
import requests
import argparse
import string
from datetime import timedelta
from flask import Flask, session, request, jsonify, render_template
from transformers.models.bert.tokenization_bert import BertTokenizer
from bot.chatbot import ChatBot
from bot.config import special_token_list
app = Flask(__name__)
app.config["SECRET_KEY"] = os.urandom(74)
app.config["PERMANENT_SESSION_LIFETIME"] = timedelta(days=7)
tokenizer:BertTokenizer = None
history_matrix:dict = {}
def move_history_from_session_to_global_memory() -> None:
global history_matrix
if session.get( "session_hash") and session["history"]:
history_matrix[session["session_hash"]] = session["history"]
def move_history_from_global_memory_to_session() -> None:
global history_matrix
if session.get( "session_hash"):
session["history"] = history_matrix.get( session.get( "session_hash") )
def set_args() -> argparse.Namespace:
parser:argparse.ArgumentParser = argparse.ArgumentParser()
parser.add_argument("--vocab_path", default=None, type=str, required=False, help="选择词库")
parser.add_argument("--model_path", default="lewiswu1209/Winnie", type=str, required=False, help="对话模型路径")
return parser.parse_args()
@app.route("/chitchat/history", methods = ["GET"])
def get_history_list() -> str:
global tokenizer
move_history_from_global_memory_to_session()
history_list:list = session.get("history")
if history_list is None:
history_list = []
history:list = []
for history_ids in history_list:
tokens = tokenizer.convert_ids_to_tokens(history_ids)
fixed_tokens = []
for token in tokens:
if token.startswith("##"):
token = token[2:]
fixed_tokens.append(token)
history.append( "".join( fixed_tokens ) )
return jsonify(history)
@app.route("/chitchat/chat", methods = ["GET"])
def talk() -> str:
global tokenizer
global history_matrix
if request.args.get("hash"):
session["session_hash"] = request.args.get("hash")
move_history_from_global_memory_to_session()
if session.get("session_hash") is None:
session["session_hash"] = "".join( random.sample(string.ascii_lowercase + string.digits, 11) )
if request.args.get("text"):
input_text = request.args.get("text")
history_list = session.get("history")
if input_text.upper()=="HELP":
help_info_list = ["输入任意文字,Winnie会和你对话",
"输入ERASE MEMORY,Winnie会清空记忆",
"输入\"<TAG>=<VALUE>\",Winnie会记录你的角色信息",
"例如:<NAME>=Vicky,Winnie会修改自己的名字",
"可以修改的角色信息有:",
"<NAME>, <GENDER>, <YEAROFBIRTH>, <MONTHOFBIRTH>, <DAYOFBIRTH>, <ZODIAC>, <AGE>",
"输入“上联:XXXXXXX”,Winnie会和你对对联",
"输入“写诗:XXXXXXX”,Winnie会以XXXXXXX为开头写诗",
"以\"请问\"开头并以问号结尾,Winnie会回答该问题"
]
return jsonify(help_info_list)
if history_list is None or len(history_list)==0 or input_text == "ERASE MEMORY":
history_list = []
output_text = requests.post(
url='https://hf.space/embed/lewiswu1209/Winnie/+/api/predict/',
json={"data": ["ERASE MEMORY"], "session_hash": session["session_hash"]}
).json()["data"][0]
if input_text != "ERASE MEMORY":
if not re.match( r"^<.+>=.+$", input_text ):
history_list.append( tokenizer.encode(input_text, add_special_tokens=False) )
output_text = requests.post(
url='https://hf.space/embed/lewiswu1209/Winnie/+/api/predict/',
json={"data": [input_text], "session_hash": session["session_hash"]}
).json()["data"][0]
if not re.match( r"^<.+>=.+$", input_text ):
history_list.append( tokenizer.encode(output_text, add_special_tokens=False) )
session["history"] = history_list
history_matrix[session["session_hash"]] = history_list
return jsonify([output_text])
else:
return jsonify([""])
@app.route("/")
def index() -> str:
return "Hello world!"
@app.route("/chitchat/hash", methods = ["GET"])
def get_hash() -> str:
global history_matrix
if request.args.get("hash"):
session["session_hash"] = request.args.get("hash")
move_history_from_global_memory_to_session()
hash = session.get("session_hash")
if hash:
return session.get("session_hash")
else:
return " "
@app.route( "/chitchat", methods = ["GET"] )
def chitchat() -> str:
return render_template( "chat_template.html" )
def main() -> None:
global tokenizer
args = set_args()
tokenizer = ChatBot.get_tokenizer(
args.model_path,
vocab_path=args.vocab_path,
special_token_list = special_token_list
)
app.run( host = "127.0.0.1", port = 8080 )
if __name__ == "__main__":
main()
|