muryshev commited on
Commit
67beed8
·
1 Parent(s): 85dfc4f
fastapi_app.py CHANGED
@@ -6,7 +6,8 @@ import os
6
  import datetime
7
  import json
8
  import traceback
9
- from llm.vllm_api import LlmParams
 
10
 
11
  # Set the path for log files
12
  LOGS_BASE_PATH = os.getenv("LOGS_BASE_PATH", "logs")
@@ -17,7 +18,9 @@ LOGS_BASE_PATH = os.getenv("LOGS_BASE_PATH", "logs")
17
 
18
  # Check if logs are enabled
19
  ENABLE_LOGS = os.getenv("ENABLE_LOGS", "0") == "1"
20
-
 
 
21
 
22
  class Query(BaseModel):
23
  query: str = ''
@@ -87,18 +90,26 @@ async def search_route(query: Query) -> dict:
87
 
88
  llm_params = getattr(query, "llm_params", None)
89
 
 
 
 
 
 
 
 
90
  if find_transaction_maps_by_question or find_transaction_maps_by_operation:
91
- transaction_maps_results, answer = transaction_maps_search.search_transaction_map(
92
  query=question,
93
  find_transaction_maps_by_question=find_transaction_maps_by_question,
94
- k_neighbours=top)
 
95
 
96
  response = {'transaction_maps_results': transaction_maps_results}
97
 
98
  else:
99
  modified_query, titles, concat_docs, \
100
  relevant_consultations, predicted_explanation, \
101
- llm_responses = await search.search(question, use_qe, use_olympic, categories, query.llm_params)
102
 
103
  results = [{'title': str(item1), 'text_for_llm': str(item2)} for item1, item2 in
104
  zip(titles, concat_docs)]
 
6
  import datetime
7
  import json
8
  import traceback
9
+ from llm.common import LlmParams, LlmPredictParams
10
+ from llm.deepinfra_api import DeepInfraApi
11
 
12
  # Set the path for log files
13
  LOGS_BASE_PATH = os.getenv("LOGS_BASE_PATH", "logs")
 
18
 
19
  # Check if logs are enabled
20
  ENABLE_LOGS = os.getenv("ENABLE_LOGS", "0") == "1"
21
+ LLM_API_URL = os.getenv("LLM_API_URL", "")
22
+ LLM_API_KEY = os.getenv("LLM_API_KEY", "")
23
+ LLM_USE_DEEPINFRA = os.getenv("LLM_USE_DEEPINFRA", "") == "1"
24
 
25
  class Query(BaseModel):
26
  query: str = ''
 
90
 
91
  llm_params = getattr(query, "llm_params", None)
92
 
93
+ if llm_params is None:
94
+ llm_params = LlmParams(url=LLM_API_URL,api_key=LLM_API_KEY, model="mistralai/Mixtral-8x7B-Instruct-v0.1", predict_params=LlmPredictParams(temperature=0.15, top_p=0.95, min_p=0.05, seed=42, repetition_penalty=1.2, presence_penalty=1.1, max_tokens=6000))
95
+
96
+ if LLM_USE_DEEPINFRA:
97
+ llm_api = DeepInfraApi(llm_params)
98
+
99
+
100
  if find_transaction_maps_by_question or find_transaction_maps_by_operation:
101
+ transaction_maps_results, answer = await transaction_maps_search.search_transaction_map(
102
  query=question,
103
  find_transaction_maps_by_question=find_transaction_maps_by_question,
104
+ k_neighbours=top,
105
+ llm_api=llm_api)
106
 
107
  response = {'transaction_maps_results': transaction_maps_results}
108
 
109
  else:
110
  modified_query, titles, concat_docs, \
111
  relevant_consultations, predicted_explanation, \
112
+ llm_responses = await search.search(question, use_qe, use_olympic, categories, llm_params)
113
 
114
  results = [{'title': str(item1), 'text_for_llm': str(item2)} for item1, item2 in
115
  zip(titles, concat_docs)]
llm/common.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import Optional, List, Protocol
3
+ from abc import ABC, abstractmethod
4
+
5
+ class LlmPredictParams(BaseModel):
6
+ """
7
+ Параметры для предсказания LLM.
8
+ """
9
+ system_prompt: Optional[str] = Field(None, description="Системный промпт.")
10
+ user_prompt: Optional[str] = Field(None, description="Шаблон промпта для передачи от роли user.")
11
+ n_predict: Optional[int] = None
12
+ temperature: Optional[float] = None
13
+ top_k: Optional[int] = None
14
+ top_p: Optional[float] = None
15
+ min_p: Optional[float] = None
16
+ seed: Optional[int] = None
17
+ repeat_penalty: Optional[float] = None
18
+ repeat_last_n: Optional[int] = None
19
+ retry_if_text_not_present: Optional[str] = None
20
+ retry_count: Optional[int] = None
21
+ presence_penalty: Optional[float] = None
22
+ frequency_penalty: Optional[float] = None
23
+ n_keep: Optional[int] = None
24
+ cache_prompt: Optional[bool] = None
25
+ stop: Optional[List[str]] = None
26
+
27
+
28
+ class LlmParams(BaseModel):
29
+ """
30
+ Основные параметры для LLM.
31
+ """
32
+ url: str
33
+ model: Optional[str] = Field(None, description="Предполагается, что для локального API этот параметр не будет указываться, т.к. будем брать первую модель из списка потому, что модель доступна всего одна. Для deepinfra такой подход не подойдет и модель нужно задавать явно.")
34
+ type: Optional[str] = None
35
+ default: Optional[bool] = None
36
+ template: Optional[str] = None
37
+ predict_params: Optional[LlmPredictParams] = None
38
+ api_key: Optional[str] = None
39
+
40
+ class LlmApiProtocol(Protocol):
41
+ async def tokenize(self, prompt: str) -> Optional[dict]:
42
+ ...
43
+ async def detokenize(self, tokens: List[int]) -> Optional[str]:
44
+ ...
45
+ async def trim_sources(self, sources: str, user_request: str, system_prompt: str = None) -> dict:
46
+ ...
47
+ async def predict(self, prompt: str) -> str:
48
+ ...
49
+
50
+ class LlmApi:
51
+ """
52
+ Базовый клас для работы с API LLM.
53
+ """
54
+ params: LlmParams = None
55
+
56
+
57
+ def create_headers(self) -> dict[str, str]:
58
+ headers = {"Content-Type": "application/json"}
59
+
60
+ if self.params.api_key is not None:
61
+ headers["Authorization"] = self.params.api_key
62
+
63
+ return headers
64
+
llm/deepinfra_api.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Optional, List
3
+ import httpx
4
+ from common import LlmPredictParams, LlmParams, LlmApi
5
+
6
+ class DeepInfraApi(LlmApi):
7
+ """
8
+ Класс для работы с API vllm.
9
+ """
10
+
11
+ def __init__(self, params: LlmParams):
12
+ super.params = params
13
+
14
+
15
+ async def get_models(self) -> List[str]:
16
+ """
17
+ Выполняет GET-запрос к API для получения списка доступных моделей.
18
+
19
+ Возвращает:
20
+ list[str]: Список идентификаторов моделей.
21
+ Если произошла ошибка или данные недоступны, возвращается пустой список.
22
+
23
+ Исключения:
24
+ Все ошибки HTTP-запросов логируются в консоль, но не выбрасываются дальше.
25
+ """
26
+ try:
27
+ async with httpx.AsyncClient() as client:
28
+ response = await client.get(f"{super.params.url}/v1/openai/models", super.create_headers())
29
+ if response.status_code == 200:
30
+ json_data = response.json()
31
+ return [item['id'] for item in json_data.get('data', [])]
32
+ except httpx.RequestError as error:
33
+ print('Error fetching models:', error)
34
+ return []
35
+
36
+ def create_messages(self, prompt: str) -> List[dict]:
37
+ """
38
+ Создает сообщения для LLM на основе переданного промпта и системного промпта (если он задан).
39
+
40
+ Args:
41
+ prompt (str): Пользовательский промпт.
42
+
43
+ Returns:
44
+ list[dict]: Список сообщений с ролями и содержимым.
45
+ """
46
+ actual_prompt = self.apply_llm_template_to_prompt(prompt)
47
+ messages = []
48
+ if super.params.predict_params and super.params.predict_params.system_prompt:
49
+ messages.append({"role": "system", "content": super.params.predict_params.system_prompt})
50
+ messages.append({"role": "user", "content": actual_prompt})
51
+ return messages
52
+
53
+ def apply_llm_template_to_prompt(self, prompt: str) -> str:
54
+ """
55
+ Применяет шаблон LLM к переданному промпту, если он задан.
56
+
57
+ Args:
58
+ prompt (str): Пользовательский промпт.
59
+
60
+ Returns:
61
+ str: Промпт с примененным шаблоном (или оригинальный, если шаблон отсутствует).
62
+ """
63
+ actual_prompt = prompt
64
+ if super.params.template is not None:
65
+ actual_prompt = super.params.template.replace("{{PROMPT}}", actual_prompt)
66
+ return actual_prompt
67
+
68
+ async def tokenize(self, prompt: str) -> Optional[dict]:
69
+ raise NotImplementedError("This function is not supported.")
70
+
71
+ async def detokenize(self, tokens: List[int]) -> Optional[str]:
72
+ raise NotImplementedError("This function is not supported.")
73
+
74
+ async def create_request(self, prompt: str) -> dict:
75
+ """
76
+ Создает запрос для предсказания на основе параметров LLM.
77
+
78
+ Args:
79
+ prompt (str): Промпт для запроса.
80
+
81
+ Returns:
82
+ dict: Словарь с параметрами для выполнения запроса.
83
+ """
84
+
85
+ request = {
86
+ "stream": False,
87
+ "model": super.params.model,
88
+ }
89
+
90
+ predict_params = super.params.predict_params
91
+ if predict_params:
92
+ if predict_params.stop:
93
+ non_empty_stop = list(filter(lambda o: o != "", predict_params.stop))
94
+ if non_empty_stop:
95
+ request["stop"] = non_empty_stop
96
+
97
+ if predict_params.n_predict is not None:
98
+ request["max_tokens"] = int(predict_params.n_predict or 0)
99
+
100
+ request["temperature"] = float(predict_params.temperature or 0)
101
+ if predict_params.top_k is not None:
102
+ request["top_k"] = int(predict_params.top_k)
103
+
104
+ if predict_params.top_p is not None:
105
+ request["top_p"] = float(predict_params.top_p)
106
+
107
+ if predict_params.min_p is not None:
108
+ request["min_p"] = float(predict_params.min_p)
109
+
110
+ if predict_params.seed is not None:
111
+ request["seed"] = int(predict_params.seed)
112
+
113
+ if predict_params.n_keep is not None:
114
+ request["n_keep"] = int(predict_params.n_keep)
115
+
116
+ if predict_params.cache_prompt is not None:
117
+ request["cache_prompt"] = bool(predict_params.cache_prompt)
118
+
119
+ if predict_params.repeat_penalty is not None:
120
+ request["repetition_penalty"] = float(predict_params.repeat_penalty)
121
+
122
+ if predict_params.repeat_last_n is not None:
123
+ request["repeat_last_n"] = int(predict_params.repeat_last_n)
124
+
125
+ if predict_params.presence_penalty is not None:
126
+ request["presence_penalty"] = float(predict_params.presence_penalty)
127
+
128
+ if predict_params.frequency_penalty is not None:
129
+ request["frequency_penalty"] = float(predict_params.frequency_penalty)
130
+
131
+ request["messages"] = self.create_messages(prompt)
132
+ return request
133
+
134
+ async def trim_sources(self, sources: str, user_request: str, system_prompt: str = None) -> dict:
135
+ raise NotImplementedError("This function is not supported.")
136
+
137
+ async def predict(self, prompt: str) -> str:
138
+ """
139
+ Выполняет запрос к API и возвращает результат.
140
+
141
+ Args:
142
+ prompt (str): Входной текст для предсказания.
143
+
144
+ Returns:
145
+ str: Сгенерированный текст.
146
+ """
147
+ async with httpx.AsyncClient() as client:
148
+ request = await self.create_request(prompt)
149
+
150
+ async with httpx.AsyncClient() as client:
151
+ response = client.post(f"{super.params.url}/v1/openai/chat/completions", super.create_headers(), json=request)
152
+ if response.status_code == 200:
153
+ return response.json()["choices"][0]["message"]["content"]
llm/vllm_api.py CHANGED
@@ -3,51 +3,17 @@ from typing import Optional, List, Any
3
 
4
  import httpx
5
  from pydantic import BaseModel, Field
 
6
 
7
 
8
- class LlmPredictParams(BaseModel):
9
- """
10
- Параметры для предсказания LLM.
11
- """
12
- system_prompt: Optional[str] = Field(None, description="Системный промпт.")
13
- user_prompt: Optional[str] = Field(None, description="Шаблон промпта для передачи от роли user.")
14
- n_predict: Optional[int] = None
15
- temperature: Optional[float] = None
16
- top_k: Optional[int] = None
17
- top_p: Optional[float] = None
18
- min_p: Optional[float] = None
19
- seed: Optional[int] = None
20
- repeat_penalty: Optional[float] = None
21
- repeat_last_n: Optional[int] = None
22
- retry_if_text_not_present: Optional[str] = None
23
- retry_count: Optional[int] = None
24
- presence_penalty: Optional[float] = None
25
- frequency_penalty: Optional[float] = None
26
- n_keep: Optional[int] = None
27
- cache_prompt: Optional[bool] = None
28
- stop: Optional[List[str]] = None
29
-
30
-
31
- class LlmParams(BaseModel):
32
- """
33
- Основные параметры для LLM.
34
- """
35
- url: str
36
- type: Optional[str] = None
37
- default: Optional[bool] = None
38
- template: Optional[str] = None
39
- predict_params: Optional[LlmPredictParams] = None
40
-
41
-
42
- class LlmApi:
43
  """
44
  Класс для работы с API vllm.
45
  """
46
- params: LlmParams = None
47
 
48
  def __init__(self, params: LlmParams):
49
- self.params = params
50
-
51
  async def get_models(self) -> List[str]:
52
  """
53
  Выполняет GET-запрос к API для получения списка доступных моделей.
@@ -61,13 +27,26 @@ class LlmApi:
61
  """
62
  try:
63
  async with httpx.AsyncClient() as client:
64
- response = await client.get(f"{self.params.url}/v1/models", headers={"Content-Type": "application/json"})
65
  if response.status_code == 200:
66
  json_data = response.json()
67
  return [item['id'] for item in json_data.get('data', [])]
68
  except httpx.RequestError as error:
69
  print('Error fetching models:', error)
70
  return []
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  def create_messages(self, prompt: str) -> List[dict]:
73
  """
@@ -81,8 +60,8 @@ class LlmApi:
81
  """
82
  actual_prompt = self.apply_llm_template_to_prompt(prompt)
83
  messages = []
84
- if self.params.predict_params and self.params.predict_params.system_prompt:
85
- messages.append({"role": "system", "content": self.params.predict_params.system_prompt})
86
  messages.append({"role": "user", "content": actual_prompt})
87
  return messages
88
 
@@ -97,8 +76,8 @@ class LlmApi:
97
  str: Промпт с примененным шаблоном (или оригинальный, если шаблон отсутствует).
98
  """
99
  actual_prompt = prompt
100
- if self.params.template is not None:
101
- actual_prompt = self.params.template.replace("{{PROMPT}}", actual_prompt)
102
  return actual_prompt
103
 
104
  async def tokenize(self, prompt: str) -> Optional[dict]:
@@ -112,14 +91,10 @@ class LlmApi:
112
  Optional[dict]: Словарь с токенами и максимальной длиной модели, если запрос успешен.
113
  Если запрос неуспешен, возвращает None.
114
  """
115
- model = (await self.get_models())[0] if await self.get_models() else None
116
- if not model:
117
- print("No models available for tokenization.")
118
- return None
119
 
120
  actual_prompt = self.apply_llm_template_to_prompt(prompt)
121
  request_data = {
122
- "model": model,
123
  "prompt": actual_prompt,
124
  "add_special_tokens": False,
125
  }
@@ -127,9 +102,9 @@ class LlmApi:
127
  try:
128
  async with httpx.AsyncClient() as client:
129
  response = await client.post(
130
- f"{self.params.url}/tokenize",
131
  json=request_data,
132
- headers={"Content-Type": "application/json"},
133
  )
134
  if response.status_code == 200:
135
  data = response.json()
@@ -155,19 +130,15 @@ class LlmApi:
155
  Optional[str]: Строка, полученная в результате детокенизации, если запрос успешен.
156
  Если запрос неуспешен, возвращает None.
157
  """
158
- model = (await self.get_models())[0] if await self.get_models() else None
159
- if not model:
160
- print("No models available for detokenization.")
161
- return None
162
-
163
- request_data = {"model": model, "tokens": tokens or []}
164
 
165
  try:
166
  async with httpx.AsyncClient() as client:
167
  response = await client.post(
168
- f"{self.params.url}/detokenize",
169
  json=request_data,
170
- headers={"Content-Type": "application/json"},
171
  )
172
  if response.status_code == 200:
173
  data = response.json()
@@ -192,17 +163,14 @@ class LlmApi:
192
  Returns:
193
  dict: Словарь с параметрами для выполнения запроса.
194
  """
195
- models = await self.get_models()
196
- if not models:
197
- raise ValueError("No models available to create a request.")
198
- model = models[0]
199
 
200
  request = {
201
  "stream": True,
202
  "model": model,
203
  }
204
 
205
- predict_params = self.params.predict_params
206
  if predict_params:
207
  if predict_params.stop:
208
  non_empty_stop = list(filter(lambda o: o != "", predict_params.stop))
@@ -283,7 +251,7 @@ class LlmApi:
283
  # Максимально допустимое количество токенов для источников
284
  max_length = (
285
  max_token_count
286
- - (self.params.predict_params.n_predict or 0)
287
  - aux_token_count
288
  - system_prompt_token_count
289
  )
@@ -322,7 +290,7 @@ class LlmApi:
322
  request = await self.create_request(prompt)
323
 
324
  # Начинаем потоковый запрос
325
- async with client.stream("POST", f"{self.params.url}/v1/chat/completions", json=request) as response:
326
  if response.status_code != 200:
327
  # Если ошибка, читаем ответ для получения подробностей
328
  error_content = await response.aread()
 
3
 
4
  import httpx
5
  from pydantic import BaseModel, Field
6
+ from common import LlmPredictParams, LlmParams, LlmApi
7
 
8
 
9
+ class LlmApi(LlmApi):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  """
11
  Класс для работы с API vllm.
12
  """
 
13
 
14
  def __init__(self, params: LlmParams):
15
+ super.params = params
16
+
17
  async def get_models(self) -> List[str]:
18
  """
19
  Выполняет GET-запрос к API для получения списка доступных моделей.
 
27
  """
28
  try:
29
  async with httpx.AsyncClient() as client:
30
+ response = await client.get(f"{super.params.url}/v1/models", super.create_headers())
31
  if response.status_code == 200:
32
  json_data = response.json()
33
  return [item['id'] for item in json_data.get('data', [])]
34
  except httpx.RequestError as error:
35
  print('Error fetching models:', error)
36
  return []
37
+
38
+ async def get_model(self) -> str:
39
+ model = None
40
+ if super.params.model is not None:
41
+ model = super.params.model
42
+ else:
43
+ models = await self.get_models()
44
+ model = models[0] if models else None
45
+
46
+ if model is None:
47
+ raise Exception("No model name provided and no models available.")
48
+
49
+ return model
50
 
51
  def create_messages(self, prompt: str) -> List[dict]:
52
  """
 
60
  """
61
  actual_prompt = self.apply_llm_template_to_prompt(prompt)
62
  messages = []
63
+ if super.params.predict_params and super.params.predict_params.system_prompt:
64
+ messages.append({"role": "system", "content": super.params.predict_params.system_prompt})
65
  messages.append({"role": "user", "content": actual_prompt})
66
  return messages
67
 
 
76
  str: Промпт с примененным шаблоном (или оригинальный, если шаблон отсутствует).
77
  """
78
  actual_prompt = prompt
79
+ if super.params.template is not None:
80
+ actual_prompt = super.params.template.replace("{{PROMPT}}", actual_prompt)
81
  return actual_prompt
82
 
83
  async def tokenize(self, prompt: str) -> Optional[dict]:
 
91
  Optional[dict]: Словарь с токенами и максимальной длиной модели, если запрос успешен.
92
  Если запрос неуспешен, возвращает None.
93
  """
 
 
 
 
94
 
95
  actual_prompt = self.apply_llm_template_to_prompt(prompt)
96
  request_data = {
97
+ "model": self.get_model(),
98
  "prompt": actual_prompt,
99
  "add_special_tokens": False,
100
  }
 
102
  try:
103
  async with httpx.AsyncClient() as client:
104
  response = await client.post(
105
+ f"{super.params.url}/tokenize",
106
  json=request_data,
107
+ headers=super.create_headers(),
108
  )
109
  if response.status_code == 200:
110
  data = response.json()
 
130
  Optional[str]: Строка, полученная в результате детокенизации, если запрос успешен.
131
  Если запрос неуспешен, возвращает None.
132
  """
133
+
134
+ request_data = {"model": self.get_model(), "tokens": tokens or []}
 
 
 
 
135
 
136
  try:
137
  async with httpx.AsyncClient() as client:
138
  response = await client.post(
139
+ f"{super.params.url}/detokenize",
140
  json=request_data,
141
+ headers=super.create_headers(),
142
  )
143
  if response.status_code == 200:
144
  data = response.json()
 
163
  Returns:
164
  dict: Словарь с параметрами для выполнения запроса.
165
  """
166
+ model = self.get_model()
 
 
 
167
 
168
  request = {
169
  "stream": True,
170
  "model": model,
171
  }
172
 
173
+ predict_params = super.params.predict_params
174
  if predict_params:
175
  if predict_params.stop:
176
  non_empty_stop = list(filter(lambda o: o != "", predict_params.stop))
 
251
  # Максимально допустимое количество токенов для источников
252
  max_length = (
253
  max_token_count
254
+ - (super.params.predict_params.n_predict or 0)
255
  - aux_token_count
256
  - system_prompt_token_count
257
  )
 
290
  request = await self.create_request(prompt)
291
 
292
  # Начинаем потоковый запрос
293
+ async with client.stream("POST", f"{super.params.url}/v1/chat/completions", json=request) as response:
294
  if response.status_code != 200:
295
  # Если ошибка, читаем ответ для получения подробностей
296
  error_content = await response.aread()
transaction_maps_search.py CHANGED
@@ -3,14 +3,13 @@ from business_transaction_map.common.constants import DEVICE, DO_NORMALIZATION,
3
  from business_transaction_map.components.faiss_vector_database import FaissVectorDatabase
4
  from business_transaction_map.components.embedding_extraction import EmbeddingExtractor
5
  import os
6
- import requests
7
  from prompts import BUSINESS_TRANSACTION_PROMPT
 
8
 
9
 
10
  db_files_path = os.environ.get("GLOBAL_TRANSACTION_MAPS_DATA_PATH", "transaction_maps_search_data/csv/карта_проводок_new.pkl")
11
 
12
  model_path = os.environ.get("GLOBAL_TRANSACTION_MAPS_MODEL_PATH", "")
13
- llm_api_endpoint = os.environ.get("LLM_API_ENDPOINT", "")
14
 
15
  class TransactionMapsSearch:
16
 
@@ -26,14 +25,11 @@ class TransactionMapsSearch:
26
  self.database = FaissVectorDatabase(str(db_files_path))
27
 
28
  @staticmethod
29
- def extract_business_transaction_with_llm(question: str) -> str:
 
 
30
 
31
- question = BUSINESS_TRANSACTION_PROMPT.replace('{{ЗАПРОС}}', question)
32
-
33
- response = requests.post(url=llm_api_endpoint,
34
- json={"prompt": f"[INST] {question} [/INST]", #пробелы внутри [INST], как оказалось, важны. Без них можно словить бесконечную генерацию бреда от ллм
35
- "temperature": 0.0})
36
- return response.json()['content']
37
 
38
 
39
  @staticmethod
@@ -66,13 +62,14 @@ class TransactionMapsSearch:
66
  return answer
67
 
68
 
69
- def search_transaction_map(self,
70
  query: str = None,
71
  find_transaction_maps_by_question: bool = False,
72
- k_neighbours: int = 15):
 
73
 
74
  if find_transaction_maps_by_question:
75
- query = self.extract_business_transaction_with_llm(query)
76
  cleaned_text = query.replace("\n", " ")
77
  # cleaned_text = 'query: ' + cleaned_text # only for e5
78
  query_tokens = self.model.query_tokenization(cleaned_text)
 
3
  from business_transaction_map.components.faiss_vector_database import FaissVectorDatabase
4
  from business_transaction_map.components.embedding_extraction import EmbeddingExtractor
5
  import os
 
6
  from prompts import BUSINESS_TRANSACTION_PROMPT
7
+ from llm.common import LlmApi
8
 
9
 
10
  db_files_path = os.environ.get("GLOBAL_TRANSACTION_MAPS_DATA_PATH", "transaction_maps_search_data/csv/карта_проводок_new.pkl")
11
 
12
  model_path = os.environ.get("GLOBAL_TRANSACTION_MAPS_MODEL_PATH", "")
 
13
 
14
  class TransactionMapsSearch:
15
 
 
25
  self.database = FaissVectorDatabase(str(db_files_path))
26
 
27
  @staticmethod
28
+ async def extract_business_transaction_with_llm(question: str, llm_api: LlmApi) -> str:
29
+ prompt = BUSINESS_TRANSACTION_PROMPT.replace('{{ЗАПРОС}}', question)
30
+ res = await llm_api.predict(prompt)
31
 
32
+ return res
 
 
 
 
 
33
 
34
 
35
  @staticmethod
 
62
  return answer
63
 
64
 
65
+ async def search_transaction_map(self,
66
  query: str = None,
67
  find_transaction_maps_by_question: bool = False,
68
+ k_neighbours: int = 15,
69
+ llm_api: LlmApi = None):
70
 
71
  if find_transaction_maps_by_question:
72
+ query = await self.extract_business_transaction_with_llm(query, llm_api)
73
  cleaned_text = query.replace("\n", " ")
74
  # cleaned_text = 'query: ' + cleaned_text # only for e5
75
  query_tokens = self.model.query_tokenization(cleaned_text)