nicholasKluge commited on
Commit
7e63de2
·
1 Parent(s): 3f72719

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -22
app.py CHANGED
@@ -178,36 +178,37 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped') as demo:
178
  decoded_text = [tokenizer.decode(tokens, skip_special_tokens=True).replace(user_msg, "") for tokens in generated_response]
179
 
180
  rewards = list()
181
- toxicities = list()
 
 
182
 
183
  for text in decoded_text:
184
- reward_tokens = rewardTokenizer(user_msg, text,
185
- truncation=True,
186
- max_length=512,
187
- return_token_type_ids=False,
188
- return_tensors="pt",
189
- return_attention_mask=True)
190
-
191
- reward_tokens.to(rewardModel.device)
192
-
193
- reward = rewardModel(**reward_tokens)[0].item()
194
-
195
- toxicity_tokens = toxiciyTokenizer(user_msg + " " + text,
196
  truncation=True,
197
  max_length=512,
198
  return_token_type_ids=False,
199
  return_tensors="pt",
200
  return_attention_mask=True)
201
-
202
- toxicity_tokens.to(toxicityModel.device)
203
-
204
- toxicity = toxicityModel(**toxicity_tokens)[0].item()
205
-
206
- rewards.append(reward)
207
- toxicities.append(toxicity)
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
- toxicity_threshold = 5
210
-
211
  ordered_generations = sorted(zip(decoded_text, rewards, toxicities), key=lambda x: x[1], reverse=True)
212
 
213
  if safety == "On":
 
178
  decoded_text = [tokenizer.decode(tokens, skip_special_tokens=True).replace(user_msg, "") for tokens in generated_response]
179
 
180
  rewards = list()
181
+
182
+ if safety == "On":
183
+ toxicities = list()
184
 
185
  for text in decoded_text:
186
+ reward_tokens = rewardTokenizer(user_msg, text,
 
 
 
 
 
 
 
 
 
 
 
187
  truncation=True,
188
  max_length=512,
189
  return_token_type_ids=False,
190
  return_tensors="pt",
191
  return_attention_mask=True)
192
+
193
+ reward_tokens.to(rewardModel.device)
194
+
195
+ reward = rewardModel(**reward_tokens)[0].item()
196
+ rewards.append(reward)
197
+
198
+ if safety == "On":
199
+ toxicity_tokens = toxiciyTokenizer(user_msg + " " + text,
200
+ truncation=True,
201
+ max_length=512,
202
+ return_token_type_ids=False,
203
+ return_tensors="pt",
204
+ return_attention_mask=True)
205
+
206
+ toxicity_tokens.to(toxicityModel.device)
207
+
208
+ toxicity = toxicityModel(**toxicity_tokens)[0].item()
209
+ toxicities.append(toxicity)
210
+ toxicity_threshold = 5
211
 
 
 
212
  ordered_generations = sorted(zip(decoded_text, rewards, toxicities), key=lambda x: x[1], reverse=True)
213
 
214
  if safety == "On":