Spaces:
Sleeping
Sleeping
from langchain import PromptTemplate, LLMChain | |
from langchain.llms import CTransformers | |
import os | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import Chroma | |
from langchain.chains import RetrievalQA | |
from langchain.embeddings import HuggingFaceBgeEmbeddings, HuggingFaceEmbeddings | |
from io import BytesIO | |
from langchain.document_loaders import PyPDFLoader | |
import gradio as gr | |
import chromadb | |
from dotenv import load_dotenv | |
from constants import CHROMA_SETTINGS | |
from io import BytesIO | |
import gradio as gr | |
local_llm = "TheBloke/zephyr-7B-beta-GGUF" | |
config = { | |
'max_new_tokens': 1024, | |
'repetition_penalty': 1.1, | |
'temperature': 0.1, | |
'top_k': 50, | |
'top_p': 0.9, | |
'stream': True, | |
'threads': int(os.cpu_count() / 2) | |
} | |
llm = CTransformers( | |
model=local_llm, | |
model_type="mistral", | |
lib="avx2", #for CPU use | |
**config | |
) | |
embeddings_model_name = os.environ.get('EMBEDDINGS_MODEL_NAME') | |
persist_directory = os.environ.get('PERSIST_DIRECTORY') | |
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4)) | |
if not load_dotenv(): | |
print("Could not load .env file or it is empty. Please check if it exists and is readable.") | |
exit(1) | |
print("Loading embeddings model...") | |
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) | |
chroma_client = chromadb.PersistentClient(settings=CHROMA_SETTINGS , path=persist_directory) | |
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS, client=chroma_client) | |
prompt_template = """Use the following pieces of information to answer the user's question. | |
If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
Context: {context} | |
Question: {question} | |
Only return the helpful answer below and nothing else. | |
Helpful answer: | |
""" | |
prompt = PromptTemplate(template=prompt_template, input_variables=['context', 'question']) | |
retriever = db.as_retriever(search_kwargs={"k": target_source_chunks}) | |
# activate/deactivate the streaming StdOut callback for LLMs | |
''' | |
query="What is state ownership report" | |
semantic_search_results = retriever.get_relevant_documents(query) | |
print(semantic_search_results) | |
query="What is state ownership report" | |
chain_type_kwargs = {"prompt": prompt} | |
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents= False, chain_type_kwargs=chain_type_kwargs, verbose=True) | |
response= qa(query) | |
print(response)''' | |
chain_type_kwargs = {"prompt": prompt} | |
input_gradio= gr.Text( | |
label="Prompt", | |
show_label=False, | |
max_lines=2, | |
placeholder="Enter your question here", | |
container=False, | |
) | |
def get_response(input_gradio ): | |
query=input_gradio | |
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents= False, chain_type_kwargs=chain_type_kwargs, verbose=True) | |
response= qa(query) | |
return response['result'] | |
iface= gr.Interface( | |
fn=get_response, | |
inputs=input_gradio, | |
outputs="text", | |
title="Stimline Chatbot", | |
description="A chatbot that uses the LLM to answer anything regarding Stimline", | |
allow_flagging='never' | |
) | |
# Interactive questions and answers | |
iface.launch() | |