lunahr commited on
Commit
4888226
·
verified ·
1 Parent(s): 610707e

wow we had bad inference code

Browse files
Files changed (1) hide show
  1. README.md +16 -10
README.md CHANGED
@@ -43,21 +43,27 @@ messages = [
43
  ]
44
 
45
  # Generate reasoning
46
- reasoning_template = tokenizer.apply_chat_template(messages, tokenize=False, add_reasoning_prompt=True)
47
- reasoning_inputs = tokenizer(reasoning_template, return_tensors="pt").to(model.device)
48
- reasoning_ids = model.generate(**reasoning_inputs, max_new_tokens=MAX_REASONING_TOKENS)
49
- reasoning_output = tokenizer.decode(reasoning_ids[0, reasoning_inputs.input_ids.shape[1]:], skip_special_tokens=True)
 
 
 
50
 
51
- print("REASONING: " + reasoning_output)
52
 
53
  # Generate answer
54
  messages.append({"role": "reasoning", "content": reasoning_output})
55
- response_template = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
56
- response_inputs = tokenizer(response_template, return_tensors="pt").to(model.device)
57
- response_ids = model.generate(**response_inputs, max_new_tokens=MAX_RESPONSE_TOKENS)
58
- response_output = tokenizer.decode(response_ids[0, response_inputs.input_ids.shape[1]:], skip_special_tokens=True)
 
 
 
59
 
60
- print("ANSWER: " + response_output)
61
  ```
62
 
63
  - **Trained by:** [Piotr Zalewski](https://huggingface.co/lunahr)
 
43
  ]
44
 
45
  # Generate reasoning
46
+ input_ids = tokenizer.apply_chat_template(messages, tokenize=False, add_reasoning_prompt=True, return_tensors="pt")
47
+ output = model.generate(
48
+ input_ids.to("cuda"),
49
+ eos_token_id=tokenizer.eos_token_id,
50
+ max_new_tokens=MAX_REASONING_TOKENS,
51
+ do_sample=False,
52
+ )
53
 
54
+ print("REASONING: " + tokenizer.decode(output[0]))
55
 
56
  # Generate answer
57
  messages.append({"role": "reasoning", "content": reasoning_output})
58
+ input_ids = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, return_tensors="pt")
59
+ output = model.generate(
60
+ input_ids.to("cuda"),
61
+ eos_token_id=tokenizer.eos_token_id,
62
+ max_new_tokens=MAX_RESPONSE_TOKENS,
63
+ do_sample=False,
64
+ )
65
 
66
+ print("REASONING: " + tokenizer.decode(output[0]))
67
  ```
68
 
69
  - **Trained by:** [Piotr Zalewski](https://huggingface.co/lunahr)