Spaces:
Runtime error
Runtime error
File size: 4,345 Bytes
9cc3eb2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
from functools import partial
from itertools import islice
from typing import Callable, List, Optional, Sequence, Union
import torch
import torch.nn.functional as F
def batched(iterable, n):
"""Batch data into lists of length *n*. The last batch may be shorter.
NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl
"""
it = iter(iterable)
while True:
batch = list(islice(it, n))
if not batch:
break
yield batch
def build_zero_shot_classifier(
model,
tokenizer,
classnames: Sequence[str],
templates: Sequence[Union[Callable, str]],
num_classes_per_batch: Optional[int] = 10,
device: Union[str, torch.device] = 'cpu',
use_tqdm: bool = False,
):
""" Build zero-shot classifier weights by iterating over class names in batches
Args:
model: CLIP model instance
tokenizer: CLIP tokenizer instance
classnames: A sequence of class (label) names
templates: A sequence of callables or format() friendly strings to produce templates per class name
num_classes_per_batch: The number of classes to batch together in each forward, all if None
device: Device to use.
use_tqdm: Enable TQDM progress bar.
"""
assert isinstance(templates, Sequence) and len(templates) > 0
assert isinstance(classnames, Sequence) and len(classnames) > 0
use_format = isinstance(templates[0], str)
num_templates = len(templates)
num_classes = len(classnames)
if use_tqdm:
import tqdm
num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1)
iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch)
else:
iter_wrap = iter
def _process_batch(batch_classnames):
num_batch_classes = len(batch_classnames)
texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates]
texts = tokenizer(texts).to(device)
class_embeddings = F.normalize(model.encode_text(texts), dim=-1)
class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1)
class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True)
class_embeddings = class_embeddings.T
return class_embeddings
with torch.no_grad():
if num_classes_per_batch:
batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))]
zeroshot_weights = torch.cat(batched_embeds, dim=1)
else:
zeroshot_weights = _process_batch(classnames)
return zeroshot_weights
def build_zero_shot_classifier_legacy(
model,
tokenizer,
classnames: Sequence[str],
templates: Sequence[Union[Callable, str]],
device: Union[str, torch.device] = 'cpu',
use_tqdm: bool = False,
):
""" Build zero-shot classifier weights by iterating over class names 1 by 1
Args:
model: CLIP model instance
tokenizer: CLIP tokenizer instance
classnames: A sequence of class (label) names
templates: A sequence of callables or format() friendly strings to produce templates per class name
device: Device to use.
use_tqdm: Enable TQDM progress bar.
"""
assert isinstance(templates, Sequence) and len(templates) > 0
assert isinstance(classnames, Sequence) and len(classnames) > 0
if use_tqdm:
import tqdm
iter_wrap = tqdm.tqdm
else:
iter_wrap = iter
use_format = isinstance(templates[0], str)
with torch.no_grad():
zeroshot_weights = []
for classname in iter_wrap(classnames):
texts = [template.format(classname) if use_format else template(classname) for template in templates]
texts = tokenizer(texts).to(device) # tokenize
class_embeddings = model.encode_text(texts)
class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
class_embedding /= class_embedding.norm()
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
return zeroshot_weights
|