Spaces:
Running
on
T4
Running
on
T4
update
Browse files- fastapi_app.py +16 -5
- llm/common.py +64 -0
- llm/deepinfra_api.py +153 -0
- llm/vllm_api.py +33 -65
- transaction_maps_search.py +9 -12
fastapi_app.py
CHANGED
@@ -6,7 +6,8 @@ import os
|
|
6 |
import datetime
|
7 |
import json
|
8 |
import traceback
|
9 |
-
from llm.
|
|
|
10 |
|
11 |
# Set the path for log files
|
12 |
LOGS_BASE_PATH = os.getenv("LOGS_BASE_PATH", "logs")
|
@@ -17,7 +18,9 @@ LOGS_BASE_PATH = os.getenv("LOGS_BASE_PATH", "logs")
|
|
17 |
|
18 |
# Check if logs are enabled
|
19 |
ENABLE_LOGS = os.getenv("ENABLE_LOGS", "0") == "1"
|
20 |
-
|
|
|
|
|
21 |
|
22 |
class Query(BaseModel):
|
23 |
query: str = ''
|
@@ -87,18 +90,26 @@ async def search_route(query: Query) -> dict:
|
|
87 |
|
88 |
llm_params = getattr(query, "llm_params", None)
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
if find_transaction_maps_by_question or find_transaction_maps_by_operation:
|
91 |
-
transaction_maps_results, answer = transaction_maps_search.search_transaction_map(
|
92 |
query=question,
|
93 |
find_transaction_maps_by_question=find_transaction_maps_by_question,
|
94 |
-
k_neighbours=top
|
|
|
95 |
|
96 |
response = {'transaction_maps_results': transaction_maps_results}
|
97 |
|
98 |
else:
|
99 |
modified_query, titles, concat_docs, \
|
100 |
relevant_consultations, predicted_explanation, \
|
101 |
-
llm_responses = await search.search(question, use_qe, use_olympic, categories,
|
102 |
|
103 |
results = [{'title': str(item1), 'text_for_llm': str(item2)} for item1, item2 in
|
104 |
zip(titles, concat_docs)]
|
|
|
6 |
import datetime
|
7 |
import json
|
8 |
import traceback
|
9 |
+
from llm.common import LlmParams, LlmPredictParams
|
10 |
+
from llm.deepinfra_api import DeepInfraApi
|
11 |
|
12 |
# Set the path for log files
|
13 |
LOGS_BASE_PATH = os.getenv("LOGS_BASE_PATH", "logs")
|
|
|
18 |
|
19 |
# Check if logs are enabled
|
20 |
ENABLE_LOGS = os.getenv("ENABLE_LOGS", "0") == "1"
|
21 |
+
LLM_API_URL = os.getenv("LLM_API_URL", "")
|
22 |
+
LLM_API_KEY = os.getenv("LLM_API_KEY", "")
|
23 |
+
LLM_USE_DEEPINFRA = os.getenv("LLM_USE_DEEPINFRA", "") == "1"
|
24 |
|
25 |
class Query(BaseModel):
|
26 |
query: str = ''
|
|
|
90 |
|
91 |
llm_params = getattr(query, "llm_params", None)
|
92 |
|
93 |
+
if llm_params is None:
|
94 |
+
llm_params = LlmParams(url=LLM_API_URL,api_key=LLM_API_KEY, model="mistralai/Mixtral-8x7B-Instruct-v0.1", predict_params=LlmPredictParams(temperature=0.15, top_p=0.95, min_p=0.05, seed=42, repetition_penalty=1.2, presence_penalty=1.1, max_tokens=6000))
|
95 |
+
|
96 |
+
if LLM_USE_DEEPINFRA:
|
97 |
+
llm_api = DeepInfraApi(llm_params)
|
98 |
+
|
99 |
+
|
100 |
if find_transaction_maps_by_question or find_transaction_maps_by_operation:
|
101 |
+
transaction_maps_results, answer = await transaction_maps_search.search_transaction_map(
|
102 |
query=question,
|
103 |
find_transaction_maps_by_question=find_transaction_maps_by_question,
|
104 |
+
k_neighbours=top,
|
105 |
+
llm_api=llm_api)
|
106 |
|
107 |
response = {'transaction_maps_results': transaction_maps_results}
|
108 |
|
109 |
else:
|
110 |
modified_query, titles, concat_docs, \
|
111 |
relevant_consultations, predicted_explanation, \
|
112 |
+
llm_responses = await search.search(question, use_qe, use_olympic, categories, llm_params)
|
113 |
|
114 |
results = [{'title': str(item1), 'text_for_llm': str(item2)} for item1, item2 in
|
115 |
zip(titles, concat_docs)]
|
llm/common.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, Field
|
2 |
+
from typing import Optional, List, Protocol
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
|
5 |
+
class LlmPredictParams(BaseModel):
|
6 |
+
"""
|
7 |
+
Параметры для предсказания LLM.
|
8 |
+
"""
|
9 |
+
system_prompt: Optional[str] = Field(None, description="Системный промпт.")
|
10 |
+
user_prompt: Optional[str] = Field(None, description="Шаблон промпта для передачи от роли user.")
|
11 |
+
n_predict: Optional[int] = None
|
12 |
+
temperature: Optional[float] = None
|
13 |
+
top_k: Optional[int] = None
|
14 |
+
top_p: Optional[float] = None
|
15 |
+
min_p: Optional[float] = None
|
16 |
+
seed: Optional[int] = None
|
17 |
+
repeat_penalty: Optional[float] = None
|
18 |
+
repeat_last_n: Optional[int] = None
|
19 |
+
retry_if_text_not_present: Optional[str] = None
|
20 |
+
retry_count: Optional[int] = None
|
21 |
+
presence_penalty: Optional[float] = None
|
22 |
+
frequency_penalty: Optional[float] = None
|
23 |
+
n_keep: Optional[int] = None
|
24 |
+
cache_prompt: Optional[bool] = None
|
25 |
+
stop: Optional[List[str]] = None
|
26 |
+
|
27 |
+
|
28 |
+
class LlmParams(BaseModel):
|
29 |
+
"""
|
30 |
+
Основные параметры для LLM.
|
31 |
+
"""
|
32 |
+
url: str
|
33 |
+
model: Optional[str] = Field(None, description="Предполагается, что для локального API этот параметр не будет указываться, т.к. будем брать первую модель из списка потому, что модель доступна всего одна. Для deepinfra такой подход не подойдет и модель нужно задавать явно.")
|
34 |
+
type: Optional[str] = None
|
35 |
+
default: Optional[bool] = None
|
36 |
+
template: Optional[str] = None
|
37 |
+
predict_params: Optional[LlmPredictParams] = None
|
38 |
+
api_key: Optional[str] = None
|
39 |
+
|
40 |
+
class LlmApiProtocol(Protocol):
|
41 |
+
async def tokenize(self, prompt: str) -> Optional[dict]:
|
42 |
+
...
|
43 |
+
async def detokenize(self, tokens: List[int]) -> Optional[str]:
|
44 |
+
...
|
45 |
+
async def trim_sources(self, sources: str, user_request: str, system_prompt: str = None) -> dict:
|
46 |
+
...
|
47 |
+
async def predict(self, prompt: str) -> str:
|
48 |
+
...
|
49 |
+
|
50 |
+
class LlmApi:
|
51 |
+
"""
|
52 |
+
Базовый клас для работы с API LLM.
|
53 |
+
"""
|
54 |
+
params: LlmParams = None
|
55 |
+
|
56 |
+
|
57 |
+
def create_headers(self) -> dict[str, str]:
|
58 |
+
headers = {"Content-Type": "application/json"}
|
59 |
+
|
60 |
+
if self.params.api_key is not None:
|
61 |
+
headers["Authorization"] = self.params.api_key
|
62 |
+
|
63 |
+
return headers
|
64 |
+
|
llm/deepinfra_api.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import Optional, List
|
3 |
+
import httpx
|
4 |
+
from common import LlmPredictParams, LlmParams, LlmApi
|
5 |
+
|
6 |
+
class DeepInfraApi(LlmApi):
|
7 |
+
"""
|
8 |
+
Класс для работы с API vllm.
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(self, params: LlmParams):
|
12 |
+
super.params = params
|
13 |
+
|
14 |
+
|
15 |
+
async def get_models(self) -> List[str]:
|
16 |
+
"""
|
17 |
+
Выполняет GET-запрос к API для получения списка доступных моделей.
|
18 |
+
|
19 |
+
Возвращает:
|
20 |
+
list[str]: Список идентификаторов моделей.
|
21 |
+
Если произошла ошибка или данные недоступны, возвращается пустой список.
|
22 |
+
|
23 |
+
Исключения:
|
24 |
+
Все ошибки HTTP-запросов логируются в консоль, но не выбрасываются дальше.
|
25 |
+
"""
|
26 |
+
try:
|
27 |
+
async with httpx.AsyncClient() as client:
|
28 |
+
response = await client.get(f"{super.params.url}/v1/openai/models", super.create_headers())
|
29 |
+
if response.status_code == 200:
|
30 |
+
json_data = response.json()
|
31 |
+
return [item['id'] for item in json_data.get('data', [])]
|
32 |
+
except httpx.RequestError as error:
|
33 |
+
print('Error fetching models:', error)
|
34 |
+
return []
|
35 |
+
|
36 |
+
def create_messages(self, prompt: str) -> List[dict]:
|
37 |
+
"""
|
38 |
+
Создает сообщения для LLM на основе переданного промпта и системного промпта (если он задан).
|
39 |
+
|
40 |
+
Args:
|
41 |
+
prompt (str): Пользовательский промпт.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
list[dict]: Список сообщений с ролями и содержимым.
|
45 |
+
"""
|
46 |
+
actual_prompt = self.apply_llm_template_to_prompt(prompt)
|
47 |
+
messages = []
|
48 |
+
if super.params.predict_params and super.params.predict_params.system_prompt:
|
49 |
+
messages.append({"role": "system", "content": super.params.predict_params.system_prompt})
|
50 |
+
messages.append({"role": "user", "content": actual_prompt})
|
51 |
+
return messages
|
52 |
+
|
53 |
+
def apply_llm_template_to_prompt(self, prompt: str) -> str:
|
54 |
+
"""
|
55 |
+
Применяет шаблон LLM к переданному промпту, если он задан.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
prompt (str): Пользовательский промпт.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
str: Промпт с примененным шаблоном (или оригинальный, если шаблон отсутствует).
|
62 |
+
"""
|
63 |
+
actual_prompt = prompt
|
64 |
+
if super.params.template is not None:
|
65 |
+
actual_prompt = super.params.template.replace("{{PROMPT}}", actual_prompt)
|
66 |
+
return actual_prompt
|
67 |
+
|
68 |
+
async def tokenize(self, prompt: str) -> Optional[dict]:
|
69 |
+
raise NotImplementedError("This function is not supported.")
|
70 |
+
|
71 |
+
async def detokenize(self, tokens: List[int]) -> Optional[str]:
|
72 |
+
raise NotImplementedError("This function is not supported.")
|
73 |
+
|
74 |
+
async def create_request(self, prompt: str) -> dict:
|
75 |
+
"""
|
76 |
+
Создает запрос для предсказания на основе параметров LLM.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
prompt (str): Промпт для запроса.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
dict: Словарь с параметрами для выполнения запроса.
|
83 |
+
"""
|
84 |
+
|
85 |
+
request = {
|
86 |
+
"stream": False,
|
87 |
+
"model": super.params.model,
|
88 |
+
}
|
89 |
+
|
90 |
+
predict_params = super.params.predict_params
|
91 |
+
if predict_params:
|
92 |
+
if predict_params.stop:
|
93 |
+
non_empty_stop = list(filter(lambda o: o != "", predict_params.stop))
|
94 |
+
if non_empty_stop:
|
95 |
+
request["stop"] = non_empty_stop
|
96 |
+
|
97 |
+
if predict_params.n_predict is not None:
|
98 |
+
request["max_tokens"] = int(predict_params.n_predict or 0)
|
99 |
+
|
100 |
+
request["temperature"] = float(predict_params.temperature or 0)
|
101 |
+
if predict_params.top_k is not None:
|
102 |
+
request["top_k"] = int(predict_params.top_k)
|
103 |
+
|
104 |
+
if predict_params.top_p is not None:
|
105 |
+
request["top_p"] = float(predict_params.top_p)
|
106 |
+
|
107 |
+
if predict_params.min_p is not None:
|
108 |
+
request["min_p"] = float(predict_params.min_p)
|
109 |
+
|
110 |
+
if predict_params.seed is not None:
|
111 |
+
request["seed"] = int(predict_params.seed)
|
112 |
+
|
113 |
+
if predict_params.n_keep is not None:
|
114 |
+
request["n_keep"] = int(predict_params.n_keep)
|
115 |
+
|
116 |
+
if predict_params.cache_prompt is not None:
|
117 |
+
request["cache_prompt"] = bool(predict_params.cache_prompt)
|
118 |
+
|
119 |
+
if predict_params.repeat_penalty is not None:
|
120 |
+
request["repetition_penalty"] = float(predict_params.repeat_penalty)
|
121 |
+
|
122 |
+
if predict_params.repeat_last_n is not None:
|
123 |
+
request["repeat_last_n"] = int(predict_params.repeat_last_n)
|
124 |
+
|
125 |
+
if predict_params.presence_penalty is not None:
|
126 |
+
request["presence_penalty"] = float(predict_params.presence_penalty)
|
127 |
+
|
128 |
+
if predict_params.frequency_penalty is not None:
|
129 |
+
request["frequency_penalty"] = float(predict_params.frequency_penalty)
|
130 |
+
|
131 |
+
request["messages"] = self.create_messages(prompt)
|
132 |
+
return request
|
133 |
+
|
134 |
+
async def trim_sources(self, sources: str, user_request: str, system_prompt: str = None) -> dict:
|
135 |
+
raise NotImplementedError("This function is not supported.")
|
136 |
+
|
137 |
+
async def predict(self, prompt: str) -> str:
|
138 |
+
"""
|
139 |
+
Выполняет запрос к API и возвращает результат.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
prompt (str): Входной текст для предсказания.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
str: Сгенерированный текст.
|
146 |
+
"""
|
147 |
+
async with httpx.AsyncClient() as client:
|
148 |
+
request = await self.create_request(prompt)
|
149 |
+
|
150 |
+
async with httpx.AsyncClient() as client:
|
151 |
+
response = client.post(f"{super.params.url}/v1/openai/chat/completions", super.create_headers(), json=request)
|
152 |
+
if response.status_code == 200:
|
153 |
+
return response.json()["choices"][0]["message"]["content"]
|
llm/vllm_api.py
CHANGED
@@ -3,51 +3,17 @@ from typing import Optional, List, Any
|
|
3 |
|
4 |
import httpx
|
5 |
from pydantic import BaseModel, Field
|
|
|
6 |
|
7 |
|
8 |
-
class
|
9 |
-
"""
|
10 |
-
Параметры для предсказания LLM.
|
11 |
-
"""
|
12 |
-
system_prompt: Optional[str] = Field(None, description="Системный промпт.")
|
13 |
-
user_prompt: Optional[str] = Field(None, description="Шаблон промпта для передачи от роли user.")
|
14 |
-
n_predict: Optional[int] = None
|
15 |
-
temperature: Optional[float] = None
|
16 |
-
top_k: Optional[int] = None
|
17 |
-
top_p: Optional[float] = None
|
18 |
-
min_p: Optional[float] = None
|
19 |
-
seed: Optional[int] = None
|
20 |
-
repeat_penalty: Optional[float] = None
|
21 |
-
repeat_last_n: Optional[int] = None
|
22 |
-
retry_if_text_not_present: Optional[str] = None
|
23 |
-
retry_count: Optional[int] = None
|
24 |
-
presence_penalty: Optional[float] = None
|
25 |
-
frequency_penalty: Optional[float] = None
|
26 |
-
n_keep: Optional[int] = None
|
27 |
-
cache_prompt: Optional[bool] = None
|
28 |
-
stop: Optional[List[str]] = None
|
29 |
-
|
30 |
-
|
31 |
-
class LlmParams(BaseModel):
|
32 |
-
"""
|
33 |
-
Основные параметры для LLM.
|
34 |
-
"""
|
35 |
-
url: str
|
36 |
-
type: Optional[str] = None
|
37 |
-
default: Optional[bool] = None
|
38 |
-
template: Optional[str] = None
|
39 |
-
predict_params: Optional[LlmPredictParams] = None
|
40 |
-
|
41 |
-
|
42 |
-
class LlmApi:
|
43 |
"""
|
44 |
Класс для работы с API vllm.
|
45 |
"""
|
46 |
-
params: LlmParams = None
|
47 |
|
48 |
def __init__(self, params: LlmParams):
|
49 |
-
|
50 |
-
|
51 |
async def get_models(self) -> List[str]:
|
52 |
"""
|
53 |
Выполняет GET-запрос к API для получения списка доступных моделей.
|
@@ -61,13 +27,26 @@ class LlmApi:
|
|
61 |
"""
|
62 |
try:
|
63 |
async with httpx.AsyncClient() as client:
|
64 |
-
response = await client.get(f"{
|
65 |
if response.status_code == 200:
|
66 |
json_data = response.json()
|
67 |
return [item['id'] for item in json_data.get('data', [])]
|
68 |
except httpx.RequestError as error:
|
69 |
print('Error fetching models:', error)
|
70 |
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
def create_messages(self, prompt: str) -> List[dict]:
|
73 |
"""
|
@@ -81,8 +60,8 @@ class LlmApi:
|
|
81 |
"""
|
82 |
actual_prompt = self.apply_llm_template_to_prompt(prompt)
|
83 |
messages = []
|
84 |
-
if
|
85 |
-
messages.append({"role": "system", "content":
|
86 |
messages.append({"role": "user", "content": actual_prompt})
|
87 |
return messages
|
88 |
|
@@ -97,8 +76,8 @@ class LlmApi:
|
|
97 |
str: Промпт с примененным шаблоном (или оригинальный, если шаблон отсутствует).
|
98 |
"""
|
99 |
actual_prompt = prompt
|
100 |
-
if
|
101 |
-
actual_prompt =
|
102 |
return actual_prompt
|
103 |
|
104 |
async def tokenize(self, prompt: str) -> Optional[dict]:
|
@@ -112,14 +91,10 @@ class LlmApi:
|
|
112 |
Optional[dict]: Словарь с токенами и максимальной длиной модели, если запрос успешен.
|
113 |
Если запрос неуспешен, возвращает None.
|
114 |
"""
|
115 |
-
model = (await self.get_models())[0] if await self.get_models() else None
|
116 |
-
if not model:
|
117 |
-
print("No models available for tokenization.")
|
118 |
-
return None
|
119 |
|
120 |
actual_prompt = self.apply_llm_template_to_prompt(prompt)
|
121 |
request_data = {
|
122 |
-
"model":
|
123 |
"prompt": actual_prompt,
|
124 |
"add_special_tokens": False,
|
125 |
}
|
@@ -127,9 +102,9 @@ class LlmApi:
|
|
127 |
try:
|
128 |
async with httpx.AsyncClient() as client:
|
129 |
response = await client.post(
|
130 |
-
f"{
|
131 |
json=request_data,
|
132 |
-
headers=
|
133 |
)
|
134 |
if response.status_code == 200:
|
135 |
data = response.json()
|
@@ -155,19 +130,15 @@ class LlmApi:
|
|
155 |
Optional[str]: Строка, полученная в результате детокенизации, если запрос успешен.
|
156 |
Если запрос неуспешен, возвращает None.
|
157 |
"""
|
158 |
-
|
159 |
-
|
160 |
-
print("No models available for detokenization.")
|
161 |
-
return None
|
162 |
-
|
163 |
-
request_data = {"model": model, "tokens": tokens or []}
|
164 |
|
165 |
try:
|
166 |
async with httpx.AsyncClient() as client:
|
167 |
response = await client.post(
|
168 |
-
f"{
|
169 |
json=request_data,
|
170 |
-
headers=
|
171 |
)
|
172 |
if response.status_code == 200:
|
173 |
data = response.json()
|
@@ -192,17 +163,14 @@ class LlmApi:
|
|
192 |
Returns:
|
193 |
dict: Словарь с параметрами для выполнения запроса.
|
194 |
"""
|
195 |
-
|
196 |
-
if not models:
|
197 |
-
raise ValueError("No models available to create a request.")
|
198 |
-
model = models[0]
|
199 |
|
200 |
request = {
|
201 |
"stream": True,
|
202 |
"model": model,
|
203 |
}
|
204 |
|
205 |
-
predict_params =
|
206 |
if predict_params:
|
207 |
if predict_params.stop:
|
208 |
non_empty_stop = list(filter(lambda o: o != "", predict_params.stop))
|
@@ -283,7 +251,7 @@ class LlmApi:
|
|
283 |
# Максимально допустимое количество токенов для источников
|
284 |
max_length = (
|
285 |
max_token_count
|
286 |
-
- (
|
287 |
- aux_token_count
|
288 |
- system_prompt_token_count
|
289 |
)
|
@@ -322,7 +290,7 @@ class LlmApi:
|
|
322 |
request = await self.create_request(prompt)
|
323 |
|
324 |
# Начинаем потоковый запрос
|
325 |
-
async with client.stream("POST", f"{
|
326 |
if response.status_code != 200:
|
327 |
# Если ошибка, читаем ответ для получения подробностей
|
328 |
error_content = await response.aread()
|
|
|
3 |
|
4 |
import httpx
|
5 |
from pydantic import BaseModel, Field
|
6 |
+
from common import LlmPredictParams, LlmParams, LlmApi
|
7 |
|
8 |
|
9 |
+
class LlmApi(LlmApi):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
"""
|
11 |
Класс для работы с API vllm.
|
12 |
"""
|
|
|
13 |
|
14 |
def __init__(self, params: LlmParams):
|
15 |
+
super.params = params
|
16 |
+
|
17 |
async def get_models(self) -> List[str]:
|
18 |
"""
|
19 |
Выполняет GET-запрос к API для получения списка доступных моделей.
|
|
|
27 |
"""
|
28 |
try:
|
29 |
async with httpx.AsyncClient() as client:
|
30 |
+
response = await client.get(f"{super.params.url}/v1/models", super.create_headers())
|
31 |
if response.status_code == 200:
|
32 |
json_data = response.json()
|
33 |
return [item['id'] for item in json_data.get('data', [])]
|
34 |
except httpx.RequestError as error:
|
35 |
print('Error fetching models:', error)
|
36 |
return []
|
37 |
+
|
38 |
+
async def get_model(self) -> str:
|
39 |
+
model = None
|
40 |
+
if super.params.model is not None:
|
41 |
+
model = super.params.model
|
42 |
+
else:
|
43 |
+
models = await self.get_models()
|
44 |
+
model = models[0] if models else None
|
45 |
+
|
46 |
+
if model is None:
|
47 |
+
raise Exception("No model name provided and no models available.")
|
48 |
+
|
49 |
+
return model
|
50 |
|
51 |
def create_messages(self, prompt: str) -> List[dict]:
|
52 |
"""
|
|
|
60 |
"""
|
61 |
actual_prompt = self.apply_llm_template_to_prompt(prompt)
|
62 |
messages = []
|
63 |
+
if super.params.predict_params and super.params.predict_params.system_prompt:
|
64 |
+
messages.append({"role": "system", "content": super.params.predict_params.system_prompt})
|
65 |
messages.append({"role": "user", "content": actual_prompt})
|
66 |
return messages
|
67 |
|
|
|
76 |
str: Промпт с примененным шаблоном (или оригинальный, если шаблон отсутствует).
|
77 |
"""
|
78 |
actual_prompt = prompt
|
79 |
+
if super.params.template is not None:
|
80 |
+
actual_prompt = super.params.template.replace("{{PROMPT}}", actual_prompt)
|
81 |
return actual_prompt
|
82 |
|
83 |
async def tokenize(self, prompt: str) -> Optional[dict]:
|
|
|
91 |
Optional[dict]: Словарь с токенами и максимальной длиной модели, если запрос успешен.
|
92 |
Если запрос неуспешен, возвращает None.
|
93 |
"""
|
|
|
|
|
|
|
|
|
94 |
|
95 |
actual_prompt = self.apply_llm_template_to_prompt(prompt)
|
96 |
request_data = {
|
97 |
+
"model": self.get_model(),
|
98 |
"prompt": actual_prompt,
|
99 |
"add_special_tokens": False,
|
100 |
}
|
|
|
102 |
try:
|
103 |
async with httpx.AsyncClient() as client:
|
104 |
response = await client.post(
|
105 |
+
f"{super.params.url}/tokenize",
|
106 |
json=request_data,
|
107 |
+
headers=super.create_headers(),
|
108 |
)
|
109 |
if response.status_code == 200:
|
110 |
data = response.json()
|
|
|
130 |
Optional[str]: Строка, полученная в результате детокенизации, если запрос успешен.
|
131 |
Если запрос неуспешен, возвращает None.
|
132 |
"""
|
133 |
+
|
134 |
+
request_data = {"model": self.get_model(), "tokens": tokens or []}
|
|
|
|
|
|
|
|
|
135 |
|
136 |
try:
|
137 |
async with httpx.AsyncClient() as client:
|
138 |
response = await client.post(
|
139 |
+
f"{super.params.url}/detokenize",
|
140 |
json=request_data,
|
141 |
+
headers=super.create_headers(),
|
142 |
)
|
143 |
if response.status_code == 200:
|
144 |
data = response.json()
|
|
|
163 |
Returns:
|
164 |
dict: Словарь с параметрами для выполнения запроса.
|
165 |
"""
|
166 |
+
model = self.get_model()
|
|
|
|
|
|
|
167 |
|
168 |
request = {
|
169 |
"stream": True,
|
170 |
"model": model,
|
171 |
}
|
172 |
|
173 |
+
predict_params = super.params.predict_params
|
174 |
if predict_params:
|
175 |
if predict_params.stop:
|
176 |
non_empty_stop = list(filter(lambda o: o != "", predict_params.stop))
|
|
|
251 |
# Максимально допустимое количество токенов для источников
|
252 |
max_length = (
|
253 |
max_token_count
|
254 |
+
- (super.params.predict_params.n_predict or 0)
|
255 |
- aux_token_count
|
256 |
- system_prompt_token_count
|
257 |
)
|
|
|
290 |
request = await self.create_request(prompt)
|
291 |
|
292 |
# Начинаем потоковый запрос
|
293 |
+
async with client.stream("POST", f"{super.params.url}/v1/chat/completions", json=request) as response:
|
294 |
if response.status_code != 200:
|
295 |
# Если ошибка, читаем ответ для получения подробностей
|
296 |
error_content = await response.aread()
|
transaction_maps_search.py
CHANGED
@@ -3,14 +3,13 @@ from business_transaction_map.common.constants import DEVICE, DO_NORMALIZATION,
|
|
3 |
from business_transaction_map.components.faiss_vector_database import FaissVectorDatabase
|
4 |
from business_transaction_map.components.embedding_extraction import EmbeddingExtractor
|
5 |
import os
|
6 |
-
import requests
|
7 |
from prompts import BUSINESS_TRANSACTION_PROMPT
|
|
|
8 |
|
9 |
|
10 |
db_files_path = os.environ.get("GLOBAL_TRANSACTION_MAPS_DATA_PATH", "transaction_maps_search_data/csv/карта_проводок_new.pkl")
|
11 |
|
12 |
model_path = os.environ.get("GLOBAL_TRANSACTION_MAPS_MODEL_PATH", "")
|
13 |
-
llm_api_endpoint = os.environ.get("LLM_API_ENDPOINT", "")
|
14 |
|
15 |
class TransactionMapsSearch:
|
16 |
|
@@ -26,14 +25,11 @@ class TransactionMapsSearch:
|
|
26 |
self.database = FaissVectorDatabase(str(db_files_path))
|
27 |
|
28 |
@staticmethod
|
29 |
-
def extract_business_transaction_with_llm(question: str) -> str:
|
|
|
|
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
response = requests.post(url=llm_api_endpoint,
|
34 |
-
json={"prompt": f"[INST] {question} [/INST]", #пробелы внутри [INST], как оказалось, важны. Без них можно словить бесконечную генерацию бреда от ллм
|
35 |
-
"temperature": 0.0})
|
36 |
-
return response.json()['content']
|
37 |
|
38 |
|
39 |
@staticmethod
|
@@ -66,13 +62,14 @@ class TransactionMapsSearch:
|
|
66 |
return answer
|
67 |
|
68 |
|
69 |
-
def search_transaction_map(self,
|
70 |
query: str = None,
|
71 |
find_transaction_maps_by_question: bool = False,
|
72 |
-
k_neighbours: int = 15
|
|
|
73 |
|
74 |
if find_transaction_maps_by_question:
|
75 |
-
query = self.extract_business_transaction_with_llm(query)
|
76 |
cleaned_text = query.replace("\n", " ")
|
77 |
# cleaned_text = 'query: ' + cleaned_text # only for e5
|
78 |
query_tokens = self.model.query_tokenization(cleaned_text)
|
|
|
3 |
from business_transaction_map.components.faiss_vector_database import FaissVectorDatabase
|
4 |
from business_transaction_map.components.embedding_extraction import EmbeddingExtractor
|
5 |
import os
|
|
|
6 |
from prompts import BUSINESS_TRANSACTION_PROMPT
|
7 |
+
from llm.common import LlmApi
|
8 |
|
9 |
|
10 |
db_files_path = os.environ.get("GLOBAL_TRANSACTION_MAPS_DATA_PATH", "transaction_maps_search_data/csv/карта_проводок_new.pkl")
|
11 |
|
12 |
model_path = os.environ.get("GLOBAL_TRANSACTION_MAPS_MODEL_PATH", "")
|
|
|
13 |
|
14 |
class TransactionMapsSearch:
|
15 |
|
|
|
25 |
self.database = FaissVectorDatabase(str(db_files_path))
|
26 |
|
27 |
@staticmethod
|
28 |
+
async def extract_business_transaction_with_llm(question: str, llm_api: LlmApi) -> str:
|
29 |
+
prompt = BUSINESS_TRANSACTION_PROMPT.replace('{{ЗАПРОС}}', question)
|
30 |
+
res = await llm_api.predict(prompt)
|
31 |
|
32 |
+
return res
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
|
35 |
@staticmethod
|
|
|
62 |
return answer
|
63 |
|
64 |
|
65 |
+
async def search_transaction_map(self,
|
66 |
query: str = None,
|
67 |
find_transaction_maps_by_question: bool = False,
|
68 |
+
k_neighbours: int = 15,
|
69 |
+
llm_api: LlmApi = None):
|
70 |
|
71 |
if find_transaction_maps_by_question:
|
72 |
+
query = await self.extract_business_transaction_with_llm(query, llm_api)
|
73 |
cleaned_text = query.replace("\n", " ")
|
74 |
# cleaned_text = 'query: ' + cleaned_text # only for e5
|
75 |
query_tokens = self.model.query_tokenization(cleaned_text)
|