aisafe commited on
Commit
203cc33
·
verified ·
1 Parent(s): 1291e3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -40
app.py CHANGED
@@ -92,22 +92,27 @@ def replace_pp_with_pause(sentence, entity_tags):
92
  for tag in entity_tags:
93
  start = tag['start']
94
  end = tag['end']
95
- if end<len(sentence)-1:
96
  token = sentence[start:end] # Adjust for 0-based indexing
97
  else:
98
- token = sentence[start:end+1]
99
- tag_name = f"[{tag['entity_group']}]"
100
 
101
- if tag['entity_group'] == 'PP':
102
- # Replace [PP] with [PAUSE]
103
- tag_name = '[PAUSE]'
 
 
 
 
 
 
104
  else:
105
- tag_name = ''
106
 
107
- tagged_tokens.append(f"{token}{tag_name}")
 
 
108
 
109
- # Return the sentence with [PAUSE] replacement
110
- return " ".join(tagged_tokens)
111
 
112
 
113
  def get_split_sentences(sentence, entity_tags):
@@ -139,34 +144,9 @@ def get_split_sentences(sentence, entity_tags):
139
 
140
  # If the sentence ends without a [PAUSE] token, add the final sentence
141
  if current_sentence:
142
- split_sentences.append(" ".join(current_sentence))
143
 
144
  return split_sentences
145
- # def get_split_sentences(sentence, entity_tags):
146
- # split_sentences = []
147
-
148
- # # Initialize a variable to hold the current sentence
149
- # current_sentence = []
150
-
151
- # # Process the entity tags to split the sentence
152
- # for tag in entity_tags:
153
- # if tag['entity_group'] == 'PP':
154
- # if current_sentence:
155
- # print(current_sentence)
156
- # split_sentences.append(" ".join(current_sentence))
157
- # current_sentence = [] # Reset the current sentence
158
- # else:
159
- # start = tag['start']
160
- # end = tag['end']
161
- # token = sentence[start - 1:end] # Adjust for 0-based indexing
162
- # current_sentence.append(token)
163
-
164
- # # If the sentence ends without a [PAUSE] token, add the final sentence
165
- # if current_sentence:
166
- # split_sentences.append(" ".join(current_sentence))
167
-
168
- # return split_sentences
169
-
170
 
171
 
172
 
@@ -510,9 +490,9 @@ def analyze_heatmap(df_input):
510
  )
511
 
512
  # Additional styling
513
- ax.set_title("Importance Score per Token", size=25)
514
- ax.set_xlabel("Token")
515
- ax.set_ylabel("Importance Value")
516
  ax.set_xticks(range(len(df["token"])))
517
  ax.set_xticklabels(df["token"], rotation=45)
518
 
@@ -724,7 +704,8 @@ class SentenceAnalyzer:
724
  attribution_df1 = process_integrated_gradients(split_sentence, self._gpt2tokenizer, self.model)
725
  if i < len(self.split_sentences) - 1:
726
  # Add a row with [PAUSE] and value 0 at the end
727
- pause_row = pd.DataFrame({'token': '[PAUSE]', 'importance_value': 0},index=[len(attribution_df1)])
 
728
  attribution_df1 = pd.concat([attribution_df1,pause_row], ignore_index=True)
729
 
730
  dataframes_list.append(attribution_df1)
 
92
  for tag in entity_tags:
93
  start = tag['start']
94
  end = tag['end']
95
+ if end < len(sentence) - 1:
96
  token = sentence[start:end] # Adjust for 0-based indexing
97
  else:
98
+ token = sentence[start:end + 1]
 
99
 
100
+ tag_name = '[PAUSE]' if tag['entity_group'] == 'PP' else ''
101
+ tagged_tokens.append(f"{token}{tag_name}")
102
+ print(tagged_tokens)
103
+
104
+ # Return the sentence with [PAUSE] replacement and spaces preserved
105
+ modified_words = []
106
+ for i, word in enumerate(tagged_tokens):
107
+ if word.startswith("'s"):
108
+ modified_words[-1] = modified_words[-1] + word
109
  else:
110
+ modified_words.append(word)
111
 
112
+ output = " ".join(modified_words)
113
+
114
+ return output
115
 
 
 
116
 
117
 
118
  def get_split_sentences(sentence, entity_tags):
 
144
 
145
  # If the sentence ends without a [PAUSE] token, add the final sentence
146
  if current_sentence:
147
+ split_sentences.append("".join(current_sentence))
148
 
149
  return split_sentences
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
 
152
 
 
490
  )
491
 
492
  # Additional styling
493
+ # ax.set_title("Importance Score per Token", size=25)
494
+ # ax.set_xlabel("Token")
495
+ # ax.set_ylabel("Importance Value")
496
  ax.set_xticks(range(len(df["token"])))
497
  ax.set_xticklabels(df["token"], rotation=45)
498
 
 
704
  attribution_df1 = process_integrated_gradients(split_sentence, self._gpt2tokenizer, self.model)
705
  if i < len(self.split_sentences) - 1:
706
  # Add a row with [PAUSE] and value 0 at the end
707
+ # pause_row = pd.DataFrame({'token': '[PAUSE]', 'importance_value': 0},index=[len(attribution_df1)])
708
+ pause_row = pd.DataFrame({'', '': 0},index=[len(attribution_df1)])
709
  attribution_df1 = pd.concat([attribution_df1,pause_row], ignore_index=True)
710
 
711
  dataframes_list.append(attribution_df1)