Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,19 +1,30 @@
|
|
1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from googleapiclient.discovery import build
|
3 |
from google_auth_oauthlib.flow import InstalledAppFlow
|
4 |
from google.auth.transport.requests import Request
|
5 |
from google.oauth2.credentials import Credentials
|
6 |
-
import os
|
7 |
-
import json
|
8 |
-
import pandas as pd
|
9 |
import base64
|
10 |
-
import
|
11 |
-
import
|
12 |
-
from fpdf import FPDF
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
15 |
|
16 |
-
#
|
|
|
|
|
|
|
17 |
if "authenticated" not in st.session_state:
|
18 |
st.session_state.authenticated = False
|
19 |
if "creds" not in st.session_state:
|
@@ -24,178 +35,394 @@ if "auth_code" not in st.session_state:
|
|
24 |
st.session_state.auth_code = ""
|
25 |
if "flow" not in st.session_state:
|
26 |
st.session_state.flow = None
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
if os.path.exists('token.json'):
|
31 |
try:
|
32 |
creds = Credentials.from_authorized_user_file('token.json', SCOPES)
|
33 |
if creds and creds.valid:
|
34 |
st.session_state.creds = creds
|
35 |
st.session_state.authenticated = True
|
36 |
-
st.success("Authentication successful!")
|
37 |
return creds
|
38 |
except Exception as e:
|
39 |
-
st.error(f"Invalid token.json file: {e}")
|
40 |
os.remove('token.json')
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
st.session_state.creds
|
45 |
st.session_state.authenticated = True
|
46 |
-
st.success("Authentication successful!")
|
47 |
-
|
|
|
|
|
48 |
else:
|
49 |
if not st.session_state.flow:
|
50 |
st.session_state.flow = InstalledAppFlow.from_client_secrets_file(credentials_file, SCOPES)
|
51 |
st.session_state.flow.redirect_uri = 'http://localhost'
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
st.
|
56 |
-
st.code(st.session_state.auth_url)
|
57 |
|
58 |
-
# Submit Authentication Code
|
59 |
def submit_auth_code():
|
60 |
try:
|
61 |
st.session_state.flow.fetch_token(code=st.session_state.auth_code)
|
62 |
st.session_state.creds = st.session_state.flow.credentials
|
63 |
st.session_state.authenticated = True
|
64 |
with open('token.json', 'w') as token_file:
|
65 |
-
|
66 |
-
|
67 |
-
"refresh_token": st.session_state.creds.refresh_token,
|
68 |
-
"token_uri": st.session_state.creds.token_uri,
|
69 |
-
"client_id": st.session_state.creds.client_id,
|
70 |
-
"client_secret": st.session_state.creds.client_secret,
|
71 |
-
"scopes": st.session_state.creds.scopes
|
72 |
-
}, token_file)
|
73 |
-
st.success("Authentication successful!")
|
74 |
except Exception as e:
|
75 |
-
st.error(f"Error during authentication: {e}")
|
76 |
-
|
77 |
-
#
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
import faiss
|
3 |
+
import numpy as np
|
4 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder
|
5 |
+
import requests
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
import pickle
|
9 |
+
from tqdm import tqdm
|
10 |
from googleapiclient.discovery import build
|
11 |
from google_auth_oauthlib.flow import InstalledAppFlow
|
12 |
from google.auth.transport.requests import Request
|
13 |
from google.oauth2.credentials import Credentials
|
|
|
|
|
|
|
14 |
import base64
|
15 |
+
import re
|
16 |
+
from pyngrok import ngrok
|
|
|
17 |
|
18 |
+
# ===============================
|
19 |
+
# 1. Streamlit App Configuration
|
20 |
+
# ===============================
|
21 |
+
st.set_page_config(page_title="π₯ Email Chat Application", layout="wide")
|
22 |
+
st.title("βοΈ Email Chat Application")
|
23 |
|
24 |
+
# ===============================
|
25 |
+
# 2. Gmail Authentication Configuration
|
26 |
+
# ===============================
|
27 |
+
SCOPES = ['https://www.googleapis.com/auth/gmail.readonly']
|
28 |
if "authenticated" not in st.session_state:
|
29 |
st.session_state.authenticated = False
|
30 |
if "creds" not in st.session_state:
|
|
|
35 |
st.session_state.auth_code = ""
|
36 |
if "flow" not in st.session_state:
|
37 |
st.session_state.flow = None
|
38 |
+
if "data_chunks" not in st.session_state:
|
39 |
+
st.session_state.data_chunks = [] # List to store all email chunks
|
40 |
+
if "embeddings" not in st.session_state:
|
41 |
+
st.session_state.embeddings = None
|
42 |
+
if "vector_store" not in st.session_state:
|
43 |
+
st.session_state.vector_store = None
|
44 |
|
45 |
+
def count_tokens(text):
|
46 |
+
return len(text.split())
|
47 |
+
|
48 |
+
# ===============================
|
49 |
+
# 3. Gmail Authentication Functions
|
50 |
+
# ===============================
|
51 |
+
def reset_session_state():
|
52 |
+
st.session_state.authenticated = False
|
53 |
+
st.session_state.creds = None
|
54 |
+
st.session_state.auth_url = None
|
55 |
+
st.session_state.auth_code = ""
|
56 |
+
st.session_state.flow = None
|
57 |
+
st.session_state.data_chunks = []
|
58 |
+
st.session_state.embeddings = None
|
59 |
+
st.session_state.vector_store = None
|
60 |
+
for filename in ["token.json", "data_chunks.pkl", "embeddings.pkl", "vector_store.index"]:
|
61 |
+
if os.path.exists(filename):
|
62 |
+
os.remove(filename)
|
63 |
+
|
64 |
+
def authenticate_gmail(credentials_file):
|
65 |
+
creds = None
|
66 |
if os.path.exists('token.json'):
|
67 |
try:
|
68 |
creds = Credentials.from_authorized_user_file('token.json', SCOPES)
|
69 |
if creds and creds.valid:
|
70 |
st.session_state.creds = creds
|
71 |
st.session_state.authenticated = True
|
72 |
+
st.success("β
Authentication successful!")
|
73 |
return creds
|
74 |
except Exception as e:
|
75 |
+
st.error(f"β Invalid token.json file: {e}")
|
76 |
os.remove('token.json')
|
77 |
+
if not creds or not creds.valid:
|
78 |
+
if creds and creds.expired and creds.refresh_token:
|
79 |
+
creds.refresh(Request())
|
80 |
+
st.session_state.creds = creds
|
81 |
st.session_state.authenticated = True
|
82 |
+
st.success("β
Authentication successful!")
|
83 |
+
with open('token.json', 'w') as token_file:
|
84 |
+
token_file.write(creds.to_json())
|
85 |
+
return creds
|
86 |
else:
|
87 |
if not st.session_state.flow:
|
88 |
st.session_state.flow = InstalledAppFlow.from_client_secrets_file(credentials_file, SCOPES)
|
89 |
st.session_state.flow.redirect_uri = 'http://localhost'
|
90 |
+
auth_url, _ = st.session_state.flow.authorization_url(prompt='consent')
|
91 |
+
st.session_state.auth_url = auth_url
|
92 |
+
st.info("π **Authorize the application by visiting the URL below:**")
|
93 |
+
st.markdown(f"[Authorize]({st.session_state.auth_url})")
|
|
|
94 |
|
|
|
95 |
def submit_auth_code():
|
96 |
try:
|
97 |
st.session_state.flow.fetch_token(code=st.session_state.auth_code)
|
98 |
st.session_state.creds = st.session_state.flow.credentials
|
99 |
st.session_state.authenticated = True
|
100 |
with open('token.json', 'w') as token_file:
|
101 |
+
token_file.write(st.session_state.creds.to_json())
|
102 |
+
st.success("β
Authentication successful!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
except Exception as e:
|
104 |
+
st.error(f"β Error during authentication: {e}")
|
105 |
+
|
106 |
+
# ===============================
|
107 |
+
# 4. Email Data Extraction, Embedding and Vector Store Functions
|
108 |
+
# ===============================
|
109 |
+
def extract_email_body(payload):
|
110 |
+
if 'body' in payload and 'data' in payload['body'] and payload['body']['data']:
|
111 |
+
try:
|
112 |
+
return base64.urlsafe_b64decode(payload['body']['data'].encode('UTF-8')).decode('UTF-8')
|
113 |
+
except Exception as e:
|
114 |
+
st.error(f"Error decoding email body: {e}")
|
115 |
+
return ""
|
116 |
+
if 'parts' in payload:
|
117 |
+
for part in payload['parts']:
|
118 |
+
if part.get('mimeType') == 'text/plain' and 'data' in part.get('body', {}):
|
119 |
+
try:
|
120 |
+
return base64.urlsafe_b64decode(part['body']['data'].encode('UTF-8')).decode('UTF-8')
|
121 |
+
except Exception as e:
|
122 |
+
st.error(f"Error decoding email part: {e}")
|
123 |
+
continue
|
124 |
+
if payload['parts']:
|
125 |
+
first_part = payload['parts'][0]
|
126 |
+
if 'data' in first_part.get('body', {}):
|
127 |
+
try:
|
128 |
+
return base64.urlsafe_b64decode(first_part['body']['data'].encode('UTF-8')).decode('UTF-8')
|
129 |
+
except Exception as e:
|
130 |
+
st.error(f"Error decoding fallback email part: {e}")
|
131 |
+
return ""
|
132 |
+
return ""
|
133 |
+
|
134 |
+
def combine_email_text(email):
|
135 |
+
parts = []
|
136 |
+
if email.get("sender"):
|
137 |
+
parts.append(f"Sender: {email['sender']}")
|
138 |
+
if email.get("to"):
|
139 |
+
parts.append(f"To: {email['to']}")
|
140 |
+
if email.get("date"):
|
141 |
+
parts.append(f"Date: {email['date']}")
|
142 |
+
if email.get("subject"):
|
143 |
+
parts.append(f"Subject: {email['subject']}")
|
144 |
+
if email.get("body"):
|
145 |
+
parts.append(f"Body: {email['body']}")
|
146 |
+
return "\n".join(parts)
|
147 |
+
|
148 |
+
def create_chunks_from_gmail(service, label):
|
149 |
+
try:
|
150 |
+
messages = []
|
151 |
+
result = service.users().messages().list(userId='me', labelIds=[label], maxResults=500).execute()
|
152 |
+
messages.extend(result.get('messages', []))
|
153 |
+
while 'nextPageToken' in result:
|
154 |
+
token = result["nextPageToken"]
|
155 |
+
result = service.users().messages().list(userId='me', labelIds=[label],
|
156 |
+
maxResults=500, pageToken=token).execute()
|
157 |
+
messages.extend(result.get('messages', []))
|
158 |
+
|
159 |
+
data_chunks = []
|
160 |
+
progress_bar = st.progress(0)
|
161 |
+
total = len(messages)
|
162 |
+
for idx, msg in enumerate(messages):
|
163 |
+
msg_data = service.users().messages().get(userId='me', id=msg['id'], format='full').execute()
|
164 |
+
headers = msg_data.get('payload', {}).get('headers', [])
|
165 |
+
email_dict = {"id": msg['id']}
|
166 |
+
for header in headers:
|
167 |
+
name = header.get('name', '').lower()
|
168 |
+
if name == 'from':
|
169 |
+
email_dict['sender'] = header.get('value', '')
|
170 |
+
elif name == 'subject':
|
171 |
+
email_dict['subject'] = header.get('value', '')
|
172 |
+
elif name == 'to':
|
173 |
+
email_dict['to'] = header.get('value', '')
|
174 |
+
elif name == 'date':
|
175 |
+
email_dict['date'] = header.get('value', '')
|
176 |
+
email_dict['body'] = extract_email_body(msg_data.get('payload', {}))
|
177 |
+
data_chunks.append(email_dict)
|
178 |
+
progress_bar.progress((idx + 1) / total)
|
179 |
+
st.session_state.data_chunks = data_chunks
|
180 |
+
st.success(f"β
Data chunks created successfully from Gmail! Total emails processed: {len(data_chunks)}")
|
181 |
+
# Save chunks locally for future use.
|
182 |
+
with open("data_chunks.pkl", "wb") as f:
|
183 |
+
pickle.dump(data_chunks, f)
|
184 |
+
except Exception as e:
|
185 |
+
st.error(f"β Error creating chunks from Gmail: {e}")
|
186 |
+
|
187 |
+
def embed_emails(email_chunks):
|
188 |
+
st.header("π Embedding Data and Creating Vector Store")
|
189 |
+
with st.spinner('π Embedding data...'):
|
190 |
+
try:
|
191 |
+
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
|
192 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
193 |
+
embed_model.to(device)
|
194 |
+
combined_texts = [combine_email_text(email) for email in email_chunks]
|
195 |
+
batch_size = 64
|
196 |
+
embeddings = []
|
197 |
+
for i in range(0, len(combined_texts), batch_size):
|
198 |
+
batch = combined_texts[i:i+batch_size]
|
199 |
+
batch_embeddings = embed_model.encode(
|
200 |
+
batch,
|
201 |
+
convert_to_numpy=True,
|
202 |
+
show_progress_bar=False,
|
203 |
+
device=device
|
204 |
+
)
|
205 |
+
embeddings.append(batch_embeddings)
|
206 |
+
embeddings = np.vstack(embeddings)
|
207 |
+
faiss.normalize_L2(embeddings)
|
208 |
+
st.session_state.embeddings = embeddings
|
209 |
+
dimension = embeddings.shape[1]
|
210 |
+
index = faiss.IndexFlatIP(dimension)
|
211 |
+
index.add(embeddings)
|
212 |
+
st.session_state.vector_store = index
|
213 |
+
st.success("β
Data embedding and vector store created successfully!")
|
214 |
+
# Save embeddings and index to disk.
|
215 |
+
with open('embeddings.pkl', 'wb') as f:
|
216 |
+
pickle.dump(embeddings, f)
|
217 |
+
faiss.write_index(index, 'vector_store.index')
|
218 |
+
except Exception as e:
|
219 |
+
st.error(f"β Error during embedding: {e}")
|
220 |
+
|
221 |
+
def save_embeddings_and_index():
|
222 |
+
try:
|
223 |
+
with open('embeddings.pkl', 'wb') as f:
|
224 |
+
pickle.dump(st.session_state.embeddings, f)
|
225 |
+
faiss.write_index(st.session_state.vector_store, 'vector_store.index')
|
226 |
+
st.success("πΎ Embeddings and vector store saved successfully!")
|
227 |
+
except Exception as e:
|
228 |
+
st.error(f"β Error saving embeddings and vector store: {e}")
|
229 |
+
|
230 |
+
def load_embeddings_and_index():
|
231 |
+
try:
|
232 |
+
with open('embeddings.pkl', 'rb') as f:
|
233 |
+
st.session_state.embeddings = pickle.load(f)
|
234 |
+
st.session_state.vector_store = faiss.read_index('vector_store.index')
|
235 |
+
st.success("π Embeddings and vector store loaded successfully!")
|
236 |
+
except Exception as e:
|
237 |
+
st.error(f"β Error loading embeddings and vector store: {e}")
|
238 |
+
|
239 |
+
def load_chunks():
|
240 |
+
try:
|
241 |
+
with open("data_chunks.pkl", "rb") as f:
|
242 |
+
st.session_state.data_chunks = pickle.load(f)
|
243 |
+
st.success("π Email chunks loaded successfully!")
|
244 |
+
except Exception as e:
|
245 |
+
st.error(f"β Error loading email chunks: {e}")
|
246 |
+
|
247 |
+
# ===============================
|
248 |
+
# 5. Handling User Queries
|
249 |
+
# ===============================
|
250 |
+
def preprocess_query(query):
|
251 |
+
return query.lower().strip()
|
252 |
+
|
253 |
+
def handle_user_query():
|
254 |
+
st.header("π¬ Let's chat with your Email")
|
255 |
+
user_query = st.text_input("Enter your query:")
|
256 |
+
TOP_K = 10
|
257 |
+
SIMILARITY_THRESHOLD = 0.4
|
258 |
+
|
259 |
+
if st.button("π Get Response"):
|
260 |
+
if (st.session_state.vector_store is None or
|
261 |
+
st.session_state.embeddings is None or
|
262 |
+
st.session_state.data_chunks is None):
|
263 |
+
st.error("β Please process your email data or load saved chunks/embeddings first.")
|
264 |
+
return
|
265 |
+
if not user_query.strip():
|
266 |
+
st.error("β Please enter a valid query.")
|
267 |
+
return
|
268 |
+
with st.spinner('π Processing your query...'):
|
269 |
+
try:
|
270 |
+
# Retrieve candidates using the bi-encoder.
|
271 |
+
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
|
272 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
273 |
+
embed_model.to(device)
|
274 |
+
processed_query = preprocess_query(user_query)
|
275 |
+
query_embedding = embed_model.encode(
|
276 |
+
[processed_query],
|
277 |
+
convert_to_numpy=True,
|
278 |
+
show_progress_bar=False,
|
279 |
+
device=device
|
280 |
+
)
|
281 |
+
faiss.normalize_L2(query_embedding)
|
282 |
+
distances, indices = st.session_state.vector_store.search(query_embedding, TOP_K)
|
283 |
+
candidates = []
|
284 |
+
for idx, score in zip(indices[0], distances[0]):
|
285 |
+
candidates.append((st.session_state.data_chunks[idx], score))
|
286 |
+
|
287 |
+
# Boost candidates if sender or "to" field contains query tokens (e.g., email addresses).
|
288 |
+
query_tokens = re.findall(r'\S+@\S+', user_query)
|
289 |
+
if query_tokens:
|
290 |
+
for i in range(len(candidates)):
|
291 |
+
candidate_email_str = (
|
292 |
+
(candidates[i][0].get("sender", "") + " " + candidates[i][0].get("to", "")).lower()
|
293 |
+
)
|
294 |
+
for token in query_tokens:
|
295 |
+
if token.lower() in candidate_email_str:
|
296 |
+
candidates[i] = (candidates[i][0], max(candidates[i][1], 1.0))
|
297 |
+
filtered_candidates = []
|
298 |
+
for candidate, score in candidates:
|
299 |
+
candidate_text = combine_email_text(candidate).lower()
|
300 |
+
if any(token.lower() in candidate_text for token in query_tokens):
|
301 |
+
filtered_candidates.append((candidate, score))
|
302 |
+
if filtered_candidates:
|
303 |
+
candidates = filtered_candidates
|
304 |
+
else:
|
305 |
+
st.info("No candidate emails contain the query token(s) exactly. Proceeding with all candidates.")
|
306 |
+
|
307 |
+
candidates.sort(key=lambda x: x[1], reverse=True)
|
308 |
+
if not candidates:
|
309 |
+
st.subheader("π AI Response:")
|
310 |
+
st.write("β οΈ No documents found.")
|
311 |
+
return
|
312 |
+
if candidates[0][1] < SIMILARITY_THRESHOLD:
|
313 |
+
st.subheader("π AI Response:")
|
314 |
+
st.write("β οΈ No document strongly matches your query. Try refining your query.")
|
315 |
+
return
|
316 |
+
|
317 |
+
# Re-rank candidates using the cross-encoder.
|
318 |
+
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2")
|
319 |
+
candidate_pairs = [(user_query, combine_email_text(candidate[0])) for candidate in candidates]
|
320 |
+
rerank_scores = cross_encoder.predict(candidate_pairs)
|
321 |
+
reranked_candidates = [(candidates[i][0], rerank_scores[i]) for i in range(len(candidates))]
|
322 |
+
reranked_candidates.sort(key=lambda x: x[1], reverse=True)
|
323 |
+
retrieved_emails = [email for email, score in reranked_candidates]
|
324 |
+
retrieved_scores = [score for email, score in reranked_candidates]
|
325 |
+
average_similarity = np.mean(retrieved_scores)
|
326 |
+
|
327 |
+
# Build the final context string.
|
328 |
+
context_str = "\n\n".join([combine_email_text(email) for email in retrieved_emails])
|
329 |
+
MAX_CONTEXT_TOKENS = 500
|
330 |
+
context_tokens = context_str.split()
|
331 |
+
if len(context_tokens) > MAX_CONTEXT_TOKENS:
|
332 |
+
context_str = " ".join(context_tokens[:MAX_CONTEXT_TOKENS])
|
333 |
+
|
334 |
+
payload = {
|
335 |
+
"model": "llama3-8b-8192", # Adjust as needed.
|
336 |
+
"messages": [
|
337 |
+
{"role": "system", "content": f"Use the following context:\n{context_str}"},
|
338 |
+
{"role": "user", "content": user_query}
|
339 |
+
]
|
340 |
+
}
|
341 |
+
api_key = "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" # Replace with your API key.
|
342 |
+
url = "https://api.groq.com/openai/v1/chat/completions"
|
343 |
+
headers = {
|
344 |
+
"Authorization": f"Bearer {api_key}",
|
345 |
+
"Content-Type": "application/json"
|
346 |
+
}
|
347 |
+
response = requests.post(url, headers=headers, json=payload)
|
348 |
+
if response.status_code == 200:
|
349 |
+
response_json = response.json()
|
350 |
+
generated_text = response_json["choices"][0]["message"]["content"]
|
351 |
+
st.subheader("π AI Response:")
|
352 |
+
st.write(generated_text)
|
353 |
+
st.write(f"Average Re-Ranked Score: {average_similarity:.4f}")
|
354 |
+
else:
|
355 |
+
st.error(f"β Error from LLM API: {response.status_code} - {response.text}")
|
356 |
+
except Exception as e:
|
357 |
+
st.error(f"β An error occurred during processing: {e}")
|
358 |
+
|
359 |
+
# ===============================
|
360 |
+
# 6. Main Application Logic
|
361 |
+
# ===============================
|
362 |
+
def main():
|
363 |
+
st.sidebar.header("π Gmail Authentication")
|
364 |
+
credentials_file = st.sidebar.file_uploader("π Upload `credentials.json`", type=["json"])
|
365 |
+
if credentials_file and st.sidebar.button("π Authenticate"):
|
366 |
+
reset_session_state()
|
367 |
+
with open("credentials.json", "wb") as f:
|
368 |
+
f.write(credentials_file.getbuffer())
|
369 |
+
authenticate_gmail("credentials.json")
|
370 |
+
|
371 |
+
# Option to load previously saved email chunks.
|
372 |
+
chunks_file = st.sidebar.file_uploader("π Upload saved email chunks (data_chunks.pkl)", type=["pkl"])
|
373 |
+
if chunks_file:
|
374 |
+
try:
|
375 |
+
st.session_state.data_chunks = pickle.load(chunks_file)
|
376 |
+
st.success("π Email chunks loaded successfully from upload!")
|
377 |
+
except Exception as e:
|
378 |
+
st.error(f"β Error loading uploaded email chunks: {e}")
|
379 |
+
|
380 |
+
# Option to load previously saved embeddings and vector store.
|
381 |
+
embeddings_file = st.sidebar.file_uploader("π Upload saved embeddings (embeddings.pkl)", type=["pkl"])
|
382 |
+
vector_file = st.sidebar.file_uploader("π Upload saved vector store (vector_store.index)", type=["index", "idx"])
|
383 |
+
if embeddings_file and vector_file:
|
384 |
+
try:
|
385 |
+
st.session_state.embeddings = pickle.load(embeddings_file)
|
386 |
+
st.session_state.vector_store = faiss.read_index(vector_file.name)
|
387 |
+
st.success("π Embeddings and vector store loaded successfully from upload!")
|
388 |
+
except Exception as e:
|
389 |
+
st.error(f"β Error loading uploaded embeddings/vector store: {e}")
|
390 |
+
|
391 |
+
if st.session_state.auth_url:
|
392 |
+
st.sidebar.markdown("### π **Authorization URL:**")
|
393 |
+
st.sidebar.markdown(f"[Authorize]({st.session_state.auth_url})")
|
394 |
+
st.sidebar.text_input("π Enter the authorization code:", key="auth_code")
|
395 |
+
if st.sidebar.button("β
Submit Authentication Code"):
|
396 |
+
submit_auth_code()
|
397 |
+
|
398 |
+
if st.session_state.authenticated:
|
399 |
+
st.sidebar.success("β
You are authenticated!")
|
400 |
+
st.sidebar.header("π Data Management")
|
401 |
+
label = st.sidebar.selectbox("π¬ Select Label to Process Emails From:",
|
402 |
+
["INBOX", "SENT", "DRAFTS", "TRASH", "SPAM"],
|
403 |
+
key="label_selector")
|
404 |
+
if st.sidebar.button("π₯ Create Chunks and Embed Data"):
|
405 |
+
service = build('gmail', 'v1', credentials=st.session_state.creds)
|
406 |
+
create_chunks_from_gmail(service, label)
|
407 |
+
if st.session_state.data_chunks:
|
408 |
+
embed_emails(st.session_state.data_chunks)
|
409 |
+
if (st.session_state.embeddings is not None and st.session_state.vector_store is not None):
|
410 |
+
with st.expander("πΎ Save Data"):
|
411 |
+
if st.button("πΎ Save Email Chunks"):
|
412 |
+
try:
|
413 |
+
with open("data_chunks.pkl", "wb") as f:
|
414 |
+
pickle.dump(st.session_state.data_chunks, f)
|
415 |
+
st.success("πΎ Email chunks saved to disk!")
|
416 |
+
except Exception as e:
|
417 |
+
st.error(f"β Error saving email chunks: {e}")
|
418 |
+
if st.button("πΎ Save Embeddings & Vector Store"):
|
419 |
+
save_embeddings_and_index()
|
420 |
+
if (st.session_state.vector_store is not None and
|
421 |
+
st.session_state.embeddings is not None and
|
422 |
+
st.session_state.data_chunks is not None):
|
423 |
+
handle_user_query()
|
424 |
+
else:
|
425 |
+
st.warning("β οΈ You are not authenticated yet. Please authenticate to access your Gmail data.")
|
426 |
+
|
427 |
+
if __name__ == "__main__":
|
428 |
+
main()
|