flicc-agent / utils /core.py
Francisco Zanartu
clean leading "response:"
7f7028e
raw
history blame
8.82 kB
"""
Module for detecting fallacies in text.
"""
import os
import re
import time
import json
import csv
from ast import literal_eval
from collections import namedtuple
import requests
from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.chat_models.huggingface import ChatHuggingFace
from langchain.agents import AgentExecutor, load_tools, create_react_agent
from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser
from langchain.tools import Tool
from langchain.tools import DuckDuckGoSearchRun
from .templates import (
REACT,
INCONTEXT,
SUMMARIZATION,
CONCLUDING,
CONCLUDING_INCONTEXT,
)
from .definitions import DEFINITIONS
from .examples import FALLACY_CLAIMS, DEBUNKINGS
os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.environ.get("HF_API_KEY")
class HamburgerStyle:
def __init__(self):
# hamburger-style structure:
self.heading = namedtuple("Heading", ["name", "content"])
self.hamburger = [
self.heading(name="Myth", content=None),
self.heading(name="##FACT", content=None),
self.heading(name="##MYTH", content=None),
self.heading(name="##FALLACY", content=None),
self.heading(name="##FACT", content=None),
]
self.llm = HuggingFaceEndpoint(
repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
temperature=1,
top_k=1,
model_kwargs={
"use_cache": False,
},
)
self.chat_model = ChatHuggingFace(llm=self.llm)
self.flicc_model = "fzanartu/flicc"
self.card_model = "crarojasca/BinaryAugmentedCARDS"
self.semantic_textual_similarity = "sentence-transformers/all-MiniLM-L6-v2"
self.taxonomy_cards = "crarojasca/TaxonomyAugmentedCARDS"
self.dirname = os.path.dirname(os.path.abspath("__file__"))
self.filename = os.path.join(self.dirname, "utils/climate_fever_cards.csv")
def generate_st_layer(self, misinformation):
## FACT: ReAct
prompt = REACT
# define the agent
chat_model_with_stop = self.chat_model.bind(stop=["\nObservation"])
agent = (
{
"input": lambda x: x["input"],
"agent_scratchpad": lambda x: format_log_to_str(
x["intermediate_steps"]
),
}
| prompt
| self.chat_model
| ReActJsonSingleInputOutputParser()
)
search = DuckDuckGoSearchRun()
tools = [
Tool(
name="google_search",
description="Search Google for recent results.",
func=search.run,
)
]
agent = create_react_agent(chat_model_with_stop, tools, prompt)
agent_executor = AgentExecutor(
agent=agent, tools=tools, verbose=False, handle_parsing_errors=True
)
return agent_executor.invoke({"input": misinformation}).get("output")
def generate_nd_layer(self, misinformation):
## MYTH: Summ
prompt = SUMMARIZATION
chain = prompt | self.llm
return chain.invoke({"text": misinformation})
def generate_rd_layer(self, misinformation):
## FALLACY: Fallacy
# 1 predict fallacy label in FLICC taxonomy
detected_fallacy = self.endpoint_query(
model=self.flicc_model, payload=misinformation
)[0][0].get("label")
fallacy_definition = DEFINITIONS.get(detected_fallacy)
# 2 get all examples with the same label
claims = FALLACY_CLAIMS.get(detected_fallacy, None)
# 3 get cosine similarity for all claims and myth
example_myths = self.endpoint_query(
payload={"source_sentence": misinformation, "sentences": claims},
model=self.semantic_textual_similarity,
)
# 3 # get most similar claim and FACT
max_similarity = example_myths.index(max(example_myths))
example_myth = claims[max_similarity]
example_response = DEBUNKINGS.get(claims[max_similarity])
fact = re.findall(r"## FALLACY:.*?(?=##)", example_response, re.DOTALL)[
0
] # get only the fallacy layer from the example.
fact = fact.replace("## FALLACY:", "")
prompt = INCONTEXT
chain = prompt | self.chat_model
content = chain.invoke(
{
"misinformation": misinformation,
"detected_fallacy": detected_fallacy,
"fallacy_definition": fallacy_definition,
"example_response": fact,
"example_myth": example_myth,
"factual_information": self.hamburger[1].content,
}
).content
content = re.sub(r"Response:", "", content)
return content
def generate_th_layer(self, misinformation):
## FACT: Concluding
cards_label = self.endpoint_query(
model=self.taxonomy_cards, payload=misinformation
)[0][0].get("label")
# 1 get all claims with same label from FEVER dataset
claims = self.get_fever_claims(cards_label) # TODO
prompt_completition = {"fact": self.hamburger[1].content}
if claims:
prompt = CONCLUDING_INCONTEXT
example_myths = self.endpoint_query(
payload={
"input": {"source_sentence": misinformation, "sentences": claims}
},
model=self.semantic_textual_similarity,
)
max_similarity = example_myths.index(max(example_myths))
example_myth = claims[max_similarity]
complementary_details = self.get_fever_evidence(example_myth) # TODO
prompt_completition.update({"complementary_details": complementary_details})
else:
prompt = CONCLUDING
chain = prompt | self.llm
return chain.invoke(prompt_completition)
def rebuttal_generator(self, misinformation):
# generate rebuttal
self.hamburger[0] = self.hamburger[0]._replace(content=misinformation)
## FACT
self.hamburger[1] = self.hamburger[1]._replace(
content=self.generate_st_layer(misinformation).strip()
)
## MYTH
self.hamburger[2] = self.hamburger[2]._replace(
content=self.generate_nd_layer(misinformation).strip()
)
## FALLACY
self.hamburger[3] = self.hamburger[3]._replace(
content=self.generate_rd_layer(misinformation).strip()
)
## FACT
self.hamburger[4] = self.hamburger[4]._replace(
content=self.generate_th_layer(misinformation).strip()
)
# compose and format the string
rebuttal = f"""{self.hamburger[1].name}: {self.hamburger[1].content}\n{self.hamburger[2].name}: {self.hamburger[2].content}\n{self.hamburger[3].name}: {self.hamburger[3].content}\n{self.hamburger[4].name}: {self.hamburger[4].content}"""
return rebuttal
def endpoint_query(self, payload, model):
headers = {"Authorization": f"Bearer {os.environ['HUGGINGFACEHUB_API_TOKEN']}"}
options = {"use_cache": False, "wait_for_model": True}
payload = {"inputs": payload, "options": options}
api_url = f"https://api-inference.huggingface.co/models/{model}"
response = requests.post(api_url, headers=headers, json=payload, timeout=120)
return json.loads(response.content.decode("utf-8"))
def retry_on_exceptions(self, function, *args):
attempt = 0
while attempt < 5:
try:
return function(*args)
except (KeyError, ValueError):
print("retrying %d out of 5", attempt + 1)
time.sleep(5 * (attempt + 1))
attempt += 1
continue
# Return None if no response after five attempts
return None
def get_fever_claims(self, label):
claims = []
with open(self.filename, "r", encoding="utf-8") as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
if row["claim_label"] == 1 and row["CARDS_label"] == label:
claims.append(row["claim"])
return claims
def get_fever_evidence(self, claim):
evidences = []
with open(self.filename, "r", encoding="utf-8") as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
if row["claim_label"] == 1 and row["claim"] == claim:
for evidence_dict in literal_eval(row["evidences"]):
evidences.append(evidence_dict["evidence"])
return "\n".join("* " + evidence for evidence in evidences)