Spaces:
Runtime error
Runtime error
import torch | |
from openai import OpenAI | |
import os | |
from transformers import pipeline | |
from groq import Groq | |
from langchain_community.document_loaders import PyMuPDFLoader | |
from langchain_community.document_loaders import TextLoader | |
from langchain_community.embeddings.sentence_transformer import ( | |
SentenceTransformerEmbeddings, | |
) | |
from langchain_community.vectorstores import Chroma | |
from langchain_text_splitters import CharacterTextSplitter | |
from langchain import hub | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain.chains import RetrievalQA | |
from langchain_groq import ChatGroq | |
from dotenv import load_dotenv | |
load_dotenv() | |
groq_client = Groq( | |
api_key=os.environ.get("GROQ_API_KEY"), | |
) | |
def create_db_with_langchain(path): | |
loader = PyMuPDFLoader(path) | |
data = loader.load() | |
# split it into chunks | |
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) | |
docs = text_splitter.split_documents(data) | |
# create the open-source embedding function | |
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") | |
# load it into Chroma | |
db = Chroma.from_documents(docs, embedding_function) | |
return db | |
def generate_groq_rag(text, model, path): | |
llm = ChatGroq( | |
temperature=0, | |
model_name=model, | |
) | |
db = create_db_with_langchain(path) | |
retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": 4, "fetch_k": 20}) | |
prompt = hub.pull("rlm/rag-prompt") | |
def format_docs(docs): | |
return "\n\n".join(doc.page_content for doc in docs) | |
rag_chain = {"context": retriever | format_docs, "question": RunnablePassthrough()} | prompt | llm | |
return rag_chain.invoke(text).content | |
def generate_groq_base(text, model): | |
completion = groq_client.chat.completions.create( | |
model=model, | |
messages=[ | |
{"role": "user", "content": text}, | |
{ | |
"role": "assistant", | |
"content": "Please follow the instruction and write about the given topic in approximately the given number of words", | |
}, | |
], | |
temperature=1, | |
max_tokens=1024, | |
top_p=1, | |
stream=True, | |
stop=None, | |
) | |
response = "" | |
for i, chunk in enumerate(completion): | |
if i != 0: | |
response += chunk.choices[0].delta.content or "" | |
return response | |
def generate_groq(text, model, path): | |
if path: | |
return generate_groq_rag(text, model, path) | |
else: | |
return generate_groq_base(text, model) | |
def generate_openai(text, model, openai_client): | |
message = [{"role": "user", "content": text}] | |
response = openai_client.chat.completions.create( | |
model=model, messages=message, temperature=0.2, max_tokens=800, frequency_penalty=0.0 | |
) | |
return response.choices[0].message.content | |
def generate(text, model, path, api): | |
if model == "Llama 3": | |
return generate_groq(text, "llama3-70b-8192", path) | |
elif model == "Groq": | |
return generate_groq(text, "llama3-groq-70b-8192-tool-use-preview", path) | |
elif model == "Mistral": | |
return generate_groq(text, "mixtral-8x7b-32768", path) | |
elif model == "Gemma": | |
return generate_groq(text, "gemma2-9b-it", path) | |
elif model == "OpenAI GPT 3.5": | |
try: | |
openai_client = OpenAI(api_key=api) | |
return generate_openai(text, "gpt-3.5-turbo", openai_client) | |
except: | |
return "Please add a valid API key" | |
elif model == "OpenAI GPT 4": | |
try: | |
openai_client = OpenAI(api_key=api) | |
return generate_openai(text, "gpt-4-turbo", openai_client) | |
except: | |
return "Please add a valid API key" | |
elif model == "OpenAI GPT 4o": | |
try: | |
openai_client = OpenAI(api_key=api) | |
return generate_openai(text, "gpt-4o", openai_client) | |
except: | |
return "Please add a valid API key" | |