Spaces:
Runtime error
Runtime error
arjunpatel
commited on
Commit
·
b688d81
1
Parent(s):
59d8681
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from sentence_transformers import SentenceTransformer, util
|
3 |
+
from datasets import load_dataset
|
4 |
+
from app.data_cleaning import prepare_document, cos_dicts, retrieve_top_k_similar
|
5 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
demo = gr.Blocks()
|
10 |
+
|
11 |
+
df = load_dataset("arjunpatel/best-selling-video-games")
|
12 |
+
df.set_format("pandas")
|
13 |
+
df = df["train"][:]
|
14 |
+
|
15 |
+
cleaned_wikis = df.wiki_page.apply(lambda x: prepare_document(x))
|
16 |
+
tfidf = TfidfVectorizer()
|
17 |
+
tfidf_wikis = tfidf.fit_transform(cleaned_wikis.tolist())
|
18 |
+
video_game_cos_dict = cos_dicts(df.Title, tfidf_wikis.toarray())
|
19 |
+
|
20 |
+
embedder = SentenceTransformer('msmarco-MiniLM-L6-cos-v5')
|
21 |
+
msmarco_embeddings = embedder.encode(df.wiki_page.tolist(), convert_to_tensor = True)
|
22 |
+
|
23 |
+
def nli_search(query):
|
24 |
+
# given a query, return top few similar games
|
25 |
+
|
26 |
+
# example code taken from Sentence Transformers docs
|
27 |
+
query_embedding = embedder.encode(query, convert_to_tensor=True)
|
28 |
+
|
29 |
+
# We use cosine-similarity and torch.topk to find the highest 5 scores
|
30 |
+
cos_scores = util.cos_sim(query_embedding, msmarco_embeddings)[0]
|
31 |
+
top_results = torch.topk(cos_scores, k=5)
|
32 |
+
|
33 |
+
#print("\n\n======================\n\n")
|
34 |
+
#print("Query:", query)
|
35 |
+
#print("\nTop 5 most similar sentences in corpus:")
|
36 |
+
ret_list = []
|
37 |
+
|
38 |
+
for score, idx in zip(top_results[0], top_results[1]):
|
39 |
+
ret_list.append((df.wiki_page.tolist()[idx][0:100], "(Score: {:.4f})".format(score)))
|
40 |
+
|
41 |
+
return ret_list
|
42 |
+
|
43 |
+
|
44 |
+
def find_similar_games(name, num):
|
45 |
+
return retrieve_top_k_similar(name, video_game_cos_dict, num)
|
46 |
+
|
47 |
+
with demo:
|
48 |
+
gr.Markdown("<h1><center>Find your next Video Game!</center></h1>")
|
49 |
+
gr.Markdown(
|
50 |
+
"""This Gradio demo allows you to search a list of best selling video games and their corresponding Wikipedia pages
|
51 |
+
using NLP! The first tab allows for a TF-IDF based search, and the second leverages Sentence Transformers for a Natural Language
|
52 |
+
Search. Enjoy!""")
|
53 |
+
with gr.Tab("TF-IDF Similarity Search"):
|
54 |
+
video_game = gr.Dropdown(df.Title.tolist(), default = df.Title.tolist()[0],
|
55 |
+
label = "Selected Game")
|
56 |
+
|
57 |
+
num_similar = gr.Dropdown([1, 2, 3, 4, 5], default = 1, label = "Number of Similar Games")
|
58 |
+
|
59 |
+
find_similar = gr.Button("Find 'em!")
|
60 |
+
|
61 |
+
output = gr.Textbox("Games will appear here!")
|
62 |
+
|
63 |
+
find_similar.click(fn = find_similar_games, inputs = [video_game, num_similar],
|
64 |
+
outputs = output)
|
65 |
+
|
66 |
+
with gr.Tab("Natural Language Search"):
|
67 |
+
q = gr.Textbox("Type a query here. Try: find me mario games")
|
68 |
+
find_nli = gr.Button("Search!")
|
69 |
+
nli_output = gr.Textbox("Output will appear here from NLI search")
|
70 |
+
|
71 |
+
find_nli.click(fn = nli_search, inputs = [q], outputs = nli_output)
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
demo.launch()
|
76 |
+
#drop down for video game
|
77 |
+
|
78 |
+
#drop down for number of similar games (1-5)
|
79 |
+
|
80 |
+
#button to retrieve
|