Spaces:
Paused
Paused
# coding=utf-8 | |
# Implements API for ChatGLM2-6B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat) | |
# Usage: python openai_api.py | |
# Visit http://localhost:8000/docs for documents. | |
import time | |
import torch | |
import uvicorn | |
from pydantic import BaseModel, Field | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from contextlib import asynccontextmanager | |
from starlette.responses import StreamingResponse | |
from typing import Any, Dict, List, Literal, Optional, Union | |
from transformers import AutoTokenizer, AutoModel | |
async def lifespan(app: FastAPI): # collects GPU memory | |
yield | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
app = FastAPI(lifespan=lifespan) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
class ModelCard(BaseModel): | |
id: str | |
object: str = "model" | |
created: int = Field(default_factory=lambda: int(time.time())) | |
owned_by: str = "owner" | |
root: Optional[str] = None | |
parent: Optional[str] = None | |
permission: Optional[list] = None | |
class ModelList(BaseModel): | |
object: str = "list" | |
data: List[ModelCard] = [] | |
class ChatMessage(BaseModel): | |
role: Literal["user", "assistant", "system"] | |
content: str | |
class DeltaMessage(BaseModel): | |
role: Optional[Literal["user", "assistant", "system"]] = None | |
content: Optional[str] = None | |
class ChatCompletionRequest(BaseModel): | |
model: str | |
messages: List[ChatMessage] | |
temperature: Optional[float] = None | |
top_p: Optional[float] = None | |
max_length: Optional[int] = None | |
stream: Optional[bool] = False | |
class ChatCompletionResponseChoice(BaseModel): | |
index: int | |
message: ChatMessage | |
finish_reason: Literal["stop", "length"] | |
class ChatCompletionResponseStreamChoice(BaseModel): | |
index: int | |
delta: DeltaMessage | |
finish_reason: Optional[Literal["stop", "length"]] | |
class ChatCompletionResponse(BaseModel): | |
model: str | |
object: Literal["chat.completion", "chat.completion.chunk"] | |
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] | |
created: Optional[int] = Field(default_factory=lambda: int(time.time())) | |
async def list_models(): | |
global model_args | |
model_card = ModelCard(id="gpt-3.5-turbo") | |
return ModelList(data=[model_card]) | |
async def create_chat_completion(request: ChatCompletionRequest): | |
global model, tokenizer | |
if request.messages[-1].role != "user": | |
raise HTTPException(status_code=400, detail="Invalid request") | |
query = request.messages[-1].content | |
prev_messages = request.messages[:-1] | |
if len(prev_messages) > 0 and prev_messages[0].role == "system": | |
query = prev_messages.pop(0).content + query | |
history = [] | |
if len(prev_messages) % 2 == 0: | |
for i in range(0, len(prev_messages), 2): | |
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": | |
history.append([prev_messages[i].content, prev_messages[i+1].content]) | |
if request.stream: | |
generate = predict(query, history, request.model) | |
return StreamingResponse(generate, media_type="text/event-stream") | |
response, _ = model.chat(tokenizer, query, history=history) | |
choice_data = ChatCompletionResponseChoice( | |
index=0, | |
message=ChatMessage(role="assistant", content=response), | |
finish_reason="stop" | |
) | |
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion") | |
async def predict(query: str, history: List[List[str]], model_id: str): | |
global model, tokenizer | |
choice_data = ChatCompletionResponseStreamChoice( | |
index=0, | |
delta=DeltaMessage(role="assistant"), | |
finish_reason=None | |
) | |
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") | |
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) | |
current_length = 0 | |
for new_response, _ in model.stream_chat(tokenizer, query, history): | |
if len(new_response) == current_length: | |
continue | |
new_text = new_response[current_length:] | |
current_length = len(new_response) | |
choice_data = ChatCompletionResponseStreamChoice( | |
index=0, | |
delta=DeltaMessage(content=new_text), | |
finish_reason=None | |
) | |
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") | |
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) | |
choice_data = ChatCompletionResponseStreamChoice( | |
index=0, | |
delta=DeltaMessage(), | |
finish_reason="stop" | |
) | |
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") | |
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) | |
if __name__ == "__main__": | |
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) | |
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda() | |
# 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量 | |
# from utils import load_model_on_gpus | |
# model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2) | |
model.eval() | |
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) | |