Spaces:
Sleeping
Sleeping
import os | |
import logging | |
from dotenv import load_dotenv | |
from langchain.memory import ConversationSummaryMemory | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_community.utilities import SQLDatabase | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_openai import ChatOpenAI | |
from langchain_openai import OpenAIEmbeddings | |
from langchain.agents import create_tool_calling_agent, AgentExecutor, Tool | |
from langchain_community.vectorstores import FAISS | |
from config.settings import Settings | |
# Load environment variables | |
load_dotenv() | |
open_api_key_token = os.getenv('OPENAI_API_KEY') | |
#db_uri = os.getenv('POST_DB_URI') | |
db_uri = Settings.DB_URI | |
class ChatAgentService: | |
def __init__(self): | |
# Database setup | |
self.db = SQLDatabase.from_uri(db_uri) | |
self.llm = ChatOpenAI(model="gpt-3.5-turbo-0125", api_key=open_api_key_token,max_tokens=150,temperature=0.2) | |
self.memory = ConversationSummaryMemory(llm=self.llm, return_messages=True) | |
# Tools setup | |
self.tools = [ | |
Tool( | |
name="DatabaseQuery", | |
func=self.database_tool, | |
description="Queries the SQL database using dynamically generated SQL queries based on user questions. Aimed to retrieve structured data like counts, specific records, or summaries from predefined schemas.", | |
tool_choice="required" | |
), | |
Tool( | |
name="DocumentData", | |
func=self.document_data_tool, | |
description="Searches through indexed documents to find relevant information based on user queries. Handles unstructured data from various document formats like PDF, DOCX, or TXT files.", | |
tool_choice="required" | |
), | |
] | |
# Agent setup | |
prompt_template = self.setup_prompt() | |
self.agent = create_tool_calling_agent(self.llm.bind(memory=self.memory), self.tools, prompt_template) | |
self.agent_executor = AgentExecutor(agent=self.agent, tools=self.tools, memory=self.memory, verbose=True) | |
def setup_prompt(self): | |
prompt_template = f""" | |
You are an assistant that helps with database queries and document retrieval. | |
Please base your responses strictly on available data and avoid assumptions. | |
If the question pertains to numerical data or structured queries, use the DatabaseQuery tool. | |
If the question relates to content within various documents, use the DocumentData tool. | |
Question: {{input}} | |
{{agent_scratchpad}} | |
""" | |
return ChatPromptTemplate.from_template(prompt_template) | |
def database_tool(self, question): | |
sql_query = self.generate_sql_query(question) | |
return self.run_query(sql_query) | |
def get_schema(self,_): | |
# print(self.db.get_table_info()) | |
return self.db.get_table_info() | |
def generate_sql_query(self, question): | |
schema = self.get_schema(None) # Get the schema using the function | |
template_query_generation = """Generate a SQL query to answer the user's question based on the available database schema. | |
{schema} | |
Question: {question} | |
SQL Query:""" | |
prompt_query_generation = ChatPromptTemplate.from_template(template_query_generation) | |
# Correctly setting up the initial data dictionary for the chain | |
input_data = {'question': question} | |
# Setup the chain correctly | |
sql_chain = (RunnablePassthrough.assign(schema=self.get_schema) | |
| prompt_query_generation | |
| self.llm.bind(stop="\nSQL Result:") | |
| StrOutputParser()) | |
# Make sure to invoke with an empty dictionary if all needed data is already assigned | |
return sql_chain.invoke(input_data) | |
def run_query(self, query): | |
try: | |
logging.info(f"Executing SQL query: {query}") | |
result = self.db.run(query) | |
logging.info(f"Query successful: {result}") | |
return result | |
except Exception as e: | |
logging.error(f"Error executing query: {query}, Error: {str(e)}") | |
return None | |
def document_data_tool(self, query): | |
try: | |
logging.info(f"Searching documents for query: {query}") | |
embeddings = OpenAIEmbeddings(api_key=open_api_key_token) | |
index_paths = self.find_index_for_document(query) | |
responses = [] | |
for index_path in index_paths: | |
vector_store = FAISS.load_local(index_path, embeddings, allow_dangerous_deserialization=True) | |
response = self.query_vector_store(vector_store, query) | |
responses.append(response) | |
logging.info(f"Document search results: {responses}") | |
return "\n".join(responses) | |
except Exception as e: | |
logging.error(f"Error in document data tool for query: {query}, Error: {str(e)}") | |
return "Error processing document query." | |
def find_index_for_document(self, query): | |
base_path = os.getenv('VECTOR_DB_PATH') | |
# document_hint = self.extract_document_hint(query) | |
index_paths = [] | |
for root, dirs, files in os.walk(base_path): | |
for dir in dirs: | |
if 'index.faiss' in os.listdir(os.path.join(root, dir)): | |
index_paths.append(os.path.join(root, dir, '')) | |
return index_paths | |
def query_vector_store(self, vector_store, query): | |
docs = vector_store.similarity_search(query) | |
return '\n\n'.join([doc.page_content for doc in docs]) | |
def answer_question(self, user_question): | |
try: | |
logging.info(f"Received question: {user_question}") | |
response = self.agent_executor.invoke({"input": user_question}) | |
output_response = response.get("output", "No valid response generated.") | |
logging.info(f"Response generated: {output_response}") | |
return output_response | |
except Exception as e: | |
logging.error(f"Error processing question: {user_question}, Error: {str(e)}") | |
return f"An error occurred: {str(e)}" | |