Spaces:
Runtime error
Runtime error
File size: 4,845 Bytes
7edc5be 03fd59b 708f094 03fd59b 708f094 03fd59b 8f26ea6 03fd59b 8f26ea6 03fd59b 4b92a71 03fd59b 134b51f 03fd59b e1b0f65 8f26ea6 f716a54 8f26ea6 59fbf6a 708f094 4b92a71 708f094 134b51f 708f094 43d4e83 59fbf6a 03fd59b 708f094 59fbf6a 708f094 43d4e83 03fd59b 708f094 59fbf6a 708f094 03fd59b 708f094 f716a54 708f094 4b92a71 708f094 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import os
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_core.documents import Document
from langchain_community.embeddings.sentence_transformer import (
SentenceTransformerEmbeddings,
)
from langchain.schema import StrOutputParser
from langchain_community.vectorstores import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_anthropic import ChatAnthropic
from dotenv import load_dotenv
load_dotenv()
# suppress grpc and glog logs for gemini
os.environ["GRPC_VERBOSITY"] = "ERROR"
os.environ["GLOG_minloglevel"] = "2"
# RAG parameters
CHUNK_SIZE = 1024
CHUNK_OVERLAP = CHUNK_SIZE // 8
K = 10
FETCH_K = 20
llm_model_translation = {
"LLaMA 3": "llama3-70b-8192",
"OpenAI GPT 4o Mini": "gpt-4o-mini",
"OpenAI GPT 4o": "gpt-4o",
"OpenAI GPT 4": "gpt-4-turbo",
"Gemini 1.5 Pro": "gemini-1.5-pro",
"Claude Sonnet 3.5": "claude-3-5-sonnet-20240620",
}
llm_classes = {
"llama3-70b-8192": ChatGroq,
"gpt-4o-mini": ChatOpenAI,
"gpt-4o": ChatOpenAI,
"gpt-4-turbo": ChatOpenAI,
"gemini-1.5-pro": ChatGoogleGenerativeAI,
"claude-3-5-sonnet-20240620": ChatAnthropic,
}
def load_llm(model: str, api_key: str, temperature: float = 1.0, max_length: int = 2048):
model_name = llm_model_translation.get(model)
llm_class = llm_classes.get(model_name)
if not llm_class:
raise ValueError(f"Model {model} not supported.")
try:
llm = llm_class(model_name=model_name, temperature=temperature, max_tokens=max_length)
except Exception as e:
print(f"An error occurred: {e}")
llm = None
return llm
def create_db_with_langchain(path: list[str], url_content: dict):
all_docs = []
text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
if path:
for file in path:
loader = PyMuPDFLoader(file)
data = loader.load()
# split it into chunks
docs = text_splitter.split_documents(data)
all_docs.extend(docs)
if url_content:
for url, content in url_content.items():
doc = Document(page_content=content, metadata={"source": url})
# split it into chunks
docs = text_splitter.split_documents([doc])
all_docs.extend(docs)
# print docs
for idx, doc in enumerate(all_docs):
print(f"Doc: {idx} | Length = {len(doc.page_content)}")
assert len(all_docs) > 0, "No PDFs or scrapped data provided"
db = Chroma.from_documents(all_docs, embedding_function)
return db
def generate_rag(
prompt: str,
topic: str,
model: str,
url_content: dict,
path: list[str],
temperature: float = 1.0,
max_length: int = 2048,
api_key: str = "",
sys_message="",
):
llm = load_llm(model, api_key, temperature, max_length)
if llm is None:
print("Failed to load LLM. Aborting operation.")
return None
db = create_db_with_langchain(path, url_content)
retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": K, "fetch_k": FETCH_K})
rag_prompt = hub.pull("rlm/rag-prompt")
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
docs = retriever.get_relevant_documents(topic)
formatted_docs = format_docs(docs)
rag_chain = (
{"context": lambda _: formatted_docs, "question": RunnablePassthrough()} | rag_prompt | llm | StrOutputParser()
)
return rag_chain.invoke(prompt)
def generate_base(
prompt: str, topic: str, model: str, temperature: float, max_length: int, api_key: str, sys_message=""
):
llm = load_llm(model, api_key, temperature, max_length)
if llm is None:
print("Failed to load LLM. Aborting operation.")
return None
try:
output = llm.invoke(prompt).content
return output
except Exception as e:
print(f"An error occurred while running the model: {e}")
return None
def generate(
prompt: str,
topic: str,
model: str,
url_content: dict,
path: list[str],
temperature: float = 1.0,
max_length: int = 2048,
api_key: str = "",
sys_message="",
):
if path or url_content:
return generate_rag(prompt, topic, model, url_content, path, temperature, max_length, api_key, sys_message)
else:
return generate_base(prompt, topic, model, temperature, max_length, api_key, sys_message)
|