Yahir commited on
Commit
9fbab2d
·
verified ·
1 Parent(s): 70820da

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -0
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import InferenceClient
2
+ import gradio as gr
3
+
4
+ client = InferenceClient(
5
+ "google/gemma-7b-it"
6
+ )
7
+
8
+ def format_prompt(message, history):
9
+ prompt = ""
10
+ if history:
11
+ #<start_of_turn>userWhat is recession?<end_of_turn><start_of_turn>model
12
+ for user_prompt, bot_response in history:
13
+ prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
14
+ prompt += f"<start_of_turn>model{bot_response}"
15
+ prompt += f"<start_of_turn>user{message}<end_of_turn><start_of_turn>model"
16
+ return prompt
17
+
18
+ def generate(
19
+ prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
20
+ ):
21
+ if not history:
22
+ history = []
23
+ hist_len=0
24
+ if history:
25
+ hist_len=len(history)
26
+ print(hist_len)
27
+
28
+ temperature = float(temperature)
29
+ if temperature < 1e-2:
30
+ temperature = 1e-2
31
+ top_p = float(top_p)
32
+
33
+ generate_kwargs = dict(
34
+ temperature=temperature,
35
+ max_new_tokens=max_new_tokens,
36
+ top_p=top_p,
37
+ repetition_penalty=repetition_penalty,
38
+ do_sample=True,
39
+ seed=42,
40
+ )
41
+
42
+ formatted_prompt = format_prompt(prompt, history)
43
+
44
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
45
+ output = ""
46
+
47
+ for response in stream:
48
+ output += response.token.text
49
+ yield output
50
+ return output
51
+
52
+
53
+ additional_inputs=[
54
+ gr.Slider(
55
+ label="Temperature",
56
+ value=0.9,
57
+ minimum=0.0,
58
+ maximum=1.0,
59
+ step=0.05,
60
+ interactive=True,
61
+ info="Higher values produce more diverse outputs",
62
+ ),
63
+ gr.Slider(
64
+ label="Max new tokens",
65
+ value=512,
66
+ minimum=0,
67
+ maximum=1048,
68
+ step=64,
69
+ interactive=True,
70
+ info="The maximum numbers of new tokens",
71
+ ),
72
+ gr.Slider(
73
+ label="Top-p (nucleus sampling)",
74
+ value=0.90,
75
+ minimum=0.0,
76
+ maximum=1,
77
+ step=0.05,
78
+ interactive=True,
79
+ info="Higher values sample more low-probability tokens",
80
+ ),
81
+ gr.Slider(
82
+ label="Repetition penalty",
83
+ value=1.2,
84
+ minimum=1.0,
85
+ maximum=2.0,
86
+ step=0.05,
87
+ interactive=True,
88
+ info="Penalize repeated tokens",
89
+ )
90
+ ]
91
+
92
+ # Create a Chatbot object with the desired height
93
+ chatbot = gr.Chatbot(height=450,
94
+ layout="bubble")
95
+
96
+ with gr.Blocks() as demo:
97
+ gr.HTML("<h1><center>🤖 Google-Gemma-7B-Chat 💬<h1><center>")
98
+ gr.ChatInterface(
99
+ generate,
100
+ chatbot=chatbot, # Use the created Chatbot object
101
+ additional_inputs=additional_inputs,
102
+ examples=[["What is the meaning of life?"], ["Tell me something about Mt Fuji."]],
103
+
104
+ )
105
+
106
+ demo.queue().launch(debug=True)