Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# @Time : 2023/04/18 08:11 p.m. | |
# @Author : JianingWang | |
# @File : uncertainty.py | |
from sklearn.utils import shuffle | |
import logging | |
import numpy as np | |
import os | |
import random | |
logger = logging.getLogger(__name__) | |
def get_BALD_acquisition(y_T): | |
expected_entropy = - np.mean(np.sum(y_T * np.log(y_T + 1e-10), axis=-1), axis=0) | |
expected_p = np.mean(y_T, axis=0) | |
entropy_expected_p = - np.sum(expected_p * np.log(expected_p + 1e-10), axis=-1) | |
return (entropy_expected_p - expected_entropy) | |
def sample_by_bald_difficulty(tokenizer, X, y_mean, y_var, y, num_samples, num_classes, y_T): | |
logger.info ("Sampling by difficulty BALD acquisition function") | |
BALD_acq = get_BALD_acquisition(y_T) | |
p_norm = np.maximum(np.zeros(len(BALD_acq)), BALD_acq) | |
p_norm = p_norm / np.sum(p_norm) | |
indices = np.random.choice(len(X['input_ids']), num_samples, p=p_norm, replace=False) | |
X_s = {"input_ids": X["input_ids"][indices], "token_type_ids": X["token_type_ids"][indices], "attention_mask": X["attention_mask"][indices]} | |
y_s = y[indices] | |
w_s = y_var[indices][:,0] | |
return X_s, y_s, w_s | |
def sample_by_bald_easiness(tokenizer, X, y_mean, y_var, y, num_samples, num_classes, y_T): | |
logger.info ("Sampling by easy BALD acquisition function") | |
BALD_acq = get_BALD_acquisition(y_T) | |
p_norm = np.maximum(np.zeros(len(BALD_acq)), (1. - BALD_acq)/np.sum(1. - BALD_acq)) | |
p_norm = p_norm / np.sum(p_norm) | |
logger.info (p_norm[:10]) | |
indices = np.random.choice(len(X['input_ids']), num_samples, p=p_norm, replace=False) | |
X_s = {"input_ids": X["input_ids"][indices], "token_type_ids": X["token_type_ids"][indices], "attention_mask": X["attention_mask"][indices]} | |
y_s = y[indices] | |
w_s = y_var[indices][:,0] | |
return X_s, y_s, w_s | |
def sample_by_bald_class_easiness(tokenizer, X, y_mean, y_var, y, num_samples, num_classes, y_T): | |
logger.info ("Sampling by easy BALD acquisition function per class") | |
BALD_acq = get_BALD_acquisition(y_T) | |
BALD_acq = (1. - BALD_acq)/np.sum(1. - BALD_acq) | |
logger.info (BALD_acq) | |
samples_per_class = num_samples // num_classes | |
X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, X_s_mask_pos, y_s, w_s = [], [], [], [], [], [] | |
for label in range(num_classes): | |
# X_input_ids, X_token_type_ids, X_attention_mask = np.array(X['input_ids'])[y == label], np.array(X['token_type_ids'])[y == label], np.array(X['attention_mask'])[y == label] | |
X_input_ids, X_attention_mask = np.array(X['input_ids'])[y == label], np.array(X['attention_mask'])[y == label] | |
if "token_type_ids" in X.features: | |
X_token_type_ids = np.array(X['token_type_ids'])[y == label] | |
if "mask_pos" in X.features: | |
X_mask_pos = np.array(X['mask_pos'])[y == label] | |
y_ = y[y==label] | |
y_var_ = y_var[y == label] | |
# p = y_mean[y == label] | |
p_norm = BALD_acq[y==label] | |
p_norm = np.maximum(np.zeros(len(p_norm)), p_norm) | |
p_norm = p_norm/np.sum(p_norm) | |
if len(X_input_ids) < samples_per_class: | |
logger.info ("Sampling with replacement.") | |
replace = True | |
else: | |
replace = False | |
if len(X_input_ids) == 0: # add by wjn | |
continue | |
indices = np.random.choice(len(X_input_ids), samples_per_class, p=p_norm, replace=replace) | |
X_s_input_ids.extend(X_input_ids[indices]) | |
# X_s_token_type_ids.extend(X_token_type_ids[indices]) | |
X_s_attention_mask.extend(X_attention_mask[indices]) | |
if "token_type_ids" in X.features: | |
X_s_token_type_ids.extend(X_token_type_ids[indices]) | |
if "mask_pos" in X.features: | |
X_s_mask_pos.extend(X_mask_pos[indices]) | |
y_s.extend(y_[indices]) | |
w_s.extend(y_var_[indices][:,0]) | |
# X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s) | |
if "token_type_ids" in X.features and "mask_pos" not in X.features: | |
X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s) | |
elif "token_type_ids" not in X.features and "mask_pos" in X.features: | |
X_s_input_ids, X_s_mask_pos, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_mask_pos, X_s_attention_mask, y_s, w_s) | |
elif "token_type_ids" in X.features and "mask_pos" in X.features: | |
X_s_input_ids, X_s_token_type_ids, X_s_mask_pos, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_token_type_ids, X_s_mask_pos, X_s_attention_mask, y_s, w_s) | |
else: | |
X_s_input_ids, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_attention_mask, y_s, w_s) | |
pseudo_labeled_input = { | |
'input_ids': np.array(X_s_input_ids), | |
'attention_mask': np.array(X_s_attention_mask) | |
} | |
if "token_type_ids" in X.features: | |
pseudo_labeled_input['token_type_ids'] = np.array(X_s_token_type_ids) | |
if "mask_pos" in X.features: | |
pseudo_labeled_input['mask_pos'] = np.array(X_s_mask_pos) | |
return pseudo_labeled_input, np.array(y_s), np.array(w_s) | |
def sample_by_bald_class_difficulty(tokenizer, X, y_mean, y_var, y, num_samples, num_classes, y_T): | |
logger.info ("Sampling by difficulty BALD acquisition function per class") | |
BALD_acq = get_BALD_acquisition(y_T) | |
samples_per_class = num_samples // num_classes | |
X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s = [], [], [], [], [] | |
for label in range(num_classes): | |
X_input_ids, X_token_type_ids, X_attention_mask = X['input_ids'][y == label], X['token_type_ids'][y == label], X['attention_mask'][y == label] | |
y_ = y[y==label] | |
y_var_ = y_var[y == label] | |
p_norm = BALD_acq[y==label] | |
p_norm = np.maximum(np.zeros(len(p_norm)), p_norm) | |
p_norm = p_norm/np.sum(p_norm) | |
if len(X_input_ids) < samples_per_class: | |
replace = True | |
logger.info ("Sampling with replacement.") | |
else: | |
replace = False | |
indices = np.random.choice(len(X_input_ids), samples_per_class, p=p_norm, replace=replace) | |
X_s_input_ids.extend(X_input_ids[indices]) | |
X_s_token_type_ids.extend(X_token_type_ids[indices]) | |
X_s_attention_mask.extend(X_attention_mask[indices]) | |
y_s.extend(y_[indices]) | |
w_s.extend(y_var_[indices][:,0]) | |
X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s) | |
return {'input_ids': np.array(X_s_input_ids), 'token_type_ids': np.array(X_s_token_type_ids), 'attention_mask': np.array(X_s_attention_mask)}, np.array(y_s), np.array(w_s) | |