File size: 1,033 Bytes
66a3123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

import numpy as np
import datetime


def print_batch(tokenizer, batch, n, header=None):
    '''
    print a batch of tokens. Used mainly for debugging
    Parameters
    ------------
    tokenizer : Tokenizer (https://huggingface.co/docs/tokenizers/python/latest/api/reference.html#tokenizers.Tokenizer)

    batch : List of List[int]

    n : int
        number of sentences to print from the batch
    header : str
        header of the batch printed before the sentences
    '''
    print(f'=== {header or "Batch"} ===')
    print(tokenizer.batch_decode(batch[:n], skip_special_tokens=True))
    print('...\n' if n < len(batch) else '')



def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)



def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    elapsed_rounded = int(round((elapsed)))
    return str(datetime.timedelta(seconds=elapsed_rounded))