find_my_show / pages /03_🚀_Find_my_show_(FAISS).py
IvT-DS's picture
Upload 10 files
41b0868 verified
raw
history blame
10.5 kB
import os
import importlib.util
import torch
import streamlit as st
import pandas as pd
from PIL import Image
# Формируем абсолютный путь до файла functions.py
module_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "resource", "functions.py")
)
# Загружаем модуль
spec = importlib.util.spec_from_file_location("resource.functions", module_path)
functions = importlib.util.module_from_spec(spec)
spec.loader.exec_module(functions)
# Теперь используем функции напрямую
table_maker = functions.table_maker
# RecSys = functions.RecSys
FAISS_inference = functions.FAISS_inference
poster_path = "https://resizer.mail.ru/p/"
show_path = "https://kino.mail.ru/series_"
placeholder_path = "../img/v2/nopicture/308x462.png"
@st.cache(allow_output_mutation=True)
def load_model(model_path):
model = torch.load(model_path)
return model
@st.cache
def load_data(data_path):
df = pd.read_pickle(data_path)
return df
MODEL_PATH = "model/model.pt"
df = pd.read_pickle("data/data.pkl")
model = torch.load(MODEL_PATH)
image = Image.open("pages/tv_shows.png")
st.image(image, use_column_width=True)
# Заголовок приложения
st.markdown("### Поиск сериалов по запросу пользователя (с использованием FAISS)")
# Создание списка уникальных стран
all_countries = sorted(set(df["county"].tolist()))
# Создание списка уникальных жанров
all_genres = set()
for genres_set in df["tags"].dropna():
all_genres.update(genres_set)
all_genres = sorted(all_genres)
# Фильтр по наличию рейтинга
has_rating = st.sidebar.checkbox("Показывать только сериалы с рейтингом?", True)
# Виджеты для боковой панели
selected_country = st.sidebar.multiselect("Страна", all_countries)
selected_genre = st.sidebar.multiselect("Жанры", all_genres)
rating = True
search_table = table_maker(
df=df,
country=selected_country,
min_year=int(df["year"].min()),
max_year=int(df["year"].max()),
tagger=set(selected_genre),
rating=has_rating,
)
# Проверяем, пустой ли отфильтрованный DataFrame
if search_table.empty:
st.error(
"После фильтрации данных не осталось. Пожалуйста, выберите другие параметры."
)
else:
# Преобразование year в числовой формат, если возможно, и обработка NaN значений
search_table["year"] = pd.to_numeric(search_table["year"], errors="coerce").dropna()
if search_table.empty:
st.error(
"После фильтрации и обработки годов в данных не осталось записей. Пожалуйста, выберите другие параметры."
)
else:
# Теперь безопасно ищем min и max
min_year = int(search_table["year"].min())
max_year = int(search_table["year"].max())
# Если есть хотя бы два разных года, отображаем слайдер
if min_year < max_year:
selected_year_range = st.sidebar.slider(
"Выберите диапазон лет выпуска",
min_value=min_year,
max_value=max_year,
value=(min_year, max_year),
)
# Применяем фильтр по годам
search_table = search_table[
(search_table["year"] >= selected_year_range[0])
& (search_table["year"] <= selected_year_range[1])
]
st.sidebar.markdown("---")
st.sidebar.markdown("### Дополнительные настройки")
# Позволяет пользователю выбрать количество сериалов для отображения, от 1 до 10
top_n = st.sidebar.number_input(
"Сколько сериалов показывать?", min_value=1, max_value=10, value=5
)
# Создание текстового поля для ввода пользовательского запроса
user_request = st.text_input(
"Введите ваш запрос:",
"про ментов, мусора по коням, менты, полиция и все такое",
)
user_request_emb = model.encode(user_request)
if st.button("Найти сериалы по запросу") and len(df) > 0:
output_faiss = FAISS_inference(search_table, user_request_emb, top_n)
# top_n = 5 # мин 1 макс 10
res = output_faiss()
(
poster,
title,
description,
rating,
genre,
cast,
score,
year,
links,
country,
) = (
{},
{},
{},
{},
{},
{},
{},
{},
{},
{},
)
for i, con in enumerate(res["poster"]):
# Проверяем, является ли значение в con ссылкой или путем к файлу
if "nopicture" in con:
poster[i] = placeholder_path
else:
poster[i] = poster_path + con
for i, con in enumerate(res["year"]):
year[i] = con
for i, con in enumerate(res["title"]):
title[i] = con
for i, con in enumerate(res["description"]):
description[i] = con
for i, con in enumerate(res["rating"]):
rating[i] = con
for i, con in enumerate(res["tags"]):
genre[i] = ", ".join(con)
for i, con in enumerate(res["cast"]):
cast[i] = con
for i, con in enumerate(res["score"]):
score[i] = con
for i, con in enumerate(res["url"]):
links[i] = show_path + con
for i, con in enumerate(res["county"]):
country[i] = con
st.markdown("---")
# Проверяем, пустой ли набор результатов
if len(res) == 0:
st.error(
"Сериалы по выбранным параметрам не найдены. Попробуйте изменить критерии поиска."
)
else:
# Если результаты есть, выводим их
iterations = min(len(res), top_n)
for i in range(iterations):
col1, col2 = st.columns([1, 3])
with col1:
st.image(poster[i])
# Добавляем ссылку под картинкой
st.markdown(
f"<a href='{links[i]}' target='_blank' style='display: block; text-align: center; color: grey; font-size: small; font-style: italic;'>Смотреть сериал</a>",
unsafe_allow_html=True,
)
with col2:
st.markdown(
f"<span style='font-weight:bold; font-size:22px;'>Название сериала:</span> <span style='font-size:20px;'>«{title[i]}»</span>",
unsafe_allow_html=True,
)
st.markdown(
f"<span style='font-weight:bold; font-size:16px;'>Страна:</span> <span style='font-size:16px;'>{country[i]}</span>",
unsafe_allow_html=True,
)
st.markdown(
f"<span style='font-weight:bold; font-size:16px;'>Год выпуска:</span> <span style='font-size:16px;'>{year[i]}</span>",
unsafe_allow_html=True,
)
st.markdown(
f"<span style='font-weight:bold; font-size:16px;'>Жанр:</span> <span style='font-size:16px;'>{genre[i]}</span>",
unsafe_allow_html=True,
)
rating_display = (
"Нет информации" if pd.isna(rating[i]) else rating[i]
)
st.markdown(
f"<span style='font-weight:bold; font-size:16px;'>Рейтинг:</span> <span style='font-size:16px;'>{rating_display}</span>",
unsafe_allow_html=True,
)
st.markdown(
"<h6 style='font-weight:bold;'>В ролях:</h6>",
unsafe_allow_html=True,
)
st.markdown(
f"<div style='text-align: justify; margin-bottom: 18px;'>{cast[i]}</div>",
unsafe_allow_html=True,
)
st.markdown(
"<h6 style='font-weight:bold;'>Описание:</h6>",
unsafe_allow_html=True,
)
st.markdown(
f"<div style='text-align: justify;'>{description[i]}</div>",
unsafe_allow_html=True,
)
score_display = round(score[i], 3)
st.markdown(
f"<div style='color: grey;'><hr style='margin: 2px 0;'/><span style='font-weight:bold; font-size:13px; font-style: italic;'>Оценка FAISS (расстояние):</span> <span style='font-size:13px; font-style: italic;'>{score_display}</span><hr style='margin: 2px 0;'/></div>",
unsafe_allow_html=True,
)
st.markdown("---")