Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import re | |
import string | |
from nltk.stem import WordNetLemmatizer | |
import umap | |
import plotly.graph_objects as go | |
from plotly import tools | |
import plotly.offline as py | |
import plotly.express as px | |
from nltk.corpus import stopwords | |
import nltk | |
nltk.download('stopwords') | |
nltk.download('wordnet') | |
from bertopic import BERTopic | |
import pickle | |
import os | |
def visualizer(prob_req, embed, df, index, company_name): | |
fname = 'topicmodel/saving_example.sav' | |
reducer= pickle.load((open(fname, 'rb'))) #load the umap dimensionality reduction model trained on rest of probablities | |
embed_req= reducer.transform(prob_req) | |
#add scatter plot for all embeddings from our dataset | |
fig1 = px.scatter( | |
embed, x=0, y=1, | |
color=df.iloc[index]['headquarters'], labels={'color': 'states'}, hover_name= df.iloc[index]['company_name'] + " with industry group: "+ df.iloc[index]['industry_groups']) | |
#add the data for users request and display | |
fig1.add_trace( | |
go.Scatter( | |
x=embed_req[:,0], | |
y=embed_req[:,1], | |
mode='markers', | |
marker_symbol="hexagon2", marker_size=15, | |
showlegend=True, name= company_name, hovertext= company_name)) | |
st.plotly_chart(fig1) | |
def clean_text(text): | |
"""util function to clean the text""" | |
text = str(text).lower() | |
text = re.sub('https?://\S+|www\.\S+', '', text) | |
text = re.sub('<.,*?>+', '', text) | |
text = re.sub('[%s]' % re.escape(string.punctuation), '', text) | |
return text | |
def preprocess(name, group, state, states_used, desc): | |
desc = desc.replace(name,'') | |
cat = "".join(cat for cat in group.split(",")) | |
cleaned= desc + " " + cat | |
stop_words = stopwords.words('english') | |
lemmatizer = WordNetLemmatizer() | |
text = clean_text(cleaned) | |
text = ' '.join(w for w in text.split(' ') if w not in stop_words) | |
text = ' '.join(lemmatizer.lemmatize(w) for w in text.split(' ')) | |
return text | |
def load_topic_model(model_path, name, group, state, states_used, desc): | |
#load Bertopic | |
model=BERTopic.load(model_path) | |
#load dataset (used for creating scatter plot) | |
data_path = 'topicmodel/data.csv' | |
df = pd.read_csv(data_path) | |
#load embeddings reduced by UMAP for the points to be displayed by scatter plot | |
embeddings_path = 'topicmodel/embed.npy' | |
embeddings = np.load(embeddings_path) | |
#preprocess user inputs | |
request= preprocess(name, group, state, states_used, desc) | |
index=[] | |
#only select states that user wants to compare | |
for state_used in states_used: | |
index.extend(df.index[df['headquarters'].str.contains(state_used)].tolist()) | |
select=embeddings[index] | |
#use bert topic to get probabilities | |
topic, prob_req= model.transform([request]) | |
st.text("Modelling done! plotting results now...") | |
return topic, prob_req, select, df, index | |
def app(): | |
st.title("Competitive Analysis of Companies ") | |
companyname = st.text_input('Input company name here:', value="") | |
companygrp = st.text_input('Input industry group here:', value="") | |
companydesc = st.text_input("Input company description: (can be found in the company's linkedin page)", value="") | |
states= ['Georgia', 'California', 'Texas', 'Tennessee', 'Massachusetts', | |
'New York', 'Ohio', 'Delaware', 'Florida', 'Washington', | |
'Connecticut', 'Colorado', 'South Carolina', 'New Jersey', | |
'Michigan', 'Maryland', 'Pennsylvania', 'Virginia', 'Vermont', | |
'Minnesota', 'Illinois', 'North Carolina', 'Montana', 'Kentucky', | |
'Oregon', 'Iowa', 'District of Columbia', 'Arizona', 'Wisconsin', | |
'Louisiana', 'Idaho', 'Utah', 'Nevada', 'Nebraska', 'New Mexico', | |
'Missouri', 'Kansas', 'New Hampshire', 'Wyoming', 'Arkansas', | |
'Indiana', 'North Dakota', 'Hawaii', 'Alabama', 'Maine', | |
'Rhode Island', 'Mississippi', 'Alaska', 'Oklahoma', | |
'Washington DC', 'Giorgia'] | |
state= st.selectbox('Select state the company is based in', states) | |
states_used = st.multiselect('Select states you want to analyse', states) | |
if(st.button("Analyse Competition")): | |
if companyname=="" or companydesc=="" or companygrp=="" or states_used==[]: | |
st.error("Some fields are empty!") | |
else: | |
model_path = 'topicmodel/my_model' | |
topic,prob_req,embed,df,index = load_topic_model(model_path, companyname, companygrp, state, states_used, companydesc) | |
visualizer(prob_req, embed, df, index, companyname) |