topic_suggest / app.py
Manish-4007's picture
update
9e26e04
import streamlit as st
import time
import requests
API_URL = "https://api-inference.huggingface.co/models/tuner007/pegasus_summarizer"
headers = {"Authorization": "Bearer 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"}
def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
@st.cache_resource
def load_topic_transfomers():
from transformers import pipeline
try:
topic_classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli",device="cuda", compute_type="float16")
except Exception as e:
topic_classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
print("Error: ", e)
return topic_classifier
def suggest_topic(topic_classifier,text):
# while len(text)> 1024:
# text = summarize(whole_text[:-10])
possible_topics = ["Gadgets", 'Business','Finance', 'Health', 'Sports', 'Politics','Government','Science','Education', 'Travel', 'Tourism', 'Finance & Economics','Market','Technology','Scientific Discovery',
'Entertainment','Environment','News & Media', "Space,Universe & Cosmos", "Fashion", "Manufacturing and Constructions","Law & Crime","Motivation", "Development & Socialization", "Archeology"]
result = topic_classifier(text, possible_topics)
return result['labels']
st.title("Topic Suggestion")
if 'topic_model' not in st.session_state:
with st.spinner("Loading Model....."):
st.session_state.topic_model = load_topic_transfomers()
st.success("Model_loaded")
st.session_state.model = True
whole_text = st.text_input("Enter the text Here: ")
try:
if st.button('Suggest topic'):
start= time.time()
output = query({
"inputs": whole_text,
})
st.subheader('Original Text: ')
st.write(whole_text)
st.subheader('\nSummarized Text:')
st.write(output[0]["summary_text"])
with st.spinner("Scanning content to suggest topics"):
topic_classifier = st.session_state.topic_model
predicted_topic = suggest_topic(topic_classifier,whole_text)
clk = time.time()-start
if clk < 60:
st.write(f'Generated in {(clk)} secs')
else:
st.write(f'Generated in {(clk)/60} minutes')
st.subheader('Top 10 Topics related to the content')
for i in predicted_topic[:10]:
st.write(i)
except Exception as e:
print("Error", e)