D1rtyB1rd commited on
Commit
441b19a
·
verified ·
1 Parent(s): de61ff2

Upload Tiny-Alice-multi-turn-chat.py

Browse files
Files changed (1) hide show
  1. Tiny-Alice-multi-turn-chat.py +74 -0
Tiny-Alice-multi-turn-chat.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+
4
+ # Load model and tokenizer
5
+ model_path = "D1rtyB1rd/Dirty-Alice-Tiny-1.1B-V2"
6
+ model = AutoModelForCausalLM.from_pretrained(model_path)
7
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
8
+
9
+ # Define the stop token and system message
10
+ stop_token_id = 2 ## </s>
11
+ system_message = "<|system|>\nYou are Alice.\n</s>"
12
+
13
+ def chat_with_model(prompt_text, stop_token_id, model, tokenizer):
14
+ # Encode the prompt text
15
+ encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
16
+
17
+ # Generate response
18
+ output_sequences = model.generate(
19
+ input_ids=encoded_prompt,
20
+ max_new_tokens=1024,
21
+ temperature=0.2,
22
+ repetition_penalty=1.2,
23
+ top_k=20,
24
+ top_p=0.9,
25
+ do_sample=True,
26
+ num_return_sequences=1,
27
+ eos_token_id=stop_token_id,
28
+ )
29
+
30
+ # Decode the generated sequence
31
+ generated_sequence = output_sequences[0].tolist()
32
+ text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
33
+
34
+ # Find the position of the stop token and truncate if necessary
35
+ stop_token_str = tokenizer.decode([stop_token_id], clean_up_tokenization_spaces=True)
36
+ if stop_token_str in text:
37
+ text = text.split(stop_token_str)[0] # Remove text after the stop token
38
+
39
+ response_text = text[len(prompt_text):].strip() # Extract only the response text
40
+ return response_text
41
+
42
+ def build_prompt(conversation_history, user_input):
43
+ """
44
+ Constructs the prompt for the model using conversation history and the latest user input.
45
+ """
46
+ prompt_text = f"{conversation_history}\n<|user|>\n{user_input}\n</s>\n<|assistant|>\n"
47
+ return prompt_text
48
+
49
+ def main():
50
+ # Initialize conversation history with the system message
51
+ conversation_history = f"{system_message}\n"
52
+
53
+ # Chat loop
54
+ while True:
55
+ user_input = input("User: ") # Get text input from the user
56
+
57
+ # Construct prompt text for model input
58
+ prompt_text = build_prompt(conversation_history, user_input)
59
+
60
+ response_text = chat_with_model(prompt_text, stop_token_id, model, tokenizer)
61
+ response_text = response_text.replace('<s>', '')
62
+
63
+ print(f"\n------\nAlice:\n{response_text}\n------")
64
+
65
+ # Update conversation history
66
+ conversation_history += f"<|user|>\n{user_input}\n</s>\n<|assistant|>\n{response_text}\n</s>\n"
67
+
68
+ # Trim the conversation history to avoid overly long inputs
69
+ if len(conversation_history) > 2048:
70
+ conversation_history = conversation_history[-1024:]
71
+
72
+ if __name__ == "__main__":
73
+ main()
74
+