chsubhasis commited on
Commit
2dd051b
1 Parent(s): 3aa009f
Files changed (4) hide show
  1. AIML.pdf +0 -0
  2. app.py +203 -138
  3. mini-llama-articles.csv +0 -0
  4. requirements.txt +0 -0
AIML.pdf ADDED
Binary file (89.9 kB). View file
 
app.py CHANGED
@@ -1,77 +1,59 @@
1
  import os
2
  from getpass import getpass
3
- import csv
4
- from langchain_core.documents import Document
5
  from langchain_text_splitters import RecursiveCharacterTextSplitter
6
- #from langchain.schema import Document
7
  from langchain_huggingface import HuggingFaceEmbeddings
8
  import torch
9
  from langchain_huggingface import HuggingFaceEndpoint
10
- from langchain_community.cache import InMemoryCache
11
- from langchain.globals import set_llm_cache
12
  from langchain_chroma import Chroma
13
  from langchain.chains import RetrievalQA
14
- import numpy as np
15
  import gradio
16
- import sqlite3
17
- from dotenv import load_dotenv
18
-
19
- # Load environment variables
20
- load_dotenv()
21
-
22
- #hfapi_key = getpass("Enter you HuggingFace access token:")
23
- hfapi_key = os.getenv("Mytoken")
24
-
25
- if not hfapi_key:
26
- raise ValueError("HUGGINGFACE_API_KEY not found in environment variables")
27
 
 
28
  os.environ["HF_TOKEN"] = hfapi_key
29
  os.environ["HUGGINGFACEHUB_API_TOKEN"] = hfapi_key
30
 
31
- set_llm_cache(InMemoryCache())
32
 
33
  persist_directory = 'docs/chroma/'
 
34
 
35
- ####################################
36
- def load_file_as_JSON():
37
- print("$$$$$ ENTER INTO load_file_as_JSON $$$$$")
38
- rows = []
39
- with open("mini-llama-articles.csv", mode="r", encoding="utf-8") as file:
40
- csv_reader = csv.reader(file)
41
- for idx, row in enumerate(csv_reader):
42
- if idx == 0:
43
- continue
44
- # Skip header row
45
- rows.append(row)
46
-
47
- print("@@@@@@ EXIT FROM load_file_as_JSON @@@@@")
48
- return rows
49
  ####################################
50
  def get_documents():
51
  print("$$$$$ ENTER INTO get_documents $$$$$")
52
- documents = [
53
- Document(
54
- page_content=row[1], metadata={"title": row[0], "url": row[2], "source_name": row[3]}
55
- )
56
- for row in load_file_as_JSON()
57
- ]
58
- print("documents lenght is ", len(documents))
59
- print("first entry from documents ", documents[0])
60
- print("document metadata ", documents[0].metadata)
 
61
  print("@@@@@@ EXIT FROM get_documents @@@@@")
62
- return documents
63
  ####################################
64
- def getDocSplitter():
65
  print("$$$$$ ENTER INTO getDocSplitter $$$$$")
66
  text_splitter = RecursiveCharacterTextSplitter(
67
  chunk_size = 512,
68
  chunk_overlap = 128
69
  )
70
- splits = text_splitter.split_documents(get_documents())
71
- print("Split length ", len(splits))
72
- print("Page content ", splits[0].page_content)
73
  print("@@@@@@ EXIT FROM getDocSplitter @@@@@")
74
- return splits
75
  ####################################
76
  def getEmbeddings():
77
  print("$$$$$ ENTER INTO getEmbeddings $$$$$")
@@ -90,133 +72,216 @@ def getEmbeddings():
90
  encode_kwargs=encode_kwargs # Pass the encoding options
91
  )
92
 
93
- print("Embedding ", embedding)
94
  print("@@@@@@ EXIT FROM getEmbeddings @@@@@")
95
  return embedding
96
  ####################################
97
  def getLLM():
98
  print("$$$$$ ENTER INTO getLLM $$$$$")
 
 
 
 
 
 
99
  llm = HuggingFaceEndpoint(
100
  repo_id="HuggingFaceH4/zephyr-7b-beta",
101
- #repo_id="chsubhasis/ai-tutor-towardsai",
102
  task="text-generation",
103
- max_new_tokens = 512,
104
- top_k = 10,
105
- temperature = 0.1,
106
- repetition_penalty = 1.03,
 
 
107
  )
108
- print("llm ", llm)
109
- print("Who is the CEO of Apple? ", llm.invoke("Who is the CEO of Apple?")) #test
110
  print("@@@@@@ EXIT FROM getLLM @@@@@")
111
  return llm
112
  ####################################
113
  def is_chroma_db_present(directory: str):
114
- """
115
- Check if the directory exists and contains any files.
116
- """
117
  return os.path.exists(directory) and len(os.listdir(directory)) > 0
118
  ####################################
119
- def getRetiriver():
120
  print("$$$$$ ENTER INTO getRetiriver $$$$$")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  if is_chroma_db_present(persist_directory):
122
  print(f"Chroma vector DB found in '{persist_directory}' and will be loaded.")
123
  # Load vector store from the local directory
124
- #vectordb = Chroma(persist_directory=persist_directory)
125
  vectordb = Chroma(
126
  persist_directory=persist_directory,
127
  embedding_function=getEmbeddings(),
128
  collection_name="ai_tutor")
129
  else:
130
- vectordb = Chroma.from_documents(
131
  collection_name="ai_tutor",
132
- documents=getDocSplitter(), # splits we created earlier
133
  embedding=getEmbeddings(),
134
  persist_directory=persist_directory, # save the directory
135
  )
136
- print("vectordb collection count ", vectordb._collection.count())
137
 
138
- docs = vectordb.search("What is Artificial Intelligence", search_type="mmr", k=5)
139
- for i in range(len(docs)):
140
- print(docs[i].page_content)
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
- metadata_filter = {
143
- "result": "llama" # ChromaDB will perform a substring search
144
- }
 
 
 
 
 
 
 
 
 
 
145
 
146
- retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={"k": 3, "fetch_k":5, "filter": metadata_filter})
147
- print("retriever ", retriever)
148
- print("@@@@@@ EXIT FROM getRetiriver @@@@@")
149
- return retriever
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  ####################################
151
- def get_rag_response(query):
152
- print("$$$$$ ENTER INTO get_rag_response $$$$$")
153
- qa_chain = RetrievalQA.from_chain_type(
154
- llm=getLLM(),
155
- chain_type="stuff",
156
- retriever=getRetiriver(),
157
- return_source_documents=True
158
- )
159
-
160
- #RAG Evaluation
161
- # Sample dataset of questions and expected answers
162
- dataset = [
163
- {"question": "Who is the CEO of Meta?", "expected_answer": "Mark Zuckerberg"},
164
- {"question": "Who is the CEO of Apple?", "expected_answer": "Tiiiiiim Coooooook"},
165
- ]
166
 
167
- hit_rate, mrr = evaluate_rag(qa_chain, dataset)
168
- print(f"Hit Rate: {hit_rate:.2f}, Mean Reciprocal Rank (MRR): {mrr:.2f}")
169
-
170
- result = qa_chain({"query": query})
171
- print("Result ",result)
172
- print("@@@@@@ EXIT FROM get_rag_response @@@@@")
173
- return result["result"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  ####################################
175
- def evaluate_rag(qa, dataset):
176
- print("$$$$$ ENTER INTO evaluate_rag $$$$$")
177
- hits = 0
178
- reciprocal_ranks = []
179
-
180
- for entry in dataset:
181
- question = entry["question"]
182
- expected_answer = entry["expected_answer"]
183
-
184
- # Get the answer from the RAG system
185
- response = qa({"query": question})
186
- answer = response["result"]
187
-
188
- # Check if the answer matches the expected answer
189
- if expected_answer.lower() in answer.lower():
190
- hits += 1
191
- reciprocal_ranks.append(1) # Hit at rank 1
192
- else:
193
- reciprocal_ranks.append(0)
194
-
195
- # Calculate Hit Rate and MRR
196
- hit_rate = hits / len(dataset)
197
- mrr = np.mean(reciprocal_ranks)
198
-
199
- print("@@@@@@ EXIT FROM evaluate_rag @@@@@")
200
- return hit_rate, mrr
201
  ####################################
202
- def launch_ui():
203
- print("$$$$$ ENTER INTO launch_ui $$$$$")
204
- # Input from user
205
- in_question = gradio.Textbox(lines=10, placeholder=None, value="query", label='Enter your query')
206
 
207
- # Output prediction
208
- out_response = gradio.Textbox(type="text", label='RAG Response')
209
 
210
- # Gradio interface to generate UI
211
- iface = gradio.Interface(fn = get_rag_response,
212
- inputs = [in_question],
213
- outputs = [out_response],
214
- title = "RAG Response",
215
- description = "Write the query and get the response from the RAG system",
216
- allow_flagging = 'never')
217
 
218
- iface.launch(share = True)
 
 
 
 
 
 
 
 
 
219
 
220
- ####################################
221
- if __name__ == "__main__":
222
- launch_ui()
 
1
  import os
2
  from getpass import getpass
 
 
3
  from langchain_text_splitters import RecursiveCharacterTextSplitter
 
4
  from langchain_huggingface import HuggingFaceEmbeddings
5
  import torch
6
  from langchain_huggingface import HuggingFaceEndpoint
7
+ from langchain_core.caches import InMemoryCache
8
+ from langchain_core.globals import set_llm_cache
9
  from langchain_chroma import Chroma
10
  from langchain.chains import RetrievalQA
 
11
  import gradio
12
+ import PyPDF2
13
+ import json
14
+ import re
15
+ import time
16
+ import threading
17
+ from langchain_core.runnables import RunnableConfig, RunnablePassthrough
18
+ from langchain_core.output_parsers import StrOutputParser
19
+ from langchain_core.prompts import PromptTemplate
20
+ from langchain_core.runnables import RunnableLambda
 
 
21
 
22
+ hfapi_key = getpass("Enter you HuggingFace access token:")
23
  os.environ["HF_TOKEN"] = hfapi_key
24
  os.environ["HUGGINGFACEHUB_API_TOKEN"] = hfapi_key
25
 
26
+ set_llm_cache(InMemoryCache()) # Set cache globally
27
 
28
  persist_directory = 'docs/chroma/'
29
+ pdf_path = 'AIML.pdf'
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  ####################################
32
  def get_documents():
33
  print("$$$$$ ENTER INTO get_documents $$$$$")
34
+
35
+ with open(pdf_path, 'rb') as file:
36
+ # Create a PDF reader object
37
+ pdf_reader = PyPDF2.PdfReader(file)
38
+
39
+ # Extract text from all pages
40
+ full_text = ""
41
+ for page in pdf_reader.pages:
42
+ full_text += page.extract_text() + "\n"
43
+
44
  print("@@@@@@ EXIT FROM get_documents @@@@@")
45
+ return full_text
46
  ####################################
47
+ def getTextSplits():
48
  print("$$$$$ ENTER INTO getDocSplitter $$$$$")
49
  text_splitter = RecursiveCharacterTextSplitter(
50
  chunk_size = 512,
51
  chunk_overlap = 128
52
  )
53
+ texts = text_splitter.split_text(get_documents())
54
+ #print("Page content ", texts)
 
55
  print("@@@@@@ EXIT FROM getDocSplitter @@@@@")
56
+ return texts
57
  ####################################
58
  def getEmbeddings():
59
  print("$$$$$ ENTER INTO getEmbeddings $$$$$")
 
72
  encode_kwargs=encode_kwargs # Pass the encoding options
73
  )
74
 
 
75
  print("@@@@@@ EXIT FROM getEmbeddings @@@@@")
76
  return embedding
77
  ####################################
78
  def getLLM():
79
  print("$$$$$ ENTER INTO getLLM $$$$$")
80
+
81
+ model_kwargs = {
82
+ 'device': "cuda" if torch.cuda.is_available() else "cpu",
83
+ 'stream': True # Ensure streaming is enabled
84
+ }
85
+
86
  llm = HuggingFaceEndpoint(
87
  repo_id="HuggingFaceH4/zephyr-7b-beta",
 
88
  task="text-generation",
89
+ max_new_tokens= 512,
90
+ do_sample= True,
91
+ temperature = 0.7,
92
+ repetition_penalty= 1.2,
93
+ top_k = 10
94
+ #model_kwargs=model_kwargs # Pass the model configuration options
95
  )
 
 
96
  print("@@@@@@ EXIT FROM getLLM @@@@@")
97
  return llm
98
  ####################################
99
  def is_chroma_db_present(directory: str):
100
+
101
+ #Check if the directory exists and contains any files.
 
102
  return os.path.exists(directory) and len(os.listdir(directory)) > 0
103
  ####################################
104
+ def getRetiriver(query, metadata_filter:None):
105
  print("$$$$$ ENTER INTO getRetiriver $$$$$")
106
+
107
+ # Classify query
108
+ query_type = classify_query(query)
109
+ print("Query classification", query_type)
110
+
111
+ k_default = 2
112
+ fetch_k_default = 5
113
+ search_type_default = "mmr"
114
+
115
+ # Routing logic
116
+ if query_type == 'concept':
117
+ # For conceptual queries, prioritize comprehensive context
118
+ k_default = 5
119
+ fetch_k_default = 10
120
+ search_type_default = "mmr"
121
+ elif query_type == 'example':
122
+ # For example queries, focus on more specific, relevant contexts
123
+ search_type_default = "similarity"
124
+ elif query_type == 'code':
125
+ # For code-related queries, use a more targeted retrieval
126
+ search_type_default = "similarity"
127
+
128
  if is_chroma_db_present(persist_directory):
129
  print(f"Chroma vector DB found in '{persist_directory}' and will be loaded.")
130
  # Load vector store from the local directory
 
131
  vectordb = Chroma(
132
  persist_directory=persist_directory,
133
  embedding_function=getEmbeddings(),
134
  collection_name="ai_tutor")
135
  else:
136
+ vectordb = Chroma.from_texts(
137
  collection_name="ai_tutor",
138
+ texts=getTextSplits(),
139
  embedding=getEmbeddings(),
140
  persist_directory=persist_directory, # save the directory
141
  )
 
142
 
143
+ print("metadata_filter", metadata_filter)
144
+ if(metadata_filter):
145
+ metadata_filter_dict = {
146
+ "result": metadata_filter # ChromaDB will perform a substring search
147
+ }
148
+ print("@@@@@@ EXIT FROM getRetiriver with metadata_filter @@@@@")
149
+
150
+ if search_type_default == "similarity":
151
+ return vectordb.as_retriever(search_type=search_type_default, search_kwargs={"k": k_default, "filter": metadata_filter_dict})
152
+
153
+ return vectordb.as_retriever(search_type=search_type_default, search_kwargs={"k": k_default, "fetch_k":fetch_k_default, "filter": metadata_filter_dict})
154
+
155
+ print("@@@@@@ EXIT FROM getRetiriver without metadata_filter @@@@@")
156
+ if search_type_default == "similarity":
157
+ return vectordb.as_retriever(search_type=search_type_default, search_kwargs={"k": k_default})
158
 
159
+ return vectordb.as_retriever(search_type=search_type_default, search_kwargs={"k": k_default, "fetch_k":fetch_k_default})
160
+ ####################################
161
+ def classify_query(query):
162
+ """
163
+ Classify the type of query to determine routing strategy.
164
+
165
+ Query Types:
166
+ - 'concept': Theoretical or conceptual questions
167
+ - 'example': Requests for practical examples
168
+ - 'code': Coding or implementation-related queries
169
+ - 'general': Default catch-all category
170
+ """
171
+ query = query.lower()
172
 
173
+ # Concept detection patterns
174
+ concept_patterns = [
175
+ r'what is',
176
+ r'define',
177
+ r'explain',
178
+ r'describe',
179
+ r'theory of',
180
+ r'concept of'
181
+ ]
182
+
183
+ # Example detection patterns
184
+ example_patterns = [
185
+ r'give an example',
186
+ r'show me an example',
187
+ r'demonstrate',
188
+ r'illustrate'
189
+ ]
190
+
191
+ # Code-related detection patterns
192
+ code_patterns = [
193
+ r'how to implement',
194
+ r'code for',
195
+ r'python code',
196
+ r'algorithm implementation',
197
+ r'write a program'
198
+ ]
199
+
200
+ # Check patterns
201
+ for pattern in concept_patterns:
202
+ if re.search(pattern, query):
203
+ return 'concept'
204
+
205
+ for pattern in example_patterns:
206
+ if re.search(pattern, query):
207
+ return 'example'
208
+
209
+ for pattern in code_patterns:
210
+ if re.search(pattern, query):
211
+ return 'code'
212
+
213
+ return 'general'
214
  ####################################
215
+ def get_rag_response(query, metadata_filter=None):
216
+ print("$$$$$ ENTER INTO get_rag_response $$$$$")
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
+ # Create the retriever
219
+ retriever = getRetiriver(query, metadata_filter)
220
+
221
+ # Get the LLM
222
+ llm = getLLM()
223
+
224
+ # Create a prompt template
225
+ template = """Use the following pieces of context to answer the question at the end.
226
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
227
+
228
+ Context: {context}
229
+
230
+ Question: {question}
231
+
232
+ Helpful Answer:"""
233
+
234
+ prompt = PromptTemplate.from_template(template)
235
+
236
+ # Function to prepare input for the chain
237
+ def prepare_inputs(inputs):
238
+ retrieved_docs = retriever.invoke(inputs["question"])
239
+ context = format_docs(retrieved_docs)
240
+ return {
241
+ "context": context,
242
+ "question": inputs["question"]
243
+ }
244
+
245
+ # Construct the RAG chain with streaming
246
+ rag_chain = (
247
+ RunnablePassthrough()
248
+ | RunnableLambda(prepare_inputs)
249
+ | prompt
250
+ | llm
251
+ | StrOutputParser()
252
+ )
253
+
254
+ # Stream the response
255
+ full_response = ""
256
+ for chunk in rag_chain.stream({"question": query}):
257
+ full_response += chunk
258
+ # Add a small delay to create a streaming effect
259
+ time.sleep(0.05) # 50 milliseconds between chunk updates
260
+ yield full_response
261
+
262
  ####################################
263
+ # Utility function to format documents
264
+ def format_docs(docs):
265
+ return "\n\n".join(doc.page_content for doc in docs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  ####################################
267
+ # Input from user
268
+ in_question = gradio.Textbox(lines=10, placeholder=None, value="What are Artificial Intelligence and Machine Learning?", label='Ask a question to your AI Tutor')
 
 
269
 
270
+ # Optional metadata filter input
271
+ in_metadata_filter = gradio.Textbox(lines=2, placeholder=None, label='Optionally add a filter word')
272
 
273
+ # Output prediction
274
+ out_response = gradio.Textbox(label='Response', interactive=False, show_copy_button=True)
 
 
 
 
 
275
 
276
+ # Gradio interface to generate UI
277
+ iface = gradio.Interface(
278
+ fn = get_rag_response,
279
+ inputs=[in_question, in_metadata_filter],
280
+ outputs=out_response,
281
+ title="Your AI Tutor",
282
+ description="Ask a question, optionally add metadata filters.",
283
+ allow_flagging='never',
284
+ stream_every=0.5
285
+ )
286
 
287
+ iface.launch(share = True)
 
 
mini-llama-articles.csv DELETED
The diff for this file is too large to render. See raw diff
 
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ