arnocandel commited on
Commit
252f114
·
1 Parent(s): 7f963c1

Upload h2oai_pipeline.py

Browse files
Files changed (1) hide show
  1. h2oai_pipeline.py +30 -0
h2oai_pipeline.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import TextGenerationPipeline
2
+ from transformers.pipelines.text_generation import ReturnType
3
+
4
+ human = "<human>:"
5
+ bot = "<bot>:"
6
+
7
+ # human-bot interaction like OIG dataset
8
+ prompt = """{human} {instruction}
9
+ {bot}""".format(
10
+ human=human,
11
+ instruction="{instruction}",
12
+ bot=bot,
13
+ )
14
+
15
+
16
+ class H2OTextGenerationPipeline(TextGenerationPipeline):
17
+ def __init__(self, *args, **kwargs):
18
+ super().__init__(*args, **kwargs)
19
+
20
+ def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
21
+ prompt_text = prompt.format(instruction=prompt_text)
22
+ return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation,
23
+ **generate_kwargs)
24
+
25
+ def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True):
26
+ records = super().postprocess(model_outputs, return_type=return_type,
27
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces)
28
+ for rec in records:
29
+ rec['generated_text'] = rec['generated_text'].split(bot)[1].strip().split(human)[0].strip()
30
+ return records