File size: 11,104 Bytes
a71961f
 
 
 
e7432eb
a71961f
 
 
 
 
 
 
 
 
 
 
e7432eb
440a9f4
 
a71961f
e7432eb
 
 
 
 
 
6d8e609
a71961f
e7432eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a71961f
e7432eb
 
 
 
440a9f4
e7432eb
 
 
 
 
 
 
 
 
 
440a9f4
6d8e609
 
 
440a9f4
6d8e609
a71961f
d220929
 
440a9f4
a71961f
440a9f4
6d8e609
440a9f4
a71961f
440a9f4
 
6d8e609
440a9f4
a71961f
 
 
 
 
d220929
 
 
a71961f
 
440a9f4
a71961f
 
 
 
440a9f4
a71961f
b8b8f72
440a9f4
d220929
a71961f
 
 
 
440a9f4
a71961f
440a9f4
 
 
 
 
 
 
 
a71961f
d220929
a71961f
d220929
440a9f4
d220929
 
440a9f4
a71961f
 
d220929
440a9f4
b8b8f72
 
440a9f4
 
a71961f
d220929
440a9f4
a71961f
440a9f4
 
d220929
 
a71961f
d220929
 
a71961f
440a9f4
 
 
d220929
 
440a9f4
 
d220929
440a9f4
 
 
 
 
 
d220929
 
 
 
440a9f4
 
 
 
d220929
440a9f4
 
 
 
 
d220929
440a9f4
a71961f
 
440a9f4
a71961f
440a9f4
a71961f
 
d220929
a71961f
 
d220929
 
 
 
 
 
 
a71961f
 
 
 
 
 
 
b8b8f72
a71961f
d220929
 
 
 
440a9f4
d220929
 
440a9f4
d220929
a71961f
d220929
 
a71961f
d220929
b8b8f72
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
import base64
import json
import os
from io import BytesIO
from typing import Any, Dict, List, Literal, Optional, Union

import requests
import torch
from PIL import Image
from torch import nn
from transformers import AutoConfig, AutoImageProcessor, AutoModel, AutoTokenizer


class Transformer(nn.Module):
    def __init__(
        self,
        model_name_or_path: str = 'jinaai/jina-clip-v2',
        tokenizer_name_or_path: Optional[str] = None,
        image_processor_name_or_path: Optional[str] = None,
        max_seq_length: Optional[int] = None,
        config_args: Optional[Dict[str, Any]] = None,
        model_args: Optional[Dict[str, Any]] = None,
        tokenizer_args: Optional[Dict[str, Any]] = None,
        image_processor_args: Optional[Dict[str, Any]] = None,
        cache_dir: Optional[str] = None,
        backend: Literal['torch', 'onnx', 'openvino'] = 'torch',
        **_,
    ) -> None:
        """
        Creates a custom SentenceTransformer module that uses `jinai/jina-clip-v2` to
        map sentences/images to embeddings

        Args:
            model_name_or_path (str, optional): If it is a filepath on disc, it loads
                the model from that path. If it is not a path, tries to construct a
                model from the Hugging Face Hub with that name. Defaults to
                'jinaai/jina-clip-v2'
            tokenizer_name_or_path (str, optional): If it is a filepath on disc, it
                loads the tokenizer from that path. If it is not a path, tries to
                construct a tokenizer from the Hugging Face Hub with that name.
                If `None` it is automatically set to the value of `model_name_or_path`
            image_processor_name_or_path (str, optional): If it is a filepath on disc,
                it loads the image processor from that path. If it is not a path, tries
                to construct an image processor from the Hugging Face Hub with that
                name. If `None` it is automatically set to the value of
                `model_name_or_path`
            max_seq_length (int, optional): The maximum sequence length of the model.
                If not provided, will be inferred from model or tokenizer
            config_args (Dict[str, Any], optional): Additional model configuration
                parameters to be passed to the Hugging Face Transformers config
            model_args (Dict[str, Any], optional): Additional model configuration
                parameters to be passed to the Hugging Face Transformers model
            tokenizer_args (Dict[str, Any], optional): Additional tokenizer
                configuration parameters to be passed to the Hugging Face Transformers
                tokenizer
            image_processor_args (Dict[str, Any], optional): Additional image processor
                configuration parameters to be passed to the Hugging Face Transformers
                image processor
            cache_dir (str, optional): The Hugging Face Hub cache directory
            backend (str, optional): Computational backend, only 'torch' is supported

        Example:
            ::

                from sentence_transformers import SentenceTransformer

                model = SentenceTransformer(
                    'jinaai/jina-clip-v2', trust_remote_code=True
                )
                sentences_or_images = [
                    "The weather is lovely today.",
                    "It's so sunny outside!",
                    "/path/to/stadium.jpg",
                ]
                embeddings = model.encode(sentences_or_images)
                print(embeddings.shape)
                # (3, 1024)

                # Get the similarity scores between all inputs
                similarities = model.similarity(embeddings, embeddings)
                print(similarities)
                # tensor([[1.0000, 0.6817, 0.0492],
                #         [0.6817, 1.0000, 0.0421],
                #         [0.0492, 0.0421, 1.0000]])
        """
        super(Transformer, self).__init__()
        if backend != 'torch':
            raise ValueError(
                f'Backend \'{backend}\' is not supported, please use \'torch\' instead'
            )

        config_kwargs = config_args or {}
        model_kwargs = model_args or {}
        tokenizer_kwargs = tokenizer_args or {}
        image_processor_kwargs = {
            'token': model_kwargs.get('token', None),
            'trust_remote_code': model_kwargs.get('trust_remote_code', False),
            'revision': model_kwargs.get('revision', None),
            'local_files_only': model_kwargs.get('local_files_only', None),
        }
        image_processor_kwargs.update(image_processor_args)

        config = AutoConfig.from_pretrained(
            model_name_or_path, cache_dir=cache_dir, **config_kwargs
        )
        self.model = AutoModel.from_pretrained(
            model_name_or_path, config=config, cache_dir=cache_dir, **model_kwargs
        )
        if max_seq_length is not None and 'model_max_length' not in tokenizer_kwargs:
            tokenizer_kwargs['model_max_length'] = max_seq_length

        self.tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_name_or_path or model_name_or_path,
            cache_dir=cache_dir,
            **tokenizer_kwargs,
        )
        self.image_processor = AutoImageProcessor.from_pretrained(
            image_processor_name_or_path or model_name_or_path,
            cache_dir=cache_dir,
            **image_processor_kwargs,
        )

        # No max_seq_length set. Try to infer from model
        if max_seq_length is None:
            if (
                hasattr(self.model, 'config')
                and hasattr(self.model.config, 'max_position_embeddings')
                and hasattr(self.tokenizer, 'model_max_length')
            ):
                max_seq_length = min(
                    self.model.config.max_position_embeddings,
                    self.tokenizer.model_max_length,
                )
        self.max_seq_length = max_seq_length
        if tokenizer_name_or_path is not None:
            self.model.config.tokenizer_class = self.tokenizer.__class__.__name__

    @staticmethod
    def _decode_data_image(data_image_str: str) -> Image.Image:
        header, data = data_image_str.split(',', 1)
        image_data = base64.b64decode(data)
        return Image.open(BytesIO(image_data))

    def tokenize(
        self, texts: List[Union[str, Image.Image]], padding: Union[str, bool] = True
    ) -> Dict[str, torch.Tensor]:
        """
        Encodes input samples. Text samples are tokenized. Image URLs, image data
        buffers and PIL images are passed through the image processor.
        """
        _images = []
        _texts = []
        _image_or_text_descriptors = []
        for sample in texts:
            if isinstance(sample, str):
                if sample.startswith('http'):
                    response = requests.get(sample)
                    _images.append(Image.open(BytesIO(response.content)).convert('RGB'))
                    _image_or_text_descriptors.append(0)
                elif sample.startswith('data:image/'):
                    _images.append(self._decode_data_image(sample).convert('RGB'))
                    _image_or_text_descriptors.append(0)
                else:
                    try:
                        _images.append(Image.open(sample).convert('RGB'))
                        _image_or_text_descriptors.append(0)
                    except Exception as e:
                        _ = str(e)
                        _texts.append(sample)
                        _image_or_text_descriptors.append(1)
            elif isinstance(sample, Image.Image):
                _images.append(sample.convert('RGB'))
                _image_or_text_descriptors.append(0)

        encoding = {}
        if len(_texts):
            encoding['input_ids'] = self.tokenizer(
                _texts,
                padding=padding,
                truncation='longest_first',
                return_tensors='pt',
                max_length=self.max_seq_length,
            ).input_ids

        if len(_images):
            encoding['pixel_values'] = self.image_processor(
                _images, return_tensors='pt'
            ).pixel_values

        encoding['image_text_info'] = _image_or_text_descriptors
        return encoding

    def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        image_embeddings = []
        text_embeddings = []

        if 'pixel_values' in features:
            image_embeddings = self.model.get_image_features(features['pixel_values'])
        if 'input_ids' in features:
            text_embeddings = self.model.get_text_features(features['input_ids'])

        sentence_embedding = []
        image_features = iter(image_embeddings)
        text_features = iter(text_embeddings)
        for _, _input_type in enumerate(features['image_text_info']):
            if _input_type == 0:
                sentence_embedding.append(next(image_features))
            else:
                sentence_embedding.append(next(text_features))

        features['sentence_embedding'] = torch.stack(sentence_embedding).float()
        return features

    def save(self, output_path: str, safe_serialization: bool = True) -> None:
        self.model.save_pretrained(output_path, safe_serialization=safe_serialization)
        self.tokenizer.save_pretrained(output_path)
        self.image_processor.save_pretrained(output_path)

    @staticmethod
    def load(input_path: str) -> 'Transformer':
        # Old classes used other config names than 'sentence_bert_config.json'
        for config_name in [
            'sentence_bert_config.json',
            'sentence_roberta_config.json',
            'sentence_distilbert_config.json',
            'sentence_camembert_config.json',
            'sentence_albert_config.json',
            'sentence_xlm-roberta_config.json',
            'sentence_xlnet_config.json',
        ]:
            sbert_config_path = os.path.join(input_path, config_name)
            if os.path.exists(sbert_config_path):
                break

        with open(sbert_config_path) as fIn:
            config = json.load(fIn)

        # Don't allow configs to set trust_remote_code
        if 'config_kwargs' in config and 'trust_remote_code' in config['config_kwargs']:
            config['config_kwargs'].pop('trust_remote_code')
        if 'model_kwargs' in config and 'trust_remote_code' in config['model_kwargs']:
            config['model_kwargs'].pop('trust_remote_code')
        if (
            'tokenizer_kwargs' in config
            and 'trust_remote_code' in config['tokenizer_kwargs']
        ):
            config['tokenizer_kwargs'].pop('trust_remote_code')
        if (
            'image_processor_kwargs' in config
            and 'trust_remote_code' in config['image_processor_kwargs']
        ):
            config['image_processor_kwargs'].pop('trust_remote_code')

        return Transformer(model_name_or_path=input_path, **config)