wenkai commited on
Commit
a60ca6c
·
verified ·
1 Parent(s): 7c8fc75

Upload 32 files

Browse files
.gitattributes CHANGED
@@ -39,3 +39,5 @@ data/go1.4-basic.obo filter=lfs diff=lfs merge=lfs -text
39
  data/swissprot_exp/train_exp_prompt_bp_new.csv filter=lfs diff=lfs merge=lfs -text
40
  data/swissprot_exp/train_exp_prompt_cc_new.csv filter=lfs diff=lfs merge=lfs -text
41
  data/swissprot_exp/train_exp_prompt_mf_new.csv filter=lfs diff=lfs merge=lfs -text
 
 
 
39
  data/swissprot_exp/train_exp_prompt_bp_new.csv filter=lfs diff=lfs merge=lfs -text
40
  data/swissprot_exp/train_exp_prompt_cc_new.csv filter=lfs diff=lfs merge=lfs -text
41
  data/swissprot_exp/train_exp_prompt_mf_new.csv filter=lfs diff=lfs merge=lfs -text
42
+ dataset_card/imgs/coco_caption.png filter=lfs diff=lfs merge=lfs -text
43
+ docs/_static/merlion.png filter=lfs diff=lfs merge=lfs -text
dataset_card/coco_caption.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![Samples from the COCO Caption dataset (Image credit: "https://arxiv.org/pdf/1504.00325.pdf").](imgs/coco_caption.png)(Samples from the COCO Caption dataset. Image credit: "https://arxiv.org/pdf/1504.00325.pdf")
2
+
3
+ # Microsoft COCO Dataset (Captioning)
4
+
5
+ ## Description
6
+ [Microsoft COCO Captions dataset](https://github.com/tylin/coco-caption) contains over one and a half million captions describing over 330,000 images. For the training and validation images, five independent human generated captions are be provided for each image.
7
+
8
+ ## Task
9
+
10
+ (from https://paperswithcode.com/task/image-captioning)
11
+
12
+ **Image captioning** is the task of describing the content of an image in words. This task lies at the intersection of computer vision and natural language processing. Most image captioning systems use an encoder-decoder framework, where an input image is encoded into an intermediate representation of the information in the image, and then decoded into a descriptive text sequence.
13
+
14
+ ## Metrics
15
+ Models are typically evaluated according to a [BLEU](https://aclanthology.org/P02-1040/) or [CIDER](https://www.cv-foundation.org/openaccess/content_cvpr_2015/papers/Vedantam_CIDEr_Consensus-Based_Image_2015_CVPR_paper.pdf) metric.
16
+
17
+ ## Leaderboard
18
+
19
+ (Ranked by BLEU-4)
20
+
21
+ | Rank | Model | BLEU-4 | CIDEr | METEOR | SPICE | Resources |
22
+ | ---- | :-----: | :----: | :---: | :----: | :---: | :----------------------------------------------------------------------------------------------------------------------------------------------: |
23
+ | 1 | OFA | 44.9 | 154.9 | 32.5 | 26.6 | [paper](https://arxiv.org/abs/2202.03052), [code](https://github.com/OFA-Sys/OFA) |
24
+ | 2 | LEMON | 42.6 | 145.5 | 31.4 | 25.5 | [paper]() |
25
+ | 3 | CoCa | 40.9 | 143.6 | 33.9 | 24.7 | [paper](https://arxiv.org/pdf/2205.01917.pdf) |
26
+ | 4 | SimVLM | 40.6 | 143.3 | 33.7 | 25.4 | [paper](https://openreview.net/pdf?id=GUrhfTuf_3) |
27
+ | 5 | VinVL | 41.0 | 140.9 | 31.1 | 25.2 | [paper](https://arxiv.org/pdf/2101.00529v2.pdf), [code](https://github.com/microsoft/Oscar) |
28
+ | 6 | OSCAR | 40.7 | 140.0 | 30.6 | 24.5 | [paper](https://arxiv.org/pdf/2004.06165v5.pdf), [code](https://github.com/microsoft/Oscar) |
29
+ | 7 | BLIP | 40.4 | 136.7 | 31.4 | 24.3 | [paper](https://arxiv.org/pdf/2201.12086.pdf), [code](https://github.com/salesforce/BLIP), [demo](https://huggingface.co/spaces/Salesforce/BLIP) |
30
+ | 8 | M^2 | 39.1 | 131.2 | 29.2 | 22.6 | [paper](https://arxiv.org/pdf/1912.08226v2.pdf), [code](https://github.com/aimagelab/meshed-memory-transformer) |
31
+ | 9 | BUTD | 36.5 | 113.5 | 27.0 | 20.3 | [paper](https://arxiv.org/abs/1707.07998?context=cs), [code](https://github.com/peteanderson80/bottom-up-attention) |
32
+ | 10 | ClipCap | 32.2 | 108.4 | 27.1 | 20.1 | [paper](https://arxiv.org/pdf/2111.09734v1.pdf), [code](https://github.com/rmokady/clip_prefix_caption) |
33
+
34
+ ## Auto-Downloading
35
+
36
+ ```
37
+ cd lavis/datasets/download_scripts && python download_coco.py
38
+ ```
39
+
40
+ ## References
41
+ "Microsoft COCO Captions: Data Collection and Evaluation Server", Xinlei Chen, Hao Fang, Tsung-Yi Lin, Ramakrishna Vedantam, Saurabh Gupta, Piotr Dollar, C. Lawrence Zitnick
dataset_card/imgs/coco_caption.png ADDED

Git LFS Details

  • SHA256: 950cfe987d6da5aad5ece7ca774ad73c1a24af7fcfda328536a8ce56eeaaf1b8
  • Pointer size: 132 Bytes
  • Size of remote file: 2.45 MB
dataset_card/protein_function.md ADDED
File without changes
docs/Makefile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Minimal makefile for Sphinx documentation
2
+ #
3
+
4
+ # You can set these variables from the command line, and also
5
+ # from the environment for the first two.
6
+ SPHINXOPTS ?=
7
+ SPHINXBUILD ?= sphinx-build
8
+ SOURCEDIR = source
9
+ BUILDDIR = build
10
+
11
+ # Put it first so that "make" without argument is like "make help".
12
+ help:
13
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14
+
15
+ .PHONY: help Makefile
16
+
17
+ # Catch-all target: route all unknown targets to Sphinx using the new
18
+ # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19
+ %: Makefile
20
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
docs/_static/Confusing-Pictures.jpg ADDED
docs/_static/architecture.png ADDED
docs/_static/logo_final.png ADDED
docs/_static/merlion.png ADDED

Git LFS Details

  • SHA256: f1f3b6a507ec92e8f47ac6d7c64e11b03fcba8c550bcb6851f80e261e8951431
  • Pointer size: 132 Bytes
  • Size of remote file: 1.6 MB
docs/benchmark.rst ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Benchmark
2
+ ############
3
+
4
+ We provide scripts for evaluating and training models on task datasets. The following benchmark results are included for reference.
5
+
6
+
7
+ ALBEF
8
+ *******
9
+ .. list-table::
10
+ :widths: 30 80 20
11
+
12
+ * - **Pretraining**
13
+ - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
14
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/pretrain.sh>`__
15
+ * -
16
+ - Visual Genome (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_vg.py>`__)
17
+ -
18
+ * -
19
+ - SBU (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_sbu.py>`__)
20
+ -
21
+ * -
22
+ - CC3M (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/DownloadConceptualCaptions/download_data_cc3m.py>`__)
23
+ -
24
+ * -
25
+ - CC12M (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/DownloadConceptualCaptions/download_data_cc12m.py>`__)
26
+ -
27
+
28
+ .. list-table::
29
+ :widths: 30 40 20 20 20 30 30
30
+ :header-rows: 1
31
+
32
+ * -
33
+ - **Retrieval**
34
+ - **R1**
35
+ - **R5**
36
+ - **R10**
37
+ - **Training**
38
+ - **Evaluation**
39
+ * - TR
40
+ - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
41
+ - 77.6
42
+ - 94.1
43
+ - 97.2
44
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_coco_retrieval_albef.sh>`__
45
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/eval_coco_retrieval.sh>`__
46
+ * - IR
47
+ - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
48
+ - 61.0
49
+ - 84.5
50
+ - 90.7
51
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_coco_retrieval_albef.sh>`__
52
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/eval_coco_retrieval.sh>`__
53
+ * - TR
54
+ - Flickr30k (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_flickr.py>`__)
55
+ - 77.6
56
+ - 94.1
57
+ - 97.2
58
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_flickr30k_retrieval_albef.sh>`__
59
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/eval_flickr30k_retrieval.sh>`__
60
+ * - IR
61
+ - Flickr30k (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_flickr.py>`__)
62
+ - 61.0
63
+ - 84.5
64
+ - 90.7
65
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_flickr30k_retrieval_albef.sh>`__
66
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/eval_flickr30k_retrieval.sh>`__
67
+
68
+
69
+ .. list-table::
70
+ :widths: 20 20 20 20 20
71
+ :header-rows: 1
72
+
73
+ * - **VQA**
74
+ - **test-dev**
75
+ - **test-std/test**
76
+ - **Training**
77
+ - **Evaluation**
78
+ * - VQAv2 (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
79
+ - 76.35
80
+ - 76.54
81
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_vqa_albef.sh>`__
82
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/test_albef_vqa.sh>`__
83
+ * - OKVQA (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
84
+ - NA
85
+ - 54.7
86
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_okvqa_albef.sh>`__
87
+ - NA
88
+ * - AOKVQA (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
89
+ - 54.5
90
+ - NA
91
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_aokvqa_albef.sh>`__
92
+ - NA
93
+
94
+
95
+ .. list-table::
96
+ :widths: 20 20 20 20 20
97
+ :header-rows: 1
98
+
99
+ * - **Multimodal Classification**
100
+ - **val**
101
+ - **test**
102
+ - **Training**
103
+ - **Evaluation**
104
+ * - SNLI-VE (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
105
+ - 80.60
106
+ - 81.04
107
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_ve_albef.sh>`__
108
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/eval_albef_ve.sh>`__
109
+ * - NLVR2 (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
110
+ - 82.47
111
+ - 82.91
112
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_nlvr_albef.sh>`__
113
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/eval_albef_nlvr.sh>`__
114
+
115
+ BLIP
116
+ *******
117
+ .. list-table::
118
+ :widths: 30 80 20
119
+
120
+ * - **Pretraining (14M)**
121
+ - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
122
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/pretrain.sh>`__
123
+ * -
124
+ - Visual Genome (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_vg.py>`__)
125
+ -
126
+ * -
127
+ - SBU (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_sbu.py>`__)
128
+ -
129
+ * -
130
+ - CC3M (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/DownloadConceptualCaptions/download_data_cc3m.py>`__)
131
+ -
132
+ * -
133
+ - CC12M (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/DownloadConceptualCaptions/download_data_cc12m.py>`__)
134
+ -
135
+
136
+ .. list-table::
137
+ :widths: 30 40 20 20 20 30 30
138
+ :header-rows: 1
139
+
140
+ * - **Tasks**
141
+ - **Retrieval**
142
+ - **R1**
143
+ - **R5**
144
+ - **R10**
145
+ - **Training**
146
+ - **Evaluation**
147
+ * - TR
148
+ - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
149
+ - 82.0
150
+ - 95.8
151
+ - 98.1
152
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_retrieval_coco.sh>`__
153
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_ret_coco.sh>`__
154
+ * - IR
155
+ - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
156
+ - 64.5
157
+ - 86.0
158
+ - 91.7
159
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_retrieval_coco.sh>`__
160
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_ret_coco.sh>`__
161
+ * - TR
162
+ - Flickr30k (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_flickr.py>`__)
163
+ - 96.9
164
+ - 99.9
165
+ - 100.0
166
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_retrieval_flickr.sh>`__
167
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_ret_flickr.sh>`__
168
+ * - IR
169
+ - Flickr30k (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_flickr.py>`__)
170
+ - 87.5
171
+ - 97.6
172
+ - 98.9
173
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_retrieval_flickr.sh>`__
174
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_ret_flickr.sh>`__
175
+
176
+
177
+ .. list-table::
178
+ :widths: 20 20 20 20 20
179
+ :header-rows: 1
180
+
181
+ * - **VQA**
182
+ - **test-dev**
183
+ - **test-std/test**
184
+ - **Training**
185
+ - **Evaluation**
186
+ * - VQAv2 (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
187
+ - 78.23
188
+ - 78.29
189
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/train/train_vqa_albef.sh>`__
190
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/albef/eval/test_albef_vqa.sh>`__
191
+ * - OKVQA (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
192
+ - NA
193
+ - 55.4
194
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_okvqa.sh>`__
195
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_okvqa.sh>`__
196
+ * - AOKVQA (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
197
+ - 56.2
198
+ - 50.1
199
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_aokvqa.sh>`__
200
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_aokvqa.sh>`__
201
+
202
+
203
+ .. list-table::
204
+ :widths: 20 20 20 20 20 20
205
+ :header-rows: 1
206
+
207
+ * - **Image Captioning**
208
+ - **BLEU@4**
209
+ - **CIDEr**
210
+ - **SPICE**
211
+ - **Training**
212
+ - **Evaluation**
213
+ * - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
214
+ - 39.9
215
+ - 133.5
216
+ - 23.7
217
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_caption_coco.sh>`__
218
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_coco_cap.sh>`__
219
+ * - NoCaps (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_nocaps.py>`__)
220
+ - 31.9
221
+ - 109.1
222
+ - 14.7
223
+ - NA
224
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_nocaps.sh>`__
225
+
226
+
227
+ .. list-table::
228
+ :widths: 20 20 20 20 20
229
+ :header-rows: 1
230
+
231
+ * - **Multimodal Classification**
232
+ - **val**
233
+ - **test**
234
+ - **Training**
235
+ - **Evaluation**
236
+ * - NLVR2 (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
237
+ - 82.48
238
+ - 83.25
239
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/train/train_nlvr.sh>`__
240
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/blip/eval/eval_nlvr.sh>`__
241
+
242
+ CLIP
243
+ *******
244
+ .. list-table::
245
+ :widths: 30 40 20 20 20 30
246
+ :header-rows: 1
247
+
248
+ * - **Tasks**
249
+ - **Retrieval (Zero-shot)**
250
+ - **R1**
251
+ - **R5**
252
+ - **R10**
253
+ - **Evaluation**
254
+ * - TR
255
+ - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
256
+ - 57.2
257
+ - 80.5
258
+ - 87.8
259
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/clip/eval/eval_clip_ret_coco.sh>`__
260
+ * - IR
261
+ - COCO (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_coco.py>`__)
262
+ - 36.5
263
+ - 60.8
264
+ - 71.0
265
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/clip/eval/eval_clip_ret_coco.sh>`__
266
+ * - TR
267
+ - Flickr30k (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_flickr.py>`__)
268
+ - 86.5
269
+ - 98.0
270
+ - 99.1
271
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/clip/eval/eval_clip_ret_flickr.sh>`__
272
+ * - IR
273
+ - Flickr30k (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_flickr.py>`__)
274
+ - 67.0
275
+ - 88.9
276
+ - 93.3
277
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/clip/eval/eval_clip_ret_flickr.sh>`__
278
+
279
+ .. list-table::
280
+ :widths: 20 20 20
281
+ :header-rows: 1
282
+
283
+ * - **Multimodal Classification**
284
+ - **val**
285
+ - **Evaluation**
286
+ * - ImageNet
287
+ - 76.5
288
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/clip/eval/eval_clip_zs_imnet.sh>`__
289
+
290
+
291
+ ALPRO
292
+ *******
293
+ .. list-table::
294
+ :widths: 30 40 20 20 20 20 30
295
+ :header-rows: 1
296
+
297
+ * - **Tasks**
298
+ - **Retrieval**
299
+ - **R1**
300
+ - **R5**
301
+ - **R10**
302
+ - **Training**
303
+ - **Evaluation**
304
+ * - TR
305
+ - MSRVTT (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_msrvtt.py>`__)
306
+ - 33.2
307
+ - 60.5
308
+ - 71.7
309
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/train/train_msrvtt_ret.sh>`__
310
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/eval/eval_msrvtt_ret.sh>`__
311
+ * - VR
312
+ - MSRVTT (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_msrvtt.py>`__)
313
+ - 33.8
314
+ - 61.4
315
+ - 72.7
316
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/train/train_msrvtt_ret.sh>`__
317
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/eval/eval_msrvtt_ret.sh>`__
318
+ * - TR
319
+ - DiDeMo (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_didemo.py>`__)
320
+ - 38.8
321
+ - 66.4
322
+ - 76.8
323
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/train/train_didemo_ret.sh>`__
324
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/eval/eval_didemo_ret.sh>`__
325
+ * - VR
326
+ - DiDeMo (`download <https://github.com/salesforce/LAVIS/blob/main/lavis/datasets/download_scripts/download_didemo.py>`__)
327
+ - 36.6
328
+ - 67.5
329
+ - 77.9
330
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/train/train_didemo_ret.sh>`__
331
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/eval/eval_didemo_ret.sh>`__
332
+
333
+ .. list-table::
334
+ :widths: 20 20 20 20
335
+ :header-rows: 1
336
+
337
+ * - **Video QA**
338
+ - **test**
339
+ - **Training**
340
+ - **Evaluation**
341
+ * - MSRVTT
342
+ - 42.1
343
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/train/train_msrvtt_qa.sh>`__
344
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/eval/eval_msrvtt_qa.sh>`__
345
+ * - MSVD
346
+ - 46.0
347
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/train/train_msvd_qa.sh>`__
348
+ - `script <https://github.com/salesforce/LAVIS/blob/main/run_scripts/alpro/eval/eval_msvd_qa.sh>`__
docs/build_docs.sh ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+
4
+ # Change to root directory of repo
5
+ DIRNAME=$(cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
6
+ cd "${DIRNAME}/.."
7
+
8
+ # # Set up virtual environment
9
+ pip3 install setuptools wheel virtualenv
10
+ if [ ! -d venv ]; then
11
+ rm -f venv
12
+ virtualenv venv
13
+ fi
14
+ source venv/bin/activate
15
+
16
+ # # Get current git branch & stash unsaved changes
17
+ GIT_BRANCH=$(git branch --show-current)
18
+ if [ -z "${GIT_BRANCH}" ]; then
19
+ GIT_BRANCH="main"
20
+ fi
21
+ git stash
22
+
23
+ # Set up exit handler to restore git state & delete temp branches
24
+ # function exit_handler {
25
+ # git reset --hard
26
+ # git checkout "${GIT_BRANCH}" --
27
+ # git stash pop || true
28
+ # for version in $(git tag --list 'v[0-9]*'); do
29
+ # branch="${version}_local_docs_only"
30
+ # if git show-ref --verify --quiet "refs/heads/$branch"; then
31
+ # git branch -D "$branch"
32
+ # fi
33
+ # done
34
+ # }
35
+ # trap exit_handler EXIT
36
+
37
+ # Clean up build directory and install Sphinx requirements
38
+ pip3 install -r "${DIRNAME}/requirements.txt"
39
+ sphinx-build -M clean "${DIRNAME}" "${DIRNAME}/_build"
40
+
41
+ # Build API docs for current head
42
+ export current_version="latest"
43
+ pip3 install "."
44
+ sphinx-build -b html "${DIRNAME}" "${DIRNAME}/_build/html/${current_version}" -W --keep-going
45
+ rm -rf "${DIRNAME}/_build/html/${current_version}/.doctrees"
46
+ #pip3 uninstall -y omnixai
47
+
48
+ # Install all previous released versions
49
+ # and use them to build the appropriate API docs.
50
+ # Uninstall after we're done with each one.
51
+ # versions=()
52
+ # checkout_files=("${DIRNAME}/*.rst" "lavis" "tutorials" "setup.py")
53
+ # for version in $(git tag --list 'v[0-9]*'); do
54
+ # versions+=("$version")
55
+ # git checkout -b "${version}_local_docs_only"
56
+ # for f in $(git diff --name-only --diff-filter=A "tags/${version}" "${DIRNAME}/*.rst"); do
57
+ # git rm "$f"
58
+ # done
59
+ # git checkout "tags/${version}" -- "${checkout_files[@]}"
60
+ # export current_version=${version}
61
+ # pip3 install ".[all]"
62
+ # sphinx-build -b html "${DIRNAME}" "${DIRNAME}/_build/html/${current_version}" -W --keep-going
63
+ # rm -rf "${DIRNAME}/_build/html/${current_version}/.doctrees"
64
+ # #pip3 uninstall -y omnixai
65
+ # git reset --hard
66
+ # git checkout "${GIT_BRANCH}" --
67
+ # done
68
+
69
+ # Determine the latest stable version if there is one
70
+ # if (( ${#versions[@]} > 0 )); then
71
+ # stable_hash=$(git rev-list --tags --max-count=1)
72
+ # stable_version=$(git describe --tags "$stable_hash")
73
+ # export stable_version
74
+ # else
75
+ export stable_version="latest"
76
+ # fi
77
+
78
+ # Create dummy HTML's for the stable version in the base directory
79
+ while read -r filename; do
80
+ filename=$(echo "$filename" | sed "s/\.\///")
81
+ n_sub=$(echo "$filename" | (grep -o "/" || true) | wc -l)
82
+ prefix=""
83
+ for (( i=0; i<n_sub; i++ )); do
84
+ prefix+="../"
85
+ done
86
+ url="${prefix}${stable_version}/$filename"
87
+ mkdir -p "${DIRNAME}/_build/html/$(dirname "$filename")"
88
+ cat > "${DIRNAME}/_build/html/$filename" <<EOF
89
+ <!DOCTYPE html>
90
+ <html>
91
+ <head>
92
+ <title>LAVIS Documentation</title>
93
+ <meta http-equiv = "refresh" content="0; url='$url'" />
94
+ </head>
95
+ <body>
96
+ <p>Please wait while you're redirected to our <a href="$url">documentation</a>.</p>
97
+ </body>
98
+ </html>
99
+ EOF
100
+ done < <(cd "${DIRNAME}/_build/html/$stable_version" && find . -name "*.html")
101
+ echo "Finished writing to _build/html."
docs/conf.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration file for the Sphinx documentation builder.
2
+ #
3
+ # This file only contains a selection of the most common options. For a full
4
+ # list see the documentation:
5
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html
6
+
7
+ # -- Path setup --------------------------------------------------------------
8
+
9
+ # If extensions (or modules to document with autodoc) are in another directory,
10
+ # add these directories to sys.path here. If the directory is relative to the
11
+ # documentation root, use os.path.abspath to make it absolute, like shown here.
12
+ #
13
+ # import os
14
+ # import sys
15
+ # sys.path.insert(0, os.path.abspath('.'))
16
+
17
+
18
+ # -- Project information -----------------------------------------------------
19
+
20
+ project = "LAVIS"
21
+ copyright = "2022, salesforce.com inc."
22
+ author = (
23
+ "Dongxu Li, Junnan Li, Hung Le, Guangsen Wang, Silvio Savarese, Steven C.H. Hoi"
24
+ )
25
+
26
+
27
+ # -- General configuration ---------------------------------------------------
28
+
29
+ # Add any Sphinx extension module names here, as strings. They can be
30
+ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
31
+ # ones.
32
+ extensions = ["nbsphinx"]
33
+
34
+ # Add any paths that contain templates here, relative to this directory.
35
+ templates_path = ["_templates"]
36
+
37
+ # List of patterns, relative to source directory, that match files and
38
+ # directories to ignore when looking for source files.
39
+ # This pattern also affects html_static_path and html_extra_path.
40
+ exclude_patterns = []
41
+
42
+
43
+ # -- Options for HTML output -------------------------------------------------
44
+
45
+ # The theme to use for HTML and HTML Help pages. See the documentation for
46
+ # a list of builtin themes.
47
+ #
48
+ # html_theme = "alabaster"
49
+ html_theme = "sphinx_rtd_theme"
50
+
51
+ # Add any paths that contain custom static files (such as style sheets) here,
52
+ # relative to this directory. They are copied after the builtin static files,
53
+ # so a file named "default.css" will overwrite the builtin "default.css".
54
+ html_static_path = ["_static"]
55
+
56
+ # pygments_style = "sphinx"
docs/getting_started.rst ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Dataset Zoo
2
+ ##################
3
+ LAVIS inherently supports a wide variety of common language-vision datasets by providing automatic download scripts to help download and organize these datasets;
4
+ and implements PyTorch datasets for these datasets. To view supported datasets, use the following code:
5
+
6
+ .. code-block:: python
7
+
8
+ from lavis.datasets.builders import dataset_zoo
9
+ dataset_names = dataset_zoo.get_names()
10
+ print(dataset_names)
11
+ # ['aok_vqa', 'coco_caption', 'coco_retrieval', 'coco_vqa', 'conceptual_caption_12m',
12
+ # 'conceptual_caption_3m', 'didemo_retrieval', 'flickr30k', 'imagenet', 'laion2B_multi',
13
+ # 'msrvtt_caption', 'msrvtt_qa', 'msrvtt_retrieval', 'msvd_caption', 'msvd_qa', 'nlvr',
14
+ # 'nocaps', 'ok_vqa', 'sbu_caption', 'snli_ve', 'vatex_caption', 'vg_caption', 'vg_vqa']
15
+ print(len(dataset_names))
16
+ # 23
17
+
18
+
19
+ Auto-Downloading and Loading Datasets
20
+ ######################################
21
+ We now take COCO caption dataset as an example to demonstrate how to download and prepare the dataset.
22
+
23
+ In ``lavis/datasets/download_scripts/``, we provide tools to download most common public language-vision datasets supported by LAVIS.
24
+ The COCO caption dataset uses images from COCO dataset. Therefore, we first download COCO images via:
25
+
26
+ .. code-block:: bash
27
+
28
+ cd lavis/datasets/download_scripts/ && python download_coco.py
29
+
30
+ This will automatically download and extract COCO images to the default LAVIS cache location.
31
+ The default cache location is ``~/.cache/lavis``, defined in ``lavis/configs/default.yaml``.
32
+
33
+ After downloading the images, we can use ``load_dataset()`` to obtain the dataset. On the first run, this will automatically download and cache annotation files.
34
+
35
+ .. code-block:: python
36
+
37
+ from lavis.datasets.builders import load_dataset
38
+ coco_dataset = load_dataset("coco_caption")
39
+
40
+ print(coco_dataset.keys())
41
+ # dict_keys(['train', 'val', 'test'])
42
+
43
+ print(len(coco_dataset["train"]))
44
+ # 566747
45
+
46
+ print(coco_dataset["train"][0])
47
+ # {'image': <PIL.Image.Image image mode=RGB size=640x480>,
48
+ # 'text_input': 'A woman wearing a net on her head cutting a cake. ',
49
+ # 'image_id': 0}
50
+
51
+ If you already host a local copy of the dataset, you can pass in the ``vis_path`` argument to change the default location to load images.
52
+
53
+ .. code-block:: python
54
+
55
+ coco_dataset = load_dataset("coco_caption", vis_path=YOUR_LOCAL_PATH)
56
+
57
+
58
+ Model Zoo
59
+ ####################################
60
+ LAVIS supports a growing list of pre-trained models for different tasks,
61
+ datatsets and of varying sizes. Let's get started by viewing the supported models.
62
+
63
+ .. code-block:: python
64
+
65
+ from lavis.models import model_zoo
66
+ print(model_zoo)
67
+ # ==================================================
68
+ # Architectures Types
69
+ # ==================================================
70
+ # albef_classification base, ve
71
+ # albef_nlvr base
72
+ # albef_pretrain base
73
+ # albef_retrieval base, coco, flickr
74
+ # albef_vqa base, vqav2
75
+ # alpro_qa base, msrvtt, msvd
76
+ # alpro_retrieval base, msrvtt, didemo
77
+ # blip_caption base, base_coco, large, large_coco
78
+ # blip_classification base
79
+ # blip_feature_extractor base
80
+ # blip_nlvr base
81
+ # blip_pretrain base
82
+ # blip_retrieval base, coco, flickr
83
+ # blip_vqa base, vqav2
84
+ # clip ViT-B-32, ViT-B-16, ViT-L-14, ViT-L-14-336, RN50
85
+
86
+ # show total number of support model variants
87
+ len(model_zoo)
88
+ # 33
89
+
90
+
91
+ Inference with Pre-trained Models
92
+ ####################################
93
+
94
+ Now let's see how to use models in LAVIS to perform inference on example data. We first
95
+ load a sample image from local.
96
+
97
+ .. code-block:: python
98
+
99
+ from PIL import Image
100
+
101
+ # setup device to use
102
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
103
+
104
+ # load sample image
105
+ raw_image = Image.open("docs/_static/merlion.png").convert("RGB")
106
+
107
+ This example image shows `Merlion park <https://en.wikipedia.org/wiki/Merlion>`_ (`image credit <https://theculturetrip.com/asia/singapore/articles/what-exactly-is-singapores-merlion-anyway/>`_), a landmark in Singapore.
108
+
109
+ .. image:: _static/merlion.png
110
+
111
+ Image Captioning
112
+ *******************************
113
+ We now use the BLIP model to generate a caption for the image. To make inference even easier, we also associate each
114
+ pre-trained model with its preprocessors (transforms), we use ``load_model_and_preprocess()`` with the following arguments:
115
+
116
+ - ``name``: The name of the model to load. This could be a pre-trained model, task model, or feature extractor. See ``model_zoo`` for a full list of model names.
117
+ - ``model_type``: Each architecture has variants trained on different datasets and at different scale. See Types column in ``model_zoo`` for a full list of model types.
118
+ - ``is_eval``: if `True`, set the model to evaluation mode. This is desired for inference or feature extraction.
119
+ - ``device``: device to load the model to.
120
+
121
+ .. code-block:: python
122
+
123
+ from lavis.models import load_model_and_preprocess
124
+ # loads BLIP caption base model, with finetuned checkpoints on MSCOCO captioning dataset.
125
+ # this also loads the associated image processors
126
+ model, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=device)
127
+
128
+ # preprocess the image
129
+ # vis_processors stores image transforms for "train" and "eval" (validation / testing / inference)
130
+ image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
131
+
132
+ # generate caption
133
+ model.generate({"image": image})
134
+ # ['a large fountain spewing water into the air']
135
+
136
+
137
+ You may also load models and their preprocessors separately via ``load_model()`` and ``load_processor()``.
138
+ In BLIP, you can also generate diverse captions by turning nucleus sampling on.
139
+
140
+ .. code-block:: python
141
+
142
+ from lavis.processors import load_processor
143
+ from lavis.models import load_model
144
+
145
+ # load image preprocesser used for BLIP
146
+ vis_processor = load_processor("blip_image_eval").build(image_size=384)
147
+ model = load_model(name="blip_caption", model_type="base_coco", is_eval=True, device=device)
148
+
149
+ image = vis_processor(image).unsqueeze(0).to(device)
150
+ model.generate({"image": raw_image}, use_nucleus_sampling=True)
151
+ # one generated random sample: ['some very pretty buildings and some water jets']
152
+
153
+
154
+ Visual question answering (VQA)
155
+ *******************************
156
+ BLIP model is able to answer free-form questions about images in natural language.
157
+ To access the VQA model, simply replace the ``name`` and ``model_type`` arguments
158
+ passed to ``load_model_and_preprocess()``.
159
+
160
+ .. code-block:: python
161
+
162
+ from lavis.models import load_model_and_preprocess
163
+ model, vis_processors, txt_processors = load_model_and_preprocess(name="blip_vqa", model_type="vqav2", is_eval=True, device=device)
164
+
165
+ # ask a random question.
166
+ question = "Which city is this photo taken?"
167
+
168
+ image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
169
+ question = txt_processors["eval"](question)
170
+
171
+ model.predict_answers(samples={"image": image, "text_input": question}, inference_method="generate")
172
+ # ['singapore']
173
+
174
+
175
+ Unified Feature Extraction Interface
176
+ ####################################
177
+
178
+ LAVIS provides a unified interface to extract multimodal features from each architecture.
179
+ To extract features, we load the feature extractor variants of each model.
180
+ The multimodal feature can be used for multimodal classification. The low-dimensional unimodal features can be used to compute cross-modal similarity.
181
+
182
+ .. code-block:: python
183
+
184
+ from lavis.models import load_model_and_preprocess
185
+
186
+ model, vis_processors, txt_processors = load_model_and_preprocess(name="blip_feature_extractor", model_type="base", is_eval=True, device=device)
187
+ caption = "a large fountain spewing water into the air"
188
+
189
+ image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
190
+ text_input = txt_processors["eval"](caption)
191
+
192
+ sample = {"image": image, "text_input": [text_input]}
193
+
194
+ features_multimodal = model.extract_features(sample)
195
+ print(features_multimodal.keys())
196
+ # odict_keys(['image_embeds', 'multimodal_embeds'])
197
+ print(features_multimodal.multimodal_embeds.shape)
198
+ # torch.Size([1, 12, 768]), use features_multimodal[:, 0, :] for multimodal classification tasks
199
+
200
+ features_image = model.extract_features(sample, mode="image")
201
+ print(features_image.keys())
202
+ # odict_keys(['image_embeds', 'image_embeds_proj'])
203
+ print(features_image.image_embeds.shape)
204
+ # torch.Size([1, 197, 768])
205
+ print(features_image.image_embeds_proj.shape)
206
+ # torch.Size([1, 197, 256])
207
+
208
+ features_text = model.extract_features(sample, mode="text")
209
+ print(features_text.keys())
210
+ # odict_keys(['text_embeds', 'text_embeds_proj'])
211
+ print(features_text.text_embeds.shape)
212
+ # torch.Size([1, 12, 768])
213
+ print(features_text.text_embeds_proj.shape)
214
+ # torch.Size([1, 12, 256])
215
+
216
+ similarity = features_image.image_embeds_proj[:, 0, :] @ features_text.text_embeds_proj[:, 0, :].t()
217
+ print(similarity)
218
+ # tensor([[0.2622]])
219
+
220
+ Since LAVIS supports a unified feature extraction interface, minimal changes are necessary to use a different model as feature extractor. For example,
221
+ to use ALBEF as the feature extractor, one only needs to change the following line:
222
+
223
+ .. code-block:: python
224
+
225
+ model, vis_processors, txt_processors = load_model_and_preprocess(name="albef_feature_extractor", model_type="base", is_eval=True, device=device)
226
+
227
+ Similarly, to use CLIP as feature extractor:
228
+
229
+ .. code-block:: python
230
+
231
+ model, vis_processors, txt_processors = load_model_and_preprocess(name="clip_feature_extractor", model_type="base", is_eval=True, device=device)
232
+ # model, vis_processors, txt_processors = load_model_and_preprocess(name="clip_feature_extractor", model_type="RN50", is_eval=True, device=device)
233
+ # model, vis_processors, txt_processors = load_model_and_preprocess(name="clip_feature_extractor", model_type="ViT-L-14", is_eval=True, device=device)
docs/index.rst ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. LAVIS documentation master file, created by
2
+ sphinx-quickstart on Sun Jul 31 10:32:27 2022.
3
+ You can adapt this file completely to your liking, but it should at least
4
+ contain the root `toctree` directive.
5
+
6
+ Welcome to LAVIS's documentation!
7
+ =================================
8
+
9
+ .. toctree::
10
+ :maxdepth: 1
11
+ :caption: Introduction
12
+
13
+ intro
14
+
15
+
16
+ .. toctree::
17
+ :maxdepth: 1
18
+ :caption: Getting Started
19
+
20
+ getting_started
21
+
22
+
23
+ .. :maxdepth: 1
24
+ .. :caption: Advanced Training
25
+
26
+ .. advanced_training
27
+
28
+
29
+ .. toctree::
30
+ :maxdepth: 2
31
+ :caption: Advanced Usage
32
+
33
+ benchmark
34
+ tutorial
35
+
36
+
37
+ .. Documentations
38
+ .. ===================
39
+
40
+
41
+ Indices and tables
42
+ ==================
43
+
44
+ * :ref:`genindex`
45
+ * :ref:`modindex`
46
+ * :ref:`search`
docs/intro.rst ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ What is LAVIS?
2
+ ####################################
3
+
4
+ LAVIS is a Python deep learning library for LAnguage-and-VISion research and applications.
5
+ It features a unified design to access state-of-the-art foundation language-vision models (`ALBEF <https://arxiv.org/pdf/2107.07651.pdf>`_,
6
+ `BLIP <https://arxiv.org/pdf/2201.12086.pdf>`_, `ALPRO <https://arxiv.org/pdf/2112.09583.pdf>`_, `CLIP <https://arxiv.org/pdf/2103.00020.pdf>`_), common tasks
7
+ (retrieval, captioning, visual question answering, multimodal classification etc.) and datasets (COCO, Flickr, Nocaps, Conceptual
8
+ Commons, SBU, etc.).
9
+
10
+ This library aims to provide engineers and researchers with a one-stop solution to rapidly develop models for their specific multimodal
11
+ scenarios, and benchmark them across standard and customized datasets.
12
+
13
+ Key features of LAVIS include:
14
+
15
+ - **Modular and Extensible Library Design**: facilitating to easily utilize and repurpose existing modules (datasets, models, preprocessors), also to add new modules.
16
+
17
+ - **Easy Off-the-shelf Inference and Feature Extraction**: readily available pre-trained models let you take advantage of state-of-the-art multimodal understanding and generation capabilities on your own data.
18
+
19
+ - **Reproducible Model Zoo**: provided training/pre-training recipies to easily replicate and extend state-of-the-art models.
20
+
21
+ - **Dataset Zoo and Automatic Downloading Tools**: it can be a hassle to prepare the many language-vision datasets. LAVIS provides automatic downloaing scripts to help prepare a large variety of datasets and their annotations.
22
+
23
+ Other features include:
24
+
25
+ - **Distributed Training** using multiple GPUs on one machine or across multiple machines.
26
+
27
+ - **Web Demo**: try supported models on your own pictures, questions etc.
28
+
29
+ - **Leaderboard**: comparing state-of-the-art models across standard datasets.
30
+
31
+ - **Dataset Explorer**: help browse and understand language-vision datasets.
32
+
33
+ Supported Tasks, Models and Datasets
34
+ ####################################
35
+
36
+ The following table shows the supported models and language-vision tasks by LAVIS. Adapting existing models to more tasks is possible and next to come in future releases.
37
+
38
+ ======================================== =========================== ============================================= ============
39
+ Tasks Supported Models Supported Datasets Modalities
40
+ ======================================== =========================== ============================================= ============
41
+ Image-text Pre-training ALBEF, BLIP COCO, VisualGenome, SBU, ConceptualCaptions image, text
42
+ Image-text Retrieval ALBEF, BLIP, CLIP COCO, Flickr30k image, text
43
+ Text-image Retrieval ALBEF, BLIP, CLIP COCO, Flickr30k image, text
44
+ Visual Question Answering ALBEF, BLIP VQAv2, OKVQA, A-OKVQA image, text
45
+ Image Captioning BLIP COCO, NoCaps image, text
46
+ Image Classification CLIP ImageNet image
47
+ Natural Language Visual Reasoning (NLVR) ALBEF, BLIP NLVR2 image, text
48
+ Visual Entailment (VE) ALBEF SNLI-VE image, text
49
+ Visual Dialogue BLIP VisDial image, text
50
+ Video-text Retrieval BLIP, ALPRO MSRVTT, DiDeMo video, text
51
+ Text-video Retrieval BLIP, ALPRO MSRVTT, DiDeMo video, text
52
+ Video Question Answering (VideoQA) BLIP, ALPRO MSRVTT, MSVD video, text
53
+ Video Dialogue VGD-GPT AVSD video, text
54
+ Multimodal Feature Extraction ALBEF, CLIP, BLIP, ALPRO customized image, text
55
+ ======================================== =========================== ============================================= ============
56
+
57
+ Library Design
58
+ ####################################
59
+
60
+ .. image:: _static/architecture.png
61
+ :width: 550
62
+
63
+ LAVIS has six key modules.
64
+
65
+ - ``lavis.runners`` manages the overall training and evaluation lifecycle. It is also responsible for creating required components lazily as per demand, such as optimizers, learning rate schedulers and dataloaders. Currently ``RunnerBase`` implements epoch-based training and ``RunerIters`` implements iteration-based training.
66
+ - ``lavis.tasks`` implements concrete training and evaluation logic per task. A task could be, for example, retrieval, captioning, pre-training. The rationale to have an abstraction of task is to accommodate task-specific training and evaluation. For example, evaluating a retrieval model is different from a classification model.
67
+ - ``lavis.datasets`` is responsible for creating datasets, where ``lavis.datasets.builders`` loads dataset configurations, downloads annotations and returns a dataset object; ``lavis.datasets.datasets`` defines the supported datasets, each is a ``torch.utils.data.Dataset`` instance. We also provide `automatic dataset downloading tools` in ``datasets/download_scripts`` to help prepare common public datasets.
68
+ - ``lavis.models`` holds definition for the supported models and shared model layers.
69
+ - ``lavis.processors`` handles preprocessing of text and images/videos before feeding the model. For images and videos, a processor can be thought as transfroms in torchvision; for text input, this may include lowering case, truncation etc.
70
+ - ``lavis.common`` module contains shared classes and methods used by multiple other modules. For example,
71
+
72
+ - ``lavis.common.config`` contains classes to store and manipulate configuration files used by LAVIS. In particular, we use a hierarchical configuration design, to allow highly customizable training and evaluation.
73
+ - ``lavis.common.registry`` serves as a centralized place to manage modules that share the same functionalities. It allows building datasets, models, tasks, and learning rate schedulers during runtime, by specifying their names as string in the configuration file.
74
+ - ``lavis.common.optims`` contains definitions of learning rate schedulers.
75
+ - ``lavis.common.dist_utils`` contains utilities for distributed training and evaluation.
76
+ - ``lavis.common.utils`` contains miscellaneous utilities, mostly IO-related helper functions.
77
+
78
+
79
+ Installation
80
+ ############
81
+ 1. (Optional) Creating conda environment
82
+
83
+ .. code-block:: bash
84
+
85
+ conda create -n lavis python=3.8
86
+ conda activate lavis
87
+
88
+ 2. Cloning and building from source
89
+
90
+ .. code-block:: bash
91
+
92
+ git clone https://github.com/salesforce/LAVIS.git
93
+ cd LAVIS
94
+ pip install .
95
+
96
+ If you would like to develop on LAVIS, you may find it easier to build with editable mode::
97
+
98
+ pip install -e .
99
+
docs/make.bat ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @ECHO OFF
2
+
3
+ pushd %~dp0
4
+
5
+ REM Command file for Sphinx documentation
6
+
7
+ if "%SPHINXBUILD%" == "" (
8
+ set SPHINXBUILD=sphinx-build
9
+ )
10
+ set SOURCEDIR=source
11
+ set BUILDDIR=build
12
+
13
+ if "%1" == "" goto help
14
+
15
+ %SPHINXBUILD% >NUL 2>NUL
16
+ if errorlevel 9009 (
17
+ echo.
18
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19
+ echo.installed, then set the SPHINXBUILD environment variable to point
20
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
21
+ echo.may add the Sphinx directory to PATH.
22
+ echo.
23
+ echo.If you don't have Sphinx installed, grab it from
24
+ echo.http://sphinx-doc.org/
25
+ exit /b 1
26
+ )
27
+
28
+ %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29
+ goto end
30
+
31
+ :help
32
+ %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33
+
34
+ :end
35
+ popd
docs/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ GitPython
2
+ ipykernel
3
+ nbsphinx==0.8.7
4
+ pandoc
5
+ sphinx
6
+ sphinx_autodoc_typehints
7
+ sphinx_rtd_theme
docs/tutorial.configs.rst ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. _config:
2
+
3
+ Training Models on Task Datasets (Commands and Configurations)
4
+ #################################################################
5
+
6
+ LAVIS provides scripts to pre-train and finetune supported models on standard language-vision tasks, stored at ``lavis/run_scripts/``.
7
+ To replicate the experiments, just run these bash scripts. For example, to train BLIP model on the image-text retrieval task with MSCOCO dataset, we can run
8
+
9
+ .. code-block::
10
+
11
+ bash run_scripts/blip/train/train_retrieval_coco.sh
12
+
13
+ Inside the scripts, we can see
14
+
15
+ .. code-block:: bash
16
+
17
+ python -m torch.distributed.run --nproc_per_node=8 train.py --cfg-path lavis/projects/blip/train/retrieval_coco_ft.yaml
18
+
19
+ where we start a pytorch distributed training on 8 GPUs (you may change according to your own hardware setup). The ``--cfg-path`` specifys a `runtime configuration file`, specifying
20
+ the task, model, dataset and training recipes.
21
+
22
+ Available options and their descriptions are as below.
23
+
24
+ .. LAVIS executes training and evaluation based on arguments specified in the configuration files. The default model and dataset configurations are defined in ``lavis/configs``. The task-specific configurations are defined in ``lavis/projects``. Task-specific configurations have higher priority over the default configurations.
25
+
26
+ .. The following tables provide explanations for the arguments in the configuration files.
27
+
28
+ .. list-table::
29
+ :widths: 30 40
30
+ :header-rows: 1
31
+
32
+ * - Model Configurations
33
+ - Functionalities
34
+ * - arch
35
+ - | name of the model from the model zoo
36
+ | default: task-dependent
37
+ * - model_type
38
+ - | the type of the model (e.g., base)
39
+ | default: task-dependent
40
+ * - load_pretrained
41
+ - | load pretrained weights
42
+ | default: True (for finetuning task) | False (for pretraining task)
43
+ * - load_finetuned
44
+ - | load task-specific finetuned weights
45
+ | default: False (for finetuning task) | True (for evaluation)
46
+ * - pretrained
47
+ - | URL or local path which stores the pretrained model, defined in the default model configuration file
48
+ | default: task-dependent
49
+ * - finetuned
50
+ - | URL or local path which stores the finetuned model, defined in the default model configuration file
51
+ | default: task-dependent
52
+
53
+ .. list-table::
54
+ :widths: 30 50
55
+ :header-rows: 1
56
+
57
+ * - Dataset Configurations
58
+ - Functionalities
59
+ * - vis_processor
60
+ - | pre-processing of visual input
61
+ | default: task-dependent
62
+ * - text_processor
63
+ - | pre-processing of text input
64
+ | default: task-dependent
65
+ * - build_info
66
+ - | dataset information including the storage location, defined in the default dataset configuration file
67
+ | default: task-dependent
68
+
69
+ .. list-table::
70
+ :widths: 30 50
71
+ :header-rows: 1
72
+
73
+ * - Runtime Configurations
74
+ - Functionalities
75
+ * - task
76
+ - | name of the task
77
+ | default: task-dependent
78
+ * - lr_sched
79
+ - | learning rate schedular
80
+ | default: linear_warmup_cosine_lr
81
+ * - init_lr
82
+ - | initial learning rate (after warmup)
83
+ | default: task-dependent
84
+ * - min_lr
85
+ - | final learning rate after decay
86
+ | default: task-dependent
87
+ * - warmup_lr
88
+ - | starting learning rate for warmup
89
+ | default: init_lr (no warmup)
90
+ * - lr_decay_rate
91
+ - | learning rate decay per epoch for step_lr_shedule
92
+ | default: 0.9
93
+ * - warmup_steps
94
+ - | number of steps for learning rate warmup
95
+ | default: 0
96
+ * - max_epoch
97
+ - | total number of training epochs
98
+ | default: task-dependent
99
+ * - weight_decay
100
+ - | weight decay coefficient for the optimizer
101
+ | default: 0.05
102
+ * - batch_size_train
103
+ - | batch size during training
104
+ | default: task-dependent
105
+ * - batch_size_eval
106
+ - | batch size during evaluation
107
+ | default: task-dependent
108
+ * - seed
109
+ - | pseudo random number generator seed
110
+ | default: 42
111
+ * - output_dir
112
+ - | directory to store logs, results and checkpoints
113
+ | default: task-dependent
114
+ * - resume_ckpt_path
115
+ - | path of the checkpoint to resume training from
116
+ | default: None
117
+ * - evaluate
118
+ - | only perform evaluation without training
119
+ | default: False
120
+ * - train_splits
121
+ - | dataset splits used for training
122
+ | default: ["train"]
123
+ * - valid_splits
124
+ - | dataset splits used for validation
125
+ | default: ["val"]
126
+ * - test
127
+ - | dataset splits used for test
128
+ | default: ["test"]
129
+ * - device
130
+ - | use cpu or gpu (cuda)
131
+ | default: cuda
132
+ * - world_size
133
+ - | number of processes participating in the job
134
+ | default: 1
135
+ * - dist_url
136
+ - | URL specifying how to initialize the process group
137
+ | default: "env://"
138
+ * - distributed
139
+ - | use distributed training
140
+ | default: True
141
+ * - amp
142
+ - | use automatic mixed precision training
143
+ | default: False
144
+
145
+ .. list-table::
146
+ :widths: 40 50
147
+ :header-rows: 1
148
+
149
+ * - Text Generation Configurations
150
+ - Functionalities
151
+ * - max_len
152
+ - | maximum number of text tokens to generate
153
+ | default: 20 (for image captioning)
154
+ * - min_len
155
+ - | minimum number of text tokens to generate
156
+ | default: 5 (for image captioning)
157
+ * - num_beams
158
+ - | number of beams to perform beam search
159
+ | default: 3
160
+
161
+ .. list-table::
162
+ :widths: 40 50
163
+ :header-rows: 1
164
+
165
+ * - Multimodal Retrieval Configurations
166
+ - Functionalities
167
+ * - negative_all_rank
168
+ - | collect negatives from all processes for the image-text matching loss
169
+ | default: True (for coco)
170
+ * - k_test
171
+ - | number of retrieval candidates ranked from contrastive similarity
172
+ | default: 256 (for coco)
docs/tutorial.datasets.rst ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Adding Datasets
2
+ ################################################
3
+
4
+ This is a tutorial on adding a new dataset using ``lavis.datasets`` module.
5
+
6
+ The LAVIS library includes a standard dataset module, which allows customization to add new datasets.
7
+ The ``lavis.datasets`` module is designed such that any new dataset class can be easily added and adapted from our code base, including creating dataset configuration, and defining and associating new dataset classes.
8
+
9
+ In this tutorial, we will replicate the steps to add a dataset class for the `Audio-Visual Scene-Aware Dialogue (AVSD) <https://arxiv.org/pdf/1901.09107.pdf>`_ benchmark for the video-grounded dialogue task.
10
+
11
+ Dataset Configuration ``lavis.configs.datasets``
12
+ **************************************************************
13
+
14
+ First, we define the basic configurations for this dataset, including a new dataset class ``avsd_dialogue``, dataset card, and data types.
15
+ We can define any new dataset configuration in ``lavis.configs.datasets``. For instance, under this module, we can set up a configuration file ``avsd/defaults_dial.yaml`` as follows:
16
+
17
+ .. code-block:: yaml
18
+
19
+ datasets:
20
+ avsd_dialogue: # name of the dataset builder
21
+ dataset_card: dataset_card/avsd_dialogue.md # path to the dataset card
22
+ data_type: features # [images|videos|features] we use features in this case for extracted video features
23
+
24
+ build_info:
25
+ # Be careful not to append minus sign (-) before split to avoid itemizing
26
+ annotations:
27
+ train:
28
+ url: /export/home/data/avsd/train_set4DSTC7-AVSD.json
29
+ storage: avsd/annotations/train.json
30
+ val:
31
+ url: /export/home/data/avsd/valid_set4DSTC7-AVSD.json
32
+ storage: avsd/annotations/val.json
33
+ test:
34
+ url: /export/home/data/avsd/test_set4DSTC7-AVSD.json
35
+ storage: avsd/annotations/test.json
36
+ features:
37
+ storage: /export/home/data/avsd/features/
38
+
39
+
40
+ Dataset Card
41
+ ===============
42
+ One optional step to set up dataset configuration is defining a dataset card, which contains more details about the dataset such as description, tasks, and metrics.
43
+ For instance, we can define a dataset card for the AVSD benchmark in ``dataset_card/avsd_dialogue.md``.
44
+ Depending on the dataset, we included in its corresponding dataset card the command for auto-downloading data (with python code defined in ``lavis.datasets.download_scripts``) that will automatically load the data and store it in a specific folder.
45
+ Else, you should describe in the dataset card the external download instructions from the original data source to load the dataset properly.
46
+
47
+ One example of a dataset card for the AVSD benchmark is:
48
+
49
+ .. code-block:: md
50
+
51
+ ![Samples from the AVSD dataset (Image credit: "https://arxiv.org/pdf/1901.09107.pdf").](imgs/avsd_dialogue.png)(Samples from the AVSD dataset. Image credit: "https://arxiv.org/pdf/1901.09107.pdf")
52
+
53
+ # Audio-Visual Scene-Aware Dialogues (AVSD)
54
+
55
+ ## Description
56
+ [Audio-Visual Scene-Aware Dialogues (AVSD)](https://github.com/hudaAlamri/DSTC7-Audio-Visual-Scene-Aware-Dialog-AVSD-Challenge) contains more than 10,000 dialogues, each of which is grounded on a unique video. In the test split, for each test sample, 6 reference dialogue responses are provided.
57
+
58
+
59
+ ## Task
60
+
61
+ (https://github.com/hudaAlamri/DSTC7-Audio-Visual-Scene-Aware-Dialog-AVSD-Challenge)
62
+
63
+ In a **video-grounded dialogue task**, the system must generate responses to user input in the context of a given dialog.
64
+ This context consists of a dialog history (previous utterances by both user and system) in addition to video and audio information that comprise the scene. The quality of a system’s automatically generated sentences is evaluated using objective measures to determine whether or not the generated responses are natural and informative
65
+
66
+ ## Metrics
67
+ Models are typically evaluated according to [BLEU](https://aclanthology.org/P02-1040/), [CIDER](https://www.cv-foundation.org/openaccess/content_cvpr_2015/papers/Vedantam_CIDEr_Consensus-Based_Image_2015_CVPR_paper.pdf), [METEOR](https://aclanthology.org/W05-0909/), and [ROUGE-L](https://aclanthology.org/W04-1013/) metrics.
68
+
69
+ ## Leaderboard
70
+
71
+ ....
72
+
73
+
74
+ ## Auto-Downloading
75
+
76
+ Please refer to [benchmark webite](https://github.com/hudaAlamri/DSTC7-Audio-Visual-Scene-Aware-Dialog-AVSD-Challenge) for instructions to download the dataset.
77
+
78
+
79
+ ## References
80
+ "Audio Visual Scene-Aware Dialog", Huda Alamri, Vincent Cartillier, Abhishek Das, Jue Wang, Anoop Cherian, Irfan Essa, Dhruv Batra, Tim K. Marks, Chiori Hori, Peter Anderson, Stefan Lee, Devi Parikh
81
+
82
+ Visual Data Type
83
+ ==============================
84
+ We currently limit the visual data types to one of three options: ``images``, ``videos``, and ``features``.
85
+ "Images" and "videos" refer to the raw visual data, which is appropriate for models processing visual data in their original forms (e.g. ViT models).
86
+ "Features" are visual representations extracted from pretrained models (e.g. CNN models).
87
+ In this tutorial, the AVSD benchmark consists of video features extracted from 3D-CNN models.
88
+
89
+ Build Info
90
+ ==============================
91
+ Build info refers to the specific locations where data is stored and cached.
92
+
93
+ For text annotations (e.g. captioning or dialogues), by default, we include three data splits, namely "train", "val", and "test", typically used in all machine learning projects.
94
+ For each split, we specify 2 parameters: ``url`` and ``storage``.
95
+ ``url`` can be either an online URL where the dataset can be loaded automatically (e.g. from *googleapis*), or a local directory where data is already downloaded beforehand.
96
+ ``storage`` is the directory where the data will be cached over time, avoiding downloading data repeatedly.
97
+
98
+ For visual data annotations, ensure the field name matches the data types defined earlier (e.g. one of "images", "videos" or features").
99
+ As visual features are usually large and should be downloaded beforehand, we maintain only a ``storage`` parameter where visual data is cached.
100
+
101
+ Dataset ``lavis.datasets.datasets``
102
+ **************************************************************
103
+
104
+ Base Dataset ``lavis.datasets.datasets.base_dataset``
105
+ =======================================================
106
+ In this step, we want to define new dataset classes that inherit our base dataset class ``lavis.datasets.datasets.base_dataset``. This base dataset class already defines standard methods such as ``collater`` which uses the default collator from Pytorch.
107
+
108
+ .. code-block:: python
109
+
110
+ import json
111
+ from typing import Iterable
112
+
113
+ from torch.utils.data import Dataset, ConcatDataset
114
+ from torch.utils.data.dataloader import default_collate
115
+
116
+ class BaseDataset(Dataset):
117
+ def __init__(
118
+ self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
119
+ ):
120
+ """
121
+ vis_root (string): Root directory of images (e.g. coco/images/)
122
+ ann_root (string): directory to store the annotation file
123
+ """
124
+ self.vis_root = vis_root
125
+
126
+ self.annotation = []
127
+ for ann_path in ann_paths:
128
+ self.annotation.extend(json.load(open(ann_path, "r")))
129
+
130
+ self.vis_processor = vis_processor
131
+ self.text_processor = text_processor
132
+
133
+ self._add_instance_ids()
134
+
135
+ def __len__(self):
136
+ return len(self.annotation)
137
+
138
+ def collater(self, samples):
139
+ return default_collate(samples)
140
+
141
+ def set_processors(self, vis_processor, text_processor):
142
+ self.vis_processor = vis_processor
143
+ self.text_processor = text_processor
144
+
145
+ def _add_instance_ids(self, key="instance_id"):
146
+ for idx, ann in enumerate(self.annotation):
147
+ ann[key] = str(idx)
148
+
149
+ Any dataset subclass will inherit these methods and it is optional to define and overwrite these methods accordingly to the specifications of the dataset.
150
+ We encourage users not to modify the base dataset class as any modification will have cascading impacts on any other dataset classes that inherit this base dataset.
151
+ Instead, the users should independently create new dataset classes to cater to their specific requirements.
152
+
153
+ Dialogue Datasets ``lavis.datasets.datasets.dialogue_datasets``
154
+ ======================================================================
155
+
156
+ For example, for the AVSD dataset, we want to define a new dataset subclass ``DialogueDataset`` for dialogue tasks. We can define this dataset class in ``lavis.datasets.datasets.dialogue_datasets`` as following:
157
+
158
+ .. code-block:: python
159
+
160
+ import os
161
+ from collections import OrderedDict
162
+
163
+ from lavis.datasets.datasets.base_dataset import BaseDataset
164
+
165
+ import json
166
+ import copy
167
+
168
+ class DialogueDataset(BaseDataset):
169
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
170
+ """
171
+ vis_processor (string): visual processor
172
+ text_processor (string): textual processor
173
+ vis_root (string): Root directory of images (e.g. coco/images/)
174
+ ann_paths (string): Root directory of images (e.g. coco/images/)
175
+ """
176
+
177
+ self.vis_root = vis_root
178
+
179
+ self.annotation = []
180
+ for ann_path in ann_paths:
181
+ dialogs = json.load(open(ann_path, "r"))['dialogs']
182
+ for dialog in dialogs:
183
+ all_turns = dialog['dialog']
184
+ dialogue_context = []
185
+ for turn in all_turns:
186
+ dialog_instance = copy.deepcopy(dialog)
187
+ question = turn['question']
188
+ answer = turn['answer']
189
+
190
+ dialog_instance['dialog'] = copy.deepcopy(dialogue_context)
191
+ dialog_instance['question'] = question
192
+ dialog_instance['answer'] = answer
193
+ self.annotation.append(dialog_instance)
194
+ dialogue_context.append(turn)
195
+
196
+ self.vis_processor = vis_processor
197
+ self.text_processor = text_processor
198
+
199
+ self._add_instance_ids()
200
+
201
+ self.img_ids = {}
202
+ n = 0
203
+ for ann in self.annotation:
204
+ img_id = ann["image_id"]
205
+ if img_id not in self.img_ids.keys():
206
+ self.img_ids[img_id] = n
207
+ n += 1
208
+
209
+ Class inheritance allows us to define multiple subclasses. For instance, we want another dialogue dataset class that is defined only for the test split. We can define another dataset class ``DialogueEvalDataset`` as similarly defined above but the annotations are processed differently.
210
+ Typically, in dialogue tasks, during test time, only a single test sample is constructed per dialogue (rather than decomposing all dialogue turns as samples during training time).
211
+ The dataset class can then be defined as:
212
+
213
+ .. code-block:: python
214
+
215
+ class DialogueEvalDataset(BaseDataset):
216
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
217
+ # ...
218
+ # defined similarly as DialogueDataset above
219
+ # except for the loading of dialogue annotation data
220
+
221
+ self.annotation = []
222
+ for ann_path in ann_paths:
223
+ dialogs = json.load(open(ann_path, "r"))['dialogs']
224
+ for dialog in dialogs:
225
+ all_turns = dialog['dialog']
226
+ dialogue_context = all_turns[:-1]
227
+ last_turn = all_turns[-1]
228
+
229
+ question = last_turn['question']
230
+ answer = last_turn['answer']
231
+
232
+ dialog['dialog'] = dialogue_context
233
+ dialog['question'] = question
234
+ dialog['answer'] = answer
235
+
236
+ self.annotation.append(dialog)
237
+
238
+
239
+ Using class inheritance to define datasets also allows us to develop more fine-grain class implementations, each of which is specifically designated for a benchmark.
240
+ For instance, under the dialogue-based tasks, we can further define another dataset subclass that is specified for the AVSD dataset.
241
+ We can define a new class ``AVSDDialDataset`` that further specifies how to load individual samples and collate them accordingly to specific requirements:
242
+
243
+ .. code-block:: python
244
+
245
+ import os
246
+ from lavis.datasets.datasets.base_dataset import BaseDataset
247
+ from lavis.datasets.datasets.dialogue_datasets import DialogueDataset, DialogueEvalDataset
248
+
249
+ import torch
250
+
251
+ class AVSDDialDataset(DialogueDataset):
252
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
253
+
254
+ super().__init__(vis_processor, text_processor, vis_root, ann_paths)
255
+
256
+ def __getitem__(self, index):
257
+
258
+ ann = self.annotation[index]
259
+
260
+ vname = ann["image_id"]
261
+
262
+ video = self.vis_processor(self.vis_root, vname)
263
+
264
+ dialogue = self.text_processor(ann)
265
+
266
+ return {
267
+ "video_fts": video['video_fts'],
268
+ "video_token_type_ids": video['token_type_ids'],
269
+ "input_ids": dialogue['input_ids'],
270
+ "token_type_ids": dialogue['token_type_ids'],
271
+ "labels": dialogue['labels'],
272
+ "image_id": ann["image_id"],
273
+ "instance_id": ann["instance_id"]
274
+ }
275
+
276
+ def collater(self, samples):
277
+
278
+ input_ids, token_type_ids, labels, video_fts, video_token_type_ids = [], [], [], [], []
279
+
280
+ for i in samples:
281
+ input_ids.append(i['input_ids'])
282
+ token_type_ids.append(i['token_type_ids'])
283
+ labels.append(i['labels'])
284
+ video_fts.append(i['video_fts'])
285
+ video_token_type_ids.append(i['video_token_type_ids'])
286
+
287
+ input_ids = self.text_processor.padding(input_ids)
288
+
289
+ labels = self.text_processor.padding(labels, -1)
290
+ video_fts = self.vis_processor.padding(video_fts)
291
+
292
+ token_type_ids = self.text_processor.padding(token_type_ids)
293
+ video_token_type_ids = self.text_processor.padding(video_token_type_ids)
294
+ token_type_ids = torch.cat([video_token_type_ids, token_type_ids], dim=1)
295
+
296
+ attn_mask = self.text_processor.get_attention_mask(input_ids)
297
+ video_mask = self.vis_processor.get_attention_mask(video_fts)
298
+ attn_mask = torch.cat([video_mask, attn_mask], dim=1)
299
+
300
+ video_labels = torch.ones((video_fts.size(0), video_fts.size(1))).long() * -1 # ignore token indice -1 by default
301
+
302
+ labels = torch.cat([video_labels, labels], dim=1)
303
+
304
+ samples = {}
305
+ samples['input_ids'] = input_ids
306
+ samples['token_type_ids'] = token_type_ids
307
+ samples['labels'] = labels
308
+ samples['video_fts'] = video_fts
309
+ samples['attn_mask'] = attn_mask
310
+
311
+ return samples
312
+
313
+ Note that in a dataset subclass, if methods such as ``__getitem__`` and ``collater`` are not defined, the same functions from the corresponding superclass will be used.
314
+ For instance, by default, we always use the collater from the ``BaseDataset`` class to collate data samples.
315
+
316
+ Dataset Builder ``lavis.datasets.builders``
317
+ **************************************************************
318
+ Dataset Builder is the data processing module that controls the dataset classes (by training or evaluation split) and associates the specific dataset configurations to these dataset classes.
319
+
320
+ Base Dataset Builder ``lavis.datasets.builders.base_dataset_builder``
321
+ ======================================================================
322
+
323
+ Note that any new builder class definition should inherit the base dataset builder class ``lavis.datasets.builders.base_dataset_builder``:
324
+
325
+ .. code-block:: python
326
+
327
+ class BaseDatasetBuilder:
328
+ train_dataset_cls, eval_dataset_cls = None, None
329
+ ...
330
+
331
+ This allows us to standardize the operations of dataset builders across all builder classes. We advise the users to carefully review the standard methods defined in the base builder class, including methods such as ``_download_data`` and ``build_dataset`` that will load download the data and create instances of dataset classes:
332
+
333
+ .. code-block:: python
334
+
335
+ class BaseDatasetBuilder:
336
+ ...
337
+
338
+ def build_datasets(self):
339
+ # download, split, etc...
340
+ # only called on 1 GPU/TPU in distributed
341
+
342
+ if is_main_process():
343
+ self._download_data()
344
+
345
+ if is_dist_avail_and_initialized():
346
+ dist.barrier()
347
+
348
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
349
+ logging.info("Building datasets...")
350
+ datasets = self.build() # dataset['train'/'val'/'test']
351
+
352
+ return datasets
353
+
354
+ def _download_data(self):
355
+ self._download_ann()
356
+ self._download_vis()
357
+
358
+ We encourage users not to modify the implementation of the base dataset builder class as this will affect all existing dataset builder subclasses.
359
+
360
+ Dialogue Dataset Builder ``lavis.datasets.builders.dialogue_builder``
361
+ ======================================================================
362
+ We can define any new builder subclass and associate this builder with the corresponding dataset classes and dataset configurations.
363
+ For instance, for the AVSD dataset, we can define a builder ``lavis.datasets.builders.dialogue_builder`` for dialogue-based datasets as follows:
364
+
365
+ .. code-block:: python
366
+
367
+ from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder
368
+ from lavis.datasets.datasets.avsd_dialogue_datasets import (
369
+ AVSDDialDataset,
370
+ AVSDDialEvalDataset
371
+ )
372
+
373
+ from lavis.common.registry import registry
374
+
375
+
376
+ @registry.register_builder("avsd_dialogue")
377
+ class AVSDDialBuilder(BaseDatasetBuilder):
378
+ train_dataset_cls = AVSDDialDataset
379
+ eval_dataset_cls = AVSDDialEvalDataset
380
+
381
+ DATASET_CONFIG_DICT = {
382
+ "default": "configs/datasets/avsd/defaults_dial.yaml"
383
+ }
384
+
385
+ Note that we chose to separately define the parameters ``train_dataset_cls`` and ``eval_dataset_cls`` to consider cases where data is processed differently between training and test time.
386
+ For instance, in captioning tasks, during test time, each data sample often includes multiple ground-truth captions rather than just a single ground-truth during training time.
387
+ If the data processing is the same in both training and test time, the two parameters can be linked to the same dataset class.
388
+
389
+ Finally, define ``DATASET_CONFIG_DICT`` to associate the dataset configurations to the assigned dataset classes.
390
+
391
+ Registering Builder ``lavis.datasets.builders.__init__``
392
+ ======================================================================
393
+
394
+ To add a new builder class, ensure to first include the class within the ``__init__.py``. For instance, to define a new builder for the AVSD dataset:
395
+
396
+ .. code-block:: python
397
+
398
+ from lavis.datasets.builders.dialogue_builder import (
399
+ AVSDDialBuilder
400
+ )
401
+
402
+ __all__ = [
403
+ ...,
404
+ "AVSDDialBuilder"
405
+ ]
406
+
407
+ Assigning Builder
408
+ ======================================================================
409
+ Note that during data loading and processing, the builder being assigned must have the correct registry to be able to load it properly.
410
+ For instance, the following should be specified in a configuration file e.g. ``dialogue_avsd_ft.yaml``:
411
+
412
+ .. code-block:: yaml
413
+
414
+ datasets:
415
+ avsd_dialogue: # name of the dataset builder
416
+ ...
417
+ # processor configuration
418
+ ...
419
+
420
+ Subsequently, any processes (e.g. training) should load this configuration file to assign the correct builder which will then associate the correct dataset classes to construct data samples.
421
+
422
+ .. code-block:: sh
423
+
424
+ python train.py --cfg-path dialogue_avsd_ft.yaml
docs/tutorial.evaluation.rst ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Evaluating Pre-trained Models on Task Datasets
2
+ ###############################################
3
+ LAVIS provides pre-trained and finetuned model for off-the-shelf evaluation on task dataset.
4
+ Let's now see an example to evaluate BLIP model on the captioning task, using MSCOCO dataset.
5
+
6
+ .. _prep coco:
7
+
8
+ Preparing Datasets
9
+ ******************
10
+ First, let's download the dataset. LAVIS provides `automatic downloading scripts` to help prepare
11
+ most of the public dataset, to download MSCOCO dataset, simply run
12
+
13
+ .. code-block:: bash
14
+
15
+ cd lavis/datasets/download_scripts && python download_coco.py
16
+
17
+ This will put the downloaded dataset at a default cache location ``cache`` used by LAVIS.
18
+
19
+ If you want to use a different cache location, you can specify it by updating ``cache_root`` in ``lavis/configs/default.yaml``.
20
+
21
+ If you have a local copy of the dataset, it is recommended to create a symlink from the cache location to the local copy, e.g.
22
+
23
+ .. code-block:: bash
24
+
25
+ ln -s /path/to/local/coco cache/coco
26
+
27
+ Evaluating pre-trained models
28
+ ******************************
29
+
30
+ To evaluate pre-trained model, simply run
31
+
32
+ .. code-block:: bash
33
+
34
+ bash run_scripts/blip/eval/eval_coco_cap.sh
35
+
36
+ Or to evaluate a large model:
37
+
38
+ .. code-block:: bash
39
+
40
+ bash run_scripts/blip/eval/eval_coco_cap_large.sh
docs/tutorial.models.rst ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Adding Models
2
+ ####################################
3
+
4
+ This is a tutorial on adding new models using ``lavis.models`` module.
5
+
6
+ The LAVIS library includes a standard model module that builds the foundation for many major language-vision models such as `ALBEF <https://arxiv.org/pdf/2107.07651.pdf>`_,
7
+ `BLIP <https://arxiv.org/pdf/2201.12086.pdf>`_, `ALPRO <https://arxiv.org/pdf/2112.09583.pdf>`_, and `CLIP <https://arxiv.org/pdf/2103.00020.pdf>`_.
8
+ The ``lavis.models`` module is designed such that any new models can be added and integrated into the LAVIS library, with minimal steps to develop training and testing procedures.
9
+ In this tutorial, we will replicate the steps to add a GPT-style model specifically for `video-grounded dialogue tasks <https://arxiv.org/pdf/1901.09107.pdf>`_.
10
+
11
+ Base Model ``lavis.models.base_model``
12
+ **************************************************************
13
+
14
+ Note that any new model definition should inherit the base model class ``BaseModel``:
15
+
16
+ .. code-block:: python
17
+
18
+ from omegaconf import OmegaConf
19
+
20
+ import numpy as np
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+
25
+ from lavis.common.utils import get_abs_path
26
+
27
+ class BaseModel(nn.Module):
28
+ """Base class for models."""
29
+
30
+ def __init__(self):
31
+ super().__init__()
32
+
33
+ def forward_features(self, *args, **kwargs):
34
+ """Similar to *forward* but only return features."""
35
+ raise NotImplementedError
36
+
37
+ def load_from_pretrained(self, url_or_filename):
38
+ raise NotImplementedError
39
+
40
+ @classmethod
41
+ def _from_config(cls, cfg=None, model_type="base"):
42
+ if not cfg:
43
+ # useful when building model without a provided configuration file
44
+ cfg = OmegaConf.load(cls.default_config_path(model_type)).model
45
+
46
+ return cls.from_config(cfg)
47
+
48
+ @classmethod
49
+ def from_pretrained(cls, model_type="base"):
50
+ """
51
+ Build a pretrained model from the default configuration file, specified by model_type.
52
+ """
53
+ return cls._from_config(cfg=None, model_type=model_type)
54
+
55
+ @property
56
+ def device(self):
57
+ return list(self.parameters())[0].device
58
+
59
+ @classmethod
60
+ def default_config_path(cls, model_type="base"):
61
+ assert (
62
+ model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
63
+ ), "Unknown model type {}".format(model_type)
64
+ return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
65
+
66
+ def before_evaluation(self, **kwargs):
67
+ pass
68
+
69
+ def show_n_params(self, return_str=True):
70
+ tot = 0
71
+ for p in self.parameters():
72
+ w = 1
73
+ for x in p.shape:
74
+ w *= x
75
+ tot += w
76
+ if return_str:
77
+ if tot >= 1e6:
78
+ return "{:.1f}M".format(tot / 1e6)
79
+ else:
80
+ return "{:.1f}K".format(tot / 1e3)
81
+ else:
82
+ return tot
83
+
84
+
85
+ In this base model, we already declare and standardize many common methods such as ``_from_config`` and ``_from_pretrained``.
86
+ Inheriting this base model class allows us to standardize operations of models across all model classes while still allowing customizations.
87
+ We advise users not to change the implementation of the base model class as this will affect all existing model subclasses.
88
+
89
+ GPT-style Video-grounded Dialogue Model ``lavis.models.gpt_models.gpt_dialogue``
90
+ ********************************************************************************
91
+
92
+ In this step, we can define a new model class, e.g. under ``lavis.models.gpt_models.gpt_dialogue``, for GPT-based dialogue models designed specifically for video-grounded dialogues.
93
+ Note that we assume the model class inherits from the standard model super class ``GPT2LMHeadModel`` from the ``transformers`` `library <https://huggingface.co/docs/transformers/index>`_.
94
+ We also enforce model integration to the LAVIS framework through the inheritance of the ``BaseModel`` from the LAVIS library, as the secondary super class.
95
+
96
+ .. code-block:: python
97
+
98
+ import torch
99
+ from lavis.common.registry import registry
100
+ from lavis.models.base_model import BaseModel
101
+
102
+ from transformers import GPT2Model, GPT2LMHeadModel
103
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
104
+ import math
105
+ import torch
106
+ import torch.nn as nn
107
+ from torch.nn import CrossEntropyLoss, MSELoss
108
+
109
+ @registry.register_model("gpt_dialogue")
110
+ class GPTDialogue(GPT2LMHeadModel, BaseModel):
111
+ ...
112
+
113
+ Next, we can modify the architecture of the model during model initialization to fit the tasks of interest, i.e. video-grounded dialogues.
114
+ In this case, we want to add additional model parameters for a linear network to transform the video feature representations to the model dimension.
115
+
116
+ .. code-block:: python
117
+
118
+ class GPTDialogue(GPT2LMHeadModel, BaseModel):
119
+
120
+ def __init__(self, config, len_video_ft=4224):
121
+
122
+ super().__init__(config)
123
+
124
+ self.video_ff = nn.Linear(len_video_ft, config.n_embd)
125
+
126
+ # Model parallel
127
+ self.model_parallel = False
128
+ self.device_map = None
129
+
130
+ # Initialize weights and apply final processing
131
+ self.post_init()
132
+
133
+ Note that for each new model class, we advise redefining the ``from_config`` method which is inherited from the ``BaseModel`` class.
134
+ As each model usually has its own unique configurations, redefining the method will ensure the model instances are created properly.
135
+ For instance, ``GPTDialogue`` requires an additional parameter of video feature length (``len_video_ft``) which should be part of the model initialization procedure.
136
+ Another additional parameter is the number of tokens/words (as we include additional special tokens in the vocabulary for dialogue tasks).
137
+
138
+ .. code-block:: python
139
+
140
+ class GPTDialogue(GPT2LMHeadModel, BaseModel):
141
+ ...
142
+ @classmethod
143
+ def from_config(cls, cfg):
144
+ model = cls.from_pretrained('gpt2', len_video_ft=cfg['len_video_ft'])
145
+ model.resize_token_embeddings(cfg['len_tokenizer'])
146
+ return model
147
+
148
+ Other basic methods should also be defined explicitly in the new model class, including the ``forward`` function.
149
+ For instance, in GPT models for video-grounded dialogue tasks, we want the forward operation also includes the transformation and integration of video features before passing the representations to the Transformer layers.
150
+
151
+ .. code-block:: python
152
+
153
+ class GPTDialogue(GPT2LMHeadModel, BaseModel):
154
+ ...
155
+
156
+ def forward(self, samples,
157
+ past_key_values=None,
158
+ position_ids=None,
159
+ head_mask=None,
160
+ encoder_hidden_states=None,
161
+ encoder_attention_mask=None,
162
+ use_cache=None,
163
+ output_attentions=None,
164
+ output_hidden_states=None,
165
+ return_dict=None):
166
+
167
+ input_embs = self.transformer.wte(samples['input_ids'])
168
+ video_embs = self.video_ff(samples['video_fts'])
169
+ input_embs = torch.cat([video_embs, input_embs], dim=1)
170
+
171
+ transformer_outputs = self.transformer(
172
+ attention_mask=samples['attn_mask'],
173
+ token_type_ids=samples['token_type_ids'],
174
+ inputs_embeds=input_embs,
175
+ position_ids=position_ids,
176
+ head_mask=head_mask,
177
+ encoder_hidden_states=encoder_hidden_states,
178
+ encoder_attention_mask=encoder_attention_mask,
179
+ use_cache=use_cache,
180
+ output_attentions=output_attentions,
181
+ output_hidden_states=output_hidden_states,
182
+ return_dict=return_dict,
183
+ )
184
+ hidden_states = transformer_outputs[0]
185
+
186
+ lm_logits = self.lm_head(hidden_states)
187
+ ...
188
+
189
+ Registering New Model ``lavis.models.__init__``
190
+ ********************************************************************************
191
+
192
+ Any new model must be officially registered as part of the ``lavis.models`` module.
193
+ For instance, to add a model class for GPT-based dialogue models, we can modify the ``__init__.py`` as follows:
194
+
195
+ .. code-block:: python
196
+
197
+ from lavis.models.gpt_models.gpt_dialogue import GPTDialogue
198
+
199
+ __all__ = [
200
+ ...
201
+ "GPTDialogue"
202
+ ]
203
+
204
+ Assigning Model
205
+ ********************************************************************************
206
+
207
+ From the above example of a model class, note that we define a ``from_config method`` for the new model class.
208
+ This method will process a configuration file and pass specific parameters to initialize the model classes properly.
209
+ To do this, we can assign/ associate the correct registry of model classes in a configuration file.
210
+ For instance, the following should be specified in a configuration file e.g. ``dialogue_avsd_ft.yaml``:
211
+
212
+ .. code-block:: yaml
213
+
214
+ model:
215
+ arch: gpt_dialogue # name of the model
216
+ model_type: base
217
+
218
+
219
+ Subsequently, any processes (e.g. training) should load this configuration file to assign the correct model.
220
+
221
+ .. code-block:: sh
222
+
223
+ python train.py --cfg-path dialogue_avsd_ft.yaml
224
+
225
+ Note that to simplify the model configuration, we only enable two main parameters here: ``arch`` and ``model_type``. ``arch`` refers to the model class registry, and ``model_type`` is the corresponding model type under this model family.
226
+ For instance, with ``gpt_dialogue``, we have a model ``base`` which has its own configuration in a separate configuration file e.g. ``gpt_dialogue_base.yaml``:
227
+
228
+ .. code-block:: yaml
229
+
230
+ model:
231
+ arch: gpt_dialogue
232
+ len_tokenizer: 50264 # 50257 tokens from gpt2 default tokenizer + additional special tokens
233
+ len_video_ft: 4224 # i3d_rgb: 2048 i3d_flow: 2048 vggish: 128
234
+
235
+ We can pass load this configuration and pass the parameters to the above ``from_config`` method to initialize the model accordingly.
236
+ We advise the users to maintain a dictionary that contains default paths to model configurations, in the model class definition.
237
+ By default, the LAVIS framework will search for configurations from each model class defined as ``model.PRETRAINED_MODEL_CONFIG_DICT``.
238
+
239
+ .. code-block:: python
240
+
241
+ class GPTDialogue(GPT2LMHeadModel, BaseModel):
242
+ PRETRAINED_MODEL_CONFIG_DICT = {
243
+ "base": "configs/models/gpt_dialogue_base.yaml"
244
+ }
245
+ ...
docs/tutorial.processors.rst ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Adding Processors
2
+ ################################################
3
+
4
+ This is a tutorial on adding new processors using ``lavis.processors`` module.
5
+
6
+ The LAVIS library includes a standard processor module that preprocesses data e.g. image transformation and sequence concatenation.
7
+ The ``lavis.processors`` module is designed such that any processors can be added, specifically to the requirements of corresponding models of interest.
8
+ In this tutorial, we will replicate the steps to add visual and textual processors specifically for `video-grounded dialogue tasks <https://arxiv.org/pdf/1901.09107.pdf>`_.
9
+ In addition, we also want the processors to have processing features to make the data samples compatible with GPT-style models.
10
+
11
+ Base Processor ``lavis.processors.base_processors``
12
+ *****************************************************
13
+
14
+ Note that any new processor definition should inherit the base processor class ``BaseProcessor``:
15
+
16
+ .. code-block:: python
17
+
18
+ from omegaconf import OmegaConf
19
+
20
+ class BaseProcessor:
21
+ def __init__(self):
22
+ self.transform = lambda x: x
23
+ return
24
+
25
+ def __call__(self, item):
26
+ return self.transform(item)
27
+
28
+ @classmethod
29
+ def from_config(cls, cfg=None):
30
+ return cls()
31
+
32
+ def build(self, **kwargs):
33
+ cfg = OmegaConf.create(kwargs)
34
+
35
+ return self.from_config(cfg)
36
+
37
+ This allows us to standardize operations of processors across all processor classes while still allowing customization of processors specifically to data and model types.
38
+ We encourage users not to modify the implementation of the base processor class as this will have an impact on all existing processor subclasses.
39
+
40
+ GPT-style Processors ``lavis.processors.gpt_processors``
41
+ **************************************************************
42
+ In this step, we can define new processor classes, e.g. under ``lavis.processors.gpt_processors``, for GPT models designed specifically for video-grounded dialogues.
43
+ First, we want to process video features by defining ``GPTVideoFeatureProcessor`` class.
44
+ In this tutorial, we assume video features are extracted beforehand and this processor simply loads the features from ``npy`` files.
45
+ Other methods that are specifically defined are ``padding`` (which is used by dataset instances to pad multiple video samples) and ``get_attention_mask`` (which creates an attention mask for Transformer attention in GPT models).
46
+
47
+ .. code-block:: python
48
+
49
+ SPECIAL_TOKENS_DICT = {'bos_token': "<bos>", 'eos_token': "<eos>", 'additional_special_tokens': ["<speaker1>", "<speaker2>", "<video>", "<cap>"], 'pad_token': "<pad>"}
50
+ ...
51
+
52
+ @registry.register_processor("gpt_video_ft")
53
+ class GPTVideoFeatureProcessor(BaseProcessor):
54
+ def __init__(self, visual_ft, audio_ft):
55
+
56
+ self.visual_ft = visual_ft
57
+ self.audio_ft = audio_ft
58
+
59
+ self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
60
+ self.tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)
61
+
62
+ def padding(self, seq):
63
+ padded_seq = torch.nn.utils.rnn.pad_sequence(seq, batch_first=True, padding_value=1.0)
64
+ return padded_seq
65
+
66
+ def get_attention_mask(self, seq):
67
+ return torch.sum(seq != 1, dim=2) != 0
68
+
69
+ def __call__(self, ft_root, vname):
70
+ all_ft = []
71
+
72
+ for ft_name in self.visual_ft:
73
+ ft_path = os.path.join(ft_root, ft_name, vname)
74
+ all_ft.append(np.load(ft_path + '.npy'))
75
+
76
+ for ft_name in self.audio_ft:
77
+ ft_path = os.path.join(ft_root, ft_name, vname)
78
+ all_ft.append(np.load(ft_path + '.npy'))
79
+
80
+ min_len = min([len(ft) for ft in all_ft])
81
+
82
+ sampled_ft = [ft[:min_len] for ft in all_ft]
83
+ sampled_ft = np.concatenate(sampled_ft, axis=1)
84
+ item = {}
85
+ item['video_fts'] = torch.Tensor(sampled_ft)
86
+
87
+ video_type_token = self.tokenizer.convert_tokens_to_ids('<video>')
88
+ item['token_type_ids'] = torch.Tensor([video_type_token] * len(sampled_ft)).long()
89
+
90
+ return item
91
+
92
+ @classmethod
93
+ def from_config(cls, cfg=None):
94
+ if cfg is None:
95
+ cfg = OmegaConf.create()
96
+
97
+ visual_ft = cfg.get("visual_ft", ["i3d_rgb"])
98
+ audio_ft = cfg.get("audio_ft", ["vggish"])
99
+
100
+ return cls(
101
+ visual_ft=visual_ft,
102
+ audio_ft=audio_ft
103
+ )
104
+
105
+ Another processor class that will be useful to have is to process dialogue data. Here we can define a ``GPTDialogueProcessor`` class.
106
+ This processor class receives raw annotations and constructs inputs as a concatenation of input sequences (questions, dialogue contexts, and responses) to facilitate application in GPT models.
107
+ Other methods that are specifically defined are ``padding`` (which is used by dataset instances to pad multiple sequence samples) and ``get_attention_mask`` (which creates an attention mask for Transformer attention in GPT models).
108
+
109
+ .. code-block:: python
110
+
111
+ SPECIAL_TOKENS_DICT = {'bos_token': "<bos>", 'eos_token': "<eos>", 'additional_special_tokens': ["<speaker1>", "<speaker2>", "<video>", "<cap>"], 'pad_token': "<pad>"}
112
+ ...
113
+
114
+ @registry.register_processor("gpt_dialogue")
115
+ class GPTDialogueProcessor(BaseProcessor):
116
+ def __init__(self, max_turns=3, use_caption=True):
117
+ self.max_turns = max_turns
118
+ self.use_caption = use_caption
119
+ self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
120
+ self.tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)
121
+
122
+ def sample_sequence(self, caption, history, answer):
123
+ bos, eos, speaker1, speaker2, cap = self.tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[:-2])
124
+ instance = {}
125
+ sequence = [caption] + history + [answer]
126
+ sequence = [s + [eos] for s in sequence]
127
+
128
+ instance["input_ids"] = list(chain(*sequence))
129
+ instance["token_type_ids"] = [cap] * len(sequence[0]) + [speaker2 if i % 2 else speaker1 for i, s in enumerate(sequence[1:]) for _ in s]
130
+ instance["labels"] = ([-1]*sum(len(s) for s in sequence[:-1])) + sequence[-1]
131
+
132
+ assert len(instance["input_ids"])==len(instance["token_type_ids"])
133
+ assert len(instance["token_type_ids"])==len(instance["labels"])
134
+
135
+ for k,v in instance.items():
136
+ instance[k] = torch.Tensor(v).long()
137
+
138
+ return instance
139
+
140
+ def padding(self, seq, pad_token=-1):
141
+ if pad_token==-1: pad_token = self.tokenizer.pad_token_id
142
+ padded_seq = torch.nn.utils.rnn.pad_sequence(seq, batch_first=True, padding_value=pad_token)
143
+ return padded_seq
144
+
145
+ def get_attention_mask(self, seq, pad_token=-1):
146
+ if pad_token==-1: pad_token = self.tokenizer.pad_token_id
147
+ return seq != pad_token
148
+
149
+ def __call__(self, ann):
150
+ if self.use_caption:
151
+ caption = ' '.join([ann['caption'], ann['summary']])
152
+ caption = self.tokenizer.encode(caption)
153
+ else:
154
+ caption = []
155
+
156
+ dial_history = []
157
+ for turn in ann['dialog'][-self.max_turns:]:
158
+ dial_history.append(turn['question'])
159
+ dial_history.append(turn['answer'])
160
+ dial_history.append(ann['question'])
161
+ dial_history = [self.tokenizer.encode(t) for t in dial_history]
162
+
163
+ answer = self.tokenizer.encode(ann['answer'])
164
+
165
+ item = self.sample_sequence(caption, dial_history, answer)
166
+
167
+ return item
168
+
169
+ @classmethod
170
+ def from_config(cls, cfg=None):
171
+ if cfg is None:
172
+ cfg = OmegaConf.create()
173
+
174
+ use_caption = cfg.get("use_caption", True)
175
+ max_turns = cfg.get("max_turns", 3)
176
+
177
+ return cls(max_turns=max_turns, use_caption=use_caption)
178
+
179
+ Registering New Processors ``lavis.processors.__init__``
180
+ **************************************************************
181
+
182
+ Finally, any new processor must be officially registered as part of the ``lavis.processors`` module.
183
+ For instance, to add processor classes for GPT-based dialogue models, including one for dialogue data ``GPTDialogueProcessor`` and one for video features ``GPTVideoFeatureProcessor``, we can modify the ``__init__.py`` as follows:
184
+
185
+ .. code-block:: python
186
+
187
+ from lavis.processors.gpt_processors import (
188
+ GPTVideoFeatureProcessor,
189
+ GPTDialogueProcessor,
190
+ )
191
+
192
+ __all__ = [
193
+ ...
194
+ # GPT
195
+ "GPTVideoFeatureProcessor",
196
+ "GPTDialogueProcessor"
197
+ ]
198
+
199
+ Assigning Processors
200
+ **************************************************************
201
+ From the above example of processor classes, note that we define a ``from_config`` method for each class.
202
+ This method will process a configuration file and pass specific parameters e.g. ``max_turns``, ``visual_ft``, to initialize the processor classes properly.
203
+ To do this, we can assign/ associate the correct registry of processor classes in a configuration file.
204
+ For instance, the following should be specified in a configuration file e.g. ``dialogue_avsd_ft.yaml``:
205
+
206
+ .. code-block:: yaml
207
+
208
+ datasets:
209
+ avsd_dialogue: # name of the dataset builder
210
+ vis_processor:
211
+ train:
212
+ name: "gpt_video_ft" # name of the visual processor for training data
213
+ visual_ft: ["i3d_flow", "i3d_rgb"]
214
+ audio_ft: ["vggish"]
215
+ eval:
216
+ name: "gpt_video_ft" # name of the visual processor for evaluation data
217
+ visual_ft: ["i3d_flow", "i3d_rgb"]
218
+ audio_ft: ["vggish"]
219
+ text_processor:
220
+ train:
221
+ name: "gpt_dialogue" # name of the textual processor for training data
222
+ max_turns: 3
223
+ use_caption: True
224
+ eval:
225
+ name: "gpt_dialogue" # name of the textual processor for evaluation data
226
+ max_turns: 3
227
+ use_caption: True
228
+
229
+ Subsequently, any processes (e.g. training) should load this configuration file to assign the correct processors.
230
+
231
+ .. code-block:: sh
232
+
233
+ python train.py --cfg-path dialogue_avsd_ft.yaml
docs/tutorial.rst ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Tutorials
2
+ ==============================
3
+
4
+ .. toctree::
5
+ :maxdepth: 1
6
+
7
+ tutorial.evaluation
8
+ tutorial.training-example
9
+ tutorial.configs
10
+ tutorial.datasets
11
+ tutorial.processors
12
+ tutorial.models
13
+ tutorial.tasks
docs/tutorial.tasks.rst ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Adding Tasks
2
+ ####################################
3
+
4
+ This is a tutorial on adding new machine learning tasks using ``lavis.tasks`` module.
5
+
6
+ The LAVIS library includes a standard task module that centralizes the model training and evaluation procedure of machine learning tasks.
7
+ The ``lavis.tasks`` module is designed such that any new tasks can be added and integrated, catering to any customization in the training and testing procedures.
8
+ In this tutorial, we will replicate the steps to add a new task into LAVIS for the `video-grounded dialogue tasks <https://arxiv.org/pdf/1901.09107.pdf>`_.
9
+
10
+ Base Task ``lavis.tasks.base_task``
11
+ ********************************************************************************
12
+
13
+ Note that any new model definition should inherit the base task class ``BaseTask``:
14
+
15
+ .. code-block:: python
16
+
17
+ import logging
18
+ import os
19
+
20
+ import torch.distributed as dist
21
+ from lavis.common.dist_utils import get_rank, get_world_size, is_main_process
22
+ from lavis.common.logger import MetricLogger, SmoothedValue
23
+ from lavis.common.registry import registry
24
+ from lavis.datasets.data_utils import prepare_sample
25
+
26
+ class BaseTask:
27
+ def __init__(self, **kwargs):
28
+ super().__init__()
29
+
30
+ self.inst_id_key = "instance_id"
31
+
32
+ @classmethod
33
+ def setup_task(cls, **kwargs):
34
+ return cls()
35
+
36
+ def build_model(self, cfg):
37
+ model_config = cfg.model_cfg
38
+
39
+ model_cls = registry.get_model_class(model_config.arch)
40
+ return model_cls.from_config(model_config)
41
+
42
+ def build_datasets(self, cfg):
43
+ """
44
+ Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'.
45
+ Download dataset and annotations automatically if not exist.
46
+
47
+ Args:
48
+ cfg (common.config.Config): _description_
49
+
50
+ Returns:
51
+ dict: Dictionary of torch.utils.data.Dataset objects by split.
52
+ """
53
+
54
+ datasets = dict()
55
+
56
+ datasets_config = cfg.datasets_cfg
57
+
58
+ assert len(datasets_config) > 0, "At least one dataset has to be specified."
59
+
60
+ for name in datasets_config:
61
+ dataset_config = datasets_config[name]
62
+
63
+ builder = registry.get_builder_class(name)(dataset_config)
64
+ dataset = builder.build_datasets()
65
+
66
+ datasets[name] = dataset
67
+
68
+ return datasets
69
+
70
+ def train_step(self, model, samples):
71
+ loss = model(samples)["loss"]
72
+ return loss
73
+
74
+ ...
75
+
76
+ In this base task, we already declare and standardize many common methods such as ``train_step``, ``build_model``, and ``build_datasets``.
77
+ Inheriting this base task class allows us to standardize operations of tasks across all task classes.
78
+ We recommend users not change the implementation of the base task class as this will have an impact on all existing task subclasses.
79
+
80
+ Dialogue Task ``lavis.tasks.dialogue``
81
+ ********************************************************************************
82
+
83
+ In this step, we can define a new task class, e.g. under ``lavis.tasks.dialogue``, for video-grounded dialogues.
84
+ For instance, we define a new task class ``DialogueTask`` that inherits the super task class ``BaseTask``.
85
+
86
+ .. code-block:: python
87
+
88
+ import json
89
+ import os
90
+
91
+ from lavis.common.dist_utils import main_process
92
+ from lavis.common.logger import MetricLogger
93
+ from lavis.common.registry import registry
94
+ from lavis.tasks.base_task import BaseTask
95
+ from lavis.datasets.data_utils import prepare_sample
96
+
97
+ import numpy as np
98
+
99
+ @registry.register_task("dialogue")
100
+ class DialogueTask(BaseTask):
101
+ def __init__(self, num_beams, max_len, min_len, evaluate, report_metric=True):
102
+ super().__init__()
103
+
104
+ self.num_beams = num_beams
105
+ self.max_len = max_len
106
+ self.min_len = min_len
107
+ self.evaluate = evaluate
108
+
109
+ self.report_metric = report_metric
110
+
111
+ @classmethod
112
+ def setup_task(cls, cfg):
113
+ run_cfg = cfg.run_cfg
114
+
115
+ num_beams = run_cfg.num_beams
116
+ max_len = run_cfg.max_len
117
+ min_len = run_cfg.min_len
118
+ evaluate = run_cfg.evaluate
119
+
120
+ report_metric = run_cfg.get("report_metric", True)
121
+
122
+ return cls(
123
+ num_beams=num_beams,
124
+ max_len=max_len,
125
+ min_len=min_len,
126
+ evaluate=evaluate,
127
+ report_metric=report_metric,
128
+ )
129
+
130
+ def valid_step(self, model, samples):
131
+ results = []
132
+ loss = model(samples)["loss"].item()
133
+
134
+ return [loss]
135
+ ...
136
+
137
+ Note that for any new task, we advise the users to review carefully the functions implemented within ``BaseTask`` and consider which methods should be modified.
138
+ For instance, the base task class already contains a standard implementation of model training steps that are common among machine learning steps.
139
+ Some major methods we want to emphasize and should be customized by each task are the ``valid_step`` and ``evaluation``.
140
+ These operations were not fully implemented in the base task class due to the differences in evaluation procedures among many machine learning tasks.
141
+ Another method that should be considered is the ``setup_task`` method.
142
+ This method will receive configurations that set task-specific parameters to initialize any task instance.
143
+
144
+ Registering New Task ``lavis.tasks.__init__``
145
+ ********************************************************************************
146
+
147
+ Any new task must be officially registered as part of the ``lavis.tasks`` module. For instance, to add a new task for video-grounded dialogues, we can modify the ``__init__.py`` as follows:
148
+
149
+ .. code-block:: python
150
+
151
+ from lavis.tasks.dialogue import DialogueTask
152
+
153
+ ...
154
+ __all__ = [
155
+ ...
156
+ "DialogueTask"
157
+ ]
158
+
159
+ Assigning Task
160
+ ***************
161
+
162
+ From the above example of task class, note that we define a ``setup_task`` method for each task class.
163
+ This method will process a configuration file and pass specific parameters e.g. ``num_beams`` (for beam search generative tasks during the inference stage), to initialize the task classes properly.
164
+ To assign and associate any task, we need to specify the correct registry of task classes in a configuration file.
165
+ For instance, the following should be specified in a configuration file e.g. ``dialogue_avsd_ft.yaml``:
166
+
167
+ .. code-block:: yaml
168
+
169
+ run:
170
+ task: dialogue # name of the task
171
+
172
+ # optimizer
173
+ ...
174
+
175
+ max_len: 20
176
+ min_len: 5
177
+ num_beams: 3
178
+ ...
179
+
180
+ Subsequently, any processes (e.g. training) should load this configuration file to assign the correct task.
181
+
182
+ .. code-block:: sh
183
+
184
+ python train.py --cfg-path dialogue_avsd_ft.yaml
docs/tutorial.training-example.rst ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Example on Finetuning BLIP on COCO-Captioning
2
+ ################################################
3
+
4
+ To finetune BLIP model on the coco caption dataset, first refer to :ref:`prep coco` to prepare the dataset if you have not done so.
5
+
6
+ To finetune the model, we have prepared a run script for you, which can run as follows:
7
+
8
+ .. code-block:: bash
9
+
10
+ bash run_scripts/blip/train/train_caption_coco_large.sh
11
+
12
+ This will finetune the pre-trained BLIP large model into a new model that can be used for captioning.
13
+
14
+ Deep Dive
15
+ **********
16
+ Now let's take a closer look at the script and see what it does.
17
+
18
+ .. code-block:: bash
19
+
20
+ python -m torch.distributed.run --nproc_per_node=8 train.py --cfg-path lavis/projects/blip/train/caption_coco_large_ft.yaml
21
+
22
+ As can be seen, the script simply calls the :code:`train.py` with PyTorch distributed training enabled.
23
+ The :code:`--cfg-path` argument specifies the **runtime config** file to use. The config file is a YAML file that specifies the training parameters, shown as follows:
24
+
25
+ .. literalinclude:: ../lavis/projects/blip/train/caption_coco_large_ft.yaml
26
+ :language: yaml
27
+ :linenos:
28
+
29
+ The runtime config file is divided into 3 sections:
30
+ - :code:`model`: specifies the model architecture and type to use.
31
+ - :code:`data`: specifies the dataset to use.
32
+ - :code:`run`: specifies the runner arguments, such as tasks, optimizer, learning rate scheduler, etc.
33
+
34
+ We describe each section in detail below.
35
+
36
+ Model configurations
37
+ =====================
38
+
39
+ .. literalinclude:: ../lavis/projects/blip/train/caption_coco_large_ft.yaml
40
+ :language: yaml
41
+ :linenos:
42
+ :lines: 6-10
43
+
44
+ The :code:`arch` argument specifies the model architecture to use. In this case, we use the :code:`blip_caption` architecture.
45
+ You can find available architectures by inspecting the :code:`model_zoo`.
46
+ Once the architecture is specified, the runner will look for the model class registered with the name and try to instantiate a model instance.
47
+ In this case :code:`BlipCaption` is the model registered with the name :code:`blip_caption`.
48
+
49
+ The registry maintains a mapping from the name string to the model class.
50
+ This allows the runner to find the model class dynamically based on the name string from the config file.
51
+ The following segment in :code:`lavis/models/blip_models/blip_caption.py` shows how :code:`BlipCaption` is registered with the name string :code:`blip_caption`:
52
+
53
+ .. literalinclude:: ../lavis/models/blip_models/blip_caption.py
54
+ :language: python
55
+ :linenos:
56
+ :lines: 20-38
57
+
58
+ One same model architecture may be pre-trained or finetuned on different datasets or have different model configurations.
59
+ For example, :code:`BlipCaption` have:
60
+
61
+ - :code:`base_coco`: pre-trained base BLIP model adapated for COCO captioning finetuning.
62
+
63
+ - :code:`large_coco`: pre-trained large BLIP model adapated for COCO captioning finetuning.
64
+
65
+ Therefore, we also need to specify :code:`model_type`. Here we use :code:`large_coco`.
66
+ And we set :code:`load_finetuned` to :code:`False` to indicate that we are finetuning the model from the pre-trained weights.
67
+ If :code:`load_finetuned` set to :code:`True` as by default, the model will load finetuned weights on coco captioning.
68
+
69
+ Given the model architecture and type, the library will then look for the default model config for :code:`large_coco` in :code:`lavis/models/blip_models/blip_caption.py`.
70
+ As can be seen in the above code snippet, the corresponding config path is stored in :code:`BlipCaption.PRETRAINED_MODEL_CONFIG_DICT`.
71
+ Then the library will load :code:`lavis/configs/models/blip_caption_large_coco.yaml` as the configuration to build the model.
72
+
73
+ *Priority of Configs*: Note that the priority of the run config is higher than the default model config, meaning that arguments in the run config will override the default model config.
74
+ For example, in the default model config, :code:`load_finetuned` is set to :code:`True` by default, while in the run config, we set it to :code:`False` and finetuning from the pre-trained weights only.
75
+
76
+
77
+ Dataset configurations
78
+ =========================
79
+
80
+ The second section of the config file specifies the dataset(s) to use.
81
+
82
+ .. literalinclude:: ../lavis/projects/blip/train/caption_coco_large_ft.yaml
83
+ :language: yaml
84
+ :linenos:
85
+ :lines: 12-24
86
+
87
+ We associate each dataset with a :code:`vis_processor` and a :code:`text_processor`, responsible for processing the visual and textual input respectively.
88
+ Here we again use the registry mechanism to dynamically load the processor class based on the name string.
89
+ For example, :code:`blip_image_train` is the name string for the :code:`BlipImageTrainProcessor` class, which is registered in :code:`lavis/processors/blip_processors.py`.
90
+
91
+ Similarly, the dataset name string is also registered in the registry, pointing to a dataset builder :code:`COCOCapBuilder` class.
92
+ By default, the builder will load the default dataset configuration as in :code:`DATASET_CONFIG_DICT`. You may also add new dataset types by adding new entries to the dictionary.
93
+
94
+ The dataset configuration used here is:
95
+
96
+ .. literalinclude:: ../lavis/configs/datasets/coco/defaults_cap.yaml
97
+ :language: yaml
98
+ :linenos:
99
+ :lines: 6-28
100
+
101
+ In this configuration file, we specify the dataset name and mainly its building information.
102
+ The build information is divided into two parts: :code:`annotation` and :code:`images`. The annotation files will be automatically downloaded upon loading the dataset for the first time.
103
+ The :code:`images` part specifies the image root directory. This is a relative path to the cache directory, which is :code:`cache` by default. If you have a local copy of the dataset, you can specify the path to the local copy by
104
+ overwriting the :code:`images` part in the runtime config file. For example, you may alter the run config as below to use your local dataset copy:
105
+
106
+ .. code:: yaml
107
+
108
+ datasets:
109
+ coco_caption: # name of the dataset builder
110
+ vis_processor:
111
+ train:
112
+ name: "blip_image_train"
113
+ eval:
114
+ name: "blip_image_eval"
115
+ text_processor:
116
+ train:
117
+ name: "blip_caption"
118
+ prompt: "a picture of "
119
+ eval:
120
+ name: "blip_caption"
121
+ images:
122
+ YOUR_LOCAL_IMAGE_ROOT_DIR
123
+
124
+ LAVIS supports using multiple datasets for training. See an example in :code:`lavis/projects/blip/train/pretrain_14m.yaml`.
125
+
126
+
127
+ Runner configurations
128
+ =========================
129
+ The last section of the config file specifies the arguments for the runner, shown below:
130
+
131
+ .. literalinclude:: ../lavis/projects/blip/train/caption_coco_large_ft.yaml
132
+ :language: yaml
133
+ :linenos:
134
+ :lines: 26-56
135
+
136
+ Here we specify runner-related arguments, including
137
+ - task-specific arguments, such as :code:`task`, :code:`max_len`, :code:`min_len`, etc.
138
+ - learning rate schedulers, optimizer;
139
+ - distributed training settings;
140
+ - logging and checkpointing settings.
141
+
142
+ Available Configurations
143
+ #########################
144
+
145
+ See :ref:`config` for the full list of available configurations and their descriptions.
examples/blip2_itm.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import torch
4
+ from PIL import Image
5
+
6
+ from lavis.models import load_model_and_preprocess
7
+ from lavis.processors import load_processor
8
+ from lavis.common.registry import registry
9
+ from torch.nn import functional as F
10
+ from lavis.models.base_model import all_gather_with_grad, concat_all_gather
11
+ import numpy as np
12
+ import pandas as pd
13
+ import time
14
+ from fuzzywuzzy import process
15
+ from multiprocessing import Pool, Queue, Process
16
+ import difflib
17
+ import Levenshtein
18
+ import os
19
+ # import obonet
20
+
21
+
22
+ def fuzzy_match(texts):
23
+ text_dict = {}
24
+ for context in texts:
25
+ if context not in choices:
26
+ # txt_dict[txt] = process.extractOne(txt, choices)[0]
27
+ text_dict[context] = difflib.get_close_matches(context, choices, n=1, cutoff=0.)[0]
28
+ return text_dict
29
+
30
+
31
+ def txt_map(x, txt_dict):
32
+ if type(x) == str:
33
+ x = eval(x)
34
+ x_ = []
35
+ for i in x:
36
+ if i in txt_dict:
37
+ x_.append(txt_dict[i])
38
+ else:
39
+ x_.append(i)
40
+ return x_
41
+
42
+
43
+ def levenshtein_sim(text, label):
44
+ all_s = []
45
+ for x in label:
46
+ s = 0
47
+ for y in text:
48
+ temp = Levenshtein.ratio(x, y)
49
+ if temp > s:
50
+ s = temp
51
+ all_s.append(s)
52
+ all_s = [round(i, 3) for i in all_s]
53
+ return all_s
54
+
55
+ def func(text, label):
56
+ all_s = []
57
+ for x in label:
58
+ s = 0
59
+ for y in text:
60
+ temp = Levenshtein.ratio(x, y)
61
+ if temp > s:
62
+ s = temp
63
+ all_s.append(s)
64
+ all_s = [round(i, 3) for i in all_s]
65
+ return all_s
66
+
67
+
68
+ def stage2_output(df_test):
69
+ config = {'arch': 'blip2_protein_opt', 'load_finetuned': False,
70
+ 'pretrained': '/cluster/home/wenkai/LAVIS/lavis/output/BLIP2/Pretrain_stage2/20230924220/checkpoint_5.pth',
71
+ 'finetuned': '', 'num_query_token': 32, 'opt_model': 'facebook/opt-2.7b', 'prompt': '',
72
+ 'model_type': 'pretrain_protein_opt2.7b', 'load_pretrained': True, 'freeze_vit': True,
73
+ 'max_protein_len': 600,
74
+ 'max_txt_len': 25}
75
+
76
+ model_cls = registry.get_model_class(config['arch'])
77
+ model = model_cls.from_config(config)
78
+ model.to(device)
79
+ model.eval()
80
+
81
+ images = df_test['protein'].tolist()
82
+ n = len(images)
83
+ bsz = 12
84
+ iter = n // bsz + 1
85
+
86
+ for i in range(iter):
87
+ image = images[i*bsz: min(n, (i+1)*bsz)]
88
+ image = [('protein{}'.format(i), x) for i, x in enumerate(image)]
89
+
90
+ with model.maybe_autocast():
91
+ _, _, batch_tokens = model.visual_encoder(image)
92
+ image_embeds = model.ln_vision(batch_tokens.to(device), repr_layers=[model.vis_layers], return_contacts=True)["representations"][model.vis_layers].contiguous()
93
+
94
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
95
+
96
+ query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
97
+ query_output = model.Qformer.bert(
98
+ query_embeds=query_tokens,
99
+ encoder_hidden_states=image_embeds,
100
+ encoder_attention_mask=image_atts,
101
+ return_dict=True,
102
+ )
103
+
104
+ inputs_opt = model.opt_proj(query_output.last_hidden_state)
105
+ atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(device)
106
+
107
+ model.opt_tokenizer.padding_side = "right"
108
+
109
+ text = ['' for i in range(len(image))]
110
+ opt_tokens = model.opt_tokenizer(
111
+ text,
112
+ return_tensors="pt",
113
+ padding="longest",
114
+ truncation=True,
115
+ max_length=model.max_txt_len,
116
+ ).to(device)
117
+ inputs_embeds = model.opt_model.model.decoder.embed_tokens(opt_tokens.input_ids)
118
+ inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)
119
+ attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)
120
+ num_txt = 10
121
+ return_num_txt = 5
122
+ with model.maybe_autocast():
123
+ outputs = model.opt_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, min_length=3,
124
+ max_length=30,
125
+ repetition_penalty=5., num_beams=num_txt, eos_token_id=50118,
126
+ length_penalty=1., num_return_sequences=return_num_txt, temperature=1.)
127
+ output_text = model.opt_tokenizer.batch_decode(outputs, skip_special_tokens=True)
128
+ output_text = [text.strip() for text in output_text]
129
+ output_text_ = []
130
+ for i in range(len(image)):
131
+ output_text_.append(';'.join(output_text[i * return_num_txt:(i + 1) * return_num_txt]))
132
+ with open('/cluster/home/wenkai/LAVIS/output/output{}.txt'.format(fix), 'a+') as f:
133
+ for i in range(len(image)):
134
+ f.write(image[i][1] + "|" + output_text_[i] + '\n')
135
+
136
+
137
+ cat = 'mf'
138
+ fix = '_mf'
139
+ if cat == 'bp':
140
+ fix = '_bp'
141
+ if cat == 'cc':
142
+ fix = '_cc'
143
+
144
+ # model_pth = {'mf': 'uniprot_swissprot_mf_stage1_epo19.pth', 'bp': 'checkpoint17_GO_swissprot_reviewed_bp_stage1.pth', 'cc': ''}
145
+
146
+ # graph = obonet.read_obo("http://purl.obolibrary.org/obo/go.obo")
147
+
148
+ # setup device to use
149
+ device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
150
+ # device = 'cpu'
151
+
152
+ ### Levenshtein similarity
153
+ test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/test{}.csv'.format(fix), sep='|')[:10000]
154
+ test['function'] = test['function'].apply(lambda x: x.lower())
155
+
156
+
157
+ if os.path.exists('/cluster/home/wenkai/LAVIS/output/output{}.txt'.format(fix)):
158
+ os.remove('/cluster/home/wenkai/LAVIS/output/output{}.txt'.format(fix))
159
+ print("stage 2 predict starting")
160
+ stage2_output(test)
161
+ print("stage 2 predict completed")
162
+
163
+
164
+
165
+ df_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/output{}.txt'.format(fix), sep='|', header=None, on_bad_lines='warn')
166
+ df_pred.columns = ['protein', 'function']
167
+ df_pred = df_pred.drop_duplicates()
168
+ df_pred['function'] = df_pred['function'].apply(lambda x: str(x).split(';'))
169
+ df_pred['function'] = df_pred['function'].apply(lambda x: [i.strip() for i in list(set(x))])
170
+
171
+ test.columns
172
+ test_g = test.groupby(['protein']).agg({'function': lambda x: list(x)}).reset_index()
173
+ test_g.columns = ['protein', 'label']
174
+
175
+ data = pd.merge(df_pred, test_g, on='protein', how='left')
176
+ data = data[data['label'].notnull()]
177
+
178
+ sim = []
179
+ for text, label in zip(data['function'].tolist(), data['label'].tolist()):
180
+ sim.append(func(text, label))
181
+
182
+ data['sim'] = sim
183
+ data['avg_score'] = data['sim'].apply(lambda x: round(np.mean(x), 3))
184
+ print("average similarity score: {}".format(round(data['avg_score'].mean(), 3)))
185
+ # data.to_csv('/home/nilin/LAVIS/predict_{}.csv'.format(cat), index=False, sep='|')
186
+
187
+
188
+ test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/test{}.csv'.format(fix), sep='|', usecols=['function', 'GO_label'])
189
+ test['function'] = test['function'].apply(lambda x: x.lower())
190
+ test = test.drop_duplicates()
191
+ test_dict = dict(zip(test['function'], test['GO_label']))
192
+ val = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/val{}.csv'.format(fix), sep='|', usecols=['function', 'GO_label'])
193
+ val['function'] = val['function'].apply(lambda x: x.lower())
194
+ val = val.drop_duplicates()
195
+ val_dict = dict(zip(val['function'], val['GO_label']))
196
+ train = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/train{}.csv'.format(fix), sep='|', usecols=['function', 'GO_label'])
197
+ train['function'] = train['function'].apply(lambda x: x.lower())
198
+ train = train.drop_duplicates()
199
+ train_dict = dict(zip(train['function'], train['GO_label']))
200
+
201
+
202
+ # go_des = pd.read_csv('/home/nilin/LAVIS/data/go_descriptions_new.txt', sep='|', header=None)
203
+ # # go_des = pd.read_csv('/home/nilin/LAVIS/data/go_descriptions.txt', sep='|', header=None)
204
+ # go_des.columns = ['GO', 'function']
205
+ # go_des = go_des[go_des['function'].notnull()]
206
+ # go_des['function'] = go_des['function'].apply(lambda x: x.lower())
207
+ # GO_dict = dict(zip(go_des['function'], go_des['GO']))
208
+ GO_dict = {}
209
+ GO_dict.update(train_dict)
210
+ GO_dict.update(val_dict)
211
+ GO_dict.update(test_dict)
212
+ choices = list(GO_dict.keys())
213
+
214
+
215
+
216
+ # data = pd.read_csv('/home/nilin/LAVIS/predict_{}.csv'.format(cat), sep='|')
217
+ data = data.sort_values(by='protein')
218
+ data = data.drop_duplicates('protein')
219
+ # data = data.sample(1000)
220
+
221
+ ### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
222
+ t0 = time.time()
223
+ txt_dict = {}
224
+
225
+ all_txt = []
226
+ for txt in data['function']:
227
+ if type(txt) == str:
228
+ all_txt.extend(eval(txt))
229
+ else:
230
+ all_txt.extend(txt)
231
+ all_txt = list(set(all_txt))
232
+
233
+ n = len(all_txt)
234
+ thread = 20
235
+ size = int(n/thread)
236
+ inds = list(range(0, n, size))
237
+ inds.append(n)
238
+ all_txt_sep = [all_txt[i: min(i+size, n)] for i in inds[:-1]]
239
+
240
+ with Pool(processes=thread) as pool:
241
+ result = pool.map(fuzzy_match, all_txt_sep)
242
+ pool.close()
243
+ pool.join()
244
+ for d in result:
245
+ txt_dict.update(d)
246
+
247
+ # for txt in all_txt[:10]:
248
+ # fuzzy_match(txt)
249
+
250
+ data['function'] = data['function'].apply(lambda x: txt_map(x, txt_dict))
251
+ data['function'] = data['function'].apply(lambda x: list(set(x)))
252
+ print("fuzzy matching time: {}".format(time.time() - t0))
253
+
254
+
255
+
256
+
257
+ ### Find the generated GO text that not included in the ground truth. Then generate pairs between them.
258
+ # pair_a, pair_b = [], []
259
+ # for preds, labels in zip(data['function'], data['label']):
260
+ # if type(preds) == str:
261
+ # preds = eval(preds)
262
+ # if type(labels) == str:
263
+ # labels = eval(labels)
264
+ # l = len(labels)
265
+ # for pred in preds:
266
+ # if pred not in labels:
267
+ # pair_a.extend([pred]*l)
268
+ # pair_b.extend(labels[:])
269
+ # pair_a = [re.sub('_', ':', GO_dict[i]) for i in pair_a]
270
+ # pair_b = [re.sub('_', ':', GO_dict[i]) for i in pair_b]
271
+ # with open('/home/nilin/LAVIS/examples/GO_pair{}.txt'.format(fix), 'w+') as f:
272
+ # for i, j in zip(pair_a, pair_b):
273
+ # f.write(i+' '+j+'\n')
274
+
275
+
276
+ # load model
277
+ model_config = {'arch': 'blip2_protein', 'load_finetuned': False,
278
+ 'pretrained': '/cluster/home/wenkai/LAVIS/lavis/output/BLIP2/Pretrain_stage1/20230922185/checkpoint_15.pth',
279
+ 'finetuned': '', 'num_query_token': 32, 'prompt': '',
280
+ 'model_type': 'pretrain', 'load_pretrained': True, 'freeze_vit': False,
281
+ 'max_protein_len': 512, 'max_txt_len': 25}
282
+
283
+ model_cls = registry.get_model_class(model_config['arch'])
284
+ model = model_cls.from_config(model_config)
285
+ model = model.to(device)
286
+ model.eval()
287
+
288
+ # evaluate
289
+ t0 = time.time()
290
+ proteins = list(data['protein'])
291
+ txts = list(data['function'])
292
+ scores = []
293
+ for seq, txt in zip(proteins, txts):
294
+ image = [('protein1', seq)]
295
+ _, _, batch_tokens = model.visual_encoder(image)
296
+ image_embeds = model.ln_vision(batch_tokens.to(device), repr_layers=[30], return_contacts=True)["representations"][
297
+ 30].contiguous()
298
+
299
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
300
+
301
+ query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
302
+
303
+ query_output = model.Qformer.bert(
304
+ query_embeds=query_tokens,
305
+ encoder_hidden_states=image_embeds,
306
+ encoder_attention_mask=image_atts,
307
+ use_cache=True,
308
+ return_dict=True,
309
+ )
310
+
311
+ image_feats = F.normalize(model.vision_proj(query_output.last_hidden_state), dim=-1)
312
+
313
+ image_feats_all = concat_all_gather(image_feats)
314
+
315
+ if type(txt) == str:
316
+ txt = eval(txt)
317
+ length = len(txt)
318
+ with torch.no_grad():
319
+ text_tokens = model.tokenizer(
320
+ txt,
321
+ padding="max_length",
322
+ truncation=True,
323
+ max_length=model.max_txt_len,
324
+ return_tensors="pt",
325
+ ).to(device)
326
+ text_output = model.Qformer.bert(
327
+ text_tokens.input_ids,
328
+ attention_mask=text_tokens.attention_mask,
329
+ return_dict=True,
330
+ )
331
+
332
+ text_feat = F.normalize(
333
+ model.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1
334
+ )
335
+
336
+ text_feat_all = concat_all_gather(text_feat)
337
+ sim_q2t = torch.matmul(image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1)).squeeze()
338
+ sim_i2t, _ = sim_q2t.max(-1)
339
+ # print('sim_i2t: {}'.format(sim_i2t))
340
+ if length > 1:
341
+ scores.append(list(sim_i2t.detach().cpu().numpy()))
342
+ else:
343
+ scores.append([sim_i2t.item()])
344
+ print("model evaluate time: {}".format(time.time() - t0))
345
+ data['score'] = scores
346
+
347
+ # precision and recall top-k
348
+ topk = 2
349
+ threshould = 0.1
350
+ labels = []
351
+ pred_labels = []
352
+ for l in data['label']:
353
+ if type(l) == str:
354
+ l = eval(l)
355
+ labels.extend(l)
356
+
357
+ labels = list(set(labels))
358
+ total = len(labels)
359
+ for topk in range(1,7):
360
+ for threshould in range(1, 25, 1):
361
+ threshould /= 100
362
+ filter_txts = []
363
+ recalls = []
364
+ precisions = []
365
+ f1 = []
366
+ tp_dict, fp_dict, fn_dict = dict(zip(labels, [0]*len(labels))), dict(zip(labels, [0]*len(labels))), dict(zip(labels, [0]*len(labels)))
367
+ for txts, scores, label in zip(data['function'], data['score'], data['label']):
368
+ if type(label) == str:
369
+ label = eval(label)
370
+ txts_ = np.array(txts)
371
+ scores = np.array(scores)
372
+ txts = txts_[scores > threshould]
373
+ if len(txts) < 1:
374
+ txts = txts_[np.argmax(scores)]
375
+ scores = scores[scores > threshould]
376
+
377
+ l = len(scores)
378
+ ll = len(label)
379
+ if l <= topk:
380
+ filter_txts.append(list(txts))
381
+ else:
382
+ ind = np.argpartition(scores, -topk)[-topk:]
383
+ txts = txts[ind]
384
+ filter_txts.append(list(txts))
385
+ l = topk
386
+ for t in label:
387
+ if t in txts:
388
+ tp_dict[t] += 1
389
+ else:
390
+ fn_dict[t] += 1
391
+ for p in txts:
392
+ if p not in label:
393
+ if p in fp_dict:
394
+ fp_dict[p] += 1
395
+ else:
396
+ fp_dict[p] = 1
397
+ pred_labels.extend(txts)
398
+ p_total = len(set(pred_labels))
399
+ re, pr = 0., 0.
400
+ for x in labels:
401
+ re += tp_dict[x] / (1.0 * (tp_dict[x] + fn_dict[x] + 1e-8))
402
+ pr += tp_dict[x] / (1.0 * (tp_dict[x] + fp_dict[x]+1e-8))
403
+ r = re / total
404
+ p = pr / total
405
+ f1 = 2 * p * r / (p + r)
406
+ print("Topk: {}, threshould: {}, macro_recall: {}, macro_precision: {}, micro_f1: {}".format(topk, threshould, r, p, f1))
407
+ # num_r = 0
408
+ # num_p = 0
409
+ # for x in label:
410
+ # if x in txts:
411
+ # num_r += 1
412
+ # for x in txts:
413
+ # if x in label:
414
+ # num_p += 1
415
+ # recall = num_r/ll
416
+ # precision = num_p/(l+0.0001)
417
+ # recalls.append(recall)
418
+ # precisions.append(precision)
419
+ # f1.append((2*recall*precision)/(recall+precision+0.0001))
420
+ #
421
+ # data['predict'] = filter_txts
422
+ # data['precision'] = precisions
423
+ # data['recall'] = recalls
424
+ # data['f1'] = f1
425
+ # print("Topk: {}, threshould: {}, macro_recall: {}, macro_precision: {}, micro_f1: {}".format(topk, threshould, round(data['recall'].mean(), 4), round(data['precision'].mean(), 4), round(data['f1'].mean(), 4)))
426
+
427
+
428
+
429
+
430
+
431
+
432
+ # sim = []
433
+ # for text, label in zip(data['predict'].tolist(), data['label'].tolist()):
434
+ # sim.append(levenshtein_sim(text, label))
435
+ #
436
+ # data['sim_filter'] = sim
437
+ # data['avg_score'] = data['sim_filter'].apply(lambda x: round(np.mean(x), 3))
438
+
439
+
440
+ # data['function'] = data['function'].apply(lambda x: eval(re.sub(';', ',', str(x))))
441
+ # data['label'] = data['label'].apply(lambda x: eval(re.sub(';', ',', str(x))))
442
+ # data['sim'] = data['sim'].apply(lambda x: eval(re.sub(';', ',', str(x))))
443
+ #
444
+ # data['function'] = data['function'].apply(lambda x: re.sub(',', ';', str(x)))
445
+ # data['label'] = data['label'].apply(lambda x: re.sub(',', ';', str(x)))
446
+ # data['sim'] = data['sim'].apply(lambda x: re.sub(',', ';', str(x)))
447
+ # data['predict'] = data['predict'].apply(lambda x: re.sub(',', ';', str(x)))
448
+ # data['sim_filter'] = data['sim_filter'].apply(lambda x: re.sub(',', ';', str(x)))
449
+
450
+ data.to_csv('/cluster/home/wenkai/LAVIS/output/predict_sim{}.csv'.format(fix), sep='|', index=False)
451
+ # data = pd.read_csv('/cluster/home/wenkai/LAVIS/output/predict_sim{}.csv'.format(fix), sep='|')
452
+
453
+
454
+
455
+
456
+
457
+
458
+
459
+
460
+ #
461
+ # # example
462
+ # image = ['MIELKHVTFGYNKKQMVLQDINITIPDGENVGILGESGCGKSTLASLVLGLFKPVKGEIYLSDNAVLTIFQHPLTSFNPDWTIETSLKEALYYYRGLTDNTAQDQLLLQHLSTFELNAQLLTKLPSEVSGGQLQRFNVMRSLLAQPRVLICDEITSNLDVIAEQNVINILKAQTITNLNHFIVISHDLSVLQRLVNRIIVLKDGMIVDDFAIEELFNVDRHPYTKELVQTFSY']
463
+ # image = [('protein{}'.format(i), x) for i, x in enumerate(image)]
464
+ #
465
+ # _, _, batch_tokens = model.visual_encoder(image)
466
+ # image_embeds = model.ln_vision(batch_tokens.to(device), repr_layers=[30], return_contacts=True)["representations"][30].contiguous()
467
+ #
468
+ # image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
469
+ #
470
+ # query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
471
+ #
472
+ # query_output = model.Qformer.bert(
473
+ # query_embeds=query_tokens,
474
+ # encoder_hidden_states=image_embeds,
475
+ # encoder_attention_mask=image_atts,
476
+ # use_cache=True,
477
+ # return_dict=True,
478
+ # )
479
+ #
480
+ # image_feats = F.normalize(model.vision_proj(query_output.last_hidden_state), dim=-1)
481
+ #
482
+ # image_feats_all = concat_all_gather(image_feats)
483
+ #
484
+ # functions = ['transmembrane transporter activity', 'nickel cation transmembrane transporter activity', 'nickel cation binding', 'atp hydrolysis activity', 'atp hydrolysis', 'cadmium binding', 'abc-type nickel transmembrane transporter activity', 'abc-type nickel transporter activity', 'nickel transmembrane transporter activity', 'atp binding']
485
+ # for text in functions:
486
+ # with torch.no_grad():
487
+ # # text = 'flavin adenine dinucleotide binding'
488
+ # text_tokens = model.tokenizer(
489
+ # text,
490
+ # padding="max_length",
491
+ # truncation=True,
492
+ # max_length=model.max_txt_len,
493
+ # return_tensors="pt",
494
+ # ).to(device)
495
+ # text_output = model.Qformer.bert(
496
+ # text_tokens.input_ids,
497
+ # attention_mask=text_tokens.attention_mask,
498
+ # return_dict=True,
499
+ # )
500
+ #
501
+ # text_feat = F.normalize(
502
+ # model.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1
503
+ # )
504
+ #
505
+ # text_feat_all = concat_all_gather(text_feat)
506
+ # sim_q2t = torch.matmul(image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1)).squeeze()
507
+ # sim_i2t, _ = sim_q2t.max(-1)
508
+ # print('sim_i2t: {}'.format(sim_i2t))
509
+ #
510
+ # # # text-query similarity: [batch_size, batch_size*num_gpu, num_query_tokens]
511
+ # # sim_t2q = torch.matmul(
512
+ # # text_feat.unsqueeze(1).unsqueeze(1), image_feats_all.permute(0, 2, 1)
513
+ # # ).squeeze()
514
+ # #
515
+ # # # text-image similarity: aggregate across all query tokens
516
+ # # sim_t2i, _ = sim_t2q.max(-1)
517
+ # # print('sim_t2i: {}'.format(sim_t2i))
518
+
519
+
520
+
examples/blip2_predict_func.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import torch
5
+ from PIL import Image
6
+
7
+ from lavis.models import load_model_and_preprocess
8
+ from lavis.processors import load_processor
9
+ from lavis.common.registry import registry
10
+ from torch.nn import functional as F
11
+ from lavis.models.base_model import all_gather_with_grad, concat_all_gather
12
+ import numpy as np
13
+ import pandas as pd
14
+ import time
15
+ from fuzzywuzzy import process
16
+ from multiprocessing import Pool, Queue, Process
17
+ import difflib
18
+ import Levenshtein
19
+ # import obonet
20
+
21
+
22
+ # setup device to use
23
+ device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
24
+ # device = 'cpu'
25
+
26
+
27
+ def txt_map(x, txt_dict):
28
+ if type(x) == str:
29
+ x = eval(x)
30
+ x_ = []
31
+ for i in x:
32
+ if i in txt_dict:
33
+ x_.append(txt_dict[i])
34
+ else:
35
+ x_.append(i)
36
+ return x_
37
+
38
+
39
+ def levenshtein_sim(text, label):
40
+ all_s = []
41
+ for x in label:
42
+ s = 0
43
+ for y in text:
44
+ temp = Levenshtein.ratio(x, y)
45
+ if temp > s:
46
+ s = temp
47
+ all_s.append(s)
48
+ all_s = [round(i, 3) for i in all_s]
49
+ return all_s
50
+
51
+ def func(text, label):
52
+ all_s = []
53
+ for x in text:
54
+ s = 0
55
+ for y in label:
56
+ temp = Levenshtein.ratio(x, y)
57
+ if temp > s:
58
+ s = temp
59
+ all_s.append(s)
60
+ all_s = [round(i, 3) for i in all_s]
61
+ return all_s
62
+
63
+
64
+ def stage2_output(df_test, return_num_txt=1):
65
+ config = {'arch': 'blip2_protein_opt', 'load_finetuned': False,
66
+ 'pretrained': '/cluster/home/wenkai/LAVIS/lavis/output/BLIP2/Pretrain_stage2/20230924220/checkpoint_5.pth',
67
+ 'finetuned': '', 'num_query_token': 32, 'opt_model': 'facebook/opt-2.7b', 'prompt': '',
68
+ 'model_type': 'pretrain_protein_opt2.7b', 'load_pretrained': True, 'freeze_vit': True,
69
+ 'max_protein_len': 600,
70
+ 'max_txt_len': 25}
71
+
72
+ model_cls = registry.get_model_class(config['arch'])
73
+ model = model_cls.from_config(config)
74
+ model.to(device)
75
+ model.eval()
76
+
77
+ images = df_test['protein'].tolist()
78
+ n = len(images)
79
+ bsz = 12
80
+ iter = n // bsz + 1
81
+
82
+ for i in range(iter):
83
+ image = images[i*bsz: min(n, (i+1)*bsz)]
84
+ image = [('protein{}'.format(i), x) for i, x in enumerate(image)]
85
+
86
+ with model.maybe_autocast():
87
+ _, _, batch_tokens = model.visual_encoder(image)
88
+ image_embeds = model.ln_vision(batch_tokens.to(device), repr_layers=[model.vis_layers], return_contacts=True)["representations"][model.vis_layers].contiguous()
89
+
90
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
91
+
92
+ query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
93
+ query_output = model.Qformer.bert(
94
+ query_embeds=query_tokens,
95
+ encoder_hidden_states=image_embeds,
96
+ encoder_attention_mask=image_atts,
97
+ return_dict=True,
98
+ )
99
+
100
+ inputs_opt = model.opt_proj(query_output.last_hidden_state)
101
+ atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(device)
102
+
103
+ model.opt_tokenizer.padding_side = "right"
104
+
105
+ text = ['' for i in range(len(image))]
106
+ opt_tokens = model.opt_tokenizer(
107
+ text,
108
+ return_tensors="pt",
109
+ padding="longest",
110
+ truncation=True,
111
+ max_length=model.max_txt_len,
112
+ ).to(device)
113
+ inputs_embeds = model.opt_model.model.decoder.embed_tokens(opt_tokens.input_ids)
114
+ inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)
115
+ attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)
116
+ num_txt = 6
117
+ with model.maybe_autocast():
118
+ outputs = model.opt_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, min_length=3,
119
+ max_length=30,
120
+ repetition_penalty=1., num_beams=num_txt, eos_token_id=50118,
121
+ length_penalty=1., num_return_sequences=return_num_txt, temperature=1.)
122
+ output_text = model.opt_tokenizer.batch_decode(outputs, skip_special_tokens=True)
123
+ output_text = [text.strip() for text in output_text]
124
+ output_text_ = []
125
+ for i in range(len(image)):
126
+ output_text_.append(';'.join(output_text[i * return_num_txt:(i + 1) * return_num_txt]))
127
+ with open('/cluster/home/wenkai/LAVIS/output/output{}.txt'.format(fix), 'a+') as f:
128
+ for i in range(len(image)):
129
+ f.write(image[i][1] + "|" + output_text_[i] + '\n')
130
+
131
+
132
+ cat = 'mf'
133
+ fix = '_mf'
134
+ if cat == 'bp':
135
+ fix = '_bp'
136
+ if cat == 'cc':
137
+ fix = '_cc'
138
+
139
+ return_num_txt = 1
140
+ # graph = obonet.read_obo("http://purl.obolibrary.org/obo/go.obo")
141
+
142
+ ### Levenshtein similarity
143
+ test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/test{}.csv'.format(fix), sep='|')
144
+ test['function'] = test['function'].apply(lambda x: x.lower())
145
+
146
+
147
+ if os.path.exists('/cluster/home/wenkai/LAVIS/output/output{}.txt'.format(fix)):
148
+ os.remove('/cluster/home/wenkai/LAVIS/output/output{}.txt'.format(fix))
149
+ print("stage 2 predict starting")
150
+ stage2_output(test)
151
+ print("stage 2 predict completed")
152
+
153
+ df_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/output{}.txt'.format(fix), sep='|', header=None, on_bad_lines='warn')
154
+ df_pred.columns = ['protein', 'function']
155
+ df_pred = df_pred.drop_duplicates()
156
+ df_pred['function'] = df_pred['function'].apply(lambda x: str(x).split(';'))
157
+ df_pred['function'] = df_pred['function'].apply(lambda x: [i.strip() for i in list(set(x))])
158
+
159
+ test_g = test.groupby(['protein']).agg({'function': lambda x: list(x)}).reset_index()
160
+ test_g.columns = ['protein', 'label']
161
+
162
+ data = pd.merge(df_pred, test_g, on='protein', how='left')
163
+ data = data[data['label'].notnull()]
164
+
165
+ sim = []
166
+ for text, label in zip(data['function'].tolist(), data['label'].tolist()):
167
+ sim.append(func(text, label))
168
+
169
+ data['sim'] = sim
170
+ data['avg_score'] = data['sim'].apply(lambda x: round(np.mean(x), 3))
171
+ data['count'] = data['sim'].apply(lambda x: x.count(1.))
172
+ print("average similarity score: {}".format(round(data['avg_score'].mean(), 3)))
173
+ print("Return texts: {}; Accuracy: {}".format(return_num_txt, data['count'].sum()/(return_num_txt*data.shape[0])))
174
+ data.to_csv('/cluster/home/wenkai/LAVIS/output/predict_{}.csv'.format(cat), index=False, sep='|')
175
+
176
+
177
+
178
+
examples/blip2_predict_func_concat.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import torch
5
+ from PIL import Image
6
+
7
+ from lavis.models import load_model_and_preprocess
8
+ from lavis.processors import load_processor
9
+ from lavis.common.registry import registry
10
+ from torch.nn import functional as F
11
+ from lavis.models.base_model import all_gather_with_grad, concat_all_gather
12
+ import numpy as np
13
+ import pandas as pd
14
+ import time
15
+ from fuzzywuzzy import process
16
+ from multiprocessing import Pool, Queue, Process
17
+ import difflib
18
+ import Levenshtein
19
+
20
+ # import obonet
21
+
22
+
23
+ # setup device to use
24
+ device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
25
+
26
+
27
+ # device = torch.device("cuda")
28
+
29
+
30
+ def txt_map(x, txt_dict):
31
+ if type(x) == str:
32
+ x = eval(x)
33
+ x_ = []
34
+ for i in x:
35
+ if i in txt_dict:
36
+ x_.append(txt_dict[i])
37
+ else:
38
+ x_.append(i)
39
+ return x_
40
+
41
+
42
+ def levenshtein_sim(text, label):
43
+ all_s = []
44
+ for x in label:
45
+ s = 0
46
+ for y in text:
47
+ temp = Levenshtein.ratio(x, y)
48
+ if temp > s:
49
+ s = temp
50
+ all_s.append(s)
51
+ all_s = [round(i, 3) for i in all_s]
52
+ return all_s
53
+
54
+
55
+ def func(text, label):
56
+ all_s = []
57
+ for x in text:
58
+ s = 0
59
+ for y in label:
60
+ temp = Levenshtein.ratio(x, y)
61
+ if temp > s:
62
+ s = temp
63
+ all_s.append(s)
64
+ all_s = [round(i, 3) for i in all_s]
65
+ return all_s
66
+
67
+
68
+ def stage2_output(df_test, return_num_txt=1):
69
+ config = {'arch': 'blip2_protein_opt', 'load_finetuned': False,
70
+ 'pretrained': '/cluster/home/wenkai/LAVIS/lavis/output/BLIP2/Pretrain_stage2/20231029182/checkpoint_0.pth',
71
+ 'finetuned': '', 'num_query_token': 32, 'opt_model': 'facebook/opt-2.7b', 'prompt': '',
72
+ 'model_type': 'pretrain_protein_opt2.7b', 'load_pretrained': True, 'freeze_vit': True,
73
+ 'max_protein_len': 600,
74
+ 'max_txt_len': 256}
75
+
76
+ model_cls = registry.get_model_class(config['arch'])
77
+ model = model_cls.from_config(config)
78
+ model.to(device)
79
+ model.eval()
80
+
81
+ images = df_test['protein'].tolist()
82
+ n = len(images)
83
+ bsz = 8
84
+ iter = n // bsz + 1
85
+ with open('/cluster/home/wenkai/LAVIS/output/output_concat_{}{}{}.txt'.format(split, fix, type_fix), 'a+') as f:
86
+ for i in range(iter):
87
+ image = images[i * bsz: min(n, (i + 1) * bsz)]
88
+ image = [('protein{}'.format(i), x) for i, x in enumerate(image)]
89
+
90
+ with model.maybe_autocast():
91
+ _, _, batch_tokens = model.visual_encoder(image)
92
+ image_embeds = \
93
+ model.ln_vision(batch_tokens.to(device), repr_layers=[model.vis_layers], return_contacts=True)[
94
+ "representations"][model.vis_layers].contiguous()
95
+
96
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
97
+
98
+ query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
99
+ query_output = model.Qformer.bert(
100
+ query_embeds=query_tokens,
101
+ encoder_hidden_states=image_embeds,
102
+ encoder_attention_mask=image_atts,
103
+ return_dict=True,
104
+ )
105
+
106
+ inputs_opt = model.opt_proj(query_output.last_hidden_state)
107
+ atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(device)
108
+
109
+ model.opt_tokenizer.padding_side = "right"
110
+
111
+ text = ['' for i in range(len(image))]
112
+ opt_tokens = model.opt_tokenizer(
113
+ text,
114
+ return_tensors="pt",
115
+ padding="longest",
116
+ truncation=True,
117
+ max_length=model.max_txt_len,
118
+ ).to(device)
119
+ inputs_embeds = model.opt_model.model.decoder.embed_tokens(opt_tokens.input_ids)
120
+ inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)
121
+ attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)
122
+ num_txt = 5
123
+ with model.maybe_autocast():
124
+ outputs = model.opt_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, min_length=1,
125
+ max_length=256,
126
+ repetition_penalty=1., num_beams=num_txt, eos_token_id=50118,
127
+ length_penalty=1., num_return_sequences=return_num_txt, temperature=1.)
128
+ output_text = model.opt_tokenizer.batch_decode(outputs)
129
+
130
+ output_text = [re.sub('\t', '', str(x)) for x in output_text]
131
+ output_text = [text.strip() for text in output_text]
132
+ output_text_ = []
133
+ for i in range(len(image)):
134
+ output_text_.append(';'.join(output_text[i * return_num_txt:(i + 1) * return_num_txt]))
135
+
136
+ for i in range(len(image)):
137
+ f.write(image[i][1] + "|" + output_text_[i] + '\n')
138
+
139
+
140
+ if __name__=="__main__":
141
+ split = 'test'
142
+ cat = 'bp'
143
+ fix = '_mf'
144
+ type_fix = ''
145
+ if cat == 'bp':
146
+ fix = '_bp'
147
+ if cat == 'cc':
148
+ fix = '_cc'
149
+
150
+ print(device)
151
+ return_num_txt = 1
152
+ # graph = obonet.read_obo("http://purl.obolibrary.org/obo/go.obo")
153
+
154
+ ### Levenshtein similarity
155
+ print("reading file ...")
156
+ test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split_concat/{}{}.csv'.format(split, fix),
157
+ usecols=['name', 'protein', 'function'], sep='|')
158
+ # test['function'] = test['function'].apply(lambda x: x.lower().split('; '))
159
+ test.columns = ['name', 'protein', 'label']
160
+
161
+ if os.path.exists('/cluster/home/wenkai/LAVIS/output/output_concat_{}{}{}.txt'.format(split, fix, type_fix)):
162
+ os.remove('/cluster/home/wenkai/LAVIS/output/output_concat_{}{}{}.txt'.format(split, fix, type_fix))
163
+ print("stage 2 predict starting")
164
+ stage2_output(test)
165
+ print("stage 2 predict completed")
166
+
167
+ df_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/output_concat_{}{}{}.txt'.format(split, fix, type_fix), sep='|',
168
+ header=None, on_bad_lines='warn')
169
+ df_pred.columns = ['protein', 'pred']
170
+ df_pred = df_pred.drop_duplicates()
171
+ # df_pred['function'] = df_pred['function'].apply(lambda x: str(x).split(';'))
172
+ # df_pred['function'] = df_pred['function'].apply(lambda x: [i.strip() for i in list(set(x))])
173
+
174
+
175
+ data = pd.merge(df_pred, test, on='protein', how='left')
176
+ data = data[data['label'].notnull()]
177
+
178
+ # sim = []
179
+ # for text, label in zip(data['function'].tolist(), data['label'].tolist()):
180
+ # sim.append(func(text, label))
181
+
182
+ # data['sim'] = sim
183
+ # data['avg_score'] = data['sim'].apply(lambda x: round(np.mean(x), 3))
184
+ # data['count'] = data['sim'].apply(lambda x: x.count(1.))
185
+ # print("average similarity score: {}".format(round(data['avg_score'].mean(), 3)))
186
+ # print("Return texts: {}; Accuracy: {}".format(return_num_txt, data['count'].sum()/(return_num_txt*data.shape[0])))
187
+ data[['name', 'label', 'pred']].to_csv(
188
+ '/cluster/home/wenkai/LAVIS/output/predict_concat_{}{}{}.csv'.format(split, cat, type_fix), index=False, sep='|')
189
+
190
+
191
+
192
+
193
+
examples/blip2_predict_func_concat_pretrain.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import torch
5
+ from PIL import Image
6
+
7
+ from lavis.models import load_model_and_preprocess
8
+ from lavis.processors import load_processor
9
+ from lavis.common.registry import registry
10
+ from torch.nn import functional as F
11
+ from lavis.models.base_model import all_gather_with_grad, concat_all_gather
12
+ import numpy as np
13
+ import pandas as pd
14
+ import time
15
+ from fuzzywuzzy import process
16
+ from multiprocessing import Pool, Queue, Process
17
+ import difflib
18
+ import Levenshtein
19
+
20
+ # import obonet
21
+
22
+
23
+ # setup device to use
24
+ device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
25
+
26
+
27
+ # device = torch.device("cuda")
28
+
29
+
30
+ def txt_map(x, txt_dict):
31
+ if type(x) == str:
32
+ x = eval(x)
33
+ x_ = []
34
+ for i in x:
35
+ if i in txt_dict:
36
+ x_.append(txt_dict[i])
37
+ else:
38
+ x_.append(i)
39
+ return x_
40
+
41
+
42
+ def levenshtein_sim(text, label):
43
+ all_s = []
44
+ for x in label:
45
+ s = 0
46
+ for y in text:
47
+ temp = Levenshtein.ratio(x, y)
48
+ if temp > s:
49
+ s = temp
50
+ all_s.append(s)
51
+ all_s = [round(i, 3) for i in all_s]
52
+ return all_s
53
+
54
+
55
+ def func(text, label):
56
+ all_s = []
57
+ for x in text:
58
+ s = 0
59
+ for y in label:
60
+ temp = Levenshtein.ratio(x, y)
61
+ if temp > s:
62
+ s = temp
63
+ all_s.append(s)
64
+ all_s = [round(i, 3) for i in all_s]
65
+ return all_s
66
+
67
+
68
+ def stage2_output(df_test, return_num_txt=1):
69
+ config = {'arch': 'blip2_protein_opt', 'load_finetuned': False,
70
+ 'pretrained': '/cluster/home/wenkai/LAVIS/lavis/output/BLIP2/Pretrain_stage2/20231029182/checkpoint_0.pth',
71
+ 'finetuned': '', 'num_query_token': 32, 'opt_model': 'facebook/opt-2.7b', 'prompt': '',
72
+ 'model_type': 'pretrain_protein_opt2.7b', 'load_pretrained': True, 'freeze_vit': True,
73
+ 'max_protein_len': 600,
74
+ 'max_txt_len': 256}
75
+
76
+ model_cls = registry.get_model_class(config['arch'])
77
+ model = model_cls.from_config(config)
78
+ model.to(device)
79
+ model.eval()
80
+
81
+ images = df_test['protein'].tolist()
82
+ n = len(images)
83
+ bsz = 8
84
+ iter = n // bsz + 1
85
+ if n > 0:
86
+ for i in range(iter):
87
+ image = images[i * bsz: min(n, (i + 1) * bsz)]
88
+ image = [('protein{}'.format(i), x) for i, x in enumerate(image)]
89
+
90
+ with model.maybe_autocast():
91
+ _, _, batch_tokens = model.visual_encoder(image)
92
+ image_embeds = \
93
+ model.ln_vision(batch_tokens.to(device), repr_layers=[model.vis_layers], return_contacts=True)[
94
+ "representations"][model.vis_layers].contiguous()
95
+
96
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
97
+
98
+ query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
99
+ query_output = model.Qformer.bert(
100
+ query_embeds=query_tokens,
101
+ encoder_hidden_states=image_embeds,
102
+ encoder_attention_mask=image_atts,
103
+ return_dict=True,
104
+ )
105
+
106
+ inputs_opt = model.opt_proj(query_output.last_hidden_state)
107
+ atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(device)
108
+
109
+ model.opt_tokenizer.padding_side = "right"
110
+
111
+ text = ['' for i in range(len(image))]
112
+ opt_tokens = model.opt_tokenizer(
113
+ text,
114
+ return_tensors="pt",
115
+ padding="longest",
116
+ truncation=True,
117
+ max_length=model.max_txt_len,
118
+ ).to(device)
119
+ inputs_embeds = model.opt_model.model.decoder.embed_tokens(opt_tokens.input_ids)
120
+ inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)
121
+ attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)
122
+ num_txt = 5
123
+ with model.maybe_autocast():
124
+ outputs = model.opt_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, min_length=1,
125
+ max_length=256,
126
+ repetition_penalty=1., num_beams=num_txt, eos_token_id=50118,
127
+ length_penalty=1., num_return_sequences=return_num_txt, temperature=1.)
128
+ output_text = model.opt_tokenizer.batch_decode(outputs)
129
+
130
+ output_text = [re.sub('\t', '', str(x)) for x in output_text]
131
+ output_text = [text.strip() for text in output_text]
132
+ output_text_ = []
133
+ for i in range(len(image)):
134
+ output_text_.append(';'.join(output_text[i * return_num_txt:(i + 1) * return_num_txt]))
135
+
136
+ f = open('/cluster/home/wenkai/LAVIS/output/output_concat_{}{}{}.txt'.format(split, fix, type_fix), 'a+')
137
+ for i in range(len(image)):
138
+ f.write(image[i][1] + "|" + output_text_[i] + '\n')
139
+ f.close()
140
+
141
+
142
+
143
+
144
+ if __name__=="__main__":
145
+ split = 'test'
146
+ cat = ''
147
+ fix = ''
148
+ type_fix = '_pretrain'
149
+ if cat == 'bp':
150
+ fix = '_bp'
151
+ if cat == 'cc':
152
+ fix = '_cc'
153
+
154
+ print(device)
155
+ return_num_txt = 1
156
+ # graph = obonet.read_obo("http://purl.obolibrary.org/obo/go.obo")
157
+
158
+ ### Levenshtein similarity
159
+ print("reading file ...")
160
+ test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/pretrain/{}_sample10000.csv'.format(split),
161
+ usecols=['name', 'protein', 'function'], sep='|')
162
+ # test['function'] = test['function'].apply(lambda x: x.lower().split('; '))
163
+ test.columns = ['name', 'protein', 'label']
164
+
165
+ if os.path.exists('/cluster/home/wenkai/LAVIS/output/output_concat_{}{}{}.txt'.format(split, fix, type_fix)):
166
+ os.remove('/cluster/home/wenkai/LAVIS/output/output_concat_{}{}{}.txt'.format(split, fix, type_fix))
167
+ print("stage 2 predict starting")
168
+ stage2_output(test)
169
+ print("stage 2 predict completed")
170
+
171
+ df_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/output_concat_{}{}{}.txt'.format(split, fix, type_fix), sep='|',
172
+ header=None, on_bad_lines='warn')
173
+ df_pred.columns = ['protein', 'pred']
174
+ df_pred = df_pred.drop_duplicates()
175
+ # df_pred['function'] = df_pred['function'].apply(lambda x: str(x).split(';'))
176
+ # df_pred['function'] = df_pred['function'].apply(lambda x: [i.strip() for i in list(set(x))])
177
+
178
+
179
+ data = pd.merge(df_pred, test, on='protein', how='left')
180
+ data = data[data['label'].notnull()]
181
+
182
+ # sim = []
183
+ # for text, label in zip(data['function'].tolist(), data['label'].tolist()):
184
+ # sim.append(func(text, label))
185
+
186
+ # data['sim'] = sim
187
+ # data['avg_score'] = data['sim'].apply(lambda x: round(np.mean(x), 3))
188
+ # data['count'] = data['sim'].apply(lambda x: x.count(1.))
189
+ # print("average similarity score: {}".format(round(data['avg_score'].mean(), 3)))
190
+ # print("Return texts: {}; Accuracy: {}".format(return_num_txt, data['count'].sum()/(return_num_txt*data.shape[0])))
191
+ data[['name', 'label', 'pred']].to_csv(
192
+ '/cluster/home/wenkai/LAVIS/output/predict_concat_{}{}{}.csv'.format(split, cat, type_fix), index=False, sep='|')
193
+
194
+
195
+
196
+
197
+
examples/blip2_predict_func_concat_timesplit.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import torch
5
+ from PIL import Image
6
+
7
+ from lavis.models import load_model_and_preprocess
8
+ from lavis.processors import load_processor
9
+ from lavis.common.registry import registry
10
+ from torch.nn import functional as F
11
+ from lavis.models.base_model import all_gather_with_grad, concat_all_gather
12
+ import numpy as np
13
+ import pandas as pd
14
+ import time
15
+ from fuzzywuzzy import process
16
+ from multiprocessing import Pool, Queue, Process
17
+ import difflib
18
+ import Levenshtein
19
+ # import obonet
20
+
21
+
22
+ # setup device to use
23
+ device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
24
+ # device = 'cpu'
25
+
26
+
27
+ def txt_map(x, txt_dict):
28
+ if type(x) == str:
29
+ x = eval(x)
30
+ x_ = []
31
+ for i in x:
32
+ if i in txt_dict:
33
+ x_.append(txt_dict[i])
34
+ else:
35
+ x_.append(i)
36
+ return x_
37
+
38
+
39
+ def levenshtein_sim(text, label):
40
+ all_s = []
41
+ for x in label:
42
+ s = 0
43
+ for y in text:
44
+ temp = Levenshtein.ratio(x, y)
45
+ if temp > s:
46
+ s = temp
47
+ all_s.append(s)
48
+ all_s = [round(i, 3) for i in all_s]
49
+ return all_s
50
+
51
+ def func(text, label):
52
+ all_s = []
53
+ for x in text:
54
+ s = 0
55
+ for y in label:
56
+ temp = Levenshtein.ratio(x, y)
57
+ if temp > s:
58
+ s = temp
59
+ all_s.append(s)
60
+ all_s = [round(i, 3) for i in all_s]
61
+ return all_s
62
+
63
+
64
+ def stage2_output(df_test, return_num_txt=1):
65
+ config = {'arch': 'blip2_protein_opt', 'load_finetuned': False,
66
+ 'pretrained': '/cluster/home/wenkai/LAVIS/lavis/output/BLIP2/Pretrain_stage2/20231007085/checkpoint_19.pth',
67
+ 'finetuned': '', 'num_query_token': 32, 'opt_model': 'facebook/opt-2.7b', 'prompt': '',
68
+ 'model_type': 'pretrain_protein_opt2.7b', 'load_pretrained': True, 'freeze_vit': True,
69
+ 'max_protein_len': 600,
70
+ 'max_txt_len': 256}
71
+
72
+ model_cls = registry.get_model_class(config['arch'])
73
+ model = model_cls.from_config(config)
74
+ model.to(device)
75
+ model.eval()
76
+
77
+ images = df_test['protein'].tolist()
78
+ n = len(images)
79
+ bsz = 12
80
+ iter = n // bsz + 1
81
+
82
+ for i in range(iter):
83
+ image = images[i*bsz: min(n, (i+1)*bsz)]
84
+ image = [('protein{}'.format(i), x) for i, x in enumerate(image)]
85
+
86
+ with model.maybe_autocast():
87
+ _, _, batch_tokens = model.visual_encoder(image)
88
+ image_embeds = model.ln_vision(batch_tokens.to(device), repr_layers=[model.vis_layers], return_contacts=True)["representations"][model.vis_layers].contiguous()
89
+
90
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
91
+
92
+ query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
93
+ query_output = model.Qformer.bert(
94
+ query_embeds=query_tokens,
95
+ encoder_hidden_states=image_embeds,
96
+ encoder_attention_mask=image_atts,
97
+ return_dict=True,
98
+ )
99
+
100
+ inputs_opt = model.opt_proj(query_output.last_hidden_state)
101
+ atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(device)
102
+
103
+ model.opt_tokenizer.padding_side = "right"
104
+
105
+ text = ['' for i in range(len(image))]
106
+ opt_tokens = model.opt_tokenizer(
107
+ text,
108
+ return_tensors="pt",
109
+ padding="longest",
110
+ truncation=True,
111
+ max_length=model.max_txt_len,
112
+ ).to(device)
113
+ inputs_embeds = model.opt_model.model.decoder.embed_tokens(opt_tokens.input_ids)
114
+ inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)
115
+ attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)
116
+ num_txt = 5
117
+ with model.maybe_autocast():
118
+ outputs = model.opt_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, min_length=3,
119
+ max_length=256,
120
+ repetition_penalty=1., num_beams=num_txt, eos_token_id=50118,
121
+ length_penalty=1., num_return_sequences=return_num_txt, temperature=1.)
122
+ output_text = model.opt_tokenizer.batch_decode(outputs)
123
+ output_text = [re.sub('\t', '', x) for x in output_text]
124
+
125
+ output_text = [text.strip() for text in output_text]
126
+ output_text_ = []
127
+ for i in range(len(image)):
128
+ output_text_.append(';'.join(output_text[i * return_num_txt:(i + 1) * return_num_txt]))
129
+ with open('/cluster/home/wenkai/LAVIS/output/output_timeconcat{}.txt'.format(fix), 'a+') as f:
130
+ for i in range(len(image)):
131
+ f.write(image[i][1] + "|" + output_text_[i] + '\n')
132
+
133
+
134
+ cat = 'mf'
135
+ fix = '_mf'
136
+ if cat == 'bp':
137
+ fix = '_bp'
138
+ if cat == 'cc':
139
+ fix = '_cc'
140
+
141
+ return_num_txt = 1
142
+ # graph = obonet.read_obo("http://purl.obolibrary.org/obo/go.obo")
143
+
144
+ ### Levenshtein similarity
145
+ test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/review_time_concat/test{}.csv'.format(fix), usecols=['name', 'protein', 'function'], sep='|')
146
+ #test['function'] = test['function'].apply(lambda x: x.lower().split('; '))
147
+ test.columns = ['name', 'protein', 'label']
148
+
149
+ if os.path.exists('/cluster/home/wenkai/LAVIS/output/output_timeconcat{}.txt'.format(fix)):
150
+ os.remove('/cluster/home/wenkai/LAVIS/output/output_timeconcat{}.txt'.format(fix))
151
+ print("stage 2 predict starting")
152
+ stage2_output(test)
153
+ print("stage 2 predict completed")
154
+
155
+ df_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/output_timeconcat{}.txt'.format(fix), sep='|', header=None, on_bad_lines='warn')
156
+ df_pred.columns = ['protein', 'pred']
157
+ df_pred = df_pred.drop_duplicates()
158
+
159
+ data = pd.merge(df_pred, test, on='protein', how='left')
160
+ data = data[data['label'].notnull()]
161
+
162
+ data[['name', 'label', 'pred']].to_csv('/cluster/home/wenkai/LAVIS/output/predict_timeconcat_{}.csv'.format(cat), index=False, sep='|')
163
+
164
+
165
+
166
+
examples/blip2_predict_names.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import torch
5
+ from PIL import Image
6
+
7
+ from lavis.models import load_model_and_preprocess
8
+ from lavis.processors import load_processor
9
+ from lavis.common.registry import registry
10
+ from torch.nn import functional as F
11
+ from lavis.models.base_model import all_gather_with_grad, concat_all_gather
12
+ import numpy as np
13
+ import pandas as pd
14
+ import time
15
+ from fuzzywuzzy import process
16
+ from multiprocessing import Pool, Queue, Process
17
+ import difflib
18
+ import Levenshtein
19
+ # import obonet
20
+
21
+
22
+ # setup device to use
23
+ device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
24
+ # device = 'cpu'
25
+
26
+
27
+ def txt_map(x, txt_dict):
28
+ if type(x) == str:
29
+ x = eval(x)
30
+ x_ = []
31
+ for i in x:
32
+ if i in txt_dict:
33
+ x_.append(txt_dict[i])
34
+ else:
35
+ x_.append(i)
36
+ return x_
37
+
38
+
39
+ def levenshtein_sim(text, label):
40
+ all_s = []
41
+ for x in label:
42
+ s = 0
43
+ for y in text:
44
+ temp = Levenshtein.ratio(x, y)
45
+ if temp > s:
46
+ s = temp
47
+ all_s.append(s)
48
+ all_s = [round(i, 3) for i in all_s]
49
+ return all_s
50
+
51
+ def func(text, label):
52
+ all_s = []
53
+ for x in label:
54
+ s = 0
55
+ for y in text:
56
+ temp = Levenshtein.ratio(x, y)
57
+ if temp > s:
58
+ s = temp
59
+ all_s.append(s)
60
+ all_s = [round(i, 3) for i in all_s]
61
+ return all_s
62
+
63
+
64
+ def stage2_output(df_test):
65
+ config = {'arch': 'blip2_protein_opt', 'load_finetuned': False,
66
+ 'pretrained': '/cluster/home/wenkai/LAVIS/lavis/output/BLIP2/Pretrain_stage2/20230926091/checkpoint_3.pth',
67
+ 'finetuned': '', 'num_query_token': 32, 'opt_model': 'facebook/opt-2.7b', 'prompt': '',
68
+ 'model_type': 'pretrain_protein_opt2.7b', 'load_pretrained': True, 'freeze_vit': True,
69
+ 'max_protein_len': 600,
70
+ 'max_txt_len': 25}
71
+
72
+ model_cls = registry.get_model_class(config['arch'])
73
+ model = model_cls.from_config(config)
74
+ model.to(device)
75
+ model.eval()
76
+
77
+ images = df_test['protein'].tolist()
78
+ n = len(images)
79
+ bsz = 12
80
+ iter = n // bsz + 1
81
+
82
+ for i in range(iter):
83
+ image = images[i*bsz: min(n, (i+1)*bsz)]
84
+ image = [('protein{}'.format(i), x) for i, x in enumerate(image)]
85
+
86
+ with model.maybe_autocast():
87
+ _, _, batch_tokens = model.visual_encoder(image)
88
+ image_embeds = model.ln_vision(batch_tokens.to(device), repr_layers=[model.vis_layers], return_contacts=True)["representations"][model.vis_layers].contiguous()
89
+
90
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
91
+
92
+ query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
93
+ query_output = model.Qformer.bert(
94
+ query_embeds=query_tokens,
95
+ encoder_hidden_states=image_embeds,
96
+ encoder_attention_mask=image_atts,
97
+ return_dict=True,
98
+ )
99
+
100
+ inputs_opt = model.opt_proj(query_output.last_hidden_state)
101
+ atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(device)
102
+
103
+ model.opt_tokenizer.padding_side = "right"
104
+
105
+ text = ['' for i in range(len(image))]
106
+ opt_tokens = model.opt_tokenizer(
107
+ text,
108
+ return_tensors="pt",
109
+ padding="longest",
110
+ truncation=True,
111
+ max_length=model.max_txt_len,
112
+ ).to(device)
113
+ inputs_embeds = model.opt_model.model.decoder.embed_tokens(opt_tokens.input_ids)
114
+ inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)
115
+ attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)
116
+ num_txt = 5
117
+ return_num_txt = 2
118
+ with model.maybe_autocast():
119
+ outputs = model.opt_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, min_length=3,
120
+ max_length=30,
121
+ repetition_penalty=5., num_beams=num_txt, eos_token_id=50118,
122
+ length_penalty=1., num_return_sequences=return_num_txt, temperature=1.)
123
+ output_text = model.opt_tokenizer.batch_decode(outputs, skip_special_tokens=True)
124
+ output_text = [text.strip() for text in output_text]
125
+ output_text_ = []
126
+ for i in range(len(image)):
127
+ output_text_.append(';'.join(output_text[i * return_num_txt:(i + 1) * return_num_txt]))
128
+ with open('/cluster/home/wenkai/LAVIS/output/output_names.txt', 'a+') as f:
129
+ for i in range(len(image)):
130
+ f.write(image[i][1] + "|" + output_text_[i] + '\n')
131
+
132
+
133
+ def evaluate_score(data):
134
+ model_config = {'arch': 'blip2_protein', 'load_finetuned': False,
135
+ 'pretrained': '/cluster/home/wenkai/LAVIS/lavis/output/BLIP2/Pretrain_stage1/20230925102/checkpoint_6.pth',
136
+ 'finetuned': '', 'num_query_token': 32, 'prompt': '',
137
+ 'model_type': 'pretrain', 'load_pretrained': True, 'freeze_vit': False,
138
+ 'max_protein_len': 512, 'max_txt_len': 30}
139
+
140
+ model_cls = registry.get_model_class(model_config['arch'])
141
+ model = model_cls.from_config(model_config)
142
+ model = model.to(device)
143
+ model.eval()
144
+
145
+ # evaluate
146
+ t0 = time.time()
147
+ proteins = list(data['protein'])
148
+ txts = list(data['function'])
149
+ scores = []
150
+ for seq, txt in zip(proteins, txts):
151
+ image = [('protein1', seq)]
152
+ _, _, batch_tokens = model.visual_encoder(image)
153
+ image_embeds = \
154
+ model.ln_vision(batch_tokens.to(device), repr_layers=[30], return_contacts=True)["representations"][
155
+ 30].contiguous()
156
+
157
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
158
+
159
+ query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
160
+
161
+ query_output = model.Qformer.bert(
162
+ query_embeds=query_tokens,
163
+ encoder_hidden_states=image_embeds,
164
+ encoder_attention_mask=image_atts,
165
+ use_cache=True,
166
+ return_dict=True,
167
+ )
168
+
169
+ image_feats = F.normalize(model.vision_proj(query_output.last_hidden_state), dim=-1)
170
+
171
+ image_feats_all = concat_all_gather(image_feats)
172
+
173
+ if type(txt) == str:
174
+ txt = eval(txt)
175
+ length = len(txt)
176
+ with torch.no_grad():
177
+ text_tokens = model.tokenizer(
178
+ txt,
179
+ padding="max_length",
180
+ truncation=True,
181
+ max_length=model.max_txt_len,
182
+ return_tensors="pt",
183
+ ).to(device)
184
+ text_output = model.Qformer.bert(
185
+ text_tokens.input_ids,
186
+ attention_mask=text_tokens.attention_mask,
187
+ return_dict=True,
188
+ )
189
+
190
+ text_feat = F.normalize(
191
+ model.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1
192
+ )
193
+
194
+ text_feat_all = concat_all_gather(text_feat)
195
+ sim_q2t = torch.matmul(image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1)).squeeze()
196
+ sim_i2t, _ = sim_q2t.max(-1)
197
+ # print('sim_i2t: {}'.format(sim_i2t))
198
+ if length > 1:
199
+ scores.append(list(sim_i2t.detach().cpu().numpy()))
200
+ else:
201
+ scores.append([sim_i2t.item()])
202
+ print("model evaluate time: {}".format(time.time() - t0))
203
+ data['sim'] = scores
204
+ return data
205
+
206
+
207
+
208
+ # graph = obonet.read_obo("http://purl.obolibrary.org/obo/go.obo")
209
+
210
+ ### Levenshtein similarity
211
+ test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/raw_time_split/reviewed//test.csv', sep='|')
212
+ test['function'] = test['function'].apply(lambda x: x.lower())
213
+
214
+
215
+ if os.path.exists('/cluster/home/wenkai/LAVIS/output/output_names.txt'):
216
+ os.remove('/cluster/home/wenkai/LAVIS/output/output_names.txt')
217
+ print("stage 2 predict starting")
218
+ stage2_output(test)
219
+ print("stage 2 predict completed")
220
+
221
+ df_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/output_names.txt', sep='|', header=None, on_bad_lines='warn')
222
+ df_pred.columns = ['protein', 'function']
223
+ df_pred = df_pred.drop_duplicates()
224
+ df_pred['function'] = df_pred['function'].apply(lambda x: str(x).split(';'))
225
+ df_pred['function'] = df_pred['function'].apply(lambda x: [i.strip() for i in list(set(x))])
226
+
227
+ test.columns
228
+ test_g = test.groupby(['protein']).agg({'function': lambda x: list(x)}).reset_index()
229
+ test_g.columns = ['protein', 'label']
230
+
231
+ data = pd.merge(df_pred, test_g, on='protein', how='left')
232
+ data = data[data['label'].notnull()]
233
+
234
+ sim = []
235
+ for text, label in zip(data['function'].tolist(), data['label'].tolist()):
236
+ sim.append(func(text, label))
237
+
238
+ data['sim'] = sim
239
+ data['avg_score'] = data['sim'].apply(lambda x: round(np.mean(x), 3))
240
+ print("average similarity score: {}".format(round(data['avg_score'].mean(), 3)))
241
+ data.to_csv('/cluster/home/wenkai/LAVIS/output/output_names.csv', index=False, sep='|')
242
+
243
+
244
+
245
+
246
+
247
+
examples/predict_test.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH -J infer_test
3
+ #SBATCH -p gpu1
4
+ #SBATCH -N 1
5
+ #SBATCH -w node[84]
6
+ #SBATCH --mem 80G
7
+ #SBATCH --gres=gpu:1
8
+ #SBATCH --output=log_predict_test.out
9
+ #SBATCH --error=log_predict_test.err
10
+ #SBATCH --cpus-per-task=8
11
+ module load anaconda3/2021.05
12
+ source activate LAVIS
13
+
14
+ python blip2_predict_func_concat_pretrain.py
examples/predict_train.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH -J infer_cc
3
+ #SBATCH -p gpu1
4
+ #SBATCH -N 1
5
+ #SBATCH -w node[84]
6
+ #SBATCH --mem 80G
7
+ #SBATCH --gres=gpu:1
8
+ #SBATCH --output=log_predict.out
9
+ #SBATCH --error=log_predict.err
10
+ #SBATCH --cpus-per-task=8
11
+ module load anaconda3/2021.05
12
+ source activate LAVIS
13
+
14
+ python blip2_predict_func_concat.py