nileshhanotia commited on
Commit
3a16d21
·
verified ·
1 Parent(s): bf23bc0

Create rag_system.py

Browse files
Files changed (1) hide show
  1. rag_system.py +64 -0
rag_system.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ from transformers import pipeline
4
+ from langchain_huggingface import HuggingFaceEmbeddings
5
+ from langchain_community.vectorstores import FAISS
6
+ from langchain.text_splitter import CharacterTextSplitter
7
+ from langchain.docstore.document import Document
8
+ from utils.logger import setup_logger
9
+ from utils.model_loader import ModelLoader
10
+
11
+ logger = setup_logger(__name__)
12
+
13
+ class RAGSystem:
14
+ def __init__(self, csv_path="apparel.csv"):
15
+ try:
16
+ self.setup_system(csv_path)
17
+ self.qa_pipeline = ModelLoader.load_model_with_retry(
18
+ "distilbert-base-cased-distilled-squad",
19
+ pipeline,
20
+ task="question-answering"
21
+ )
22
+ except Exception as e:
23
+ logger.error(f"Failed to initialize RAGSystem: {str(e)}")
24
+ raise
25
+
26
+ def setup_system(self, csv_path):
27
+ if not os.path.exists(csv_path):
28
+ raise FileNotFoundError(f"CSV file not found at {csv_path}")
29
+
30
+ try:
31
+ documents = pd.read_csv(csv_path)
32
+ docs = [
33
+ Document(
34
+ page_content=str(row['Title']),
35
+ metadata={'index': idx}
36
+ ) for idx, row in documents.iterrows()
37
+ ]
38
+
39
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
40
+ split_docs = text_splitter.split_documents(docs)
41
+
42
+ embeddings = HuggingFaceEmbeddings(
43
+ model_name="sentence-transformers/all-MiniLM-L6-v2"
44
+ )
45
+ self.vector_store = FAISS.from_documents(split_docs, embeddings)
46
+ self.retriever = self.vector_store.as_retriever()
47
+ except Exception as e:
48
+ logger.error(f"Failed to setup RAG system: {str(e)}")
49
+ raise
50
+
51
+ def process_query(self, query):
52
+ try:
53
+ retrieved_docs = self.retriever.get_relevant_documents(query)
54
+ retrieved_text = "\n".join([doc.page_content for doc in retrieved_docs])[:1000]
55
+
56
+ qa_input = {
57
+ "question": query,
58
+ "context": retrieved_text
59
+ }
60
+ response = self.qa_pipeline(qa_input)
61
+ return response['answer']
62
+ except Exception as e:
63
+ logger.error(f"Query processing error: {str(e)}")
64
+ return "Failed to process query due to an error."