File size: 12,718 Bytes
2dbc77d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 |
"""
processing_prismatic.py
HuggingFace-style preprocessor definitions for Prismatic VLMs, inheriting from `ProcessorMixin`. Default configuration
specifies `siglip-224px+7b`.
"""
from typing import Any, ClassVar, List, Optional, Tuple, Union
import timm.data
import torch
import torchvision.transforms.functional as TVF
from PIL import Image
from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
from transformers import PreTrainedTokenizerBase
from transformers.image_processing_utils import BatchFeature, ImageProcessingMixin
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from transformers.utils import TensorType
# === Image Processing ===
def letterbox_pad_transform(image: Image.Image, padding_fill_value: Tuple[int, int, int]) -> Image.Image:
"""Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
(w, h), max_wh = image.size, max(image.size)
horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2)
padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
return TVF.pad(image, padding, fill=padding_fill_value, padding_mode="constant")
class PrismaticImageProcessor(ImageProcessingMixin):
model_input_names: ClassVar[List[str]] = ["pixel_values"]
def __init__(
self,
use_fused_vision_backbone: bool = False,
image_resize_strategy: str = "letterbox",
input_sizes: Optional[List[Tuple[int, int, int]]] = None,
interpolations: Optional[List[str]] = None,
means: Optional[List[Tuple[float, float, float]]] = None,
stds: Optional[List[Tuple[float, float, float]]] = None,
**kwargs: str,
) -> None:
"""
Initialize a PrismaticImageProcessor as a wrapper around a torchvision transform; this transform will be
created by TIMM, and edited to follow our custom `image_resize_strategy` logic.
@param use_fused_vision_backbone: Boolean indicating single or fused (dual) vision backbone
@param image_resize_strategy: Prismatic image resize strategy in < resize-naive | resize-crop | letterbox >
@param input_size: [TIMM :: `data_cfg`] Input image size as tuple (channels, width, height)
@param interpolation: [TIMM :: `data_cfg`] Interpolation as string (default: "bicubic")
@param mean: [TIMM :: `data_cfg`] Normalization mean as float tuple (or two-tuple if `fused_backbone`)
@param std: [TIMM :: `data_cfg`] Normalization std as float tuple (or two-tuple if `fused_backbone`)
"""
self.use_fused_vision_backbone = use_fused_vision_backbone
self.image_resize_strategy = image_resize_strategy
# Handle `None` default values
input_sizes = [(3, 224, 224)] if input_sizes is None else input_sizes
means = [(0.5, 0.5, 0.5)] if means is None else means
stds = [(0.5, 0.5, 0.5)] if stds is None else stds
# TIMM `data_cfg` Parameters
self.input_sizes, self.interpolations, self.means, self.stds = input_sizes, interpolations, means, stds
# Grab torchvision transforms via TIMM =>> need to parse for specific "functional" transform values!
self.tvf_resize_params, self.tvf_crop_params, self.tvf_normalize_params = [], [], []
self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
for idx in range(len(input_sizes)):
transform = timm.data.create_transform(
input_size=self.input_sizes[idx],
interpolation=self.interpolations[idx],
mean=self.means[idx],
std=self.stds[idx],
crop_pct=1.0, # Set to 1.0 to ignore cropping (initial Resize sets `input_size`)
crop_mode="center", # Default crop mode -- no-op when `crop_pct == 1.0`
is_training=False, # No image augmentations when loading the transform!
)
# [Validation] Ensure appropriate transform structure, expected sizes
if not (
isinstance(transform, Compose)
and (len(transform.transforms) == 4)
and isinstance(transform.transforms[0], Resize)
and isinstance(transform.transforms[1], CenterCrop)
and isinstance(transform.transforms[2], ToTensor)
and isinstance(transform.transforms[3], Normalize)
and (transform.transforms[0].size == self.input_sizes[idx][-1])
and (transform.transforms[1].size == self.input_sizes[idx][-2:])
):
raise ValueError(f"Unexpected TIMM image transformation structure/sizes: `{transform}`")
# HF Image Processors *must* be JSON-serializable; as such, cannot have torchvision. as an attribute.
# => Instead, we're going to parse the transform and call "torchvision.transforms.functional" (`tvf`)
resize_t, crop_t, norm_t = transform.transforms[0], transform.transforms[1], transform.transforms[3]
self.tvf_resize_params.append(
{
"size": resize_t.size,
"interpolation": TVF.pil_modes_mapping[resize_t.interpolation],
"max_size": None,
"antialias": True,
}
)
self.tvf_crop_params.append({"output_size": crop_t.size})
self.tvf_normalize_params.append(
{
"mean": norm_t.mean.float().numpy().tolist(),
"std": norm_t.std.float().numpy().tolist(),
"inplace": False,
}
)
self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
# Handle Prismatic `image_resize_strategy`
if self.image_resize_strategy == "resize-naive":
self.tvf_resize_params[idx]["size"] = (resize_t.size, resize_t.size)
elif self.image_resize_strategy == "letterbox":
self.tvf_do_letterbox, self.tvf_letterbox_fill = True, tuple([int(x * 255) for x in self.means[idx]])
elif self.image_resize_strategy == "resize-crop":
pass
else:
raise ValueError(f"Image resize strategy `{self.image_resize_strategy}` is not supported!")
# Dispatch **kwargs to super()
super().__init__(**kwargs)
def apply_transform(self, img: Image.Image) -> torch.Tensor:
"""Apply `functional` variant of TIMM's Transform = Compose([Resize -> CenterCrop -> ToTensor -> Normalize])"""
if self.tvf_do_letterbox:
img = letterbox_pad_transform(img, self.tvf_letterbox_fill)
# [Contract] Fused Backbones expect "channel-stacked" inputs; we'll unpack on the model side!
imgs_t = []
for idx in range(len(self.input_sizes)):
img_idx = TVF.resize(img, **self.tvf_resize_params[idx])
img_idx = TVF.center_crop(img_idx, **self.tvf_crop_params[idx])
img_idx_t = TVF.to_tensor(img_idx)
img_idx_t = TVF.normalize(img_idx_t, **self.tvf_normalize_params[idx])
imgs_t.append(img_idx_t)
# [Contract] `imgs_t` is a list of Tensors of shape [3, input_size, input_size]; stack along dim = 0
img_t = torch.vstack(imgs_t)
return img_t
def preprocess(
self,
images: Union[Image.Image, List[Image.Image]],
return_tensors: Optional[Union[str, TensorType]] = None,
**_: str,
) -> BatchFeature:
"""
Preprocess an image (or batch of images); note that unlike the `transformers :: BaseImageProcessor` we
explicitly only handle PIL.Image.Image instances for simplicity.
@param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
@param return_tensors: BatchFeature default Tensor format (e.g., "pt" for torch); if None, returns np.ndarray
@return: Instance of `transformers :: BatchFeature` with a single key "pixel_values"
"""
if not isinstance(images, list):
images = [images]
# Apply `self.img_transform` to each image (will return list of torch.Tensors); stack into "batched" Tensor
pixel_values = torch.stack([self.apply_transform(img.convert("RGB")) for img in images])
# Return BatchFeature =>> note that for compatibility, constructor expects Dict[str, np.ndarray], so we convert
return BatchFeature(data={"pixel_values": pixel_values.float().numpy()}, tensor_type=return_tensors)
def __call__(self, images: Union[Image.Image, List[Image.Image]], **kwargs) -> BatchFeature:
return self.preprocess(images, **kwargs)
# === PrismaticProcessor =>> Wraps both ImageProcessor and Tokenizer ===
# =>> https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/processing_llava.py
class PrismaticProcessor(ProcessorMixin):
attributes: ClassVar[List[str]] = ["image_processor", "tokenizer"]
image_processor_class: str = "AutoImageProcessor"
tokenizer_class: str = "AutoTokenizer"
def __init__(
self,
image_processor: Optional[ImageProcessingMixin] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
) -> None:
super().__init__(image_processor, tokenizer)
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
images: Union[Image.Image, List[Image.Image]],
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Optional[Union[bool, str, TruncationStrategy]] = None,
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
) -> BatchFeature:
"""
Preprocess a given (batch) of text/images for a Prismatic VLM; forwards text to the underlying LLM's tokenizer,
forwards images to PrismaticImageProcessor.
@param text: The (batch) of text to encode; must be a string or list of strings.
@param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
@param padding: Sequence padding strategy (if multiple specified) in < True = "longest" | "max_length" | False >
@param truncation: Truncation strategy for the output sequences; requires `max_length` to be specified
@param max_length: Maximum length (in tokens) to truncate
@param return_tensors: Type of return tensors (usually "pt" or TensorType.PYTORCH)
@return: BatchFeature with keys for `input_ids`, `attention_mask` and `pixel_values`.
"""
pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"]
text_inputs = self.tokenizer(
text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
)
# [Validate] Need same number of images and text inputs!
if pixel_values.shape[0] != text_inputs.input_ids.shape[0]:
raise ValueError("Batch is malformed; expected same number of images and text inputs!")
return BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
# === Tokenizer Dispatch Utilities =>> check `PreTrainedTokenizerBase` for documentation ===
def batch_decode(
self,
sequences: Union[List[int], List[List[int]], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: Optional[bool] = None,
**kwargs: str,
) -> List[str]:
return self.tokenizer.batch_decode(
sequences=sequences,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
def decode(
self,
token_ids: Union[int, List[int], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: Optional[bool] = None,
**kwargs: str,
) -> str:
return self.tokenizer.decode(
token_ids=token_ids,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
@property
def model_input_names(self) -> List[str]:
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|