LLM Google T5 integration

#3
by AryanJh - opened
Files changed (1) hide show
  1. app.py +107 -65
app.py CHANGED
@@ -16,17 +16,26 @@ class BrockEventsRAG:
16
  def __init__(self):
17
  """Initialize the RAG system with improved caching"""
18
  self.model = SentenceTransformer('all-MiniLM-L6-v2')
19
- self.chroma_client = chromadb.Client()
 
 
 
 
 
 
 
 
20
 
21
  # Get current date range
22
  self.eastern = pytz.timezone('America/New_York')
23
  self.today = datetime.now(self.eastern).replace(hour=0, minute=0, second=0, microsecond=0)
24
  self.date_range_end = self.today + timedelta(days=14)
25
-
26
  # Cache directory setup
27
  os.makedirs("cache", exist_ok=True)
28
  self.cache_file = "cache/events_cache.json"
29
-
 
30
  # Initialize or reset collection
31
  try:
32
  self.collection = self.chroma_client.create_collection(
@@ -42,69 +51,18 @@ class BrockEventsRAG:
42
 
43
  # Load initial events
44
  self.update_database()
45
-
46
- def save_cache(self, data: dict):
47
- """Save events data to cache file"""
48
- try:
49
- # Convert datetime objects to strings for JSON serialization
50
- serializable_data = {
51
- 'last_update': data['last_update'],
52
- 'events': []
53
- }
54
-
55
- for event in data['events']:
56
- event_copy = event.copy()
57
- # Convert datetime objects to strings
58
- if event_copy.get('start_time'):
59
- event_copy['start_time'] = event_copy['start_time'].isoformat()
60
- if event_copy.get('end_time'):
61
- event_copy['end_time'] = event_copy['end_time'].isoformat()
62
- serializable_data['events'].append(event_copy)
63
-
64
- with open(self.cache_file, 'w', encoding='utf-8') as f:
65
- json.dump(serializable_data, f, ensure_ascii=False, indent=2)
66
- print(f"Cache saved successfully to {self.cache_file}")
67
-
68
- except Exception as e:
69
- print(f"Error saving cache: {e}")
70
-
71
- def load_cache(self) -> dict:
72
- """Load and parse cached events data"""
73
  try:
74
- if os.path.exists(self.cache_file):
75
- with open(self.cache_file, 'r', encoding='utf-8') as f:
76
- data = json.load(f)
77
-
78
- # Convert string timestamps back to datetime objects
79
- for event in data['events']:
80
- if event.get('start_time'):
81
- event['start_time'] = datetime.fromisoformat(event['start_time'])
82
- if event.get('end_time'):
83
- event['end_time'] = datetime.fromisoformat(event['end_time'])
84
-
85
- return data
86
- return {'last_update': None, 'events': []}
87
-
88
  except Exception as e:
89
- print(f"Error loading cache: {e}")
90
- return {'last_update': None, 'events': []}
91
-
92
- def should_update_cache(self) -> bool:
93
- """Check if cache needs updating (older than 24 hours)"""
94
- try:
95
- cached_data = self.load_cache()
96
- if not cached_data['last_update']:
97
- return True
98
-
99
- last_update = datetime.fromisoformat(cached_data['last_update'])
100
- time_since_update = datetime.now() - last_update
101
-
102
- return time_since_update.total_seconds() > 86400 # 24 hours
103
 
104
- except Exception as e:
105
- print(f"Error checking cache: {e}")
106
- return True
107
-
108
  def parse_event_datetime(self, entry) -> tuple:
109
  """Parse start and end times from both RSS and HTML"""
110
  try:
@@ -294,6 +252,28 @@ class BrockEventsRAG:
294
  except Exception as e:
295
  print(f"Error during query: {e}")
296
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  def generate_response(self, question: str, history: list) -> str:
298
  """Generate a response based on the query and chat history"""
299
  try:
@@ -308,7 +288,7 @@ class BrockEventsRAG:
308
  is_location_query = any(word in question_lower for word in ['where', 'location', 'place', 'building', 'room'])
309
 
310
  # Format the response
311
- response = "Here are some relevant events I found:\n\n"
312
 
313
  # Add top 3 matching events
314
  for i, (doc, metadata) in enumerate(zip(results['documents'][0][:3], results['metadatas'][0][:3]), 1):
@@ -326,7 +306,69 @@ class BrockEventsRAG:
326
  except Exception as e:
327
  print(f"Error generating response: {e}")
328
  return "I encountered an error while searching for events. Please try asking in a different way."
329
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  def create_demo():
331
  # Initialize the RAG system
332
  rag_system = BrockEventsRAG()
 
16
  def __init__(self):
17
  """Initialize the RAG system with improved caching"""
18
  self.model = SentenceTransformer('all-MiniLM-L6-v2')
19
+ self.embeddings = HuggingFaceEmbeddings(model_name='all-MiniLM-L6-v2')
20
+
21
+ # ChromaDB client setup
22
+ self.chroma_client = chromadb.Client(Settings(persist_directory="chroma_db", chroma_db_impl="duckdb+parquet"))
23
+
24
+ # LLM model setup
25
+ self.tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
26
+ self.llm = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
27
+
28
 
29
  # Get current date range
30
  self.eastern = pytz.timezone('America/New_York')
31
  self.today = datetime.now(self.eastern).replace(hour=0, minute=0, second=0, microsecond=0)
32
  self.date_range_end = self.today + timedelta(days=14)
33
+
34
  # Cache directory setup
35
  os.makedirs("cache", exist_ok=True)
36
  self.cache_file = "cache/events_cache.json"
37
+
38
+
39
  # Initialize or reset collection
40
  try:
41
  self.collection = self.chroma_client.create_collection(
 
51
 
52
  # Load initial events
53
  self.update_database()
54
+
55
+ def fetch_rss_feed(self, url: str) -> List[Dict]:
56
+ """Fetch and parse RSS feed from the given URL"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  try:
58
+ feed = feedparser.parse(url)
59
+ entries = feed.entries
60
+ print(f"Fetched {len(entries)} entries from the feed.")
61
+ return entries
 
 
 
 
 
 
 
 
 
 
62
  except Exception as e:
63
+ print(f"Error fetching RSS feed: {e}")
64
+ return []
 
 
 
 
 
 
 
 
 
 
 
 
65
 
 
 
 
 
66
  def parse_event_datetime(self, entry) -> tuple:
67
  """Parse start and end times from both RSS and HTML"""
68
  try:
 
252
  except Exception as e:
253
  print(f"Error during query: {e}")
254
  return None
255
+
256
+ def generate_response_with_llm(events: List[Dict]) -> str:
257
+ """Use the LLM to generate a natural language response for the given events."""
258
+ try:
259
+ if not events:
260
+ input_text = "There are no events matching the query. How should I respond?"
261
+ else:
262
+ event_summaries = "\n".join([
263
+ f"Event: {event['title']}. Start: {event['start_time']}, Location: {event['location']}."
264
+ for event in events
265
+ ])
266
+ input_text = f"Format this information into a friendly response: {event_summaries}"
267
+
268
+ inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
269
+ outputs = self.llm.generate(**inputs)
270
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
271
+ return response
272
+ except Exception as e:
273
+ print(f"Error generating response: {e}")
274
+ return "Sorry, I couldn't generate a response."
275
+
276
+
277
  def generate_response(self, question: str, history: list) -> str:
278
  """Generate a response based on the query and chat history"""
279
  try:
 
288
  is_location_query = any(word in question_lower for word in ['where', 'location', 'place', 'building', 'room'])
289
 
290
  # Format the response
291
+ response = generate_response_with_llm(matched_events)
292
 
293
  # Add top 3 matching events
294
  for i, (doc, metadata) in enumerate(zip(results['documents'][0][:3], results['metadatas'][0][:3]), 1):
 
306
  except Exception as e:
307
  print(f"Error generating response: {e}")
308
  return "I encountered an error while searching for events. Please try asking in a different way."
309
+ def save_cache(self, data: dict):
310
+ """Save events data to cache file"""
311
+ try:
312
+ # Convert datetime objects to strings for JSON serialization
313
+ serializable_data = {
314
+ 'last_update': data['last_update'],
315
+ 'events': []
316
+ }
317
+
318
+ for event in data['events']:
319
+ event_copy = event.copy()
320
+ # Convert datetime objects to strings
321
+ if event_copy.get('start_time'):
322
+ event_copy['start_time'] = event_copy['start_time'].isoformat()
323
+ if event_copy.get('end_time'):
324
+ event_copy['end_time'] = event_copy['end_time'].isoformat()
325
+ serializable_data['events'].append(event_copy)
326
+
327
+ with open(self.cache_file, 'w', encoding='utf-8') as f:
328
+ json.dump(serializable_data, f, ensure_ascii=False, indent=2)
329
+ print(f"Cache saved successfully to {self.cache_file}")
330
+
331
+ except Exception as e:
332
+ print(f"Error saving cache: {e}")
333
+ """
334
+ def load_cache(self) -> dict:
335
+ #Load and parse cached events data
336
+ try:
337
+ if os.path.exists(self.cache_file):
338
+ with open(self.cache_file, 'r', encoding='utf-8') as f:
339
+ data = json.load(f)
340
+
341
+ # Convert string timestamps back to datetime objects
342
+ for event in data['events']:
343
+ if event.get('start_time'):
344
+ event['start_time'] = datetime.fromisoformat(event['start_time'])
345
+ if event.get('end_time'):
346
+ event['end_time'] = datetime.fromisoformat(event['end_time'])
347
+
348
+ return data
349
+ return {'last_update': None, 'events': []}
350
+
351
+ except Exception as e:
352
+ print(f"Error loading cache: {e}")
353
+ return {'last_update': None, 'events': []}
354
+
355
+ def should_update_cache(self) -> bool:
356
+ #Check if cache needs updating (older than 24 hours)
357
+ try:
358
+ cached_data = self.load_cache()
359
+ if not cached_data['last_update']:
360
+ return True
361
+
362
+ last_update = datetime.fromisoformat(cached_data['last_update'])
363
+ time_since_update = datetime.now() - last_update
364
+
365
+ return time_since_update.total_seconds() > 86400 # 24 hours
366
+
367
+ except Exception as e:
368
+ print(f"Error checking cache: {e}")
369
+ return True
370
+ """
371
+
372
  def create_demo():
373
  # Initialize the RAG system
374
  rag_system = BrockEventsRAG()