awinml commited on
Commit
a7fc504
·
1 Parent(s): 18468cb

Upload 16 files (#11)

Browse files

- Upload 16 files (75295ea33fd86490c00779398a7101ad71ee52e9)

app.py CHANGED
@@ -14,12 +14,14 @@ from utils.entity_extraction import (
14
  extract_ticker_spacy,
15
  format_entities_flan_alpaca,
16
  generate_alpaca_ner_prompt,
 
17
  )
18
  from utils.models import (
19
  generate_entities_flan_alpaca_checkpoint,
20
  generate_entities_flan_alpaca_inference_api,
21
  generate_text_flan_t5,
22
  get_data,
 
23
  get_flan_alpaca_xl_model,
24
  get_flan_t5_model,
25
  get_instructor_embedding_model,
@@ -85,6 +87,8 @@ with st.sidebar:
85
  if ner_choice == "Spacy":
86
  ner_model = get_spacy_model()
87
 
 
 
88
  with col1:
89
  st.subheader("Question")
90
  if document_type == "Single-Document":
@@ -104,6 +108,10 @@ with col1:
104
  value="How was AAPL's capex spend compared to GOOGL?",
105
  )
106
 
 
 
 
 
107
  years_choice = ["2020", "2019", "2018", "2017", "2016", "All"]
108
  quarters_choice = ["Q1", "Q2", "Q3", "Q4", "All"]
109
  ticker_choice = [
@@ -382,6 +390,7 @@ if document_type == "Single-Document":
382
  quarter,
383
  ticker,
384
  participant_type,
 
385
  threshold,
386
  )
387
 
 
14
  extract_ticker_spacy,
15
  format_entities_flan_alpaca,
16
  generate_alpaca_ner_prompt,
17
+ extract_keywords
18
  )
19
  from utils.models import (
20
  generate_entities_flan_alpaca_checkpoint,
21
  generate_entities_flan_alpaca_inference_api,
22
  generate_text_flan_t5,
23
  get_data,
24
+ get_alpaca_model,
25
  get_flan_alpaca_xl_model,
26
  get_flan_t5_model,
27
  get_instructor_embedding_model,
 
87
  if ner_choice == "Spacy":
88
  ner_model = get_spacy_model()
89
 
90
+ alpaca_model = get_alpaca_model()
91
+
92
  with col1:
93
  st.subheader("Question")
94
  if document_type == "Single-Document":
 
108
  value="How was AAPL's capex spend compared to GOOGL?",
109
  )
110
 
111
+
112
+ # Extract keywords from query
113
+ keywords = extract_keywords(query_text, alpaca_model)
114
+
115
  years_choice = ["2020", "2019", "2018", "2017", "2016", "All"]
116
  quarters_choice = ["Q1", "Q2", "Q3", "Q4", "All"]
117
  ticker_choice = [
 
390
  quarter,
391
  ticker,
392
  participant_type,
393
+ keywords,
394
  threshold,
395
  )
396
 
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  pandas
 
2
  tqdm
3
  pinecone-client
4
  spacy[transformers] == 3.3.0
@@ -12,3 +13,5 @@ streamlit
12
  streamlit-scrollable-textbox
13
  openai
14
  InstructorEmbedding
 
 
 
1
  pandas
2
+ nltk
3
  tqdm
4
  pinecone-client
5
  spacy[transformers] == 3.3.0
 
13
  streamlit-scrollable-textbox
14
  openai
15
  InstructorEmbedding
16
+ gradio_client
17
+
utils/__pycache__/entity_extraction.cpython-38.pyc CHANGED
Binary files a/utils/__pycache__/entity_extraction.cpython-38.pyc and b/utils/__pycache__/entity_extraction.cpython-38.pyc differ
 
utils/__pycache__/models.cpython-38.pyc CHANGED
Binary files a/utils/__pycache__/models.cpython-38.pyc and b/utils/__pycache__/models.cpython-38.pyc differ
 
utils/__pycache__/retriever.cpython-38.pyc CHANGED
Binary files a/utils/__pycache__/retriever.cpython-38.pyc and b/utils/__pycache__/retriever.cpython-38.pyc differ
 
utils/__pycache__/vector_index.cpython-38.pyc CHANGED
Binary files a/utils/__pycache__/vector_index.cpython-38.pyc and b/utils/__pycache__/vector_index.cpython-38.pyc differ
 
utils/entity_extraction.py CHANGED
@@ -1,4 +1,54 @@
1
  import re
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  # Entity Extraction
4
 
 
1
  import re
2
+ from nltk.stem import PorterStemmer, WordNetLemmatizer
3
+
4
+ # Keyword Extracttion
5
+
6
+
7
+ def expand_list_of_lists(list_of_lists):
8
+ """
9
+ Expands a list of lists of strings to a list of strings.
10
+ Args:
11
+ list_of_lists: A list of lists of strings.
12
+ Returns:
13
+ A list of strings.
14
+ """
15
+
16
+ expanded_list = []
17
+ for inner_list in list_of_lists:
18
+ for string in inner_list:
19
+ expanded_list.append(string)
20
+ return expanded_list
21
+
22
+
23
+ def all_keywords_combs(texts):
24
+
25
+ texts = [text.split(" ") for text in texts]
26
+ texts = expand_list_of_lists(texts)
27
+
28
+ # Convert all strings to lowercase.
29
+ lower_texts = [text.lower() for text in texts]
30
+
31
+ # Stem the words in each string.
32
+ stemmer = PorterStemmer()
33
+ stem_texts = [stemmer.stem(text) for text in texts]
34
+
35
+ # Lemmatize the words in each string.
36
+ lemmatizer = WordNetLemmatizer()
37
+ lemm_texts = [lemmatizer.lemmatize(text) for text in texts]
38
+
39
+ texts.extend(lower_texts)
40
+ texts.extend(stem_texts)
41
+ texts.extend(lemm_texts)
42
+ return texts
43
+
44
+
45
+ def extract_keywords(query_text, model):
46
+ prompt = f"###Instruction:Extract the important keywords which describe the context accurately.\n\nInput:{query_text}\n\n###Response:"
47
+ response = model.predict(prompt)
48
+ keywords = response.split(", ")
49
+ keywords = all_keywords_combs(keywords)
50
+ return keywords
51
+
52
 
53
  # Entity Extraction
54
 
utils/models.py CHANGED
@@ -10,6 +10,7 @@ import streamlit_scrollable_textbox as stx
10
  import torch
11
  from InstructorEmbedding import INSTRUCTOR
12
  from sentence_transformers import SentenceTransformer
 
13
  from tqdm import tqdm
14
  from transformers import (
15
  AutoModelForMaskedLM,
@@ -103,6 +104,12 @@ def get_instructor_embedding_model():
103
  return model
104
 
105
 
 
 
 
 
 
 
106
  @st.experimental_memo
107
  def save_key(api_key):
108
  return api_key
 
10
  import torch
11
  from InstructorEmbedding import INSTRUCTOR
12
  from sentence_transformers import SentenceTransformer
13
+ from gradio_client import Client
14
  from tqdm import tqdm
15
  from transformers import (
16
  AutoModelForMaskedLM,
 
104
  return model
105
 
106
 
107
+ @st.experimental_singleton
108
+ def get_alpaca_model():
109
+ client = Client("https://awinml-alpaca-cpp.hf.space")
110
+ return client
111
+
112
+
113
  @st.experimental_memo
114
  def save_key(api_key):
115
  return api_key
utils/retriever.py CHANGED
@@ -7,6 +7,7 @@ def query_pinecone_sparse(
7
  quarter,
8
  ticker,
9
  participant_type,
 
10
  threshold=0.25,
11
  ):
12
  if participant_type == "Company Speaker":
@@ -33,6 +34,7 @@ def query_pinecone_sparse(
33
  "Quarter": {"$in": ["Q1", "Q2", "Q3", "Q4"]},
34
  "Ticker": {"$eq": ticker},
35
  "QA_Flag": {"$eq": participant},
 
36
  },
37
  include_metadata=True,
38
  )
@@ -54,6 +56,7 @@ def query_pinecone_sparse(
54
  "Quarter": {"$eq": quarter},
55
  "Ticker": {"$eq": ticker},
56
  "QA_Flag": {"$eq": participant},
 
57
  },
58
  include_metadata=True,
59
  )
@@ -68,6 +71,7 @@ def query_pinecone_sparse(
68
  "Quarter": {"$eq": quarter},
69
  "Ticker": {"$eq": ticker},
70
  "QA_Flag": {"$eq": participant},
 
71
  },
72
  include_metadata=True,
73
  )
@@ -88,6 +92,7 @@ def query_pinecone(
88
  quarter,
89
  ticker,
90
  participant_type,
 
91
  threshold=0.25,
92
  ):
93
  if participant_type == "Company Speaker":
@@ -113,6 +118,7 @@ def query_pinecone(
113
  "Quarter": {"$in": ["Q1", "Q2", "Q3", "Q4"]},
114
  "Ticker": {"$eq": ticker},
115
  "QA_Flag": {"$eq": participant},
 
116
  },
117
  include_metadata=True,
118
  )
@@ -133,6 +139,7 @@ def query_pinecone(
133
  "Quarter": {"$eq": quarter},
134
  "Ticker": {"$eq": ticker},
135
  "QA_Flag": {"$eq": participant},
 
136
  },
137
  include_metadata=True,
138
  )
@@ -146,6 +153,7 @@ def query_pinecone(
146
  "Quarter": {"$eq": quarter},
147
  "Ticker": {"$eq": ticker},
148
  "QA_Flag": {"$eq": participant},
 
149
  },
150
  include_metadata=True,
151
  )
 
7
  quarter,
8
  ticker,
9
  participant_type,
10
+ keywords=None,
11
  threshold=0.25,
12
  ):
13
  if participant_type == "Company Speaker":
 
34
  "Quarter": {"$in": ["Q1", "Q2", "Q3", "Q4"]},
35
  "Ticker": {"$eq": ticker},
36
  "QA_Flag": {"$eq": participant},
37
+ "Keywords": {"$in": keywords}
38
  },
39
  include_metadata=True,
40
  )
 
56
  "Quarter": {"$eq": quarter},
57
  "Ticker": {"$eq": ticker},
58
  "QA_Flag": {"$eq": participant},
59
+ "Keywords": {"$in": keywords}
60
  },
61
  include_metadata=True,
62
  )
 
71
  "Quarter": {"$eq": quarter},
72
  "Ticker": {"$eq": ticker},
73
  "QA_Flag": {"$eq": participant},
74
+ "Keywords": {"$in": keywords}
75
  },
76
  include_metadata=True,
77
  )
 
92
  quarter,
93
  ticker,
94
  participant_type,
95
+ keywords=None,
96
  threshold=0.25,
97
  ):
98
  if participant_type == "Company Speaker":
 
118
  "Quarter": {"$in": ["Q1", "Q2", "Q3", "Q4"]},
119
  "Ticker": {"$eq": ticker},
120
  "QA_Flag": {"$eq": participant},
121
+ "Keywords": {"$in": keywords}
122
  },
123
  include_metadata=True,
124
  )
 
139
  "Quarter": {"$eq": quarter},
140
  "Ticker": {"$eq": ticker},
141
  "QA_Flag": {"$eq": participant},
142
+ "Keywords": {"$in": keywords}
143
  },
144
  include_metadata=True,
145
  )
 
153
  "Quarter": {"$eq": quarter},
154
  "Ticker": {"$eq": ticker},
155
  "QA_Flag": {"$eq": participant},
156
+ "Keywords": {"$in": keywords}
157
  },
158
  include_metadata=True,
159
  )