mlkorra commited on
Commit
0ffa809
·
1 Parent(s): 0d949dd
Files changed (2) hide show
  1. app.py +121 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import re
5
+ import string
6
+
7
+ from nltk.stem import WordNetLemmatizer
8
+ import umap
9
+
10
+ import plotly.graph_objects as go
11
+ from plotly import tools
12
+ import plotly.offline as py
13
+ import plotly.express as px
14
+
15
+ from nltk.corpus import stopwords
16
+ import nltk
17
+ nltk.download('stopwords')
18
+ nltk.download('wordnet')
19
+ from bertopic import BERTopic
20
+ import pickle
21
+ import os
22
+
23
+ def visualizer(prob_req, embed, df, index, company_name):
24
+
25
+ fname = 'topicmodel/saving_example.sav'
26
+ reducer= pickle.load((open(fname, 'rb'))) #load the umap dimensionality reduction model trained on rest of probablities
27
+ embed_req= reducer.transform(prob_req)
28
+
29
+ #add scatter plot for all embeddings from our dataset
30
+ fig1 = px.scatter(
31
+ embed, x=0, y=1,
32
+ color=df.iloc[index]['headquarters'], labels={'color': 'states'}, hover_name= df.iloc[index]['company_name'] + " with industry group: "+ df.iloc[index]['industry_groups'])
33
+ #add the data for users request and display
34
+ fig1.add_trace(
35
+ go.Scatter(
36
+ x=embed_req[:,0],
37
+ y=embed_req[:,1],
38
+ mode='markers',
39
+ marker_symbol="hexagon2", marker_size=15,
40
+ showlegend=True, name= company_name, hovertext= company_name))
41
+ st.plotly_chart(fig1)
42
+
43
+ def clean_text(text):
44
+
45
+ """util function to clean the text"""
46
+
47
+ text = str(text).lower()
48
+ text = re.sub('https?://\S+|www\.\S+', '', text)
49
+ text = re.sub('<.,*?>+', '', text)
50
+ text = re.sub('[%s]' % re.escape(string.punctuation), '', text)
51
+
52
+ return text
53
+
54
+ def preprocess(name, group, state, states_used, desc):
55
+ desc = desc.replace(name,'')
56
+ cat = "".join(cat for cat in group.split(","))
57
+ cleaned= desc + " " + cat
58
+
59
+ stop_words = stopwords.words('english')
60
+ lemmatizer = WordNetLemmatizer()
61
+ text = clean_text(cleaned)
62
+ text = ' '.join(w for w in text.split(' ') if w not in stop_words)
63
+ text = ' '.join(lemmatizer.lemmatize(w) for w in text.split(' '))
64
+ return text
65
+
66
+ @st.cache(persist=True,suppress_st_warning=True)
67
+ def load_topic_model(model_path, name, group, state, states_used, desc):
68
+
69
+
70
+ #load Bertopic
71
+ model=BERTopic.load(model_path)
72
+ #load dataset (used for creating scatter plot)
73
+
74
+ data_path = 'topicmodel/data.csv'
75
+ df = pd.read_csv(data_path)
76
+ #load embeddings reduced by UMAP for the points to be displayed by scatter plot
77
+
78
+ embeddings_path = 'topicmodel/embed.npy'
79
+ embeddings = np.load(embeddings_path)
80
+ #preprocess user inputs
81
+ request= preprocess(name, group, state, states_used, desc)
82
+ index=[]
83
+ #only select states that user wants to compare
84
+ for state_used in states_used:
85
+ index.extend(df.index[df['headquarters'].str.contains(state_used)].tolist())
86
+ select=embeddings[index]
87
+
88
+ #use bert topic to get probabilities
89
+ topic, prob_req= model.transform([request])
90
+ st.text("Modelling done! plotting results now...")
91
+
92
+ return topic, prob_req, select, df, index
93
+
94
+ def app():
95
+
96
+ st.title("Competitive Analysis of Companies ")
97
+ companyname = st.text_input('Input company name here:', value="")
98
+ companygrp = st.text_input('Input industry group here:', value="")
99
+ companydesc = st.text_input("Input company description: (can be found in the company's linkedin page)", value="")
100
+ states= ['Georgia', 'California', 'Texas', 'Tennessee', 'Massachusetts',
101
+ 'New York', 'Ohio', 'Delaware', 'Florida', 'Washington',
102
+ 'Connecticut', 'Colorado', 'South Carolina', 'New Jersey',
103
+ 'Michigan', 'Maryland', 'Pennsylvania', 'Virginia', 'Vermont',
104
+ 'Minnesota', 'Illinois', 'North Carolina', 'Montana', 'Kentucky',
105
+ 'Oregon', 'Iowa', 'District of Columbia', 'Arizona', 'Wisconsin',
106
+ 'Louisiana', 'Idaho', 'Utah', 'Nevada', 'Nebraska', 'New Mexico',
107
+ 'Missouri', 'Kansas', 'New Hampshire', 'Wyoming', 'Arkansas',
108
+ 'Indiana', 'North Dakota', 'Hawaii', 'Alabama', 'Maine',
109
+ 'Rhode Island', 'Mississippi', 'Alaska', 'Oklahoma',
110
+ 'Washington DC', 'Giorgia']
111
+ state= st.selectbox('Select state the company is based in', states)
112
+ states_used = st.multiselect('Select states you want to analyse', states)
113
+
114
+ if(st.button("Analyse Competition")):
115
+ if companyname=="" or companydesc=="" or companygrp=="" or states_used==[]:
116
+ st.error("Some fields are empty!")
117
+ else:
118
+
119
+ model_path = 'topicmodel/my_model'
120
+ topic,prob_req,embed,df,index = load_topic_model(model_path, companyname, companygrp, state, states_used, companydesc)
121
+ visualizer(prob_req, embed, df, index, companyname)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ umap-learn
2
+ nltk
3
+ plotly
4
+ bertopic
5
+ pickle