aredden's picture
Add offloading & improved fp8 inference.
28dec30
raw
history blame
4.17 kB
import os
import torch
from torch import Tensor, nn
from transformers import (
CLIPTextModel,
CLIPTokenizer,
T5EncoderModel,
T5Tokenizer,
__version__,
)
from transformers.utils.quantization_config import QuantoConfig, BitsAndBytesConfig
CACHE_DIR = os.environ.get("HF_HOME", "~/.cache/huggingface")
def into_quantization_name(quantization_dtype: str) -> str:
if quantization_dtype == "qfloat8":
return "float8"
elif quantization_dtype == "qint4":
return "int4"
elif quantization_dtype == "qint8":
return "int8"
elif quantization_dtype == "qint2":
return "int2"
else:
raise ValueError(f"Unsupported quantization dtype: {quantization_dtype}")
def auto_quantization_config(
quantization_dtype: str,
) -> QuantoConfig | BitsAndBytesConfig:
if quantization_dtype == "qfloat8":
return QuantoConfig(weights="float8")
elif quantization_dtype == "qint4":
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
)
elif quantization_dtype == "qint8":
return BitsAndBytesConfig(load_in_8bit=True, llm_int8_has_fp16_weight=False)
elif quantization_dtype == "qint2":
return QuantoConfig(weights="int2")
else:
raise ValueError(f"Unsupported quantization dtype: {quantization_dtype}")
class HFEmbedder(nn.Module):
def __init__(
self,
version: str,
max_length: int,
device: torch.device | int,
quantization_dtype: str | None = None,
offloading_device: torch.device | int | None = torch.device("cpu"),
**hf_kwargs,
):
super().__init__()
self.offloading_device = (
offloading_device
if isinstance(offloading_device, torch.device)
else torch.device(offloading_device)
)
self.device = (
device if isinstance(device, torch.device) else torch.device(device)
)
self.is_clip = version.startswith("openai")
self.max_length = max_length
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
if self.is_clip:
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
version, max_length=max_length
)
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
version,
**hf_kwargs,
quantization_config=(
auto_quantization_config(quantization_dtype)
if quantization_dtype
else None
),
)
else:
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
version, max_length=max_length
)
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
version,
**hf_kwargs,
quantization_config=(
auto_quantization_config(quantization_dtype)
if quantization_dtype
else None
),
)
def offload(self):
self.hf_module.to(device=self.offloading_device)
torch.cuda.empty_cache()
def cuda(self):
self.hf_module.to(device=self.device)
def forward(self, text: list[str]) -> Tensor:
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
outputs = self.hf_module(
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
attention_mask=None,
output_hidden_states=False,
)
return outputs[self.output_key]
if __name__ == "__main__":
model = HFEmbedder(
"city96/t5-v1_1-xxl-encoder-bf16",
max_length=512,
device=0,
quantization_dtype="qfloat8",
)
o = model(["hello"])
print(o)