Aria-torchao-int8wo / quant_int8wo.py
aria-dev's picture
first version
e83fa52
raw
history blame
2.84 kB
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
from torchao.quantization import (
quantize_,
)
from torchao.quantization.quant_api import _is_linear
import requests
from torchao.quantization.quant_api import to_affine_quantized_intx, MappingType, _get_linear_subclass_inserter
from moe_lm import GroupedGEMM
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.fx_graph_cache = True
model_id_or_path = "./out/aria-torchao-in8wo"
tokenizer_id_or_path = "./"
def int8_weight_only(group_size=None):
"""
Applies int8 weight-only symmetric per-channel quantization to linear layers.
"""
def apply_int8wo_quant(weight, group_size=None):
weight = weight.reshape(-1, weight.shape[-1])
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64
if group_size is None:
group_size = weight.shape[1]
block_size = (1, weight.shape[1])
return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
return _get_linear_subclass_inserter(apply_int8wo_quant, group_size=group_size)
model = AutoModelForCausalLM.from_pretrained(
model_id_or_path,
device_map="cuda",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
do_sample=True,
temperature=0.7,
)
model = torch.compile(model, mode="max-autotune", fullgraph=True)
def filter_fn(m, *args):
if "experts.fc1" in args[0] or "experts.fc2" in args[0]:
return True
return _is_linear(m, *args)
# quantize_(model, int8_weight_only(group_size=128), filter_fn=filter_fn)
print(model)
model.to("cuda")
messages = [
{
"role": "user",
"content": [
# {"text": None, "type": "image"},
{"text": "what's in the image?", "type": "text"},
],
}
]
image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"
image = Image.open(requests.get(image_path, stream=True).raw)
image = None
processor = AutoProcessor.from_pretrained(tokenizer_id_or_path, trust_remote_code=True)
text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=text, images=image, return_tensors="pt")
# inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
out = model.generate(**inputs, max_new_tokens=50, tokenizer=processor.tokenizer, stop_strings=["<|im_end|>"])
output_ids = out[0][inputs["input_ids"].shape[1] :]
result = processor.decode(output_ids, skip_special_tokens=True)
print(result)
# model.save_pretrained("out/aria-torchao-in8wo", safe_serialization=False)