File size: 3,764 Bytes
c4a514f
 
d9aea20
c4a514f
 
 
 
 
 
 
 
28dec30
c4a514f
 
d9aea20
c4a514f
28dec30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9aea20
 
c4a514f
 
 
 
 
28dec30
c4a514f
d9aea20
 
28dec30
 
 
 
 
 
 
 
d9aea20
 
 
 
 
 
 
 
c4a514f
d9aea20
c4a514f
 
 
28dec30
 
c4a514f
 
d9aea20
c4a514f
d9aea20
 
 
 
 
 
c4a514f
 
28dec30
 
c4a514f
d9aea20
 
 
28dec30
 
 
 
 
 
 
d9aea20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4a514f
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os

import torch
from torch import Tensor, nn
from transformers import (
    CLIPTextModel,
    CLIPTokenizer,
    T5EncoderModel,
    T5Tokenizer,
    __version__,
)
from transformers.utils.quantization_config import QuantoConfig, BitsAndBytesConfig

CACHE_DIR = os.environ.get("HF_HOME", "~/.cache/huggingface")


def auto_quantization_config(
    quantization_dtype: str,
) -> QuantoConfig | BitsAndBytesConfig:
    if quantization_dtype == "qfloat8":
        return QuantoConfig(weights="float8")
    elif quantization_dtype == "qint4":
        return BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_quant_type="nf4",
        )
    elif quantization_dtype == "qint8":
        return BitsAndBytesConfig(load_in_8bit=True, llm_int8_has_fp16_weight=False)
    elif quantization_dtype == "qint2":
        return QuantoConfig(weights="int2")
    else:
        raise ValueError(f"Unsupported quantization dtype: {quantization_dtype}")


class HFEmbedder(nn.Module):
    def __init__(
        self,
        version: str,
        max_length: int,
        device: torch.device | int,
        quantization_dtype: str | None = None,
        offloading_device: torch.device | int | None = torch.device("cpu"),
        **hf_kwargs,
    ):
        super().__init__()
        self.offloading_device = (
            offloading_device
            if isinstance(offloading_device, torch.device)
            else torch.device(offloading_device)
        )
        self.device = (
            device if isinstance(device, torch.device) else torch.device(device)
        )
        self.is_clip = version.startswith("openai")
        self.max_length = max_length
        self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"

        if self.is_clip:
            self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
                version, max_length=max_length
            )

            self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
                version,
                **hf_kwargs,
                quantization_config=(
                    auto_quantization_config(quantization_dtype)
                    if quantization_dtype
                    else None
                ),
            )

        else:
            self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
                version, max_length=max_length
            )
            self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
                version,
                **hf_kwargs,
                quantization_config=(
                    auto_quantization_config(quantization_dtype)
                    if quantization_dtype
                    else None
                ),
            )

    def offload(self):
        self.hf_module.to(device=self.offloading_device)
        torch.cuda.empty_cache()

    def cuda(self):
        self.hf_module.to(device=self.device)

    def forward(self, text: list[str]) -> Tensor:
        batch_encoding = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            return_length=False,
            return_overflowing_tokens=False,
            padding="max_length",
            return_tensors="pt",
        )
        outputs = self.hf_module(
            input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
            attention_mask=None,
            output_hidden_states=False,
        )
        return outputs[self.output_key]


if __name__ == "__main__":
    model = HFEmbedder(
        "city96/t5-v1_1-xxl-encoder-bf16",
        max_length=512,
        device=0,
        quantization_dtype="qfloat8",
    )
    o = model(["hello"])
    print(o)