|
import os |
|
import sys |
|
from pathlib import Path |
|
from datetime import datetime |
|
import json |
|
import traceback |
|
import uuid |
|
from huggingface_hub import CommitScheduler |
|
|
|
current_dir = Path(__file__).resolve().parent |
|
duckdb_nsql_dir = current_dir / 'duckdb-nsql' |
|
eval_dir = duckdb_nsql_dir / 'eval' |
|
sys.path.extend([str(current_dir), str(duckdb_nsql_dir), str(eval_dir)]) |
|
|
|
from eval.predict import get_manifest, DefaultLoader, PROMPT_FORMATTERS, generate_sql |
|
from eval.evaluate import evaluate, compute_metrics, get_to_print |
|
from eval.evaluate import test_suite_evaluation, read_tables_json |
|
from eval.schema import TextToSQLParams, Table |
|
|
|
AVAILABLE_PROMPT_FORMATS = list(PROMPT_FORMATTERS.keys()) |
|
|
|
prediction_folder = Path("prediction_results/") |
|
evaluation_folder = Path("evaluation_results/") |
|
|
|
file_uuid = uuid.uuid4() |
|
|
|
prediction_scheduler = CommitScheduler( |
|
repo_id="sql-console/duckdb-nsql-predictions", |
|
repo_type="dataset", |
|
folder_path=prediction_folder, |
|
path_in_repo="data", |
|
every=10, |
|
) |
|
|
|
evaluation_scheduler = CommitScheduler( |
|
repo_id="sql-console/duckdb-nsql-scores", |
|
repo_type="dataset", |
|
folder_path=evaluation_folder, |
|
path_in_repo="data", |
|
every=10, |
|
) |
|
|
|
def save_prediction(inference_api, model_name, prompt_format, question, generated_sql): |
|
prediction_file = prediction_folder / f"prediction_{file_uuid}.json" |
|
prediction_folder.mkdir(parents=True, exist_ok=True) |
|
with prediction_scheduler.lock: |
|
with prediction_file.open("a") as f: |
|
json.dump({ |
|
"inference_api": inference_api, |
|
"model_name": model_name, |
|
"prompt_format": prompt_format, |
|
"question": question, |
|
"generated_sql": generated_sql, |
|
"timestamp": datetime.now().isoformat() |
|
}, f) |
|
|
|
def save_evaluation(inference_api, model_name, prompt_format, custom_prompt, metrics): |
|
evaluation_file = evaluation_folder / f"evaluation_{file_uuid}.json" |
|
evaluation_folder.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
categories = ['easy', 'medium', 'hard', 'duckdb', 'ddl', 'all'] |
|
flattened_metrics = { |
|
"inference_api": inference_api, |
|
"model_name": model_name, |
|
"prompt_format": prompt_format, |
|
"custom_prompt": str(custom_prompt) if prompt_format.startswith("custom") else "", |
|
"timestamp": datetime.now().isoformat() |
|
} |
|
|
|
|
|
for category in categories: |
|
if category in metrics['exec']: |
|
category_metrics = metrics['exec'][category] |
|
flattened_metrics[f"{category}_count"] = category_metrics['count'] |
|
flattened_metrics[f"{category}_execution_accuracy"] = category_metrics['exec'] |
|
else: |
|
flattened_metrics[f"{category}_count"] = 0 |
|
flattened_metrics[f"{category}_execution_accuracy"] = 0.0 |
|
|
|
with evaluation_scheduler.lock: |
|
with evaluation_file.open("a") as f: |
|
json.dump(flattened_metrics, f) |
|
f.write('\n') |
|
|
|
def run_prediction(inference_api, model_name, prompt_format, custom_prompt, output_file): |
|
dataset_path = str(eval_dir / "data/dev.json") |
|
table_meta_path = str(eval_dir / "data/tables.json") |
|
stop_tokens = ['`<|dummy|>`'] |
|
max_tokens = 1000 |
|
temperature = 0 |
|
num_beams = -1 |
|
manifest_client = inference_api |
|
manifest_engine = model_name |
|
manifest_connection = "http://localhost:5000" |
|
overwrite_manifest = True |
|
parallel = False |
|
|
|
yield "Starting prediction..." |
|
|
|
try: |
|
|
|
data_formatter = DefaultLoader() |
|
if prompt_format.startswith("custom"): |
|
prompt_formatter_cls = PROMPT_FORMATTERS["custom"] |
|
prompt_formatter_cls.PROMPT_TEMPLATE = custom_prompt |
|
prompt_formatter = prompt_formatter_cls() |
|
else: |
|
prompt_formatter = PROMPT_FORMATTERS[prompt_format]() |
|
|
|
|
|
manifest = get_manifest( |
|
manifest_client=manifest_client, |
|
manifest_connection=manifest_connection, |
|
manifest_engine=manifest_engine, |
|
) |
|
|
|
|
|
data = data_formatter.load_data(dataset_path) |
|
db_to_tables = data_formatter.load_table_metadata(table_meta_path) |
|
|
|
|
|
text_to_sql_inputs = [] |
|
for input_question in data: |
|
question = input_question["question"] |
|
db_id = input_question.get("db_id", "none") |
|
if db_id != "none": |
|
table_params = list(db_to_tables.get(db_id, {}).values()) |
|
else: |
|
table_params = [] |
|
|
|
text_to_sql_inputs.append(TextToSQLParams( |
|
instruction=question, |
|
database=db_id, |
|
tables=table_params, |
|
)) |
|
|
|
|
|
generated_sqls = generate_sql( |
|
manifest=manifest, |
|
text_to_sql_in=text_to_sql_inputs, |
|
retrieved_docs=[[] for _ in text_to_sql_inputs], |
|
prompt_formatter=prompt_formatter, |
|
stop_tokens=stop_tokens, |
|
overwrite_manifest=overwrite_manifest, |
|
max_tokens=max_tokens, |
|
temperature=temperature, |
|
num_beams=num_beams, |
|
parallel=parallel |
|
) |
|
|
|
|
|
output_file.parent.mkdir(parents=True, exist_ok=True) |
|
with output_file.open('w') as f: |
|
for original_data, (sql, _) in zip(data, generated_sqls): |
|
output = {**original_data, "pred": sql} |
|
json.dump(output, f) |
|
f.write('\n') |
|
|
|
|
|
save_prediction(inference_api, model_name, prompt_format, original_data["question"], sql) |
|
|
|
yield f"Prediction completed. Results saved to {output_file}" |
|
except Exception as e: |
|
yield f"Prediction failed with error: {str(e)}" |
|
yield f"Error traceback: {traceback.format_exc()}" |
|
|
|
def run_evaluation(inference_api, model_name, prompt_format="duckdbinstgraniteshort", custom_prompt=None): |
|
if "OPENROUTER_API_KEY" not in os.environ: |
|
yield "Error: OPENROUTER_API_KEY not found in environment variables." |
|
return |
|
if "HF_TOKEN" not in os.environ: |
|
yield "Error: HF_TOKEN not found in environment variables." |
|
return |
|
|
|
try: |
|
|
|
dataset_path = str(eval_dir / "data/dev.json") |
|
table_meta_path = str(eval_dir / "data/tables.json") |
|
output_dir = eval_dir / "output" |
|
|
|
yield f"Using model: {model_name}" |
|
yield f"Using prompt format: {prompt_format}" |
|
|
|
if prompt_format == "custom": |
|
prompt_format = prompt_format+"_"+str(abs(hash(custom_prompt)) % (10 ** 8)) |
|
|
|
output_file = output_dir / f"{prompt_format}_0docs_{model_name.replace('/', '_')}_dev_{datetime.now().strftime('%y-%m-%d')}.json" |
|
|
|
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
if output_file.exists(): |
|
yield f"Prediction file already exists: {output_file}" |
|
yield "Skipping prediction step and proceeding to evaluation." |
|
else: |
|
|
|
for output in run_prediction(inference_api, model_name, prompt_format, custom_prompt, output_file): |
|
yield output |
|
|
|
|
|
yield "Starting evaluation..." |
|
|
|
|
|
gold_path = Path(dataset_path) |
|
db_dir = str(eval_dir / "data/databases/") |
|
tables_path = Path(table_meta_path) |
|
|
|
kmaps = test_suite_evaluation.build_foreign_key_map_from_json(str(tables_path)) |
|
db_schemas = read_tables_json(str(tables_path)) |
|
|
|
gold_sqls_dict = json.load(gold_path.open("r", encoding="utf-8")) |
|
pred_sqls_dict = [json.loads(l) for l in output_file.open("r").readlines()] |
|
|
|
gold_sqls = [p.get("query", p.get("sql", "")) for p in gold_sqls_dict] |
|
setup_sqls = [p["setup_sql"] for p in gold_sqls_dict] |
|
validate_sqls = [p["validation_sql"] for p in gold_sqls_dict] |
|
gold_dbs = [p.get("db_id", p.get("db", "")) for p in gold_sqls_dict] |
|
pred_sqls = [p["pred"] for p in pred_sqls_dict] |
|
categories = [p.get("category", "") for p in gold_sqls_dict] |
|
|
|
yield "Computing metrics..." |
|
metrics = compute_metrics( |
|
gold_sqls=gold_sqls, |
|
pred_sqls=pred_sqls, |
|
gold_dbs=gold_dbs, |
|
setup_sqls=setup_sqls, |
|
validate_sqls=validate_sqls, |
|
kmaps=kmaps, |
|
db_schemas=db_schemas, |
|
database_dir=db_dir, |
|
lowercase_schema_match=False, |
|
model_name=model_name, |
|
categories=categories, |
|
) |
|
|
|
|
|
save_evaluation(inference_api, model_name, prompt_format, custom_prompt, metrics) |
|
|
|
yield "Evaluation completed." |
|
|
|
if metrics: |
|
yield "Overall Results:" |
|
overall_metrics = metrics['exec']['all'] |
|
yield f"All (n={overall_metrics['count']}) - Execution Accuracy: {overall_metrics['exec']:.3f}" |
|
yield f"All (n={overall_metrics['count']}) - Edit Distance: {metrics['edit_distance']['edit_distance']:.3f}" |
|
|
|
categories = ['easy', 'medium', 'hard', 'duckdb', 'ddl', 'all'] |
|
|
|
for category in categories: |
|
if category in metrics['exec']: |
|
category_metrics = metrics['exec'][category] |
|
yield f"{category} (n={category_metrics['count']}) - Execution Accuracy: {category_metrics['exec']:.3f}" |
|
else: |
|
yield f"{category}: No data available" |
|
else: |
|
yield "No evaluation metrics returned." |
|
except Exception as e: |
|
yield f"An unexpected error occurred: {str(e)}" |
|
yield f"Error traceback: {traceback.format_exc()}" |
|
|
|
if __name__ == "__main__": |
|
model_name = input("Enter the model name: ") |
|
prompt_format = input("Enter the prompt format (default is duckdbinstgraniteshort): ") or "duckdbinstgraniteshort" |
|
for result in run_evaluation(model_name, prompt_format): |
|
print(result, flush=True) |