few_shot_intent_gpt2_base

这个模型是基于 uer/gpt2-chinese-cluecorpussmall 模型在 qgyd2021/few_shot_intent_sft 数据集上微调的结果.

(1)训练在(11000 steps)处 Early Stop。这相当于加载的 qgyd2021/few_shot_intent_sft 数据集的 1 个 epoch 处。

(2)此处保存的是 checkpoint-6000 (6000 steps)的最优权重。这相当于原数据集的 0.63 个 epoch 处。

最终的模型大约是在训练了 0.6 个 epoch 时保存的结果。

你可以在此处体验该模型 qgyd2021/gpt2_chat

TensorBoard 数集

Eval Loss 见下图:

eval_loss.jpg

Learning rate 见下图:

学习率从 2e-4 下降到 1.4e-4。

learning_rate.jpg

讨论

(1)最优解在不到 1 个 epoch 处得到。

  • 这可能说明 GPT2 模型大小,相对于任务复杂度来说太小了。

  • 模型进入到局部最终解而无法跳出,应考虑使用较大的学习率,或更换学习率调度器。

(2)后续应考虑针对 prompt-response 中 response 部分进行训练。

  • 即只优化 response 部分的损失以提升识别结果与 prompt 之间的注意力机制。当前的训练有可能只是使模型拟合了 few shot 数据的格式,而并没有拟合到意图识别的目的。

(3)模型使用中的体会。

  • 如果在使用过程中,模型生成 response 不在 prompt 中给定的选项,这可能说明模型已经过拟合了。

  • 如果模型生成 response 在 prompt 中,但答案不正确,则说明模型已学习到生成的表层模型,而没有学习到意图识别的目的。则建议在此模型基础上进一步优化 response 部分的损失。

其它

训练时加载数据集的代码

#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import json

from datasets import load_dataset
from datasets.download.download_manager import DownloadMode
from tqdm import tqdm

from project_settings import project_path


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_path", default="qgyd2021/few_shot_intent_sft", type=str)
    parser.add_argument("--dataset_split", default=None, type=str)
    parser.add_argument(
        "--dataset_cache_dir",
        default=(project_path / "hub_datasets").as_posix(),
        type=str
    )

    parser.add_argument("--num_epochs", default=1, type=int)

    parser.add_argument("--train_subset", default="train.jsonl", type=str)
    parser.add_argument("--valid_subset", default="valid.jsonl", type=str)
    args = parser.parse_args()
    return args


def main():
    args = get_args()

    name_list = [
        # "a_intent_prompt",
        "amazon_massive_intent_en_us_prompt",
        "amazon_massive_intent_zh_cn_prompt",
        "atis_intents_prompt",
        "banking77_prompt",
        "bi_text11_prompt",
        "bi_text27_prompt",
        # "book6_prompt",
        "carer_prompt",
        "chatbots_prompt",
        "chinese_news_title_prompt",
        "cmid_4class_prompt",
        "cmid_36class_prompt",
        "coig_cqia_prompt",
        "conv_intent_prompt",
        "crosswoz_prompt",
        "dmslots_prompt",
        "dnd_style_intents_prompt",
        "emo2019_prompt",
        "finance21_prompt",
        "ide_intent_prompt",
        "intent_classification_prompt",
        "jarvis_intent_prompt",
        "mobile_assistant_prompt",
        "mtop_intent_prompt",
        "out_of_scope_prompt",
        "ri_sawoz_domain_prompt",
        "ri_sawoz_general_prompt",
        "small_talk_prompt",
        "smp2017_task1_prompt",
        "smp2019_task1_domain_prompt",
        "smp2019_task1_intent_prompt",
        # "snips_built_in_intents_prompt",
        "star_wars_prompt",
        "suicide_intent_prompt",
        "snips_built_in_intents_prompt",
        "telemarketing_intent_cn_prompt",
        "telemarketing_intent_en_prompt",
        "vira_intents_prompt",
    ]

    with open(args.train_subset, "w", encoding="utf-8") as f:
        for _ in range(args.num_epochs):
            for name in name_list:
                print(name)
                dataset = load_dataset(
                    path=args.dataset_path,
                    name=name,
                    split="train",
                    cache_dir=args.dataset_cache_dir,
                    download_mode=DownloadMode.FORCE_REDOWNLOAD,
                    ignore_verifications=True
                )
                for sample in tqdm(dataset):
                    row = json.dumps(sample, ensure_ascii=False)
                    f.write("{}\n".format(row))

    with open(args.valid_subset, "w", encoding="utf-8") as f:
        for _ in range(args.num_epochs):
            for name in name_list:
                print(name)
                dataset = load_dataset(
                    path=args.dataset_path,
                    name=name,
                    split="test",
                    cache_dir=args.dataset_cache_dir,
                    download_mode=DownloadMode.FORCE_REDOWNLOAD,
                    ignore_verifications=True
                )
                for sample in tqdm(dataset):
                    row = json.dumps(sample, ensure_ascii=False)
                    f.write("{}\n".format(row))

    return


if __name__ == '__main__':
    main()
Downloads last month
28
Safetensors
Model size
102M params
Tensor type
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Space using qgyd2021/few_shot_intent_gpt2_base 1