File size: 8,623 Bytes
f8b8183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e311b6
f8b8183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11a765e
f8b8183
 
11a765e
f8b8183
 
11a765e
f8b8183
 
11a765e
f8b8183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
"""
Module for detecting fallacies in text.

"""

import os
import re
import time
import json
import csv
from ast import literal_eval
from collections import namedtuple
import requests
from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.chat_models.huggingface import ChatHuggingFace
from langchain.agents import AgentExecutor, load_tools, create_react_agent
from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser
from langchain.tools import Tool
from langchain.tools import DuckDuckGoSearchRun
from .templates import (
    REACT,
    INCONTEXT,
    SUMMARIZATION,
    CONCLUDING,
    CONCLUDING_INCONTEXT,
)
from .definitions import DEFINITIONS
from .examples import FALLACY_CLAIMS, DEBUNKINGS

os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.environ.get("HF_API_KEY")


class HamburgerStyle:
    def __init__(self):
        # hamburger-style structure:
        self.heading = namedtuple("Heading", ["name", "content"])
        self.hamburger = [
            self.heading(name="Myth", content=None),
            self.heading(name="##FACT", content=None),
            self.heading(name="##MYTH", content=None),
            self.heading(name="##FALLACY", content=None),
            self.heading(name="##FACT", content=None),
        ]
        self.llm = HuggingFaceEndpoint(
            repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
            temperature=1,
            top_k=1,
            model_kwargs={
                "use_cache": False,
            },
        )
        self.flicc_model = "fzanartu/flicc"
        self.card_model = "crarojasca/BinaryAugmentedCARDS"
        self.semantic_textual_similarity = "sentence-transformers/all-MiniLM-L6-v2"
        self.taxonomy_cards = "crarojasca/TaxonomyAugmentedCARDS"

        self.dirname = os.path.dirname(os.path.abspath("__file__"))
        self.filename = os.path.join(self.dirname, "utils/climate_fever_cards.csv")

    def generate_st_layer(self, misinformation):
        ## FACT: ReAct
        prompt = REACT

        chat_model = ChatHuggingFace(llm=self.llm)

        # define the agent
        chat_model_with_stop = chat_model.bind(stop=["\nObservation"])
        agent = (
            {
                "input": lambda x: x["input"],
                "agent_scratchpad": lambda x: format_log_to_str(
                    x["intermediate_steps"]
                ),
            }
            | prompt
            | chat_model
            | ReActJsonSingleInputOutputParser()
        )

        search = DuckDuckGoSearchRun()

        tools = [
            Tool(
                name="google_search",
                description="Search Google for recent results.",
                func=search.run,
            )
        ]

        agent = create_react_agent(chat_model_with_stop, tools, prompt)

        agent_executor = AgentExecutor(
            agent=agent, tools=tools, verbose=False, handle_parsing_errors=True
        )

        return agent_executor.invoke({"input": misinformation}).get("output")

    def generate_nd_layer(self, misinformation):
        ## MYTH: Summ
        prompt = SUMMARIZATION
        chain = prompt | self.llm
        return chain.invoke({"text": misinformation})

    def generate_rd_layer(self, misinformation):
        ## FALLACY: Fallacy

        # 1 predict fallacy label in FLICC taxonomy
        fallacy_label = self.endpoint_query(
            model=self.flicc_model, payload=misinformation
        )[0][0].get("label")
        fallacy_definition = DEFINITIONS.get(fallacy_label)

        # 2 get all examples with the same label
        claims = FALLACY_CLAIMS.get(fallacy_label, None)

        # 3 get cosine similarity for all claims and myth
        example_myths = self.endpoint_query(
            payload={"source_sentence": misinformation, "sentences": claims},
            model=self.semantic_textual_similarity,
        )

        # 3 # get most similar claim and FACT
        max_similarity = example_myths.index(max(example_myths))
        example_myth = claims[max_similarity]
        example_response = DEBUNKINGS.get(claims[max_similarity])
        fact = re.findall(r"## FALLACY:.*?(?=##)", example_response, re.DOTALL)[
            0
        ]  # get only the fallacy layer if the example.
        fact = fact.replace("## FALLACY:", "")

        prompt = INCONTEXT
        chain = prompt | self.llm
        return chain.invoke(
            {
                "misinformation": misinformation,
                "example_myth": example_myth,
                "example_response": fact,
                "fallacy": fallacy_label,
                "fallacy_definition": fallacy_definition,
                "factual_information": self.hamburger[1].content,
            }
        )

    def generate_th_layer(self, misinformation):

        ## FACT: Concluding
        cards_label = self.endpoint_query(
            model=self.taxonomy_cards, payload=misinformation
        )[0][0].get("label")
        # 1 get all claims with same label from FEVER dataset
        claims = self.get_fever_claims(cards_label)  # TODO
        prompt_completition = {"fact": self.hamburger[1].content}
        if claims:
            prompt = CONCLUDING_INCONTEXT
            example_myths = self.endpoint_query(
                payload={
                    "input": {"source_sentence": misinformation, "sentences": claims}
                },
                model=self.semantic_textual_similarity,
            )
            max_similarity = example_myths.index(max(example_myths))
            example_myth = claims[max_similarity]
            complementary_details = self.get_fever_evidence(example_myth)  # TODO
            prompt_completition.update({"complementary_details": complementary_details})
        else:
            prompt = CONCLUDING

        chain = prompt | self.llm

        return chain.invoke(prompt_completition)

    def rebuttal_generator(self, misinformation):

        # generate rebuttal
        self.hamburger[0] = self.hamburger[0]._replace(content=misinformation)
        self.hamburger[1] = self.hamburger[1]._replace(
            content=self.generate_st_layer(misinformation).strip()
        )
        self.hamburger[2] = self.hamburger[2]._replace(
            content=self.generate_nd_layer(misinformation).strip()
        )
        self.hamburger[3] = self.hamburger[3]._replace(
            content=self.generate_rd_layer(misinformation).strip()
        )
        self.hamburger[4] = self.hamburger[4]._replace(
            content=self.generate_th_layer(misinformation).strip()
        )

        # compose and format the string
        rebuttal = f"""{self.hamburger[1].name}: {self.hamburger[1].content}\n{self.hamburger[2].name}: {self.hamburger[2].content}\n{self.hamburger[3].name}: {self.hamburger[3].content}\n{self.hamburger[4].name}: {self.hamburger[4].content}"""

        return rebuttal

    def endpoint_query(self, payload, model):
        headers = {"Authorization": f"Bearer {os.environ['HUGGINGFACEHUB_API_TOKEN']}"}
        options = {"use_cache": False, "wait_for_model": True}
        payload = {"inputs": payload, "options": options}
        api_url = f"https://api-inference.huggingface.co/models/{model}"
        response = requests.post(api_url, headers=headers, json=payload, timeout=120)
        return json.loads(response.content.decode("utf-8"))

    def retry_on_exceptions(self, function, *args):
        attempt = 0
        while attempt < 5:
            try:
                return function(*args)
            except (KeyError, ValueError):
                print("retrying %d out of 5", attempt + 1)
                time.sleep(5 * (attempt + 1))
                attempt += 1
                continue
        # Return None if no response after five attempts
        return None

    def get_fever_claims(self, label):
        claims = []

        with open(self.filename, "r", encoding="utf-8") as csvfile:
            reader = csv.DictReader(csvfile)
            for row in reader:
                if row["claim_label"] == 1 and row["CARDS_label"] == label:
                    claims.append(row["claim"])
        return claims

    def get_fever_evidence(self, claim):
        evidences = []
        with open(self.filename, "r", encoding="utf-8") as csvfile:
            reader = csv.DictReader(csvfile)
            for row in reader:
                if row["claim_label"] == 1 and row["claim"] == claim:
                    for evidence_dict in literal_eval(row["evidences"]):
                        evidences.append(evidence_dict["evidence"])
        return "\n".join("* " + evidence for evidence in evidences)