ariG23498's picture
ariG23498 HF staff
chore: scale change
73fb8da
raw
history blame
2.34 kB
from sentence_transformers import SentenceTransformer
import numpy as np
import requests
import gradio as gr
from description import project_description
def request_and_response(url):
response = requests.get(url)
papers = response.json()
return papers
def extract_abstracts_and_ids(papers):
abstracts = [paper["paper"]["summary"] for paper in papers]
paper_ids = [paper["paper"]["id"] for paper in papers]
return abstracts, paper_ids
def get_embeddings(model, texts):
embeddings = model.encode(texts)
return embeddings
def compute_similarity(model, embeddings1, embeddings2):
similarities = model.similarity(embeddings1, embeddings2)
return similarities
def find_closest(similarities, paper_ids):
best_match_idx = np.argmax(similarities)
best_match_id = paper_ids[best_match_idx]
return best_match_id
# Step 0: Get the model
model = SentenceTransformer("all-MiniLM-L6-v2")
# Step 1: Get papers from API
papers = request_and_response("https://hf.co/api/daily_papers")
# Step 2: Extract abstracts and paper ids
abstracts, paper_ids = extract_abstracts_and_ids(papers)
# Step 3: Embed Query and the Abstracts of papers
abstract_embeddings = get_embeddings(model, abstracts)
def get_closest_paper(query):
query_embeddings = get_embeddings(model, [query])
# Step 4: Find similarity scores
similarities = compute_similarity(model, query_embeddings, abstract_embeddings)
# Step 5: Find the closest match
best_match_id = find_closest(similarities, paper_ids)
# Step 6: Get the best match paper title and id
paper = request_and_response(f"https://hf.co/api/papers/{best_match_id}")
title = paper["title"]
summary = paper["summary"]
return title, f"https://huggingface.co/papers/{best_match_id}", summary
with gr.Blocks() as iface:
gr.Markdown(project_description)
with gr.Row():
with gr.Column(scale=1):
query = gr.Textbox(placeholder="What do you have in mind?")
btn = gr.Button(value="Submit")
with gr.Column(scale=3):
with gr.Row():
title = gr.Textbox()
paper_link = gr.Textbox()
abstract = gr.Textbox()
btn.click(get_closest_paper, query, [title, paper_link, abstract])
if __name__ == "__main__":
iface.launch()