|
"""Streamlit app for demoing SambaCoder-nsql-llama-2-70b.""" |
|
|
|
import json |
|
import os |
|
|
|
import pandas as pd |
|
import requests |
|
import streamlit as st |
|
from manifest import Manifest, Response |
|
from manifest.connections.client_pool import ClientConnection |
|
|
|
STOP_TOKENS = ["###", ";", "--", "```"] |
|
|
|
|
|
def generate_prompt(question, schema): |
|
return f"""{schema}\n\n\n-- Using valid SQLite, answer the following questions for the tables provided above.\n\n-- {question}\n""" |
|
|
|
|
|
def generate_sql(question, schema): |
|
prompt = generate_prompt(question, schema) |
|
url = st.secrets["backend_url"] |
|
headers = { |
|
"Content-Type": "application/json", |
|
"key": st.secrets["key"], |
|
} |
|
|
|
data = { |
|
"inputs": [prompt], |
|
"params": { |
|
"do_sample": {"type": "bool", "value": "false"}, |
|
"max_tokens_to_generate": {"type": "int", "value": "1000"}, |
|
"repetition_penalty": {"type": "float", "value": "1"}, |
|
"temperature": {"type": "float", "value": "1"}, |
|
"top_k": {"type": "int", "value": "50"}, |
|
"top_logprobs": {"type": "int", "value": "0"}, |
|
"top_p": {"type": "float", "value": "1"}, |
|
}, |
|
} |
|
|
|
r = requests.post(url, headers=headers, data=json.dumps(data), stream=True) |
|
|
|
if r.encoding is None: |
|
r.encoding = "utf-8" |
|
for line in r.iter_lines(decode_unicode=True): |
|
if line and line.startswith("data: "): |
|
output = json.loads(line[len("data: ") :]) |
|
token = output.get("stream_token", "") |
|
if len(token) > 0: |
|
yield token |
|
|
|
|
|
st.title("SambaCoder-nsql-llama-2-70b Demo") |
|
|
|
expander = st.expander("Database Schema") |
|
|
|
|
|
|
|
default_schema = """CREATE TABLE stadium ( |
|
stadium_id number, |
|
location text, |
|
name text, |
|
capacity number, |
|
highest number, |
|
lowest number, |
|
average number |
|
) |
|
CREATE TABLE singer ( |
|
singer_id number, |
|
name text, |
|
country text, |
|
song_name text, |
|
song_release_year text, |
|
age number, |
|
is_male others |
|
) |
|
CREATE TABLE concert ( |
|
concert_id number, |
|
concert_name text, |
|
theme text, |
|
stadium_id text, |
|
year text |
|
) |
|
CREATE TABLE singer_in_concert ( |
|
concert_id number, |
|
singer_id text |
|
)""" |
|
|
|
schema = expander.text_area("Current schema:", value=default_schema, height=500) |
|
|
|
|
|
text_prompt = st.text_input( |
|
"Please let me know what question do you want to ask?", |
|
value="What is the maximum, the average, and the minimum capacity of stadiums ?", |
|
) |
|
|
|
|
|
|
|
if st.button("Generate SQL"): |
|
sql_query = generate_sql(text_prompt, schema) |
|
st.write_stream(sql_query) |
|
|