File size: 8,229 Bytes
4943752 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
"""
HuggingFaceDataset Class
=========================
TextAttack allows users to provide their own dataset or load from HuggingFace.
"""
import collections
import datasets
import textattack
from .dataset import Dataset
def _cb(s):
"""Colors some text blue for printing to the terminal."""
return textattack.shared.utils.color_text(str(s), color="blue", method="ansi")
def get_datasets_dataset_columns(dataset):
"""Common schemas for datasets found in dataset hub."""
schema = set(dataset.column_names)
if {"premise", "hypothesis", "label"} <= schema:
input_columns = ("premise", "hypothesis")
output_column = "label"
elif {"question", "sentence", "label"} <= schema:
input_columns = ("question", "sentence")
output_column = "label"
elif {"sentence1", "sentence2", "label"} <= schema:
input_columns = ("sentence1", "sentence2")
output_column = "label"
elif {"question1", "question2", "label"} <= schema:
input_columns = ("question1", "question2")
output_column = "label"
elif {"question", "sentence", "label"} <= schema:
input_columns = ("question", "sentence")
output_column = "label"
elif {"context", "question", "title", "answers"} <= schema:
# Common schema for SQUAD dataset
input_columns = ("title", "context", "question")
output_column = "answers"
elif {"text", "label"} <= schema:
input_columns = ("text",)
output_column = "label"
elif {"sentence", "label"} <= schema:
input_columns = ("sentence",)
output_column = "label"
elif {"document", "summary"} <= schema:
input_columns = ("document",)
output_column = "summary"
elif {"content", "summary"} <= schema:
input_columns = ("content",)
output_column = "summary"
elif {"label", "review"} <= schema:
input_columns = ("review",)
output_column = "label"
else:
raise ValueError(
f"Unsupported dataset schema {schema}. Try passing your own `dataset_columns` argument."
)
return input_columns, output_column
class HuggingFaceDataset(Dataset):
"""Loads a dataset from 🤗 Datasets and prepares it as a TextAttack dataset.
Args:
name_or_dataset (:obj:`Union[str, datasets.Dataset]`):
The dataset name as :obj:`str` or actual :obj:`datasets.Dataset` object.
If it's your custom :obj:`datasets.Dataset` object, please pass the input and output columns via :obj:`dataset_columns` argument.
subset (:obj:`str`, `optional`, defaults to :obj:`None`):
The subset of the main dataset. Dataset will be loaded as :obj:`datasets.load_dataset(name, subset)`.
split (:obj:`str`, `optional`, defaults to :obj:`"train"`):
The split of the dataset.
dataset_columns (:obj:`tuple(list[str], str))`, `optional`, defaults to :obj:`None`):
Pair of :obj:`list[str]` representing list of input column names (e.g. :obj:`["premise", "hypothesis"]`)
and :obj:`str` representing the output column name (e.g. :obj:`label`). If not set, we will try to automatically determine column names from known designs.
label_map (:obj:`dict[int, int]`, `optional`, defaults to :obj:`None`):
Mapping if output labels of the dataset should be re-mapped. Useful if model was trained with a different label arrangement.
For example, if dataset's arrangement is 0 for `Negative` and 1 for `Positive`, but model's label
arrangement is 1 for `Negative` and 0 for `Positive`, passing :obj:`{0: 1, 1: 0}` will remap the dataset's label to match with model's arrangements.
Could also be used to remap literal labels to numerical labels (e.g. :obj:`{"positive": 1, "negative": 0}`).
label_names (:obj:`list[str]`, `optional`, defaults to :obj:`None`):
List of label names in corresponding order (e.g. :obj:`["World", "Sports", "Business", "Sci/Tech"]` for AG-News dataset).
If not set, labels will printed as is (e.g. "0", "1", ...). This should be set to :obj:`None` for non-classification datasets.
output_scale_factor (:obj:`float`, `optional`, defaults to :obj:`None`):
Factor to divide ground-truth outputs by. Generally, TextAttack goal functions require model outputs between 0 and 1.
Some datasets are regression tasks, in which case this is necessary.
shuffle (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to shuffle the underlying dataset.
.. note::
Generally not recommended to shuffle the underlying dataset. Shuffling can be performed using DataLoader or by shuffling the order of indices we attack.
"""
def __init__(
self,
name_or_dataset,
subset=None,
split="train",
dataset_columns=None,
label_map=None,
label_names=None,
output_scale_factor=None,
shuffle=False,
):
if isinstance(name_or_dataset, datasets.Dataset):
self._dataset = name_or_dataset
else:
self._name = name_or_dataset
self._subset = subset
self._dataset = datasets.load_dataset(self._name, subset)[split]
subset_print_str = f", subset {_cb(subset)}" if subset else ""
textattack.shared.logger.info(
f"Loading {_cb('datasets')} dataset {_cb(self._name)}{subset_print_str}, split {_cb(split)}."
)
# Input/output column order, like (('premise', 'hypothesis'), 'label')
(
self.input_columns,
self.output_column,
) = dataset_columns or get_datasets_dataset_columns(self._dataset)
if not isinstance(self.input_columns, (list, tuple)):
raise ValueError(
"First element of `dataset_columns` must be a list or a tuple."
)
self.label_map = label_map
self.output_scale_factor = output_scale_factor
if label_names:
self.label_names = label_names
else:
try:
self.label_names = self._dataset.features[self.output_column].names
except (KeyError, AttributeError):
# This happens when the dataset doesn't have 'features' or a 'label' column.
self.label_names = None
# If labels are remapped, the label names have to be remapped as well.
if self.label_names and label_map:
self.label_names = [
self.label_names[self.label_map[i]] for i in self.label_map
]
self.shuffled = shuffle
if shuffle:
self._dataset.shuffle()
def _format_as_dict(self, example):
input_dict = collections.OrderedDict(
[(c, example[c]) for c in self.input_columns]
)
output = example[self.output_column]
if self.label_map:
output = self.label_map[output]
if self.output_scale_factor:
output = output / self.output_scale_factor
return (input_dict, output)
def filter_by_labels_(self, labels_to_keep):
"""Filter items by their labels for classification datasets. Performs
in-place filtering.
Args:
labels_to_keep (:obj:`Union[Set, Tuple, List, Iterable]`):
Set, tuple, list, or iterable of integers representing labels.
"""
if not isinstance(labels_to_keep, set):
labels_to_keep = set(labels_to_keep)
self._dataset = self._dataset.filter(
lambda x: x[self.output_column] in labels_to_keep
)
def __getitem__(self, i):
"""Return i-th sample."""
if isinstance(i, int):
return self._format_as_dict(self._dataset[i])
else:
# `idx` could be a slice or an integer. if it's a slice,
# return the formatted version of the proper slice of the list
return [
self._format_as_dict(self._dataset[j]) for j in range(i.start, i.stop)
]
def shuffle(self):
self._dataset.shuffle()
self.shuffled = True
|