nileshhanotia commited on
Commit
3b3a6c5
·
verified ·
1 Parent(s): 0a3e8ee

Update models/rag_system.py

Browse files
Files changed (1) hide show
  1. models/rag_system.py +26 -22
models/rag_system.py CHANGED
@@ -1,10 +1,9 @@
1
  import os
2
  import pandas as pd
3
- from transformers import pipeline
4
- from langchain_core.embeddings import HuggingFaceEmbeddings # Updated import
5
- from langchain_community.vectorstores import FAISS
6
- from langchain.text_splitter import CharacterTextSplitter
7
- from langchain.docstore.document import Document
8
  from utils.logger import setup_logger
9
  from utils.model_loader import ModelLoader
10
 
@@ -13,6 +12,7 @@ logger = setup_logger(__name__)
13
  class RAGSystem:
14
  def __init__(self, csv_path="apparel.csv"):
15
  try:
 
16
  self.setup_system(csv_path)
17
  self.qa_pipeline = ModelLoader.load_model_with_retry(
18
  "distilbert-base-cased-distilled-squad",
@@ -28,30 +28,34 @@ class RAGSystem:
28
  raise FileNotFoundError(f"CSV file not found at {csv_path}")
29
 
30
  try:
31
- documents = pd.read_csv(csv_path)
32
- docs = [
33
- Document(
34
- page_content=str(row['Title']),
35
- metadata={'index': idx}
36
- ) for idx, row in documents.iterrows()
37
- ]
38
-
39
- text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
40
- split_docs = text_splitter.split_documents(docs)
41
-
42
- embeddings = HuggingFaceEmbeddings(
43
- model_name="sentence-transformers/all-MiniLM-L6-v2"
44
  )
45
- self.vector_store = FAISS.from_documents(split_docs, embeddings)
46
- self.retriever = self.vector_store.as_retriever()
47
  except Exception as e:
48
  logger.error(f"Failed to setup RAG system: {str(e)}")
49
  raise
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def process_query(self, query):
52
  try:
53
- retrieved_docs = self.retriever.get_relevant_documents(query)
54
- retrieved_text = "\n".join([doc.page_content for doc in retrieved_docs])[:1000]
55
 
56
  qa_input = {
57
  "question": query,
 
1
  import os
2
  import pandas as pd
3
+ from transformers import pipeline, AutoTokenizer, AutoModel
4
+ import torch
5
+ import numpy as np
6
+ from sentence_transformers import SentenceTransformer
 
7
  from utils.logger import setup_logger
8
  from utils.model_loader import ModelLoader
9
 
 
12
  class RAGSystem:
13
  def __init__(self, csv_path="apparel.csv"):
14
  try:
15
+ self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
16
  self.setup_system(csv_path)
17
  self.qa_pipeline = ModelLoader.load_model_with_retry(
18
  "distilbert-base-cased-distilled-squad",
 
28
  raise FileNotFoundError(f"CSV file not found at {csv_path}")
29
 
30
  try:
31
+ self.documents = pd.read_csv(csv_path)
32
+ # Create embeddings for all documents
33
+ self.doc_embeddings = self.model.encode(
34
+ self.documents['Title'].astype(str).tolist(),
35
+ convert_to_tensor=True
 
 
 
 
 
 
 
 
36
  )
 
 
37
  except Exception as e:
38
  logger.error(f"Failed to setup RAG system: {str(e)}")
39
  raise
40
 
41
+ def get_relevant_documents(self, query, top_k=5):
42
+ # Get query embedding
43
+ query_embedding = self.model.encode(query, convert_to_tensor=True)
44
+
45
+ # Calculate cosine similarities
46
+ cos_scores = torch.nn.functional.cosine_similarity(
47
+ query_embedding.unsqueeze(0),
48
+ self.doc_embeddings
49
+ )
50
+
51
+ # Get top_k most similar documents
52
+ top_indices = torch.topk(cos_scores, min(top_k, len(self.documents))).indices
53
+ return [str(self.documents.iloc[idx]['Title']) for idx in top_indices]
54
+
55
  def process_query(self, query):
56
  try:
57
+ retrieved_docs = self.get_relevant_documents(query)
58
+ retrieved_text = "\n".join(retrieved_docs)[:1000]
59
 
60
  qa_input = {
61
  "question": query,