AmirShabani
commited on
Commit
·
a3e15e0
1
Parent(s):
225ee02
Fuzzy matching
Browse files- core.py +5 -1
- requirements.txt +1 -0
core.py
CHANGED
@@ -28,6 +28,9 @@ from torch_geometric.nn.conv.gcn_conv import gcn_norm
|
|
28 |
from torch_geometric.nn.conv import MessagePassing
|
29 |
from torch_geometric.typing import Adj
|
30 |
from sklearn.neighbors import BallTree
|
|
|
|
|
|
|
31 |
class LightGCN(MessagePassing):
|
32 |
def __init__(self, num_users, num_items, embedding_dim=64, diffusion_steps=3, add_self_loops=False):
|
33 |
super().__init__()
|
@@ -186,7 +189,8 @@ def drop_non_numerical_columns(df):
|
|
186 |
def output_list(input_dict, movies_df = movie_embeds, tree = btree, user_embeddings = user_embeds, movies = final_movies):
|
187 |
movie_ratings = {}
|
188 |
for movie_title in input_dict:
|
189 |
-
|
|
|
190 |
movie_ratings[index] = input_dict[movie_title]
|
191 |
user_embed = create_user_embedding(movie_ratings, movie_embeds)
|
192 |
# Call the find_closest_user function with the pre-built BallTree
|
|
|
28 |
from torch_geometric.nn.conv import MessagePassing
|
29 |
from torch_geometric.typing import Adj
|
30 |
from sklearn.neighbors import BallTree
|
31 |
+
from thefuzz import fuzz
|
32 |
+
from thefuzz import process
|
33 |
+
|
34 |
class LightGCN(MessagePassing):
|
35 |
def __init__(self, num_users, num_items, embedding_dim=64, diffusion_steps=3, add_self_loops=False):
|
36 |
super().__init__()
|
|
|
189 |
def output_list(input_dict, movies_df = movie_embeds, tree = btree, user_embeddings = user_embeds, movies = final_movies):
|
190 |
movie_ratings = {}
|
191 |
for movie_title in input_dict:
|
192 |
+
matching_title = process.extractOne(movie_title, final_movies['title'].values, scorer=fuzz.token_sort_ratio)[0]
|
193 |
+
index = movies.index[movies['title'] == matching_title].tolist()[0]
|
194 |
movie_ratings[index] = input_dict[movie_title]
|
195 |
user_embed = create_user_embedding(movie_ratings, movie_embeds)
|
196 |
# Call the find_closest_user function with the pre-built BallTree
|
requirements.txt
CHANGED
@@ -2,6 +2,7 @@ requests==2.29.0
|
|
2 |
pillow
|
3 |
numpy==1.23.5
|
4 |
pandas==1.5.3
|
|
|
5 |
scikit-learn==1.2.2
|
6 |
torch==2.0.0
|
7 |
torchvision==0.15.1
|
|
|
2 |
pillow
|
3 |
numpy==1.23.5
|
4 |
pandas==1.5.3
|
5 |
+
thefuzz[speedup]
|
6 |
scikit-learn==1.2.2
|
7 |
torch==2.0.0
|
8 |
torchvision==0.15.1
|