gmastrapas commited on
Commit
e7432eb
·
1 Parent(s): 66ee61a

fix: kwargs in custom Sentence Transformer

Browse files
Files changed (1) hide show
  1. custom_st.py +80 -12
custom_st.py CHANGED
@@ -2,7 +2,7 @@ import base64
2
  import json
3
  import os
4
  from io import BytesIO
5
- from typing import Any, Dict, List, Optional, Union
6
 
7
  import requests
8
  import torch
@@ -14,23 +14,91 @@ from transformers import AutoConfig, AutoImageProcessor, AutoModel, AutoTokenize
14
  class Transformer(nn.Module):
15
  def __init__(
16
  self,
17
- model_name_or_path: str,
18
  tokenizer_name_or_path: Optional[str] = None,
19
  image_processor_name_or_path: Optional[str] = None,
20
  max_seq_length: Optional[int] = None,
21
- config_kwargs: Optional[Dict[str, Any]] = None,
22
- model_kwargs: Optional[Dict[str, Any]] = None,
23
- tokenizer_kwargs: Optional[Dict[str, Any]] = None,
24
- image_processor_kwargs: Optional[Dict[str, Any]] = None,
25
- cache_dir: str = None,
 
26
  **_,
27
  ) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  super(Transformer, self).__init__()
29
-
30
- config_kwargs = config_kwargs or {}
31
- model_kwargs = model_kwargs or {}
32
- tokenizer_kwargs = tokenizer_kwargs or {}
33
- image_processor_kwargs = image_processor_kwargs or {}
 
 
 
 
 
 
 
 
 
 
34
 
35
  config = AutoConfig.from_pretrained(
36
  model_name_or_path, cache_dir=cache_dir, **config_kwargs
 
2
  import json
3
  import os
4
  from io import BytesIO
5
+ from typing import Any, Dict, List, Literal, Optional, Union
6
 
7
  import requests
8
  import torch
 
14
  class Transformer(nn.Module):
15
  def __init__(
16
  self,
17
+ model_name_or_path: str = 'jinaai/jina-clip-v2',
18
  tokenizer_name_or_path: Optional[str] = None,
19
  image_processor_name_or_path: Optional[str] = None,
20
  max_seq_length: Optional[int] = None,
21
+ config_args: Optional[Dict[str, Any]] = None,
22
+ model_args: Optional[Dict[str, Any]] = None,
23
+ tokenizer_args: Optional[Dict[str, Any]] = None,
24
+ image_processor_args: Optional[Dict[str, Any]] = None,
25
+ cache_dir: Optional[str] = None,
26
+ backend: Literal['torch', 'onnx', 'openvino'] = 'torch',
27
  **_,
28
  ) -> None:
29
+ """
30
+ Creates a custom SentenceTransformer module that uses `jinai/jina-clip-v2` to
31
+ map sentences/images to embeddings
32
+
33
+ Args:
34
+ model_name_or_path (str, optional): If it is a filepath on disc, it loads
35
+ the model from that path. If it is not a path, tries to construct a
36
+ model from the Hugging Face Hub with that name. Defaults to
37
+ 'jinaai/jina-clip-v2'
38
+ tokenizer_name_or_path (str, optional): If it is a filepath on disc, it
39
+ loads the tokenizer from that path. If it is not a path, tries to
40
+ construct a tokenizer from the Hugging Face Hub with that name.
41
+ If `None` it is automatically set to the value of `model_name_or_path`
42
+ image_processor_name_or_path (str, optional): If it is a filepath on disc,
43
+ it loads the image processor from that path. If it is not a path, tries
44
+ to construct an image processor from the Hugging Face Hub with that
45
+ name. If `None` it is automatically set to the value of
46
+ `model_name_or_path`
47
+ max_seq_length (int, optional): The maximum sequence length of the model.
48
+ If not provided, will be inferred from model or tokenizer
49
+ config_args (Dict[str, Any], optional): Additional model configuration
50
+ parameters to be passed to the Hugging Face Transformers config
51
+ model_args (Dict[str, Any], optional): Additional model configuration
52
+ parameters to be passed to the Hugging Face Transformers model
53
+ tokenizer_args (Dict[str, Any], optional): Additional tokenizer
54
+ configuration parameters to be passed to the Hugging Face Transformers
55
+ tokenizer
56
+ image_processor_args (Dict[str, Any], optional): Additional image processor
57
+ configuration parameters to be passed to the Hugging Face Transformers
58
+ image processor
59
+ cache_dir (str, optional): The Hugging Face Hub cache directory
60
+ backend (str, optional): Computational backend, only 'torch' is supported
61
+
62
+ Example:
63
+ ::
64
+
65
+ from sentence_transformers import SentenceTransformer
66
+
67
+ model = SentenceTransformer(
68
+ 'jinaai/jina-clip-v2', trust_remote_code=True
69
+ )
70
+ sentences_or_images = [
71
+ "The weather is lovely today.",
72
+ "It's so sunny outside!",
73
+ "/path/to/stadium.jpg",
74
+ ]
75
+ embeddings = model.encode(sentences_or_images)
76
+ print(embeddings.shape)
77
+ # (3, 1024)
78
+
79
+ # Get the similarity scores between all inputs
80
+ similarities = model.similarity(embeddings, embeddings)
81
+ print(similarities)
82
+ # tensor([[1.0000, 0.6817, 0.0492],
83
+ # [0.6817, 1.0000, 0.0421],
84
+ # [0.0492, 0.0421, 1.0000]])
85
+ """
86
  super(Transformer, self).__init__()
87
+ if backend != 'torch':
88
+ raise ValueError(
89
+ f'Backend \'{backend}\' is not supported, please use \'torch\' instead'
90
+ )
91
+
92
+ config_kwargs = config_args or {}
93
+ model_kwargs = model_args or {}
94
+ tokenizer_kwargs = tokenizer_args or {}
95
+ image_processor_kwargs = {
96
+ 'token': model_kwargs.get('token', None),
97
+ 'trust_remote_code': model_kwargs.get('trust_remote_code', False),
98
+ 'revision': model_kwargs.get('revision', None),
99
+ 'local_files_only': model_kwargs.get('local_files_only', None),
100
+ }
101
+ image_processor_kwargs.update(image_processor_args)
102
 
103
  config = AutoConfig.from_pretrained(
104
  model_name_or_path, cache_dir=cache_dir, **config_kwargs