Spaces:
Runtime error
Runtime error
mlkorra
commited on
Commit
·
0ffa809
1
Parent(s):
0d949dd
Add app
Browse files- app.py +121 -0
- requirements.txt +5 -0
app.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import re
|
5 |
+
import string
|
6 |
+
|
7 |
+
from nltk.stem import WordNetLemmatizer
|
8 |
+
import umap
|
9 |
+
|
10 |
+
import plotly.graph_objects as go
|
11 |
+
from plotly import tools
|
12 |
+
import plotly.offline as py
|
13 |
+
import plotly.express as px
|
14 |
+
|
15 |
+
from nltk.corpus import stopwords
|
16 |
+
import nltk
|
17 |
+
nltk.download('stopwords')
|
18 |
+
nltk.download('wordnet')
|
19 |
+
from bertopic import BERTopic
|
20 |
+
import pickle
|
21 |
+
import os
|
22 |
+
|
23 |
+
def visualizer(prob_req, embed, df, index, company_name):
|
24 |
+
|
25 |
+
fname = 'topicmodel/saving_example.sav'
|
26 |
+
reducer= pickle.load((open(fname, 'rb'))) #load the umap dimensionality reduction model trained on rest of probablities
|
27 |
+
embed_req= reducer.transform(prob_req)
|
28 |
+
|
29 |
+
#add scatter plot for all embeddings from our dataset
|
30 |
+
fig1 = px.scatter(
|
31 |
+
embed, x=0, y=1,
|
32 |
+
color=df.iloc[index]['headquarters'], labels={'color': 'states'}, hover_name= df.iloc[index]['company_name'] + " with industry group: "+ df.iloc[index]['industry_groups'])
|
33 |
+
#add the data for users request and display
|
34 |
+
fig1.add_trace(
|
35 |
+
go.Scatter(
|
36 |
+
x=embed_req[:,0],
|
37 |
+
y=embed_req[:,1],
|
38 |
+
mode='markers',
|
39 |
+
marker_symbol="hexagon2", marker_size=15,
|
40 |
+
showlegend=True, name= company_name, hovertext= company_name))
|
41 |
+
st.plotly_chart(fig1)
|
42 |
+
|
43 |
+
def clean_text(text):
|
44 |
+
|
45 |
+
"""util function to clean the text"""
|
46 |
+
|
47 |
+
text = str(text).lower()
|
48 |
+
text = re.sub('https?://\S+|www\.\S+', '', text)
|
49 |
+
text = re.sub('<.,*?>+', '', text)
|
50 |
+
text = re.sub('[%s]' % re.escape(string.punctuation), '', text)
|
51 |
+
|
52 |
+
return text
|
53 |
+
|
54 |
+
def preprocess(name, group, state, states_used, desc):
|
55 |
+
desc = desc.replace(name,'')
|
56 |
+
cat = "".join(cat for cat in group.split(","))
|
57 |
+
cleaned= desc + " " + cat
|
58 |
+
|
59 |
+
stop_words = stopwords.words('english')
|
60 |
+
lemmatizer = WordNetLemmatizer()
|
61 |
+
text = clean_text(cleaned)
|
62 |
+
text = ' '.join(w for w in text.split(' ') if w not in stop_words)
|
63 |
+
text = ' '.join(lemmatizer.lemmatize(w) for w in text.split(' '))
|
64 |
+
return text
|
65 |
+
|
66 |
+
@st.cache(persist=True,suppress_st_warning=True)
|
67 |
+
def load_topic_model(model_path, name, group, state, states_used, desc):
|
68 |
+
|
69 |
+
|
70 |
+
#load Bertopic
|
71 |
+
model=BERTopic.load(model_path)
|
72 |
+
#load dataset (used for creating scatter plot)
|
73 |
+
|
74 |
+
data_path = 'topicmodel/data.csv'
|
75 |
+
df = pd.read_csv(data_path)
|
76 |
+
#load embeddings reduced by UMAP for the points to be displayed by scatter plot
|
77 |
+
|
78 |
+
embeddings_path = 'topicmodel/embed.npy'
|
79 |
+
embeddings = np.load(embeddings_path)
|
80 |
+
#preprocess user inputs
|
81 |
+
request= preprocess(name, group, state, states_used, desc)
|
82 |
+
index=[]
|
83 |
+
#only select states that user wants to compare
|
84 |
+
for state_used in states_used:
|
85 |
+
index.extend(df.index[df['headquarters'].str.contains(state_used)].tolist())
|
86 |
+
select=embeddings[index]
|
87 |
+
|
88 |
+
#use bert topic to get probabilities
|
89 |
+
topic, prob_req= model.transform([request])
|
90 |
+
st.text("Modelling done! plotting results now...")
|
91 |
+
|
92 |
+
return topic, prob_req, select, df, index
|
93 |
+
|
94 |
+
def app():
|
95 |
+
|
96 |
+
st.title("Competitive Analysis of Companies ")
|
97 |
+
companyname = st.text_input('Input company name here:', value="")
|
98 |
+
companygrp = st.text_input('Input industry group here:', value="")
|
99 |
+
companydesc = st.text_input("Input company description: (can be found in the company's linkedin page)", value="")
|
100 |
+
states= ['Georgia', 'California', 'Texas', 'Tennessee', 'Massachusetts',
|
101 |
+
'New York', 'Ohio', 'Delaware', 'Florida', 'Washington',
|
102 |
+
'Connecticut', 'Colorado', 'South Carolina', 'New Jersey',
|
103 |
+
'Michigan', 'Maryland', 'Pennsylvania', 'Virginia', 'Vermont',
|
104 |
+
'Minnesota', 'Illinois', 'North Carolina', 'Montana', 'Kentucky',
|
105 |
+
'Oregon', 'Iowa', 'District of Columbia', 'Arizona', 'Wisconsin',
|
106 |
+
'Louisiana', 'Idaho', 'Utah', 'Nevada', 'Nebraska', 'New Mexico',
|
107 |
+
'Missouri', 'Kansas', 'New Hampshire', 'Wyoming', 'Arkansas',
|
108 |
+
'Indiana', 'North Dakota', 'Hawaii', 'Alabama', 'Maine',
|
109 |
+
'Rhode Island', 'Mississippi', 'Alaska', 'Oklahoma',
|
110 |
+
'Washington DC', 'Giorgia']
|
111 |
+
state= st.selectbox('Select state the company is based in', states)
|
112 |
+
states_used = st.multiselect('Select states you want to analyse', states)
|
113 |
+
|
114 |
+
if(st.button("Analyse Competition")):
|
115 |
+
if companyname=="" or companydesc=="" or companygrp=="" or states_used==[]:
|
116 |
+
st.error("Some fields are empty!")
|
117 |
+
else:
|
118 |
+
|
119 |
+
model_path = 'topicmodel/my_model'
|
120 |
+
topic,prob_req,embed,df,index = load_topic_model(model_path, companyname, companygrp, state, states_used, companydesc)
|
121 |
+
visualizer(prob_req, embed, df, index, companyname)
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
umap-learn
|
2 |
+
nltk
|
3 |
+
plotly
|
4 |
+
bertopic
|
5 |
+
pickle
|