File size: 4,168 Bytes
c4a514f d9aea20 c4a514f 28dec30 c4a514f d9aea20 c4a514f d9aea20 28dec30 d9aea20 c4a514f 28dec30 c4a514f d9aea20 28dec30 d9aea20 c4a514f d9aea20 c4a514f 28dec30 c4a514f d9aea20 c4a514f d9aea20 c4a514f 28dec30 c4a514f d9aea20 28dec30 d9aea20 c4a514f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import os
import torch
from torch import Tensor, nn
from transformers import (
CLIPTextModel,
CLIPTokenizer,
T5EncoderModel,
T5Tokenizer,
__version__,
)
from transformers.utils.quantization_config import QuantoConfig, BitsAndBytesConfig
CACHE_DIR = os.environ.get("HF_HOME", "~/.cache/huggingface")
def into_quantization_name(quantization_dtype: str) -> str:
if quantization_dtype == "qfloat8":
return "float8"
elif quantization_dtype == "qint4":
return "int4"
elif quantization_dtype == "qint8":
return "int8"
elif quantization_dtype == "qint2":
return "int2"
else:
raise ValueError(f"Unsupported quantization dtype: {quantization_dtype}")
def auto_quantization_config(
quantization_dtype: str,
) -> QuantoConfig | BitsAndBytesConfig:
if quantization_dtype == "qfloat8":
return QuantoConfig(weights="float8")
elif quantization_dtype == "qint4":
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
)
elif quantization_dtype == "qint8":
return BitsAndBytesConfig(load_in_8bit=True, llm_int8_has_fp16_weight=False)
elif quantization_dtype == "qint2":
return QuantoConfig(weights="int2")
else:
raise ValueError(f"Unsupported quantization dtype: {quantization_dtype}")
class HFEmbedder(nn.Module):
def __init__(
self,
version: str,
max_length: int,
device: torch.device | int,
quantization_dtype: str | None = None,
offloading_device: torch.device | int | None = torch.device("cpu"),
**hf_kwargs,
):
super().__init__()
self.offloading_device = (
offloading_device
if isinstance(offloading_device, torch.device)
else torch.device(offloading_device)
)
self.device = (
device if isinstance(device, torch.device) else torch.device(device)
)
self.is_clip = version.startswith("openai")
self.max_length = max_length
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
if self.is_clip:
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
version, max_length=max_length
)
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
version,
**hf_kwargs,
quantization_config=(
auto_quantization_config(quantization_dtype)
if quantization_dtype
else None
),
)
else:
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
version, max_length=max_length
)
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
version,
**hf_kwargs,
quantization_config=(
auto_quantization_config(quantization_dtype)
if quantization_dtype
else None
),
)
def offload(self):
self.hf_module.to(device=self.offloading_device)
torch.cuda.empty_cache()
def cuda(self):
self.hf_module.to(device=self.device)
def forward(self, text: list[str]) -> Tensor:
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
outputs = self.hf_module(
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
attention_mask=None,
output_hidden_states=False,
)
return outputs[self.output_key]
if __name__ == "__main__":
model = HFEmbedder(
"city96/t5-v1_1-xxl-encoder-bf16",
max_length=512,
device=0,
quantization_dtype="qfloat8",
)
o = model(["hello"])
print(o)
|