File size: 807 Bytes
16d5d78 24d96ab 16d5d78 24d96ab 16d5d78 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
from typing import Union
import torch
from transformers import AutoTokenizer
from src.config import TinyCLIPTextConfig
class Tokenizer:
def __init__(self, text_config: TinyCLIPTextConfig) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(text_config.text_model)
self.max_len = text_config.max_len
def __call__(self, x: Union[str, list[str]]) -> dict[str, torch.LongTensor]:
return self.tokenizer(
x, max_length=self.max_len, truncation=True, padding=True, return_tensors="pt"
) # type: ignore
def decode(self, x: dict[str, torch.LongTensor]) -> list[str]:
return [
self.tokenizer.decode(sentence[:sentence_len])
for sentence, sentence_len in zip(x["input_ids"], x["attention_mask"].sum(axis=-1))
]
|