|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
def chat_with_model(model_path: str): |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_path).to(device) |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
|
|
if torch.cuda.is_available() and torch.cuda.device_count() > 1: |
|
print(f"Using {torch.cuda.device_count()} GPUs!") |
|
model = torch.nn.DataParallel(model) |
|
|
|
print("You're now chatting with the model. Type 'quit' to exit.") |
|
|
|
while True: |
|
|
|
input_text = input("You: ") |
|
if input_text.lower() == 'quit': |
|
break |
|
|
|
|
|
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
generated_text_samples = model.generate(input_ids, max_length=50, pad_token_id=tokenizer.eos_token_id) |
|
|
|
|
|
response_text = tokenizer.decode(generated_text_samples[0], skip_special_tokens=True) |
|
print("AI:", response_text) |
|
|
|
if __name__ == "__main__": |
|
model_path = '/home/energyxadmin/UI2/merge' |
|
chat_with_model(model_path) |
|
|