File size: 2,316 Bytes
54f25c9 8e3f06a f2210e2 8e3f06a 54f25c9 8e3f06a 54f25c9 8e3f06a 54f25c9 8e3f06a 86c493f 1c29a01 b933ab1 746c4fa cb19bba b933ab1 8e3f06a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import gradio as gr
from transformers import pipeline
# Load the pre-trained model from Hugging Face
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, pipeline
peft_model_id = "jinhybr/code-llama-7b-text-to-sql"
# peft_model_id = args.output_dir
# Load Model with PEFT adapter
model = AutoPeftModelForCausalLM.from_pretrained(
peft_model_id,
device_map="auto",
torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
# load into pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
def text_to_sql(text):
# Load Model with PEFT adapter
# Define schema and user question
#schema = "CREATE TABLE table_17429402_7 (school VARCHAR, last_occ_championship VARCHAR)"
schema = 'You are an text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.\nSCHEMA:\nCREATE TABLE table_17429402_7 (school VARCHAR, last_occ_championship VARCHAR)'
user_question = text
#user_question = 'How many schools won their last occ championship in 2006?'
# Combine schema and user question
combined_json_data = [
{'content': schema, 'role': 'system'},
{'content': user_question, 'role': 'user'}
]
# Generate SQL query
prompt = pipe.tokenizer.apply_chat_template(combined_json_data, tokenize=False, add_generation_prompt=True)
outputs = pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)
sql_query = outputs[0]['generated_text'][len(prompt):].strip()
return sql_query
# Create Gradio Interface
iface = gr.Interface(
fn=text_to_sql,
#inputs=gr.inputs.Textbox(lines=7, label="User Question"),
#inputs=gr.inputs.Textbox(lines=7, label="User Question"),
inputs = ['text'],
outputs=['text'],
theme="soft",
examples=['How many schools won their last occ championship in 2006?'],
cache_examples=True,
title="Finetuned code-llama-7b for Text-to-SQL Demo",
description="Translate text to SQL query based on the provided schema.CREATE TABLE table_17429402_7 (school VARCHAR, last_occ_championship VARCHAR)"
)
iface.launch()
|