schema_filter / eval_mode.py
justinsiow's picture
Uploaded Utils, Pycache and Python Files
1e712af verified
from schema_filter import filter_func, SchemaItemClassifierInference
# 在eval模式下,sql不用提供
data = {
"text": "Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.",
"sql": "",
"schema": {
"schema_items": [
{
"table_name": "lists",
"table_comment": "",
"column_names": [
"user_id",
"list_id",
"list_title",
"list_movie_number",
"list_update_timestamp_utc",
"list_creation_timestamp_utc",
"list_followers",
"list_url",
"list_comments",
"list_description",
"list_cover_image_url",
"list_first_image_url",
"list_second_image_url",
"list_third_image_url"
],
"column_comments": [
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
""
]
},
{
"table_name": "movies",
"table_comment": "",
"column_names": [
"movie_id",
"movie_title",
"movie_release_year",
"movie_url",
"movie_title_language",
"movie_popularity",
"movie_image_url",
"director_id",
"director_name",
"director_url"
],
"column_comments": [
"",
"",
"",
"",
"",
"",
"",
"",
"",
""
]
},
{
"table_name": "ratings_users",
"table_comment": "",
"column_names": [
"user_id",
"rating_date_utc",
"user_trialist",
"user_subscriber",
"user_avatar_image_url",
"user_cover_image_url",
"user_eligible_for_trial",
"user_has_payment_method"
],
"column_comments": [
"",
"",
"",
"",
"",
"",
"",
""
]
},
{
"table_name": "lists_users",
"table_comment": "",
"column_names": [
"user_id",
"list_id",
"list_update_date_utc",
"list_creation_date_utc",
"user_trialist",
"user_subscriber",
"user_avatar_image_url",
"user_cover_image_url",
"user_eligible_for_trial",
"user_has_payment_method"
],
"column_comments": [
"",
"",
"",
"",
"",
"",
"",
"",
"",
""
]
},
{
"table_name": "ratings",
"table_comment": "",
"column_names": [
"movie_id",
"rating_id",
"rating_url",
"rating_score",
"rating_timestamp_utc",
"critic",
"critic_likes",
"critic_comments",
"user_id",
"user_trialist",
"user_subscriber",
"user_eligible_for_trial",
"user_has_payment_method"
],
"column_comments": [
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
""
]
}
]
}
}
dataset = [data]
# 最多保留数据库中的7张表
num_top_k_tables = 7
# 对于每张保留的表,最多保留其中20个列,所以输入的prompt中最多有7*10=70个列
num_top_k_columns = 10
# 加载分类器模型
sic = SchemaItemClassifierInference("sic_merged")
# 对于测试数据,我们需要加载训练好的分类器,根据用户问题对表和列打分
dataset = filter_func(
dataset = dataset,
dataset_type = "eval",
sic = sic,
num_top_k_tables = num_top_k_tables,
num_top_k_columns = num_top_k_columns
)
print(dataset)