|
""" |
|
|
|
GoalFunctionResult class |
|
==================================== |
|
|
|
""" |
|
|
|
from abc import ABC, abstractmethod |
|
|
|
import torch |
|
|
|
from textattack.shared import utils |
|
|
|
|
|
class GoalFunctionResultStatus: |
|
SUCCEEDED = 0 |
|
SEARCHING = 1 |
|
MAXIMIZING = 2 |
|
SKIPPED = 3 |
|
|
|
|
|
class GoalFunctionResult(ABC): |
|
"""Represents the result of a goal function evaluating a AttackedText |
|
object. |
|
|
|
Args: |
|
attacked_text: The sequence that was evaluated. |
|
output: The display-friendly output. |
|
goal_status: The ``GoalFunctionResultStatus`` representing the status of the achievement of the goal. |
|
score: A score representing how close the model is to achieving its goal. |
|
num_queries: How many model queries have been used |
|
ground_truth_output: The ground truth output |
|
""" |
|
|
|
def __init__( |
|
self, |
|
attacked_text, |
|
raw_output, |
|
output, |
|
goal_status, |
|
score, |
|
num_queries, |
|
ground_truth_output, |
|
goal_function_result_type="", |
|
): |
|
self.attacked_text = attacked_text |
|
self.raw_output = raw_output |
|
self.output = output |
|
self.score = score |
|
self.goal_status = goal_status |
|
self.num_queries = num_queries |
|
self.ground_truth_output = ground_truth_output |
|
self.goal_function_result_type = goal_function_result_type |
|
|
|
if isinstance(self.raw_output, torch.Tensor): |
|
self.raw_output = self.raw_output.numpy() |
|
|
|
if isinstance(self.score, torch.Tensor): |
|
self.score = self.score.item() |
|
|
|
def __repr__(self): |
|
main_str = "GoalFunctionResult( " |
|
lines = [] |
|
lines.append( |
|
utils.add_indent( |
|
f"(goal_function_result_type): {self.goal_function_result_type}", 2 |
|
) |
|
) |
|
lines.append(utils.add_indent(f"(attacked_text): {self.attacked_text.text}", 2)) |
|
lines.append( |
|
utils.add_indent(f"(ground_truth_output): {self.ground_truth_output}", 2) |
|
) |
|
lines.append(utils.add_indent(f"(model_output): {self.output}", 2)) |
|
lines.append(utils.add_indent(f"(score): {self.score}", 2)) |
|
main_str += "\n " + "\n ".join(lines) + "\n" |
|
main_str += ")" |
|
return main_str |
|
|
|
@abstractmethod |
|
def get_text_color_input(self): |
|
"""A string representing the color this result's changed portion should |
|
be if it represents the original input.""" |
|
raise NotImplementedError() |
|
|
|
@abstractmethod |
|
def get_text_color_perturbed(self): |
|
"""A string representing the color this result's changed portion should |
|
be if it represents the perturbed input.""" |
|
raise NotImplementedError() |
|
|
|
@abstractmethod |
|
def get_colored_output(self, color_method=None): |
|
"""Returns a string representation of this result's output, colored |
|
according to `color_method`.""" |
|
raise NotImplementedError() |
|
|