import torch import gradio as gr from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from peft import PeftModel, PeftConfig device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") peft_model_id = "kimmeoungjun/qlora-koalpaca" config = PeftConfig.from_pretrained(peft_model_id) model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path) model = PeftModel.from_pretrained(model, peft_model_id).to(device) tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) def my_split(s, seps): res = [s] for sep in seps: s, res = res, [] for seq in s: res += seq.split(sep) return res def chat_base(input): p = input input_ids = tokenizer(p, return_tensors="pt").input_ids.to(device) gen_tokens = model.generate(input_ids, do_sample=True, early_stopping=True, eos_token_id=2,) gen_text = tokenizer.batch_decode(gen_tokens)[0] # print(gen_text) result = gen_text[len(p):] # print(">", result) result = my_split(result, [']', '\n'])[1] # print(">>", result) # print(">>>", result) return result def chat(message): history = gr.get_state() or [] print(history) response = chat_base(message) history.append((message, response)) gr.set_state(history) html = "