Dirty-Alice-Tiny-1.1B-V2-Chatml / chatml_tiny-multi-chat.py
D1rtyB1rd's picture
Upload chatml_tiny-multi-chat.py
4009c9e verified
raw
history blame
3.36 kB
import os
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
# Load model and tokenizer
model_path = "D1rtyB1rd/Dirty-Alice-Tiny-1.1B-V2-Chatml"
model = AutoModelForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Define the system message
system_message = "<|im_start|>system\nYou are Alice.\n<|im_end|>"
class StopWordCriteria(StoppingCriteria):
def __init__(self, stop_words, tokenizer):
self.stop_words = stop_words
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs):
# Decode the generated tokens to text
generated_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
# Check if any of the stop words are in the generated text
for stop_word in self.stop_words:
if stop_word in generated_text:
return True
return False
def chat_with_model(prompt_text, stop_word, model, tokenizer):
# Encode the prompt text
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
# Create custom stopping criteria
stopping_criteria = StoppingCriteriaList([StopWordCriteria(stop_words=[stop_word], tokenizer=tokenizer)])
# Generate response
output_sequences = model.generate(
input_ids=encoded_prompt,
max_new_tokens=1024,
temperature=0.2,
repetition_penalty=1.2,
top_k=20,
top_p=0.9,
do_sample=True,
num_return_sequences=1,
stopping_criteria=stopping_criteria, # Use custom stopping criteria
)
# Decode the generated sequence
generated_sequence = output_sequences[0].tolist()
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
# Find the stop word and truncate the text if necessary
if stop_word in text:
text = text.split(stop_word)[0]
response_text = text[len(prompt_text):].strip() # Extract only the response text
return response_text
def build_prompt(conversation_history, user_input):
"""
Constructs the prompt for the model using conversation history and the latest user input.
"""
prompt_text = f"{conversation_history}<|im_start|>user\n{user_input}\n<|im_end|>\n<|im_start|>assistant\n"
return prompt_text
def main():
# Initialize conversation history with the system message
conversation_history = f"{system_message}\n"
stop_word = "<|im_end|>"
# Chat loop
while True:
user_input = input("User: ") # Get text input from the user
# Construct prompt text for model input
prompt_text = build_prompt(conversation_history, user_input)
response_text = chat_with_model(prompt_text, stop_word, model, tokenizer)
response_text = response_text.replace('<s>', '')
print(f"\n------\nAlice:\n{response_text}\n------")
# Update conversation history
conversation_history += f"<|im_start|>user\n{user_input}\n<|im_end|>\n<|im_start|>assistant\n{response_text}\n<|im_end|>\n"
# Trim the conversation history to avoid overly long inputs
if len(conversation_history) > 2048:
conversation_history = conversation_history[-1024:]
if __name__ == "__main__":
main()