Spaces:
Runtime error
Runtime error
""" | |
Chat with a model with command line interface. | |
Usage: | |
python3 -m fastchat.serve.cli --model lmsys/vicuna-7b-v1.5 | |
python3 -m fastchat.serve.cli --model lmsys/fastchat-t5-3b-v1.0 | |
Other commands: | |
- Type "!!exit" or an empty line to exit. | |
- Type "!!reset" to start a new conversation. | |
- Type "!!remove" to remove the last prompt. | |
- Type "!!regen" to regenerate the last message. | |
- Type "!!save <filename>" to save the conversation history to a json file. | |
- Type "!!load <filename>" to load a conversation history from a json file. | |
""" | |
import argparse | |
import os | |
import re | |
import sys | |
from prompt_toolkit import PromptSession | |
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory | |
from prompt_toolkit.completion import WordCompleter | |
from prompt_toolkit.history import InMemoryHistory | |
from prompt_toolkit.key_binding import KeyBindings | |
from rich.console import Console | |
from rich.live import Live | |
from rich.markdown import Markdown | |
import torch | |
from fastchat.model.model_adapter import add_model_args | |
from fastchat.modules.awq import AWQConfig | |
from fastchat.modules.exllama import ExllamaConfig | |
from fastchat.modules.xfastertransformer import XftConfig | |
from fastchat.modules.gptq import GptqConfig | |
from fastchat.serve.inference import ChatIO, chat_loop | |
from fastchat.utils import str_to_torch_dtype | |
class SimpleChatIO(ChatIO): | |
def __init__(self, multiline: bool = False): | |
self._multiline = multiline | |
def prompt_for_input(self, role) -> str: | |
if not self._multiline: | |
return input(f"{role}: ") | |
prompt_data = [] | |
line = input(f"{role} [ctrl-d/z on empty line to end]: ") | |
while True: | |
prompt_data.append(line.strip()) | |
try: | |
line = input() | |
except EOFError as e: | |
break | |
return "\n".join(prompt_data) | |
def prompt_for_output(self, role: str): | |
print(f"{role}: ", end="", flush=True) | |
def stream_output(self, output_stream): | |
pre = 0 | |
for outputs in output_stream: | |
output_text = outputs["text"] | |
output_text = output_text.strip().split(" ") | |
now = len(output_text) - 1 | |
if now > pre: | |
print(" ".join(output_text[pre:now]), end=" ", flush=True) | |
pre = now | |
print(" ".join(output_text[pre:]), flush=True) | |
return " ".join(output_text) | |
def print_output(self, text: str): | |
print(text) | |
class RichChatIO(ChatIO): | |
bindings = KeyBindings() | |
def _(event): | |
event.app.current_buffer.newline() | |
def __init__(self, multiline: bool = False, mouse: bool = False): | |
self._prompt_session = PromptSession(history=InMemoryHistory()) | |
self._completer = WordCompleter( | |
words=["!!exit", "!!reset", "!!remove", "!!regen", "!!save", "!!load"], | |
pattern=re.compile("$"), | |
) | |
self._console = Console() | |
self._multiline = multiline | |
self._mouse = mouse | |
def prompt_for_input(self, role) -> str: | |
self._console.print(f"[bold]{role}:") | |
# TODO(suquark): multiline input has some issues. fix it later. | |
prompt_input = self._prompt_session.prompt( | |
completer=self._completer, | |
multiline=False, | |
mouse_support=self._mouse, | |
auto_suggest=AutoSuggestFromHistory(), | |
key_bindings=self.bindings if self._multiline else None, | |
) | |
self._console.print() | |
return prompt_input | |
def prompt_for_output(self, role: str): | |
self._console.print(f"[bold]{role.replace('/', '|')}:") | |
def stream_output(self, output_stream): | |
"""Stream output from a role.""" | |
# TODO(suquark): the console flickers when there is a code block | |
# above it. We need to cut off "live" when a code block is done. | |
# Create a Live context for updating the console output | |
with Live(console=self._console, refresh_per_second=4) as live: | |
# Read lines from the stream | |
for outputs in output_stream: | |
if not outputs: | |
continue | |
text = outputs["text"] | |
# Render the accumulated text as Markdown | |
# NOTE: this is a workaround for the rendering "unstandard markdown" | |
# in rich. The chatbots output treat "\n" as a new line for | |
# better compatibility with real-world text. However, rendering | |
# in markdown would break the format. It is because standard markdown | |
# treat a single "\n" in normal text as a space. | |
# Our workaround is adding two spaces at the end of each line. | |
# This is not a perfect solution, as it would | |
# introduce trailing spaces (only) in code block, but it works well | |
# especially for console output, because in general the console does not | |
# care about trailing spaces. | |
lines = [] | |
for line in text.splitlines(): | |
lines.append(line) | |
if line.startswith("```"): | |
# Code block marker - do not add trailing spaces, as it would | |
# break the syntax highlighting | |
lines.append("\n") | |
else: | |
lines.append(" \n") | |
markdown = Markdown("".join(lines)) | |
# Update the Live console output | |
live.update(markdown) | |
self._console.print() | |
return text | |
def print_output(self, text: str): | |
self.stream_output([{"text": text}]) | |
class ProgrammaticChatIO(ChatIO): | |
def prompt_for_input(self, role) -> str: | |
contents = "" | |
# `end_sequence` signals the end of a message. It is unlikely to occur in | |
# message content. | |
end_sequence = " __END_OF_A_MESSAGE_47582648__\n" | |
len_end = len(end_sequence) | |
while True: | |
if len(contents) >= len_end: | |
last_chars = contents[-len_end:] | |
if last_chars == end_sequence: | |
break | |
try: | |
char = sys.stdin.read(1) | |
contents = contents + char | |
except EOFError: | |
continue | |
contents = contents[:-len_end] | |
print(f"[!OP:{role}]: {contents}", flush=True) | |
return contents | |
def prompt_for_output(self, role: str): | |
print(f"[!OP:{role}]: ", end="", flush=True) | |
def stream_output(self, output_stream): | |
pre = 0 | |
for outputs in output_stream: | |
output_text = outputs["text"] | |
output_text = output_text.strip().split(" ") | |
now = len(output_text) - 1 | |
if now > pre: | |
print(" ".join(output_text[pre:now]), end=" ", flush=True) | |
pre = now | |
print(" ".join(output_text[pre:]), flush=True) | |
return " ".join(output_text) | |
def print_output(self, text: str): | |
print(text) | |
def main(args): | |
if args.gpus: | |
if len(args.gpus.split(",")) < args.num_gpus: | |
raise ValueError( | |
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" | |
) | |
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus | |
os.environ["XPU_VISIBLE_DEVICES"] = args.gpus | |
if args.enable_exllama: | |
exllama_config = ExllamaConfig( | |
max_seq_len=args.exllama_max_seq_len, | |
gpu_split=args.exllama_gpu_split, | |
cache_8bit=args.exllama_cache_8bit, | |
) | |
else: | |
exllama_config = None | |
if args.enable_xft: | |
xft_config = XftConfig( | |
max_seq_len=args.xft_max_seq_len, | |
data_type=args.xft_dtype, | |
) | |
if args.device != "cpu": | |
print("xFasterTransformer now is only support CPUs. Reset device to CPU") | |
args.device = "cpu" | |
else: | |
xft_config = None | |
if args.style == "simple": | |
chatio = SimpleChatIO(args.multiline) | |
elif args.style == "rich": | |
chatio = RichChatIO(args.multiline, args.mouse) | |
elif args.style == "programmatic": | |
chatio = ProgrammaticChatIO() | |
else: | |
raise ValueError(f"Invalid style for console: {args.style}") | |
try: | |
chat_loop( | |
args.model_path, | |
args.device, | |
args.num_gpus, | |
args.max_gpu_memory, | |
str_to_torch_dtype(args.dtype), | |
args.load_8bit, | |
args.cpu_offloading, | |
args.conv_template, | |
args.conv_system_msg, | |
args.temperature, | |
args.repetition_penalty, | |
args.max_new_tokens, | |
chatio, | |
gptq_config=GptqConfig( | |
ckpt=args.gptq_ckpt or args.model_path, | |
wbits=args.gptq_wbits, | |
groupsize=args.gptq_groupsize, | |
act_order=args.gptq_act_order, | |
), | |
awq_config=AWQConfig( | |
ckpt=args.awq_ckpt or args.model_path, | |
wbits=args.awq_wbits, | |
groupsize=args.awq_groupsize, | |
), | |
exllama_config=exllama_config, | |
xft_config=xft_config, | |
revision=args.revision, | |
judge_sent_end=args.judge_sent_end, | |
debug=args.debug, | |
history=not args.no_history, | |
) | |
except KeyboardInterrupt: | |
print("exit...") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
add_model_args(parser) | |
parser.add_argument( | |
"--conv-template", type=str, default=None, help="Conversation prompt template." | |
) | |
parser.add_argument( | |
"--conv-system-msg", type=str, default=None, help="Conversation system message." | |
) | |
parser.add_argument("--temperature", type=float, default=0.7) | |
parser.add_argument("--repetition_penalty", type=float, default=1.0) | |
parser.add_argument("--max-new-tokens", type=int, default=512) | |
parser.add_argument("--no-history", action="store_true") | |
parser.add_argument( | |
"--style", | |
type=str, | |
default="simple", | |
choices=["simple", "rich", "programmatic"], | |
help="Display style.", | |
) | |
parser.add_argument( | |
"--multiline", | |
action="store_true", | |
help="Enable multiline input. Use ESC+Enter for newline.", | |
) | |
parser.add_argument( | |
"--mouse", | |
action="store_true", | |
help="[Rich Style]: Enable mouse support for cursor positioning.", | |
) | |
parser.add_argument( | |
"--judge-sent-end", | |
action="store_true", | |
help="Whether enable the correction logic that interrupts the output of sentences due to EOS.", | |
) | |
parser.add_argument( | |
"--debug", | |
action="store_true", | |
help="Print useful debug information (e.g., prompts)", | |
) | |
args = parser.parse_args() | |
main(args) | |