Any way to set the "stop, split by" when running the model locally?
Hello, I've been trying to replicate the GPT-JT spaces demo results locally. So far the results are getting close to the demo but seems to be struggling for tasks that require deeper contextual understanding like answering questions from give examples. I've noticed by changing the "stop, split by" parameter on the web demo from the default, the model also seems to struggle on tasks it previously was performing well at. However, there doesn't seem to be an argument that is settable in the Transformers pipeline. Could anyone provide an explanation on how or if setting "stop, split by" is possible locally?
I am using the following code to load and run the model:
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import pipeline
tokenizer = AutoTokenizer.from_pretrained("togethercomputer/GPT-JT-6B-v1")
model = AutoModelForCausalLM.from_pretrained("togethercomputer/GPT-JT-6B-v1", device_map="auto", load_in_8bit=True)
pipe = pipeline("text-generation",model=model, tokenizer=tokenizer)
gen_args = {
'top_p':1.0,
'top_k' : 40,
'temperature':0.01,
'repetition_penalty' : 1.0,
'max_new_tokens' : 2,
'do_sample' : True,
#'stop' : '\n' # This doesn't work and raises an exception about 'model_kwargs' not being used
}
output = pipe(prompt, **gen_args)
print(output[0]['generated_text'])
@jswowah The "stop," parameter is not a configurable setting in the Transformers library. You can do a post-process yourself after generate the text or define a custom stop criteria:
import torch
from transformers import StoppingCriteria, StoppingCriteriaList
class StopWordsCriteria(StoppingCriteria):
def __init__(self, stop_words, tokenizer):
self.tokenizer = tokenizer
self.stop_words = stop_words
self._cache_str = ''
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
self._cache_str += self.tokenizer.decode(input_ids[0, -1])
for stop_words in self.stop_words:
if stop_words in self._cache_str:
return True
return False
And then do
pipe("Question: What currency is used in Zurich?\n\nAnswer:", stopping_criteria=StoppingCriteriaList([StopWordsCriteria(['\n', 'other stop words'], tokenizer)]))
@juewang
Thanks for the clarification, I am getting much better results after implementing the StopWordsCriteria
. Does the order of the list of stopping words affect how and when the model decides to stop?
@jswowah There is no difference I think. You can customize this part if you want to.
@juewang
Great, thanks again for the help.