import streamlit as st import faiss import numpy as np from sentence_transformers import SentenceTransformer, CrossEncoder import requests import os import torch import pickle from tqdm import tqdm from googleapiclient.discovery import build from google_auth_oauthlib.flow import InstalledAppFlow from google.auth.transport.requests import Request from google.oauth2.credentials import Credentials import base64 import re from pyngrok import ngrok # =============================== # 1. Streamlit App Configuration # =============================== st.set_page_config(page_title="📥 Email Chat Application", layout="wide") st.title("✉️ Email Chat Application") # =============================== # 2. Gmail Authentication Configuration # =============================== SCOPES = ['https://www.googleapis.com/auth/gmail.readonly'] if "authenticated" not in st.session_state: st.session_state.authenticated = False if "creds" not in st.session_state: st.session_state.creds = None if "auth_url" not in st.session_state: st.session_state.auth_url = None if "auth_code" not in st.session_state: st.session_state.auth_code = "" if "flow" not in st.session_state: st.session_state.flow = None if "data_chunks" not in st.session_state: st.session_state.data_chunks = [] # List to store all email chunks if "embeddings" not in st.session_state: st.session_state.embeddings = None if "vector_store" not in st.session_state: st.session_state.vector_store = None def count_tokens(text): return len(text.split()) # =============================== # 3. Gmail Authentication Functions # =============================== def reset_session_state(): st.session_state.authenticated = False st.session_state.creds = None st.session_state.auth_url = None st.session_state.auth_code = "" st.session_state.flow = None st.session_state.data_chunks = [] st.session_state.embeddings = None st.session_state.vector_store = None for filename in ["token.json", "data_chunks.pkl", "embeddings.pkl", "vector_store.index"]: if os.path.exists(filename): os.remove(filename) def authenticate_gmail(credentials_file): creds = None if os.path.exists('token.json'): try: creds = Credentials.from_authorized_user_file('token.json', SCOPES) if creds and creds.valid: st.session_state.creds = creds st.session_state.authenticated = True st.success("✅ Authentication successful!") return creds except Exception as e: st.error(f"❌ Invalid token.json file: {e}") os.remove('token.json') if not creds or not creds.valid: if creds and creds.expired and creds.refresh_token: creds.refresh(Request()) st.session_state.creds = creds st.session_state.authenticated = True st.success("✅ Authentication successful!") with open('token.json', 'w') as token_file: token_file.write(creds.to_json()) return creds else: if not st.session_state.flow: st.session_state.flow = InstalledAppFlow.from_client_secrets_file(credentials_file, SCOPES) st.session_state.flow.redirect_uri = 'http://localhost' auth_url, _ = st.session_state.flow.authorization_url(prompt='consent') st.session_state.auth_url = auth_url st.info("🔗 **Authorize the application by visiting the URL below:**") st.markdown(f"[Authorize]({st.session_state.auth_url})") def submit_auth_code(): try: st.session_state.flow.fetch_token(code=st.session_state.auth_code) st.session_state.creds = st.session_state.flow.credentials st.session_state.authenticated = True with open('token.json', 'w') as token_file: token_file.write(st.session_state.creds.to_json()) st.success("✅ Authentication successful!") except Exception as e: st.error(f"❌ Error during authentication: {e}") # =============================== # 4. Email Data Extraction, Embedding and Vector Store Functions # =============================== def extract_email_body(payload): if 'body' in payload and 'data' in payload['body'] and payload['body']['data']: try: return base64.urlsafe_b64decode(payload['body']['data'].encode('UTF-8')).decode('UTF-8') except Exception as e: st.error(f"Error decoding email body: {e}") return "" if 'parts' in payload: for part in payload['parts']: if part.get('mimeType') == 'text/plain' and 'data' in part.get('body', {}): try: return base64.urlsafe_b64decode(part['body']['data'].encode('UTF-8')).decode('UTF-8') except Exception as e: st.error(f"Error decoding email part: {e}") continue if payload['parts']: first_part = payload['parts'][0] if 'data' in first_part.get('body', {}): try: return base64.urlsafe_b64decode(first_part['body']['data'].encode('UTF-8')).decode('UTF-8') except Exception as e: st.error(f"Error decoding fallback email part: {e}") return "" return "" def combine_email_text(email): parts = [] if email.get("sender"): parts.append(f"Sender: {email['sender']}") if email.get("to"): parts.append(f"To: {email['to']}") if email.get("date"): parts.append(f"Date: {email['date']}") if email.get("subject"): parts.append(f"Subject: {email['subject']}") if email.get("body"): parts.append(f"Body: {email['body']}") return "\n".join(parts) def create_chunks_from_gmail(service, label): try: messages = [] result = service.users().messages().list(userId='me', labelIds=[label], maxResults=500).execute() messages.extend(result.get('messages', [])) while 'nextPageToken' in result: token = result["nextPageToken"] result = service.users().messages().list(userId='me', labelIds=[label], maxResults=500, pageToken=token).execute() messages.extend(result.get('messages', [])) data_chunks = [] progress_bar = st.progress(0) total = len(messages) for idx, msg in enumerate(messages): msg_data = service.users().messages().get(userId='me', id=msg['id'], format='full').execute() headers = msg_data.get('payload', {}).get('headers', []) email_dict = {"id": msg['id']} for header in headers: name = header.get('name', '').lower() if name == 'from': email_dict['sender'] = header.get('value', '') elif name == 'subject': email_dict['subject'] = header.get('value', '') elif name == 'to': email_dict['to'] = header.get('value', '') elif name == 'date': email_dict['date'] = header.get('value', '') email_dict['body'] = extract_email_body(msg_data.get('payload', {})) data_chunks.append(email_dict) progress_bar.progress((idx + 1) / total) st.session_state.data_chunks = data_chunks st.success(f"✅ Data chunks created successfully from Gmail! Total emails processed: {len(data_chunks)}") # Save chunks locally for future use. with open("data_chunks.pkl", "wb") as f: pickle.dump(data_chunks, f) except Exception as e: st.error(f"❌ Error creating chunks from Gmail: {e}") def embed_emails(email_chunks): st.header("🔄 Embedding Data and Creating Vector Store") with st.spinner('🔄 Embedding data...'): try: embed_model = SentenceTransformer("all-MiniLM-L6-v2") device = 'cuda' if torch.cuda.is_available() else 'cpu' embed_model.to(device) combined_texts = [combine_email_text(email) for email in email_chunks] batch_size = 64 embeddings = [] for i in range(0, len(combined_texts), batch_size): batch = combined_texts[i:i+batch_size] batch_embeddings = embed_model.encode( batch, convert_to_numpy=True, show_progress_bar=False, device=device ) embeddings.append(batch_embeddings) embeddings = np.vstack(embeddings) faiss.normalize_L2(embeddings) st.session_state.embeddings = embeddings dimension = embeddings.shape[1] index = faiss.IndexFlatIP(dimension) index.add(embeddings) st.session_state.vector_store = index st.success("✅ Data embedding and vector store created successfully!") # Save embeddings and index to disk. with open('embeddings.pkl', 'wb') as f: pickle.dump(embeddings, f) faiss.write_index(index, 'vector_store.index') except Exception as e: st.error(f"❌ Error during embedding: {e}") def save_embeddings_and_index(): try: with open('embeddings.pkl', 'wb') as f: pickle.dump(st.session_state.embeddings, f) faiss.write_index(st.session_state.vector_store, 'vector_store.index') st.success("💾 Embeddings and vector store saved successfully!") except Exception as e: st.error(f"❌ Error saving embeddings and vector store: {e}") def load_embeddings_and_index(): try: with open('embeddings.pkl', 'rb') as f: st.session_state.embeddings = pickle.load(f) st.session_state.vector_store = faiss.read_index('vector_store.index') st.success("📁 Embeddings and vector store loaded successfully!") except Exception as e: st.error(f"❌ Error loading embeddings and vector store: {e}") def load_chunks(): try: with open("data_chunks.pkl", "rb") as f: st.session_state.data_chunks = pickle.load(f) st.success("📁 Email chunks loaded successfully!") except Exception as e: st.error(f"❌ Error loading email chunks: {e}") # =============================== # 5. Handling User Queries # =============================== def preprocess_query(query): return query.lower().strip() def handle_user_query(): st.header("💬 Let's chat with your Email") user_query = st.text_input("Enter your query:") TOP_K = 10 SIMILARITY_THRESHOLD = 0.4 if st.button("🔍 Get Response"): if (st.session_state.vector_store is None or st.session_state.embeddings is None or st.session_state.data_chunks is None): st.error("❌ Please process your email data or load saved chunks/embeddings first.") return if not user_query.strip(): st.error("❌ Please enter a valid query.") return with st.spinner('🔄 Processing your query...'): try: # Retrieve candidates using the bi-encoder. embed_model = SentenceTransformer("all-MiniLM-L6-v2") device = 'cuda' if torch.cuda.is_available() else 'cpu' embed_model.to(device) processed_query = preprocess_query(user_query) query_embedding = embed_model.encode( [processed_query], convert_to_numpy=True, show_progress_bar=False, device=device ) faiss.normalize_L2(query_embedding) distances, indices = st.session_state.vector_store.search(query_embedding, TOP_K) candidates = [] for idx, score in zip(indices[0], distances[0]): candidates.append((st.session_state.data_chunks[idx], score)) # Boost candidates if sender or "to" field contains query tokens (e.g., email addresses). query_tokens = re.findall(r'\S+@\S+', user_query) if query_tokens: for i in range(len(candidates)): candidate_email_str = ( (candidates[i][0].get("sender", "") + " " + candidates[i][0].get("to", "")).lower() ) for token in query_tokens: if token.lower() in candidate_email_str: candidates[i] = (candidates[i][0], max(candidates[i][1], 1.0)) filtered_candidates = [] for candidate, score in candidates: candidate_text = combine_email_text(candidate).lower() if any(token.lower() in candidate_text for token in query_tokens): filtered_candidates.append((candidate, score)) if filtered_candidates: candidates = filtered_candidates else: st.info("No candidate emails contain the query token(s) exactly. Proceeding with all candidates.") candidates.sort(key=lambda x: x[1], reverse=True) if not candidates: st.subheader("📝 AI Response:") st.write("⚠️ No documents found.") return if candidates[0][1] < SIMILARITY_THRESHOLD: st.subheader("📝 AI Response:") st.write("⚠️ No document strongly matches your query. Try refining your query.") return # Re-rank candidates using the cross-encoder. cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2") candidate_pairs = [(user_query, combine_email_text(candidate[0])) for candidate in candidates] rerank_scores = cross_encoder.predict(candidate_pairs) reranked_candidates = [(candidates[i][0], rerank_scores[i]) for i in range(len(candidates))] reranked_candidates.sort(key=lambda x: x[1], reverse=True) retrieved_emails = [email for email, score in reranked_candidates] retrieved_scores = [score for email, score in reranked_candidates] average_similarity = np.mean(retrieved_scores) # Build the final context string. context_str = "\n\n".join([combine_email_text(email) for email in retrieved_emails]) MAX_CONTEXT_TOKENS = 500 context_tokens = context_str.split() if len(context_tokens) > MAX_CONTEXT_TOKENS: context_str = " ".join(context_tokens[:MAX_CONTEXT_TOKENS]) payload = { "model": "llama3-8b-8192", # Adjust as needed. "messages": [ {"role": "system", "content": f"Use the following context:\n{context_str}"}, {"role": "user", "content": user_query} ] } api_key = "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" # Replace with your API key. url = "https://api.groq.com/openai/v1/chat/completions" headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json" } response = requests.post(url, headers=headers, json=payload) if response.status_code == 200: response_json = response.json() generated_text = response_json["choices"][0]["message"]["content"] st.subheader("📝 AI Response:") st.write(generated_text) st.write(f"Average Re-Ranked Score: {average_similarity:.4f}") else: st.error(f"❌ Error from LLM API: {response.status_code} - {response.text}") except Exception as e: st.error(f"❌ An error occurred during processing: {e}") # =============================== # 6. Main Application Logic # =============================== def main(): st.sidebar.header("🔒 Gmail Authentication") credentials_file = st.sidebar.file_uploader("📁 Upload `credentials.json`", type=["json"]) if credentials_file and st.sidebar.button("🔓 Authenticate"): reset_session_state() with open("credentials.json", "wb") as f: f.write(credentials_file.getbuffer()) authenticate_gmail("credentials.json") # Option to load previously saved email chunks. chunks_file = st.sidebar.file_uploader("📁 Upload saved email chunks (data_chunks.pkl)", type=["pkl"]) if chunks_file: try: st.session_state.data_chunks = pickle.load(chunks_file) st.success("📁 Email chunks loaded successfully from upload!") except Exception as e: st.error(f"❌ Error loading uploaded email chunks: {e}") # Option to load previously saved embeddings and vector store. embeddings_file = st.sidebar.file_uploader("📁 Upload saved embeddings (embeddings.pkl)", type=["pkl"]) vector_file = st.sidebar.file_uploader("📁 Upload saved vector store (vector_store.index)", type=["index", "idx"]) if embeddings_file and vector_file: try: st.session_state.embeddings = pickle.load(embeddings_file) st.session_state.vector_store = faiss.read_index(vector_file.name) st.success("📁 Embeddings and vector store loaded successfully from upload!") except Exception as e: st.error(f"❌ Error loading uploaded embeddings/vector store: {e}") if st.session_state.auth_url: st.sidebar.markdown("### 🔗 **Authorization URL:**") st.sidebar.markdown(f"[Authorize]({st.session_state.auth_url})") st.sidebar.text_input("🔑 Enter the authorization code:", key="auth_code") if st.sidebar.button("✅ Submit Authentication Code"): submit_auth_code() if st.session_state.authenticated: st.sidebar.success("✅ You are authenticated!") st.sidebar.header("📂 Data Management") label = st.sidebar.selectbox("📬 Select Label to Process Emails From:", ["INBOX", "SENT", "DRAFTS", "TRASH", "SPAM"], key="label_selector") if st.sidebar.button("📥 Create Chunks and Embed Data"): service = build('gmail', 'v1', credentials=st.session_state.creds) create_chunks_from_gmail(service, label) if st.session_state.data_chunks: embed_emails(st.session_state.data_chunks) if (st.session_state.embeddings is not None and st.session_state.vector_store is not None): with st.expander("💾 Save Data"): if st.button("💾 Save Email Chunks"): try: with open("data_chunks.pkl", "wb") as f: pickle.dump(st.session_state.data_chunks, f) st.success("💾 Email chunks saved to disk!") except Exception as e: st.error(f"❌ Error saving email chunks: {e}") if st.button("💾 Save Embeddings & Vector Store"): save_embeddings_and_index() if (st.session_state.vector_store is not None and st.session_state.embeddings is not None and st.session_state.data_chunks is not None): handle_user_query() else: st.warning("⚠️ You are not authenticated yet. Please authenticate to access your Gmail data.") if __name__ == "__main__": main()