File size: 1,033 Bytes
66a3123 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import numpy as np
import datetime
def print_batch(tokenizer, batch, n, header=None):
'''
print a batch of tokens. Used mainly for debugging
Parameters
------------
tokenizer : Tokenizer (https://huggingface.co/docs/tokenizers/python/latest/api/reference.html#tokenizers.Tokenizer)
batch : List of List[int]
n : int
number of sentences to print from the batch
header : str
header of the batch printed before the sentences
'''
print(f'=== {header or "Batch"} ===')
print(tokenizer.batch_decode(batch[:n], skip_special_tokens=True))
print('...\n' if n < len(batch) else '')
def flat_accuracy(preds, labels):
pred_flat = np.argmax(preds, axis=1).flatten()
labels_flat = labels.flatten()
return np.sum(pred_flat == labels_flat) / len(labels_flat)
def format_time(elapsed):
'''
Takes a time in seconds and returns a string hh:mm:ss
'''
elapsed_rounded = int(round((elapsed)))
return str(datetime.timedelta(seconds=elapsed_rounded)) |