Spaces:
Sleeping
Sleeping
import json | |
from typing import Optional, List | |
import httpx | |
from llm.common import LlmParams, LlmApi | |
class LlmApi(LlmApi): | |
""" | |
Класс для работы с API vllm. | |
""" | |
def __init__(self, params: LlmParams): | |
super().__init__() | |
super().set_params(params) | |
async def get_models(self) -> List[str]: | |
""" | |
Выполняет GET-запрос к API для получения списка доступных моделей. | |
Возвращает: | |
list[str]: Список идентификаторов моделей. | |
Если произошла ошибка или данные недоступны, возвращается пустой список. | |
Исключения: | |
Все ошибки HTTP-запросов логируются в консоль, но не выбрасываются дальше. | |
""" | |
try: | |
async with httpx.AsyncClient() as client: | |
response = await client.get(f"{self.params.url}/v1/models", headers=super().create_headers()) | |
if response.status_code == 200: | |
json_data = response.json() | |
return [item['id'] for item in json_data.get('data', [])] | |
except httpx.RequestError as error: | |
print('Error fetching models:', error) | |
return [] | |
async def get_model(self) -> str: | |
model = None | |
if self.params.model is not None: | |
model = self.params.model | |
else: | |
models = await self.get_models() | |
model = models[0] if models else None | |
if model is None: | |
raise Exception("No model name provided and no models available.") | |
return model | |
def create_messages(self, prompt: str) -> List[dict]: | |
""" | |
Создает сообщения для LLM на основе переданного промпта и системного промпта (если он задан). | |
Args: | |
prompt (str): Пользовательский промпт. | |
Returns: | |
list[dict]: Список сообщений с ролями и содержимым. | |
""" | |
actual_prompt = self.apply_llm_template_to_prompt(prompt) | |
messages = [] | |
if self.params.predict_params and self.params.predict_params.system_prompt: | |
messages.append({"role": "system", "content": self.params.predict_params.system_prompt}) | |
messages.append({"role": "user", "content": actual_prompt}) | |
return messages | |
def apply_llm_template_to_prompt(self, prompt: str) -> str: | |
""" | |
Применяет шаблон LLM к переданному промпту, если он задан. | |
Args: | |
prompt (str): Пользовательский промпт. | |
Returns: | |
str: Промпт с примененным шаблоном (или оригинальный, если шаблон отсутствует). | |
""" | |
actual_prompt = prompt | |
if self.params.template is not None: | |
actual_prompt = self.params.template.replace("{{PROMPT}}", actual_prompt) | |
return actual_prompt | |
async def tokenize(self, prompt: str) -> Optional[dict]: | |
""" | |
Выполняет токенизацию переданного промпта. | |
Args: | |
prompt (str): Промпт для токенизации. | |
Returns: | |
Optional[dict]: Словарь с токенами и максимальной длиной модели, если запрос успешен. | |
Если запрос неуспешен, возвращает None. | |
""" | |
actual_prompt = self.apply_llm_template_to_prompt(prompt) | |
request_data = { | |
"model": self.get_model(), | |
"prompt": actual_prompt, | |
"add_special_tokens": False, | |
} | |
try: | |
async with httpx.AsyncClient() as client: | |
response = await client.post( | |
f"{self.params.url}/tokenize", | |
json=request_data, | |
headers=super().create_headers(), | |
) | |
if response.status_code == 200: | |
data = response.json() | |
if "tokens" in data: | |
return {"tokens": data["tokens"], "maxLength": data.get("max_model_len")} | |
elif response.status_code == 404: | |
print("Tokenization endpoint not found (404).") | |
else: | |
print(f"Failed to tokenize: {response.status_code}") | |
except httpx.RequestError as e: | |
print(f"Request failed: {e}") | |
return None | |
async def detokenize(self, tokens: List[int]) -> Optional[str]: | |
""" | |
Выполняет детокенизацию переданных токенов. | |
Args: | |
tokens (List[int]): Список токенов для детокенизации. | |
Returns: | |
Optional[str]: Строка, полученная в результате детокенизации, если запрос успешен. | |
Если запрос неуспешен, возвращает None. | |
""" | |
request_data = {"model": self.get_model(), "tokens": tokens or []} | |
try: | |
async with httpx.AsyncClient() as client: | |
response = await client.post( | |
f"{self.params.url}/detokenize", | |
json=request_data, | |
headers=super().create_headers(), | |
) | |
if response.status_code == 200: | |
data = response.json() | |
if "prompt" in data: | |
return data["prompt"].strip() | |
elif response.status_code == 404: | |
print("Detokenization endpoint not found (404).") | |
else: | |
print(f"Failed to detokenize: {response.status_code}") | |
except httpx.RequestError as e: | |
print(f"Request failed: {e}") | |
return None | |
async def create_request(self, prompt: str) -> dict: | |
""" | |
Создает запрос для предсказания на основе параметров LLM. | |
Args: | |
prompt (str): Промпт для запроса. | |
Returns: | |
dict: Словарь с параметрами для выполнения запроса. | |
""" | |
model = self.get_model() | |
request = { | |
"stream": True, | |
"model": model, | |
} | |
predict_params = self.params.predict_params | |
if predict_params: | |
if predict_params.stop: | |
non_empty_stop = list(filter(lambda o: o != "", predict_params.stop)) | |
if non_empty_stop: | |
request["stop"] = non_empty_stop | |
if predict_params.n_predict is not None: | |
request["max_tokens"] = int(predict_params.n_predict or 0) | |
request["temperature"] = float(predict_params.temperature or 0) | |
if predict_params.top_k is not None: | |
request["top_k"] = int(predict_params.top_k) | |
if predict_params.top_p is not None: | |
request["top_p"] = float(predict_params.top_p) | |
if predict_params.min_p is not None: | |
request["min_p"] = float(predict_params.min_p) | |
if predict_params.seed is not None: | |
request["seed"] = int(predict_params.seed) | |
if predict_params.n_keep is not None: | |
request["n_keep"] = int(predict_params.n_keep) | |
if predict_params.cache_prompt is not None: | |
request["cache_prompt"] = bool(predict_params.cache_prompt) | |
if predict_params.repeat_penalty is not None: | |
request["repetition_penalty"] = float(predict_params.repeat_penalty) | |
if predict_params.repeat_last_n is not None: | |
request["repeat_last_n"] = int(predict_params.repeat_last_n) | |
if predict_params.presence_penalty is not None: | |
request["presence_penalty"] = float(predict_params.presence_penalty) | |
if predict_params.frequency_penalty is not None: | |
request["frequency_penalty"] = float(predict_params.frequency_penalty) | |
request["messages"] = self.create_messages(prompt) | |
return request | |
async def trim_sources(self, sources: str, user_request: str, system_prompt: str = None) -> dict: | |
""" | |
Обрезает текст источников, чтобы уложиться в допустимое количество токенов. | |
Args: | |
sources (str): Текст источников. | |
user_request (str): Запрос пользователя с примененным шаблоном без текста источников. | |
system_prompt (str): Системный промпт, если нужен. | |
Returns: | |
dict: Словарь с результатом, количеством токенов до и после обрезки. | |
""" | |
# Токенизация текста источников | |
sources_tokens_data = await self.tokenize(sources) | |
if sources_tokens_data is None: | |
raise ValueError("Failed to tokenize sources.") | |
max_token_count = sources_tokens_data.get("maxLength", 0) | |
# Токены системного промпта | |
system_prompt_token_count = 0 | |
if system_prompt is not None: | |
system_prompt_tokens = await self.tokenize(system_prompt) | |
system_prompt_token_count = len(system_prompt_tokens["tokens"]) if system_prompt_tokens else 0 | |
# Оригинальное количество токенов | |
original_token_count = len(sources_tokens_data["tokens"]) | |
# Токенизация пользовательского промпта | |
aux_prompt = self.apply_llm_template_to_prompt(user_request) | |
aux_tokens_data = await self.tokenize(aux_prompt) | |
aux_token_count = len(aux_tokens_data["tokens"]) if aux_tokens_data else 0 | |
# Максимально допустимое количество токенов для источников | |
max_length = ( | |
max_token_count | |
- (self.params.predict_params.n_predict or 0) | |
- aux_token_count | |
- system_prompt_token_count | |
) | |
max_length = max(max_length, 0) | |
# Обрезка токенов источников | |
if "tokens" in sources_tokens_data: | |
sources_tokens_data["tokens"] = sources_tokens_data["tokens"][:max_length] | |
detokenized_prompt = await self.detokenize(sources_tokens_data["tokens"]) | |
if detokenized_prompt is not None: | |
sources = detokenized_prompt | |
else: | |
sources = sources[:max_length] | |
else: | |
sources = sources[:max_length] | |
# Возврат результата | |
return { | |
"result": sources, | |
"originalTokenCount": original_token_count, | |
"slicedTokenCount": len(sources_tokens_data["tokens"]), | |
} | |
async def predict(self, prompt: str) -> str: | |
""" | |
Выполняет запрос к API с поддержкой потокового вывода (SSE) и возвращает результат. | |
Args: | |
prompt (str): Входной текст для предсказания. | |
Returns: | |
str: Сгенерированный текст. | |
""" | |
async with httpx.AsyncClient() as client: | |
# Формируем тело запроса | |
request = await self.create_request(prompt) | |
# Начинаем потоковый запрос | |
async with client.stream("POST", f"{self.params.url}/v1/chat/completions", json=request) as response: | |
if response.status_code != 200: | |
# Если ошибка, читаем ответ для получения подробностей | |
error_content = await response.aread() | |
raise Exception(f"API error: {error_content.decode('utf-8')}") | |
# Для хранения результата | |
generated_text = "" | |
# Асинхронное чтение построчно | |
async for line in response.aiter_lines(): | |
if line.startswith("data: "): # SSE-сообщения начинаются с "data: " | |
try: | |
# Парсим JSON из строки | |
data = json.loads(line[len("data: "):].strip()) | |
if data == "[DONE]": # Конец потока | |
break | |
if "choices" in data and data["choices"]: | |
# Получаем текст из текущего токена | |
token_value = data["choices"][0].get("delta", {}).get("content", "") | |
generated_text += token_value | |
except json.JSONDecodeError: | |
continue # Игнорируем строки, которые не удается декодировать | |
return generated_text.strip() | |