|
import collections |
|
import numpy as np |
|
from tqdm.auto import tqdm |
|
|
|
|
|
def postprocess_token_span_predictions( |
|
features, |
|
examples, |
|
raw_predictions, |
|
tokenizer, |
|
n_best_size=20, |
|
max_answer_length=30, |
|
squad_v2=False, |
|
): |
|
all_start_logits, all_end_logits, token_logits = raw_predictions |
|
|
|
example_id_to_index = {k: i for i, k in enumerate(list(examples["id"]))} |
|
features_per_example = collections.defaultdict(list) |
|
for i, feature in enumerate(features): |
|
features_per_example[example_id_to_index[feature["example_id"]]].append(i) |
|
|
|
|
|
predictions = collections.OrderedDict() |
|
|
|
|
|
print( |
|
f"Post-processing {len(examples)} example predictions split into {len(features)} features." |
|
) |
|
|
|
|
|
for example_index in tqdm(range(len(examples))): |
|
|
|
feature_indices = features_per_example[example_index] |
|
|
|
min_null_score = None |
|
valid_answers = [] |
|
|
|
context = examples[example_index]["context"] |
|
|
|
for feature_index in feature_indices: |
|
|
|
start_logits = all_start_logits[feature_index] |
|
end_logits = all_end_logits[feature_index] |
|
|
|
|
|
offset_mapping = features[feature_index]["offset_mapping"] |
|
|
|
|
|
cls_index = features[feature_index]["input_ids"].index( |
|
tokenizer.cls_token_id |
|
) |
|
feature_null_score = start_logits[cls_index] + end_logits[cls_index] |
|
if min_null_score is None or min_null_score < feature_null_score: |
|
min_null_score = feature_null_score |
|
|
|
|
|
start_indexes = np.argsort(start_logits)[ |
|
-1 : -n_best_size - 1 : -1 |
|
].tolist() |
|
end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist() |
|
for start_index in start_indexes: |
|
for end_index in end_indexes: |
|
|
|
|
|
if ( |
|
start_index >= len(offset_mapping) |
|
or end_index >= len(offset_mapping) |
|
or offset_mapping[start_index] is None |
|
or offset_mapping[end_index] is None |
|
): |
|
continue |
|
|
|
if ( |
|
end_index < start_index |
|
or end_index - start_index + 1 > max_answer_length |
|
): |
|
continue |
|
|
|
start_char = offset_mapping[start_index][0] |
|
end_char = offset_mapping[end_index][1] |
|
valid_answers.append( |
|
{ |
|
"qa_score": ( |
|
start_logits[start_index] + end_logits[end_index] |
|
) |
|
/ 2, |
|
"token_score": np.mean( |
|
[ |
|
token_logits[example_index][token_index][1] |
|
for token_index in range(start_index, end_index + 1) |
|
] |
|
), |
|
"score": (start_logits[start_index] + end_logits[end_index]) |
|
/ 2 |
|
+ np.mean( |
|
[ |
|
token_logits[example_index][token_index][1] |
|
for token_index in range(start_index, end_index + 1) |
|
] |
|
), |
|
"text": context[start_char:end_char], |
|
"start": start_char, |
|
"end": end_char, |
|
} |
|
) |
|
|
|
if len(valid_answers) > 0: |
|
sorted_answers = sorted( |
|
valid_answers, key=lambda x: x["score"], reverse=True |
|
) |
|
else: |
|
|
|
|
|
sorted_answers = [{"text": "", "score": 0.0, "start": None, "end": None}] |
|
|
|
if sorted_answers[0]["score"] <= min_null_score: |
|
sorted_answers = [ |
|
{"text": "", "score": min_null_score, "start": None, "end": None}, |
|
] + sorted_answers |
|
predictions[examples[example_index]["id"]] = sorted_answers |
|
|
|
return predictions |
|
|
|
|
|
def postprocess_multi_span_predictions( |
|
features, |
|
examples, |
|
raw_predictions, |
|
tokenizer, |
|
n_best_size=20, |
|
max_answer_length=30, |
|
squad_v2=False, |
|
): |
|
|
|
all_start_logits, all_end_logits = raw_predictions |
|
|
|
example_id_to_index = {k: i for i, k in enumerate(list(examples["id"]))} |
|
features_per_example = collections.defaultdict(list) |
|
for i, feature in enumerate(features): |
|
features_per_example[example_id_to_index[feature["example_id"]]].append(i) |
|
|
|
|
|
predictions = collections.OrderedDict() |
|
|
|
|
|
print( |
|
f"Post-processing {len(examples)} example predictions split into {len(features)} features." |
|
) |
|
|
|
|
|
for example_index in tqdm(range(len(examples))): |
|
|
|
feature_indices = features_per_example[example_index] |
|
|
|
min_null_score = None |
|
valid_answers = [] |
|
|
|
context = examples[example_index]["context"] |
|
|
|
for feature_index in feature_indices: |
|
|
|
start_logits = all_start_logits[feature_index] |
|
end_logits = all_end_logits[feature_index] |
|
|
|
|
|
offset_mapping = features[feature_index]["offset_mapping"] |
|
|
|
|
|
cls_index = features[feature_index]["input_ids"].index( |
|
tokenizer.cls_token_id |
|
) |
|
feature_null_score = start_logits[cls_index] + end_logits[cls_index] |
|
if min_null_score is None or min_null_score < feature_null_score: |
|
min_null_score = feature_null_score |
|
|
|
|
|
start_indexes = np.argsort(start_logits)[ |
|
-1 : -n_best_size - 1 : -1 |
|
].tolist() |
|
end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist() |
|
for start_index in start_indexes: |
|
for end_index in end_indexes: |
|
|
|
|
|
|
|
if ( |
|
start_index >= len(offset_mapping) |
|
or end_index >= len(offset_mapping) |
|
or offset_mapping[start_index] is None |
|
or offset_mapping[end_index] is None |
|
): |
|
continue |
|
|
|
|
|
if ( |
|
end_index < start_index |
|
or end_index - start_index + 1 > max_answer_length |
|
): |
|
continue |
|
|
|
start_char = offset_mapping[start_index][0] |
|
end_char = offset_mapping[end_index][1] |
|
valid_answers.append( |
|
{ |
|
"score": start_logits[start_index] + end_logits[end_index], |
|
"text": context[start_char:end_char], |
|
"start": start_char, |
|
"end": end_char, |
|
} |
|
) |
|
|
|
if len(valid_answers) > 0: |
|
sorted_answers = sorted( |
|
valid_answers, key=lambda x: x["score"], reverse=True |
|
) |
|
else: |
|
|
|
|
|
sorted_answers = [{"text": "", "score": 0.0, "start": None, "end": None}] |
|
|
|
|
|
|
|
if sorted_answers[0]["score"] <= min_null_score: |
|
sorted_answers = [ |
|
{"text": "", "score": min_null_score, "start": None, "end": None}, |
|
] + sorted_answers |
|
|
|
predictions[examples[example_index]["id"]] = sorted_answers |
|
|
|
return predictions |