|
from langchain import SQLDatabaseChain |
|
from langchain.sql_database import SQLDatabase |
|
from langchain.llms.openai import OpenAI |
|
from langchain.chat_models import ChatOpenAI |
|
from langchain.prompts.prompt import PromptTemplate |
|
|
|
llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", verbose=True) |
|
|
|
DEFAULT_TABLES = [ |
|
'Active Players', |
|
'Team_Per_Game_Statistics_2022_23', |
|
"Team_Totals_Statistics_2022_23", |
|
"Player_Total_Statistics_2022_23", |
|
"Player_Per_Game_Statistics_2022_23" |
|
] |
|
|
|
def get_prompt(): |
|
_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. |
|
Use the following format: |
|
|
|
Question: "Question here" |
|
SQLQuery: "SQL Query to run" |
|
SQLResult: "Result of the SQLQuery" |
|
|
|
Answer: "Final answer here" |
|
|
|
Only use the following tables: |
|
|
|
{table_info} |
|
|
|
Question: {input}""" |
|
|
|
PROMPT = PromptTemplate( |
|
input_variables=["input", "table_info", "dialect"], template=_DEFAULT_TEMPLATE |
|
) |
|
return PROMPT |
|
|
|
def check_query(query): |
|
if query.startswith("### Query"): |
|
split = query.split('\n\n') |
|
q_text = split[0] |
|
t_text = split[1] |
|
|
|
if t_text.startswith("### Tables"): |
|
query_params = dict() |
|
tables = t_text.split('\n') |
|
query_params['tables'] = tables[1:] |
|
query_params['q'] = q_text.split('\n')[1] |
|
print(query_params) |
|
return query_params |
|
else: |
|
return 'error' |
|
return 'small' |
|
|
|
def get_db(q, tables): |
|
if len(tables) == 0: |
|
db = SQLDatabase.from_uri("sqlite:///nba_small.db", |
|
sample_rows_in_table_info=2) |
|
else: |
|
tables.extend(DEFAULT_TABLES) |
|
db = SQLDatabase.from_uri("sqlite:///nba.db", |
|
include_tables = tables, |
|
sample_rows_in_table_info=2) |
|
return db |
|
def answer_question(query): |
|
PROMPT = get_prompt() |
|
query_check = check_query(query) |
|
if query_check == 'error': |
|
return('ERROR: Wrong format for getting the big db schema') |
|
if isinstance(query_check, dict): |
|
q = query_check['q'] |
|
tables = query_check['tables'] |
|
if query_check == 'small': |
|
q = query |
|
tables = [] |
|
db = get_db(q, tables) |
|
|
|
db_chain = SQLDatabaseChain.from_llm(llm, db, |
|
prompt=PROMPT, |
|
verbose=True, |
|
return_intermediate_steps=True, |
|
|
|
) |
|
result = db_chain(q) |
|
return result['result'] |
|
|
|
if __name__ == "__main__": |
|
import gradio as gr |
|
|
|
|
|
gr.Interface( |
|
answer_question, |
|
[ |
|
gr.inputs.Textbox(lines=10, label="Query"), |
|
], |
|
gr.outputs.Textbox(label="Response"), |
|
title="Ask NBA Stats", |
|
description=""" Ask NBA Stats is a tool that let's you ask a question with |
|
the NBA SQL tables as a reference |
|
|
|
Ask a simple question to use the small database |
|
|
|
If you would like to access the large DB use format |
|
|
|
### Query |
|
single line query |
|
|
|
### Tables |
|
tables to access line by line |
|
table1 |
|
table2""" |
|
).launch() |