import os import json import chromadb import numpy as np from dotenv import load_dotenv import gradio as gr from groq import Groq import torch from transformers import AutoTokenizer, AutoModel import logging # Load environment variables load_dotenv() # List of API keys for Groq api_keys = [ os.getenv("GROQ_API_KEY"), os.getenv("GROQ_API_KEY_2"), os.getenv("GROQ_API_KEY_3"), os.getenv("GROQ_API_KEY_4"), ] if not any(api_keys): raise ValueError("At least one GROQ_API_KEY environment variable must be set.") # Initialize Groq client with the first API key current_key_index = 0 client = Groq(api_key=api_keys[current_key_index]) # Define Groq-based model with fallback class GroqChatbot: def __init__(self, api_keys): self.api_keys = api_keys self.current_key_index = 0 self.client = Groq(api_key=self.api_keys[self.current_key_index]) def switch_key(self): """Switch to the next API key in the list.""" self.current_key_index = (self.current_key_index + 1) % len(self.api_keys) self.client = Groq(api_key=self.api_keys[self.current_key_index]) print(f"Switched to API key index {self.current_key_index}") def get_response(self, prompt): """Get a response from the API, switching keys on failure.""" while True: try: response = self.client.chat.completions.create( messages=[ {"role": "system", "content": "You are a helpful AI assistant."}, {"role": "user", "content": prompt} ], model="llama3-70b-8192", ) return response.choices[0].message.content except Exception as e: print(f"Error: {e}") self.switch_key() if self.current_key_index == 0: return "All API keys have been exhausted. Please try again later." def text_to_embedding(self, text): """Convert text to embedding using the current model.""" try: # Load the model and tokenizer tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-3.2-1B") model = AutoModel.from_pretrained("NousResearch/Llama-3.2-1B") # Move model to GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) model.eval() # Ensure tokenizer has a padding token if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Tokenize the text encoded_input = tokenizer( text, padding=True, truncation=True, max_length=512, return_tensors='pt' ).to(device) # Generate embeddings with torch.no_grad(): model_output = model(**encoded_input) sentence_embeddings = model_output.last_hidden_state # Mean pooling attention_mask = encoded_input['attention_mask'] mask = attention_mask.unsqueeze(-1).expand(sentence_embeddings.size()).float() masked_embeddings = sentence_embeddings * mask summed = torch.sum(masked_embeddings, dim=1) summed_mask = torch.clamp(torch.sum(attention_mask, dim=1).unsqueeze(-1), min=1e-9) mean_pooled = (summed / summed_mask).squeeze() # Move to CPU and convert to numpy embedding = mean_pooled.cpu().numpy() # Normalize the embedding vector embedding = embedding / np.linalg.norm(embedding) print(f"Generated embedding for text: {text}") return embedding except Exception as e: print(f"Error generating embedding: {e}") return None # Modify LocalEmbeddingStore to use ChromaDB class LocalEmbeddingStore: def __init__(self, storage_dir="./chromadb_storage"): self.client = chromadb.PersistentClient(path=storage_dir) # Use ChromaDB client with persistent storage self.collection_name = "chatbot_docs" # Collection for storing embeddings self.collection = self.client.get_or_create_collection(name=self.collection_name) def add_embedding(self, doc_id, embedding, metadata): """Add a document and its embedding to ChromaDB.""" self.collection.add( documents=[doc_id], # Document ID for identification embeddings=[embedding], # Embedding for the document metadatas=[metadata], # Optional metadata ids=[doc_id] # Same ID as document ID ) print(f"Added embedding for document ID: {doc_id}") def search_embedding(self, query_embedding, num_results=3): """Search for the most relevant document based on embedding similarity.""" results = self.collection.query( query_embeddings=[query_embedding], n_results=num_results ) print(f"Search results: {results}") return results['documents'], results['distances'] # Returning both document IDs and distances # Modify RAGSystem to integrate ChromaDB search class RAGSystem: def __init__(self, groq_client, embedding_store): self.groq_client = groq_client self.embedding_store = embedding_store def get_most_relevant_document(self, query_embedding): """Retrieve the most relevant document based on cosine similarity.""" docs, distances = self.embedding_store.search_embedding(query_embedding) if docs: return docs[0], distances[0][0] # Return the most relevant document and the first distance value return None, None def chat_with_rag(self, user_input): """Handle the RAG process.""" query_embedding = self.groq_client.text_to_embedding(user_input) if query_embedding is None or query_embedding.size == 0: return "Failed to generate embeddings." context_document_id, similarity_score = self.get_most_relevant_document(query_embedding) if not context_document_id: return "No relevant documents found." # Assuming metadata retrieval works context_metadata = f"Metadata for {context_document_id}" # Placeholder, implement as needed prompt = f"""Context (similarity score {similarity_score:.2f}): {context_metadata} User: {user_input} AI:""" return self.groq_client.get_response(prompt) # Initialize components embedding_store = LocalEmbeddingStore(storage_dir="./chromadb_storage") chatbot = GroqChatbot(api_keys=api_keys) rag_system = RAGSystem(groq_client=chatbot, embedding_store=embedding_store) # Gradio UI def chat_ui(user_input, chat_history): """Handle chat interactions and update history.""" if not user_input.strip(): return chat_history ai_response = rag_system.chat_with_rag(user_input) chat_history.append((user_input, ai_response)) return chat_history # Gradio interface with gr.Blocks() as demo: chat_history = gr.Chatbot(label="Groq Chatbot with RAG", elem_id="chatbox") user_input = gr.Textbox(placeholder="Enter your prompt here...") submit_button = gr.Button("Submit") submit_button.click(chat_ui, inputs=[user_input, chat_history], outputs=chat_history) if __name__ == "__main__": demo.launch()