File size: 1,176 Bytes
abbb14d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from transformers import AutoFeatureExtractor, AutoTokenizer

pretrained_name = 'openai/whisper-base'
feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_name)
tokenizer = AutoTokenizer.from_pretrained(pretrained_name)

def prepare_dataset(
    sample: dict,
    labels_max_len: int = None,
):
    sample = sample['audio']
    inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])

    sample['input_features'] = inputs.get('input_features')[0]
    sample["input_length"] = len(sample["array"])

    input_str = sample['sentence']
    
    sample['labels'] = tokenizer(input_str).input_ids
    sample['labels_length'] = len(sample['labels'])  # include special characters

    sample['labels_truncated'] = 0
    # need to truncate validation and test labels that are longer that model.config.max_length.
    # can't drop such examples because this will affect validation and test scores.
    # thus need to truncate.
    if labels_max_len is not None:
        if len(sample['labels']) > labels_max_len:
            sample['labels'] = sample['labels'][:labels_max_len]
            sample['labels_truncated'] = 1

    return sample