article_writer / ai_generate.py
eljanmahammadli's picture
removed unused imports
8f26ea6
raw
history blame
4.85 kB
import os
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_core.documents import Document
from langchain_community.embeddings.sentence_transformer import (
SentenceTransformerEmbeddings,
)
from langchain.schema import StrOutputParser
from langchain_community.vectorstores import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_anthropic import ChatAnthropic
from dotenv import load_dotenv
load_dotenv()
# suppress grpc and glog logs for gemini
os.environ["GRPC_VERBOSITY"] = "ERROR"
os.environ["GLOG_minloglevel"] = "2"
# RAG parameters
CHUNK_SIZE = 1024
CHUNK_OVERLAP = CHUNK_SIZE // 8
K = 10
FETCH_K = 20
llm_model_translation = {
"LLaMA 3": "llama3-70b-8192",
"OpenAI GPT 4o Mini": "gpt-4o-mini",
"OpenAI GPT 4o": "gpt-4o",
"OpenAI GPT 4": "gpt-4-turbo",
"Gemini 1.5 Pro": "gemini-1.5-pro",
"Claude Sonnet 3.5": "claude-3-5-sonnet-20240620",
}
llm_classes = {
"llama3-70b-8192": ChatGroq,
"gpt-4o-mini": ChatOpenAI,
"gpt-4o": ChatOpenAI,
"gpt-4-turbo": ChatOpenAI,
"gemini-1.5-pro": ChatGoogleGenerativeAI,
"claude-3-5-sonnet-20240620": ChatAnthropic,
}
def load_llm(model: str, api_key: str, temperature: float = 1.0, max_length: int = 2048):
model_name = llm_model_translation.get(model)
llm_class = llm_classes.get(model_name)
if not llm_class:
raise ValueError(f"Model {model} not supported.")
try:
llm = llm_class(model_name=model_name, temperature=temperature, max_tokens=max_length)
except Exception as e:
print(f"An error occurred: {e}")
llm = None
return llm
def create_db_with_langchain(path: list[str], url_content: dict):
all_docs = []
text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
if path:
for file in path:
loader = PyMuPDFLoader(file)
data = loader.load()
# split it into chunks
docs = text_splitter.split_documents(data)
all_docs.extend(docs)
if url_content:
for url, content in url_content.items():
doc = Document(page_content=content, metadata={"source": url})
# split it into chunks
docs = text_splitter.split_documents([doc])
all_docs.extend(docs)
# print docs
for idx, doc in enumerate(all_docs):
print(f"Doc: {idx} | Length = {len(doc.page_content)}")
assert len(all_docs) > 0, "No PDFs or scrapped data provided"
db = Chroma.from_documents(all_docs, embedding_function)
return db
def generate_rag(
prompt: str,
topic: str,
model: str,
url_content: dict,
path: list[str],
temperature: float = 1.0,
max_length: int = 2048,
api_key: str = "",
sys_message="",
):
llm = load_llm(model, api_key, temperature, max_length)
if llm is None:
print("Failed to load LLM. Aborting operation.")
return None
db = create_db_with_langchain(path, url_content)
retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": K, "fetch_k": FETCH_K})
rag_prompt = hub.pull("rlm/rag-prompt")
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
docs = retriever.get_relevant_documents(topic)
formatted_docs = format_docs(docs)
rag_chain = (
{"context": lambda _: formatted_docs, "question": RunnablePassthrough()} | rag_prompt | llm | StrOutputParser()
)
return rag_chain.invoke(prompt)
def generate_base(
prompt: str, topic: str, model: str, temperature: float, max_length: int, api_key: str, sys_message=""
):
llm = load_llm(model, api_key, temperature, max_length)
if llm is None:
print("Failed to load LLM. Aborting operation.")
return None
try:
output = llm.invoke(prompt).content
return output
except Exception as e:
print(f"An error occurred while running the model: {e}")
return None
def generate(
prompt: str,
topic: str,
model: str,
url_content: dict,
path: list[str],
temperature: float = 1.0,
max_length: int = 2048,
api_key: str = "",
sys_message="",
):
if path or url_content:
return generate_rag(prompt, topic, model, url_content, path, temperature, max_length, api_key, sys_message)
else:
return generate_base(prompt, topic, model, temperature, max_length, api_key, sys_message)