Spaces:
Runtime error
Runtime error
import os | |
from langchain_community.document_loaders import PyMuPDFLoader | |
from langchain_core.documents import Document | |
from langchain_community.embeddings.sentence_transformer import ( | |
SentenceTransformerEmbeddings, | |
) | |
from langchain.schema import StrOutputParser | |
from langchain_community.vectorstores import Chroma | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain import hub | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_groq import ChatGroq | |
from langchain_openai import ChatOpenAI | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_anthropic import ChatAnthropic | |
from dotenv import load_dotenv | |
load_dotenv() | |
# suppress grpc and glog logs for gemini | |
os.environ["GRPC_VERBOSITY"] = "ERROR" | |
os.environ["GLOG_minloglevel"] = "2" | |
# RAG parameters | |
CHUNK_SIZE = 1024 | |
CHUNK_OVERLAP = CHUNK_SIZE // 8 | |
K = 10 | |
FETCH_K = 20 | |
llm_model_translation = { | |
"LLaMA 3": "llama3-70b-8192", | |
"OpenAI GPT 4o Mini": "gpt-4o-mini", | |
"OpenAI GPT 4o": "gpt-4o", | |
"OpenAI GPT 4": "gpt-4-turbo", | |
"Gemini 1.5 Pro": "gemini-1.5-pro", | |
"Claude Sonnet 3.5": "claude-3-5-sonnet-20240620", | |
} | |
llm_classes = { | |
"llama3-70b-8192": ChatGroq, | |
"gpt-4o-mini": ChatOpenAI, | |
"gpt-4o": ChatOpenAI, | |
"gpt-4-turbo": ChatOpenAI, | |
"gemini-1.5-pro": ChatGoogleGenerativeAI, | |
"claude-3-5-sonnet-20240620": ChatAnthropic, | |
} | |
def load_llm(model: str, api_key: str, temperature: float = 1.0, max_length: int = 2048): | |
model_name = llm_model_translation.get(model) | |
llm_class = llm_classes.get(model_name) | |
if not llm_class: | |
raise ValueError(f"Model {model} not supported.") | |
try: | |
llm = llm_class(model_name=model_name, temperature=temperature, max_tokens=max_length) | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
llm = None | |
return llm | |
def create_db_with_langchain(path: list[str], url_content: dict): | |
all_docs = [] | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP) | |
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") | |
if path: | |
for file in path: | |
loader = PyMuPDFLoader(file) | |
data = loader.load() | |
# split it into chunks | |
docs = text_splitter.split_documents(data) | |
all_docs.extend(docs) | |
if url_content: | |
for url, content in url_content.items(): | |
doc = Document(page_content=content, metadata={"source": url}) | |
# split it into chunks | |
docs = text_splitter.split_documents([doc]) | |
all_docs.extend(docs) | |
# print docs | |
for idx, doc in enumerate(all_docs): | |
print(f"Doc: {idx} | Length = {len(doc.page_content)}") | |
assert len(all_docs) > 0, "No PDFs or scrapped data provided" | |
db = Chroma.from_documents(all_docs, embedding_function) | |
return db | |
def generate_rag( | |
prompt: str, | |
topic: str, | |
model: str, | |
url_content: dict, | |
path: list[str], | |
temperature: float = 1.0, | |
max_length: int = 2048, | |
api_key: str = "", | |
sys_message="", | |
): | |
llm = load_llm(model, api_key, temperature, max_length) | |
if llm is None: | |
print("Failed to load LLM. Aborting operation.") | |
return None | |
db = create_db_with_langchain(path, url_content) | |
retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": K, "fetch_k": FETCH_K}) | |
rag_prompt = hub.pull("rlm/rag-prompt") | |
def format_docs(docs): | |
return "\n\n".join(doc.page_content for doc in docs) | |
docs = retriever.get_relevant_documents(topic) | |
formatted_docs = format_docs(docs) | |
rag_chain = ( | |
{"context": lambda _: formatted_docs, "question": RunnablePassthrough()} | rag_prompt | llm | StrOutputParser() | |
) | |
return rag_chain.invoke(prompt) | |
def generate_base( | |
prompt: str, topic: str, model: str, temperature: float, max_length: int, api_key: str, sys_message="" | |
): | |
llm = load_llm(model, api_key, temperature, max_length) | |
if llm is None: | |
print("Failed to load LLM. Aborting operation.") | |
return None | |
try: | |
output = llm.invoke(prompt).content | |
return output | |
except Exception as e: | |
print(f"An error occurred while running the model: {e}") | |
return None | |
def generate( | |
prompt: str, | |
topic: str, | |
model: str, | |
url_content: dict, | |
path: list[str], | |
temperature: float = 1.0, | |
max_length: int = 2048, | |
api_key: str = "", | |
sys_message="", | |
): | |
if path or url_content: | |
return generate_rag(prompt, topic, model, url_content, path, temperature, max_length, api_key, sys_message) | |
else: | |
return generate_base(prompt, topic, model, temperature, max_length, api_key, sys_message) | |