import streamlit as st import torch from transformers import AutoModelForSequenceClassification, pipeline, AutoTokenizer, DistilBertForSequenceClassification, DistilBertTokenizerFast import pandas as pd import comments from random import randint import requests def predict_cyberbullying_probability(sentence, tokenizer, model): # Preprocess the input sentence inputs = tokenizer(sentence, padding='max_length', return_token_type_ids=False, return_attention_mask=True, truncation=True, max_length=512, return_tensors='pt') attention_mask = inputs['attention_mask'] inputs = inputs['input_ids'] with torch.no_grad(): # Forward pass outputs = model(inputs, attention_mask=attention_mask) probs = torch.sigmoid(outputs.logits.unsqueeze(1).flatten()) res = probs.numpy().tolist() return res # @st.cache def perform_cyberbullying_analysis(tweet): with st.spinner(text="loading model, wait until spinner ends..."): model = AutoModelForSequenceClassification.from_pretrained('kingsotn/finetuned_cyberbullying') tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased') df = pd.DataFrame({'comment': [tweet]}) list_probs = predict_cyberbullying_probability(tweet, tokenizer, model) for i, label in enumerate(labels[1:]): df[label] = list_probs[i] return df def perform_default_analysis(model_name): tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name) clf = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, framework="pt") tweet = st.text_area(label="Enter Text:",value="I'm nice at ping pong") submitted = st.form_submit_button("Analyze") if submitted: #loading bar with st.spinner(text="loading..."): out = clf(tweet) st.json(out) if out[0]["label"] == "POSITIVE" or out[0]["label"] == "POS": st.balloons() # prompt = f"{basic_prompt} + \n\nThe user wrote a tweet that says: {tweet}, compliment them on how nice of a person they are! Remember try to be as cringe and awkard as possible!" # response = generator(prompt, max_length=1000)[0] st.success("nice tweet!") else: # prompt = f"{basic_prompt} + \n\nThe user wrote a tweet that says: {tweet}, tell them on how terrible of a person they are! Remember try to be as cringe and awkard as possible!" # response = generator(prompt, max_length=1000)[0] st.error("bad tweet!") # main --> st.title("Toxic Tweets Analyzer") image = "kanye_loves_tweet.jpg" st.image(image, use_column_width=True) labels = ['comment', 'toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] with st.form("my_form"): #select model model_name = st.selectbox("Enter a text and select a pre-trained model to get the sentiment analysis", ["kingsotn/finetuned_cyberbullying", "distilbert-base-uncased-finetuned-sst-2-english", "finiteautomata/bertweet-base-sentiment-analysis", "distilbert-base-uncased"]) if model_name == "kingsotn/finetuned_cyberbullying": default = "I'm not even going to lie to you. I love me so much right now." tweet = st.text_area(label="Enter Text:",value=default) submitted = st.form_submit_button("Analyze textbox") random = st.form_submit_button("Get a random 😈😈😈 tweet (warning!!)") kanye = st.form_submit_button("Get a ye quote 🐻🎤🎧🎶") if random: tweet = comments.comments[randint(0, 354)] st.write(tweet) submitted = True if kanye: response = requests.get('https://api.kanye.rest/') if response.status_code == 200: data = response.json() tweet = data['quote'] else: st.error("Error getting Kanye quote | status code: " + str(response.status_code)) st.write(tweet) submitted = True if submitted: df = perform_cyberbullying_analysis(tweet) st.table(df) else: perform_default_analysis(model_name)