Update app.py
Browse files
app.py
CHANGED
@@ -92,22 +92,27 @@ def replace_pp_with_pause(sentence, entity_tags):
|
|
92 |
for tag in entity_tags:
|
93 |
start = tag['start']
|
94 |
end = tag['end']
|
95 |
-
if end<len(sentence)-1:
|
96 |
token = sentence[start:end] # Adjust for 0-based indexing
|
97 |
else:
|
98 |
-
token = sentence[start:end+1]
|
99 |
-
tag_name = f"[{tag['entity_group']}]"
|
100 |
|
101 |
-
if tag['entity_group'] == 'PP'
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
else:
|
105 |
-
|
106 |
|
107 |
-
|
|
|
|
|
108 |
|
109 |
-
# Return the sentence with [PAUSE] replacement
|
110 |
-
return " ".join(tagged_tokens)
|
111 |
|
112 |
|
113 |
def get_split_sentences(sentence, entity_tags):
|
@@ -139,34 +144,9 @@ def get_split_sentences(sentence, entity_tags):
|
|
139 |
|
140 |
# If the sentence ends without a [PAUSE] token, add the final sentence
|
141 |
if current_sentence:
|
142 |
-
split_sentences.append("
|
143 |
|
144 |
return split_sentences
|
145 |
-
# def get_split_sentences(sentence, entity_tags):
|
146 |
-
# split_sentences = []
|
147 |
-
|
148 |
-
# # Initialize a variable to hold the current sentence
|
149 |
-
# current_sentence = []
|
150 |
-
|
151 |
-
# # Process the entity tags to split the sentence
|
152 |
-
# for tag in entity_tags:
|
153 |
-
# if tag['entity_group'] == 'PP':
|
154 |
-
# if current_sentence:
|
155 |
-
# print(current_sentence)
|
156 |
-
# split_sentences.append(" ".join(current_sentence))
|
157 |
-
# current_sentence = [] # Reset the current sentence
|
158 |
-
# else:
|
159 |
-
# start = tag['start']
|
160 |
-
# end = tag['end']
|
161 |
-
# token = sentence[start - 1:end] # Adjust for 0-based indexing
|
162 |
-
# current_sentence.append(token)
|
163 |
-
|
164 |
-
# # If the sentence ends without a [PAUSE] token, add the final sentence
|
165 |
-
# if current_sentence:
|
166 |
-
# split_sentences.append(" ".join(current_sentence))
|
167 |
-
|
168 |
-
# return split_sentences
|
169 |
-
|
170 |
|
171 |
|
172 |
|
@@ -510,9 +490,9 @@ def analyze_heatmap(df_input):
|
|
510 |
)
|
511 |
|
512 |
# Additional styling
|
513 |
-
ax.set_title("Importance Score per Token", size=25)
|
514 |
-
ax.set_xlabel("Token")
|
515 |
-
ax.set_ylabel("Importance Value")
|
516 |
ax.set_xticks(range(len(df["token"])))
|
517 |
ax.set_xticklabels(df["token"], rotation=45)
|
518 |
|
@@ -724,7 +704,8 @@ class SentenceAnalyzer:
|
|
724 |
attribution_df1 = process_integrated_gradients(split_sentence, self._gpt2tokenizer, self.model)
|
725 |
if i < len(self.split_sentences) - 1:
|
726 |
# Add a row with [PAUSE] and value 0 at the end
|
727 |
-
pause_row = pd.DataFrame({'token': '[PAUSE]', 'importance_value': 0},index=[len(attribution_df1)])
|
|
|
728 |
attribution_df1 = pd.concat([attribution_df1,pause_row], ignore_index=True)
|
729 |
|
730 |
dataframes_list.append(attribution_df1)
|
|
|
92 |
for tag in entity_tags:
|
93 |
start = tag['start']
|
94 |
end = tag['end']
|
95 |
+
if end < len(sentence) - 1:
|
96 |
token = sentence[start:end] # Adjust for 0-based indexing
|
97 |
else:
|
98 |
+
token = sentence[start:end + 1]
|
|
|
99 |
|
100 |
+
tag_name = '[PAUSE]' if tag['entity_group'] == 'PP' else ''
|
101 |
+
tagged_tokens.append(f"{token}{tag_name}")
|
102 |
+
print(tagged_tokens)
|
103 |
+
|
104 |
+
# Return the sentence with [PAUSE] replacement and spaces preserved
|
105 |
+
modified_words = []
|
106 |
+
for i, word in enumerate(tagged_tokens):
|
107 |
+
if word.startswith("'s"):
|
108 |
+
modified_words[-1] = modified_words[-1] + word
|
109 |
else:
|
110 |
+
modified_words.append(word)
|
111 |
|
112 |
+
output = " ".join(modified_words)
|
113 |
+
|
114 |
+
return output
|
115 |
|
|
|
|
|
116 |
|
117 |
|
118 |
def get_split_sentences(sentence, entity_tags):
|
|
|
144 |
|
145 |
# If the sentence ends without a [PAUSE] token, add the final sentence
|
146 |
if current_sentence:
|
147 |
+
split_sentences.append("".join(current_sentence))
|
148 |
|
149 |
return split_sentences
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
|
152 |
|
|
|
490 |
)
|
491 |
|
492 |
# Additional styling
|
493 |
+
# ax.set_title("Importance Score per Token", size=25)
|
494 |
+
# ax.set_xlabel("Token")
|
495 |
+
# ax.set_ylabel("Importance Value")
|
496 |
ax.set_xticks(range(len(df["token"])))
|
497 |
ax.set_xticklabels(df["token"], rotation=45)
|
498 |
|
|
|
704 |
attribution_df1 = process_integrated_gradients(split_sentence, self._gpt2tokenizer, self.model)
|
705 |
if i < len(self.split_sentences) - 1:
|
706 |
# Add a row with [PAUSE] and value 0 at the end
|
707 |
+
# pause_row = pd.DataFrame({'token': '[PAUSE]', 'importance_value': 0},index=[len(attribution_df1)])
|
708 |
+
pause_row = pd.DataFrame({'', '': 0},index=[len(attribution_df1)])
|
709 |
attribution_df1 = pd.concat([attribution_df1,pause_row], ignore_index=True)
|
710 |
|
711 |
dataframes_list.append(attribution_df1)
|