import os from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList # Load model and tokenizer model_path = "D1rtyB1rd/Dirty-Alice-Tiny-1.1B-V2-Chatml" model = AutoModelForCausalLM.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) # Define the system message system_message = "<|im_start|>system\nYou are Alice.\n<|im_end|>" class StopWordCriteria(StoppingCriteria): def __init__(self, stop_words, tokenizer): self.stop_words = stop_words self.tokenizer = tokenizer def __call__(self, input_ids, scores, **kwargs): # Decode the generated tokens to text generated_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True) # Check if any of the stop words are in the generated text for stop_word in self.stop_words: if stop_word in generated_text: return True return False def chat_with_model(prompt_text, stop_word, model, tokenizer): # Encode the prompt text encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt") # Create custom stopping criteria stopping_criteria = StoppingCriteriaList([StopWordCriteria(stop_words=[stop_word], tokenizer=tokenizer)]) # Generate response output_sequences = model.generate( input_ids=encoded_prompt, max_new_tokens=1024, temperature=0.2, repetition_penalty=1.2, top_k=20, top_p=0.9, do_sample=True, num_return_sequences=1, stopping_criteria=stopping_criteria, # Use custom stopping criteria ) # Decode the generated sequence generated_sequence = output_sequences[0].tolist() text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) # Find the stop word and truncate the text if necessary if stop_word in text: text = text.split(stop_word)[0] response_text = text[len(prompt_text):].strip() # Extract only the response text return response_text def build_prompt(conversation_history, user_input): """ Constructs the prompt for the model using conversation history and the latest user input. """ prompt_text = f"{conversation_history}<|im_start|>user\n{user_input}\n<|im_end|>\n<|im_start|>assistant\n" return prompt_text def main(): # Initialize conversation history with the system message conversation_history = f"{system_message}\n" stop_word = "<|im_end|>" # Chat loop while True: user_input = input("User: ") # Get text input from the user # Construct prompt text for model input prompt_text = build_prompt(conversation_history, user_input) response_text = chat_with_model(prompt_text, stop_word, model, tokenizer) response_text = response_text.replace('', '') print(f"\n------\nAlice:\n{response_text}\n------") # Update conversation history conversation_history += f"<|im_start|>user\n{user_input}\n<|im_end|>\n<|im_start|>assistant\n{response_text}\n<|im_end|>\n" # Trim the conversation history to avoid overly long inputs if len(conversation_history) > 2048: conversation_history = conversation_history[-1024:] if __name__ == "__main__": main()