import torch from openai import OpenAI import os from transformers import pipeline from groq import Groq from langchain_community.document_loaders import PyMuPDFLoader from langchain_community.document_loaders import TextLoader from langchain_community.embeddings.sentence_transformer import ( SentenceTransformerEmbeddings, ) from langchain_community.vectorstores import Chroma from langchain_text_splitters import CharacterTextSplitter from langchain import hub from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough from langchain.chains import RetrievalQA from langchain_groq import ChatGroq from dotenv import load_dotenv load_dotenv() groq_client = Groq( api_key=os.environ.get("GROQ_API_KEY"), ) def create_db_with_langchain(path): loader = PyMuPDFLoader(path) data = loader.load() # split it into chunks text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) docs = text_splitter.split_documents(data) # create the open-source embedding function embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") # load it into Chroma db = Chroma.from_documents(docs, embedding_function) return db def generate_groq_rag(text, model, path): llm = ChatGroq( temperature=0, model_name=model, ) db = create_db_with_langchain(path) retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": 4, "fetch_k": 20}) prompt = hub.pull("rlm/rag-prompt") def format_docs(docs): return "\n\n".join(doc.page_content for doc in docs) rag_chain = {"context": retriever | format_docs, "question": RunnablePassthrough()} | prompt | llm return rag_chain.invoke(text).content def generate_groq_base(text, model): completion = groq_client.chat.completions.create( model=model, messages=[ {"role": "user", "content": text}, { "role": "assistant", "content": "Please follow the instruction and write about the given topic in approximately the given number of words", }, ], temperature=1, max_tokens=1024, top_p=1, stream=True, stop=None, ) response = "" for i, chunk in enumerate(completion): if i != 0: response += chunk.choices[0].delta.content or "" return response def generate_groq(text, model, path): if path: return generate_groq_rag(text, model, path) else: return generate_groq_base(text, model) def generate_openai(text, model, openai_client): message = [{"role": "user", "content": text}] response = openai_client.chat.completions.create( model=model, messages=message, temperature=0.2, max_tokens=800, frequency_penalty=0.0 ) return response.choices[0].message.content def generate(text, model, path, api): if model == "Llama 3": return generate_groq(text, "llama3-70b-8192", path) elif model == "Groq": return generate_groq(text, "llama3-groq-70b-8192-tool-use-preview", path) elif model == "Mistral": return generate_groq(text, "mixtral-8x7b-32768", path) elif model == "Gemma": return generate_groq(text, "gemma2-9b-it", path) elif model == "OpenAI GPT 3.5": try: openai_client = OpenAI(api_key=api) return generate_openai(text, "gpt-3.5-turbo", openai_client) except: return "Please add a valid API key" elif model == "OpenAI GPT 4": try: openai_client = OpenAI(api_key=api) return generate_openai(text, "gpt-4-turbo", openai_client) except: return "Please add a valid API key" elif model == "OpenAI GPT 4o": try: openai_client = OpenAI(api_key=api) return generate_openai(text, "gpt-4o", openai_client) except: return "Please add a valid API key"