SilentWraith commited on
Commit
533133e
·
verified ·
1 Parent(s): ac440d6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import AsyncInferenceClient
2
+ import gradio as gr
3
+
4
+ client = AsyncInferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
5
+
6
+
7
+ def format_prompt(prompt: str, history: list[str], system_prompt: str) -> str:
8
+ if not history:
9
+ final_prompt = (
10
+ f"[INST] {system_prompt if system_prompt else ''}:\n{prompt} [/INST]"
11
+ )
12
+ else:
13
+ formatted_history = "".join(
14
+ f"[INST] {user_prompt} [/INST]{bot_response}</s> "
15
+ for user_prompt, bot_response in history
16
+ )
17
+ final_prompt = f"<s>{formatted_history}[INST] {prompt} [/INST]"
18
+ return final_prompt
19
+
20
+
21
+ async def generate(
22
+ prompt: str,
23
+ history: list[str],
24
+ system_prompt: str = "You're a helpful assistant.",
25
+ temperature: float = 0.3,
26
+ max_new_tokens: int = 4000,
27
+ top_p: float = 0.95,
28
+ repetition_penalty: float = 1.0,
29
+ ):
30
+ temperature = float(temperature)
31
+ if temperature < 1e-2:
32
+ temperature = 1e-2
33
+ top_p = float(top_p)
34
+
35
+ generate_kwargs = dict(
36
+ temperature=temperature,
37
+ max_new_tokens=max_new_tokens,
38
+ top_p=top_p,
39
+ repetition_penalty=repetition_penalty,
40
+ do_sample=True,
41
+ seed=42,
42
+ )
43
+
44
+ formatted_prompt = format_prompt(
45
+ prompt=prompt, history=history, system_prompt=history
46
+ )
47
+
48
+ stream = await client.text_generation(
49
+ formatted_prompt,
50
+ **generate_kwargs,
51
+ stream=True,
52
+ details=True,
53
+ return_full_text=True,
54
+ )
55
+
56
+ output = f""
57
+
58
+ async for response in stream:
59
+ output += response.token.text
60
+ yield output
61
+
62
+
63
+ additional_inputs = [
64
+ gr.Textbox(
65
+ label="System Prompt (optional)",
66
+ value="You're a helpful assistant.",
67
+ info="This is experimental",
68
+ placeholder="system prompt",
69
+ ),
70
+ gr.Slider(
71
+ label="Temperature",
72
+ value=0.9,
73
+ minimum=0.0,
74
+ maximum=1.0,
75
+ step=0.05,
76
+ interactive=True,
77
+ info="Higher values produce more diverse outputs",
78
+ ),
79
+ gr.Slider(
80
+ label="Max new tokens",
81
+ value=256,
82
+ minimum=0,
83
+ maximum=1048,
84
+ step=64,
85
+ interactive=True,
86
+ info="The maximum numbers of new tokens",
87
+ ),
88
+ gr.Slider(
89
+ label="Top-p (nucleus sampling)",
90
+ value=0.90,
91
+ minimum=0.0,
92
+ maximum=1,
93
+ step=0.05,
94
+ interactive=True,
95
+ info="Higher values sample more low-probability tokens",
96
+ ),
97
+ gr.Slider(
98
+ label="Repetition penalty",
99
+ value=1.2,
100
+ minimum=1.0,
101
+ maximum=2.0,
102
+ step=0.05,
103
+ interactive=True,
104
+ info="Penalize repeated tokens",
105
+ ),
106
+ ]
107
+
108
+ chatbot = gr.Chatbot(
109
+ avatar_images=["./user.png", "./bot.png"],
110
+ bubble_full_width=False,
111
+ show_label=False,
112
+ show_copy_button=True,
113
+ likeable=True,
114
+ )
115
+
116
+ demo = gr.ChatInterface(
117
+ fn=generate,
118
+ additional_inputs=additional_inputs,
119
+ chatbot=chatbot,
120
+ title="🪷",
121
+ description="Mixtral-8x7B-Instruct-v0.1",
122
+ concurrency_limit=100,
123
+ )
124
+
125
+ demo.queue().launch()