import streamlit as st from ragchatbot import RAGChatBot from pydantic_models import RequestModel, ChatHistoryItem def validate_chat_history_item(chat_history_item: ChatHistoryItem): return ChatHistoryItem.model_validate(chat_history_item.model_dump()) st.set_page_config(page_title="RAG-Chatbot", page_icon=":mag:", layout="wide") st.title("Test Contextual Retrieval - KCS10") col1, col2, col3 = st.columns(3) col1.title("Contextual Chunking") col2.title("Current Model") col3.title("Formatted Text") if "context_ragchatbot" not in st.session_state: st.session_state.context_ragchatbot = RAGChatBot(vectorstore_path="context_vectorstore") if "formatted_ragchatbot" not in st.session_state: st.session_state.formatted_ragchatbot = RAGChatBot(vectorstore_path="formatted_vectorstore") if "just_ragchatbot" not in st.session_state: st.session_state.just_ragchatbot = RAGChatBot(vectorstore_path="just_vectorstore") if "context_chat_history" not in st.session_state: st.session_state.context_chat_history = [] if "formatted_chat_history" not in st.session_state: st.session_state.formatted_chat_history = [] if "just_chat_history" not in st.session_state: st.session_state.just_chat_history = [] for chat_index in range(0,len(st.session_state.context_chat_history)): assert len(st.session_state.context_chat_history) == len(st.session_state.formatted_chat_history) == len(st.session_state.just_chat_history) for col, chat_history, sources_text in zip(st.columns(3, vertical_alignment="top"), [st.session_state.context_chat_history, st.session_state.just_chat_history, st.session_state.formatted_chat_history], ["Contextual Chunking", "Current Model", "Formatted Text"]): chat = chat_history[chat_index] with col.chat_message("user"): st.write(chat.get("user_message").replace("\n","\n\n")) with col.chat_message("assistant"): st.write(chat.get("assistant_message").replace("\n","\n\n")) st.write(chat.get("search_phrase")) for i, doc in enumerate(chat.get("sources_documents")): with st.expander(f"{sources_text} Sources - {i+1}"): st.subheader(f"{doc.get('heading')} - {doc.get('relevance_score')}") if sources_text == "Contextual Chunking": st.write(doc.get("page_content").replace("\n","\n\n").split("")[1].split("")[0]) else: st.write(doc.get("page_content").replace("\n","\n\n")) # print_session_state_variables() if user_query := st.chat_input("Enter your query"): for col in st.columns(3, vertical_alignment="top"): with col.chat_message("user"): st.write(user_query.replace("\n","\n\n")) with st.spinner("Generating response..."): context_response = st.session_state.context_ragchatbot.get_response( RequestModel(user_question=user_query, chat_history=[ChatHistoryItem(user_message=chat.get("user_message"), assistant_message=chat.get("assistant_message")) for chat in st.session_state.context_chat_history]) ) sources_documents = [{"heading":doc.heading, "page_content":doc.page_content, "relevance_score":doc.relevance_score} for doc in context_response.sources_documents] st.session_state.context_chat_history.append({ "user_message": user_query, "assistant_message": context_response.answer, "search_phrase": context_response.search_phrase, "sources_documents": sources_documents }) just_response = st.session_state.just_ragchatbot.get_response( RequestModel(user_question=user_query, chat_history=[ChatHistoryItem(user_message=chat.get("user_message"), assistant_message=chat.get("assistant_message")) for chat in st.session_state.just_chat_history]) ) sources_documents = [{"heading":doc.heading, "page_content":doc.page_content, "relevance_score":doc.relevance_score} for doc in just_response.sources_documents] st.session_state.just_chat_history.append({ "user_message": user_query, "assistant_message": just_response.answer, "search_phrase": just_response.search_phrase, "sources_documents": sources_documents }) formatted_response = st.session_state.formatted_ragchatbot.get_response( RequestModel(user_question=user_query, chat_history=[ChatHistoryItem(user_message=chat.get("user_message"), assistant_message=chat.get("assistant_message")) for chat in st.session_state.formatted_chat_history]) ) sources_documents = [{"heading":doc.heading, "page_content":doc.page_content, "relevance_score":doc.relevance_score} for doc in formatted_response.sources_documents] st.session_state.formatted_chat_history.append({ "user_message": user_query, "assistant_message": formatted_response.answer, "search_phrase": formatted_response.search_phrase, "sources_documents": sources_documents }) st.rerun() # with col1.chat_message("assistant"): # st.write(context_response.answer.replace("\n","\n\n")) # with col1.expander("Contextual Chunking Sources"): # for doc in context_response.sources_documents: # st.subheader(f"{doc.heading} - {doc.relevance_score}") # st.write(doc.page_content.replace("\n","\n\n").split("")[1].split("")[0]) # st.divider() # with col2.chat_message("assistant"): # st.write(just_response.answer.replace("\n","\n\n")) # with st.expander("Without Contextual Chunking Sources"): # st.write(just_response.chat_history[-1].search_phrase) # for doc in just_response.sources_documents: # st.subheader(f"{doc.heading} - {doc.relevance_score}") # st.write(doc.page_content.replace("\n","\n\n")) # st.divider() # with col3.chat_message("assistant"): # st.write(formatted_response.answer.replace("\n","\n\n")) # with st.expander("Formatted Contextual Chunking Sources"): # st.write(formatted_response.chat_history[-1].search_phrase) # for doc in formatted_response.sources_documents: # st.subheader(f"{doc.heading} - {doc.relevance_score}") # st.write(doc.page_content.replace("\n","\n\n")) # st.divider()