|
""" |
|
|
|
Ted Multi TranslationDataset Class |
|
------------------------------------ |
|
""" |
|
|
|
|
|
import collections |
|
|
|
import datasets |
|
import numpy as np |
|
|
|
from textattack.datasets import HuggingFaceDataset |
|
|
|
|
|
class TedMultiTranslationDataset(HuggingFaceDataset): |
|
"""Loads examples from the Ted Talk translation dataset using the |
|
`datasets` package. |
|
|
|
dataset source: http://www.cs.jhu.edu/~kevinduh/a/multitarget-tedtalks/ |
|
""" |
|
|
|
def __init__(self, source_lang="en", target_lang="de", split="test"): |
|
self._dataset = datasets.load_dataset("ted_multi")[split] |
|
self.examples = self._dataset["translations"] |
|
language_options = set(self.examples[0]["language"]) |
|
if source_lang not in language_options: |
|
raise ValueError( |
|
f"Source language {source_lang} invalid. Choices: {sorted(language_options)}" |
|
) |
|
if target_lang not in language_options: |
|
raise ValueError( |
|
f"Target language {target_lang} invalid. Choices: {sorted(language_options)}" |
|
) |
|
self.source_lang = source_lang |
|
self.target_lang = target_lang |
|
|
|
def _format_raw_example(self, raw_example): |
|
translations = np.array(raw_example["translation"]) |
|
languages = np.array(raw_example["language"]) |
|
source = translations[languages == self.source_lang][0] |
|
target = translations[languages == self.target_lang][0] |
|
source_dict = collections.OrderedDict([("Source", source)]) |
|
return (source_dict, target) |
|
|