add remote code and hf-format "pytorch_model.bin"

#20
by chuhac - opened
config.json ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "",
3
+ "architectures": [
4
+ "BiomedCLIPModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_biomed_clip.BiomedCLIPConfig",
8
+ "AutoProcessor": "processing_biomed_clip.BiomedCLIPProcessor",
9
+ "AutoModel": "modeling_biomed_clip.BiomedCLIPModel",
10
+ "AutoModelForImageClassification": "modeling_biomed_clip.BiomedCLIPForImageClassification"
11
+ },
12
+ "initializer_factor": 1.0,
13
+ "logit_scale_init_value": 4.4454,
14
+ "model_type": "clip",
15
+ "projection_dim": 512,
16
+ "text_config": {
17
+ "attention_probs_dropout_prob": 0.1,
18
+ "gradient_checkpointing": false,
19
+ "hidden_act": "gelu",
20
+ "hidden_dropout_prob": 0.1,
21
+ "hidden_size": 768,
22
+ "initializer_range": 0.02,
23
+ "intermediate_size": 3072,
24
+ "layer_norm_eps": 1e-12,
25
+ "max_position_embeddings": 512,
26
+ "model_type": "bert",
27
+ "num_attention_heads": 12,
28
+ "num_hidden_layers": 12,
29
+ "pad_token_id": 0,
30
+ "position_embedding_type": "absolute",
31
+ "transformers_version": "4.6.0.dev0",
32
+ "type_vocab_size": 2,
33
+ "use_cache": true,
34
+ "vocab_size": 30522
35
+ },
36
+ "text_config_dict": {
37
+ "attention_probs_dropout_prob": 0.1,
38
+ "gradient_checkpointing": false,
39
+ "hidden_act": "gelu",
40
+ "hidden_dropout_prob": 0.1,
41
+ "hidden_size": 768,
42
+ "initializer_range": 0.02,
43
+ "intermediate_size": 3072,
44
+ "layer_norm_eps": 1e-12,
45
+ "max_position_embeddings": 512,
46
+ "model_type": "bert",
47
+ "num_attention_heads": 12,
48
+ "num_hidden_layers": 12,
49
+ "pad_token_id": 0,
50
+ "position_embedding_type": "absolute",
51
+ "transformers_version": "4.6.0.dev0",
52
+ "type_vocab_size": 2,
53
+ "use_cache": true,
54
+ "vocab_size": 30522
55
+ },
56
+ "text_projection_config": {
57
+ "hidden_size": 768,
58
+ "intermediate_size": 640,
59
+ "projection_dim": 512,
60
+ "hidden_act": "gelu"
61
+ },
62
+ "text_projection_config_dict": {
63
+ "hidden_size": 768,
64
+ "intermediate_size": 640,
65
+ "projection_dim": 512,
66
+ "hidden_act": "gelu",
67
+ "num_hidden_layers": 2
68
+ },
69
+ "torch_dtype": "float32",
70
+ "transformers_version": null,
71
+ "vision_config": {
72
+ "attention_probs_dropout_prob": 0.0,
73
+ "hidden_act": "gelu",
74
+ "hidden_dropout_prob": 0.0,
75
+ "hidden_size": 768,
76
+ "image_size": 224,
77
+ "initializer_range": 0.02,
78
+ "intermediate_size": 3072,
79
+ "layer_norm_eps": 1e-12,
80
+ "model_type": "vit",
81
+ "num_attention_heads": 12,
82
+ "num_channels": 3,
83
+ "num_hidden_layers": 12,
84
+ "patch_size": 16,
85
+ "qkv_bias": true
86
+ },
87
+ "vision_config_dict": {
88
+ "attention_probs_dropout_prob": 0.0,
89
+ "hidden_act": "gelu",
90
+ "hidden_dropout_prob": 0.0,
91
+ "hidden_size": 768,
92
+ "image_size": 224,
93
+ "initializer_range": 0.02,
94
+ "intermediate_size": 3072,
95
+ "layer_norm_eps": 1e-12,
96
+ "model_type": "vit",
97
+ "num_attention_heads": 12,
98
+ "num_channels": 3,
99
+ "num_hidden_layers": 12,
100
+ "patch_size": 16,
101
+ "qkv_bias": true
102
+ }
103
+ }
configuration_biomed_clip.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import *
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.models.clip.configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
5
+
6
+ class BiomedCLIPTextProjectionConfig(PretrainedConfig):
7
+ def __init__(
8
+ self,
9
+ hidden_size=768,
10
+ intermediate_size=640,
11
+ projection_dim=512,
12
+ num_hidden_layers=2,
13
+ **kwargs,
14
+ ):
15
+ super().__init__(**kwargs)
16
+
17
+ self.hidden_size = hidden_size
18
+ self.intermediate_size = intermediate_size
19
+ self.projection_dim = projection_dim
20
+ self.num_hidden_layers = num_hidden_layers
21
+
22
+ @classmethod
23
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
24
+ cls._set_token_in_kwargs(kwargs)
25
+
26
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
27
+
28
+ # get the vision config dict if we are loading from CLIPConfig
29
+ if config_dict.get("model_type") == "clip":
30
+ config_dict = config_dict["text_projection_config"]
31
+
32
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
33
+ logger.warning(
34
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
35
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
36
+ )
37
+
38
+ return cls.from_dict(config_dict, **kwargs)
39
+
40
+ class BiomedCLIPConfig(CLIPConfig):
41
+ def __init__(
42
+ self, text_config=None, text_projection_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
43
+ ):
44
+ # If `_config_dict` exist, we use them for the backward compatibility.
45
+ # We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot
46
+ # of confusion!).
47
+ super().__init__(text_config, vision_config, projection_dim, logit_scale_init_value, **kwargs)
48
+
49
+ text_projection_config_dict = kwargs.pop("text_projection_config_dict", None)
50
+ if text_projection_config is None:
51
+ if text_projection_config_dict is not None:
52
+ text_projection_config = {}
53
+
54
+ _text_projection_config_dict = BiomedCLIPTextProjectionConfig(**text_projection_config_dict)
55
+
56
+ text_projection_config.update(_text_projection_config_dict)
57
+ else:
58
+ text_projection_config = BiomedCLIPTextProjectionConfig(**text_projection_config)
59
+
60
+ self.text_projection_config = text_projection_config
modeling_biomed_clip.py ADDED
@@ -0,0 +1,938 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Modified by chuhac for a timm-free implementation
3
+ # Model can be directly imported with ``from_pretrained`` and ``trust_remote_code = True`` in the huggingface format
4
+ # Diff from HF CLIP Implementation:
5
+ # 1. pre-norm instead of post-norm in Vision Tower (the original implementation is right but the module registration order is misleading)
6
+ # 2. CLS Pooling with MLP in Text Tower
7
+ # 3. Remove pre norm in Vision Tower
8
+ # 4. CNN bias in Vision Tower
9
+ # 5. Change layer_norm eps from 1e-5 to 1e-12, which introduce a little numerical variations (1e-5 level)
10
+ ## ******************************** ##
11
+ # Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
12
+ # Licensed under the Apache License, Version 2.0 (the "License");
13
+ # you may not use this file except in compliance with the License.
14
+ # You may obtain a copy of the License at
15
+ #
16
+ # http://www.apache.org/licenses/LICENSE-2.0
17
+ #
18
+ # Unless required by applicable law or agreed to in writing, software
19
+ # distributed under the License is distributed on an "AS IS" BASIS,
20
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21
+ # See the License for the specific language governing permissions and
22
+ # limitations under the License.
23
+ """ PyTorch BiomedCLIP model """
24
+ """ No need for timm or open-clip-torch """
25
+
26
+
27
+ from dataclasses import dataclass
28
+ from typing import Any, Optional, Tuple, Union, List
29
+
30
+ import math
31
+ import torch
32
+ import torch.utils.checkpoint
33
+ from torch import nn
34
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
35
+
36
+ from transformers.activations import ACT2FN
37
+ from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
38
+ from transformers.modeling_outputs import (
39
+ BaseModelOutput,
40
+ BaseModelOutputWithPooling,
41
+ ImageClassifierOutput,
42
+ BaseModelOutputWithPoolingAndCrossAttentions,
43
+ BaseModelOutputWithPastAndCrossAttentions
44
+ )
45
+ from transformers.modeling_utils import PreTrainedModel
46
+ from transformers.utils import (
47
+ ModelOutput,
48
+ add_code_sample_docstrings,
49
+ add_start_docstrings,
50
+ add_start_docstrings_to_model_forward,
51
+ logging,
52
+ replace_return_docstrings,
53
+ )
54
+ from transformers.models.clip.configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
55
+ from transformers.models.clip.modeling_clip import *
56
+
57
+ from .configuration_biomed_clip import BiomedCLIPTextProjectionConfig, BiomedCLIPConfig
58
+
59
+
60
+ logger = logging.get_logger(__name__)
61
+
62
+
63
+
64
+ # contrastive loss function, adapted from
65
+ # https://sachinruk.github.io/blog/2021-03-07-clip.html
66
+ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
67
+ return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
68
+
69
+
70
+ def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
71
+ caption_loss = contrastive_loss(similarity)
72
+ image_loss = contrastive_loss(similarity.t())
73
+ return (caption_loss + image_loss) / 2.0
74
+
75
+
76
+ class BiomedCLIPVisionEmbeddings(CLIPVisionEmbeddings):
77
+ def __init__(self, config: CLIPVisionConfig):
78
+ super().__init__(config)
79
+
80
+ self.patch_embedding = nn.Conv2d(
81
+ in_channels=config.num_channels,
82
+ out_channels=self.embed_dim,
83
+ kernel_size=self.patch_size,
84
+ stride=self.patch_size,
85
+ # True in open_clip
86
+ bias=True,
87
+ )
88
+
89
+ # TODO
90
+ class BiomedCLIPTextEmbeddings(nn.Module):
91
+ def __init__(self, config: CLIPTextConfig):
92
+ super().__init__()
93
+ embed_dim = config.hidden_size
94
+
95
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
96
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
97
+ self.token_type_embedding = nn.Embedding(config.type_vocab_size, embed_dim)
98
+
99
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
100
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
101
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
102
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
103
+
104
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
105
+ self.register_buffer(
106
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
107
+ )
108
+ self.register_buffer(
109
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
110
+ )
111
+
112
+ def forward(
113
+ self,
114
+ input_ids: Optional[torch.LongTensor] = None,
115
+ token_type_ids: Optional[torch.LongTensor] = None,
116
+ position_ids: Optional[torch.LongTensor] = None,
117
+ inputs_embeds: Optional[torch.FloatTensor] = None,
118
+ past_key_values_length: int = 0,
119
+ ) -> torch.Tensor:
120
+
121
+ if input_ids is not None:
122
+ input_shape = input_ids.size()
123
+ else:
124
+ input_shape = inputs_embeds.size()[:-1]
125
+
126
+ seq_length = input_shape[1]
127
+
128
+ if position_ids is None:
129
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
130
+
131
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
132
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
133
+ # issue #5664
134
+ if token_type_ids is None:
135
+ if hasattr(self, "token_type_ids"):
136
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
137
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
138
+ token_type_ids = buffered_token_type_ids_expanded
139
+ else:
140
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
141
+
142
+ if inputs_embeds is None:
143
+ inputs_embeds = self.token_embedding(input_ids)
144
+ token_type_embeddings = self.token_type_embedding(token_type_ids)
145
+
146
+ embeddings = inputs_embeds + token_type_embeddings
147
+ if self.position_embedding_type == "absolute":
148
+ position_embeddings = self.position_embedding(position_ids)
149
+ embeddings += position_embeddings
150
+
151
+ embeddings = self.layer_norm(embeddings)
152
+ embeddings = self.dropout(embeddings)
153
+ return embeddings
154
+
155
+
156
+ class BiomedCLIPAttention(nn.Module):
157
+ def __init__(self, config, position_embedding_type=None):
158
+ super().__init__()
159
+ super().__init__()
160
+ self.config = config
161
+ self.embed_dim = config.hidden_size
162
+ self.num_heads = config.num_attention_heads
163
+ self.head_dim = self.embed_dim // self.num_heads
164
+ if self.head_dim * self.num_heads != self.embed_dim:
165
+ raise ValueError(
166
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
167
+ f" {self.num_heads})."
168
+ )
169
+ self.scale = self.head_dim**-0.5
170
+ self.dropout = nn.Dropout(config.attention_dropout)
171
+
172
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
173
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
174
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
175
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
176
+
177
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
178
+ new_x_shape = x.size()[:-1] + (self.num_heads, self.head_dim)
179
+ x = x.view(new_x_shape)
180
+ return x.permute(0, 2, 1, 3)
181
+
182
+ def forward(
183
+ self,
184
+ hidden_states: torch.Tensor,
185
+ attention_mask: Optional[torch.FloatTensor] = None,
186
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
187
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
188
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
189
+ output_attentions: Optional[bool] = False,
190
+ ) -> Tuple[torch.Tensor]:
191
+
192
+ mixed_query_layer = self.q_proj(hidden_states)
193
+
194
+ # If this is instantiated as a cross-attention module, the keys
195
+ # and values come from an encoder; the attention mask needs to be
196
+ # such that the encoder's padding tokens are not attended to.
197
+ is_cross_attention = encoder_hidden_states is not None
198
+
199
+ key_layer = self.transpose_for_scores(self.k_proj(hidden_states))
200
+ value_layer = self.transpose_for_scores(self.v_proj(hidden_states))
201
+
202
+ query_layer = self.transpose_for_scores(mixed_query_layer)
203
+
204
+
205
+ # Take the dot product between "query" and "key" to get the raw attention scores.
206
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
207
+
208
+
209
+ attention_scores = attention_scores / math.sqrt(self.head_dim)
210
+ if attention_mask is not None:
211
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
212
+ attention_scores = attention_scores + attention_mask
213
+
214
+ # Normalize the attention scores to probabilities.
215
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
216
+
217
+ # This is actually dropping out entire tokens to attend to, which might
218
+ # seem a bit unusual, but is taken from the original Transformer paper.
219
+ attention_probs = self.dropout(attention_probs)
220
+
221
+
222
+ context_layer = torch.matmul(attention_probs, value_layer)
223
+
224
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
225
+ new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
226
+ context_layer = context_layer.view(new_context_layer_shape).contiguous()
227
+
228
+ outputs = self.out_proj(context_layer)
229
+ return outputs, attention_probs
230
+
231
+
232
+
233
+
234
+ class BiomedCLIPEncoderLayer(nn.Module):
235
+ def __init__(self, config: BiomedCLIPConfig, norm='pre'):
236
+ super().__init__()
237
+ self.embed_dim = config.hidden_size
238
+ # pre-norm
239
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
240
+ self.self_attn = BiomedCLIPAttention(config)
241
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
242
+ self.mlp = CLIPMLP(config)
243
+ self.norm = norm
244
+
245
+ if self.norm == 'pre':
246
+ self.forward = self.pre_norm_forward
247
+ elif self.norm == 'post':
248
+ self.forward = self.post_norm_forward
249
+
250
+
251
+ def pre_norm_forward(
252
+ self,
253
+ hidden_states: torch.Tensor,
254
+ attention_mask: torch.Tensor,
255
+ output_attentions: Optional[bool] = False,
256
+ ) -> Tuple[torch.FloatTensor]:
257
+ """
258
+ Args:
259
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
260
+ attention_mask (`torch.FloatTensor`): attention mask of size
261
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
262
+ `(config.encoder_attention_heads,)`.
263
+ output_attentions (`bool`, *optional*):
264
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
265
+ returned tensors for more detail.
266
+ """
267
+ residual = hidden_states
268
+
269
+ hidden_states = self.layer_norm1(hidden_states)
270
+ hidden_states, attn_weights = self.self_attn(
271
+ hidden_states=hidden_states,
272
+ attention_mask=attention_mask,
273
+ output_attentions=output_attentions,
274
+ )
275
+ hidden_states = residual + hidden_states
276
+
277
+ residual = hidden_states
278
+ hidden_states = self.layer_norm2(hidden_states)
279
+ hidden_states = self.mlp(hidden_states)
280
+ hidden_states = residual + hidden_states
281
+
282
+ outputs = (hidden_states,)
283
+
284
+ if output_attentions:
285
+ outputs += (attn_weights,)
286
+
287
+ return outputs
288
+
289
+ def post_norm_forward(
290
+ self,
291
+ hidden_states: torch.Tensor,
292
+ attention_mask: torch.Tensor,
293
+ output_attentions: Optional[bool] = False,
294
+ ) -> Tuple[torch.FloatTensor]:
295
+ """
296
+ Args:
297
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
298
+ attention_mask (`torch.FloatTensor`): attention mask of size
299
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
300
+ `(config.encoder_attention_heads,)`.
301
+ output_attentions (`bool`, *optional*):
302
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
303
+ returned tensors for more detail.
304
+ """
305
+ residual = hidden_states
306
+
307
+ hidden_states, attn_weights = self.self_attn(
308
+ hidden_states=hidden_states,
309
+ attention_mask=attention_mask,
310
+ output_attentions=output_attentions,
311
+ )
312
+ hidden_states = residual + hidden_states
313
+
314
+ hidden_states = self.layer_norm1(hidden_states)
315
+
316
+ residual = hidden_states
317
+ hidden_states = self.mlp(hidden_states)
318
+ hidden_states = residual + hidden_states
319
+ hidden_states = self.layer_norm2(hidden_states)
320
+ outputs = (hidden_states,)
321
+
322
+ if output_attentions:
323
+ outputs += (attn_weights,)
324
+
325
+ return outputs
326
+
327
+
328
+ class BiomedCLIPTextProjection(nn.Module):
329
+ def __init__(self, config):
330
+ super().__init__()
331
+ self.config = config
332
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
333
+ self.activation_fn = ACT2FN[config.hidden_act]
334
+ self.fc2 = nn.Linear(config.intermediate_size, config.projection_dim, bias=False)
335
+
336
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
337
+ hidden_states = self.fc1(hidden_states)
338
+ hidden_states = self.activation_fn(hidden_states)
339
+ hidden_states = self.fc2(hidden_states)
340
+ return hidden_states
341
+
342
+
343
+ class BiomedCLIPEncoder(nn.Module):
344
+ """
345
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
346
+ [`BiomedCLIPEncoderLayer`].
347
+
348
+ Args:
349
+ config: BiomedCLIPConfig
350
+ """
351
+ def __init__(self, config, norm='pre'):
352
+ super().__init__()
353
+ self.config = config
354
+ self.norm = norm
355
+ self.layers = nn.ModuleList([BiomedCLIPEncoderLayer(config, norm) for _ in range(config.num_hidden_layers)])
356
+ self.gradient_checkpointing = False
357
+
358
+ def forward(
359
+ self,
360
+ hidden_states: torch.Tensor,
361
+ attention_mask: Optional[torch.FloatTensor] = None,
362
+ head_mask: Optional[torch.FloatTensor] = None,
363
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
364
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
365
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
366
+ use_cache: Optional[bool] = None,
367
+ output_attentions: Optional[bool] = False,
368
+ output_hidden_states: Optional[bool] = False,
369
+ return_dict: Optional[bool] = True,
370
+ ) :
371
+ all_hidden_states = () if output_hidden_states else None
372
+ all_self_attentions = () if output_attentions else None
373
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
374
+
375
+ if self.gradient_checkpointing and self.training:
376
+ if use_cache:
377
+ logger.warning_once(
378
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
379
+ )
380
+ use_cache = False
381
+
382
+ next_decoder_cache = () if use_cache else None
383
+ for i, layer_module in enumerate(self.layers):
384
+ if output_hidden_states:
385
+ all_hidden_states = all_hidden_states + (hidden_states,)
386
+
387
+ layer_head_mask = head_mask[i] if head_mask is not None else None
388
+ past_key_value = past_key_values[i] if past_key_values is not None else None
389
+
390
+ if self.gradient_checkpointing and self.training:
391
+ layer_outputs = self._gradient_checkpointing_func(
392
+ layer_module.__call__,
393
+ hidden_states,
394
+ attention_mask,
395
+ output_attentions,
396
+ )
397
+ else:
398
+ layer_outputs = layer_module(
399
+ hidden_states,
400
+ attention_mask,
401
+ output_attentions,
402
+ )
403
+
404
+ hidden_states = layer_outputs[0]
405
+ if use_cache:
406
+ next_decoder_cache += (layer_outputs[-1],)
407
+ if output_attentions:
408
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
409
+ if self.config.add_cross_attention:
410
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
411
+
412
+ if output_hidden_states:
413
+ all_hidden_states = all_hidden_states + (hidden_states,)
414
+
415
+ if not return_dict:
416
+ return tuple(
417
+ v
418
+ for v in [
419
+ hidden_states,
420
+ next_decoder_cache,
421
+ all_hidden_states,
422
+ all_self_attentions,
423
+ all_cross_attentions,
424
+ ]
425
+ if v is not None
426
+ )
427
+ return BaseModelOutputWithPastAndCrossAttentions(
428
+ last_hidden_state=hidden_states,
429
+ past_key_values=next_decoder_cache,
430
+ hidden_states=all_hidden_states,
431
+ attentions=all_self_attentions,
432
+ cross_attentions=all_cross_attentions,
433
+ )
434
+
435
+
436
+
437
+ class BiomedCLIPTextTransformer(CLIPPreTrainedModel):
438
+ def __init__(self, config: CLIPTextConfig):
439
+ super().__init__(config)
440
+ self.config = config
441
+ embed_dim = config.hidden_size
442
+ self.embeddings = BiomedCLIPTextEmbeddings(config)
443
+ self.encoder = BiomedCLIPEncoder(config, norm='post')
444
+ # no final_ln
445
+ # self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
446
+
447
+ # For `pooled_output` computation
448
+
449
+ def forward(
450
+ self,
451
+ input_ids: Optional[torch.Tensor] = None,
452
+ attention_mask: Optional[torch.Tensor] = None,
453
+ token_type_ids: Optional[torch.Tensor] = None,
454
+ position_ids: Optional[torch.Tensor] = None,
455
+ inputs_embeds: Optional[torch.Tensor] = None,
456
+ encoder_hidden_states: Optional[torch.Tensor] = None,
457
+ encoder_attention_mask: Optional[torch.Tensor] = None,
458
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
459
+ use_cache: Optional[bool] = None,
460
+ output_attentions: Optional[bool] = None,
461
+ output_hidden_states: Optional[bool] = None,
462
+ return_dict: Optional[bool] = None,
463
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
464
+ r"""
465
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
466
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
467
+ the model is configured as a decoder.
468
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
469
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
470
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
471
+
472
+ - 1 for tokens that are **not masked**,
473
+ - 0 for tokens that are **masked**.
474
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
475
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
476
+
477
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
478
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
479
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
480
+ use_cache (`bool`, *optional*):
481
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
482
+ `past_key_values`).
483
+ """
484
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
485
+ output_hidden_states = (
486
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
487
+ )
488
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
489
+
490
+ if self.config.is_decoder:
491
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
492
+ else:
493
+ use_cache = False
494
+
495
+ if input_ids is not None and inputs_embeds is not None:
496
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
497
+ elif input_ids is not None:
498
+ input_shape = input_ids.size()
499
+ elif inputs_embeds is not None:
500
+ input_shape = inputs_embeds.size()[:-1]
501
+ else:
502
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
503
+
504
+ batch_size, seq_length = input_shape
505
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
506
+
507
+ # past_key_values_length
508
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
509
+
510
+ if token_type_ids is None:
511
+ if hasattr(self.embeddings, "token_type_ids"):
512
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
513
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
514
+ token_type_ids = buffered_token_type_ids_expanded
515
+ else:
516
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
517
+
518
+ embedding_output = self.embeddings(
519
+ input_ids=input_ids,
520
+ position_ids=position_ids,
521
+ token_type_ids=token_type_ids,
522
+ inputs_embeds=inputs_embeds,
523
+ past_key_values_length=past_key_values_length,
524
+ )
525
+
526
+ if attention_mask is None:
527
+ attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
528
+
529
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
530
+ # ourselves in which case we just need to make it broadcastable to all heads.
531
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
532
+
533
+ # If a 2D or 3D attention mask is provided for the cross-attention
534
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
535
+ if self.config.is_decoder and encoder_hidden_states is not None:
536
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
537
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
538
+ if encoder_attention_mask is None:
539
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
540
+
541
+ if use_sdpa_attention_masks:
542
+ # Expand the attention mask for SDPA.
543
+ # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
544
+ encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
545
+ encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
546
+ )
547
+ else:
548
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
549
+ else:
550
+ encoder_extended_attention_mask = None
551
+
552
+
553
+ encoder_outputs = self.encoder(
554
+ embedding_output,
555
+ attention_mask=extended_attention_mask,
556
+ output_attentions=output_attentions,
557
+ output_hidden_states=output_hidden_states,
558
+ return_dict=return_dict,
559
+ )
560
+ sequence_output = encoder_outputs[0]
561
+
562
+ return (sequence_output, sequence_output[:, 0, :])
563
+
564
+
565
+
566
+ class BiomedCLIPVisionTransformer(nn.Module):
567
+ def __init__(self, config: CLIPVisionConfig):
568
+ super().__init__()
569
+ self.config = config
570
+ embed_dim = config.hidden_size
571
+
572
+ self.embeddings = BiomedCLIPVisionEmbeddings(config)
573
+ # No pre_norm in open_clip Vision Tower
574
+ # self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
575
+ self.encoder = BiomedCLIPEncoder(config)
576
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
577
+
578
+ def forward(
579
+ self,
580
+ pixel_values: Optional[torch.FloatTensor] = None,
581
+ output_attentions: Optional[bool] = None,
582
+ output_hidden_states: Optional[bool] = None,
583
+ return_dict: Optional[bool] = None,
584
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
585
+ r"""
586
+ Returns:
587
+
588
+ """
589
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
590
+ output_hidden_states = (
591
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
592
+ )
593
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
594
+
595
+ if pixel_values is None:
596
+ raise ValueError("You have to specify pixel_values")
597
+
598
+ hidden_states = self.embeddings(pixel_values)
599
+ # hidden_states = self.pre_layrnorm(hidden_states)
600
+
601
+ encoder_outputs = self.encoder(
602
+ hidden_states=hidden_states,
603
+ output_attentions=output_attentions,
604
+ output_hidden_states=output_hidden_states,
605
+ return_dict=return_dict,
606
+ )
607
+
608
+ last_hidden_state = encoder_outputs[0]
609
+ pooled_output = last_hidden_state[:, 0, :]
610
+ pooled_output = self.post_layernorm(pooled_output)
611
+
612
+ if not return_dict:
613
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
614
+
615
+ return BaseModelOutputWithPooling(
616
+ last_hidden_state=last_hidden_state,
617
+ pooler_output=pooled_output,
618
+ hidden_states=encoder_outputs.hidden_states,
619
+ attentions=encoder_outputs.attentions,
620
+ )
621
+
622
+
623
+ class BiomedCLIPModel(CLIPPreTrainedModel):
624
+ config_class = BiomedCLIPConfig
625
+ _no_split_modules = ["BiomedCLIPTextEmbeddings", "BiomedCLIPEncoderLayer"]
626
+
627
+ def __init__(self, config: BiomedCLIPConfig):
628
+ super().__init__(config)
629
+
630
+ if not isinstance(config.text_config, CLIPTextConfig):
631
+ raise ValueError(
632
+ "config.text_config is expected to be of type CLIPTextConfig but is of type"
633
+ f" {type(config.text_config)}."
634
+ )
635
+
636
+ if not isinstance(config.vision_config, CLIPVisionConfig):
637
+ raise ValueError(
638
+ "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
639
+ f" {type(config.vision_config)}."
640
+ )
641
+
642
+ text_config = config.text_config
643
+ text_projection_config = config.text_projection_config
644
+ vision_config = config.vision_config
645
+
646
+
647
+ self.projection_dim = config.projection_dim
648
+ self.text_embed_dim = text_config.hidden_size
649
+ self.vision_embed_dim = vision_config.hidden_size
650
+
651
+ self.text_model = BiomedCLIPTextTransformer(text_config)
652
+ self.vision_model = BiomedCLIPVisionTransformer(vision_config)
653
+
654
+ self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
655
+
656
+ self.text_projection = BiomedCLIPTextProjection(text_projection_config)
657
+
658
+ self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
659
+
660
+ # Initialize weights and apply final processing
661
+ self.post_init()
662
+
663
+ def get_text_features(
664
+ self,
665
+ input_ids: Optional[torch.Tensor] = None,
666
+ attention_mask: Optional[torch.Tensor] = None,
667
+ token_type_ids: Optional[torch.Tensor] = None,
668
+ position_ids: Optional[torch.Tensor] = None,
669
+ output_attentions: Optional[bool] = None,
670
+ output_hidden_states: Optional[bool] = None,
671
+ return_dict: Optional[bool] = None,
672
+ ) -> torch.FloatTensor:
673
+ r"""
674
+ Returns:
675
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
676
+ applying the projection layer to the pooled output of [`CLIPTextModel`].
677
+
678
+ Examples:
679
+
680
+ ```python
681
+ >>> from transformers import AutoTokenizer, CLIPModel
682
+
683
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
684
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
685
+
686
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
687
+ >>> text_features = model.get_text_features(**inputs)
688
+ ```"""
689
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
690
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
691
+ output_hidden_states = (
692
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
693
+ )
694
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
695
+
696
+ text_outputs = self.text_model(
697
+ input_ids=input_ids,
698
+ attention_mask=attention_mask,
699
+ token_type_ids=token_type_ids,
700
+ position_ids=position_ids,
701
+ output_attentions=output_attentions,
702
+ output_hidden_states=output_hidden_states,
703
+ return_dict=return_dict,
704
+ )
705
+
706
+ pooled_output = text_outputs[1]
707
+ text_features = self.text_projection(pooled_output)
708
+
709
+ return text_features
710
+
711
+ def get_image_features(
712
+ self,
713
+ pixel_values: Optional[torch.FloatTensor] = None,
714
+ output_attentions: Optional[bool] = None,
715
+ output_hidden_states: Optional[bool] = None,
716
+ return_dict: Optional[bool] = None,
717
+ ) -> torch.FloatTensor:
718
+ r"""
719
+ Returns:
720
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
721
+ applying the projection layer to the pooled output of [`CLIPVisionModel`].
722
+
723
+ Examples:
724
+
725
+ ```python
726
+ >>> from PIL import Image
727
+ >>> import requests
728
+ >>> from transformers import AutoProcessor, CLIPModel
729
+
730
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
731
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
732
+
733
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
734
+ >>> image = Image.open(requests.get(url, stream=True).raw)
735
+
736
+ >>> inputs = processor(images=image, return_tensors="pt")
737
+
738
+ >>> image_features = model.get_image_features(**inputs)
739
+ ```"""
740
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
741
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
742
+ output_hidden_states = (
743
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
744
+ )
745
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
746
+
747
+ vision_outputs = self.vision_model(
748
+ pixel_values=pixel_values,
749
+ output_attentions=output_attentions,
750
+ output_hidden_states=output_hidden_states,
751
+ return_dict=return_dict,
752
+ )
753
+
754
+ pooled_output = vision_outputs[1] # pooled_output
755
+ image_features = self.visual_projection(pooled_output)
756
+
757
+ return image_features
758
+
759
+ def forward(
760
+ self,
761
+ input_ids: Optional[torch.LongTensor] = None,
762
+ pixel_values: Optional[torch.FloatTensor] = None,
763
+ attention_mask: Optional[torch.Tensor] = None,
764
+ token_type_ids: Optional[torch.LongTensor] = None,
765
+ position_ids: Optional[torch.LongTensor] = None,
766
+ return_loss: Optional[bool] = None,
767
+ output_attentions: Optional[bool] = None,
768
+ output_hidden_states: Optional[bool] = None,
769
+ return_dict: Optional[bool] = None,
770
+ ) -> Union[Tuple, CLIPOutput]:
771
+ r"""
772
+ Returns:
773
+
774
+ Examples:
775
+
776
+ ```python
777
+ >>> from PIL import Image
778
+ >>> import requests
779
+ >>> from transformers import AutoProcessor, CLIPModel
780
+
781
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
782
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
783
+
784
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
785
+ >>> image = Image.open(requests.get(url, stream=True).raw)
786
+
787
+ >>> inputs = processor(
788
+ ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
789
+ ... )
790
+
791
+ >>> outputs = model(**inputs)
792
+ >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
793
+ >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
794
+ ```"""
795
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
796
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
797
+ output_hidden_states = (
798
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
799
+ )
800
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
801
+
802
+ vision_outputs = self.vision_model(
803
+ pixel_values=pixel_values,
804
+ output_attentions=output_attentions,
805
+ output_hidden_states=output_hidden_states,
806
+ return_dict=return_dict,
807
+ )
808
+
809
+ text_outputs = self.text_model(
810
+ input_ids=input_ids,
811
+ token_type_ids=token_type_ids,
812
+ attention_mask=attention_mask,
813
+ position_ids=position_ids,
814
+ output_attentions=output_attentions,
815
+ output_hidden_states=output_hidden_states,
816
+ return_dict=return_dict,
817
+ )
818
+
819
+ image_embeds = vision_outputs[1]
820
+ image_embeds = self.visual_projection(image_embeds)
821
+
822
+ text_embeds = text_outputs[1]
823
+ text_embeds = self.text_projection(text_embeds)
824
+
825
+ # normalized features
826
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
827
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
828
+
829
+ # cosine similarity as logits
830
+ logit_scale = self.logit_scale.exp()
831
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
832
+ logits_per_image = logits_per_text.t()
833
+
834
+ loss = None
835
+ if return_loss:
836
+ loss = clip_loss(logits_per_text)
837
+
838
+ if not return_dict:
839
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
840
+ return ((loss,) + output) if loss is not None else output
841
+
842
+ return CLIPOutput(
843
+ loss=loss,
844
+ logits_per_image=logits_per_image,
845
+ logits_per_text=logits_per_text,
846
+ text_embeds=text_embeds,
847
+ image_embeds=image_embeds,
848
+ text_model_output=text_outputs,
849
+ vision_model_output=vision_outputs,
850
+ )
851
+
852
+
853
+ class BiomedCLIPForImageClassification(CLIPPreTrainedModel):
854
+ main_input_name = "pixel_values"
855
+
856
+ def __init__(self, config: BiomedCLIPConfig) -> None:
857
+ super().__init__(config)
858
+
859
+ self.num_labels = config.num_labels
860
+ self.vision_model = BiomedCLIPVisionTransformer(config.vision_config)
861
+
862
+ # Classifier head
863
+ self.classifier = (
864
+ nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
865
+ )
866
+
867
+ # Initialize weights and apply final processing
868
+ self.post_init()
869
+
870
+ def forward(
871
+ self,
872
+ pixel_values: Optional[torch.Tensor] = None,
873
+ labels: Optional[torch.Tensor] = None,
874
+ output_attentions: Optional[bool] = None,
875
+ output_hidden_states: Optional[bool] = None,
876
+ return_dict: Optional[bool] = None,
877
+ ) -> Union[tuple, ImageClassifierOutput]:
878
+ r"""
879
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
880
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
881
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
882
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
883
+ """
884
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
885
+ output_hidden_states = (
886
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
887
+ )
888
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
889
+
890
+ outputs = self.vision_model(
891
+ pixel_values,
892
+ output_attentions=output_attentions,
893
+ output_hidden_states=output_hidden_states,
894
+ return_dict=return_dict,
895
+ )
896
+
897
+ sequence_output = outputs[0]
898
+
899
+ # average pool the patch tokens
900
+ sequence_output = torch.mean(sequence_output[:, 1:, :], dim=1)
901
+ # apply classifier
902
+ logits = self.classifier(sequence_output)
903
+
904
+ loss = None
905
+ if labels is not None:
906
+ # move labels to correct device to enable model parallelism
907
+ labels = labels.to(logits.device)
908
+ if self.config.problem_type is None:
909
+ if self.num_labels == 1:
910
+ self.config.problem_type = "regression"
911
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
912
+ self.config.problem_type = "single_label_classification"
913
+ else:
914
+ self.config.problem_type = "multi_label_classification"
915
+
916
+ if self.config.problem_type == "regression":
917
+ loss_fct = MSELoss()
918
+ if self.num_labels == 1:
919
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
920
+ else:
921
+ loss = loss_fct(logits, labels)
922
+ elif self.config.problem_type == "single_label_classification":
923
+ loss_fct = CrossEntropyLoss()
924
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
925
+ elif self.config.problem_type == "multi_label_classification":
926
+ loss_fct = BCEWithLogitsLoss()
927
+ loss = loss_fct(logits, labels)
928
+
929
+ if not return_dict:
930
+ output = (logits,) + outputs[2:]
931
+ return ((loss,) + output) if loss is not None else output
932
+
933
+ return ImageClassifierOutput(
934
+ loss=loss,
935
+ logits=logits,
936
+ hidden_states=outputs.hidden_states,
937
+ attentions=outputs.attentions,
938
+ )
preprocessor_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": 224,
3
+ "do_center_crop": true,
4
+ "do_normalize": true,
5
+ "do_resize": true,
6
+ "image_processor_type": "CLIPImageProcessor",
7
+ "tokenizer_type": "BertTokenizer",
8
+ "image_mean": [
9
+ 0.48145466,
10
+ 0.4578275,
11
+ 0.40821073
12
+ ],
13
+ "image_std": [
14
+ 0.26862954,
15
+ 0.26130258,
16
+ 0.27577711
17
+ ],
18
+ "resample": 3,
19
+ "size": 224
20
+ }
processing_biomed_clip.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Image/Text processor class for CLIP
17
+ """
18
+
19
+ import warnings
20
+
21
+ from transformers.processing_utils import ProcessorMixin
22
+ from transformers.tokenization_utils_base import BatchEncoding
23
+
24
+
25
+ class BiomedCLIPProcessor(ProcessorMixin):
26
+ r"""
27
+ Constructs a CLIP processor which wraps a CLIP image processor and a CLIP tokenizer into a single processor.
28
+
29
+ [`CLIPProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`CLIPTokenizerFast`]. See the
30
+ [`~CLIPProcessor.__call__`] and [`~CLIPProcessor.decode`] for more information.
31
+
32
+ Args:
33
+ image_processor ([`CLIPImageProcessor`], *optional*):
34
+ The image processor is a required input.
35
+ tokenizer ([`CLIPTokenizerFast`], *optional*):
36
+ The tokenizer is a required input.
37
+ """
38
+
39
+ attributes = ["image_processor", "tokenizer"]
40
+ image_processor_class = "CLIPImageProcessor"
41
+ tokenizer_class = ("BertTokenizer", "BertTokenizerFast")
42
+
43
+ def __init__(self, image_processor=None, tokenizer=None, **kwargs):
44
+ feature_extractor = None
45
+ if "feature_extractor" in kwargs:
46
+ warnings.warn(
47
+ "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
48
+ " instead.",
49
+ FutureWarning,
50
+ )
51
+ feature_extractor = kwargs.pop("feature_extractor")
52
+
53
+ image_processor = image_processor if image_processor is not None else feature_extractor
54
+ if image_processor is None:
55
+ raise ValueError("You need to specify an `image_processor`.")
56
+ if tokenizer is None:
57
+ raise ValueError("You need to specify a `tokenizer`.")
58
+
59
+ super().__init__(image_processor, tokenizer)
60
+
61
+ def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
62
+ """
63
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
64
+ and `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode
65
+ the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
66
+ CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
67
+ of the above two methods for more information.
68
+
69
+ Args:
70
+ text (`str`, `List[str]`, `List[List[str]]`):
71
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
72
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
73
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
74
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
75
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
76
+ tensor. Both channels-first and channels-last formats are supported.
77
+
78
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
79
+ If set, will return tensors of a particular framework. Acceptable values are:
80
+
81
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
82
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
83
+ - `'np'`: Return NumPy `np.ndarray` objects.
84
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
85
+
86
+ Returns:
87
+ [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
88
+
89
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
90
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
91
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
92
+ `None`).
93
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
94
+ """
95
+ tokenizer_kwargs, image_processor_kwargs = {}, {}
96
+ if kwargs:
97
+ tokenizer_kwargs = {k: v for k, v in kwargs.items() if k not in self.image_processor._valid_processor_keys}
98
+ image_processor_kwargs = {
99
+ k: v for k, v in kwargs.items() if k in self.image_processor._valid_processor_keys
100
+ }
101
+
102
+ if text is None and images is None:
103
+ raise ValueError("You have to specify either text or images. Both cannot be none.")
104
+
105
+ if text is not None:
106
+ encoding = self.tokenizer(text, return_tensors=return_tensors, **tokenizer_kwargs)
107
+
108
+ if images is not None:
109
+ image_features = self.image_processor(images, return_tensors=return_tensors, **image_processor_kwargs)
110
+
111
+ if text is not None and images is not None:
112
+ encoding["pixel_values"] = image_features.pixel_values
113
+ return encoding
114
+ elif text is not None:
115
+ return encoding
116
+ else:
117
+ return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
118
+
119
+ def batch_decode(self, *args, **kwargs):
120
+ """
121
+ This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
122
+ refer to the docstring of this method for more information.
123
+ """
124
+ return self.tokenizer.batch_decode(*args, **kwargs)
125
+
126
+ def decode(self, *args, **kwargs):
127
+ """
128
+ This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
129
+ the docstring of this method for more information.
130
+ """
131
+ return self.tokenizer.decode(*args, **kwargs)
132
+
133
+ @property
134
+ def model_input_names(self):
135
+ tokenizer_input_names = self.tokenizer.model_input_names
136
+ image_processor_input_names = self.image_processor.model_input_names
137
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
138
+
139
+ @property
140
+ def feature_extractor_class(self):
141
+ warnings.warn(
142
+ "`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.",
143
+ FutureWarning,
144
+ )
145
+ return self.image_processor_class
146
+
147
+ @property
148
+ def feature_extractor(self):
149
+ warnings.warn(
150
+ "`feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.",
151
+ FutureWarning,
152
+ )
153
+ return self.image_processor
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5bdc400de59a85620ddc7584d06913dc901c47f22647899c6addec71b9a5c9a2
3
+ size 783733062