Hansimov commited on
Commit
3a09006
·
1 Parent(s): eb3e513

:gem: [Feature] New ChatAPIApp: Enable fastapi for openai format api call

Browse files
Files changed (2) hide show
  1. apis/__init__.py +0 -0
  2. apis/chat_api.py +82 -0
apis/__init__.py ADDED
File without changes
apis/chat_api.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uvicorn
2
+ from fastapi import FastAPI
3
+ from pydantic import BaseModel, Field
4
+ from sse_starlette.sse import EventSourceResponse
5
+ from utils.logger import logger
6
+ from networks.message_streamer import MessageStreamer
7
+ from messagers.message_composer import MessageComposer
8
+
9
+
10
+ class ChatAPIApp:
11
+ def __init__(self):
12
+ self.app = FastAPI(
13
+ docs_url="/",
14
+ title="HuggingFace LLM API",
15
+ swagger_ui_parameters={"defaultModelsExpandDepth": -1},
16
+ version="1.0",
17
+ )
18
+ self.setup_routes()
19
+
20
+ def get_available_models(self):
21
+ self.available_models = [
22
+ {
23
+ "id": "mixtral-8x7b",
24
+ "description": "Mixtral-8x7B: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1",
25
+ },
26
+ ]
27
+ return self.available_models
28
+
29
+ class ChatCompletionsPostItem(BaseModel):
30
+ model: str = Field(
31
+ default="mixtral-8x7b",
32
+ description="(str) `mixtral-8x7b`",
33
+ )
34
+ messages: list = Field(
35
+ default=[{"role": "user", "content": "Hello, who are you?"}],
36
+ description="(list) Messages",
37
+ )
38
+ temperature: float = Field(
39
+ default=0.01,
40
+ description="(float) Temperature",
41
+ )
42
+ max_tokens: int = Field(
43
+ default=32000,
44
+ description="(int) Max tokens",
45
+ )
46
+ stream: bool = Field(
47
+ default=True,
48
+ description="(bool) Stream",
49
+ )
50
+
51
+ def chat_completions(self, item: ChatCompletionsPostItem):
52
+ streamer = MessageStreamer(model=item.model)
53
+ composer = MessageComposer(model=item.model)
54
+ composer.merge(messages=item.messages)
55
+ return EventSourceResponse(
56
+ streamer.chat(
57
+ prompt=composer.merged_str,
58
+ temperature=item.temperature,
59
+ max_new_tokens=item.max_tokens,
60
+ stream=item.stream,
61
+ yield_output=True,
62
+ ),
63
+ media_type="text/event-stream",
64
+ )
65
+
66
+ def setup_routes(self):
67
+ for prefix in ["", "/v1"]:
68
+ self.app.get(
69
+ prefix + "/models",
70
+ summary="Get available models",
71
+ )(self.get_available_models)
72
+
73
+ self.app.post(
74
+ prefix + "/chat/completions",
75
+ summary="Chat completions in conversation session",
76
+ )(self.chat_completions)
77
+
78
+
79
+ app = ChatAPIApp().app
80
+
81
+ if __name__ == "__main__":
82
+ uvicorn.run("__main__:app", host="0.0.0.0", port=23333, reload=True)