macadeliccc commited on
Commit
f316cfc
·
1 Parent(s): 86efe77

change to starling-LM-7B

Browse files
Files changed (1) hide show
  1. app.py +42 -104
app.py CHANGED
@@ -1,117 +1,55 @@
1
  import spaces
2
  import gradio as gr
3
  import torch
4
- import subprocess
5
- import aiohttp
6
  from gradio import State
7
- import asyncio
8
- import json
9
- import asyncio
10
- import threading
11
- import time
12
 
13
- # Function to start the ochat server
14
- @spaces.GPU
15
- def start_ochat_server():
16
- print(f"Is CUDA available: {torch.cuda.is_available()}")
17
- print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
18
-
19
- command = [
20
- "python", "-m", "ochat.serving.openai_api_server",
21
- "--model", "openchat/openchat_3.5"
22
- ]
23
-
24
- # Start the server in a separate process
25
- try:
26
- subprocess.Popen(command)
27
- return "ochat server started successfully"
28
- except Exception as e:
29
- return f"Failed to start ochat server: {e}"
30
-
31
- start_ochat_server()
32
-
33
- # async def is_server_running():
34
- # async with aiohttp.ClientSession() as session:
35
- # try:
36
- # async with session.get("http://localhost:18888/v1/chat/completions") as response:
37
- # return response.status == 200 or response.status == 400 or response.status == 422
38
- # except aiohttp.ClientError:
39
- # return False
40
-
41
- # async def monitor_server():
42
- # # Wait for 5 minutes before starting to monitor
43
- # await asyncio.sleep(300)
44
-
45
- # while True:
46
- # if not await is_server_running():
47
- # print("Server is not running. Attempting to restart...")
48
- # start_ochat_server()
49
-
50
- # await asyncio.sleep(60)
51
-
52
- # def run_async_monitor():
53
- # time.sleep(120)
54
- # loop = asyncio.new_event_loop()
55
- # asyncio.set_event_loop(loop)
56
- # loop.run_until_complete(monitor_server())
57
- # loop.close()
58
 
59
- # # Start the monitoring in a separate thread
60
- # thread = threading.Thread(target=run_async_monitor)
61
- # thread.start()
62
 
63
- # Function to send a message to the ochat server and get a response
64
- async def chat_with_ochat(message):
65
- base_url = "http://localhost:18888"
66
- chat_url = f"{base_url}/v1/chat/completions"
67
- headers = {"Content-Type": "application/json"}
68
- data = {
69
- "model": "openchat_3.5",
70
- "messages": [{"role": "user", "content": message}]
71
- }
72
 
73
- async with aiohttp.ClientSession() as session:
74
- try:
75
- async with session.post(chat_url, headers=headers, json=data) as response:
76
- if response.status == 200:
77
- response_data = await response.json()
78
- return response_data['choices'][0]['message']['content']
79
- else:
80
- return f"Error: Server responded with status code {response.status}"
81
- except aiohttp.ClientError as e:
82
- return f"Error: {e}"
83
-
84
- # Create a Gradio Blocks interface with session state
85
- with gr.Blocks(theme=gr.themes.Soft()) as app:
86
- gr.Markdown("## vLLM OpenChat-3.5 Interface")
87
- gr.Markdown("### the vLLM server cannot handle concurrent users in spaces. If you get an error, run it on docker.")
88
- gr.Markdown("This will run better on your own machine: ```docker run -it -p 7860:7860 --platform=linux/amd64 --gpus all \
89
- registry.hf.space/macadeliccc-openchat-3-5-chatbot:latest python app.py```")
90
 
91
-
92
- message = gr.Textbox(label="Your Message", placeholder="Type your message here...")
93
- chatbot = gr.Chatbot()
94
- clear = gr.Button("Clear")
95
-
96
- history = State([]) # Session state for chat history
97
-
98
- async def user(message, history):
99
- return "", history + [[message, None]]
100
-
101
-
102
- async def bot(history):
103
- if history and history[-1] and history[-1][0]:
104
- user_message = history[-1][0]
105
- bot_response = await chat_with_ochat(user_message)
106
- history[-1][1] = bot_response # Update the last entry with the bot's response
107
- return history
108
-
109
- message.submit(user, [message, chatbot], [message, chatbot], queue=True).then(
110
- bot, chatbot, chatbot
111
- )
112
- clear.click(lambda: None, None, chatbot, queue=False)
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
 
 
 
 
115
 
116
- app.queue()
117
  app.launch()
 
1
  import spaces
2
  import gradio as gr
3
  import torch
 
 
4
  from gradio import State
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # Load the tokenizer and model
9
+ tokenizer = AutoTokenizer.from_pretrained("berkeley-nest/Starling-LM-7B-alpha")
10
+ model = AutoModelForCausalLM.from_pretrained("berkeley-nest/Starling-LM-7B-alpha")
11
 
12
+ # Ensure the model is in evaluation mode
13
+ model.eval()
 
 
 
 
 
 
 
14
 
15
+ # Move model to GPU if available
16
+ if torch.cuda.is_available():
17
+ model = model.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ @spaces.GPU
20
+ def generate_response(user_input, chat_history):
21
+ prompt = "GPT4 Correct User: " + user_input + "GPT4 Correct Assistant: "
22
+ if chat_history:
23
+ prompt = chat_history + prompt
24
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # Move tensors to the same device as model
27
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
28
+
29
+ with torch.no_grad():
30
+ output = model.generate(**inputs, max_length=1024, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)
31
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
32
+
33
+ # Update chat history
34
+ new_history = prompt + response
35
+ return response, new_history
36
+
37
+ # Gradio Interface
38
+ def clear_chat():
39
+ return "", ""
40
+
41
+ with gr.Blocks(gr.themes.Soft()) as app:
42
+ with gr.Row():
43
+ with gr.Column():
44
+ user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...")
45
+ send = gr.Button("Send")
46
+ clear = gr.Button("Clear")
47
+ with gr.Column():
48
+ chatbot = gr.Chatbot()
49
 
50
+ chat_history = gr.State() # Holds the chat history
51
+
52
+ send.click(generate_response, inputs=[user_input, chat_history], outputs=[chatbot, chat_history])
53
+ clear.click(clear_chat, outputs=[chatbot, chat_history])
54
 
 
55
  app.launch()