File size: 2,657 Bytes
4943752 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
"""
ClassificationGoalFunctionResult Class
========================================
"""
import torch
import textattack
from textattack.shared import utils
from .goal_function_result import GoalFunctionResult
class ClassificationGoalFunctionResult(GoalFunctionResult):
"""Represents the result of a classification goal function."""
def __init__(
self,
attacked_text,
raw_output,
output,
goal_status,
score,
num_queries,
ground_truth_output,
):
super().__init__(
attacked_text,
raw_output,
output,
goal_status,
score,
num_queries,
ground_truth_output,
goal_function_result_type="Classification",
)
@property
def _processed_output(self):
"""Takes a model output (like `1`) and returns the class labeled output
(like `positive`), if possible.
Also returns the associated color.
"""
output_label = self.raw_output.argmax()
if self.attacked_text.attack_attrs.get("label_names") is not None:
output = self.attacked_text.attack_attrs["label_names"][self.output]
output = textattack.shared.utils.process_label_name(output)
color = textattack.shared.utils.color_from_output(output, output_label)
return output, color
else:
color = textattack.shared.utils.color_from_label(output_label)
return output_label, color
def get_text_color_input(self):
"""A string representing the color this result's changed portion should
be if it represents the original input."""
_, color = self._processed_output
return color
def get_text_color_perturbed(self):
"""A string representing the color this result's changed portion should
be if it represents the perturbed input."""
_, color = self._processed_output
return color
def get_colored_output(self, color_method=None):
"""Returns a string representation of this result's output, colored
according to `color_method`."""
output_label = self.raw_output.argmax()
confidence_score = self.raw_output[output_label]
if isinstance(confidence_score, torch.Tensor):
confidence_score = confidence_score.item()
output, color = self._processed_output
# concatenate with label and convert confidence score to percent, like '33%'
output_str = f"{output} ({confidence_score:.0%})"
return utils.color_text(output_str, color=color, method=color_method)
|