neuraltextgen_BERT / utils.py
JuanJoseMV's picture
first commit
66a3123
import numpy as np
import datetime
def print_batch(tokenizer, batch, n, header=None):
'''
print a batch of tokens. Used mainly for debugging
Parameters
------------
tokenizer : Tokenizer (https://huggingface.co/docs/tokenizers/python/latest/api/reference.html#tokenizers.Tokenizer)
batch : List of List[int]
n : int
number of sentences to print from the batch
header : str
header of the batch printed before the sentences
'''
print(f'=== {header or "Batch"} ===')
print(tokenizer.batch_decode(batch[:n], skip_special_tokens=True))
print('...\n' if n < len(batch) else '')
def flat_accuracy(preds, labels):
pred_flat = np.argmax(preds, axis=1).flatten()
labels_flat = labels.flatten()
return np.sum(pred_flat == labels_flat) / len(labels_flat)
def format_time(elapsed):
'''
Takes a time in seconds and returns a string hh:mm:ss
'''
elapsed_rounded = int(round((elapsed)))
return str(datetime.timedelta(seconds=elapsed_rounded))