Modification of example code for quantization and load models one by one
#3
by
emra
- opened
I modifies the example code so it can use bitsandbytes quantization and also load models one by one so it doesn't OOM, its a bit slower of course. (You may add this to readme.md if you like)
(btw if you want load quantized models at the beggining and not free the memory, you need to move load_model
function out of while: and delete the added model_regenerator.cpu() del model_critic gc.collect() torch.cuda.empty_cache()
lines)
import torch
import gc
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import BitsAndBytesConfig
model_path_actor = "/home/ubuntu/llm/HelixNet/actor"
model_path_critic = "/home/ubuntu/llm/HelixNet/critic"
model_path_regenerator = "/home/ubuntu/llm/HelixNet/regenerator"
nf4_config = BitsAndBytesConfig(
load_in_8bit=False,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
def load_model_quant(model_path):
model = AutoModelForCausalLM.from_pretrained(
model_path,
quantization_config=nf4_config
)
return model
def load_model(model_path):
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
device_map="cuda",
load_in_4bit=False,
trust_remote_code=True,
)
return model
def load_tokenizer(model_path):
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
return tokenizer
tokenizer_actor = load_tokenizer(model_path_actor)
tokenizer_critic = load_tokenizer(model_path_critic)
tokenizer_regenerator = load_tokenizer(model_path_regenerator)
def generate_text(instruction, model, tokenizer):
tokens = tokenizer.encode(instruction)
tokens = torch.LongTensor(tokens).unsqueeze(0)
tokens = tokens.to("cuda")
instance = {
"input_ids": tokens,
"top_p": 0.3,
"temperature": 0.75,
"generate_len": 1024,
"top_k": 50,
}
length = len(tokens[0])
with torch.no_grad():
rest = model.generate(
input_ids=tokens,
max_length=length + instance["generate_len"],
use_cache=True,
do_sample=True,
top_p=instance["top_p"],
temperature=instance["temperature"],
top_k=instance["top_k"],
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id
)
output = rest[0][length:]
string = tokenizer.decode(output, skip_special_tokens=True)
return f"{string}"
system_prompt = "You are HelixNet. Elaborate on the topic using a Tree of Thoughts and backtrack when necessary to construct a clear, cohesive Chain of Thought reasoning. Always answer without hesitation."
while True:
user_input = input("You: ")
model_actor = load_model(model_path_actor)
prompt_actor = f"SYSTEM: {system_prompt} \nUSER: {user_input} \nASSISTANT: "
actor_response = generate_text(prompt_actor, model_actor, tokenizer_actor)
print(f"ACTOR: {actor_response}\n\n")
model_actor.cpu()
del model_actor
gc.collect()
torch.cuda.empty_cache()
model_critic = load_model(model_path_critic)
prompt_critic = f"SYSTEM: {system_prompt} \nUSER: {user_input} \nRESPONSE: {actor_response} \nCRITIQUE:"
critic_response = generate_text(prompt_critic, model_critic, tokenizer_critic)
print(f"CRITIQUE: {critic_response}\n\n")
model_critic.cpu()
del model_critic
gc.collect()
torch.cuda.empty_cache()
model_regenerator = load_model(model_path_regenerator)
prompt_regenerator = f"SYSTEM: {system_prompt} \nUSER: {user_input} \nRESPONSE: {actor_response} \nCRITIQUE: {critic_response} \nREGENERATOR:"
regenerator_response = generate_text(prompt_regenerator, model_regenerator, tokenizer_regenerator)
print(f"REGENERATION: {regenerator_response}")
model_regenerator.cpu()
del model_regenerator
gc.collect()
torch.cuda.empty_cache()
Nice one! Thanks for sharing!
migtissera
changed discussion status to
closed