InternVL / app.py
zyliu's picture
update app.py
8b33d6d
raw
history blame
3.76 kB
import spaces
import fire
import subprocess
import os
import time
import signal
import subprocess
import atexit
try:
import flash_attn
except ImportError:
@spaces.GPU
def install_flash_attn():
os.system("pip install flash-attn==2.5.9.post1")
install_flash_attn()
import flash_attn
def kill_processes_by_cmd_substring(cmd_substring):
# execute `ps -ef` and obtain its output
result = subprocess.run(["ps", "-ef"], stdout=subprocess.PIPE, text=True)
lines = result.stdout.splitlines()
# visit each line
for line in lines:
if cmd_substring in line:
# extract PID
parts = line.split()
pid = int(parts[1])
print(f"Killing process with PID: {pid}, CMD: {line}")
os.kill(pid, signal.SIGTERM)
def main(
python_path="python",
run_controller=True,
run_worker=True,
run_gradio=True,
controller_port=10086,
gradio_port=7860,
worker_names=[
"OpenGVLab/InternVL2-8B",
],
run_sd_worker=False,
**kwargs,
):
host = "http://0.0.0.0"
controller_process = None
if run_controller:
# python controller.py --host 0.0.0.0 --port 10086
cmd_args = [
f"{python_path}",
"controller.py",
"--host",
"0.0.0.0",
"--port",
f"{controller_port}",
]
kill_processes_by_cmd_substring(" ".join(cmd_args))
print("Launching controller: ", " ".join(cmd_args))
controller_process = subprocess.Popen(cmd_args)
atexit.register(controller_process.terminate)
worker_processes = []
if run_worker:
worker_port = 10088
for worker_name in worker_names:
cmd_args = [
f"{python_path}",
"model_worker.py",
"--port",
f"{worker_port}",
"--controller-url",
f"{host}:{controller_port}",
"--model-path",
f"{worker_name}",
"--load-8bit",
]
kill_processes_by_cmd_substring(" ".join(cmd_args))
print("Launching worker: ", " ".join(cmd_args))
worker_process = subprocess.Popen(cmd_args)
worker_processes.append(worker_process)
atexit.register(worker_process.terminate)
worker_port += 1
time.sleep(10)
gradio_process = None
if run_gradio:
# python gradio_web_server.py --port 10088 --controller-url http://0.0.0.0:10086
cmd_args = [
f"{python_path}",
"gradio_web_server.py",
"--port",
f"{gradio_port}",
"--controller-url",
f"{host}:{controller_port}",
"--model-list-mode",
"reload",
]
kill_processes_by_cmd_substring(" ".join(cmd_args))
print("Launching gradio: ", " ".join(cmd_args))
gradio_process = subprocess.Popen(cmd_args)
atexit.register(gradio_process.terminate)
sd_worker_process = None
if run_sd_worker:
# python model_worker.py --port 10088 --controller-address http://
cmd_args = [f"{python_path}", "sd_worker.py"]
kill_processes_by_cmd_substring(" ".join(cmd_args))
print("Launching sd_worker: ", " ".join(cmd_args))
sd_worker_process = subprocess.Popen(cmd_args)
atexit.register(sd_worker_process.terminate)
for worker_process in worker_processes:
worker_process.wait()
if controller_process:
controller_process.wait()
if gradio_process:
gradio_process.wait()
if sd_worker_process:
sd_worker_process.wait()
if __name__ == "__main__":
fire.Fire(main)