infgrad commited on
Commit
1d33766
·
verified ·
1 Parent(s): c5227da

Introduce a custom Sentence Transformer module for smooth multi-modality (#1)

Browse files

- Introduce custom Sentence Transformer module (9862f98edfbc3c5f1a56b1a00ef87ad1b9af3b76)
- Use self.max_seq_length to inform the maximum tokenize length (c0c6d64415a1e25865af6dbb702ac5ba5a1645e4)
- Merge branch 'main' into pr/1, resolve merge conflict (008f2574a5989c788b0fa395d4342a0e1c40f250)

Files changed (4) hide show
  1. README.md +11 -74
  2. custom_st.py +87 -0
  3. modules.json +12 -6
  4. sentence_bert_config.json +4 -1
README.md CHANGED
@@ -9004,66 +9004,10 @@ Actually, I've got first place on MTEB (Chinese and English), I will not release
9004
 
9005
  ## Usage
9006
  ```python
9007
-
9008
- import functools
9009
- import PIL
9010
- import numpy as np
9011
  import torch
9012
- from typing import Dict
9013
- from io import BytesIO
9014
- from transformers import SiglipImageProcessor
9015
  from sentence_transformers import SentenceTransformer
9016
 
9017
 
9018
- def jasper_vl_forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
9019
- trans_features = {"input_ids": features["input_ids"], "attention_mask": features["attention_mask"]}
9020
- if "pixel_values" in features:
9021
- trans_features["pixel_values"] = features["pixel_values"]
9022
- sentence_embedding = self.auto_model(**trans_features, **kwargs)["sentence_embedding"]
9023
- features.update({"sentence_embedding": sentence_embedding})
9024
- return features
9025
-
9026
-
9027
- def jasper_vl_tokenize(self, texts: list[Dict] | list[str]) -> dict[str, torch.Tensor]:
9028
- img_start_token = "<|jasper_img_start|>"
9029
- img_token = "<|jasper_img_token|>"
9030
- img_end_token = "<|jasper_img_end|>"
9031
- num_img_tokens = 300
9032
-
9033
- def process_text_item(item):
9034
- if isinstance(item, str):
9035
- return item, []
9036
- text, images = "", []
9037
- for sub_item in item:
9038
- if sub_item["type"] == "text":
9039
- text += sub_item["content"]
9040
- elif sub_item["type"] == "image_bytes":
9041
- text += img_start_token + img_token * num_img_tokens + img_end_token
9042
- images.append(PIL.Image.open(BytesIO(sub_item["content"])).convert("RGB"))
9043
- elif sub_item["type"] == "image_path":
9044
- text += img_start_token + img_token * num_img_tokens + img_end_token
9045
- images.append(PIL.Image.open(sub_item["content"]).convert("RGB"))
9046
- else:
9047
- raise ValueError(f"unknown data type {sub_item['type']}")
9048
- return text, images
9049
-
9050
- all_texts, all_images = [], []
9051
- for item in texts:
9052
- text, images = process_text_item(item)
9053
- all_texts.append(text)
9054
- all_images.extend(images)
9055
- ipt = self.tokenizer(all_texts, padding="longest", truncation=True, max_length=1024, return_tensors="pt")
9056
- if all_images:
9057
- ipt["pixel_values"] = self.processor(
9058
- images=all_images,
9059
- return_tensors="pt"
9060
- )["pixel_values"]
9061
- # For the sake of demonstration, external variables are used here, please modify the code according to your own environment.
9062
- if use_gpu:
9063
- ipt["pixel_values"] = ipt["pixel_values"].bfloat16()
9064
- return ipt
9065
-
9066
-
9067
  DOC1 = """
9068
  Blue light is scattered in all directions by the tiny molecules of air in Earth's atmosphere.
9069
  Blue is scattered more than other colors because it travels as shorter, smaller waves. This is why we see a blue sky most of the time.
@@ -9081,10 +9025,6 @@ Color combinations: Decide how to best complement your preferred color with othe
9081
  Color palette: Limit your color palette to a main color and one or two additional colors.
9082
  60-30-10 rule: Use a primary color 60% of the time, a secondary color 30% of the time, and an accent color 10% of the time
9083
  """
9084
- prompt_dict = {
9085
- "s2p_query": "Instruct: Given a web search query, retrieve relevant passages that answer the query.\nQuery: ",
9086
- "s2s_query": "Instruct: Retrieve semantically similar text.\nQuery: "
9087
- }
9088
  if __name__ == "__main__":
9089
  # load model
9090
  use_gpu = False
@@ -9092,7 +9032,7 @@ if __name__ == "__main__":
9092
  model = SentenceTransformer(
9093
  model_name,
9094
  trust_remote_code=True,
9095
- device="cpu",
9096
  model_kwargs={
9097
  "torch_dtype": torch.bfloat16 if use_gpu else torch.float32,
9098
  "attn_implementation": "sdpa"
@@ -9101,13 +9041,10 @@ if __name__ == "__main__":
9101
  ## 1024 is recommended
9102
  # set is_text_encoder 'True', if you do not encode image
9103
  config_kwargs={"is_text_encoder": False, "vector_dim": 1024},
9104
- tokenizer_kwargs={"padding_side": "right"}
9105
  )
9106
- # jasper model cannot directly be used in SentenceTransformer, do some modifications
9107
- model.processor = SiglipImageProcessor.from_pretrained(model_name)
9108
- model.tokenize = functools.partial(jasper_vl_tokenize, model)
9109
- model._first_module().forward = functools.partial(jasper_vl_forward, model._first_module())
9110
  model.max_seq_length = 1024
 
9111
  # data
9112
  q_list = [
9113
  "Why the sky is blue?",
@@ -9118,16 +9055,16 @@ if __name__ == "__main__":
9118
  [{"type": "image_path", "content": "./assets/img1.png"}, {"type": "text", "content": "Hope this image helps!"}],
9119
  DOC2,
9120
  [{"type": "image_path", "content": "./assets/img2.png"}],
9121
-
9122
  ]
9123
- q_vecs = model.encode([prompt_dict["s2p_query"] + text for text in q_list], normalize_embeddings=True)
9124
- doc_vecs = model.encode(doc_list, normalize_embeddings=True)
9125
- print(np.matmul(q_vecs, doc_vecs.T))
9126
- # the output is:
9127
- # [[0.777521 0.75944513 0.24291277 0.2187205]
9128
- # [0.32261407 0.30536035 0.74208796 0.5484469]]
9129
-
9130
 
 
 
 
 
 
 
9131
  ```
9132
 
9133
  ## Evaluation on MTEB
 
9004
 
9005
  ## Usage
9006
  ```python
 
 
 
 
9007
  import torch
 
 
 
9008
  from sentence_transformers import SentenceTransformer
9009
 
9010
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9011
  DOC1 = """
9012
  Blue light is scattered in all directions by the tiny molecules of air in Earth's atmosphere.
9013
  Blue is scattered more than other colors because it travels as shorter, smaller waves. This is why we see a blue sky most of the time.
 
9025
  Color palette: Limit your color palette to a main color and one or two additional colors.
9026
  60-30-10 rule: Use a primary color 60% of the time, a secondary color 30% of the time, and an accent color 10% of the time
9027
  """
 
 
 
 
9028
  if __name__ == "__main__":
9029
  # load model
9030
  use_gpu = False
 
9032
  model = SentenceTransformer(
9033
  model_name,
9034
  trust_remote_code=True,
9035
+ device="cpu" if not use_gpu else "cuda",
9036
  model_kwargs={
9037
  "torch_dtype": torch.bfloat16 if use_gpu else torch.float32,
9038
  "attn_implementation": "sdpa"
 
9041
  ## 1024 is recommended
9042
  # set is_text_encoder 'True', if you do not encode image
9043
  config_kwargs={"is_text_encoder": False, "vector_dim": 1024},
 
9044
  )
9045
+ # We can reduce the max_seq_length from the default of 2048 for faster encoding
 
 
 
9046
  model.max_seq_length = 1024
9047
+
9048
  # data
9049
  q_list = [
9050
  "Why the sky is blue?",
 
9055
  [{"type": "image_path", "content": "./assets/img1.png"}, {"type": "text", "content": "Hope this image helps!"}],
9056
  DOC2,
9057
  [{"type": "image_path", "content": "./assets/img2.png"}],
 
9058
  ]
9059
+ q_vecs = model.encode(q_list, prompt_name="s2p_query")
9060
+ doc_vecs = model.encode(doc_list)
 
 
 
 
 
9061
 
9062
+ # calculate similarity
9063
+ similarities = model.similarity(q_vecs, doc_vecs)
9064
+ print(similarities)
9065
+ # the output is:
9066
+ # tensor([[0.7775, 0.7594, 0.2429, 0.2187],
9067
+ # [0.3226, 0.3054, 0.7421, 0.5484]])
9068
  ```
9069
 
9070
  ## Evaluation on MTEB
custom_st.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+ import PIL
3
+ import torch
4
+ import PIL
5
+ import torch
6
+ from typing import Dict
7
+ from io import BytesIO
8
+ from transformers import SiglipImageProcessor
9
+ from sentence_transformers.models import Transformer as BaseTransformer
10
+
11
+
12
+ class MultiModalTransformer(BaseTransformer):
13
+
14
+ def __init__(
15
+ self,
16
+ model_name_or_path: str,
17
+ cache_dir: Optional[str] = None,
18
+ tokenizer_args: Optional[Dict[str, Any]] = None,
19
+ **kwargs,
20
+ ):
21
+ super().__init__(model_name_or_path, **kwargs)
22
+ if tokenizer_args is None:
23
+ tokenizer_args = {}
24
+ self.processor = SiglipImageProcessor.from_pretrained(
25
+ model_name_or_path, cache_dir=cache_dir, **tokenizer_args
26
+ )
27
+
28
+ def forward(
29
+ self, features: dict[str, torch.Tensor], **kwargs
30
+ ) -> dict[str, torch.Tensor]:
31
+ trans_features = {
32
+ "input_ids": features["input_ids"],
33
+ "attention_mask": features["attention_mask"],
34
+ }
35
+ if "pixel_values" in features:
36
+ trans_features["pixel_values"] = features["pixel_values"].to(
37
+ self.auto_model.dtype
38
+ )
39
+
40
+ sentence_embedding = self.auto_model(**trans_features, **kwargs)[
41
+ "sentence_embedding"
42
+ ]
43
+ features.update({"sentence_embedding": sentence_embedding})
44
+ return features
45
+
46
+ def tokenize(self, texts: list[Dict] | list[str]) -> dict[str, torch.Tensor]:
47
+ img_start_token = "<|jasper_img_start|>"
48
+ img_token = "<|jasper_img_token|>"
49
+ img_end_token = "<|jasper_img_end|>"
50
+ num_img_tokens = 300
51
+
52
+ def process_text_item(item):
53
+ if isinstance(item, str):
54
+ return item, []
55
+ text, images = "", []
56
+ for sub_item in item:
57
+ if sub_item["type"] == "text":
58
+ text += sub_item["content"]
59
+ elif sub_item["type"] == "image_bytes":
60
+ text += img_start_token + img_token * num_img_tokens + img_end_token
61
+ images.append(
62
+ PIL.Image.open(BytesIO(sub_item["content"])).convert("RGB")
63
+ )
64
+ elif sub_item["type"] == "image_path":
65
+ text += img_start_token + img_token * num_img_tokens + img_end_token
66
+ images.append(PIL.Image.open(sub_item["content"]).convert("RGB"))
67
+ else:
68
+ raise ValueError(f"unknown data type {sub_item['type']}")
69
+ return text, images
70
+
71
+ all_texts, all_images = [], []
72
+ for item in texts:
73
+ text, images = process_text_item(item)
74
+ all_texts.append(text)
75
+ all_images.extend(images)
76
+ ipt = self.tokenizer(
77
+ all_texts,
78
+ padding="longest",
79
+ truncation=True,
80
+ max_length=self.max_seq_length,
81
+ return_tensors="pt",
82
+ )
83
+ if all_images:
84
+ ipt["pixel_values"] = self.processor(
85
+ images=all_images, return_tensors="pt"
86
+ )["pixel_values"]
87
+ return ipt
modules.json CHANGED
@@ -1,8 +1,14 @@
1
  [
2
- {
3
- "idx": 0,
4
- "name": "0",
5
- "path": "",
6
- "type": "sentence_transformers.models.Transformer"
7
- }
 
 
 
 
 
 
8
  ]
 
1
  [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "custom_st.MultiModalTransformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Normalize",
12
+ "type": "sentence_transformers.models.Normalize"
13
+ }
14
  ]
sentence_bert_config.json CHANGED
@@ -1,4 +1,7 @@
1
  {
2
  "max_seq_length": 2048,
3
- "do_lower_case": false
 
 
 
4
  }
 
1
  {
2
  "max_seq_length": 2048,
3
+ "do_lower_case": false,
4
+ "tokenizer_args": {
5
+ "padding_side": "right"
6
+ }
7
  }