File size: 3,781 Bytes
e9c76ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
---
license: apache-2.0
---

Same as https://huggingface.co/HuggingFaceM4/siglip-so400m-14-384-flash-attn2 with two changes:
- increase max resolution to 980 x 980 (instead of 384 x 384) by interpolating the position embeddings
- implement the strategy in [NaViT](https://arxiv.org/abs/2307.06304) to allow a/ variable resoltion images, b/ aspect ratio preserved images

These changes only apply to the vision tower. No changes to the text tower.
Implementation is fully backward compatible to `https://huggingface.co/HuggingFaceM4/siglip-so400m-14-384-flash-attn2` -> just don't specify the `patch_attention_mask`


Usage:
```python
import torch
from modeling_siglip import SiglipVisionModel

DEVICE = torch.device("cuda:0")
PATCH_SIZE = 14

pixel_values = torch.randn(2, 3, 28, 42, dtype=torch.bfloat16, device=DEVICE)
pixel_attention_mask = [
    [
        [1] * 14 + [1] * 14  + [1] * 14,
        [1] * 14 + [1] * 14  + [1] * 14,
        [1] * 14 + [1] * 14  + [1] * 14,
        [1] * 14 + [1] * 14  + [1] * 14,
        [1] * 14 + [1] * 14  + [1] * 14,
        [1] * 14 + [1] * 14  + [1] * 14,
        [1] * 14 + [1] * 14  + [1] * 14,
        [1] * 14 + [1] * 14  + [1] * 14,
        [1] * 14 + [1] * 14  + [1] * 14,
        [1] * 14 + [1] * 14  + [1] * 14,
        [1] * 14 + [1] * 14  + [1] * 14,
        [1] * 14 + [1] * 14  + [1] * 14,
        [1] * 14 + [1] * 14  + [1] * 14,
        [1] * 14 + [1] * 14  + [1] * 14,

        [0] * 14 + [0] * 14  + [0] * 14,
        [0] * 14 + [0] * 14  + [0] * 14,
        [0] * 14 + [0] * 14  + [0] * 14,
        [0] * 14 + [0] * 14  + [0] * 14,
        [0] * 14 + [0] * 14  + [0] * 14,
        [0] * 14 + [0] * 14  + [0] * 14,
        [0] * 14 + [0] * 14  + [0] * 14,
        [0] * 14 + [0] * 14  + [0] * 14,
        [0] * 14 + [0] * 14  + [0] * 14,
        [0] * 14 + [0] * 14  + [0] * 14,
        [0] * 14 + [0] * 14  + [0] * 14,
        [0] * 14 + [0] * 14  + [0] * 14,
        [0] * 14 + [0] * 14  + [0] * 14,
        [0] * 14 + [0] * 14  + [0] * 14,
    ],
    [
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,

        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
        [1] * 14 + [1] * 14  + [0] * 14,
    ],
]
pixel_attention_mask = torch.tensor(pixel_attention_mask, dtype=torch.bool, device=DEVICE)
patches_subgrid = pixel_attention_mask.unfold(
    dimension=1, size=PATCH_SIZE, step=PATCH_SIZE
).unfold(dimension=2, size=PATCH_SIZE, step=PATCH_SIZE)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()

model = SiglipVisionModel.from_pretrained("HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit", _flash_attn_2_enabled=True)
model.train()
model.vision_model.to(DEVICE, dtype=torch.bfloat16)

output = model.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
```