nn-search-full / llm /vllm_api-sync.py
muryshev's picture
init
b24d496
raw
history blame
14.6 kB
import json
import os
import requests
from typing import Optional, List, Any
from pydantic import BaseModel, Field
class LlmPredictParams(BaseModel):
"""
Параметры для предсказания LLM.
"""
system_prompt: Optional[str] = Field(None, description="Системный промпт.")
user_prompt: Optional[str] = Field(None, description="Шаблон промпта для передачи от роли user.")
n_predict: Optional[int] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
min_p: Optional[float] = None
seed: Optional[int] = None
repeat_penalty: Optional[float] = None
repeat_last_n: Optional[int] = None
retry_if_text_not_present: Optional[str] = None
retry_count: Optional[int] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
n_keep: Optional[int] = None
cache_prompt: Optional[bool] = None
stop: Optional[List[str]] = None
class LlmParams(BaseModel):
"""
Основные параметры для LLM.
"""
url: str
type: Optional[str] = None
default: Optional[bool] = None
template: Optional[str] = None
predict_params: Optional[LlmPredictParams] = None
class LlmApi:
"""
Класс для работы с API vllm.
"""
params: LlmParams = None
def __init__(self, params: LlmParams):
self.params = params
def get_models(self) -> list[str]:
"""
Выполняет GET-запрос к API для получения списка доступных моделей.
Возвращает:
list[str]: Список идентификаторов моделей.
Если произошла ошибка или данные недоступны, возвращается пустой список.
Исключения:
Все ошибки HTTP-запросов логируются в консоль, но не выбрасываются дальше.
"""
try:
response = requests.get(f"{self.params.url}/v1/models", headers={"Content-Type": "application/json"})
if response.status_code == 200:
json_data = response.json()
result = [item['id'] for item in json_data.get('data', [])]
return result
except requests.RequestException as error:
print('OpenAiService.getModels error:')
print(error)
return []
def create_messages(self, prompt: str) -> list[dict]:
"""
Создает сообщения для LLM на основе переданного промпта и системного промпта (если он задан).
Args:
prompt (str): Пользовательский промпт.
Returns:
list[dict]: Список сообщений с ролями и содержимым.
"""
actual_prompt = self.apply_llm_template_to_prompt(prompt)
messages = []
if self.params.predict_params and self.params.predict_params.system_prompt:
messages.append({"role": "system", "content": self.params.predict_params.system_prompt})
messages.append({"role": "user", "content": actual_prompt})
return messages
def apply_llm_template_to_prompt(self, prompt: str) -> str:
"""
Применяет шаблон LLM к переданному промпту, если он задан.
Args:
prompt (str): Пользовательский промпт.
Returns:
str: Промпт с примененным шаблоном (или оригинальный, если шаблон отсутствует).
"""
actual_prompt = prompt
if self.params.template is not None:
actual_prompt = self.params.template.replace("{{PROMPT}}", actual_prompt)
return actual_prompt
def tokenize(self, prompt: str) -> Optional[dict]:
"""
Выполняет токенизацию переданного промпта.
Args:
prompt (str): Промпт для токенизации.
Returns:
Optional[dict]: Словарь с токенами и максимальной длиной модели, если запрос успешен.
Если запрос неуспешен, возвращает None.
"""
model = self.get_models()[0] if self.get_models() else None
if not model:
print("No models available for tokenization.")
return None
actual_prompt = self.apply_llm_template_to_prompt(prompt)
request_data = {
"model": model,
"prompt": actual_prompt,
"add_special_tokens": False,
}
try:
response = requests.post(
f"{self.params.url}/tokenize",
json=request_data,
headers={"Content-Type": "application/json"},
)
if response.ok:
data = response.json()
if "tokens" in data:
return {"tokens": data["tokens"], "maxLength": data.get("max_model_len")}
elif response.status_code == 404:
print("Tokenization endpoint not found (404).")
else:
print(f"Failed to tokenize: {response.status_code}")
except requests.RequestException as e:
print(f"Request failed: {e}")
return None
def detokenize(self, tokens: List[int]) -> Optional[str]:
"""
Выполняет детокенизацию переданных токенов.
Args:
tokens (List[int]): Список токенов для детокенизации.
Returns:
Optional[str]: Строка, полученная в результате детокенизации, если запрос успешен.
Если запрос неуспешен, возвращает None.
"""
model = self.get_models()[0] if self.get_models() else None
if not model:
print("No models available for detokenization.")
return None
request_data = {"model": model, "tokens": tokens or []}
try:
response = requests.post(
f"{self.params.url}/detokenize",
json=request_data,
headers={"Content-Type": "application/json"},
)
if response.ok:
data = response.json()
if "prompt" in data:
return data["prompt"].strip()
elif response.status_code == 404:
print("Detokenization endpoint not found (404).")
else:
print(f"Failed to detokenize: {response.status_code}")
except requests.RequestException as e:
print(f"Request failed: {e}")
return None
def create_request(self, prompt: str) -> dict:
"""
Создает запрос для предсказания на основе параметров LLM.
Args:
prompt (str): Промпт для запроса.
Returns:
dict: Словарь с параметрами для выполнения запроса.
"""
llm_params = self.params
models = self.get_models()
if not models:
raise ValueError("No models available to create a request.")
model = models[0]
request = {
"stream": True,
"model": model,
}
predict_params = llm_params.predict_params
if predict_params:
if predict_params.stop:
# Фильтруем пустые строки в stop
non_empty_stop = list(filter(lambda o: o != "", predict_params.stop))
if non_empty_stop:
request["stop"] = non_empty_stop
if predict_params.n_predict is not None:
request["max_tokens"] = int(predict_params.n_predict or 0)
request["temperature"] = float(predict_params.temperature or 0)
if predict_params.top_k is not None:
request["top_k"] = int(predict_params.top_k)
if predict_params.top_p is not None:
request["top_p"] = float(predict_params.top_p)
if predict_params.min_p is not None:
request["min_p"] = float(predict_params.min_p)
if predict_params.seed is not None:
request["seed"] = int(predict_params.seed)
if predict_params.n_keep is not None:
request["n_keep"] = int(predict_params.n_keep)
if predict_params.cache_prompt is not None:
request["cache_prompt"] = bool(predict_params.cache_prompt)
if predict_params.repeat_penalty is not None:
request["repetition_penalty"] = float(predict_params.repeat_penalty)
if predict_params.repeat_last_n is not None:
request["repeat_last_n"] = int(predict_params.repeat_last_n)
if predict_params.presence_penalty is not None:
request["presence_penalty"] = float(predict_params.presence_penalty)
if predict_params.frequency_penalty is not None:
request["frequency_penalty"] = float(predict_params.frequency_penalty)
# Генерируем сообщения
request["messages"] = self.create_messages(prompt)
return request
def trim_sources(self, sources: str, user_request: str, system_prompt: str = None) -> dict:
"""
Обрезает текст источников, чтобы уложиться в допустимое количество токенов.
Args:
sources (str): Текст источников.
user_request (str): Запрос пользователя с примененным шаблоном без текста источников.
system_prompt (str): Системный промпт, если нужен.
Returns:
dict: Словарь с результатом, количеством токенов до и после обрезки.
"""
# Токенизация текста источников
sources_tokens_data = self.tokenize(sources)
if sources_tokens_data is None:
raise ValueError("Failed to tokenize sources.")
max_token_count = sources_tokens_data.get("maxLength", 0)
# Токены системного промпта
system_prompt_token_count = 0
if system_prompt is not None:
system_prompt_tokens = self.tokenize(system_prompt)
system_prompt_token_count = len(system_prompt_tokens["tokens"]) if system_prompt_tokens else 0
# Оригинальное количество токенов
original_token_count = len(sources_tokens_data["tokens"])
# Токенизация пользовательского промпта
aux_prompt = self.apply_llm_template_to_prompt(user_request)
aux_tokens_data = self.tokenize(aux_prompt)
aux_token_count = len(aux_tokens_data["tokens"]) if aux_tokens_data else 0
# Максимально допустимое количество токенов для источников
max_length = (
max_token_count
- (self.params.predict_params.n_predict or 0)
- aux_token_count
- system_prompt_token_count
)
max_length = max(max_length, 0)
# Обрезка токенов источников
if "tokens" in sources_tokens_data:
sources_tokens_data["tokens"] = sources_tokens_data["tokens"][:max_length]
detokenized_prompt = self.detokenize(sources_tokens_data["tokens"])
if detokenized_prompt is not None:
sources = detokenized_prompt
else:
sources = sources[:max_length]
else:
sources = sources[:max_length]
# Возврат результата
return {
"result": sources,
"originalTokenCount": original_token_count,
"slicedTokenCount": len(sources_tokens_data["tokens"]),
}
def predict(self, prompt: str) -> str:
"""
Выполняет SSE-запрос к API и возвращает собранный результат как текст.
Args:
prompt (str): Входной текст для предсказания.
Returns:
str: Сгенерированный текст.
Raises:
Exception: Если запрос завершился ошибкой.
"""
# Создание запроса
request = self.create_request(prompt)
print(f"Predict request. Url: {self.params.url}")
response = requests.post(
f"{self.params.url}/v1/chat/completions",
headers={"Content-Type": "application/json"},
json=request,
stream=True # Для обработки SSE
)
if not response.ok:
raise Exception(f"Failed to generate text: {response.text}")
# Обработка SSE-ответа
generated_text = ""
for line in response.iter_lines(decode_unicode=True):
if line.startswith("data: "):
try:
data = json.loads(line[len("data: "):].strip())
# Проверка завершения генерации
if data == "[DONE]":
break
# Получение текста из ответа
if "choices" in data and data["choices"]:
token_value = data["choices"][0].get("delta", {}).get("content", "")
generated_text += token_value.replace("</s>", "")
except json.JSONDecodeError:
continue # Игнорирование строк, которые не удалось декодировать
return generated_text