alexkueck commited on
Commit
8cf4fcd
·
verified ·
1 Parent(s): 72fc81a

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +1 -2
utils.py CHANGED
@@ -389,11 +389,10 @@ def query(api_llm, payload):
389
  def llm_chain2(prompt, context):
390
  full_prompt = RAG_CHAIN_PROMPT.format(context=context, question=prompt)
391
  inputs = tokenizer_rag(full_prompt, return_tensors="pt", max_length=1024, truncation=True)
392
- attention_mask = (inputs != tokenizer_rag.pad_token_id).long()
393
  #Generiere die Antwort
394
  outputs = modell_rag.generate(
395
  inputs.input_ids,
396
- attention_mask=attention_mask,
397
  max_new_tokens=1024,
398
  do_sample=True,
399
  temperature=0.9,
 
389
  def llm_chain2(prompt, context):
390
  full_prompt = RAG_CHAIN_PROMPT.format(context=context, question=prompt)
391
  inputs = tokenizer_rag(full_prompt, return_tensors="pt", max_length=1024, truncation=True)
 
392
  #Generiere die Antwort
393
  outputs = modell_rag.generate(
394
  inputs.input_ids,
395
+ attention_mask=inputs.attention_mask,
396
  max_new_tokens=1024,
397
  do_sample=True,
398
  temperature=0.9,