IvT-DS commited on
Commit
41b0868
·
verified ·
1 Parent(s): b1e05ac

Upload 10 files

Browse files
data/data.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe31f260bf45d3d790e24ad53b22f86611dc7d6f0e658f834685d142cead29f9
3
+ size 46974868
model/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:635462300692bfdc4f400d6e81207c96649048f27489cd36ab79cf236a37a5a8
3
+ size 481766937
pages/02_📺_Find_my_show.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import importlib.util
3
+ import torch
4
+ import streamlit as st
5
+ import pandas as pd
6
+ from PIL import Image
7
+
8
+ # Формируем абсолютный путь до файла functions.py
9
+ module_path = os.path.abspath(
10
+ os.path.join(os.path.dirname(__file__), "..", "resource", "functions.py")
11
+ )
12
+
13
+ # Загружаем модуль
14
+ spec = importlib.util.spec_from_file_location("resource.functions", module_path)
15
+ functions = importlib.util.module_from_spec(spec)
16
+ spec.loader.exec_module(functions)
17
+
18
+ # Теперь используем функции напрямую
19
+ table_maker = functions.table_maker
20
+ RecSys = functions.RecSys
21
+
22
+ poster_path = "https://resizer.mail.ru/p/"
23
+ show_path = "https://kino.mail.ru/series_"
24
+ placeholder_path = "../img/v2/nopicture/308x462.png"
25
+
26
+
27
+ @st.cache(allow_output_mutation=True)
28
+ def load_model(model_path):
29
+ model = torch.load(model_path)
30
+ return model
31
+
32
+
33
+ @st.cache
34
+ def load_data(data_path):
35
+ df = pd.read_pickle(data_path)
36
+ return df
37
+
38
+
39
+ MODEL_PATH = "model/model.pt"
40
+
41
+ df = pd.read_pickle("data/data.pkl")
42
+ model = torch.load(MODEL_PATH)
43
+
44
+ image = Image.open("pages/tv_shows.png")
45
+ st.image(image, use_column_width=True)
46
+
47
+ # Заголовок приложения
48
+ st.markdown("### Поиск сериалов по запросу пользователя")
49
+
50
+ # Создание списка уникальных стран
51
+ all_countries = sorted(set(df["county"].tolist()))
52
+
53
+ # Создание списка уникальных жанров
54
+ all_genres = set()
55
+ for genres_set in df["tags"].dropna():
56
+ all_genres.update(genres_set)
57
+ all_genres = sorted(all_genres)
58
+
59
+ # Фильтр по наличию рейтинга
60
+ has_rating = st.sidebar.checkbox("Показывать только сериалы с рейтингом?", True)
61
+
62
+ # Виджеты для боковой панели
63
+ selected_country = st.sidebar.multiselect("Страна", all_countries)
64
+ selected_genre = st.sidebar.multiselect("Жанры", all_genres)
65
+
66
+
67
+ rating = True
68
+
69
+ search_table = table_maker(
70
+ df=df,
71
+ country=selected_country,
72
+ min_year=int(df["year"].min()),
73
+ max_year=int(df["year"].max()),
74
+ tagger=set(selected_genre),
75
+ rating=has_rating,
76
+ )
77
+
78
+ # Проверяем, пустой ли отфильтрованный DataFrame
79
+ if search_table.empty:
80
+ st.error(
81
+ "После фильтрации данных не осталось. Пожалуйста, выберите другие параметры."
82
+ )
83
+ else:
84
+ # Преобразование year в числовой формат, если возможно, и обработка NaN значений
85
+ search_table["year"] = pd.to_numeric(search_table["year"], errors="coerce").dropna()
86
+
87
+ if search_table.empty:
88
+ st.error(
89
+ "После фильтрации и обработки годов в данных не осталось записей. Пожалуйста, выберите другие параметры."
90
+ )
91
+ else:
92
+ # Теперь безопасно ищем min и max
93
+ min_year = int(search_table["year"].min())
94
+ max_year = int(search_table["year"].max())
95
+
96
+ # Если есть хотя бы два разных года, отображаем слайдер
97
+ if min_year < max_year:
98
+ selected_year_range = st.sidebar.slider(
99
+ "Выберите диапазон лет выпуска",
100
+ min_value=min_year,
101
+ max_value=max_year,
102
+ value=(min_year, max_year),
103
+ )
104
+ # Применяем фильтр по годам
105
+ search_table = search_table[
106
+ (search_table["year"] >= selected_year_range[0])
107
+ & (search_table["year"] <= selected_year_range[1])
108
+ ]
109
+
110
+ st.sidebar.markdown("---")
111
+ st.sidebar.markdown("### Дополнительные настройки")
112
+
113
+ # Позволяет пользователю выбрать количество сериалов для отображения, от 1 до 10
114
+ top_n = st.sidebar.number_input(
115
+ "Сколько сериалов показывать?", min_value=1, max_value=10, value=5
116
+ )
117
+
118
+ # Создание текстового поля для ввода пользовательского запроса
119
+ user_request = st.text_input("Введите ваш запрос:", "звездные войны")
120
+
121
+ if st.button("Найти сериалы по запросу") and len(df) > 0:
122
+
123
+ output = RecSys(search_table, user_request, model)
124
+
125
+ # top_n = 5 # мин 1 макс 10
126
+ res = output().head(top_n)
127
+
128
+ (
129
+ poster,
130
+ title,
131
+ description,
132
+ rating,
133
+ genre,
134
+ cast,
135
+ score,
136
+ year,
137
+ links,
138
+ country,
139
+ ) = (
140
+ {},
141
+ {},
142
+ {},
143
+ {},
144
+ {},
145
+ {},
146
+ {},
147
+ {},
148
+ {},
149
+ {},
150
+ )
151
+
152
+ for i, con in enumerate(res["poster"]):
153
+ # Проверяем, является ли значение в con ссылкой или путем к файлу
154
+ if "nopicture" in con:
155
+ poster[i] = placeholder_path
156
+ else:
157
+ poster[i] = poster_path + con
158
+
159
+ for i, con in enumerate(res["year"]):
160
+ year[i] = con
161
+
162
+ for i, con in enumerate(res["title"]):
163
+ title[i] = con
164
+
165
+ for i, con in enumerate(res["description"]):
166
+ description[i] = con
167
+
168
+ for i, con in enumerate(res["rating"]):
169
+ rating[i] = con
170
+
171
+ for i, con in enumerate(res["tags"]):
172
+ genre[i] = ", ".join(con)
173
+
174
+ for i, con in enumerate(res["cast"]):
175
+ cast[i] = con
176
+
177
+ for i, con in enumerate(res["score"]):
178
+ score[i] = con
179
+
180
+ for i, con in enumerate(res["url"]):
181
+ links[i] = show_path + con
182
+
183
+ for i, con in enumerate(res["county"]):
184
+ country[i] = con
185
+
186
+ st.markdown("---")
187
+
188
+ # Проверяем, пустой ли набор результатов
189
+ if len(res) == 0:
190
+ st.error(
191
+ "Сериалы по выбранным параметрам не найдены. Попробуйте изменить критерии поиска."
192
+ )
193
+ else:
194
+ # Если результаты есть, выводим их
195
+ iterations = min(len(res), top_n)
196
+
197
+ for i in range(iterations):
198
+
199
+ col1, col2 = st.columns([1, 3])
200
+ with col1:
201
+ st.image(poster[i])
202
+ # Добавляем ссылку под картинкой
203
+ st.markdown(
204
+ f"<a href='{links[i]}' target='_blank' style='display: block; text-align: center; color: grey; font-size: small; font-style: italic;'>Смотреть сериал</a>",
205
+ unsafe_allow_html=True,
206
+ )
207
+
208
+ with col2:
209
+
210
+ st.markdown(
211
+ f"<span style='font-weight:bold; font-size:22px;'>Название сериала:</span> <span style='font-size:20px;'>«{title[i]}»</span>",
212
+ unsafe_allow_html=True,
213
+ )
214
+
215
+ st.markdown(
216
+ f"<span style='font-weight:bold; font-size:16px;'>Страна:</span> <span style='font-size:16px;'>{country[i]}</span>",
217
+ unsafe_allow_html=True,
218
+ )
219
+
220
+ st.markdown(
221
+ f"<span style='font-weight:bold; font-size:16px;'>Год выпуска:</span> <span style='font-size:16px;'>{year[i]}</span>",
222
+ unsafe_allow_html=True,
223
+ )
224
+
225
+ st.markdown(
226
+ f"<span style='font-weight:bold; font-size:16px;'>Жанр:</span> <span style='font-size:16px;'>{genre[i]}</span>",
227
+ unsafe_allow_html=True,
228
+ )
229
+
230
+ rating_display = (
231
+ "Нет информации" if pd.isna(rating[i]) else rating[i]
232
+ )
233
+
234
+ st.markdown(
235
+ f"<span style='font-weight:bold; font-size:16px;'>Рейтинг:</span> <span style='font-size:16px;'>{rating_display}</span>",
236
+ unsafe_allow_html=True,
237
+ )
238
+
239
+ st.markdown(
240
+ "<h6 style='font-weight:bold;'>В ролях:</h6>",
241
+ unsafe_allow_html=True,
242
+ )
243
+
244
+ st.markdown(
245
+ f"<div style='text-align: justify; margin-bottom: 18px;'>{cast[i]}</div>",
246
+ unsafe_allow_html=True,
247
+ )
248
+
249
+ st.markdown(
250
+ "<h6 style='font-weight:bold;'>Описание:</h6>",
251
+ unsafe_allow_html=True,
252
+ )
253
+
254
+ st.markdown(
255
+ f"<div style='text-align: justify;'>{description[i]}</div>",
256
+ unsafe_allow_html=True,
257
+ )
258
+ score_display = round(score[i], 3)
259
+ st.markdown(
260
+ f"<div style='color: grey;'><hr style='margin: 2px 0;'/><span style='font-weight:bold; font-size:13px; font-style: italic;'>Коэффициент сходимости (косинусное сходство):</span> <span style='font-size:13px; font-style: italic;'>{score_display}</span><hr style='margin: 2px 0;'/></div>",
261
+ unsafe_allow_html=True,
262
+ )
263
+
264
+ st.markdown("---")
pages/03_🚀_Find_my_show_(FAISS).py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import importlib.util
3
+ import torch
4
+ import streamlit as st
5
+ import pandas as pd
6
+ from PIL import Image
7
+
8
+
9
+ # Формируем абсолютный путь до файла functions.py
10
+ module_path = os.path.abspath(
11
+ os.path.join(os.path.dirname(__file__), "..", "resource", "functions.py")
12
+ )
13
+
14
+ # Загружаем модуль
15
+ spec = importlib.util.spec_from_file_location("resource.functions", module_path)
16
+ functions = importlib.util.module_from_spec(spec)
17
+ spec.loader.exec_module(functions)
18
+
19
+ # Теперь используем функции напрямую
20
+ table_maker = functions.table_maker
21
+ # RecSys = functions.RecSys
22
+ FAISS_inference = functions.FAISS_inference
23
+
24
+ poster_path = "https://resizer.mail.ru/p/"
25
+ show_path = "https://kino.mail.ru/series_"
26
+ placeholder_path = "../img/v2/nopicture/308x462.png"
27
+
28
+
29
+ @st.cache(allow_output_mutation=True)
30
+ def load_model(model_path):
31
+ model = torch.load(model_path)
32
+ return model
33
+
34
+
35
+ @st.cache
36
+ def load_data(data_path):
37
+ df = pd.read_pickle(data_path)
38
+ return df
39
+
40
+
41
+ MODEL_PATH = "model/model.pt"
42
+
43
+ df = pd.read_pickle("data/data.pkl")
44
+ model = torch.load(MODEL_PATH)
45
+
46
+
47
+ image = Image.open("pages/tv_shows.png")
48
+ st.image(image, use_column_width=True)
49
+
50
+ # Заголовок приложения
51
+ st.markdown("### Поиск сериалов по запросу пользователя (с использованием FAISS)")
52
+
53
+
54
+ # Создание списка уникальных стран
55
+ all_countries = sorted(set(df["county"].tolist()))
56
+
57
+ # Создание списка уникальных жанров
58
+ all_genres = set()
59
+ for genres_set in df["tags"].dropna():
60
+ all_genres.update(genres_set)
61
+ all_genres = sorted(all_genres)
62
+
63
+ # Фильтр по наличию рейтинга
64
+ has_rating = st.sidebar.checkbox("Показывать только сериалы с рейтингом?", True)
65
+
66
+ # Виджеты для боковой панели
67
+ selected_country = st.sidebar.multiselect("Страна", all_countries)
68
+ selected_genre = st.sidebar.multiselect("Жанры", all_genres)
69
+
70
+
71
+ rating = True
72
+
73
+ search_table = table_maker(
74
+ df=df,
75
+ country=selected_country,
76
+ min_year=int(df["year"].min()),
77
+ max_year=int(df["year"].max()),
78
+ tagger=set(selected_genre),
79
+ rating=has_rating,
80
+ )
81
+
82
+ # Проверяем, пустой ли отфильтрованный DataFrame
83
+ if search_table.empty:
84
+ st.error(
85
+ "После фильтрации данных не осталось. Пожалуйста, выберите другие параметры."
86
+ )
87
+ else:
88
+ # Преобразование year в числовой формат, если возможно, и обработка NaN значений
89
+ search_table["year"] = pd.to_numeric(search_table["year"], errors="coerce").dropna()
90
+
91
+ if search_table.empty:
92
+ st.error(
93
+ "После фильтрации и обработки годов в данных не осталось записей. Пожалуйста, выберите другие параметры."
94
+ )
95
+ else:
96
+ # Теперь безопасно ищем min и max
97
+ min_year = int(search_table["year"].min())
98
+ max_year = int(search_table["year"].max())
99
+
100
+ # Если есть хотя бы два разных года, отображаем слайдер
101
+ if min_year < max_year:
102
+ selected_year_range = st.sidebar.slider(
103
+ "Выберите диапазон лет выпуска",
104
+ min_value=min_year,
105
+ max_value=max_year,
106
+ value=(min_year, max_year),
107
+ )
108
+ # Применяем фильтр по годам
109
+ search_table = search_table[
110
+ (search_table["year"] >= selected_year_range[0])
111
+ & (search_table["year"] <= selected_year_range[1])
112
+ ]
113
+
114
+ st.sidebar.markdown("---")
115
+ st.sidebar.markdown("### Дополнительные настройки")
116
+
117
+ # Позволяет пользователю выбрать количество сериалов для отображения, от 1 до 10
118
+ top_n = st.sidebar.number_input(
119
+ "Сколько сериалов показывать?", min_value=1, max_value=10, value=5
120
+ )
121
+
122
+ # Создание текстового поля для ввода пользовательского запроса
123
+ user_request = st.text_input(
124
+ "Введите ваш запрос:",
125
+ "про ментов, мусора по коням, менты, полиция и все такое",
126
+ )
127
+
128
+ user_request_emb = model.encode(user_request)
129
+
130
+ if st.button("Найти сериалы по запросу") and len(df) > 0:
131
+
132
+ output_faiss = FAISS_inference(search_table, user_request_emb, top_n)
133
+
134
+ # top_n = 5 # мин 1 макс 10
135
+ res = output_faiss()
136
+
137
+ (
138
+ poster,
139
+ title,
140
+ description,
141
+ rating,
142
+ genre,
143
+ cast,
144
+ score,
145
+ year,
146
+ links,
147
+ country,
148
+ ) = (
149
+ {},
150
+ {},
151
+ {},
152
+ {},
153
+ {},
154
+ {},
155
+ {},
156
+ {},
157
+ {},
158
+ {},
159
+ )
160
+
161
+ for i, con in enumerate(res["poster"]):
162
+ # Проверяем, является ли значение в con ссылкой или путем к файлу
163
+ if "nopicture" in con:
164
+ poster[i] = placeholder_path
165
+ else:
166
+ poster[i] = poster_path + con
167
+
168
+ for i, con in enumerate(res["year"]):
169
+ year[i] = con
170
+
171
+ for i, con in enumerate(res["title"]):
172
+ title[i] = con
173
+
174
+ for i, con in enumerate(res["description"]):
175
+ description[i] = con
176
+
177
+ for i, con in enumerate(res["rating"]):
178
+ rating[i] = con
179
+
180
+ for i, con in enumerate(res["tags"]):
181
+ genre[i] = ", ".join(con)
182
+
183
+ for i, con in enumerate(res["cast"]):
184
+ cast[i] = con
185
+
186
+ for i, con in enumerate(res["score"]):
187
+ score[i] = con
188
+
189
+ for i, con in enumerate(res["url"]):
190
+ links[i] = show_path + con
191
+
192
+ for i, con in enumerate(res["county"]):
193
+ country[i] = con
194
+
195
+ st.markdown("---")
196
+
197
+ # Проверяем, пустой ли набор результатов
198
+ if len(res) == 0:
199
+ st.error(
200
+ "Сериалы по выбранным параметрам не найдены. Попробуйте изменить критерии поиска."
201
+ )
202
+ else:
203
+ # Если результаты есть, выводим их
204
+ iterations = min(len(res), top_n)
205
+
206
+ for i in range(iterations):
207
+
208
+ col1, col2 = st.columns([1, 3])
209
+ with col1:
210
+ st.image(poster[i])
211
+ # Добавляем ссылку под картинкой
212
+ st.markdown(
213
+ f"<a href='{links[i]}' target='_blank' style='display: block; text-align: center; color: grey; font-size: small; font-style: italic;'>Смотреть сериал</a>",
214
+ unsafe_allow_html=True,
215
+ )
216
+
217
+ with col2:
218
+
219
+ st.markdown(
220
+ f"<span style='font-weight:bold; font-size:22px;'>Название сериала:</span> <span style='font-size:20px;'>«{title[i]}»</span>",
221
+ unsafe_allow_html=True,
222
+ )
223
+
224
+ st.markdown(
225
+ f"<span style='font-weight:bold; font-size:16px;'>Страна:</span> <span style='font-size:16px;'>{country[i]}</span>",
226
+ unsafe_allow_html=True,
227
+ )
228
+
229
+ st.markdown(
230
+ f"<span style='font-weight:bold; font-size:16px;'>Год выпуска:</span> <span style='font-size:16px;'>{year[i]}</span>",
231
+ unsafe_allow_html=True,
232
+ )
233
+
234
+ st.markdown(
235
+ f"<span style='font-weight:bold; font-size:16px;'>Жанр:</span> <span style='font-size:16px;'>{genre[i]}</span>",
236
+ unsafe_allow_html=True,
237
+ )
238
+
239
+ rating_display = (
240
+ "Нет информации" if pd.isna(rating[i]) else rating[i]
241
+ )
242
+
243
+ st.markdown(
244
+ f"<span style='font-weight:bold; font-size:16px;'>Рейтинг:</span> <span style='font-size:16px;'>{rating_display}</span>",
245
+ unsafe_allow_html=True,
246
+ )
247
+
248
+ st.markdown(
249
+ "<h6 style='font-weight:bold;'>В ролях:</h6>",
250
+ unsafe_allow_html=True,
251
+ )
252
+
253
+ st.markdown(
254
+ f"<div style='text-align: justify; margin-bottom: 18px;'>{cast[i]}</div>",
255
+ unsafe_allow_html=True,
256
+ )
257
+
258
+ st.markdown(
259
+ "<h6 style='font-weight:bold;'>Описание:</h6>",
260
+ unsafe_allow_html=True,
261
+ )
262
+
263
+ st.markdown(
264
+ f"<div style='text-align: justify;'>{description[i]}</div>",
265
+ unsafe_allow_html=True,
266
+ )
267
+ score_display = round(score[i], 3)
268
+ st.markdown(
269
+ f"<div style='color: grey;'><hr style='margin: 2px 0;'/><span style='font-weight:bold; font-size:13px; font-style: italic;'>Оценка FAISS (расстояние):</span> <span style='font-size:13px; font-style: italic;'>{score_display}</span><hr style='margin: 2px 0;'/></div>",
270
+ unsafe_allow_html=True,
271
+ )
272
+
273
+ st.markdown("---")
pages/__init__.py ADDED
File without changes
pages/tv_shows.png ADDED
resource/__init__.py ADDED
File without changes
resource/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (207 Bytes). View file
 
resource/__pycache__/functions.cpython-311.pyc ADDED
Binary file (5.17 kB). View file
 
resource/functions.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ import faiss
4
+ import numpy as np
5
+ from numpy import dot
6
+ from numpy.linalg import norm
7
+
8
+
9
+ def table_maker(
10
+ df: pd.DataFrame,
11
+ country: list = [],
12
+ min_year: int = 1999,
13
+ max_year: int = None,
14
+ tagger=set(),
15
+ rating: bool = True,
16
+ ):
17
+
18
+ x = df.copy()
19
+ # фильтр по рейтингк
20
+ if rating:
21
+ rat_con = ~(x["rating"].isna())
22
+ else:
23
+ rat_con = ~(x["url"].isna())
24
+ # фильтр по стране
25
+ if country == []:
26
+ con_con = ~(x["url"].isna())
27
+ else:
28
+ con_con = x["county"].isin(country)
29
+ # фильтр по тегам
30
+ if tagger == set():
31
+ tagger_con = ~(x["url"].isna())
32
+ else:
33
+ tagger_con = x["tags"].ge(tagger)
34
+
35
+ # Условие для фильтрации по минимальному году
36
+ year_cond = x["year"] >= min_year
37
+
38
+ # Добавляем условие для фильтрации по максимальному году, если оно задано
39
+ if max_year is not None:
40
+ year_cond &= x["year"] <= max_year
41
+
42
+ condi = rat_con & con_con & tagger_con & year_cond
43
+
44
+ return x.loc[condi]
45
+
46
+
47
+ class RecSys:
48
+ def __init__(self, df: pd.DataFrame, input_, model):
49
+ self.df = df
50
+ self.input_ = input_
51
+ self.model = model
52
+ with torch.no_grad():
53
+ self.emb = model.encode(self.input_)
54
+
55
+ def __call__(self):
56
+
57
+ def compute(a):
58
+ return dot(a, self.emb) / (norm(a) * norm(self.emb))
59
+
60
+ res = self.df.copy()
61
+ res["compute"] = res["vec"].map(compute)
62
+ res["compute2"] = res["vec2"].map(compute)
63
+ self.df["score"] = res["compute"] * 0.8 + res["compute2"] * 0.2
64
+
65
+ return self.df.sort_values("score", ascending=False)
66
+
67
+
68
+ class FAISS_inference:
69
+ def __init__(self, df, emb, k=5):
70
+ self.df = df
71
+ self.emb = emb.reshape(1, -1)
72
+ self.k = k
73
+
74
+ vec = df["vec"].to_numpy()
75
+ self.d = vec[0].shape[0]
76
+ for i, e in enumerate(vec):
77
+ if i == 0:
78
+ vex = e.T
79
+ else:
80
+ temp = e.T
81
+ vex = np.append(vex, temp)
82
+ self.vex = np.reshape(vex, (-1, 384))
83
+
84
+ # self.index = faiss.IndexFlatIP(self.d)
85
+ # self.index = faiss.IndexFlatL2(self.d)
86
+ self.index = faiss.IndexFlat(self.d)
87
+
88
+ self.index.add(self.vex)
89
+
90
+ def __call__(self):
91
+
92
+ d, i = self.index.search(self.emb, self.k)
93
+
94
+ faiss_table = self.df.iloc[i[0]]
95
+ faiss_table.loc[:, "score"] = d[0]
96
+ return faiss_table