Spaces:
Running
Running
from sentence_transformers import SentenceTransformer | |
import numpy as np | |
import requests | |
import gradio as gr | |
from description import project_description | |
def request_and_response(url): | |
response = requests.get(url) | |
papers = response.json() | |
return papers | |
def extract_abstracts_and_ids(papers): | |
abstracts = [paper["paper"]["summary"] for paper in papers] | |
paper_ids = [paper["paper"]["id"] for paper in papers] | |
return abstracts, paper_ids | |
def get_embeddings(model, texts): | |
embeddings = model.encode(texts) | |
return embeddings | |
def compute_similarity(model, embeddings1, embeddings2): | |
similarities = model.similarity(embeddings1, embeddings2) | |
return similarities | |
def find_closest(similarities, paper_ids): | |
best_match_idx = np.argmax(similarities) | |
best_match_id = paper_ids[best_match_idx] | |
return best_match_id | |
# Step 0: Get the model | |
model = SentenceTransformer("all-MiniLM-L6-v2") | |
# Step 1: Get papers from API | |
papers = request_and_response("https://hf.co/api/daily_papers") | |
# Step 2: Extract abstracts and paper ids | |
abstracts, paper_ids = extract_abstracts_and_ids(papers) | |
# Step 3: Embed Query and the Abstracts of papers | |
abstract_embeddings = get_embeddings(model, abstracts) | |
def get_closest_paper(query): | |
query_embeddings = get_embeddings(model, [query]) | |
# Step 4: Find similarity scores | |
similarities = compute_similarity(model, query_embeddings, abstract_embeddings) | |
# Step 5: Find the closest match | |
best_match_id = find_closest(similarities, paper_ids) | |
# Step 6: Get the best match paper title and id | |
paper = request_and_response(f"https://hf.co/api/papers/{best_match_id}") | |
title = paper["title"] | |
summary = paper["summary"] | |
return title, f"https://huggingface.co/papers/{best_match_id}", summary | |
with gr.Blocks() as iface: | |
gr.Markdown(project_description) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
query = gr.Textbox(placeholder="What do you have in mind?") | |
btn = gr.Button(value="Submit") | |
with gr.Column(scale=3): | |
with gr.Row(): | |
title = gr.Textbox() | |
paper_link = gr.Textbox() | |
abstract = gr.Textbox() | |
btn.click(get_closest_paper, query, [title, paper_link, abstract]) | |
if __name__ == "__main__": | |
iface.launch() |