cobalt-salmon / h2oai_pipeline.py
assafm's picture
Upload h2oai_pipeline.py
259ee24
raw
history blame
1.3 kB
from transformers import TextGenerationPipeline
from transformers.pipelines.text_generation import ReturnType
STYLE = "<|prompt|>{instruction}</s><|answer|>"
class H2OTextGenerationPipeline(TextGenerationPipeline):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.prompt = STYLE
def preprocess(
self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs
):
prompt_text = self.prompt.format(instruction=prompt_text)
return super().preprocess(
prompt_text,
prefix=prefix,
handle_long_generation=handle_long_generation,
**generate_kwargs,
)
def postprocess(
self,
model_outputs,
return_type=ReturnType.FULL_TEXT,
clean_up_tokenization_spaces=True,
):
records = super().postprocess(
model_outputs,
return_type=return_type,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
for rec in records:
rec["generated_text"] = (
rec["generated_text"]
.split("<|answer|>")[1]
.strip()
.split("<|prompt|>")[0]
.strip()
)
return records