Spaces:
Sleeping
Sleeping
File size: 4,846 Bytes
2ed8e0d 018ec39 2ed8e0d 018ec39 2ed8e0d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import streamlit as st
# from dotenv import load_dotenv
from langchain.chains import RetrievalQA
from langchain.chains.question_answering import load_qa_chain
from langchain_community.llms import HuggingFaceHub
from langchain.document_loaders import AssemblyAIAudioTranscriptLoader
from langchain.embeddings import HuggingFaceHubEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from tempfile import NamedTemporaryFile
# Load environment variables
# load_dotenv()
# Function to create a prompt for retrieval QA chain
def create_qa_prompt() -> PromptTemplate:
template = """\n\nHuman: Use the following pieces of context to answer the question at the end. If the answer is not clear, say I DON'T KNOW
{context}
Question: {question}
\n\nAssistant:
Answer:"""
return PromptTemplate(template=template, input_variables=["context", "question"])
# Function to create documents from a list of URLs
def create_docs(urls_list):
documents = []
for url in urls_list:
st.write(f'Transcribing {url}')
documents.append(AssemblyAIAudioTranscriptLoader(file_path=url).load()[0])
return documents
# Function to create a Hugging Face embeddings model
def make_embedder():
model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
return HuggingFaceHubEmbeddings(
repo_id=model_name,
task="feature-extraction"
)
# Function to create a retrieval QA chain
def make_qa_chain():
llm = HuggingFaceHub(
repo_id="HuggingFaceH4/zephyr-7b-beta",
model_kwargs={
"max_new_tokens": 512,
"top_k": 30,
"temperature": 0.01,
"repetition_penalty": 1.5,
},
)
return llm
# return RetrievalQA.from_chain_type(
# llm,
# retriever=db.as_retriever(search_type="mmr", search_kwargs={'fetch_k': 3}),
# return_source_documents=True,
# chain_type_kwargs={
# "prompt": create_qa_prompt(),
# }
# )
# Streamlit UI
def main():
st.set_page_config(page_title="Audio Query Chatbot", page_icon=":microphone:", layout="wide")
# Left pane - Audio file upload
col1, col2 = st.columns([1, 2])
with col1:
st.header("Upload Audio File")
uploaded_file = st.file_uploader("Choose a WAV or MP3 file", type=["wav", "mp3"], key="audio_uploader")
if uploaded_file is not None:
with NamedTemporaryFile(suffix='.mp3') as temp:
temp.write(uploaded_file.getvalue())
temp.seek(0)
docs = create_docs([temp.name])
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
# texts = text_splitter.split_documents(docs)
# for text in texts:
# text.metadata = {"audio_url": text.metadata["audio_url"]}
st.success('Audio file transcribed successfully!')
# hf = make_embedder()
# db = FAISS.from_documents(texts, hf)
# qa_chain = make_qa_chain(db)
# Right pane - Chatbot Interface
with col2:
st.header("Chatbot Interface")
if uploaded_file is not None:
with st.form(key="form"):
user_input = st.text_input("Ask your question", key="user_input")
# Automatically submit the form on Enter key press
st.markdown("<div><br></div>", unsafe_allow_html=True) # Adds some space
st.markdown(
"""<style>
#form input {margin-bottom: 15px;}
</style>""", unsafe_allow_html=True
)
submit = st.form_submit_button("Submit Question")
# Display the result once the form is submitted
if submit:
llm = make_qa_chain()
chain = load_qa_chain(llm, chain_type="stuff")
# docs = db.similarity_search(user_input)
result = chain.run(question=user_input,input_documents = docs)
# result = qa_chain.invoke(user_input)
# result = qa_chain({"query": user_input})
st.success("Query Result:")
st.write(f"User: {user_input}")
st.write(f"Assistant: {result}")
# st.subheader("Source Documents:")
# for idx, elt in enumerate(result['source_documents']):
# st.write(f"Source {idx + 1}:")
# st.write(f"Filepath: {elt.metadata['audio_url']}")
# st.write(f"Contents: {elt.page_content}")
if __name__ == "__main__":
main()
|