danseith commited on
Commit
cd3e092
·
1 Parent(s): 020fa3d

Fixed loop structure and output sampling to avoid infinite loops. Now allows deletions.

Browse files
Files changed (1) hide show
  1. app.py +22 -19
app.py CHANGED
@@ -135,34 +135,37 @@ PIPELINE_REGISTRY.register_pipeline(
135
  scrambler = pipeline("temp-scale", model="anferico/bert-for-patents")
136
 
137
 
 
 
 
 
 
 
 
 
 
 
 
138
  def unmask(text, temp, rounds):
139
  sampling = 'multi'
140
- successful_iters = 0
141
- unsuccessful_iters = 0
142
- while successful_iters < rounds or unsuccessful_iters > 5:
143
- unsuccessful_iters += 1
144
  tp = add_mask(text, size=1)
145
  masked_text, masked = tp[0], tp[1]
146
  split_text = masked_text.split()
147
- res = scrambler(masked_text, temp=temp, top_k=10)
148
  mask_pos = [i for i, t in enumerate(split_text) if 'MASK' in t][0]
149
  out = {item["token_str"]: item["score"] for item in res}
150
- score_to_str = {out[k] : k for k in out.keys()}
151
- score_list = list(score_to_str.keys())
152
- if sampling == 'multi':
153
- idx = np.argmax(np.random.multinomial(1, score_list, 1))
154
- else:
155
- idx = np.random.randint(0, len(score_list))
156
- score = score_list[idx]
157
- new_token = score_to_str[score]
158
- if len(list(new_token)) < 2 or new_token == masked[0]:
159
- continue
160
  split_text[mask_pos] = '*' + new_token + '*'
161
  text = ' '.join(split_text)
162
- successful_iters += 1
163
- unsuccessful_iters -= 1
164
- if unsuccessful_iters > 5:
165
- text = "Ran into an issue :( Please try again."
166
  text = list(text)
167
  text[0] = text[0].upper()
168
  return ''.join(text)
 
135
  scrambler = pipeline("temp-scale", model="anferico/bert-for-patents")
136
 
137
 
138
+ def sample_output(out, sampling):
139
+ score_to_str = {out[k]: k for k in out.keys()}
140
+ score_list = list(score_to_str.keys())
141
+ if sampling == 'multi':
142
+ idx = np.argmax(np.random.multinomial(1, score_list, 1))
143
+ else:
144
+ idx = np.random.randint(0, len(score_list))
145
+ score = score_list[idx]
146
+ return score_to_str[score]
147
+
148
+
149
  def unmask(text, temp, rounds):
150
  sampling = 'multi'
151
+ for _ in range(rounds):
 
 
 
152
  tp = add_mask(text, size=1)
153
  masked_text, masked = tp[0], tp[1]
154
  split_text = masked_text.split()
155
+ res = scrambler(masked_text, temp=temp, top_k=15)
156
  mask_pos = [i for i, t in enumerate(split_text) if 'MASK' in t][0]
157
  out = {item["token_str"]: item["score"] for item in res}
158
+ new_token = sample_output(out, sampling)
159
+ unsuccessful_iters = 0
160
+ while new_token == masked[0]:
161
+ if unsuccessful_iters > 5:
162
+ break
163
+ print(new_token)
164
+ new_token = sample_output(out, sampling='uniform')
165
+ unsuccessful_iters += 1
 
 
166
  split_text[mask_pos] = '*' + new_token + '*'
167
  text = ' '.join(split_text)
168
+
 
 
 
169
  text = list(text)
170
  text[0] = text[0].upper()
171
  return ''.join(text)