File size: 5,654 Bytes
d7068b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import re
import torch
import gradio as gr
from transformers import pipeline, AutoTokenizer
from langchain.text_splitter import RecursiveCharacterTextSplitter

class AbuseHateProfanityDetector:
    def __init__(self):
        # Device configuration (CPU or GPU)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Initialize detection models
        self.Abuse_detector = pipeline("text-classification", model="Hate-speech-CNERG/english-abusive-MuRIL", device=self.device)
        self.Hate_speech_detector = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-hate-latest", device=self.device)
        self.Profanity_detector = pipeline("text-classification", model="tarekziade/pardonmyai", device=self.device)

        # Load tokenizers
        self.abuse_tokenizer = AutoTokenizer.from_pretrained('Hate-speech-CNERG/english-abusive-MuRIL')
        self.hate_speech_tokenizer = AutoTokenizer.from_pretrained('cardiffnlp/twitter-roberta-base-hate-latest')
        self.profanity_tokenizer = AutoTokenizer.from_pretrained('tarekziade/pardonmyai')

        # Define max token sizes for each model
        self.Abuse_max_context_size = 512
        self.HateSpeech_max_context_size = 512
        self.Profanity_max_context_size = 512

    def preprocess_and_clean_text(self, text: str) -> str:
        """
        Preprocesses and cleans the text.
        """
        stammering_pattern = r'\b(\w+)\s*[,;]+\s*(\1\b\s*[,;]*)+'
        passage_without_stammering = re.sub(stammering_pattern, r'\1', text)
        passage_without_um = re.sub(r'\bum\b', ' ', passage_without_stammering)
        modified_text = re.sub(r'\s*,+\s*', ', ', passage_without_um)
        processed_text = re.sub(r'\s+([^\w\s])', r'\1', modified_text)
        processed_text = re.sub(r'\s+', ' ', processed_text)
        pattern = r'(\.\s*)+'
        cleaned_text = re.sub(pattern, '.', processed_text)
        return cleaned_text.strip()

    def token_length(self, text, tokenizer):
        """
        Computes the token length of a text.
        """
        tokens = tokenizer.encode(text, add_special_tokens=False)
        return len(tokens)

    def create_token_length_wrapper(self, tokenizer):
        """
        Creates a closure to calculate token length using the tokenizer.
        """
        def token_length_wrapper(text):
            return self.token_length(text, tokenizer)
        return token_length_wrapper

    def chunk_text(self, text, tokenizer, max_length):
        """
        Chunks the input text based on the max token length and cleans the text.
        """
        text = self.preprocess_and_clean_text(text)
        token_length_wrapper = self.create_token_length_wrapper(tokenizer)
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=max_length - 2, length_function=token_length_wrapper)
        chunks = text_splitter.split_text(text)
        return chunks

    def classify_text(self, text: str):
        """
        Classifies text for abuse, hate speech, and profanity using the respective models.
        """
        # Split text into chunks for each classification model
        abuse_chunks = self.chunk_text(text, self.abuse_tokenizer, self.Abuse_max_context_size)
        hate_speech_chunks = self.chunk_text(text, self.hate_speech_tokenizer, self.HateSpeech_max_context_size)
        profanity_chunks = self.chunk_text(text, self.profanity_tokenizer, self.Profanity_max_context_size)

        # Initialize flags
        abusive_flag = False
        hatespeech_flag = False
        profanity_flag = False

        # Detect Abuse
        for chunk in abuse_chunks:
            result = self.Abuse_detector(chunk)
            if result[0]['label'] == 'LABEL_1':  # Assuming LABEL_1 is abusive content
                abusive_flag = True

        # Detect Hate Speech
        for chunk in hate_speech_chunks:
            result = self.Hate_speech_detector(chunk)
            if result[0]['label'] == 'HATE':  # Assuming HATE label indicates hate speech
                hatespeech_flag = True

        # Detect Profanity
        for chunk in profanity_chunks:
            result = self.Profanity_detector(chunk)
            if result[0]['label'] == 'OFFENSIVE':  # Assuming OFFENSIVE label indicates profanity
                profanity_flag = True

        # Return classification results
        return {
            "abusive_flag": abusive_flag,
            "hatespeech_flag": hatespeech_flag,
            "profanity_flag": profanity_flag
        }

    def extract_speaker_text(self, transcript, client_label="Client", care_provider_label="Care Provider"):
        """
        Extracts text spoken by the client and the care provider from the transcript.
        """
        client_text = []
        care_provider_text = []

        lines = transcript.split("\n")
        for line in lines:
            if line.startswith(client_label + ":"):
                client_text.append(line[len(client_label) + 1:].strip())
            elif line.startswith(care_provider_label + ":"):
                care_provider_text.append(line[len(care_provider_label) + 1:].strip())

        return " ".join(client_text), " ".join(care_provider_text)

# Gradio interface for the web app
detector = AbuseHateProfanityDetector()

interface = gr.Interface(
    fn=detector.classify_text,
    inputs=[gr.Textbox(label="Enter text")],
    outputs="json",
    title="Abuse, Hate Speech, and Profanity Detection",
    description="Enter text to detect whether it contains abusive, hateful, or offensive content."
)

# Launch the Gradio app
if __name__ == "__main__":
    interface.launch(share=True)