import os
import zipfile
import gradio as gr
import nltk
import pandas as pd
import requests
from pyabsa import TADCheckpointManager
from textattack.attack_recipes import (
BAEGarg2019,
PWWSRen2019,
TextFoolerJin2019,
PSOZang2020,
IGAWang2019,
GeneticAlgorithmAlzantot2018,
DeepWordBugGao2018,
CLARE2020,
)
from textattack.attack_results import SuccessfulAttackResult
from utils import SentAttacker, get_agnews_example, get_sst2_example, get_amazon_example, get_imdb_example, diff_texts
# from utils import get_yahoo_example
import difflib
def render_diff(text1, text2):
diff = difflib.ndiff(text1.splitlines(keepends=True), text2.splitlines(keepends=True))
return ''.join([line.replace(' ', ' ').replace('\n', '
') for line in diff])
def update_diff(original_example, repaired_example, adv_example):
ori_diff = render_diff(original_example, original_example)
adv_diff = render_diff(original_example, adv_example)
restored_diff = render_diff(original_example, repaired_example)
return ori_diff, adv_diff, restored_diff
sent_attackers = {}
tad_classifiers = {}
attack_recipes = {
"bae": BAEGarg2019,
"pwws": PWWSRen2019,
"textfooler": TextFoolerJin2019,
"pso": PSOZang2020,
"iga": IGAWang2019,
"ga": GeneticAlgorithmAlzantot2018,
"deepwordbug": DeepWordBugGao2018,
"clare": CLARE2020,
}
def init():
nltk.download("omw-1.4")
if not os.path.exists("TAD-SST2"):
z = zipfile.ZipFile("checkpoints.zip", "r")
z.extractall(os.getcwd())
for attacker in ["pwws", "bae", "textfooler", "deepwordbug"]:
for dataset in [
"agnews10k",
"sst2",
"MR",
'imdb'
]:
if "tad-{}".format(dataset) not in tad_classifiers:
tad_classifiers[
"tad-{}".format(dataset)
] = TADCheckpointManager.get_tad_text_classifier(
"tad-{}".format(dataset).upper()
)
sent_attackers["tad-{}{}".format(dataset, attacker)] = SentAttacker(
tad_classifiers["tad-{}".format(dataset)], attack_recipes[attacker]
)
tad_classifiers["tad-{}".format(dataset)].sent_attacker = sent_attackers[
"tad-{}pwws".format(dataset)
]
cache = set()
def generate_adversarial_example(dataset, attacker, text=None, label=None):
if not text or text in cache:
if "agnews" in dataset.lower():
text, label = get_agnews_example()
elif "sst2" in dataset.lower():
text, label = get_sst2_example()
elif "MR" in dataset.lower():
text, label = get_amazon_example()
# elif "yahoo" in dataset.lower():
# text, label = get_yahoo_example()
elif "imdb" in dataset.lower():
text, label = get_imdb_example()
cache.add(text)
result = None
attack_result = sent_attackers[
"tad-{}{}".format(dataset.lower(), attacker.lower())
].attacker.simple_attack(text, int(label))
if isinstance(attack_result, SuccessfulAttackResult):
if (
attack_result.perturbed_result.output
!= attack_result.original_result.ground_truth_output
) and (
attack_result.original_result.output
== attack_result.original_result.ground_truth_output
):
# with defense
result = tad_classifiers["tad-{}".format(dataset.lower())].infer(
attack_result.perturbed_result.attacked_text.text
+ "$LABEL${},{},{}".format(
attack_result.original_result.ground_truth_output,
1,
attack_result.perturbed_result.output,
),
print_result=True,
defense=attacker,
)
if result:
classification_df = {}
classification_df["is_repaired"] = result["is_fixed"]
classification_df["pred_label"] = result["label"]
classification_df["confidence"] = round(result["confidence"], 3)
classification_df["is_correct"] = str(result["pred_label"]) == str(label)
advdetection_df = {}
if result["is_adv_label"] != "0":
advdetection_df["is_adversarial"] = {
"0": False,
"1": True,
0: False,
1: True,
}[result["is_adv_label"]]
advdetection_df["perturbed_label"] = result["perturbed_label"]
advdetection_df["confidence"] = round(result["is_adv_confidence"], 3)
advdetection_df['ref_is_attack'] = result['ref_is_adv_label']
advdetection_df['is_correct'] = result['ref_is_adv_check']
else:
return generate_adversarial_example(dataset, attacker)
return (
text,
label,
result["restored_text"],
result["label"],
attack_result.perturbed_result.attacked_text.text,
diff_texts(text, text),
diff_texts(text, attack_result.perturbed_result.attacked_text.text),
diff_texts(text, result["restored_text"]),
attack_result.perturbed_result.output,
pd.DataFrame(classification_df, index=[0]),
pd.DataFrame(advdetection_df, index=[0]),
)
def run_demo(dataset, attacker, text=None, label=None):
try:
data = {
"dataset": dataset,
"attacker": attacker,
"text": text,
"label": label,
}
response = requests.post('https://rpddemo.pagekite.me/api/generate_adversarial_example', json=data)
result = response.json()
print(response.json())
return (
result["text"],
result["label"],
result["restored_text"],
result["result_label"],
result["perturbed_text"],
result["text_diff"],
result["perturbed_diff"],
result["restored_diff"],
result["output"],
pd.DataFrame(result["classification_df"]),
pd.DataFrame(result["advdetection_df"]),
result["message"]
)
except Exception as e:
print(e)
return generate_adversarial_example(dataset, attacker, text, label)
def check_gpu():
try:
response = requests.post('https://rpddemo.pagekite.me/api/generate_adversarial_example', timeout=3)
if response.status_code < 500:
return 'GPU available'
else:
return 'GPU not available'
except Exception as e:
return 'GPU not available'
if __name__ == "__main__":
try:
init()
except Exception as e:
print(e)
print("Failed to initialize the demo. Please try again later.")
demo = gr.Blocks()
with demo:
gr.Markdown("
The (+) and (-) in the boxes indicate the added and deleted characters in the adversarial example compared to the original input natural example.
""") ori_text_diff = gr.HTML(label="The Original Natural Example", value="") adv_text_diff = gr.HTML(label="Character Editions of Adversarial Example Compared to the Natural Example", value="") restored_text_diff = gr.HTML(label="Character Editions of Repaired Adversarial Example Compared to the Natural Example", value="") gr.Markdown( "##