|
import os |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList |
|
|
|
|
|
model_path = "D1rtyB1rd/Dirty-Alice-Tiny-1.1B-V2-Chatml" |
|
model = AutoModelForCausalLM.from_pretrained(model_path) |
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
|
|
|
|
system_message = "<|im_start|>system\nYou are Alice.\n<|im_end|>" |
|
|
|
class StopWordCriteria(StoppingCriteria): |
|
def __init__(self, stop_words, tokenizer): |
|
self.stop_words = stop_words |
|
self.tokenizer = tokenizer |
|
|
|
def __call__(self, input_ids, scores, **kwargs): |
|
|
|
generated_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True) |
|
|
|
|
|
for stop_word in self.stop_words: |
|
if stop_word in generated_text: |
|
return True |
|
return False |
|
|
|
def chat_with_model(prompt_text, stop_word, model, tokenizer): |
|
|
|
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt") |
|
|
|
|
|
stopping_criteria = StoppingCriteriaList([StopWordCriteria(stop_words=[stop_word], tokenizer=tokenizer)]) |
|
|
|
|
|
output_sequences = model.generate( |
|
input_ids=encoded_prompt, |
|
max_new_tokens=1024, |
|
temperature=0.2, |
|
repetition_penalty=1.2, |
|
top_k=20, |
|
top_p=0.9, |
|
do_sample=True, |
|
num_return_sequences=1, |
|
stopping_criteria=stopping_criteria, |
|
) |
|
|
|
|
|
generated_sequence = output_sequences[0].tolist() |
|
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) |
|
|
|
|
|
if stop_word in text: |
|
text = text.split(stop_word)[0] |
|
|
|
response_text = text[len(prompt_text):].strip() |
|
return response_text |
|
|
|
def build_prompt(conversation_history, user_input): |
|
""" |
|
Constructs the prompt for the model using conversation history and the latest user input. |
|
""" |
|
prompt_text = f"{conversation_history}<|im_start|>user\n{user_input}\n<|im_end|>\n<|im_start|>assistant\n" |
|
return prompt_text |
|
|
|
def main(): |
|
|
|
conversation_history = f"{system_message}\n" |
|
stop_word = "<|im_end|>" |
|
|
|
|
|
while True: |
|
user_input = input("User: ") |
|
|
|
|
|
prompt_text = build_prompt(conversation_history, user_input) |
|
|
|
response_text = chat_with_model(prompt_text, stop_word, model, tokenizer) |
|
response_text = response_text.replace('<s>', '') |
|
|
|
print(f"\n------\nAlice:\n{response_text}\n------") |
|
|
|
|
|
conversation_history += f"<|im_start|>user\n{user_input}\n<|im_end|>\n<|im_start|>assistant\n{response_text}\n<|im_end|>\n" |
|
|
|
|
|
if len(conversation_history) > 2048: |
|
conversation_history = conversation_history[-1024:] |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|