import os from langchain_community.document_loaders import PyMuPDFLoader from langchain_core.documents import Document from langchain_community.embeddings.sentence_transformer import ( SentenceTransformerEmbeddings, ) from langchain.schema import StrOutputParser from langchain_community.vectorstores import Chroma from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain import hub from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough from langchain_groq import ChatGroq from langchain_openai import ChatOpenAI from langchain_google_genai import ChatGoogleGenerativeAI from langchain_anthropic import ChatAnthropic from dotenv import load_dotenv load_dotenv() # suppress grpc and glog logs for gemini os.environ["GRPC_VERBOSITY"] = "ERROR" os.environ["GLOG_minloglevel"] = "2" # RAG parameters CHUNK_SIZE = 1024 CHUNK_OVERLAP = CHUNK_SIZE // 8 K = 10 FETCH_K = 20 llm_model_translation = { "LLaMA 3": "llama3-70b-8192", "OpenAI GPT 4o Mini": "gpt-4o-mini", "OpenAI GPT 4o": "gpt-4o", "OpenAI GPT 4": "gpt-4-turbo", "Gemini 1.5 Pro": "gemini-1.5-pro", "Claude Sonnet 3.5": "claude-3-5-sonnet-20240620", } llm_classes = { "llama3-70b-8192": ChatGroq, "gpt-4o-mini": ChatOpenAI, "gpt-4o": ChatOpenAI, "gpt-4-turbo": ChatOpenAI, "gemini-1.5-pro": ChatGoogleGenerativeAI, "claude-3-5-sonnet-20240620": ChatAnthropic, } def load_llm(model: str, api_key: str, temperature: float = 1.0, max_length: int = 2048): model_name = llm_model_translation.get(model) llm_class = llm_classes.get(model_name) if not llm_class: raise ValueError(f"Model {model} not supported.") try: llm = llm_class(model_name=model_name, temperature=temperature, max_tokens=max_length) except Exception as e: print(f"An error occurred: {e}") llm = None return llm def create_db_with_langchain(path: list[str], url_content: dict): all_docs = [] text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP) embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") if path: for file in path: loader = PyMuPDFLoader(file) data = loader.load() # split it into chunks docs = text_splitter.split_documents(data) all_docs.extend(docs) if url_content: for url, content in url_content.items(): doc = Document(page_content=content, metadata={"source": url}) # split it into chunks docs = text_splitter.split_documents([doc]) all_docs.extend(docs) # print docs for idx, doc in enumerate(all_docs): print(f"Doc: {idx} | Length = {len(doc.page_content)}") assert len(all_docs) > 0, "No PDFs or scrapped data provided" db = Chroma.from_documents(all_docs, embedding_function) return db def generate_rag( prompt: str, topic: str, model: str, url_content: dict, path: list[str], temperature: float = 1.0, max_length: int = 2048, api_key: str = "", sys_message="", ): llm = load_llm(model, api_key, temperature, max_length) if llm is None: print("Failed to load LLM. Aborting operation.") return None db = create_db_with_langchain(path, url_content) retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": K, "fetch_k": FETCH_K}) rag_prompt = hub.pull("rlm/rag-prompt") def format_docs(docs): return "\n\n".join(doc.page_content for doc in docs) docs = retriever.get_relevant_documents(topic) formatted_docs = format_docs(docs) rag_chain = ( {"context": lambda _: formatted_docs, "question": RunnablePassthrough()} | rag_prompt | llm | StrOutputParser() ) return rag_chain.invoke(prompt) def generate_base( prompt: str, topic: str, model: str, temperature: float, max_length: int, api_key: str, sys_message="" ): llm = load_llm(model, api_key, temperature, max_length) if llm is None: print("Failed to load LLM. Aborting operation.") return None try: output = llm.invoke(prompt).content return output except Exception as e: print(f"An error occurred while running the model: {e}") return None def generate( prompt: str, topic: str, model: str, url_content: dict, path: list[str], temperature: float = 1.0, max_length: int = 2048, api_key: str = "", sys_message="", ): if path or url_content: return generate_rag(prompt, topic, model, url_content, path, temperature, max_length, api_key, sys_message) else: return generate_base(prompt, topic, model, temperature, max_length, api_key, sys_message)