使用 SetFit 进行零样本文本分类的数据标注建议
作者: David Berenstein 和 Sara Han Díaz
建议是使标注团队工作更加轻松快捷的绝佳方式。这些预设选项将使标注过程更加高效,因为标注者只需纠正建议即可。在这个例子中,我们将展示如何使用 SetFit 实现零样本方法,以获取 Argilla 中一个数据集的初步建议,该数据集结合了两个文本分类任务,包括一个 LabelQuestion
和一个 MultiLabelQuestion
。
Argilla 是一个开源的数据策展平台,旨在提升小型和大型语言模型(LLMs)的开发。使用 Argilla,每个人都可以通过使用人类和机器的反馈来更快地进行数据策展,从而构建健壮的语言模型。因此,它为 MLOps 周期的每一步提供支持,从数据标注到模型监控。
反馈是数据策展过程的一个关键部分,Argilla 也提供了一种管理和可视化反馈的方式,以便策展的数据可以后来用于改进语言模型。在本教程中,我们将展示一个实际的例子,说明如何通过提供建议来使我们的标注者工作更加轻松。为此,你将学习如何使用 SetFit 训练零样本情感和主题分类器,然后使用它们为数据集提供建议。
在本教程中,我们将遵循以下步骤:
- 在 Argilla 中创建一个数据集。
- 使用 SetFit 训练零样本分类器。
- 使用训练好的分类器为数据集提供建议。
- 在 Argilla 中可视化这些建议。
让我们开始吧!
初始化设置
对于本教程,你需要运行一个 Argilla 服务器。如果你还没有,请查看我们的快速入门或安装页面。完成后,请完成以下步骤:
- 使用
pip
安装Argilla客户端和所需的第三方库:
!pip install argilla setfit
- 导入必要的库和包
import argilla as rg
from datasets import load_dataset
from setfit import get_templated_dataset
from setfit import SetFitModel, SetFitTrainer
- 如果你使用 Docker 快速启动镜像或 Hugging Face Spaces 运行 Argilla,你需要使用
URL
和API_KEY
初始化 Argilla 客户端:
# Replace api_url with the url to your HF Spaces URL if using Spaces
# Replace api_key if you configured a custom API key
rg.init(api_url="http://localhost:6900", api_key="admin.apikey", workspace="admin")
如果你正在运行一个私有的 Hugging Face Space,你还需要按照以下方式设置 HF_TOKEN:
# # Set the HF_TOKEN environment variable
# import os
# os.environ['HF_TOKEN'] = "your-hf-token"
# # Replace api_url with the url to your HF Spaces URL
# # Replace api_key if you configured a custom API key
# rg.init(
# api_url="https://[your-owner-name]-[your_space_name].hf.space",
# api_key="admin.apikey",
# workspace="admin",
# extra_headers={"Authorization": f"Bearer {os.environ['HF_TOKEN']}"},
# )
配置数据集
在这个例子中,我们将加载 banking77 数据集,这是一个流行的开源数据集,包含了银行领域的客户请求。
data = load_dataset("PolyAI/banking77", split="test")
Argilla 使用 FeedbackDataset
,它可以轻松地让你创建数据集并管理数据和反馈。FeedbackDataset
首先需要通过指明两个主要组件(尽管可以添加更多)来进行配置:要添加标注数据 的 字段 和标注者的 问题。关于 FeedbackDataset
和可选组件的更多信息,请查看 Argilla 文档 和我们的 端到端教程。
你也可以直接使用 默认模板 来创建。
在这种情况下,我们将配置一个自定义数据集,其中包含两个不同的问题,以便我们能够同时处理两个文本分类任务。我们将加载该数据集的原始标签,以对请求中提到的主题进行多标签分类,并且我们还将设置一个问题,以将请求的情感分类为“积极”、“中性”或“消极”。
dataset = rg.FeedbackDataset(
fields=[rg.TextField(name="text")],
questions=[
rg.MultiLabelQuestion(
name="topics",
title="Select the topic(s) of the request",
labels=data.info.features["label"].names, # these are the original labels present in the dataset
visible_labels=10,
),
rg.LabelQuestion(
name="sentiment", title="What is the sentiment of the message?", labels=["positive", "neutral", "negative"]
),
],
)
训练模型
现在我们将使用我们加载的数据以及为数据集配置的标签和问题来训练数据集中的每个问题的零样本文本分类模型。如前面所述,我们将使用 SetFit 框架对两个分类器中的 Sentence Transformers 进行少样本微调。此外,我们将使用的模型是 all-MiniLM-L6-v2,这是一个在 10 亿句子对数据集上使用对比目标进行微调的句子嵌入模型。
def train_model(question_name, template, multi_label=False):
# build a training dataset that uses the labels of a specific question in our Argilla dataset
train_dataset = get_templated_dataset(
candidate_labels=dataset.question_by_name(question_name).labels,
sample_size=8,
template=template,
multi_label=multi_label,
)
# train a model using the training dataset we just built
if multi_label:
model = SetFitModel.from_pretrained("all-MiniLM-L6-v2", multi_target_strategy="one-vs-rest")
else:
model = SetFitModel.from_pretrained("all-MiniLM-L6-v2")
trainer = SetFitTrainer(model=model, train_dataset=train_dataset)
trainer.train()
return model
topic_model = train_model(question_name="topics", template="The customer request is about {}", multi_label=True)
sentiment_model = train_model(question_name="sentiment", template="This message is {}", multi_label=False)
预测
一旦训练步骤结束,我们就可以通过我们的数据进行预测了。
def get_predictions(texts, model, question_name):
probas = model.predict_proba(texts, as_numpy=True)
labels = dataset.question_by_name(question_name).labels
for pred in probas:
yield [{"label": label, "score": score} for label, score in zip(labels, pred)]
data = data.map(
lambda batch: {
"topics": list(get_predictions(batch["text"], topic_model, "topics")),
"sentiment": list(get_predictions(batch["text"], sentiment_model, "sentiment")),
},
batched=True,
)
data.to_pandas().head()
构建记录并推送
有了我们生成的数据和预测,现在我们可以构建记录(将由标注团队标注的每个数据项),其中包括我们模型的建议。对于 LabelQuestion
,我们将使用概率得分最高的标签,而对于 MultiLabelQuestion
,我们将包含所有得分高于一定阈值的标签。在这种情况下,我们决定使用 2/len(labels)
作为阈值,但你可以根据你的数据实验,并决定采用更严格或更宽松的阈值。
注意,更宽松的阈值(接近或等于
1/len(labels)
)将建议更多的标签,而严格的阈值(在 2 到 3 之间)将选择更少的标签(或没有标签)。
def add_suggestions(record):
suggestions = []
# get label with max score for sentiment question
sentiment = max(record["sentiment"], key=lambda x: x["score"])["label"]
suggestions.append({"question_name": "sentiment", "value": sentiment})
# get all labels above a threshold for topics questions
threshold = 2 / len(dataset.question_by_name("topics").labels)
topics = [label["label"] for label in record["topics"] if label["score"] >= threshold]
# apply the suggestion only if at least one label was over the threshold
if topics:
suggestions.append({"question_name": "topics", "value": topics})
return suggestions
records = [rg.FeedbackRecord(fields={"text": record["text"]}, suggestions=add_suggestions(record)) for record in data]
一旦我们对结果满意,我们可以将记录添加到我们上面配置的数据集中。最后,为了可视化并开始标注,你需要将其推送到 Argilla。这意味着将你的数据集添加到运行的 Argilla 服务器上,并使其对标注者可用。
dataset.add_records(records)
dataset.push_to_argilla("setfit_tutorial", workspace="admin")
这是从我们的模型看建议的 UI 样式
这部分可选,你还可以将你的 FeedbackDataset
保存并加载到 Hugging Face Hub。请参阅文档以获取更多关于如何执行此操作的信息。
# Push to HuggingFace Hub
dataset.push_to_huggingface("argilla/my-dataset")
# Load a public dataset
dataset = rg.FeedbackDataset.from_huggingface("argilla/my-dataset")
总结
在本教程中,我们介绍了如何使用 SetFit 库的零样本方法向 Feedback Task 数据集添加建议。这将通过减少标注团队必须做出的决定和编辑数量来提高标注过程的效率。
要了解更多关于 SetFit 的信息,请查看以下链接:
< > Update on GitHub