Spaces:
Sleeping
Sleeping
File size: 6,273 Bytes
45311fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# -*- coding: utf-8 -*-
# @Time : 2023/04/18 08:11 p.m.
# @Author : JianingWang
# @File : uncertainty.py
from sklearn.utils import shuffle
import logging
import numpy as np
import os
import random
logger = logging.getLogger(__name__)
def get_BALD_acquisition(y_T):
expected_entropy = - np.mean(np.sum(y_T * np.log(y_T + 1e-10), axis=-1), axis=0)
expected_p = np.mean(y_T, axis=0)
entropy_expected_p = - np.sum(expected_p * np.log(expected_p + 1e-10), axis=-1)
return (entropy_expected_p - expected_entropy)
def sample_by_bald_difficulty(tokenizer, X, y_mean, y_var, y, num_samples, num_classes, y_T):
logger.info ("Sampling by difficulty BALD acquisition function")
BALD_acq = get_BALD_acquisition(y_T)
p_norm = np.maximum(np.zeros(len(BALD_acq)), BALD_acq)
p_norm = p_norm / np.sum(p_norm)
indices = np.random.choice(len(X['input_ids']), num_samples, p=p_norm, replace=False)
X_s = {"input_ids": X["input_ids"][indices], "token_type_ids": X["token_type_ids"][indices], "attention_mask": X["attention_mask"][indices]}
y_s = y[indices]
w_s = y_var[indices][:,0]
return X_s, y_s, w_s
def sample_by_bald_easiness(tokenizer, X, y_mean, y_var, y, num_samples, num_classes, y_T):
logger.info ("Sampling by easy BALD acquisition function")
BALD_acq = get_BALD_acquisition(y_T)
p_norm = np.maximum(np.zeros(len(BALD_acq)), (1. - BALD_acq)/np.sum(1. - BALD_acq))
p_norm = p_norm / np.sum(p_norm)
logger.info (p_norm[:10])
indices = np.random.choice(len(X['input_ids']), num_samples, p=p_norm, replace=False)
X_s = {"input_ids": X["input_ids"][indices], "token_type_ids": X["token_type_ids"][indices], "attention_mask": X["attention_mask"][indices]}
y_s = y[indices]
w_s = y_var[indices][:,0]
return X_s, y_s, w_s
def sample_by_bald_class_easiness(tokenizer, X, y_mean, y_var, y, num_samples, num_classes, y_T):
logger.info ("Sampling by easy BALD acquisition function per class")
BALD_acq = get_BALD_acquisition(y_T)
BALD_acq = (1. - BALD_acq)/np.sum(1. - BALD_acq)
logger.info (BALD_acq)
samples_per_class = num_samples // num_classes
X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, X_s_mask_pos, y_s, w_s = [], [], [], [], [], []
for label in range(num_classes):
# X_input_ids, X_token_type_ids, X_attention_mask = np.array(X['input_ids'])[y == label], np.array(X['token_type_ids'])[y == label], np.array(X['attention_mask'])[y == label]
X_input_ids, X_attention_mask = np.array(X['input_ids'])[y == label], np.array(X['attention_mask'])[y == label]
if "token_type_ids" in X.features:
X_token_type_ids = np.array(X['token_type_ids'])[y == label]
if "mask_pos" in X.features:
X_mask_pos = np.array(X['mask_pos'])[y == label]
y_ = y[y==label]
y_var_ = y_var[y == label]
# p = y_mean[y == label]
p_norm = BALD_acq[y==label]
p_norm = np.maximum(np.zeros(len(p_norm)), p_norm)
p_norm = p_norm/np.sum(p_norm)
if len(X_input_ids) < samples_per_class:
logger.info ("Sampling with replacement.")
replace = True
else:
replace = False
if len(X_input_ids) == 0: # add by wjn
continue
indices = np.random.choice(len(X_input_ids), samples_per_class, p=p_norm, replace=replace)
X_s_input_ids.extend(X_input_ids[indices])
# X_s_token_type_ids.extend(X_token_type_ids[indices])
X_s_attention_mask.extend(X_attention_mask[indices])
if "token_type_ids" in X.features:
X_s_token_type_ids.extend(X_token_type_ids[indices])
if "mask_pos" in X.features:
X_s_mask_pos.extend(X_mask_pos[indices])
y_s.extend(y_[indices])
w_s.extend(y_var_[indices][:,0])
# X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s)
if "token_type_ids" in X.features and "mask_pos" not in X.features:
X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s)
elif "token_type_ids" not in X.features and "mask_pos" in X.features:
X_s_input_ids, X_s_mask_pos, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_mask_pos, X_s_attention_mask, y_s, w_s)
elif "token_type_ids" in X.features and "mask_pos" in X.features:
X_s_input_ids, X_s_token_type_ids, X_s_mask_pos, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_token_type_ids, X_s_mask_pos, X_s_attention_mask, y_s, w_s)
else:
X_s_input_ids, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_attention_mask, y_s, w_s)
pseudo_labeled_input = {
'input_ids': np.array(X_s_input_ids),
'attention_mask': np.array(X_s_attention_mask)
}
if "token_type_ids" in X.features:
pseudo_labeled_input['token_type_ids'] = np.array(X_s_token_type_ids)
if "mask_pos" in X.features:
pseudo_labeled_input['mask_pos'] = np.array(X_s_mask_pos)
return pseudo_labeled_input, np.array(y_s), np.array(w_s)
def sample_by_bald_class_difficulty(tokenizer, X, y_mean, y_var, y, num_samples, num_classes, y_T):
logger.info ("Sampling by difficulty BALD acquisition function per class")
BALD_acq = get_BALD_acquisition(y_T)
samples_per_class = num_samples // num_classes
X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s = [], [], [], [], []
for label in range(num_classes):
X_input_ids, X_token_type_ids, X_attention_mask = X['input_ids'][y == label], X['token_type_ids'][y == label], X['attention_mask'][y == label]
y_ = y[y==label]
y_var_ = y_var[y == label]
p_norm = BALD_acq[y==label]
p_norm = np.maximum(np.zeros(len(p_norm)), p_norm)
p_norm = p_norm/np.sum(p_norm)
if len(X_input_ids) < samples_per_class:
replace = True
logger.info ("Sampling with replacement.")
else:
replace = False
indices = np.random.choice(len(X_input_ids), samples_per_class, p=p_norm, replace=replace)
X_s_input_ids.extend(X_input_ids[indices])
X_s_token_type_ids.extend(X_token_type_ids[indices])
X_s_attention_mask.extend(X_attention_mask[indices])
y_s.extend(y_[indices])
w_s.extend(y_var_[indices][:,0])
X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s)
return {'input_ids': np.array(X_s_input_ids), 'token_type_ids': np.array(X_s_token_type_ids), 'attention_mask': np.array(X_s_attention_mask)}, np.array(y_s), np.array(w_s)
|