|
import ast
|
|
import gradio as gr
|
|
import os
|
|
import re
|
|
import json
|
|
import logging
|
|
|
|
import torch
|
|
from datetime import datetime
|
|
|
|
from threading import Thread
|
|
from typing import Optional
|
|
from transformers import TextIteratorStreamer
|
|
from functools import partial
|
|
from huggingface_hub import CommitScheduler
|
|
from uuid import uuid4
|
|
from pathlib import Path
|
|
|
|
from code_interpreter.JupyterClient import JupyterNotebook
|
|
|
|
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
|
|
|
|
import warnings
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
|
|
|
|
|
from code_interpreter.OpenCodeInterpreter import OpenCodeInterpreter
|
|
|
|
JSON_DATASET_DIR = Path("json_dataset")
|
|
JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
scheduler = CommitScheduler(
|
|
repo_id="opencodeinterpreter_user_data",
|
|
repo_type="dataset",
|
|
folder_path=JSON_DATASET_DIR,
|
|
path_in_repo="data",
|
|
private=True
|
|
)
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
class StreamingOpenCodeInterpreter(OpenCodeInterpreter):
|
|
streamer: Optional[TextIteratorStreamer] = None
|
|
|
|
|
|
@torch.inference_mode()
|
|
def generate(
|
|
self,
|
|
prompt: str = "",
|
|
max_new_tokens = 1024,
|
|
do_sample: bool = False,
|
|
top_p: float = 0.95,
|
|
top_k: int = 50,
|
|
) -> str:
|
|
|
|
|
|
self.streamer = TextIteratorStreamer(
|
|
self.tokenizer, skip_prompt=True, Timeout=5
|
|
)
|
|
|
|
inputs = self.tokenizer([prompt], return_tensors="pt", truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH)
|
|
inputs = inputs.to(self.model.device)
|
|
|
|
kwargs = dict(
|
|
**inputs,
|
|
streamer=self.streamer,
|
|
max_new_tokens=max_new_tokens,
|
|
do_sample=do_sample,
|
|
top_p=top_p,
|
|
top_k=top_k,
|
|
eos_token_id=self.tokenizer.eos_token_id
|
|
)
|
|
|
|
thread = Thread(target=self.model.generate, kwargs=kwargs)
|
|
thread.start()
|
|
|
|
return ""
|
|
|
|
def save_json(dialog, mode, json_file_path, dialog_id) -> None:
|
|
with scheduler.lock:
|
|
with json_file_path.open("a") as f:
|
|
json.dump({"id": dialog_id, "dialog": dialog, "mode": mode, "datetime": datetime.now().isoformat()}, f, ensure_ascii=False)
|
|
f.write("\n")
|
|
|
|
def convert_history(gradio_history: list[list], interpreter_history: list[dict]):
|
|
interpreter_history = [interpreter_history[0]] if interpreter_history and interpreter_history[0]["role"] == "system" else []
|
|
if not gradio_history:
|
|
return interpreter_history
|
|
for item in gradio_history:
|
|
if item[0] is not None:
|
|
interpreter_history.append({"role": "user", "content": item[0]})
|
|
if item[1] is not None:
|
|
interpreter_history.append({"role": "assistant", "content": item[1]})
|
|
return interpreter_history
|
|
|
|
def update_uuid(dialog_info):
|
|
new_uuid = str(uuid4())
|
|
logging.info(f"allocating new uuid {new_uuid} for conversation...")
|
|
return [new_uuid, dialog_info[1]]
|
|
|
|
def is_valid_python_code(code):
|
|
try:
|
|
ast.parse(code)
|
|
return True
|
|
except SyntaxError:
|
|
return False
|
|
|
|
|
|
class InputFunctionVisitor(ast.NodeVisitor):
|
|
def __init__(self):
|
|
self.found_input = False
|
|
|
|
def visit_Call(self, node):
|
|
if isinstance(node.func, ast.Name) and node.func.id == 'input':
|
|
self.found_input = True
|
|
self.generic_visit(node)
|
|
|
|
def has_input_function_calls(code):
|
|
try:
|
|
tree = ast.parse(code)
|
|
except SyntaxError:
|
|
return False
|
|
visitor = InputFunctionVisitor()
|
|
visitor.visit(tree)
|
|
return visitor.found_input
|
|
|
|
def gradio_launch(model_path: str, MAX_TRY: int = 3):
|
|
with gr.Blocks() as demo:
|
|
chatbot = gr.Chatbot(height=600, label="OpenCodeInterpreter", avatar_images=["assets/user.pic.jpg", "assets/assistant.pic.jpg"], show_copy_button=True)
|
|
with gr.Group():
|
|
with gr.Row():
|
|
msg = gr.Textbox(
|
|
container=False,
|
|
show_label=False,
|
|
label="Message",
|
|
placeholder="Type a message...",
|
|
scale=7,
|
|
autofocus=True
|
|
)
|
|
sub = gr.Button(
|
|
"Submit",
|
|
variant="primary",
|
|
scale=1,
|
|
min_width=150
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Row():
|
|
|
|
|
|
clear = gr.Button("🗑️ Clear", variant="secondary")
|
|
|
|
session_state = gr.State([])
|
|
jupyter_state = gr.State(JupyterNotebook())
|
|
dialog_info = gr.State(["", 0])
|
|
demo.load(update_uuid, dialog_info, dialog_info)
|
|
|
|
def bot(user_message, history, jupyter_state, dialog_info, interpreter):
|
|
logging.info(f"user message: {user_message}")
|
|
interpreter.dialog = convert_history(gradio_history=history, interpreter_history=interpreter.dialog)
|
|
history.append([user_message, None])
|
|
|
|
interpreter.dialog.append({"role": "user", "content": user_message})
|
|
|
|
|
|
HAS_CODE = False
|
|
prompt = interpreter.dialog_to_prompt(dialog=interpreter.dialog)
|
|
|
|
_ = interpreter.generate(prompt)
|
|
history[-1][1] = ""
|
|
generated_text = ""
|
|
for character in interpreter.streamer:
|
|
history[-1][1] += character
|
|
history[-1][1] = history[-1][1].replace("<|EOT|>","")
|
|
generated_text += character
|
|
yield history, history, jupyter_state, dialog_info
|
|
|
|
if is_valid_python_code(history[-1][1].strip()):
|
|
history[-1][1] = f"```python\n{history[-1][1].strip()}\n```"
|
|
generated_text = history[-1][1]
|
|
|
|
HAS_CODE, generated_code_block = interpreter.extract_code_blocks(
|
|
generated_text
|
|
)
|
|
|
|
interpreter.dialog.append(
|
|
{
|
|
"role": "assistant",
|
|
"content": generated_text.replace("<unk>_", "")
|
|
.replace("<unk>", "")
|
|
.replace("<|EOT|>", ""),
|
|
}
|
|
)
|
|
|
|
logging.info(f"saving current dialog to file {dialog_info[0]}.json...")
|
|
logging.info(f"current dialog: {interpreter.dialog}")
|
|
save_json(interpreter.dialog, mode="openci_only", json_file_path=JSON_DATASET_DIR/f"{dialog_info[0]}.json", dialog_id=dialog_info[0])
|
|
|
|
attempt = 1
|
|
while HAS_CODE:
|
|
if attempt > MAX_TRY:
|
|
break
|
|
|
|
generated_text = ""
|
|
|
|
yield history, history, jupyter_state, dialog_info
|
|
|
|
|
|
generated_code_block = generated_code_block.replace(
|
|
"<unk>_", ""
|
|
).replace("<unk>", "")
|
|
|
|
if has_input_function_calls(generated_code_block):
|
|
code_block_output = "Please directly assign the value of inputs instead of using input() function in your code."
|
|
else:
|
|
(
|
|
code_block_output,
|
|
error_flag,
|
|
) = interpreter.execute_code_and_return_output(
|
|
f"{generated_code_block}",
|
|
jupyter_state
|
|
)
|
|
if error_flag == "Timeout":
|
|
logging.info(f"{dialog_info[0]}: Restart jupyter kernel due to timeout")
|
|
jupyter_state = JupyterNotebook()
|
|
code_block_output = interpreter.clean_code_output(code_block_output)
|
|
|
|
if code_block_output.strip():
|
|
code_block_output = "Execution result: \n" + code_block_output
|
|
else:
|
|
code_block_output = "Code is executed, but result is empty. Please make sure that you include test case in your code."
|
|
|
|
history.append([code_block_output, ""])
|
|
|
|
interpreter.dialog.append({"role": "user", "content": code_block_output})
|
|
|
|
yield history, history, jupyter_state, dialog_info
|
|
|
|
prompt = interpreter.dialog_to_prompt(dialog=interpreter.dialog)
|
|
|
|
logging.info(f"generating answer for dialog {dialog_info[0]}")
|
|
_ = interpreter.generate(prompt)
|
|
for character in interpreter.streamer:
|
|
history[-1][1] += character
|
|
history[-1][1] = history[-1][1].replace("<|EOT|>","")
|
|
generated_text += character
|
|
yield history, history, jupyter_state, dialog_info
|
|
logging.info(f"finish generating answer for dialog {dialog_info[0]}")
|
|
|
|
HAS_CODE, generated_code_block = interpreter.extract_code_blocks(
|
|
history[-1][1]
|
|
)
|
|
|
|
interpreter.dialog.append(
|
|
{
|
|
"role": "assistant",
|
|
"content": generated_text.replace("<unk>_", "")
|
|
.replace("<unk>", "")
|
|
.replace("<|EOT|>", ""),
|
|
}
|
|
)
|
|
|
|
attempt += 1
|
|
|
|
logging.info(f"saving current dialog to file {dialog_info[0]}.json...")
|
|
logging.info(f"current dialog: {interpreter.dialog}")
|
|
save_json(interpreter.dialog, mode="openci_only", json_file_path=JSON_DATASET_DIR/f"{dialog_info[0]}.json", dialog_id=dialog_info[0])
|
|
|
|
if generated_text.endswith("<|EOT|>"):
|
|
continue
|
|
|
|
return history, history, jupyter_state, dialog_info
|
|
|
|
|
|
def reset_textbox():
|
|
return gr.update(value="")
|
|
|
|
|
|
def clear_history(history, jupyter_state, dialog_info, interpreter):
|
|
interpreter.dialog = []
|
|
jupyter_state.close()
|
|
return [], [], JupyterNotebook(), update_uuid(dialog_info)
|
|
|
|
interpreter = StreamingOpenCodeInterpreter(model_path=model_path)
|
|
|
|
sub.click(partial(bot, interpreter=interpreter), [msg, session_state, jupyter_state, dialog_info], [chatbot, session_state, jupyter_state, dialog_info])
|
|
sub.click(reset_textbox, [], [msg])
|
|
|
|
clear.click(partial(clear_history, interpreter=interpreter), [session_state, jupyter_state, dialog_info], [chatbot, session_state, jupyter_state, dialog_info], queue=False)
|
|
|
|
demo.queue(max_size=20)
|
|
demo.launch(share=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--path",
|
|
type=str,
|
|
required=False,
|
|
help="Path to the OpenCodeInterpreter Model.",
|
|
default="m-a-p/OpenCodeInterpreter-DS-6.7B",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
gradio_launch(model_path=args.path)
|
|
|