LLM Google T5 integration
#3
by
AryanJh
- opened
app.py
CHANGED
@@ -16,17 +16,26 @@ class BrockEventsRAG:
|
|
16 |
def __init__(self):
|
17 |
"""Initialize the RAG system with improved caching"""
|
18 |
self.model = SentenceTransformer('all-MiniLM-L6-v2')
|
19 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
# Get current date range
|
22 |
self.eastern = pytz.timezone('America/New_York')
|
23 |
self.today = datetime.now(self.eastern).replace(hour=0, minute=0, second=0, microsecond=0)
|
24 |
self.date_range_end = self.today + timedelta(days=14)
|
25 |
-
|
26 |
# Cache directory setup
|
27 |
os.makedirs("cache", exist_ok=True)
|
28 |
self.cache_file = "cache/events_cache.json"
|
29 |
-
|
|
|
30 |
# Initialize or reset collection
|
31 |
try:
|
32 |
self.collection = self.chroma_client.create_collection(
|
@@ -42,69 +51,18 @@ class BrockEventsRAG:
|
|
42 |
|
43 |
# Load initial events
|
44 |
self.update_database()
|
45 |
-
|
46 |
-
def
|
47 |
-
"""
|
48 |
-
try:
|
49 |
-
# Convert datetime objects to strings for JSON serialization
|
50 |
-
serializable_data = {
|
51 |
-
'last_update': data['last_update'],
|
52 |
-
'events': []
|
53 |
-
}
|
54 |
-
|
55 |
-
for event in data['events']:
|
56 |
-
event_copy = event.copy()
|
57 |
-
# Convert datetime objects to strings
|
58 |
-
if event_copy.get('start_time'):
|
59 |
-
event_copy['start_time'] = event_copy['start_time'].isoformat()
|
60 |
-
if event_copy.get('end_time'):
|
61 |
-
event_copy['end_time'] = event_copy['end_time'].isoformat()
|
62 |
-
serializable_data['events'].append(event_copy)
|
63 |
-
|
64 |
-
with open(self.cache_file, 'w', encoding='utf-8') as f:
|
65 |
-
json.dump(serializable_data, f, ensure_ascii=False, indent=2)
|
66 |
-
print(f"Cache saved successfully to {self.cache_file}")
|
67 |
-
|
68 |
-
except Exception as e:
|
69 |
-
print(f"Error saving cache: {e}")
|
70 |
-
|
71 |
-
def load_cache(self) -> dict:
|
72 |
-
"""Load and parse cached events data"""
|
73 |
try:
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
# Convert string timestamps back to datetime objects
|
79 |
-
for event in data['events']:
|
80 |
-
if event.get('start_time'):
|
81 |
-
event['start_time'] = datetime.fromisoformat(event['start_time'])
|
82 |
-
if event.get('end_time'):
|
83 |
-
event['end_time'] = datetime.fromisoformat(event['end_time'])
|
84 |
-
|
85 |
-
return data
|
86 |
-
return {'last_update': None, 'events': []}
|
87 |
-
|
88 |
except Exception as e:
|
89 |
-
print(f"Error
|
90 |
-
return
|
91 |
-
|
92 |
-
def should_update_cache(self) -> bool:
|
93 |
-
"""Check if cache needs updating (older than 24 hours)"""
|
94 |
-
try:
|
95 |
-
cached_data = self.load_cache()
|
96 |
-
if not cached_data['last_update']:
|
97 |
-
return True
|
98 |
-
|
99 |
-
last_update = datetime.fromisoformat(cached_data['last_update'])
|
100 |
-
time_since_update = datetime.now() - last_update
|
101 |
-
|
102 |
-
return time_since_update.total_seconds() > 86400 # 24 hours
|
103 |
|
104 |
-
except Exception as e:
|
105 |
-
print(f"Error checking cache: {e}")
|
106 |
-
return True
|
107 |
-
|
108 |
def parse_event_datetime(self, entry) -> tuple:
|
109 |
"""Parse start and end times from both RSS and HTML"""
|
110 |
try:
|
@@ -294,6 +252,28 @@ class BrockEventsRAG:
|
|
294 |
except Exception as e:
|
295 |
print(f"Error during query: {e}")
|
296 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
def generate_response(self, question: str, history: list) -> str:
|
298 |
"""Generate a response based on the query and chat history"""
|
299 |
try:
|
@@ -308,7 +288,7 @@ class BrockEventsRAG:
|
|
308 |
is_location_query = any(word in question_lower for word in ['where', 'location', 'place', 'building', 'room'])
|
309 |
|
310 |
# Format the response
|
311 |
-
response =
|
312 |
|
313 |
# Add top 3 matching events
|
314 |
for i, (doc, metadata) in enumerate(zip(results['documents'][0][:3], results['metadatas'][0][:3]), 1):
|
@@ -326,7 +306,69 @@ class BrockEventsRAG:
|
|
326 |
except Exception as e:
|
327 |
print(f"Error generating response: {e}")
|
328 |
return "I encountered an error while searching for events. Please try asking in a different way."
|
329 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
def create_demo():
|
331 |
# Initialize the RAG system
|
332 |
rag_system = BrockEventsRAG()
|
|
|
16 |
def __init__(self):
|
17 |
"""Initialize the RAG system with improved caching"""
|
18 |
self.model = SentenceTransformer('all-MiniLM-L6-v2')
|
19 |
+
self.embeddings = HuggingFaceEmbeddings(model_name='all-MiniLM-L6-v2')
|
20 |
+
|
21 |
+
# ChromaDB client setup
|
22 |
+
self.chroma_client = chromadb.Client(Settings(persist_directory="chroma_db", chroma_db_impl="duckdb+parquet"))
|
23 |
+
|
24 |
+
# LLM model setup
|
25 |
+
self.tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
|
26 |
+
self.llm = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
|
27 |
+
|
28 |
|
29 |
# Get current date range
|
30 |
self.eastern = pytz.timezone('America/New_York')
|
31 |
self.today = datetime.now(self.eastern).replace(hour=0, minute=0, second=0, microsecond=0)
|
32 |
self.date_range_end = self.today + timedelta(days=14)
|
33 |
+
|
34 |
# Cache directory setup
|
35 |
os.makedirs("cache", exist_ok=True)
|
36 |
self.cache_file = "cache/events_cache.json"
|
37 |
+
|
38 |
+
|
39 |
# Initialize or reset collection
|
40 |
try:
|
41 |
self.collection = self.chroma_client.create_collection(
|
|
|
51 |
|
52 |
# Load initial events
|
53 |
self.update_database()
|
54 |
+
|
55 |
+
def fetch_rss_feed(self, url: str) -> List[Dict]:
|
56 |
+
"""Fetch and parse RSS feed from the given URL"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
try:
|
58 |
+
feed = feedparser.parse(url)
|
59 |
+
entries = feed.entries
|
60 |
+
print(f"Fetched {len(entries)} entries from the feed.")
|
61 |
+
return entries
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
except Exception as e:
|
63 |
+
print(f"Error fetching RSS feed: {e}")
|
64 |
+
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
|
|
|
|
|
|
|
|
66 |
def parse_event_datetime(self, entry) -> tuple:
|
67 |
"""Parse start and end times from both RSS and HTML"""
|
68 |
try:
|
|
|
252 |
except Exception as e:
|
253 |
print(f"Error during query: {e}")
|
254 |
return None
|
255 |
+
|
256 |
+
def generate_response_with_llm(events: List[Dict]) -> str:
|
257 |
+
"""Use the LLM to generate a natural language response for the given events."""
|
258 |
+
try:
|
259 |
+
if not events:
|
260 |
+
input_text = "There are no events matching the query. How should I respond?"
|
261 |
+
else:
|
262 |
+
event_summaries = "\n".join([
|
263 |
+
f"Event: {event['title']}. Start: {event['start_time']}, Location: {event['location']}."
|
264 |
+
for event in events
|
265 |
+
])
|
266 |
+
input_text = f"Format this information into a friendly response: {event_summaries}"
|
267 |
+
|
268 |
+
inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
|
269 |
+
outputs = self.llm.generate(**inputs)
|
270 |
+
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
271 |
+
return response
|
272 |
+
except Exception as e:
|
273 |
+
print(f"Error generating response: {e}")
|
274 |
+
return "Sorry, I couldn't generate a response."
|
275 |
+
|
276 |
+
|
277 |
def generate_response(self, question: str, history: list) -> str:
|
278 |
"""Generate a response based on the query and chat history"""
|
279 |
try:
|
|
|
288 |
is_location_query = any(word in question_lower for word in ['where', 'location', 'place', 'building', 'room'])
|
289 |
|
290 |
# Format the response
|
291 |
+
response = generate_response_with_llm(matched_events)
|
292 |
|
293 |
# Add top 3 matching events
|
294 |
for i, (doc, metadata) in enumerate(zip(results['documents'][0][:3], results['metadatas'][0][:3]), 1):
|
|
|
306 |
except Exception as e:
|
307 |
print(f"Error generating response: {e}")
|
308 |
return "I encountered an error while searching for events. Please try asking in a different way."
|
309 |
+
def save_cache(self, data: dict):
|
310 |
+
"""Save events data to cache file"""
|
311 |
+
try:
|
312 |
+
# Convert datetime objects to strings for JSON serialization
|
313 |
+
serializable_data = {
|
314 |
+
'last_update': data['last_update'],
|
315 |
+
'events': []
|
316 |
+
}
|
317 |
+
|
318 |
+
for event in data['events']:
|
319 |
+
event_copy = event.copy()
|
320 |
+
# Convert datetime objects to strings
|
321 |
+
if event_copy.get('start_time'):
|
322 |
+
event_copy['start_time'] = event_copy['start_time'].isoformat()
|
323 |
+
if event_copy.get('end_time'):
|
324 |
+
event_copy['end_time'] = event_copy['end_time'].isoformat()
|
325 |
+
serializable_data['events'].append(event_copy)
|
326 |
+
|
327 |
+
with open(self.cache_file, 'w', encoding='utf-8') as f:
|
328 |
+
json.dump(serializable_data, f, ensure_ascii=False, indent=2)
|
329 |
+
print(f"Cache saved successfully to {self.cache_file}")
|
330 |
+
|
331 |
+
except Exception as e:
|
332 |
+
print(f"Error saving cache: {e}")
|
333 |
+
"""
|
334 |
+
def load_cache(self) -> dict:
|
335 |
+
#Load and parse cached events data
|
336 |
+
try:
|
337 |
+
if os.path.exists(self.cache_file):
|
338 |
+
with open(self.cache_file, 'r', encoding='utf-8') as f:
|
339 |
+
data = json.load(f)
|
340 |
+
|
341 |
+
# Convert string timestamps back to datetime objects
|
342 |
+
for event in data['events']:
|
343 |
+
if event.get('start_time'):
|
344 |
+
event['start_time'] = datetime.fromisoformat(event['start_time'])
|
345 |
+
if event.get('end_time'):
|
346 |
+
event['end_time'] = datetime.fromisoformat(event['end_time'])
|
347 |
+
|
348 |
+
return data
|
349 |
+
return {'last_update': None, 'events': []}
|
350 |
+
|
351 |
+
except Exception as e:
|
352 |
+
print(f"Error loading cache: {e}")
|
353 |
+
return {'last_update': None, 'events': []}
|
354 |
+
|
355 |
+
def should_update_cache(self) -> bool:
|
356 |
+
#Check if cache needs updating (older than 24 hours)
|
357 |
+
try:
|
358 |
+
cached_data = self.load_cache()
|
359 |
+
if not cached_data['last_update']:
|
360 |
+
return True
|
361 |
+
|
362 |
+
last_update = datetime.fromisoformat(cached_data['last_update'])
|
363 |
+
time_since_update = datetime.now() - last_update
|
364 |
+
|
365 |
+
return time_since_update.total_seconds() > 86400 # 24 hours
|
366 |
+
|
367 |
+
except Exception as e:
|
368 |
+
print(f"Error checking cache: {e}")
|
369 |
+
return True
|
370 |
+
"""
|
371 |
+
|
372 |
def create_demo():
|
373 |
# Initialize the RAG system
|
374 |
rag_system = BrockEventsRAG()
|