mlkorra
Add app
0ffa809
raw
history blame
4.64 kB
import streamlit as st
import pandas as pd
import numpy as np
import re
import string
from nltk.stem import WordNetLemmatizer
import umap
import plotly.graph_objects as go
from plotly import tools
import plotly.offline as py
import plotly.express as px
from nltk.corpus import stopwords
import nltk
nltk.download('stopwords')
nltk.download('wordnet')
from bertopic import BERTopic
import pickle
import os
def visualizer(prob_req, embed, df, index, company_name):
fname = 'topicmodel/saving_example.sav'
reducer= pickle.load((open(fname, 'rb'))) #load the umap dimensionality reduction model trained on rest of probablities
embed_req= reducer.transform(prob_req)
#add scatter plot for all embeddings from our dataset
fig1 = px.scatter(
embed, x=0, y=1,
color=df.iloc[index]['headquarters'], labels={'color': 'states'}, hover_name= df.iloc[index]['company_name'] + " with industry group: "+ df.iloc[index]['industry_groups'])
#add the data for users request and display
fig1.add_trace(
go.Scatter(
x=embed_req[:,0],
y=embed_req[:,1],
mode='markers',
marker_symbol="hexagon2", marker_size=15,
showlegend=True, name= company_name, hovertext= company_name))
st.plotly_chart(fig1)
def clean_text(text):
"""util function to clean the text"""
text = str(text).lower()
text = re.sub('https?://\S+|www\.\S+', '', text)
text = re.sub('<.,*?>+', '', text)
text = re.sub('[%s]' % re.escape(string.punctuation), '', text)
return text
def preprocess(name, group, state, states_used, desc):
desc = desc.replace(name,'')
cat = "".join(cat for cat in group.split(","))
cleaned= desc + " " + cat
stop_words = stopwords.words('english')
lemmatizer = WordNetLemmatizer()
text = clean_text(cleaned)
text = ' '.join(w for w in text.split(' ') if w not in stop_words)
text = ' '.join(lemmatizer.lemmatize(w) for w in text.split(' '))
return text
@st.cache(persist=True,suppress_st_warning=True)
def load_topic_model(model_path, name, group, state, states_used, desc):
#load Bertopic
model=BERTopic.load(model_path)
#load dataset (used for creating scatter plot)
data_path = 'topicmodel/data.csv'
df = pd.read_csv(data_path)
#load embeddings reduced by UMAP for the points to be displayed by scatter plot
embeddings_path = 'topicmodel/embed.npy'
embeddings = np.load(embeddings_path)
#preprocess user inputs
request= preprocess(name, group, state, states_used, desc)
index=[]
#only select states that user wants to compare
for state_used in states_used:
index.extend(df.index[df['headquarters'].str.contains(state_used)].tolist())
select=embeddings[index]
#use bert topic to get probabilities
topic, prob_req= model.transform([request])
st.text("Modelling done! plotting results now...")
return topic, prob_req, select, df, index
def app():
st.title("Competitive Analysis of Companies ")
companyname = st.text_input('Input company name here:', value="")
companygrp = st.text_input('Input industry group here:', value="")
companydesc = st.text_input("Input company description: (can be found in the company's linkedin page)", value="")
states= ['Georgia', 'California', 'Texas', 'Tennessee', 'Massachusetts',
'New York', 'Ohio', 'Delaware', 'Florida', 'Washington',
'Connecticut', 'Colorado', 'South Carolina', 'New Jersey',
'Michigan', 'Maryland', 'Pennsylvania', 'Virginia', 'Vermont',
'Minnesota', 'Illinois', 'North Carolina', 'Montana', 'Kentucky',
'Oregon', 'Iowa', 'District of Columbia', 'Arizona', 'Wisconsin',
'Louisiana', 'Idaho', 'Utah', 'Nevada', 'Nebraska', 'New Mexico',
'Missouri', 'Kansas', 'New Hampshire', 'Wyoming', 'Arkansas',
'Indiana', 'North Dakota', 'Hawaii', 'Alabama', 'Maine',
'Rhode Island', 'Mississippi', 'Alaska', 'Oklahoma',
'Washington DC', 'Giorgia']
state= st.selectbox('Select state the company is based in', states)
states_used = st.multiselect('Select states you want to analyse', states)
if(st.button("Analyse Competition")):
if companyname=="" or companydesc=="" or companygrp=="" or states_used==[]:
st.error("Some fields are empty!")
else:
model_path = 'topicmodel/my_model'
topic,prob_req,embed,df,index = load_topic_model(model_path, companyname, companygrp, state, states_used, companydesc)
visualizer(prob_req, embed, df, index, companyname)