|
|
|
import numpy as np |
|
import datetime |
|
|
|
|
|
def print_batch(tokenizer, batch, n, header=None): |
|
''' |
|
print a batch of tokens. Used mainly for debugging |
|
Parameters |
|
------------ |
|
tokenizer : Tokenizer (https://huggingface.co/docs/tokenizers/python/latest/api/reference.html#tokenizers.Tokenizer) |
|
|
|
batch : List of List[int] |
|
|
|
n : int |
|
number of sentences to print from the batch |
|
header : str |
|
header of the batch printed before the sentences |
|
''' |
|
print(f'=== {header or "Batch"} ===') |
|
print(tokenizer.batch_decode(batch[:n], skip_special_tokens=True)) |
|
print('...\n' if n < len(batch) else '') |
|
|
|
|
|
|
|
def flat_accuracy(preds, labels): |
|
pred_flat = np.argmax(preds, axis=1).flatten() |
|
labels_flat = labels.flatten() |
|
return np.sum(pred_flat == labels_flat) / len(labels_flat) |
|
|
|
|
|
|
|
def format_time(elapsed): |
|
''' |
|
Takes a time in seconds and returns a string hh:mm:ss |
|
''' |
|
elapsed_rounded = int(round((elapsed))) |
|
return str(datetime.timedelta(seconds=elapsed_rounded)) |