aari1995 commited on
Commit
38e1149
·
verified ·
1 Parent(s): 5555ddd

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +37 -0
README.md CHANGED
@@ -160,6 +160,43 @@ prompt = "Schreibe eine Stellenanzeige für Data Scientist bei AXA!"
160
  final_prompt = prompt_template.format(prompt=prompt)
161
  ```
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  ### German benchmarks
164
 
165
  | **German tasks:** | **MMLU-DE** | **Hellaswag-DE** | **ARC-DE** |**Average** |
 
160
  final_prompt = prompt_template.format(prompt=prompt)
161
  ```
162
 
163
+ #### Limit the model to output reply-only:
164
+ To solve this, you need to implement a custom stopping criteria:
165
+
166
+ ```python
167
+ from transformers import StoppingCriteria
168
+ class GermeoStoppingCriteria(StoppingCriteria):
169
+ def __init__(self, target_sequence, prompt):
170
+ self.target_sequence = target_sequence
171
+ self.prompt=prompt
172
+
173
+ def __call__(self, input_ids, scores, **kwargs):
174
+ # Get the generated text as a string
175
+ generated_text = tokenizer.decode(input_ids[0])
176
+ generated_text = generated_text.replace(self.prompt,'')
177
+ # Check if the target sequence appears in the generated text
178
+ if self.target_sequence in generated_text:
179
+ return True # Stop generation
180
+
181
+ return False # Continue generation
182
+
183
+ def __len__(self):
184
+ return 1
185
+
186
+ def __iter__(self):
187
+ yield self
188
+ ```
189
+ This then expects your input prompt (formatted as given into the model), and a stopping criteria, in this case the im_end token. Simply add it to the generation:
190
+
191
+ ```python
192
+ generation_output = model.generate(
193
+ tokens,
194
+ streamer=streamer,
195
+ max_new_tokens=1012,
196
+ stopping_criteria=GermeoStoppingCriteria("<|im_end|>", prompt_template.format(prompt=prompt))
197
+ )
198
+ ```
199
+
200
  ### German benchmarks
201
 
202
  | **German tasks:** | **MMLU-DE** | **Hellaswag-DE** | **ARC-DE** |**Average** |