# -*- coding: utf-8 -*- # @Time : 2023/04/18 08:11 p.m. # @Author : JianingWang # @File : uncertainty.py from sklearn.utils import shuffle import logging import numpy as np import os import random logger = logging.getLogger(__name__) def get_BALD_acquisition(y_T): expected_entropy = - np.mean(np.sum(y_T * np.log(y_T + 1e-10), axis=-1), axis=0) expected_p = np.mean(y_T, axis=0) entropy_expected_p = - np.sum(expected_p * np.log(expected_p + 1e-10), axis=-1) return (entropy_expected_p - expected_entropy) def sample_by_bald_difficulty(tokenizer, X, y_mean, y_var, y, num_samples, num_classes, y_T): logger.info ("Sampling by difficulty BALD acquisition function") BALD_acq = get_BALD_acquisition(y_T) p_norm = np.maximum(np.zeros(len(BALD_acq)), BALD_acq) p_norm = p_norm / np.sum(p_norm) indices = np.random.choice(len(X['input_ids']), num_samples, p=p_norm, replace=False) X_s = {"input_ids": X["input_ids"][indices], "token_type_ids": X["token_type_ids"][indices], "attention_mask": X["attention_mask"][indices]} y_s = y[indices] w_s = y_var[indices][:,0] return X_s, y_s, w_s def sample_by_bald_easiness(tokenizer, X, y_mean, y_var, y, num_samples, num_classes, y_T): logger.info ("Sampling by easy BALD acquisition function") BALD_acq = get_BALD_acquisition(y_T) p_norm = np.maximum(np.zeros(len(BALD_acq)), (1. - BALD_acq)/np.sum(1. - BALD_acq)) p_norm = p_norm / np.sum(p_norm) logger.info (p_norm[:10]) indices = np.random.choice(len(X['input_ids']), num_samples, p=p_norm, replace=False) X_s = {"input_ids": X["input_ids"][indices], "token_type_ids": X["token_type_ids"][indices], "attention_mask": X["attention_mask"][indices]} y_s = y[indices] w_s = y_var[indices][:,0] return X_s, y_s, w_s def sample_by_bald_class_easiness(tokenizer, X, y_mean, y_var, y, num_samples, num_classes, y_T): logger.info ("Sampling by easy BALD acquisition function per class") BALD_acq = get_BALD_acquisition(y_T) BALD_acq = (1. - BALD_acq)/np.sum(1. - BALD_acq) logger.info (BALD_acq) samples_per_class = num_samples // num_classes X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, X_s_mask_pos, y_s, w_s = [], [], [], [], [], [] for label in range(num_classes): # X_input_ids, X_token_type_ids, X_attention_mask = np.array(X['input_ids'])[y == label], np.array(X['token_type_ids'])[y == label], np.array(X['attention_mask'])[y == label] X_input_ids, X_attention_mask = np.array(X['input_ids'])[y == label], np.array(X['attention_mask'])[y == label] if "token_type_ids" in X.features: X_token_type_ids = np.array(X['token_type_ids'])[y == label] if "mask_pos" in X.features: X_mask_pos = np.array(X['mask_pos'])[y == label] y_ = y[y==label] y_var_ = y_var[y == label] # p = y_mean[y == label] p_norm = BALD_acq[y==label] p_norm = np.maximum(np.zeros(len(p_norm)), p_norm) p_norm = p_norm/np.sum(p_norm) if len(X_input_ids) < samples_per_class: logger.info ("Sampling with replacement.") replace = True else: replace = False if len(X_input_ids) == 0: # add by wjn continue indices = np.random.choice(len(X_input_ids), samples_per_class, p=p_norm, replace=replace) X_s_input_ids.extend(X_input_ids[indices]) # X_s_token_type_ids.extend(X_token_type_ids[indices]) X_s_attention_mask.extend(X_attention_mask[indices]) if "token_type_ids" in X.features: X_s_token_type_ids.extend(X_token_type_ids[indices]) if "mask_pos" in X.features: X_s_mask_pos.extend(X_mask_pos[indices]) y_s.extend(y_[indices]) w_s.extend(y_var_[indices][:,0]) # X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s) if "token_type_ids" in X.features and "mask_pos" not in X.features: X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s) elif "token_type_ids" not in X.features and "mask_pos" in X.features: X_s_input_ids, X_s_mask_pos, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_mask_pos, X_s_attention_mask, y_s, w_s) elif "token_type_ids" in X.features and "mask_pos" in X.features: X_s_input_ids, X_s_token_type_ids, X_s_mask_pos, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_token_type_ids, X_s_mask_pos, X_s_attention_mask, y_s, w_s) else: X_s_input_ids, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_attention_mask, y_s, w_s) pseudo_labeled_input = { 'input_ids': np.array(X_s_input_ids), 'attention_mask': np.array(X_s_attention_mask) } if "token_type_ids" in X.features: pseudo_labeled_input['token_type_ids'] = np.array(X_s_token_type_ids) if "mask_pos" in X.features: pseudo_labeled_input['mask_pos'] = np.array(X_s_mask_pos) return pseudo_labeled_input, np.array(y_s), np.array(w_s) def sample_by_bald_class_difficulty(tokenizer, X, y_mean, y_var, y, num_samples, num_classes, y_T): logger.info ("Sampling by difficulty BALD acquisition function per class") BALD_acq = get_BALD_acquisition(y_T) samples_per_class = num_samples // num_classes X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s = [], [], [], [], [] for label in range(num_classes): X_input_ids, X_token_type_ids, X_attention_mask = X['input_ids'][y == label], X['token_type_ids'][y == label], X['attention_mask'][y == label] y_ = y[y==label] y_var_ = y_var[y == label] p_norm = BALD_acq[y==label] p_norm = np.maximum(np.zeros(len(p_norm)), p_norm) p_norm = p_norm/np.sum(p_norm) if len(X_input_ids) < samples_per_class: replace = True logger.info ("Sampling with replacement.") else: replace = False indices = np.random.choice(len(X_input_ids), samples_per_class, p=p_norm, replace=replace) X_s_input_ids.extend(X_input_ids[indices]) X_s_token_type_ids.extend(X_token_type_ids[indices]) X_s_attention_mask.extend(X_attention_mask[indices]) y_s.extend(y_[indices]) w_s.extend(y_var_[indices][:,0]) X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s) return {'input_ids': np.array(X_s_input_ids), 'token_type_ids': np.array(X_s_token_type_ids), 'attention_mask': np.array(X_s_attention_mask)}, np.array(y_s), np.array(w_s)