Spaces:
Running
Running
import logging | |
import os | |
import re | |
from typing import List | |
import llama_index | |
import phoenix as px | |
from llama_index.llms import ChatMessage, MessageRole | |
from openai import OpenAI | |
from environments import OPENAI_API_KEY | |
class IndexBuilder: | |
def __init__(self, vdb_collection_name, embed_model, is_load_from_vector_store=False): | |
self.documents = None | |
self.vdb_collection_name = vdb_collection_name | |
self.embed_model = embed_model | |
self.index = None | |
self.is_load_from_vector_store = is_load_from_vector_store | |
self.build_index() | |
def _load_doucments(self): | |
pass | |
def _setup_service_context(self): | |
print("Using global service context...") | |
def _setup_vector_store(self): | |
print("Setup vector store...") | |
def _setup_index(self): | |
if not self.is_load_from_vector_store and self.documents is None: | |
raise ValueError("No documents provided for index building.") | |
print("Building Index") | |
def build_index(self): | |
if self.is_load_from_vector_store: | |
self._setup_service_context() | |
self._setup_vector_store() | |
self._setup_index() | |
return | |
self._load_doucments() | |
self._setup_service_context() | |
self._setup_vector_store() | |
self._setup_index() | |
class Chatbot: | |
SYSTEM_PROMPT = "" | |
DENIED_ANSWER_PROMPT = "" | |
CHAT_EXAMPLES = [] | |
def __init__(self, model_name, index_builder: IndexBuilder, llm=None): | |
self.model_name = model_name | |
self.index_builder = index_builder | |
self.llm = llm | |
self.documents = None | |
self.index = None | |
self.chat_engine = None | |
self.service_context = None | |
self.vector_store = None | |
self.tools = None | |
self._setup_logger() | |
self._setup_chatbot() | |
def _setup_logger(self): | |
logs_dir = 'logs' | |
if not os.path.exists(logs_dir): | |
os.makedirs(logs_dir) # Step 3: Create logs directory | |
logging.basicConfig( | |
filename=os.path.join(logs_dir, 'chatbot.log'), | |
filemode='a', | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
level=logging.INFO | |
) | |
self.logger = logging.getLogger(__name__) | |
def _setup_chatbot(self): | |
# self._setup_observer() | |
self._setup_index() | |
self._setup_query_engine() | |
self._setup_tools() | |
self._setup_chat_engine() | |
def _setup_observer(self): | |
px.launch_app() | |
llama_index.set_global_handler("arize_phoenix") | |
def _setup_index(self): | |
self.index = self.index_builder.index | |
print("Inherited index builder") | |
def _setup_query_engine(self): | |
if self.index is None: | |
raise ValueError("No index built") | |
pass | |
print("Setup query engine...") | |
def _setup_tools(self): | |
pass | |
print("Setup tools...") | |
def _setup_chat_engine(self): | |
if self.index is None: | |
raise ValueError("No index built") | |
pass | |
print("Setup chat engine...") | |
def stream_chat(self, message, history): | |
self.logger.info(history) | |
self.logger.info(self.convert_to_chat_messages(history)) | |
response = self.chat_engine.stream_chat( | |
message, chat_history=self.convert_to_chat_messages(history) | |
) | |
# Stream tokens as they are generated | |
partial_message = "" | |
for token in response.response_gen: | |
partial_message += token | |
yield partial_message | |
references = {} | |
for source in response.source_nodes: | |
if source.score < 0.76: | |
continue | |
url = source.node.metadata.get('url') | |
title = source.node.metadata.get('title') | |
if url and title: | |
references[url] = title | |
if references: | |
partial_message = partial_message + "\n \n\n---\n\nSources: \n" | |
for url, title in references.items(): | |
partial_message = partial_message + f"- [{title}]({url})\n" | |
yield partial_message | |
return partial_message | |
def convert_to_chat_messages(self, history: List[List[str]]) -> List[ChatMessage]: | |
chat_messages = [ChatMessage( | |
role=MessageRole.SYSTEM, content=self.SYSTEM_PROMPT)] | |
for conversation in history[-3:]: | |
for index, message in enumerate(conversation): | |
role = MessageRole.USER if index % 2 == 0 else MessageRole.ASSISTANT | |
clean_message = re.sub( | |
r"\n \n\n---\n\nSources: \n.*$", "", message, flags=re.DOTALL) | |
chat_messages.append(ChatMessage( | |
role=role, content=clean_message.strip())) | |
return chat_messages | |
def predict_with_rag(self, message, history): | |
return self.stream_chat(message, history) | |
# Vanilla chatgpt methods, shared across all chatbot instance | |
def _invoke_chatgpt(self, history, message, is_include_system_prompt=False): | |
openai_client = OpenAI(api_key=OPENAI_API_KEY) | |
history_openai_format = [] | |
if is_include_system_prompt: | |
history_openai_format.append( | |
{"role": "system", "content": self.SYSTEM_PROMPT}) | |
for human, assistant in history: | |
history_openai_format.append({"role": "user", "content": human}) | |
history_openai_format.append( | |
{"role": "assistant", "content": assistant}) | |
history_openai_format.append({"role": "user", "content": message}) | |
stream = openai_client.chat.completions.create( | |
model=self.model_name, | |
messages=history_openai_format, | |
temperature=1.0, | |
stream=True) | |
partial_message = "" | |
for part in stream: | |
partial_message += part.choices[0].delta.content or "" | |
yield partial_message | |
# For 'With Prompt Wrapper' - Add system prompt, no Pinecone | |
def predict_with_prompt_wrapper(self, message, history): | |
yield from self._invoke_chatgpt(history, message, is_include_system_prompt=True) | |
# For 'Vanilla ChatGPT' - No system prompt | |
def predict_vanilla_chatgpt(self, message, history): | |
yield from self._invoke_chatgpt(history, message) | |