few_shot_intent_gpt2_base
这个模型是基于 uer/gpt2-chinese-cluecorpussmall 模型在 qgyd2021/few_shot_intent_sft 数据集上微调的结果.
(1)训练在(11000 steps)处 Early Stop。这相当于加载的 qgyd2021/few_shot_intent_sft 数据集的 1 个 epoch 处。
(2)此处保存的是 checkpoint-6000 (6000 steps)的最优权重。这相当于原数据集的 0.63 个 epoch 处。
最终的模型大约是在训练了 0.6 个 epoch 时保存的结果。
你可以在此处体验该模型 qgyd2021/gpt2_chat。
TensorBoard 数集
Eval Loss 见下图:
Learning rate 见下图:
学习率从 2e-4 下降到 1.4e-4。
讨论
(1)最优解在不到 1 个 epoch 处得到。
这可能说明 GPT2 模型大小,相对于任务复杂度来说太小了。
模型进入到局部最终解而无法跳出,应考虑使用较大的学习率,或更换学习率调度器。
(2)后续应考虑针对 prompt-response 中 response 部分进行训练。
- 即只优化 response 部分的损失以提升识别结果与 prompt 之间的注意力机制。当前的训练有可能只是使模型拟合了 few shot 数据的格式,而并没有拟合到意图识别的目的。
(3)模型使用中的体会。
如果在使用过程中,模型生成 response 不在 prompt 中给定的选项,这可能说明模型已经过拟合了。
如果模型生成 response 在 prompt 中,但答案不正确,则说明模型已学习到生成的表层模型,而没有学习到意图识别的目的。则建议在此模型基础上进一步优化 response 部分的损失。
其它
训练时加载数据集的代码
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import json
from datasets import load_dataset
from datasets.download.download_manager import DownloadMode
from tqdm import tqdm
from project_settings import project_path
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", default="qgyd2021/few_shot_intent_sft", type=str)
parser.add_argument("--dataset_split", default=None, type=str)
parser.add_argument(
"--dataset_cache_dir",
default=(project_path / "hub_datasets").as_posix(),
type=str
)
parser.add_argument("--num_epochs", default=1, type=int)
parser.add_argument("--train_subset", default="train.jsonl", type=str)
parser.add_argument("--valid_subset", default="valid.jsonl", type=str)
args = parser.parse_args()
return args
def main():
args = get_args()
name_list = [
# "a_intent_prompt",
"amazon_massive_intent_en_us_prompt",
"amazon_massive_intent_zh_cn_prompt",
"atis_intents_prompt",
"banking77_prompt",
"bi_text11_prompt",
"bi_text27_prompt",
# "book6_prompt",
"carer_prompt",
"chatbots_prompt",
"chinese_news_title_prompt",
"cmid_4class_prompt",
"cmid_36class_prompt",
"coig_cqia_prompt",
"conv_intent_prompt",
"crosswoz_prompt",
"dmslots_prompt",
"dnd_style_intents_prompt",
"emo2019_prompt",
"finance21_prompt",
"ide_intent_prompt",
"intent_classification_prompt",
"jarvis_intent_prompt",
"mobile_assistant_prompt",
"mtop_intent_prompt",
"out_of_scope_prompt",
"ri_sawoz_domain_prompt",
"ri_sawoz_general_prompt",
"small_talk_prompt",
"smp2017_task1_prompt",
"smp2019_task1_domain_prompt",
"smp2019_task1_intent_prompt",
# "snips_built_in_intents_prompt",
"star_wars_prompt",
"suicide_intent_prompt",
"snips_built_in_intents_prompt",
"telemarketing_intent_cn_prompt",
"telemarketing_intent_en_prompt",
"vira_intents_prompt",
]
with open(args.train_subset, "w", encoding="utf-8") as f:
for _ in range(args.num_epochs):
for name in name_list:
print(name)
dataset = load_dataset(
path=args.dataset_path,
name=name,
split="train",
cache_dir=args.dataset_cache_dir,
download_mode=DownloadMode.FORCE_REDOWNLOAD,
ignore_verifications=True
)
for sample in tqdm(dataset):
row = json.dumps(sample, ensure_ascii=False)
f.write("{}\n".format(row))
with open(args.valid_subset, "w", encoding="utf-8") as f:
for _ in range(args.num_epochs):
for name in name_list:
print(name)
dataset = load_dataset(
path=args.dataset_path,
name=name,
split="test",
cache_dir=args.dataset_cache_dir,
download_mode=DownloadMode.FORCE_REDOWNLOAD,
ignore_verifications=True
)
for sample in tqdm(dataset):
row = json.dumps(sample, ensure_ascii=False)
f.write("{}\n".format(row))
return
if __name__ == '__main__':
main()
- Downloads last month
- 28
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.