Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PyPDF2 import PdfReader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
import os | |
from langchain_google_genai import GoogleGenerativeAIEmbeddings | |
import google.generativeai as genai | |
from langchain_community.vectorstores import FAISS | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain.chains.question_answering import load_qa_chain | |
from langchain.prompts import PromptTemplate | |
from dotenv import load_dotenv | |
from fastapi import FastAPI, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
import json | |
# Initialize FastAPI app | |
api_app = FastAPI() | |
# CORS settings to allow cross-origin requests | |
api_app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Load environment variables | |
load_dotenv() | |
# Check if the API key is loaded correctly | |
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") | |
if not GOOGLE_API_KEY: | |
raise ValueError("API key not found in environment variables. Please check '.env' file.") | |
# Configure the API key | |
genai.configure(api_key=GOOGLE_API_KEY) | |
def get_pdf_text(pdf_path): | |
"""Extract text from a PDF document.""" | |
text = "" | |
try: | |
with open(pdf_path, "rb") as pdf_file: | |
pdf_reader = PdfReader(pdf_file) | |
for page in pdf_reader.pages: | |
extracted_text = page.extract_text() | |
if extracted_text: | |
text += extracted_text | |
except Exception as e: | |
st.error(f"Error reading PDF file: {e}") | |
return text | |
def get_text_chunks(text): | |
"""Split text into manageable chunks for FAISS indexing.""" | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=1000) | |
chunks = text_splitter.split_text(text) | |
return chunks | |
def get_vector_store(text_chunks): | |
"""Create and save a FAISS vector store from text chunks.""" | |
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=GOOGLE_API_KEY) | |
vector_store = FAISS.from_texts(text_chunks, embedding=embeddings) | |
vector_store.save_local("faiss_index") | |
def get_conversational_chain(): | |
"""Define the prompt template and load the QA chain with the Generative AI model.""" | |
prompt_template = """ | |
{context} | |
You are an AI assistant for Robi Axiata Ltd., specializing in providing personalized internet package recommendations and customer support. Your role is to: | |
1. Recommend the best internet package based on the user's current task and its duration. | |
2. Suggest suitable packages based on the user's weekly internet usage patterns. | |
3. Provide solutions to common internet-related problems. | |
## Task: | |
{question} | |
Provide a detailed response with: | |
1. **Recommended Package:** Include package name, data volume, validity, and price. | |
2. **Alternative Packages:** List alternative options sorted from cheapest to most expensive. | |
3. **Solution to Issues:** If applicable, provide solutions to any mentioned problems. | |
Begin your response now. | |
""" | |
model = ChatGoogleGenerativeAI(model="gemini-pro", temperature=0.3) | |
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) | |
chain = load_qa_chain(model, chain_type="stuff", prompt=prompt) | |
return chain | |
def user_input(user_question): | |
"""Process user input, perform similarity search, and generate response.""" | |
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=GOOGLE_API_KEY) | |
# Check if FAISS index exists before loading | |
faiss_index_path = "faiss_index/index.faiss" | |
if not os.path.exists(faiss_index_path): | |
return {"error": "FAISS index not found. Please ensure the documents are processed first."} | |
new_db = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True) | |
docs = new_db.similarity_search(user_question) | |
chain = get_conversational_chain() | |
response = chain({"input_documents": docs, "question": user_question}, return_only_outputs=True) | |
return {"response": response["output_text"]} | |
# FastAPI endpoint to expose the API | |
async def ask_question(request: Request): | |
"""API endpoint to handle user questions.""" | |
req_data = await request.json() | |
user_question = req_data.get("question", "") | |
if not user_question: | |
return {"error": "No question provided."} | |
# Call the function that retrieves the answer | |
result = user_input(user_question) | |
return result | |
# Streamlit UI | |
def main(): | |
"""Main function to run the Streamlit app.""" | |
st.set_page_config(page_title="Robi AI Customer Service", page_icon="π±") | |
# Display Logo | |
if os.path.exists("robi_logo.png"): | |
st.image("robi_logo.png", width=200) # Ensure you have a 'robi_logo.png' in your project directory | |
else: | |
st.warning("Logo image 'robi_logo.png' not found. Please add it to the project directory.") | |
# Title and Description | |
st.title("πΆ Robi AI Customer Service & Recommendation System") | |
st.markdown(""" | |
Welcome to Robi Axiata Ltd.'s AI-powered Customer Service and Recommendation System. Get personalized internet package recommendations based on your usage and tasks. | |
""") | |
# Input field for user's task-based question | |
st.header("π Get Package Recommendations Based on Your Current Task") | |
user_task = st.text_input( | |
"Describe the task you're performing and its estimated duration:", | |
placeholder="e.g., Streaming a movie for 3 hours" | |
) | |
# Show an example input below the input field | |
st.caption("Example: 'I am downloading large files for 2 hours.'") | |
if user_task: | |
# Format the user question for the AI model | |
user_question = f"Before my current internet package runs out, I am {user_task}. Please recommend the best package to buy next to complete this task, including how much internet is needed and the best options sorted from cheapest to most expensive." | |
result = user_input(user_question) | |
if "error" in result: | |
st.error(result["error"]) | |
else: | |
st.markdown("### π Recommended Package(s):") | |
st.write(result["response"]) | |
st.markdown("---") | |
# Input field for user's weekly usage | |
st.header("π Get Package Recommendations Based on Weekly Usage") | |
weekly_usage = st.text_input( | |
"Describe your weekly internet usage patterns and the types of content you consume:", | |
placeholder="e.g., I use social media and stream videos frequently during the week." | |
) | |
# Show an example input below the input field | |
st.caption("Example: 'I browse the web, stream videos, and play online games every day.'") | |
if weekly_usage: | |
# Format the user question for the AI model | |
user_question = f"I use the internet {weekly_usage}. Please suggest the best Robi internet packages that suit my weekly usage, sorted from cheapest to most expensive." | |
result = user_input(user_question) | |
if "error" in result: | |
st.error(result["error"]) | |
else: | |
st.markdown("### π Suggested Package(s):") | |
st.write(result["response"]) | |
st.markdown("---") | |
# Input field for common problems | |
st.header("π οΈ Need Help? Common Problems and Solutions") | |
problem_description = st.text_input( | |
"Describe any internet-related issues you're facing:", | |
placeholder="e.g., My internet speed is slow during evenings." | |
) | |
# Show an example input below the input field | |
st.caption("Example: 'I experience low network connection at night.'") | |
if problem_description: | |
# Format the user question for the AI model | |
user_question = f"I am experiencing the following issue: {problem_description}. Please provide solutions to resolve this problem." | |
result = user_input(user_question) | |
if "error" in result: | |
st.error(result["error"]) | |
else: | |
st.markdown("### π οΈ Solutions:") | |
st.write(result["response"]) | |
st.markdown("---") | |
with st.sidebar: | |
st.header("π Documents:") | |
# Check if the FAISS index exists | |
if not os.path.exists("faiss_index/index.faiss"): | |
pdf_path = "customerservicebot.pdf" # Default PDF file | |
if os.path.exists(pdf_path): | |
with st.spinner(f"Processing '{pdf_path}'..."): | |
raw_text = get_pdf_text(pdf_path) | |
if raw_text: | |
text_chunks = get_text_chunks(raw_text) | |
get_vector_store(text_chunks) | |
st.success(f"'{pdf_path}' processed and FAISS index created.") | |
else: | |
st.error(f"No text extracted from '{pdf_path}'. Please check the PDF content.") | |
else: | |
st.error(f"Default PDF '{pdf_path}' not found. Please add it to the project directory.") | |
else: | |
st.info("FAISS index loaded from 'faiss_index'.") | |
if __name__ == "__main__": | |
main() | |