Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
from pathlib import Path | |
import re | |
import torch | |
import gc | |
from typing import Any | |
from huggingface_hub import hf_hub_download, HfApi | |
from llama_cpp import Llama | |
from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType | |
from llama_cpp_agent.providers import LlamaCppPythonProvider | |
from llama_cpp_agent.chat_history import BasicChatHistory | |
from llama_cpp_agent.chat_history.messages import Roles | |
from ja_to_danbooru.ja_to_danbooru import jatags_to_danbooru_tags | |
import wrapt_timeout_decorator | |
from llama_cpp_agent.messages_formatter import MessagesFormatter | |
from formatter import mistral_v1_formatter, mistral_v2_formatter, mistral_v3_tekken_formatter | |
from llmenv import llm_models, llm_models_dir, llm_formats, llm_languages, dolphin_system_prompt | |
import subprocess | |
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True) | |
llm_models_tupled_list = [] | |
default_llm_model_filename = list(llm_models.keys())[0] | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def to_list(s: str): | |
return [x.strip() for x in s.split(",") if not s == ""] | |
def list_uniq(l: list): | |
return sorted(set(l), key=l.index) | |
DEFAULT_STATE = { | |
"dolphin_sysprompt_mode": "Default", | |
"dolphin_output_language": llm_languages[0], | |
} | |
def get_state(state: dict, key: str): | |
if key in state.keys(): return state[key] | |
elif key in DEFAULT_STATE.keys(): | |
print(f"State '{key}' not found. Use dedault value.") | |
return DEFAULT_STATE[key] | |
else: | |
print(f"State '{key}' not found.") | |
return None | |
def set_state(state: dict, key: str, value: Any): | |
state[key] = value | |
def to_list_ja(s: str): | |
s = re.sub(r'[、。]', ',', s) | |
return [x.strip() for x in s.split(",") if not s == ""] | |
def is_japanese(s: str): | |
import unicodedata | |
for ch in s: | |
name = unicodedata.name(ch, "") | |
if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name: | |
return True | |
return False | |
def update_llm_model_tupled_list(): | |
global llm_models_tupled_list | |
llm_models_tupled_list = [] | |
for k, v in llm_models.items(): | |
name = k | |
value = k | |
llm_models_tupled_list.append((name, value)) | |
model_files = Path(llm_models_dir).glob('*.gguf') | |
for path in model_files: | |
name = path.name | |
value = path.name | |
llm_models_tupled_list.append((name, value)) | |
llm_models_tupled_list = list_uniq(llm_models_tupled_list) | |
return llm_models_tupled_list | |
def download_llm_models(): | |
global llm_models_tupled_list | |
llm_models_tupled_list = [] | |
for k, v in llm_models.items(): | |
try: | |
hf_hub_download(repo_id = v[0], filename = k, local_dir = llm_models_dir) | |
except Exception: | |
continue | |
name = k | |
value = k | |
llm_models_tupled_list.append((name, value)) | |
def download_llm_model(filename: str): | |
if not filename in llm_models.keys(): return default_llm_model_filename | |
try: | |
hf_hub_download(repo_id = llm_models[filename][0], filename = filename, local_dir = llm_models_dir) | |
except Exception as e: | |
print(e) | |
return default_llm_model_filename | |
update_llm_model_tupled_list() | |
return filename | |
def get_dolphin_model_info(filename: str): | |
md = "None" | |
items = llm_models.get(filename, None) | |
if items: | |
md = f'Repo: [{items[0]}](https://huggingface.co/{items[0]})' | |
return md | |
def select_dolphin_model(filename: str, state: dict, progress=gr.Progress(track_tqdm=True)): | |
set_state(state, "override_llm_format", None) | |
progress(0, desc="Loading model...") | |
value = download_llm_model(filename) | |
progress(1, desc="Model loaded.") | |
md = get_dolphin_model_info(filename) | |
return gr.update(value=value, choices=get_dolphin_models()), gr.update(value=get_dolphin_model_format(value)), gr.update(value=md), state | |
def select_dolphin_format(format_name: str, state: dict): | |
set_state(state, "override_llm_format", llm_formats[format_name]) | |
return gr.update(value=format_name), state | |
download_llm_model(default_llm_model_filename) | |
def get_dolphin_models(): | |
return update_llm_model_tupled_list() | |
def get_llm_formats(): | |
return list(llm_formats.keys()) | |
def get_key_from_value(d, val): | |
keys = [k for k, v in d.items() if v == val] | |
if keys: | |
return keys[0] | |
return None | |
def get_dolphin_model_format(filename: str): | |
if not filename in llm_models.keys(): filename = default_llm_model_filename | |
format = llm_models[filename][1] | |
format_name = get_key_from_value(llm_formats, format) | |
return format_name | |
def add_dolphin_models(query: str, format_name: str): | |
global llm_models | |
api = HfApi() | |
add_models = {} | |
format = llm_formats[format_name] | |
filename = "" | |
repo = "" | |
try: | |
s = list(re.findall(r'^(?:https?://huggingface.co/)?(.+?/.+?)(?:/.*/(.+?.gguf).*?)?$', query)[0]) | |
if s and "" in s: s.remove("") | |
if len(s) == 1: | |
repo = s[0] | |
if not api.repo_exists(repo_id = repo): return gr.update() | |
files = api.list_repo_files(repo_id = repo) | |
for file in files: | |
if str(file).endswith(".gguf"): add_models[filename] = [repo, format] | |
elif len(s) >= 2: | |
repo = s[0] | |
filename = s[1] | |
if not api.repo_exists(repo_id = repo) or not api.file_exists(repo_id = repo, filename = filename): return gr.update() | |
add_models[filename] = [repo, format] | |
else: return gr.update() | |
except Exception as e: | |
print(e) | |
return gr.update() | |
llm_models = (llm_models | add_models).copy() | |
update_llm_model_tupled_list() | |
choices = get_dolphin_models() | |
return gr.update(choices=choices, value=choices[-1][1]) | |
def get_dolphin_sysprompt(state: dict={}): | |
dolphin_sysprompt_mode = get_state(state, "dolphin_sysprompt_mode") | |
dolphin_output_language = get_state(state, "dolphin_output_language") | |
prompt = re.sub('<LANGUAGE>', dolphin_output_language if dolphin_output_language else llm_languages[0], | |
dolphin_system_prompt.get(dolphin_sysprompt_mode, dolphin_system_prompt[list(dolphin_system_prompt.keys())[0]])) | |
return prompt | |
def get_dolphin_sysprompt_mode(): | |
return list(dolphin_system_prompt.keys()) | |
def select_dolphin_sysprompt(key: str, state: dict): | |
dolphin_sysprompt_mode = get_state(state, "dolphin_sysprompt_mode") | |
if not key in dolphin_system_prompt.keys(): dolphin_sysprompt_mode = "Default" | |
else: dolphin_sysprompt_mode = key | |
set_state(state, "dolphin_sysprompt_mode", dolphin_sysprompt_mode) | |
return gr.update(value=get_dolphin_sysprompt(state)), state | |
def get_dolphin_languages(): | |
return llm_languages | |
def select_dolphin_language(lang: str, state: dict): | |
set_state(state, "dolphin_output_language", lang) | |
return gr.update(value=get_dolphin_sysprompt(state)), state | |
def get_raw_prompt(msg: str): | |
m = re.findall(r'/GENBEGIN/(.+?)/GENEND/', msg, re.DOTALL) | |
return re.sub(r'[*/:_"#\n]', ' ', ", ".join(m)).lower() if m else "" | |
def dolphin_respond( | |
message: str, | |
history: list[tuple[str, str]], | |
model: str = default_llm_model_filename, | |
system_message: str = get_dolphin_sysprompt(), | |
max_tokens: int = 1024, | |
temperature: float = 0.7, | |
top_p: float = 0.95, | |
top_k: int = 40, | |
repeat_penalty: float = 1.1, | |
state: dict = {}, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
try: | |
model_path = Path(f"{llm_models_dir}/{model}") | |
if not model_path.exists(): raise gr.Error(f"Model file not found: {str(model_path)}") | |
progress(0, desc="Processing...") | |
override_llm_format = get_state(state, "override_llm_format") | |
if override_llm_format: chat_template = override_llm_format | |
else: chat_template = llm_models[model][1] | |
llm = Llama( | |
model_path=str(model_path), | |
flash_attn=True, | |
n_gpu_layers=81, # 81 | |
n_batch=1024, | |
n_ctx=8192, #8192 | |
) | |
provider = LlamaCppPythonProvider(llm) | |
agent = LlamaCppAgent( | |
provider, | |
system_prompt=f"{system_message}", | |
predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None, | |
custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None, | |
debug_output=False | |
) | |
settings = provider.get_provider_default_settings() | |
settings.temperature = temperature | |
settings.top_k = top_k | |
settings.top_p = top_p | |
settings.max_tokens = max_tokens | |
settings.repeat_penalty = repeat_penalty | |
settings.stream = True | |
messages = BasicChatHistory() | |
for msn in history: | |
user = { | |
'role': Roles.user, | |
'content': msn[0] | |
} | |
assistant = { | |
'role': Roles.assistant, | |
'content': msn[1] | |
} | |
messages.add_message(user) | |
messages.add_message(assistant) | |
stream = agent.get_chat_response( | |
message, | |
llm_sampling_settings=settings, | |
chat_history=messages, | |
returns_streaming_generator=True, | |
print_output=False | |
) | |
progress(0.5, desc="Processing...") | |
outputs = "" | |
for output in stream: | |
outputs += output | |
yield [(outputs, None)] | |
except Exception as e: | |
print(e) | |
raise gr.Error(f"Error: {e}") | |
#yield [("", None)] | |
finally: | |
torch.cuda.empty_cache() | |
gc.collect() | |
def dolphin_parse( | |
history: list[tuple[str, str]], | |
state: dict, | |
): | |
try: | |
dolphin_sysprompt_mode = get_state(state, "dolphin_sysprompt_mode") | |
if dolphin_sysprompt_mode == "Chat with LLM" or not history or len(history) < 1: | |
return "", gr.update(), gr.update() | |
msg = history[-1][0] | |
raw_prompt = get_raw_prompt(msg) | |
prompts = [] | |
if dolphin_sysprompt_mode == "Japanese to Danbooru Dictionary" and is_japanese(raw_prompt): | |
prompts = list_uniq(jatags_to_danbooru_tags(to_list_ja(raw_prompt)) + ["nsfw", "explicit"]) | |
else: | |
prompts = list_uniq(to_list(raw_prompt) + ["nsfw", "explicit"]) | |
return ", ".join(prompts), gr.update(interactive=True), gr.update(interactive=True) | |
except Exception as e: | |
print(e) | |
return "", gr.update(), gr.update() | |
def dolphin_respond_auto( | |
message: str, | |
history: list[tuple[str, str]], | |
model: str = default_llm_model_filename, | |
system_message: str = get_dolphin_sysprompt(), | |
max_tokens: int = 1024, | |
temperature: float = 0.7, | |
top_p: float = 0.95, | |
top_k: int = 40, | |
repeat_penalty: float = 1.1, | |
state: dict = {}, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
try: | |
model_path = Path(f"{llm_models_dir}/{model}") | |
#if not is_japanese(message): return [(None, None)] | |
progress(0, desc="Processing...") | |
override_llm_format = get_state(state, "override_llm_format") | |
if override_llm_format: chat_template = override_llm_format | |
else: chat_template = llm_models[model][1] | |
llm = Llama( | |
model_path=str(model_path), | |
flash_attn=True, | |
n_gpu_layers=81, # 81 | |
n_batch=1024, | |
n_ctx=8192, #8192 | |
) | |
provider = LlamaCppPythonProvider(llm) | |
agent = LlamaCppAgent( | |
provider, | |
system_prompt=f"{system_message}", | |
predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None, | |
custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None, | |
debug_output=False | |
) | |
settings = provider.get_provider_default_settings() | |
settings.temperature = temperature | |
settings.top_k = top_k | |
settings.top_p = top_p | |
settings.max_tokens = max_tokens | |
settings.repeat_penalty = repeat_penalty | |
settings.stream = True | |
messages = BasicChatHistory() | |
for msn in history: | |
user = { | |
'role': Roles.user, | |
'content': msn[0] | |
} | |
assistant = { | |
'role': Roles.assistant, | |
'content': msn[1] | |
} | |
messages.add_message(user) | |
messages.add_message(assistant) | |
progress(0, desc="Translating...") | |
stream = agent.get_chat_response( | |
message, | |
llm_sampling_settings=settings, | |
chat_history=messages, | |
returns_streaming_generator=True, | |
print_output=False | |
) | |
progress(0.5, desc="Processing...") | |
outputs = "" | |
for output in stream: | |
outputs += output | |
yield [(outputs, None)], gr.update(), gr.update() | |
except Exception as e: | |
print(e) | |
yield [("", None)], gr.update(), gr.update() | |
finally: | |
torch.cuda.empty_cache() | |
gc.collect() | |
def dolphin_parse_simple( | |
message: str, | |
history: list[tuple[str, str]], | |
state: dict, | |
): | |
try: | |
#if not is_japanese(message): return message | |
dolphin_sysprompt_mode = get_state(state, "dolphin_sysprompt_mode") | |
if dolphin_sysprompt_mode == "Chat with LLM" or not history or len(history) < 1: return message | |
msg = history[-1][0] | |
raw_prompt = get_raw_prompt(msg) | |
prompts = [] | |
if dolphin_sysprompt_mode == "Japanese to Danbooru Dictionary" and is_japanese(raw_prompt): | |
prompts = list_uniq(jatags_to_danbooru_tags(to_list_ja(raw_prompt)) + ["nsfw", "explicit", "rating_explicit"]) | |
else: | |
prompts = list_uniq(to_list(raw_prompt) + ["nsfw", "explicit", "rating_explicit"]) | |
return ", ".join(prompts) | |
except Exception as e: | |
print(e) | |
return "" | |
# https://huggingface.co/spaces/CaioXapelaum/GGUF-Playground | |
import cv2 | |
cv2.setNumThreads(1) | |
def respond_playground( | |
message: str, | |
history: list[tuple[str, str]], | |
model: str = default_llm_model_filename, | |
system_message: str = get_dolphin_sysprompt(), | |
max_tokens: int = 1024, | |
temperature: float = 0.7, | |
top_p: float = 0.95, | |
top_k: int = 40, | |
repeat_penalty: float = 1.1, | |
state: dict = {}, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
try: | |
model_path = Path(f"{llm_models_dir}/{model}") | |
if not model_path.exists(): raise gr.Error(f"Model file not found: {str(model_path)}") | |
override_llm_format = get_state(state, "override_llm_format") | |
if override_llm_format: chat_template = override_llm_format | |
else: chat_template = llm_models[model][1] | |
llm = Llama( | |
model_path=str(model_path), | |
flash_attn=True, | |
n_gpu_layers=81, # 81 | |
n_batch=1024, | |
n_ctx=8192, #8192 | |
) | |
provider = LlamaCppPythonProvider(llm) | |
agent = LlamaCppAgent( | |
provider, | |
system_prompt=f"{system_message}", | |
predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None, | |
custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None, | |
debug_output=False | |
) | |
settings = provider.get_provider_default_settings() | |
settings.temperature = temperature | |
settings.top_k = top_k | |
settings.top_p = top_p | |
settings.max_tokens = max_tokens | |
settings.repeat_penalty = repeat_penalty | |
settings.stream = True | |
messages = BasicChatHistory() | |
# Add user and assistant messages to the history | |
for msn in history: | |
user = {'role': Roles.user, 'content': msn[0]} | |
assistant = {'role': Roles.assistant, 'content': msn[1]} | |
messages.add_message(user) | |
messages.add_message(assistant) | |
# Stream the response | |
stream = agent.get_chat_response( | |
message, | |
llm_sampling_settings=settings, | |
chat_history=messages, | |
returns_streaming_generator=True, | |
print_output=False | |
) | |
outputs = "" | |
for output in stream: | |
outputs += output | |
yield outputs | |
except Exception as e: | |
print(e) | |
raise gr.Error(f"Error: {e}") | |
#yield "" | |
finally: | |
torch.cuda.empty_cache() | |
gc.collect() | |