Spaces:
Build error
Build error
nileshhanotia
commited on
Update models/rag_system.py
Browse files- models/rag_system.py +26 -22
models/rag_system.py
CHANGED
@@ -1,10 +1,9 @@
|
|
1 |
import os
|
2 |
import pandas as pd
|
3 |
-
from transformers import pipeline
|
4 |
-
|
5 |
-
|
6 |
-
from
|
7 |
-
from langchain.docstore.document import Document
|
8 |
from utils.logger import setup_logger
|
9 |
from utils.model_loader import ModelLoader
|
10 |
|
@@ -13,6 +12,7 @@ logger = setup_logger(__name__)
|
|
13 |
class RAGSystem:
|
14 |
def __init__(self, csv_path="apparel.csv"):
|
15 |
try:
|
|
|
16 |
self.setup_system(csv_path)
|
17 |
self.qa_pipeline = ModelLoader.load_model_with_retry(
|
18 |
"distilbert-base-cased-distilled-squad",
|
@@ -28,30 +28,34 @@ class RAGSystem:
|
|
28 |
raise FileNotFoundError(f"CSV file not found at {csv_path}")
|
29 |
|
30 |
try:
|
31 |
-
documents = pd.read_csv(csv_path)
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
) for idx, row in documents.iterrows()
|
37 |
-
]
|
38 |
-
|
39 |
-
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
|
40 |
-
split_docs = text_splitter.split_documents(docs)
|
41 |
-
|
42 |
-
embeddings = HuggingFaceEmbeddings(
|
43 |
-
model_name="sentence-transformers/all-MiniLM-L6-v2"
|
44 |
)
|
45 |
-
self.vector_store = FAISS.from_documents(split_docs, embeddings)
|
46 |
-
self.retriever = self.vector_store.as_retriever()
|
47 |
except Exception as e:
|
48 |
logger.error(f"Failed to setup RAG system: {str(e)}")
|
49 |
raise
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
def process_query(self, query):
|
52 |
try:
|
53 |
-
retrieved_docs = self.
|
54 |
-
retrieved_text = "\n".join(
|
55 |
|
56 |
qa_input = {
|
57 |
"question": query,
|
|
|
1 |
import os
|
2 |
import pandas as pd
|
3 |
+
from transformers import pipeline, AutoTokenizer, AutoModel
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
from sentence_transformers import SentenceTransformer
|
|
|
7 |
from utils.logger import setup_logger
|
8 |
from utils.model_loader import ModelLoader
|
9 |
|
|
|
12 |
class RAGSystem:
|
13 |
def __init__(self, csv_path="apparel.csv"):
|
14 |
try:
|
15 |
+
self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
16 |
self.setup_system(csv_path)
|
17 |
self.qa_pipeline = ModelLoader.load_model_with_retry(
|
18 |
"distilbert-base-cased-distilled-squad",
|
|
|
28 |
raise FileNotFoundError(f"CSV file not found at {csv_path}")
|
29 |
|
30 |
try:
|
31 |
+
self.documents = pd.read_csv(csv_path)
|
32 |
+
# Create embeddings for all documents
|
33 |
+
self.doc_embeddings = self.model.encode(
|
34 |
+
self.documents['Title'].astype(str).tolist(),
|
35 |
+
convert_to_tensor=True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
)
|
|
|
|
|
37 |
except Exception as e:
|
38 |
logger.error(f"Failed to setup RAG system: {str(e)}")
|
39 |
raise
|
40 |
|
41 |
+
def get_relevant_documents(self, query, top_k=5):
|
42 |
+
# Get query embedding
|
43 |
+
query_embedding = self.model.encode(query, convert_to_tensor=True)
|
44 |
+
|
45 |
+
# Calculate cosine similarities
|
46 |
+
cos_scores = torch.nn.functional.cosine_similarity(
|
47 |
+
query_embedding.unsqueeze(0),
|
48 |
+
self.doc_embeddings
|
49 |
+
)
|
50 |
+
|
51 |
+
# Get top_k most similar documents
|
52 |
+
top_indices = torch.topk(cos_scores, min(top_k, len(self.documents))).indices
|
53 |
+
return [str(self.documents.iloc[idx]['Title']) for idx in top_indices]
|
54 |
+
|
55 |
def process_query(self, query):
|
56 |
try:
|
57 |
+
retrieved_docs = self.get_relevant_documents(query)
|
58 |
+
retrieved_text = "\n".join(retrieved_docs)[:1000]
|
59 |
|
60 |
qa_input = {
|
61 |
"question": query,
|