Update utils.py
Browse files
utils.py
CHANGED
@@ -389,11 +389,10 @@ def query(api_llm, payload):
|
|
389 |
def llm_chain2(prompt, context):
|
390 |
full_prompt = RAG_CHAIN_PROMPT.format(context=context, question=prompt)
|
391 |
inputs = tokenizer_rag(full_prompt, return_tensors="pt", max_length=1024, truncation=True)
|
392 |
-
attention_mask = (inputs != tokenizer_rag.pad_token_id).long()
|
393 |
#Generiere die Antwort
|
394 |
outputs = modell_rag.generate(
|
395 |
inputs.input_ids,
|
396 |
-
attention_mask=attention_mask,
|
397 |
max_new_tokens=1024,
|
398 |
do_sample=True,
|
399 |
temperature=0.9,
|
|
|
389 |
def llm_chain2(prompt, context):
|
390 |
full_prompt = RAG_CHAIN_PROMPT.format(context=context, question=prompt)
|
391 |
inputs = tokenizer_rag(full_prompt, return_tensors="pt", max_length=1024, truncation=True)
|
|
|
392 |
#Generiere die Antwort
|
393 |
outputs = modell_rag.generate(
|
394 |
inputs.input_ids,
|
395 |
+
attention_mask=inputs.attention_mask,
|
396 |
max_new_tokens=1024,
|
397 |
do_sample=True,
|
398 |
temperature=0.9,
|