Tom Aarsen commited on
Commit
9862f98
·
1 Parent(s): 8836013

Introduce custom Sentence Transformer module

Browse files
Files changed (4) hide show
  1. README.md +12 -74
  2. custom_st.py +87 -0
  3. modules.json +12 -6
  4. sentence_bert_config.json +4 -1
README.md CHANGED
@@ -8985,66 +8985,10 @@ This work was accomplished during my free time; please grant time a little time.
8985
 
8986
  ## Usage
8987
  ```python
8988
-
8989
- import functools
8990
- import PIL
8991
- import numpy as np
8992
  import torch
8993
- from typing import Dict
8994
- from io import BytesIO
8995
- from transformers import SiglipImageProcessor
8996
  from sentence_transformers import SentenceTransformer
8997
 
8998
 
8999
- def jasper_vl_forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
9000
- trans_features = {"input_ids": features["input_ids"], "attention_mask": features["attention_mask"]}
9001
- if "pixel_values" in features:
9002
- trans_features["pixel_values"] = features["pixel_values"]
9003
- sentence_embedding = self.auto_model(**trans_features, **kwargs)["sentence_embedding"]
9004
- features.update({"sentence_embedding": sentence_embedding})
9005
- return features
9006
-
9007
-
9008
- def jasper_vl_tokenize(self, texts: list[Dict] | list[str]) -> dict[str, torch.Tensor]:
9009
- img_start_token = "<|jasper_img_start|>"
9010
- img_token = "<|jasper_img_token|>"
9011
- img_end_token = "<|jasper_img_end|>"
9012
- num_img_tokens = 300
9013
-
9014
- def process_text_item(item):
9015
- if isinstance(item, str):
9016
- return item, []
9017
- text, images = "", []
9018
- for sub_item in item:
9019
- if sub_item["type"] == "text":
9020
- text += sub_item["content"]
9021
- elif sub_item["type"] == "image_bytes":
9022
- text += img_start_token + img_token * num_img_tokens + img_end_token
9023
- images.append(PIL.Image.open(BytesIO(sub_item["content"])).convert("RGB"))
9024
- elif sub_item["type"] == "image_path":
9025
- text += img_start_token + img_token * num_img_tokens + img_end_token
9026
- images.append(PIL.Image.open(sub_item["content"]).convert("RGB"))
9027
- else:
9028
- raise ValueError(f"unknown data type {sub_item['type']}")
9029
- return text, images
9030
-
9031
- all_texts, all_images = [], []
9032
- for item in texts:
9033
- text, images = process_text_item(item)
9034
- all_texts.append(text)
9035
- all_images.extend(images)
9036
- ipt = self.tokenizer(all_texts, padding="longest", truncation=True, max_length=1024, return_tensors="pt")
9037
- if all_images:
9038
- ipt["pixel_values"] = self.processor(
9039
- images=all_images,
9040
- return_tensors="pt"
9041
- )["pixel_values"]
9042
- # For the sake of demonstration, external variables are used here, please modify the code according to your own environment.
9043
- if use_gpu:
9044
- ipt["pixel_values"] = ipt["pixel_values"].bfloat16()
9045
- return ipt
9046
-
9047
-
9048
  DOC1 = """
9049
  Blue light is scattered in all directions by the tiny molecules of air in Earth's atmosphere.
9050
  Blue is scattered more than other colors because it travels as shorter, smaller waves. This is why we see a blue sky most of the time.
@@ -9062,10 +9006,6 @@ Color combinations: Decide how to best complement your preferred color with othe
9062
  Color palette: Limit your color palette to a main color and one or two additional colors.
9063
  60-30-10 rule: Use a primary color 60% of the time, a secondary color 30% of the time, and an accent color 10% of the time
9064
  """
9065
- prompt_dict = {
9066
- "s2p_query": "Instruct: Given a web search query, retrieve relevant passages that answer the query.\nQuery: ",
9067
- "s2s_query": "Instruct: Retrieve semantically similar text.\nQuery: "
9068
- }
9069
  if __name__ == "__main__":
9070
  # load model
9071
  use_gpu = False
@@ -9073,7 +9013,7 @@ if __name__ == "__main__":
9073
  model = SentenceTransformer(
9074
  model_name,
9075
  trust_remote_code=True,
9076
- device="cpu",
9077
  model_kwargs={
9078
  "torch_dtype": torch.bfloat16 if use_gpu else torch.float32,
9079
  "attn_implementation": "sdpa"
@@ -9082,13 +9022,10 @@ if __name__ == "__main__":
9082
  ## 1024 is recommended
9083
  # set is_text_encoder 'True', if you do not encode image
9084
  config_kwargs={"is_text_encoder": False, "vector_dim": 1024},
9085
- tokenizer_kwargs={"padding_side": "right"}
9086
  )
9087
- # jasper model cannot directly be used in SentenceTransformer, do some modifications
9088
- model.processor = SiglipImageProcessor.from_pretrained(model_name)
9089
- model.tokenize = functools.partial(jasper_vl_tokenize, model)
9090
- model._first_module().forward = functools.partial(jasper_vl_forward, model._first_module())
9091
  model.max_seq_length = 1024
 
9092
  # data
9093
  q_list = [
9094
  "Why the sky is blue?",
@@ -9099,16 +9036,17 @@ if __name__ == "__main__":
9099
  [{"type": "image_path", "content": "./assets/img1.png"}, {"type": "text", "content": "Hope this image helps!"}],
9100
  DOC2,
9101
  [{"type": "image_path", "content": "./assets/img2.png"}],
9102
-
9103
  ]
9104
- q_vecs = model.encode([prompt_dict["s2p_query"] + text for text in q_list], normalize_embeddings=True)
9105
- doc_vecs = model.encode(doc_list, normalize_embeddings=True)
9106
- print(np.matmul(q_vecs, doc_vecs.T))
9107
- # the output is:
9108
- # [[0.777521 0.75944513 0.24291277 0.2187205]
9109
- # [0.32261407 0.30536035 0.74208796 0.5484469]]
9110
-
9111
 
 
 
 
 
 
 
9112
  ```
 
9113
  ## License
9114
  **This model should not be used for any commercial purpose!**
 
8985
 
8986
  ## Usage
8987
  ```python
 
 
 
 
8988
  import torch
 
 
 
8989
  from sentence_transformers import SentenceTransformer
8990
 
8991
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8992
  DOC1 = """
8993
  Blue light is scattered in all directions by the tiny molecules of air in Earth's atmosphere.
8994
  Blue is scattered more than other colors because it travels as shorter, smaller waves. This is why we see a blue sky most of the time.
 
9006
  Color palette: Limit your color palette to a main color and one or two additional colors.
9007
  60-30-10 rule: Use a primary color 60% of the time, a secondary color 30% of the time, and an accent color 10% of the time
9008
  """
 
 
 
 
9009
  if __name__ == "__main__":
9010
  # load model
9011
  use_gpu = False
 
9013
  model = SentenceTransformer(
9014
  model_name,
9015
  trust_remote_code=True,
9016
+ device="cpu" if not use_gpu else "cuda",
9017
  model_kwargs={
9018
  "torch_dtype": torch.bfloat16 if use_gpu else torch.float32,
9019
  "attn_implementation": "sdpa"
 
9022
  ## 1024 is recommended
9023
  # set is_text_encoder 'True', if you do not encode image
9024
  config_kwargs={"is_text_encoder": False, "vector_dim": 1024},
 
9025
  )
9026
+ # We can reduce the max_seq_length from the default of 2048 for faster encoding
 
 
 
9027
  model.max_seq_length = 1024
9028
+
9029
  # data
9030
  q_list = [
9031
  "Why the sky is blue?",
 
9036
  [{"type": "image_path", "content": "./assets/img1.png"}, {"type": "text", "content": "Hope this image helps!"}],
9037
  DOC2,
9038
  [{"type": "image_path", "content": "./assets/img2.png"}],
 
9039
  ]
9040
+ q_vecs = model.encode(q_list, prompt_name="s2p_query")
9041
+ doc_vecs = model.encode(doc_list)
 
 
 
 
 
9042
 
9043
+ # calculate similarity
9044
+ similarities = model.similarity(q_vecs, doc_vecs)
9045
+ print(similarities)
9046
+ # the output is:
9047
+ # tensor([[0.7775, 0.7594, 0.2429, 0.2187],
9048
+ # [0.3226, 0.3054, 0.7421, 0.5484]])
9049
  ```
9050
+
9051
  ## License
9052
  **This model should not be used for any commercial purpose!**
custom_st.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+ import PIL
3
+ import torch
4
+ import PIL
5
+ import torch
6
+ from typing import Dict
7
+ from io import BytesIO
8
+ from transformers import SiglipImageProcessor
9
+ from sentence_transformers.models import Transformer as BaseTransformer
10
+
11
+
12
+ class MultiModalTransformer(BaseTransformer):
13
+
14
+ def __init__(
15
+ self,
16
+ model_name_or_path: str,
17
+ cache_dir: Optional[str] = None,
18
+ tokenizer_args: Optional[Dict[str, Any]] = None,
19
+ **kwargs,
20
+ ):
21
+ super().__init__(model_name_or_path, **kwargs)
22
+ if tokenizer_args is None:
23
+ tokenizer_args = {}
24
+ self.processor = SiglipImageProcessor.from_pretrained(
25
+ model_name_or_path, cache_dir=cache_dir, **tokenizer_args
26
+ )
27
+
28
+ def forward(
29
+ self, features: dict[str, torch.Tensor], **kwargs
30
+ ) -> dict[str, torch.Tensor]:
31
+ trans_features = {
32
+ "input_ids": features["input_ids"],
33
+ "attention_mask": features["attention_mask"],
34
+ }
35
+ if "pixel_values" in features:
36
+ trans_features["pixel_values"] = features["pixel_values"].to(
37
+ self.auto_model.dtype
38
+ )
39
+
40
+ sentence_embedding = self.auto_model(**trans_features, **kwargs)[
41
+ "sentence_embedding"
42
+ ]
43
+ features.update({"sentence_embedding": sentence_embedding})
44
+ return features
45
+
46
+ def tokenize(self, texts: list[Dict] | list[str]) -> dict[str, torch.Tensor]:
47
+ img_start_token = "<|jasper_img_start|>"
48
+ img_token = "<|jasper_img_token|>"
49
+ img_end_token = "<|jasper_img_end|>"
50
+ num_img_tokens = 300
51
+
52
+ def process_text_item(item):
53
+ if isinstance(item, str):
54
+ return item, []
55
+ text, images = "", []
56
+ for sub_item in item:
57
+ if sub_item["type"] == "text":
58
+ text += sub_item["content"]
59
+ elif sub_item["type"] == "image_bytes":
60
+ text += img_start_token + img_token * num_img_tokens + img_end_token
61
+ images.append(
62
+ PIL.Image.open(BytesIO(sub_item["content"])).convert("RGB")
63
+ )
64
+ elif sub_item["type"] == "image_path":
65
+ text += img_start_token + img_token * num_img_tokens + img_end_token
66
+ images.append(PIL.Image.open(sub_item["content"]).convert("RGB"))
67
+ else:
68
+ raise ValueError(f"unknown data type {sub_item['type']}")
69
+ return text, images
70
+
71
+ all_texts, all_images = [], []
72
+ for item in texts:
73
+ text, images = process_text_item(item)
74
+ all_texts.append(text)
75
+ all_images.extend(images)
76
+ ipt = self.tokenizer(
77
+ all_texts,
78
+ padding="longest",
79
+ truncation=True,
80
+ max_length=1024,
81
+ return_tensors="pt",
82
+ )
83
+ if all_images:
84
+ ipt["pixel_values"] = self.processor(
85
+ images=all_images, return_tensors="pt"
86
+ )["pixel_values"]
87
+ return ipt
modules.json CHANGED
@@ -1,8 +1,14 @@
1
  [
2
- {
3
- "idx": 0,
4
- "name": "0",
5
- "path": "",
6
- "type": "sentence_transformers.models.Transformer"
7
- }
 
 
 
 
 
 
8
  ]
 
1
  [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "custom_st.MultiModalTransformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Normalize",
12
+ "type": "sentence_transformers.models.Normalize"
13
+ }
14
  ]
sentence_bert_config.json CHANGED
@@ -1,4 +1,7 @@
1
  {
2
  "max_seq_length": 2048,
3
- "do_lower_case": false
 
 
 
4
  }
 
1
  {
2
  "max_seq_length": 2048,
3
+ "do_lower_case": false,
4
+ "tokenizer_args": {
5
+ "padding_side": "right"
6
+ }
7
  }