YAML Metadata
Warning:
empty or missing yaml metadata in repo card
(https://huggingface.co/docs/hub/model-cards#model-card-metadata)
LoRA + llama-2-chat + Tesla T4 15 mins = Emoji ChatBot
Try yourself
Expected answers
What can I surprise my wife?ππ
What can I surprise my kids?ππ§οΏ½
What can I surprise my parents?ππͺ
What can I surprise my friends?ππ€
Code to reproduce
# !pip install -q bitsandbytes
# !pip install -q peft
# !pip install -q transformers
# !pip install sentencepiece
from peft import PeftModel
import transformers
import torch
from transformers import LlamaTokenizer, LlamaForCausalLM
model_name_or_path = "daryl149/llama-2-7b-chat-hf"
tokenizer_name_or_path = "daryl149/llama-2-7b-chat-hf"
tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name_or_path)
tokenizer.pad_token_id = 0
tokenizer.padding_side = "left"
original_8bit_llama_model = LlamaForCausalLM.from_pretrained(
model_name_or_path,
device_map="auto",
load_in_8bit=True)
emoji_model = PeftModel.from_pretrained(
original_8bit_llama_model, "hululuzhu/llama2-chat-emoji-lora")
CUTOFF_LEN = 48
def tokenize(tokenizer, prompt, cutoff_len, add_eos_token=True):
result = tokenizer(
prompt,
truncation=True,
max_length=cutoff_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < cutoff_len
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
def answer(p_in, my_model=emoji_model):
batch = tokenizer(
p_in,
return_tensors='pt',
)
with torch.cuda.amp.autocast(): # required for mixed precisions
output_tokens = my_model.generate(
**batch, max_new_tokens=batch['input_ids'].shape[-1])
# print(output_tokens[0])
out = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
# My own post-processing logic to "cheat" to align chars
if len(out) > len(p_in) * 2 - 7:
out = out[:len(p_in) * 2 - 7 - len(out)] # perfectly match chars
# replace the last N for visibility
if out.count('\n') > 1:
out = out[::-1].replace("\n", "n\\", 1)[::-1]
# if out.startswith(p_in):
# out = out[len(p_in):]
print(out)
print()
answer("What can I surprise my wife?")
answer("What can I surprise my kids?")
answer("What can I surprise my parents?")
answer("What can I surprise my friends?")