Abrar20's picture
Update app.py
9ce8b35 verified
raw
history blame
9.2 kB
import streamlit as st
from PyPDF2 import PdfReader
from langchain.text_splitter import RecursiveCharacterTextSplitter
import os
from langchain_google_genai import GoogleGenerativeAIEmbeddings
import google.generativeai as genai
from langchain_community.vectorstores import FAISS
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts import PromptTemplate
from dotenv import load_dotenv
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
import json
# Initialize FastAPI app
api_app = FastAPI()
# CORS settings to allow cross-origin requests
api_app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load environment variables
load_dotenv()
# Check if the API key is loaded correctly
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
if not GOOGLE_API_KEY:
raise ValueError("API key not found in environment variables. Please check '.env' file.")
# Configure the API key
genai.configure(api_key=GOOGLE_API_KEY)
def get_pdf_text(pdf_path):
"""Extract text from a PDF document."""
text = ""
try:
with open(pdf_path, "rb") as pdf_file:
pdf_reader = PdfReader(pdf_file)
for page in pdf_reader.pages:
extracted_text = page.extract_text()
if extracted_text:
text += extracted_text
except Exception as e:
st.error(f"Error reading PDF file: {e}")
return text
def get_text_chunks(text):
"""Split text into manageable chunks for FAISS indexing."""
text_splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=1000)
chunks = text_splitter.split_text(text)
return chunks
def get_vector_store(text_chunks):
"""Create and save a FAISS vector store from text chunks."""
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=GOOGLE_API_KEY)
vector_store = FAISS.from_texts(text_chunks, embedding=embeddings)
vector_store.save_local("faiss_index")
def get_conversational_chain():
"""Define the prompt template and load the QA chain with the Generative AI model."""
prompt_template = """
{context}
You are an AI assistant for Robi Axiata Ltd., specializing in providing personalized internet package recommendations and customer support. Your role is to:
1. Recommend the best internet package based on the user's current task and its duration.
2. Suggest suitable packages based on the user's weekly internet usage patterns.
3. Provide solutions to common internet-related problems.
## Task:
{question}
Provide a detailed response with:
1. **Recommended Package:** Include package name, data volume, validity, and price.
2. **Alternative Packages:** List alternative options sorted from cheapest to most expensive.
3. **Solution to Issues:** If applicable, provide solutions to any mentioned problems.
Begin your response now.
"""
model = ChatGoogleGenerativeAI(model="gemini-pro", temperature=0.3)
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
chain = load_qa_chain(model, chain_type="stuff", prompt=prompt)
return chain
def user_input(user_question):
"""Process user input, perform similarity search, and generate response."""
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=GOOGLE_API_KEY)
# Check if FAISS index exists before loading
faiss_index_path = "faiss_index/index.faiss"
if not os.path.exists(faiss_index_path):
return {"error": "FAISS index not found. Please ensure the documents are processed first."}
new_db = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True)
docs = new_db.similarity_search(user_question)
chain = get_conversational_chain()
response = chain({"input_documents": docs, "question": user_question}, return_only_outputs=True)
return {"response": response["output_text"]}
# FastAPI endpoint to expose the API
@api_app.post("/api/ask")
async def ask_question(request: Request):
"""API endpoint to handle user questions."""
req_data = await request.json()
user_question = req_data.get("question", "")
if not user_question:
return {"error": "No question provided."}
# Call the function that retrieves the answer
result = user_input(user_question)
return result
# Streamlit UI
def main():
"""Main function to run the Streamlit app."""
st.set_page_config(page_title="Robi AI Customer Service", page_icon="πŸ“±")
# Display Logo
if os.path.exists("robi_logo.png"):
st.image("robi_logo.png", width=200) # Ensure you have a 'robi_logo.png' in your project directory
else:
st.warning("Logo image 'robi_logo.png' not found. Please add it to the project directory.")
# Title and Description
st.title("πŸ“Ά Robi AI Customer Service & Recommendation System")
st.markdown("""
Welcome to Robi Axiata Ltd.'s AI-powered Customer Service and Recommendation System. Get personalized internet package recommendations based on your usage and tasks.
""")
# Input field for user's task-based question
st.header("πŸ” Get Package Recommendations Based on Your Current Task")
user_task = st.text_input(
"Describe the task you're performing and its estimated duration:",
placeholder="e.g., Streaming a movie for 3 hours"
)
# Show an example input below the input field
st.caption("Example: 'I am downloading large files for 2 hours.'")
if user_task:
# Format the user question for the AI model
user_question = f"Before my current internet package runs out, I am {user_task}. Please recommend the best package to buy next to complete this task, including how much internet is needed and the best options sorted from cheapest to most expensive."
result = user_input(user_question)
if "error" in result:
st.error(result["error"])
else:
st.markdown("### πŸ“„ Recommended Package(s):")
st.write(result["response"])
st.markdown("---")
# Input field for user's weekly usage
st.header("πŸ“… Get Package Recommendations Based on Weekly Usage")
weekly_usage = st.text_input(
"Describe your weekly internet usage patterns and the types of content you consume:",
placeholder="e.g., I use social media and stream videos frequently during the week."
)
# Show an example input below the input field
st.caption("Example: 'I browse the web, stream videos, and play online games every day.'")
if weekly_usage:
# Format the user question for the AI model
user_question = f"I use the internet {weekly_usage}. Please suggest the best Robi internet packages that suit my weekly usage, sorted from cheapest to most expensive."
result = user_input(user_question)
if "error" in result:
st.error(result["error"])
else:
st.markdown("### πŸ“„ Suggested Package(s):")
st.write(result["response"])
st.markdown("---")
# Input field for common problems
st.header("πŸ› οΈ Need Help? Common Problems and Solutions")
problem_description = st.text_input(
"Describe any internet-related issues you're facing:",
placeholder="e.g., My internet speed is slow during evenings."
)
# Show an example input below the input field
st.caption("Example: 'I experience low network connection at night.'")
if problem_description:
# Format the user question for the AI model
user_question = f"I am experiencing the following issue: {problem_description}. Please provide solutions to resolve this problem."
result = user_input(user_question)
if "error" in result:
st.error(result["error"])
else:
st.markdown("### πŸ› οΈ Solutions:")
st.write(result["response"])
st.markdown("---")
with st.sidebar:
st.header("πŸ“„ Documents:")
# Check if the FAISS index exists
if not os.path.exists("faiss_index/index.faiss"):
pdf_path = "customerservicebot.pdf" # Default PDF file
if os.path.exists(pdf_path):
with st.spinner(f"Processing '{pdf_path}'..."):
raw_text = get_pdf_text(pdf_path)
if raw_text:
text_chunks = get_text_chunks(raw_text)
get_vector_store(text_chunks)
st.success(f"'{pdf_path}' processed and FAISS index created.")
else:
st.error(f"No text extracted from '{pdf_path}'. Please check the PDF content.")
else:
st.error(f"Default PDF '{pdf_path}' not found. Please add it to the project directory.")
else:
st.info("FAISS index loaded from 'faiss_index'.")
if __name__ == "__main__":
main()