article_writer / ai_generate.py
eljanmahammadli's picture
added RAG
03fd59b
raw
history blame
4 kB
import torch
from openai import OpenAI
import os
from transformers import pipeline
from groq import Groq
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_community.document_loaders import TextLoader
from langchain_community.embeddings.sentence_transformer import (
SentenceTransformerEmbeddings,
)
from langchain_community.vectorstores import Chroma
from langchain_text_splitters import CharacterTextSplitter
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain.chains import RetrievalQA
from langchain_groq import ChatGroq
from dotenv import load_dotenv
load_dotenv()
groq_client = Groq(
api_key=os.environ.get("GROQ_API_KEY"),
)
def create_db_with_langchain(path):
loader = PyMuPDFLoader(path)
data = loader.load()
# split it into chunks
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(data)
# create the open-source embedding function
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
# load it into Chroma
db = Chroma.from_documents(docs, embedding_function)
return db
def generate_groq_rag(text, model, path):
llm = ChatGroq(
temperature=0,
model_name=model,
)
db = create_db_with_langchain(path)
retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": 4, "fetch_k": 20})
prompt = hub.pull("rlm/rag-prompt")
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
rag_chain = {"context": retriever | format_docs, "question": RunnablePassthrough()} | prompt | llm
return rag_chain.invoke(text).content
def generate_groq_base(text, model):
completion = groq_client.chat.completions.create(
model=model,
messages=[
{"role": "user", "content": text},
{
"role": "assistant",
"content": "Please follow the instruction and write about the given topic in approximately the given number of words",
},
],
temperature=1,
max_tokens=1024,
top_p=1,
stream=True,
stop=None,
)
response = ""
for i, chunk in enumerate(completion):
if i != 0:
response += chunk.choices[0].delta.content or ""
return response
def generate_groq(text, model, path):
if path:
return generate_groq_rag(text, model, path)
else:
return generate_groq_base(text, model)
def generate_openai(text, model, openai_client):
message = [{"role": "user", "content": text}]
response = openai_client.chat.completions.create(
model=model, messages=message, temperature=0.2, max_tokens=800, frequency_penalty=0.0
)
return response.choices[0].message.content
def generate(text, model, path, api):
if model == "Llama 3":
return generate_groq(text, "llama3-70b-8192", path)
elif model == "Groq":
return generate_groq(text, "llama3-groq-70b-8192-tool-use-preview", path)
elif model == "Mistral":
return generate_groq(text, "mixtral-8x7b-32768", path)
elif model == "Gemma":
return generate_groq(text, "gemma2-9b-it", path)
elif model == "OpenAI GPT 3.5":
try:
openai_client = OpenAI(api_key=api)
return generate_openai(text, "gpt-3.5-turbo", openai_client)
except:
return "Please add a valid API key"
elif model == "OpenAI GPT 4":
try:
openai_client = OpenAI(api_key=api)
return generate_openai(text, "gpt-4-turbo", openai_client)
except:
return "Please add a valid API key"
elif model == "OpenAI GPT 4o":
try:
openai_client = OpenAI(api_key=api)
return generate_openai(text, "gpt-4o", openai_client)
except:
return "Please add a valid API key"