makcrx commited on
Commit
e42dbc0
·
1 Parent(s): d6a31a5
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  __pycache__
2
  *.sqlite
 
 
1
  __pycache__
2
  *.sqlite
3
+ models
extract_keywords.py CHANGED
@@ -14,14 +14,17 @@ aliases = [
14
  ('почта россия трекинг', ['пр трекинг', 'почта трекинг', 'пр трэкинг', 'почта трэкинг']),
15
  ('реестр почта', ['реестр пр', 'реестр почта россии']),
16
  ('реестр пэк', []),
 
17
  ('реквизиты', []),
18
  ('пешкарики', []),
19
  ('импорт лидов директ', []),
20
  ('яндекс доставка экспресс', ['яндекс доставка express', 'яд экспресс', 'ядоставка экспресс']),
21
  ('яндекс доставка ndd', ['яд ндд', 'я доставка ндд', 'ядоставка ндд', 'модуль ндд']),
 
22
  ('яндекс метрика', ['яндекс метрика импорт']),
23
  ('альфабанк', ['альфа банк', 'alfabank', 'альфа']),
24
  ('импорт лидов facebook', ['импорт лидов fb', 'загрузка лидов fb', 'лиды фейсбук', 'импорт лидов фб', 'fb lead']),
 
25
  ('маркетинговые расходы', ['расходы', 'загрузка расходов']),
26
  ('cloudpayments', ['клауд', 'клаудпеймент', 'клаудпейментс']),
27
  ('robokassa', ['робокасса', 'робокаса']),
@@ -30,12 +33,13 @@ aliases = [
30
  ('unisender', ['юнисендер']),
31
  ('яндекс аудитории', ['экспорт аудитории', 'экспорт яндекс аудитории']),
32
  ('экспорт facebook', ['экспорт сегментов facebook', 'экспорт fb', 'экспорт фейсбук', 'экспорт аудиторий фб', 'fb экспорт']),
33
- ('экспорт вк', ['экспорт сегментов vkontakte', 'экспорт vk', 'экспорт контакте']),
34
  ('retailcrm', ['срм', 'ритейл', 'ритейл срм', 'ритейлсрм', 'retail crm', 'ритейлцрм', 'ритейл црм']),
35
  ('retailcrm services', [
36
  'retailcrmservices', 'ритейлцрм services', 'лк crm services', 'ритейлцрм сервисес',
37
  'ритейлсрм сервисес', 'ритейлцрм сервисе', 'ритейлцрмсервисес', 'ритейлсрмсервисес',
38
- ])
 
39
  ]
40
 
41
  vocab_raw = flatten([[k] + keywords for k, keywords in aliases])
@@ -45,6 +49,8 @@ import pymorphy3
45
 
46
  morph = None
47
  def normalize_word(word):
 
 
48
  if word == 'лид':
49
  return word
50
  if word in ['росии', 'росси']:
@@ -59,7 +65,7 @@ def tokenize_sentence(text):
59
  # remove punctuation
60
  text = text.translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation)))
61
  # tokenize
62
- return [normalize_word(word) for word in text.split()]
63
 
64
  def normalize_sentence(text):
65
  return " ".join(tokenize_sentence(text))
@@ -114,6 +120,9 @@ def init_keyword_extractor():
114
  from keybert import KeyBERT
115
  import spacy
116
  from sklearn.feature_extraction.text import CountVectorizer
 
 
 
117
 
118
  kw_model = KeyBERT(model=spacy.load("ru_core_news_sm", exclude=['tokenizer', 'tagger', 'parser', 'ner', 'attribute_ruler', 'lemmatizer']))
119
  vocab = [" ".join(tokenize_sentence(s)) for s in vocab_raw]
@@ -126,5 +135,15 @@ def extract_keywords(text):
126
  if vectorizer is None or kw_model is None:
127
  init_keyword_extractor()
128
 
129
- keywords = [k for k, score in kw_model.extract_keywords(text, vectorizer=vectorizer)]
 
130
  return merge_keywords(canonical_keywords(keywords))
 
 
 
 
 
 
 
 
 
 
14
  ('почта россия трекинг', ['пр трекинг', 'почта трекинг', 'пр трэкинг', 'почта трэкинг']),
15
  ('реестр почта', ['реестр пр', 'реестр почта россии']),
16
  ('реестр пэк', []),
17
+ ('реестры наложенных платежей', ['документы наложенных платежей']),
18
  ('реквизиты', []),
19
  ('пешкарики', []),
20
  ('импорт лидов директ', []),
21
  ('яндекс доставка экспресс', ['яндекс доставка express', 'яд экспресс', 'ядоставка экспресс']),
22
  ('яндекс доставка ndd', ['яд ндд', 'я доставка ндд', 'ядоставка ндд', 'модуль ндд']),
23
+ ('яндекс доставка', ['яд', 'я доставка', 'ядоставка']),
24
  ('яндекс метрика', ['яндекс метрика импорт']),
25
  ('альфабанк', ['альфа банк', 'alfabank', 'альфа']),
26
  ('импорт лидов facebook', ['импорт лидов fb', 'загрузка лидов fb', 'лиды фейсбук', 'импорт лидов фб', 'fb lead']),
27
+ ('импорт лидов вк', ['импорт лидов вконтакте', 'загрузка лидов вк', 'лиды вконтакте', 'импорт лидов vk', 'vk lead']),
28
  ('маркетинговые расходы', ['расходы', 'загрузка расходов']),
29
  ('cloudpayments', ['клауд', 'клаудпеймент', 'клаудпейментс']),
30
  ('robokassa', ['робокасса', 'робокаса']),
 
33
  ('unisender', ['юнисендер']),
34
  ('яндекс аудитории', ['экспорт аудитории', 'экспорт яндекс аудитории']),
35
  ('экспорт facebook', ['экспорт сегментов facebook', 'экспорт fb', 'экспорт фейсбук', 'экспорт аудиторий фб', 'fb экспорт']),
36
+ ('экспорт вк', ['экспорт сегментов vkontakte', 'экспорт vk', 'экспорт контакте', 'экспорт сегментов вконтакте']),
37
  ('retailcrm', ['срм', 'ритейл', 'ритейл срм', 'ритейлсрм', 'retail crm', 'ритейлцрм', 'ритейл црм']),
38
  ('retailcrm services', [
39
  'retailcrmservices', 'ритейлцрм services', 'лк crm services', 'ритейлцрм сервисес',
40
  'ритейлсрм сервисес', 'ритейлцрм сервисе', 'ритейлцрмсервисес', 'ритейлсрмсервисес',
41
+ ]),
42
+ ('digital pipeline', ['digital pipline']),
43
  ]
44
 
45
  vocab_raw = flatten([[k] + keywords for k, keywords in aliases])
 
49
 
50
  morph = None
51
  def normalize_word(word):
52
+ if word == 'в' or word == 'из':
53
+ return ''
54
  if word == 'лид':
55
  return word
56
  if word in ['росии', 'росси']:
 
65
  # remove punctuation
66
  text = text.translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation)))
67
  # tokenize
68
+ return list(filter(bool, [normalize_word(word) for word in text.split()]))
69
 
70
  def normalize_sentence(text):
71
  return " ".join(tokenize_sentence(text))
 
120
  from keybert import KeyBERT
121
  import spacy
122
  from sklearn.feature_extraction.text import CountVectorizer
123
+
124
+ import warnings
125
+ warnings.filterwarnings("ignore", category=UserWarning)
126
 
127
  kw_model = KeyBERT(model=spacy.load("ru_core_news_sm", exclude=['tokenizer', 'tagger', 'parser', 'ner', 'attribute_ruler', 'lemmatizer']))
128
  vocab = [" ".join(tokenize_sentence(s)) for s in vocab_raw]
 
135
  if vectorizer is None or kw_model is None:
136
  init_keyword_extractor()
137
 
138
+ #print(normalize_sentence(text))
139
+ keywords = [k for k, score in kw_model.extract_keywords(normalize_sentence(text), vectorizer=vectorizer)]
140
  return merge_keywords(canonical_keywords(keywords))
141
+
142
+ def extract_keywords2(text):
143
+ vocab = sorted([" ".join(tokenize_sentence(s)) for s in vocab_raw], key=len, reverse=True)
144
+ text = normalize_sentence(text)
145
+ keywords = []
146
+ for w in vocab:
147
+ if w in text:
148
+ keywords.append(w)
149
+ return merge_keywords(canonical_keywords(keywords))
finetune_crossencoder.ipynb ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from list_questions import load_questions\n",
10
+ "from extract_keywords import extract_keywords, extract_keywords2\n",
11
+ "db_name = 'omnidesk-ai-chatgpt-questions.sqlite'"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 2,
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "from sentence_transformers import InputExample\n",
21
+ "import random\n",
22
+ "\n",
23
+ "def get_user_question(q):\n",
24
+ " keywords = extract_keywords2(q['query'])\n",
25
+ " return ' '.join([q['question'].strip(), ' '.join(keywords)]).lower()\n",
26
+ " \n",
27
+ "def get_system_question(q):\n",
28
+ " return q['query'].lower()\n",
29
+ " \n",
30
+ "def get_negative_system_question(q, all_questions):\n",
31
+ " negative_q = random.choice(list(filter(lambda q2: q['query'] != q2['query'], all_questions)))\n",
32
+ " return negative_q['query'].lower()\n",
33
+ "\n",
34
+ "def input_example_generator():\n",
35
+ " all_questions = list(load_questions(db_name))\n",
36
+ " for q in all_questions:\n",
37
+ " yield InputExample(texts=[get_user_question(q), get_system_question(q)], label=1.0)\n",
38
+ " yield InputExample(texts=[get_user_question(q), get_negative_system_question(q, all_questions)], label=0.0)"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": 3,
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "from torch.utils.data import IterableDataset, DataLoader\n",
48
+ "\n",
49
+ "additional_examples = [\n",
50
+ " InputExample(texts=['добрый день', 'добрый день, здравствуйте'], label=1.0),\n",
51
+ " InputExample(texts=['здравствуйте', 'добрый день, здравствуйте'], label=1.0),\n",
52
+ " InputExample(texts=['привет', 'добрый день, здравствуйте'], label=1.0),\n",
53
+ " InputExample(texts=['спасибо', 'спасибо, до свидания'], label=1.0),\n",
54
+ " InputExample(texts=['до свидания', 'спасибо, до свидания'], label=1.0),\n",
55
+ " InputExample(texts=['не понял', 'некорректный ответ, не понял'], label=1.0),\n",
56
+ " InputExample(texts=['некорректный ответ', 'некорректный ответ, не понял'], label=1.0),\n",
57
+ " InputExample(texts=['как убрать ошибку', 'как убрать ошибку'], label=1.0),\n",
58
+ " InputExample(texts=['как устранить ошибку', 'как убрать ошибку'], label=1.0),\n",
59
+ " InputExample(texts=['как решить проблему с ошибкой', 'как убрать ошибку'], label=1.0),\n",
60
+ " InputExample(texts=['есть ли способ устранить ошибку', 'как убрать ошибку'], label=1.0),\n",
61
+ " InputExample(texts=['каким образом можно избавиться от ошибки', 'как убрать ошибку'], label=1.0),\n",
62
+ " InputExample(texts=['позови человека', 'позови человека сотрудника менеджера оператора'], label=1.0),\n",
63
+ " InputExample(texts=['позови сотрудника', 'позови человека сотрудника менеджера оператора'], label=1.0),\n",
64
+ " InputExample(texts=['позови менеджера', 'позови человека сотрудника менеджера оператора'], label=1.0),\n",
65
+ " InputExample(texts=['позови оператора', 'позови человека сотрудника менеджера оператора'], label=1.0),\n",
66
+ " InputExample(texts=['оператор', 'позови человека сотрудника менеджера оператора'], label=1.0),\n",
67
+ " InputExample(texts=['человек', 'позови человека сотрудника менеджера оператора'], label=1.0),\n",
68
+ " \n",
69
+ " # special cases\n",
70
+ " InputExample(texts=['можете подсказать, что делать с ошибкой', 'как убрать ошибку'], label=4.0),\n",
71
+ " InputExample(texts=['что произойдет при удалении оплаты cloudpayments', 'cloudpayments перенос оплаты в платежных модулях на примере модуля cloudpayments что произойдет при удалении оплаты'], label=1.0),\n",
72
+ " InputExample(texts=['превышен лимит количества контактов unisender', 'экспорт сегментов в unisender ошибка превышен лимит количества контактов для текущего превышен лимит количества контактов'], label=1.0),\n",
73
+ " InputExample(texts=['не отображаются тарифы', 'не передаются тарифы'], label=0.0),\n",
74
+ " \n",
75
+ " # ???\n",
76
+ " InputExample(texts=['почему количество пользователей отличается', 'почему clientid отличается'], label=0.0),\n",
77
+ " InputExample(texts=['что означает галка \\'доставка курьером\\'', 'что означает галка доставка курьером'], label=1.0),\n",
78
+ " \n",
79
+ " InputExample(texts=['почта россии', 'яндекс доставка'], label=0.0),\n",
80
+ " InputExample(texts=['почта россии', 'яндекс метрика'], label=0.0),\n",
81
+ " InputExample(texts=['яндекс доставка', 'яндекс метрика'], label=0.0),\n",
82
+ " InputExample(texts=['unisender', 'яндекс доставка'], label=0.0),\n",
83
+ " InputExample(texts=['альфабанк', 'яндекс доставка'], label=0.0),\n",
84
+ " InputExample(texts=['почта россии', 'яндекс аудитории'], label=0.0),\n",
85
+ " InputExample(texts=['sipuni', 'cloudpayments'], label=0.0),\n",
86
+ " InputExample(texts=['sipuni', 'facebook'], label=0.0),\n",
87
+ " InputExample(texts=['robokassa', 'вконтакте'], label=0.0),\n",
88
+ " InputExample(texts=['robokassa', 'digital pipeline'], label=0.0),\n",
89
+ " InputExample(texts=['facebook', 'вконтакте'], label=0.0),\n",
90
+ " InputExample(texts=['facebook', 'mailchimp'], label=0.0),\n",
91
+ " InputExample(texts=['почта россии', 'cloudpayments'], label=0.0),\n",
92
+ "]\n",
93
+ "\n",
94
+ "train_dataloader = DataLoader(list(input_example_generator()) + additional_examples, batch_size=16)"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "markdown",
99
+ "metadata": {},
100
+ "source": [
101
+ "# Pretrain"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": 4,
107
+ "metadata": {},
108
+ "outputs": [
109
+ {
110
+ "name": "stderr",
111
+ "output_type": "stream",
112
+ "text": [
113
+ "2023-08-22 14:47:41.400087: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
114
+ ]
115
+ }
116
+ ],
117
+ "source": [
118
+ "from sentence_transformers import CrossEncoder\n",
119
+ "model = CrossEncoder('cross-encoder/mmarco-mMiniLMv2-L12-H384-v1')"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": 5,
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": [
128
+ "from sentence_transformers import InputExample\n",
129
+ "pretrain_samples = [\n",
130
+ " #InputExample(texts=['тест', 'тест'], label=1.0),\n",
131
+ " InputExample(texts=['пока', 'до свидания'], label=1.0),\n",
132
+ " InputExample(texts=['как настроить модуль', 'как настроить модуль'], label=1.0),\n",
133
+ " InputExample(texts=['как настроить модуль яндекс доставка', 'как настроить модуль почта россии'], label=0.0),\n",
134
+ " InputExample(texts=['как настроить модуль почта россии', 'как настроить модуль robokassa'], label=0.0),\n",
135
+ " InputExample(texts=['как настроить модуль яндекс доставка', 'как настроить модуль robokassa'], label=0.0),\n",
136
+ " # InputExample(texts=['ошибка сервиса доставки почта россии', 'ошибка сервиса почта россии'], label=1.0),\n",
137
+ " InputExample(texts=['ошибка дата отгрузки, полученная от яндекс.доставки', 'ошибка даты отгрузки яндекс доставки'], label=1.0),\n",
138
+ " InputExample(texts=['ошибка дата отгрузки, полученная от яндекс.доставки', 'яндекс доставка ошибка сервиса доставки при выборе терминала отгрузки'], label=0.0),\n",
139
+ "]"
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": 6,
145
+ "metadata": {},
146
+ "outputs": [],
147
+ "source": [
148
+ "from torch.utils.data import DataLoader\n",
149
+ "pretrain_dataloader = DataLoader(pretrain_samples)"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": 7,
155
+ "metadata": {},
156
+ "outputs": [
157
+ {
158
+ "data": {
159
+ "application/vnd.jupyter.widget-view+json": {
160
+ "model_id": "8492aa648e8b4165a2696c46262135aa",
161
+ "version_major": 2,
162
+ "version_minor": 0
163
+ },
164
+ "text/plain": [
165
+ "Epoch: 0%| | 0/4 [00:00<?, ?it/s]"
166
+ ]
167
+ },
168
+ "metadata": {},
169
+ "output_type": "display_data"
170
+ },
171
+ {
172
+ "data": {
173
+ "application/vnd.jupyter.widget-view+json": {
174
+ "model_id": "0923809b3c7d4baeb7732847fd2fc81f",
175
+ "version_major": 2,
176
+ "version_minor": 0
177
+ },
178
+ "text/plain": [
179
+ "Iteration: 0%| | 0/7 [00:00<?, ?it/s]"
180
+ ]
181
+ },
182
+ "metadata": {},
183
+ "output_type": "display_data"
184
+ },
185
+ {
186
+ "data": {
187
+ "application/vnd.jupyter.widget-view+json": {
188
+ "model_id": "de0f058e118a4c7394209899d462ff52",
189
+ "version_major": 2,
190
+ "version_minor": 0
191
+ },
192
+ "text/plain": [
193
+ "Iteration: 0%| | 0/7 [00:00<?, ?it/s]"
194
+ ]
195
+ },
196
+ "metadata": {},
197
+ "output_type": "display_data"
198
+ },
199
+ {
200
+ "data": {
201
+ "application/vnd.jupyter.widget-view+json": {
202
+ "model_id": "91ab1f49559541e69c48eb6c2dee09ab",
203
+ "version_major": 2,
204
+ "version_minor": 0
205
+ },
206
+ "text/plain": [
207
+ "Iteration: 0%| | 0/7 [00:00<?, ?it/s]"
208
+ ]
209
+ },
210
+ "metadata": {},
211
+ "output_type": "display_data"
212
+ },
213
+ {
214
+ "data": {
215
+ "application/vnd.jupyter.widget-view+json": {
216
+ "model_id": "11d14e7927d8451d9eba7d0e0475ae77",
217
+ "version_major": 2,
218
+ "version_minor": 0
219
+ },
220
+ "text/plain": [
221
+ "Iteration: 0%| | 0/7 [00:00<?, ?it/s]"
222
+ ]
223
+ },
224
+ "metadata": {},
225
+ "output_type": "display_data"
226
+ }
227
+ ],
228
+ "source": [
229
+ "model.fit(pretrain_dataloader, epochs=4, optimizer_params={'lr': 1e-1, 'eps': 1e-6})"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": 8,
235
+ "metadata": {},
236
+ "outputs": [
237
+ {
238
+ "data": {
239
+ "text/plain": [
240
+ "array([ 7.250268, 7.564313, 10.511606, -5.297707, -10.485567],\n",
241
+ " dtype=float32)"
242
+ ]
243
+ },
244
+ "execution_count": 8,
245
+ "metadata": {},
246
+ "output_type": "execute_result"
247
+ }
248
+ ],
249
+ "source": [
250
+ "model.predict([\n",
251
+ " ('добрый день', 'здравствуйте'),\n",
252
+ " ('добрый день', 'привет'),\n",
253
+ " ('как исправить ошибку', 'как убрать ошибку'),\n",
254
+ " ('какой сегодня прекрасный день', 'некорректный ответ не понял'),\n",
255
+ " ('как настроить модуль яндекс доставка', 'как настроить модуль сдэк'),\n",
256
+ "])"
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": 9,
262
+ "metadata": {},
263
+ "outputs": [
264
+ {
265
+ "data": {
266
+ "application/vnd.jupyter.widget-view+json": {
267
+ "model_id": "a9ad4a608aba412c94ea890018b3f904",
268
+ "version_major": 2,
269
+ "version_minor": 0
270
+ },
271
+ "text/plain": [
272
+ "Epoch: 0%| | 0/1 [00:00<?, ?it/s]"
273
+ ]
274
+ },
275
+ "metadata": {},
276
+ "output_type": "display_data"
277
+ },
278
+ {
279
+ "data": {
280
+ "application/vnd.jupyter.widget-view+json": {
281
+ "model_id": "c0776393f80944cbac38e78b4bb43d3b",
282
+ "version_major": 2,
283
+ "version_minor": 0
284
+ },
285
+ "text/plain": [
286
+ "Iteration: 0%| | 0/50 [00:00<?, ?it/s]"
287
+ ]
288
+ },
289
+ "metadata": {},
290
+ "output_type": "display_data"
291
+ }
292
+ ],
293
+ "source": [
294
+ "model.fit(train_dataloader, epochs=1, optimizer_params={'lr': 1e-3, 'eps': 1e-6})"
295
+ ]
296
+ },
297
+ {
298
+ "cell_type": "code",
299
+ "execution_count": 10,
300
+ "metadata": {},
301
+ "outputs": [
302
+ {
303
+ "data": {
304
+ "text/plain": [
305
+ "array([ 6.8022833, 7.1733284, 10.234115 , -5.4563026, -10.522914 ],\n",
306
+ " dtype=float32)"
307
+ ]
308
+ },
309
+ "execution_count": 10,
310
+ "metadata": {},
311
+ "output_type": "execute_result"
312
+ }
313
+ ],
314
+ "source": [
315
+ "model.predict([\n",
316
+ " ('добрый день', 'здравствуйте'),\n",
317
+ " ('добрый день', 'привет'),\n",
318
+ " ('как исправить ошибку', 'как убрать ошибку'),\n",
319
+ " ('какой сегодня прекрасный день', 'некорректный ответ не понял'),\n",
320
+ " ('как настроить модуль яндекс доставка', 'как настроить модуль сдэк'),\n",
321
+ "])"
322
+ ]
323
+ },
324
+ {
325
+ "cell_type": "code",
326
+ "execution_count": 11,
327
+ "metadata": {},
328
+ "outputs": [
329
+ {
330
+ "data": {
331
+ "application/vnd.jupyter.widget-view+json": {
332
+ "model_id": "36256d415afb4b249b6609fb2329f799",
333
+ "version_major": 2,
334
+ "version_minor": 0
335
+ },
336
+ "text/plain": [
337
+ "Batches: 0%| | 0/25 [00:00<?, ?it/s]"
338
+ ]
339
+ },
340
+ "metadata": {},
341
+ "output_type": "display_data"
342
+ }
343
+ ],
344
+ "source": [
345
+ "labels = [example.label for example in train_dataloader.dataset]\n",
346
+ "predicts = model.predict([example.texts for example in train_dataloader.dataset], show_progress_bar=True)"
347
+ ]
348
+ },
349
+ {
350
+ "cell_type": "code",
351
+ "execution_count": 12,
352
+ "metadata": {},
353
+ "outputs": [
354
+ {
355
+ "name": "stdout",
356
+ "output_type": "stream",
357
+ "text": [
358
+ "0.0 0.65314835\n",
359
+ "=== почему кол-во пользователей в сегменте аудиторий отличается от кол-ва пользователей в сегменте в retailcrm? яндекс аудитория retailcrm\n",
360
+ "=== экспорт сегментов в вконтакте минимальное число контактов в сегменте для загрузки в вконтакте какое минимальное число контактов нужно загрузить в вконтакте\n",
361
+ "1.0 -9.538332\n",
362
+ "=== какие два случая рассматриваются в статье? digital pipeline почта россия\n",
363
+ "=== digital pipline принцип работы типа синхронизации точное соответствие какие два случая рассматриваются в статье\n",
364
+ "1.0 -0.25401944\n",
365
+ "=== превышен лимит количества контактов почта россия mailchimp\n",
366
+ "=== экспорт сегментов в mailchimp ошибка превышен лимит количества контактов превышен лимит количества контактов\n",
367
+ "1.0 -6.570962\n",
368
+ "=== как добавить клиентский аккаунт? маркетинговый расход почта россия\n",
369
+ "=== маркетинговые расходы добавление аккаунта представителя как добавить клиентский аккаунт\n",
370
+ "0.0 3.0155332\n",
371
+ "=== как подключить модуль sipuni?\n",
372
+ "подключение модуля sipuni sipuni\n",
373
+ "=== почта россии работа с маркировкой почта россии работа с маркировкой почта россии\n",
374
+ "0.0 0.5682208\n",
375
+ "=== как настроить/подключить модуль экспорт сегментов яндекс аудитории яндекс аудитория\n",
376
+ "=== экспорт сегментов в яндекс аудитории частота обновления сегментов в яндекс аудитории как часто обновляются уже созданные сегменты\n",
377
+ "0.0 6.7683535\n",
378
+ "=== как часто обновляются статусы почта россии?\n",
379
+ "какая частота обновления статусов почта россии в retailcrm почта россия retailcrm\n",
380
+ "=== импорт лидов facebook leads добавление дополнительной страницы facebook как добавить страницу facebook в retailcrm services\n",
381
+ "0.0 3.5910566\n",
382
+ "=== кейс: менеджер собирает заказ одним набором товарных позиций, далее добавляет 1-ую оплату. \n",
383
+ "клиент ее оплачивает. \n",
384
+ "менеджер по просьбе клиента добавляет в заказ новую позицию и добавляет новую оплату. \n",
385
+ "клиент и ее оплачивает. \n",
386
+ "и после подтверждения получает два чека. один — на сумму первой оплаты, в чеке корректно отображаются товарные позиции, и второй чек на сумму второй оплаты и с позицией \"аванс\". cloudpayments почта россия\n",
387
+ "=== экспорт сегментов в вконтакте минимальное число контактов в сегменте для загрузки в вконтакте какое минимальное число контактов нужно загрузить в вконтакте\n",
388
+ "1.0 -2.5233462\n",
389
+ "=== что произойдет при удалении оплаты? cloudpayments почта россия\n",
390
+ "=== cloudpayments перенос оплаты в платежных модулях на примере модуля cloudpayments что произойдет при удалении оплаты\n",
391
+ "0.0 4.6610775\n",
392
+ "=== отличие модулей яндекс.метрика в маркетплейс\n",
393
+ "чем отличается модуль яндекс метрика ткаченко от вашего? яндекс метрика\n",
394
+ "=== яндекс метрика почему не отображаются заказы в отчетах метрики загрузятся ли телефонные заказы в метрику \n",
395
+ "1.0 -9.697531\n",
396
+ "=== фильтры в списке обращений почта россия\n",
397
+ "=== продвинутый функционал фильтры в списке обращений фильтры в списке обращений\n",
398
+ "5.0 3.6328275\n",
399
+ "=== как решить проблему с ошибкой\n",
400
+ "=== как убрать ошибку\n",
401
+ "4.0 1.7188714\n",
402
+ "=== можете подсказать, что делать с ошибкой\n",
403
+ "=== как убрать ошибку\n",
404
+ "0.0 7.129506\n",
405
+ "=== не отображаются тарифы\n",
406
+ "=== не передаются тарифы\n",
407
+ "0.0 0.903283\n",
408
+ "=== почему количество пользователей отличается\n",
409
+ "=== почему clientid отличается\n",
410
+ "1.0 -1.9656792\n",
411
+ "=== что означает галка 'доставка курьером'\n",
412
+ "=== что означает галка доставка курьером\n",
413
+ "0.0 2.8144162\n",
414
+ "=== почта россии\n",
415
+ "=== яндекс доставка\n",
416
+ "0.0 8.619669\n",
417
+ "=== unisender\n",
418
+ "=== яндекс доставка\n",
419
+ "0.0 3.990125\n",
420
+ "=== альфабанк\n",
421
+ "=== яндекс доставка\n",
422
+ "0.0 1.2026126\n",
423
+ "=== почта россии\n",
424
+ "=== яндекс аудитории\n",
425
+ "0.0 6.3785267\n",
426
+ "=== sipuni\n",
427
+ "=== cloudpayments\n",
428
+ "0.0 8.150368\n",
429
+ "=== sipuni\n",
430
+ "=== facebook\n",
431
+ "0.0 8.385334\n",
432
+ "=== robokassa\n",
433
+ "=== вконтакте\n",
434
+ "0.0 2.1733193\n",
435
+ "=== robokassa\n",
436
+ "=== digital pipeline\n",
437
+ "0.0 10.423201\n",
438
+ "=== facebook\n",
439
+ "=== вконтакте\n",
440
+ "0.0 5.420491\n",
441
+ "=== facebook\n",
442
+ "=== mailchimp\n",
443
+ "0.0 0.3983256\n",
444
+ "=== почта россии\n",
445
+ "=== cloudpayments\n"
446
+ ]
447
+ }
448
+ ],
449
+ "source": [
450
+ "import math\n",
451
+ "\n",
452
+ "for label, predict, example in zip(labels, predicts, train_dataloader.dataset):\n",
453
+ " label1 = 1.0 if math.copysign(1, predict) == 1.0 else 0.0\n",
454
+ " if (label != label1):\n",
455
+ " print(label, predict)\n",
456
+ " print('===', example.texts[0])\n",
457
+ " print('===', example.texts[1])"
458
+ ]
459
+ },
460
+ {
461
+ "cell_type": "code",
462
+ "execution_count": null,
463
+ "metadata": {},
464
+ "outputs": [],
465
+ "source": []
466
+ }
467
+ ],
468
+ "metadata": {
469
+ "kernelspec": {
470
+ "display_name": "base",
471
+ "language": "python",
472
+ "name": "python3"
473
+ },
474
+ "language_info": {
475
+ "codemirror_mode": {
476
+ "name": "ipython",
477
+ "version": 3
478
+ },
479
+ "file_extension": ".py",
480
+ "mimetype": "text/x-python",
481
+ "name": "python",
482
+ "nbconvert_exporter": "python",
483
+ "pygments_lexer": "ipython3",
484
+ "version": "3.10.9"
485
+ },
486
+ "orig_nbformat": 4
487
+ },
488
+ "nbformat": 4,
489
+ "nbformat_minor": 2
490
+ }
list_questions.ipynb ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from list_questions import load_questions\n",
10
+ "from extract_keywords import extract_keywords, kw_model, vectorizer"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 2,
16
+ "metadata": {},
17
+ "outputs": [],
18
+ "source": [
19
+ "questions = load_questions('omnidesk-ai-chatgpt-questions.sqlite')"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": 4,
25
+ "metadata": {},
26
+ "outputs": [
27
+ {
28
+ "name": "stdout",
29
+ "output_type": "stream",
30
+ "text": [
31
+ "Как долго рассматривается обращение на портале поддержки?\n",
32
+ "Не работает модуль СДЭК\n",
33
+ "Как деактивировать модуль?\n",
34
+ "Что произойдет, если я не деактивирую модуль после окончания пробного периода?\n",
35
+ "Как происходит оплата за использование модуля?\n",
36
+ "Модуль заморожен. Как включить/возобновить работу модуля?\n",
37
+ "Как вернуть оплату за использование модуля?\n",
38
+ "Возврат оплаты за модуль\n",
39
+ "Как активировать модуль?\n",
40
+ "Какие методы API ключа нужно разрешить для работы с модулем?\n",
41
+ "Фильтры в списке обращений\n",
42
+ "добрый день / здравствуйте\n",
43
+ "спасибо / до свидания\n",
44
+ "некорректный ответ, не понял\n",
45
+ "как убрать ошибку\n",
46
+ "позови человека/сотрудника/менеджера/оператора\n"
47
+ ]
48
+ }
49
+ ],
50
+ "source": [
51
+ "for q in questions:\n",
52
+ " keywords = extract_keywords(q['query'])\n",
53
+ " if (len(keywords) == 0):\n",
54
+ " print(q['question'])"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": 5,
60
+ "metadata": {},
61
+ "outputs": [
62
+ {
63
+ "name": "stdout",
64
+ "output_type": "stream",
65
+ "text": [
66
+ "экспорт сегментов в вконтакте\n"
67
+ ]
68
+ },
69
+ {
70
+ "data": {
71
+ "text/plain": [
72
+ "[]"
73
+ ]
74
+ },
75
+ "execution_count": 5,
76
+ "metadata": {},
77
+ "output_type": "execute_result"
78
+ }
79
+ ],
80
+ "source": [
81
+ "extract_keywords('экспорт сегментов в вконтакте')"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": 14,
87
+ "metadata": {},
88
+ "outputs": [
89
+ {
90
+ "data": {
91
+ "text/plain": [
92
+ "[]"
93
+ ]
94
+ },
95
+ "execution_count": 14,
96
+ "metadata": {},
97
+ "output_type": "execute_result"
98
+ }
99
+ ],
100
+ "source": [
101
+ "kw_model.extract_keywords('экспорт сегментов в вконтакте', vectorizer=vectorizer)"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": 15,
107
+ "metadata": {},
108
+ "outputs": [
109
+ {
110
+ "data": {
111
+ "text/plain": [
112
+ "['почта россия',\n",
113
+ " 'почта',\n",
114
+ " 'почта рф',\n",
115
+ " 'пр',\n",
116
+ " 'gh',\n",
117
+ " 'почта россия трекинг',\n",
118
+ " 'пр трекинг',\n",
119
+ " 'почта трекинг',\n",
120
+ " 'пр трэкинг',\n",
121
+ " 'почта трэкинг',\n",
122
+ " 'реестр почта',\n",
123
+ " 'реестр пр',\n",
124
+ " 'реестр почта россия',\n",
125
+ " 'реестр пэк',\n",
126
+ " 'реквизит',\n",
127
+ " 'пешкарика',\n",
128
+ " 'импорт лид директ',\n",
129
+ " 'яндекс доставка экспресс',\n",
130
+ " 'яндекс доставка express',\n",
131
+ " 'яд экспресс',\n",
132
+ " 'ядоставка экспресс',\n",
133
+ " 'яндекс доставка ndd',\n",
134
+ " 'яд ндд',\n",
135
+ " 'я доставка ндд',\n",
136
+ " 'ядоставка ндд',\n",
137
+ " 'модуль ндд',\n",
138
+ " 'яндекс метрика',\n",
139
+ " 'яндекс метрика импорт',\n",
140
+ " 'альфабанк',\n",
141
+ " 'альфа банк',\n",
142
+ " 'alfabank',\n",
143
+ " 'альфа',\n",
144
+ " 'импорт лид facebook',\n",
145
+ " 'импорт лид fb',\n",
146
+ " 'загрузка лид fb',\n",
147
+ " 'лида фейсбук',\n",
148
+ " 'импорт лид фб',\n",
149
+ " 'fb lead',\n",
150
+ " 'маркетинговый расход',\n",
151
+ " 'расход',\n",
152
+ " 'загрузка расход',\n",
153
+ " 'cloudpayments',\n",
154
+ " 'клауд',\n",
155
+ " 'клаудпеймент',\n",
156
+ " 'клаудпейментс',\n",
157
+ " 'robokassa',\n",
158
+ " 'робокасса',\n",
159
+ " 'робокас',\n",
160
+ " 'sipuni',\n",
161
+ " 'сипуня',\n",
162
+ " 'сипьюня',\n",
163
+ " 'mailchimp',\n",
164
+ " 'майлчимп',\n",
165
+ " 'мейлчать',\n",
166
+ " 'мейлчимп',\n",
167
+ " 'unisender',\n",
168
+ " 'юнисендер',\n",
169
+ " 'яндекс аудитория',\n",
170
+ " 'экспорт аудитория',\n",
171
+ " 'экспорт яндекс аудитория',\n",
172
+ " 'экспорт facebook',\n",
173
+ " 'экспорт сегмент facebook',\n",
174
+ " 'экспорт fb',\n",
175
+ " 'экспорт фейсбук',\n",
176
+ " 'экспорт аудитория фб',\n",
177
+ " 'fb экспорт',\n",
178
+ " 'экспорт вк',\n",
179
+ " 'экспорт сегмент vkontakte',\n",
180
+ " 'экспорт vk',\n",
181
+ " 'экспорт контакт',\n",
182
+ " 'экспорт сегмент вконтакте',\n",
183
+ " 'retailcrm',\n",
184
+ " 'срм',\n",
185
+ " 'ритейл',\n",
186
+ " 'ритейл срм',\n",
187
+ " 'ритейлсрма',\n",
188
+ " 'retail crm',\n",
189
+ " 'ритейлцрма',\n",
190
+ " 'ритейл црм',\n",
191
+ " 'retailcrm services',\n",
192
+ " 'retailcrmservices',\n",
193
+ " 'ритейлцрма services',\n",
194
+ " 'лк crm services',\n",
195
+ " 'ритейлцрма сервисес',\n",
196
+ " 'ритейлсрма сервисес',\n",
197
+ " 'ритейлцрма сервис',\n",
198
+ " 'ритейлцрмсервисес',\n",
199
+ " 'ритейлсрмсервисес']"
200
+ ]
201
+ },
202
+ "execution_count": 15,
203
+ "metadata": {},
204
+ "output_type": "execute_result"
205
+ }
206
+ ],
207
+ "source": [
208
+ "vectorizer.vocabulary"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": null,
214
+ "metadata": {},
215
+ "outputs": [],
216
+ "source": []
217
+ }
218
+ ],
219
+ "metadata": {
220
+ "kernelspec": {
221
+ "display_name": "base",
222
+ "language": "python",
223
+ "name": "python3"
224
+ },
225
+ "language_info": {
226
+ "codemirror_mode": {
227
+ "name": "ipython",
228
+ "version": 3
229
+ },
230
+ "file_extension": ".py",
231
+ "mimetype": "text/x-python",
232
+ "name": "python",
233
+ "nbconvert_exporter": "python",
234
+ "pygments_lexer": "ipython3",
235
+ "version": "3.10.9"
236
+ },
237
+ "orig_nbformat": 4
238
+ },
239
+ "nbformat": 4,
240
+ "nbformat_minor": 2
241
+ }
list_questions.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3, json
2
+ from contextlib import closing
3
+ from extract_keywords import extract_keywords
4
+
5
+ punctuation = '!"#\'(),:;?[]^`}{'
6
+ punctuation2 = '-/&._~+*=@<>[]\\'
7
+ remove_punctuation = str.maketrans(punctuation2, ' ' * len(punctuation2), punctuation)
8
+
9
+ def load_questions(sqlite_filename):
10
+ all_questions = []
11
+ with closing(sqlite3.connect(sqlite_filename)) as db:
12
+ db.row_factory = sqlite3.Row
13
+ with closing(db.cursor()) as cursor:
14
+ results = cursor.execute(
15
+ "SELECT id, articleId, title, category, section, questions FROM articles WHERE articleType = ? AND doNotUse IS NULL OR doNotUse = 0",
16
+ ('article',)
17
+ ).fetchall()
18
+
19
+ for res in results:
20
+ section = res['section'].lower()
21
+ title = res['title'].lower()
22
+ if section == 'служебная информация':
23
+ section = ''
24
+ title = ''
25
+
26
+ questions = json.loads(res['questions'])
27
+ for q in questions:
28
+ q['query'] = " ".join(section.split() + title.split() + q['question'].split()).translate(remove_punctuation).lower()
29
+ q['articleId'] = res['articleId']
30
+ all_questions += questions
31
+
32
+ return all_questions
33
+
34
+
35
+ #print("Loading questions from db...")
36
+ #questions = load_questions("omnidesk-ai-chatgpt-questions.sqlite")
37
+
38
+ #for q in questions:
39
+ # keywords = extract_keywords(q['query'])
40
+ # if (len(keywords) == 0):
41
+ # print(q)
42
+ # break
reranking.py CHANGED
@@ -1,7 +1,8 @@
1
  from pathlib import Path
2
  from sentence_transformers.cross_encoder import CrossEncoder
3
  from more_itertools import windowed
4
- model = CrossEncoder('cross-encoder/mmarco-mMiniLMv2-L12-H384-v1', max_length=512, device='cpu')
 
5
 
6
  def rerank(sentence_combinations):
7
  similarity_scores = model.predict(sentence_combinations)
 
1
  from pathlib import Path
2
  from sentence_transformers.cross_encoder import CrossEncoder
3
  from more_itertools import windowed
4
+ #model = CrossEncoder('cross-encoder/mmarco-mMiniLMv2-L12-H384-v1', max_length=512, device='cpu')
5
+ model = CrossEncoder('retailcrmservices/crossencoder-mMiniLMv2-L12-H384-v1-ru', max_length=512, device='cpu')
6
 
7
  def rerank(sentence_combinations):
8
  similarity_scores = model.predict(sentence_combinations)
test_sbert.ipynb ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from sentence_transformers import CrossEncoder"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": 2,
15
+ "metadata": {},
16
+ "outputs": [
17
+ {
18
+ "name": "stderr",
19
+ "output_type": "stream",
20
+ "text": [
21
+ "2023-08-20 23:39:32.598364: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
22
+ ]
23
+ }
24
+ ],
25
+ "source": [
26
+ "model = CrossEncoder('cross-encoder/mmarco-mMiniLMv2-L12-H384-v1')"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": 3,
32
+ "metadata": {},
33
+ "outputs": [
34
+ {
35
+ "data": {
36
+ "text/plain": [
37
+ "array([-4.4115596, -3.8891914], dtype=float32)"
38
+ ]
39
+ },
40
+ "execution_count": 3,
41
+ "metadata": {},
42
+ "output_type": "execute_result"
43
+ }
44
+ ],
45
+ "source": [
46
+ "model.predict(\n",
47
+ " [\n",
48
+ " ('как исправить ошибку', 'как убрать ошибку'),\n",
49
+ " ('добрый день', 'привет'),\n",
50
+ " ]\n",
51
+ ")"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": 4,
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "from sentence_transformers import InputExample\n",
61
+ "train_samples = [\n",
62
+ " InputExample(texts=['как настроить модуль', 'как настроить модуль'], label=1.0),\n",
63
+ " InputExample(texts=['как настроить модуль яндекс доставка', 'как настроить модуль почта россии'], label=0.0),\n",
64
+ "]"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": 5,
70
+ "metadata": {},
71
+ "outputs": [],
72
+ "source": [
73
+ "from torch.utils.data import DataLoader\n",
74
+ "train_dataloader = DataLoader(train_samples)"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": 6,
80
+ "metadata": {},
81
+ "outputs": [
82
+ {
83
+ "data": {
84
+ "application/vnd.jupyter.widget-view+json": {
85
+ "model_id": "e8062b9c6b4b47619bccaa0d18beec23",
86
+ "version_major": 2,
87
+ "version_minor": 0
88
+ },
89
+ "text/plain": [
90
+ "Epoch: 0%| | 0/10 [00:00<?, ?it/s]"
91
+ ]
92
+ },
93
+ "metadata": {},
94
+ "output_type": "display_data"
95
+ },
96
+ {
97
+ "data": {
98
+ "application/vnd.jupyter.widget-view+json": {
99
+ "model_id": "2f5f91e57ea94671ac197c3f253ba29e",
100
+ "version_major": 2,
101
+ "version_minor": 0
102
+ },
103
+ "text/plain": [
104
+ "Iteration: 0%| | 0/2 [00:00<?, ?it/s]"
105
+ ]
106
+ },
107
+ "metadata": {},
108
+ "output_type": "display_data"
109
+ },
110
+ {
111
+ "data": {
112
+ "application/vnd.jupyter.widget-view+json": {
113
+ "model_id": "59daed18b9ac4b438def689dad8adccd",
114
+ "version_major": 2,
115
+ "version_minor": 0
116
+ },
117
+ "text/plain": [
118
+ "Iteration: 0%| | 0/2 [00:00<?, ?it/s]"
119
+ ]
120
+ },
121
+ "metadata": {},
122
+ "output_type": "display_data"
123
+ },
124
+ {
125
+ "data": {
126
+ "application/vnd.jupyter.widget-view+json": {
127
+ "model_id": "d27ce1dccdca437fa5f319687405a90b",
128
+ "version_major": 2,
129
+ "version_minor": 0
130
+ },
131
+ "text/plain": [
132
+ "Iteration: 0%| | 0/2 [00:00<?, ?it/s]"
133
+ ]
134
+ },
135
+ "metadata": {},
136
+ "output_type": "display_data"
137
+ },
138
+ {
139
+ "data": {
140
+ "application/vnd.jupyter.widget-view+json": {
141
+ "model_id": "949c302a01c642c29b80a79c25b6a449",
142
+ "version_major": 2,
143
+ "version_minor": 0
144
+ },
145
+ "text/plain": [
146
+ "Iteration: 0%| | 0/2 [00:00<?, ?it/s]"
147
+ ]
148
+ },
149
+ "metadata": {},
150
+ "output_type": "display_data"
151
+ },
152
+ {
153
+ "data": {
154
+ "application/vnd.jupyter.widget-view+json": {
155
+ "model_id": "8c5ce379a69d4d88a8ecfbe16dee6e4b",
156
+ "version_major": 2,
157
+ "version_minor": 0
158
+ },
159
+ "text/plain": [
160
+ "Iteration: 0%| | 0/2 [00:00<?, ?it/s]"
161
+ ]
162
+ },
163
+ "metadata": {},
164
+ "output_type": "display_data"
165
+ },
166
+ {
167
+ "data": {
168
+ "application/vnd.jupyter.widget-view+json": {
169
+ "model_id": "cf8929bca29c4765bc2214e4b6b1ca6c",
170
+ "version_major": 2,
171
+ "version_minor": 0
172
+ },
173
+ "text/plain": [
174
+ "Iteration: 0%| | 0/2 [00:00<?, ?it/s]"
175
+ ]
176
+ },
177
+ "metadata": {},
178
+ "output_type": "display_data"
179
+ },
180
+ {
181
+ "data": {
182
+ "application/vnd.jupyter.widget-view+json": {
183
+ "model_id": "270c365e0c984d5ba1099c637e7c2632",
184
+ "version_major": 2,
185
+ "version_minor": 0
186
+ },
187
+ "text/plain": [
188
+ "Iteration: 0%| | 0/2 [00:00<?, ?it/s]"
189
+ ]
190
+ },
191
+ "metadata": {},
192
+ "output_type": "display_data"
193
+ },
194
+ {
195
+ "data": {
196
+ "application/vnd.jupyter.widget-view+json": {
197
+ "model_id": "71f789e96b30453da3d71a428c2870c6",
198
+ "version_major": 2,
199
+ "version_minor": 0
200
+ },
201
+ "text/plain": [
202
+ "Iteration: 0%| | 0/2 [00:00<?, ?it/s]"
203
+ ]
204
+ },
205
+ "metadata": {},
206
+ "output_type": "display_data"
207
+ },
208
+ {
209
+ "data": {
210
+ "application/vnd.jupyter.widget-view+json": {
211
+ "model_id": "e8161252067e4cf283df577bb3bb70f4",
212
+ "version_major": 2,
213
+ "version_minor": 0
214
+ },
215
+ "text/plain": [
216
+ "Iteration: 0%| | 0/2 [00:00<?, ?it/s]"
217
+ ]
218
+ },
219
+ "metadata": {},
220
+ "output_type": "display_data"
221
+ },
222
+ {
223
+ "data": {
224
+ "application/vnd.jupyter.widget-view+json": {
225
+ "model_id": "9de909bafcd5453fac96198158d5dc00",
226
+ "version_major": 2,
227
+ "version_minor": 0
228
+ },
229
+ "text/plain": [
230
+ "Iteration: 0%| | 0/2 [00:00<?, ?it/s]"
231
+ ]
232
+ },
233
+ "metadata": {},
234
+ "output_type": "display_data"
235
+ }
236
+ ],
237
+ "source": [
238
+ "model.fit(train_dataloader, epochs=10, optimizer_params={'lr': 1e-1, 'eps': 1e-6})"
239
+ ]
240
+ },
241
+ {
242
+ "cell_type": "code",
243
+ "execution_count": 7,
244
+ "metadata": {},
245
+ "outputs": [
246
+ {
247
+ "data": {
248
+ "text/plain": [
249
+ "array([11.245676], dtype=float32)"
250
+ ]
251
+ },
252
+ "execution_count": 7,
253
+ "metadata": {},
254
+ "output_type": "execute_result"
255
+ }
256
+ ],
257
+ "source": [
258
+ "model.predict(\n",
259
+ " [('test', 'test')]\n",
260
+ ")"
261
+ ]
262
+ },
263
+ {
264
+ "cell_type": "code",
265
+ "execution_count": 13,
266
+ "metadata": {},
267
+ "outputs": [
268
+ {
269
+ "data": {
270
+ "text/plain": [
271
+ "array([ 1.686686 , 2.3622127, 9.887518 , -9.177755 , -6.320711 ],\n",
272
+ " dtype=float32)"
273
+ ]
274
+ },
275
+ "execution_count": 13,
276
+ "metadata": {},
277
+ "output_type": "execute_result"
278
+ }
279
+ ],
280
+ "source": [
281
+ "model.predict([\n",
282
+ " ('добрый день', 'здравствуйте'),\n",
283
+ " ('добрый день', 'привет'),\n",
284
+ " ('как исправить ошибку', 'как убрать ошибку'),\n",
285
+ " ('какой сегодня прекрасный день', 'некорректный ответ не понял'),\n",
286
+ " ('как настроить модуль яндекс доставка', 'как настроить модуль сдэк'),\n",
287
+ "])"
288
+ ]
289
+ },
290
+ {
291
+ "cell_type": "code",
292
+ "execution_count": null,
293
+ "metadata": {},
294
+ "outputs": [],
295
+ "source": []
296
+ }
297
+ ],
298
+ "metadata": {
299
+ "kernelspec": {
300
+ "display_name": "base",
301
+ "language": "python",
302
+ "name": "python3"
303
+ },
304
+ "language_info": {
305
+ "codemirror_mode": {
306
+ "name": "ipython",
307
+ "version": 3
308
+ },
309
+ "file_extension": ".py",
310
+ "mimetype": "text/x-python",
311
+ "name": "python",
312
+ "nbconvert_exporter": "python",
313
+ "pygments_lexer": "ipython3",
314
+ "version": "3.10.9"
315
+ },
316
+ "orig_nbformat": 4
317
+ },
318
+ "nbformat": 4,
319
+ "nbformat_minor": 2
320
+ }