nn-search-full / llm /common.py
muryshev's picture
update
7be11b2
raw
history blame
2.69 kB
from pydantic import BaseModel, Field
from typing import Optional, List, Protocol
from abc import ABC, abstractmethod
class LlmPredictParams(BaseModel):
"""
Параметры для предсказания LLM.
"""
system_prompt: Optional[str] = Field(None, description="Системный промпт.")
user_prompt: Optional[str] = Field(None, description="Шаблон промпта для передачи от роли user.")
n_predict: Optional[int] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
min_p: Optional[float] = None
seed: Optional[int] = None
repeat_penalty: Optional[float] = None
repeat_last_n: Optional[int] = None
retry_if_text_not_present: Optional[str] = None
retry_count: Optional[int] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
n_keep: Optional[int] = None
cache_prompt: Optional[bool] = None
stop: Optional[List[str]] = None
class LlmParams(BaseModel):
"""
Основные параметры для LLM.
"""
url: str
model: Optional[str] = Field(None, description="Предполагается, что для локального API этот параметр не будет указываться, т.к. будем брать первую модель из списка потому, что модель доступна всего одна. Для deepinfra такой подход не подойдет и модель нужно задавать явно.")
type: Optional[str] = None
default: Optional[bool] = None
template: Optional[str] = None
predict_params: Optional[LlmPredictParams] = None
api_key: Optional[str] = None
class LlmApiProtocol(Protocol):
async def tokenize(self, prompt: str) -> Optional[dict]:
...
async def detokenize(self, tokens: List[int]) -> Optional[str]:
...
async def trim_sources(self, sources: str, user_request: str, system_prompt: str = None) -> dict:
...
async def predict(self, prompt: str) -> str:
...
class LlmApi:
"""
Базовый клас для работы с API LLM.
"""
params: LlmParams = None
def __init__(self):
self.params = None
def set_params(self, params: LlmParams):
self.params = params
def create_headers(self) -> dict[str, str]:
headers = {"Content-Type": "application/json"}
if self.params.api_key is not None:
headers["Authorization"] = self.params.api_key
return headers